In [9]:
from smartgd.common.data.graph_drawing_data import *
from smartgd.common.data.transforms import NormalizeGraph, AddAdjacencyInfo, ComputeShortestPath
from smartgd.common.nn import NormalizeRotation
from smartgd.constants import EPS

from dataclasses import fields

from scipy import spatial, sparse
import networkx as nx
import torch_sparse
import torch_scatter

In [3]:
G_list = [
    nx.wheel_graph(10),
    nx.ladder_graph(10),
    nx.grid_graph((4, 4)),
    nx.lollipop_graph(6, 6)
]
for G in G_list:
    G.graph.update(dict(
        name="name",
        dataset="dataset"
    ))
data_list = [GraphDrawingData.new(G).post_transform() for G in G_list]
(wheel,
 ladder,
 grid,
 lollipop) = data_list
batch = Batch.from_data_list(data_list)

In [4]:
batch

GraphDrawingDataBatch(G=[4], num_nodes=58, perm_index=[2, 842], edge_metaindex=[182], apsp_attr=[842], perm_weight=[842], laplacian_eigenvector_pe=[58, 3], name=[4], dataset=[4], n=[4], m=[4], aggr_metaindex=[842], pos=[58, 2], face=[3, 74], edge_pair_metaindex=[2, 8318], gabriel_index=[2, 172], rng_index=[2, 122], batch=[58], ptr=[5])

In [7]:
struct = batch.struct()
struct

GraphStruct(pos=[58, 2], n=[4], m=[4], x=[58, 3], batch=[58], perm_index=[2, 842], perm_attr=[842, 1], perm_weight=[842], edge_index=[2, 182], edge_attr=[182, 1], edge_weight=[182], aggr_index=[2, 842], aggr_attr=[842, 1], aggr_weight=[842], apsp_attr=[842], gabriel_index=[2, 172], rng_index=[2, 122], edge_pair_index=[2, 2, 8318])

In [13]:
fields(struct)[0]

Field(name='pos',type=<class 'torch.FloatTensor'>,default=<dataclasses._MISSING_TYPE object at 0x10d7c75e0>,default_factory=<dataclasses._MISSING_TYPE object at 0x10d7c75e0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD)

In [8]:
NormalizeRotation()(struct)

GraphStruct(pos=[58, 2], n=[4], m=[4], x=[58, 3], batch=[58], perm_index=[2, 842], perm_attr=[842, 1], perm_weight=[842], edge_index=[2, 182], edge_attr=[182, 1], edge_weight=[182], aggr_index=[2, 842], aggr_attr=[842, 1], aggr_weight=[842], apsp_attr=[842], gabriel_index=[2, 172], rng_index=[2, 122], edge_pair_index=[2, 2, 8318])

In [89]:
data = wheel
delaunay_edges = data.face[list(permutations(range(3), 2)), :].transpose(1, 2).flatten(end_dim=1).unique(dim=0)
tree = spatial.KDTree(data.pos.detach().cpu().numpy())
c = data.pos[delaunay_edges]
m = c.mean(dim=1)
d = (c[:, 0, :] - c[:, 1, :]).norm(dim=1)
dm = torch.tensor(tree.query(x=m.detach().cpu().numpy(), k=1)[0]).to(m)
delaunay_edges[dm >= d / 2 * (1 - EPS)].T

tensor([[0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 4, 5, 5, 6, 6, 6, 7, 8, 8, 9, 9],
        [1, 2, 0, 5, 7, 0, 4, 9, 4, 8, 2, 3, 5, 6, 1, 4, 4, 8, 9, 1, 3, 6, 2, 6]])

In [100]:
data = wheel
delaunay_edges = data.face[list(permutations(range(3), 2)), :].transpose(1, 2).flatten(end_dim=1).unique(dim=0)
tree = spatial.KDTree(data.pos.detach().cpu().numpy())
c = data.pos[delaunay_edges]
src, dst = c[:, 0, :], c[:, 1, :]
d = (src - dst).norm(dim=1)
r = (d * (1 - EPS)).detach().cpu().numpy()

