In [1]:
%load_ext autoreload
%autoreload 2
from torch_geometric.datasets.qm9 import QM9
import torch_geometric.nn as tgnn
import tqdm
import numpy as np
import dfs_code
import json
import torch

# Molecular Graph Based

In [2]:
path = '../datasets/qm9_geometric_work/'
dataset = QM9(path)

In [5]:
dfs_codes = {}
for data in tqdm.tqdm(dataset):
    vertex_features = data.x.detach().cpu().numpy()
    edge_features = data.edge_attr.detach().cpu().numpy()
    vertex_labels = np.argmax(vertex_features[:, :5], axis=1).tolist()
    edge_labels = np.argmax(edge_features, axis=1).tolist()
    code, dfs_indices = dfs_code.min_dfs_code_from_torch_geometric_new(data, vertex_labels, edge_labels)
    dfs_codes[data.name] = {'min_dfs_code':code, 'dfs_indices':dfs_indices}

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 130831/130831 [02:07<00:00, 1030.16it/s]


# Distance Graph Based

In [3]:
cutoff = 5.

In [4]:
dfs_codes = {}
for data in tqdm.tqdm(dataset):
    data.edge_index = tgnn.radius_graph(data.pos, r=cutoff)
    row, col = data.edge_index
    ew = (data.pos[row] - data.pos[col]).norm(dim=-1).numpy()
    edge_labels = np.zeros(len(ew), dtype=np.int32)
    edge_labels[ew < 2] = 0
    edge_labels[(ew >= 2)*(ew < 3)] = 1
    edge_labels[(ew >= 3)*(ew < 4)] = 2
    edge_labels[ew >= 4] = 3
    
    vertex_features = data.x.detach().cpu().numpy()
    vertex_labels = np.argmax(vertex_features[:, :5], axis=1).tolist()
    edge_labels = edge_labels.tolist()
    
    code, dfs_indices = dfs_code.min_dfs_code_from_torch_geometric_new(data, vertex_labels, edge_labels)
    dfs_codes[data.name] = {'min_dfs_code':code, 'dfs_indices':dfs_indices}

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 130831/130831 [11:56<00:00, 182.51it/s]


In [5]:
with open(path+'dist5_min_dfs_codes.json', 'w') as f:
    json.dump(dfs_codes, f)