In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import numpy as np
from torch.utils.data import Dataset, DataLoader
import glob
import wandb
import os
import torch.optim as optimizers
import dfs_code
from torch_geometric.data import InMemoryDataset, Data
import pickle
import torch
import torch.nn as nn
import tqdm
import copy
import pandas as pd
import torch.nn.functional as F
import sys
import yaml
import functools
from ml_collections import ConfigDict
sys.path = ['../../src'] + sys.path
from dfs_transformer import DFSCodeSeq2SeqFC, Deepchem2TorchGeometric, Trainer, to_cuda, Enzymes
from graphein.protein.resi_atoms import RESI_THREE_TO_1, AMINO_ACIDS

# download pretrained model

In [None]:
run = wandb.init(mode="online", 
                 project="pubchem", 
                 entity="dfstransformer", 
                 job_type="inference")

#model_at = run.use_artifact("bertloops0.5-10M-nofeats" + ":latest")
model_at = run.use_artifact("rnd2min-10M-nofeats" + ":latest")

model_dir = model_at.download()
run.finish()
features = None #"chemprop"
fingerprint = 'min-mean-max-std'
fingerprint = 'cls'
load_flag = True

In [None]:
with open(model_dir+"/config.yaml") as file:
    config = ConfigDict(yaml.load(file, Loader=yaml.FullLoader))

In [None]:
device = torch.device('cuda:%d'%config.training.gpu_id if torch.cuda.is_available()  else 'cpu')

In [None]:
m = config.model

In [None]:
model = DFSCodeSeq2SeqFC(**m)
if load_flag:
    model.load_state_dict(torch.load(model_dir+'/checkpoint.pt', map_location=device))

In [None]:
model.to(device)

In [None]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("number of trainable parameters %d"%params)

# load dataset

In [None]:
trainset = pd.read_csv("../../datasets/mymoleculenet/bbbp/0/train.csv")
train_X, train_y = trainset["smiles"].to_numpy(), trainset["target"].to_numpy()
traindata = Deepchem2TorchGeometric(train_X, train_y, loaddir="../../results/mymoleculenet_plus_features/bbbp/1/", features=features)

In [None]:
import networkx as nx

In [None]:
def nmatch(n1, n2):
    return n1==n2

def ematch(e1, e2):
    return e1==e2

def edit_distance(g1, g2):
    return nx.graph_edit_distance(g1, g2, node_match=nmatch, edge_match=ematch)


# exact edit distance is too expensive, we use approximate edit distance instead
def edit_distance_approx(g1, g2, nsteps=1):
    iterator = nx.optimize_graph_edit_distance(g1, g2, node_match=nmatch, edge_match=ematch)
    for i in range(nsteps):
        try:
            res = next(iterator)
        except:
            break 
    return res

In [None]:
G=nx.Graph()
G.add_node(1, id=1)
G.add_node(2, id=5)
G.add_node(3, id=3)
G.add_edge(1, 2, type='a')
G.add_edge(2, 3, type='b')


G2=nx.Graph()
G2.add_node(1, id=1)
G2.add_node(2, id=2)
G2.add_node(3, id=3)
G2.add_edge(1, 2, type='c')
G2.add_edge(2, 3, type='b')

In [None]:
edit_distance(G,G2)

In [None]:
def collate_graph(dlist):
    nx_batch = []
    smiles = []
    
    for d in dlist:
        smiles += d.smiles
        graph = nx.Graph()
        for idx, atomic_number in enumerate(d.z.numpy()):
            graph.add_node(idx, atomic_number=atomic_number)
        for edge, edge_type in zip(d.edge_index.numpy().T, np.argmax(d.edge_attr.numpy(), axis=1)):
            graph.add_edge(edge[0], edge[1], bond_type=edge_type)
        nx_batch += [graph]
            
        
    return smiles, nx_batch