p0 = tree.query_ball_point(x=src.detach().cpu().numpy(), r=r)
p0m = sparse.lil_matrix((len(delaunay_edges), data.num_nodes))
p0m.rows, p0m.data = p0, list(map(np.ones_like, p0))
p0idx = torch.tensor(p0m.toarray(), device=src.device, dtype=torch.bool)

p1 = tree.query_ball_point(x=dst.detach().cpu().numpy(), r=r)
p1m = sparse.lil_matrix((len(delaunay_edges), data.num_nodes))
p1m.rows, p1m.data = p1, list(map(np.ones_like, p1))
p1idx = torch.tensor(p1m.toarray(), device=dst.device, dtype=torch.bool)

data.rng_index = delaunay_edges[~(p0idx & p1idx).any(dim=1)].T

In [6]:
cat_index = torch.cat([batch.edge_index, batch.rng_index], dim=1)
merged_index, merged_value = torch_sparse.coalesce(
    index=cat_index,
    value=torch.ones_like(cat_index[0]),
    m=batch.num_nodes,
    n=batch.num_nodes
)
intersection = torch_scatter.scatter((merged_value > 1).to(float), merged_index[0])
union = torch_scatter.scatter((merged_value > 0).to(float), merged_index[0])
torch_scatter.scatter(intersection / union, batch.batch, reduce="mean")

tensor([0.1872, 0.0183, 0.0677, 0.1369], dtype=torch.float64)

In [11]:
data = lollipop
size = data.num_nodes, data.num_nodes
adj = torch.sparse_coo_tensor(data.edge_index, torch.ones_like(data.edge_index[0]), size=size, dtype=bool).to_dense()
rng_adj = torch.sparse_coo_tensor(data.rng_index, torch.ones_like(data.rng_index[0]), size=size, dtype=bool).to_dense()

In [12]:
torch.mean((adj & rng_adj).sum(dim=1) / (adj | rng_adj).sum(dim=1))

tensor(0.1369)

In [None]:
np.mean((adj & shape_adj).sum(axis=1) / (adj | shape_adj).sum(axis=1))

In [None]:
n = max(edges.max(), shape_edges.max()) + 1
adj = sparse.coo_matrix((np.ones_like(edges[:, 0]), edges.T), (n, n)).astype(bool).toarray()
shape_adj = sparse.coo_matrix((np.ones_like(shape_edges[:, 0]), shape_edges.T), (n, n)).astype(bool).toarray()
assert np.all(adj.T == adj) and np.all(shape_adj.T == shape_adj)
return np.mean((adj & shape_adj).sum(axis=1) / (adj | shape_adj).sum(axis=1))

In [None]:


# TODO: torchfy
def rng(pos, edge_set, eps=1e-5):
    tree = spatial.KDTree(pos)
    c = pos[edge_set]
    d = np.linalg.norm(c[:, 0, :] - c[:, 1, :], axis=1)
    p0 = tree.query_ball_point(x=c[:, 0, :], r=d*(1 - eps))
    p1 = tree.query_ball_point(x=c[:, 1, :], r=d*(1 - eps))
    p0m = sparse.lil_matrix((len(edge_set), len(pos)))
    p0m.rows, p0m.data = p0, list(map(np.ones_like, p0))
    p1m = sparse.lil_matrix((len(edge_set), len(pos)))
    p1m.rows, p1m.data = p1, list(map(np.ones_like, p1))
    return edge_set[~(p0m.toarray().astype(bool) & p1m.toarray().astype(bool)).any(axis=1)]


# TODO: torchfy
def jaccard_index(edges, shape_edges):
    n = max(edges.max(), shape_edges.max()) + 1
    adj = sparse.coo_matrix((np.ones_like(edges[:, 0]), edges.T), (n, n)).astype(bool).toarray()
    shape_adj = sparse.coo_matrix((np.ones_like(shape_edges[:, 0]), shape_edges.T), (n, n)).astype(bool).toarray()
    assert np.all(adj.T == adj) and np.all(shape_adj.T == shape_adj)
    return np.mean((adj & shape_adj).sum(axis=1) / (adj | shape_adj).sum(axis=1))
