In [2]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets.qm9 import QM9
import torch_geometric.datasets.qm9 as qm9
from torch_geometric.data import DataLoader
import torch_geometric.nn as tgnn
from torch_geometric.data import Data
from torch_scatter import scatter
import torch_geometric
import tqdm
import numpy as np
import wandb
import random
import pandas as pd

import rdkit
from rdkit import Chem
from rdkit.Chem.rdchem import HybridizationType
from rdkit.Chem.rdchem import BondType as BT
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

import dfs_code
import _dfs_codes
import networkx as nx

manualSeed = 43
random.seed(manualSeed)
torch.manual_seed(manualSeed)
np.random.seed(manualSeed)
print("Random Seed: ", manualSeed)

Random Seed:  43


In [3]:
from torch_geometric.data import InMemoryDataset
class ChEMBL(InMemoryDataset):
    def __init__(self):
        super().__init__()
        self.data, self.slices = torch.load('../../datasets/ChEMBL/preprocessedPlusHs_split1.pt')

In [4]:
dataset = ChEMBL()

In [5]:
dataset[0]

Data(edge_attr=[106, 4], edge_index=[2, 106], idx=[1], name="CHEMBL3394127", x=[50, 39], z=[50])

In [6]:
dataset[10000].edge_index

tensor([[ 0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3,  4,  4,
          4,  5,  5,  5,  6,  7,  7,  7,  7,  8,  8,  9,  9,  9, 10, 10, 11, 11,
         11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17,
         17, 18, 19, 19, 19, 20, 20, 20, 20, 21, 21, 21, 21, 22, 22, 22, 22, 23,
         23, 23, 24, 25, 25, 25, 26, 26, 26, 26, 27, 27, 27, 27, 28, 28, 28, 29,
         29, 29, 30, 30, 30, 31, 31, 31, 32, 32, 33, 33, 33, 33, 34, 34, 34, 35,
         35, 36, 36, 36, 36, 37, 37, 37, 38, 38, 39, 39, 39, 39, 40, 40, 41, 42,
         43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
         61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76],
        [ 1, 41, 42, 43,  0,  2, 44, 45,  1,  3,  4, 46,  2, 47, 48, 49,  2,  5,
         50,  4,  6,  7,  5,  5,  8, 51, 52,  7,  9,  8, 10, 19,  9, 11, 10, 12,
         16, 11, 13, 53, 12, 14, 40, 13, 15, 38, 14, 16, 54, 11, 15, 17, 16, 18,
         19, 17,  9, 17, 20, 19, 21

In [None]:
dfs_codes = {}
bad_examples = []
for data in tqdm.tqdm(dataset):
    vertex_features = data.x.detach().cpu().numpy()
    edge_features = data.edge_attr.detach().cpu().numpy()
    vertex_labels = data.z.detach().cpu().numpy().tolist()
    edge_labels = np.argmax(edge_features, axis=1).tolist()
    if len(vertex_features) > 30:
        continue
    try:    
        code, dfs_indices = dfs_code.min_dfs_code_from_torch_geometric(data, vertex_labels, edge_labels)
        dfs_codes[data.name] = {'min_dfs_code':code, 'dfs_indices':dfs_indices}
    except KeyboardInterrupt:
        break        
    except KeyError as e:
        print(data.name, 'failed')
        bad_examples += [data.name]
        print('diagnosis: ')
        edges_coo = data.edge_index.detach().cpu().numpy().T
        edge2id = {tuple(e.tolist()): idx for idx, e in enumerate(edges_coo)}
        edge_list = dfs_code.torch_geometric_2_lists(data, vertex_labels, edge_labels)
        code = _dfs_codes.compute_minimal_dfs_code(edge_list, vertex_labels)
        dfs_indices = {}
        for idx, row in enumerate(code):
            code[idx][-2] = edge2id[(row[-3], row[-1])]
            dfs_indices[row[-3]] = row[0]
            dfs_indices[row[-1]] = row[1]
        g = nx.Graph()
        g.add_nodes_from(list(range(len(vertex_labels))))
        g.add_edges_from(edges_coo.tolist())
        print('has correct length?', len(code) == len(edge_list)//2,'is valid?', dfs_code.isValidDFSCode(code), end='')
        print(' is connected?', nx.is_connected(g))
    
        

  0%|                                                 | 0/29917 [00:00<?, ?it/s]

In [None]:
len(bad_examples)

In [None]:
len(dfs_codes)

In [None]:
bad_examples

In [None]:
!mkdir ../../datasets/ChEMBL/leq30

In [None]:
import json
with open('../../datasets/ChEMBL/leq30/bad_examples.json', 'w') as f:
    json.dump(bad_examples, f)

In [None]:
with open('../../datasets/ChEMBL/leq30/min_dfs_codes.json', 'w') as f:
    json.dump(dfs_codes, f)

In [None]:
dfs_codes