In [None]:
def collate_fn(dlist):
    node_batch = [] 
    edge_batch = []
    y_batch = []
    code_batch = []
    smiles = []
    
    for d in dlist:
        smiles += d.smiles
        edge_features = d.edge_features.clone()

        code, index = d.min_dfs_code.clone(), d.min_dfs_index.clone()
        code_batch += [code]
        node_batch += [d.node_features.clone()]
        edge_batch += [edge_features]
        y_batch += [d.y.clone()]
            
    y = torch.cat(y_batch).unsqueeze(1)
    return smiles, code_batch, node_batch, edge_batch, y

In [None]:
trainloaderg = DataLoader(traindata, batch_size=1, shuffle=False, pin_memory=False, 
                         collate_fn=collate_graph)
trainloader = DataLoader(traindata, batch_size=1, shuffle=False, pin_memory=False, 
                         collate_fn=collate_fn)

In [None]:
graphs = {''.join(d[0]): d[-1] for d in trainloaderg}

In [None]:
graphs = {key: value[0] for key, value in graphs.items()}

In [None]:
#'CCN(CC)CCNC(=O)c1cc(Cl)cc(Cl)c1OC'

In [None]:
reference = graphs['CCN(CC)CCNC(=O)c1cc(Cl)cc(Cl)c1OC']

from rdkit import Chem
m1 = Chem.MolFromSmiles('CCN(CC)CC(=O)OCC(=O)C1(O)CCC2C3CCC4=CC(=O)C=CC4(C)C3C(O)CC21C')

Chem.Draw.MolToMPL(m1)

In [None]:
edit_distances = {}
for smiles, graph in tqdm.tqdm(graphs.items()):
    edit_distances[smiles] = edit_distance_approx(reference, graph, 1)  

In [None]:
from matplotlib import pyplot as plt

In [None]:
plt.hist(list(edit_distances.values()), bins='rice')

In [None]:
from sklearn.decomposition import PCA

In [None]:
encodings = {}
encodings2 = {}
iterator = iter(trainloader)
for d in tqdm.tqdm(trainloader):
    d = next(iterator)
    smiles = ''.join(d[0])
    d = d[1:]
    d = [to_cuda(dd, device) for dd in d]
    encodings[smiles] = model.encode(*d[:-1], fingerprint).detach().cpu().numpy()
    dfs1, dfs2, atm1, atm2, bnd = model(*d[:-1])
    enc2 = torch.cat((dfs1.mean(dim=0), dfs2.mean(dim=0), atm1.mean(dim=0), atm2.mean(dim=0), bnd.mean(dim=0)), dim=1)
    encodings2[smiles] = enc2.detach().cpu().numpy()

In [None]:
pca = PCA(n_components=600)

In [None]:
X = np.concatenate(list(encodings.values()), axis=0)
pca.fit(X)

In [None]:
encodings_pca = {smiles: pca.transform(x) for smiles, x in encodings.items()}

In [None]:
distances = {}

In [None]:
ref_enc = encodings['CCN(CC)CCNC(=O)c1cc(Cl)cc(Cl)c1OC']
for smiles, enc in tqdm.tqdm(encodings.items()):
    dists = np.linalg.norm(ref_enc - enc, axis=1)
    distances[smiles] = np.mean(dists)

In [None]:
edit_dists = edit_distances.values()
transf_dists = distances.values() 

In [None]:
plt.figure(figsize=(10,10))
plt.scatter(edit_dists, transf_dists, c='crimson')

plt.plot([0, max(edit_dists)], [0, max(transf_dists)], 'b-')
plt.xlabel('approximate edit distance', fontsize=15)
plt.ylabel('Euclidean distance', fontsize=15)
plt.show()

In [None]:
x = []
y = []
z = []
for smiles in edit_distances.keys():
    x += [encodings_pca[smiles][0, 0]]
    y += [encodings_pca[smiles][0, 1]]
    z += [edit_distances[smiles]]
    

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(x, y, z)
ax.dist = 5
ax.azim = 60
plt.show()

# what happens along an edit path

In [None]:
print(np.unique(list(edit_distances.values())))
for smiles, dist in edit_distances.items():
    if dist == 4:
        print(smiles)
        break

paths, cost = nx.optimal_edit_paths(reference, graphs['CCN(CC)CCNC(=O)c1cc(Br)c(N)cc1OC'])