In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import numpy as np
from torch.utils.data import Dataset, DataLoader
import glob
import wandb
import os
import torch.optim as optimizers

In [3]:
import dfs_code
from torch_geometric.data import InMemoryDataset, Data
import pickle
import torch
import torch.nn as nn
import tqdm
import copy
torch.multiprocessing.set_sharing_strategy('file_system') # this is important
# ulimit -n 500000
def set_worker_sharing_strategy(worker_id: int) -> None:
    torch.multiprocessing.set_sharing_strategy('file_system')

In [4]:
import sys
sys.path = ['/home/chrisw/Documents/projects/2021/graph-transformer/src'] + sys.path
from dfs_transformer import EarlyStopping, DFSCodeSeq2SeqFC

In [9]:
config = wandb.config
config.max_nodes = 100
config.max_edges = 200
config.nlayers = 6
config.emb_dim = 50
config.nhead = 5
config.dim_feedforward = 2*(5*config.emb_dim)
config.lr = 0.0003
config.n_epochs = 10000
config.patience = 5
config.factor = 0.95
config.minimal_lr = 6e-8
config.target_idx = 7
config.batch_size = 512#256
config.valid_patience = 100
config.valid_minimal_improvement=0.00
config.model_dir = '../models/chembl/transformer/mini/'
config.num_workers = 4
config.dfs_codes = None

In [10]:
os.makedirs(config.model_dir, exist_ok=True)

In [11]:
torch.multiprocessing.get_all_sharing_strategies()

{'file_descriptor', 'file_system'}

In [12]:
path = "/mnt/ssd/datasets/ChEMBL/ChEMBL100_noH/"

In [13]:
class ChEMBL100NoH(Dataset):
    """ChEMBL dataset of molecules and minimal DFS codes."""
    # create data structure that says which id is in which file...
    def __init__(self, path, transform=None):
        self.path = path
        self.fnames = glob.glob(path+"CHEMBL*.pkl")
        
    def __len__(self):
        return len(self.fnames)

    def __getitem__(self, idx):
        with open(self.fnames[idx], 'rb') as f:
            d = pickle.load(f)
        data = Data(x=torch.tensor(d['x']),
                    z=torch.tensor(d['z']),
                    edge_attr=torch.tensor(d['edge_attr']),
                    edge_index=torch.tensor(d['edge_index'], dtype=torch.long),
                    name=d['name'],
                    min_dfs_code=torch.tensor(d['min_dfs_code']),
                    min_dfs_index=torch.tensor(d['min_dfs_index'], dtype=torch.long),
                    smiles=d['smiles'])
        return data

In [14]:
def collate_fn(dlist):
    x_batch = [] 
    z_batch = []
    edge_attr_batch = []
    rnd_code_batch = []
    min_code_batch = []
    for d in dlist:
        rnd_code, rnd_index = dfs_code.rnd_dfs_code_from_torch_geometric(d, 
                                                                         d.z.numpy().tolist(), 
                                                                         np.argmax(d.edge_attr.numpy(), axis=1))
        x_batch += [d.x]
        z_batch += [d.z]#[nn.functional.one_hot(d.z, 118)]#118 elements in periodic table
        edge_attr_batch += [d.edge_attr]
        rnd_code_batch += [torch.tensor(rnd_code)]
        min_code_batch += [d.min_dfs_code]
    return rnd_code_batch, x_batch, z_batch, edge_attr_batch, min_code_batch

In [15]:
ngpu=1
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

In [16]:
dataset = ChEMBL100NoH(path)

In [17]:
loader = DataLoader(dataset, batch_size=256, shuffle=False, pin_memory=False, collate_fn=collate_fn)
                  # worker_init_fn=set_worker_sharing_strategy)
#shuffle False -> huge speedup

In [18]:
to_cuda = lambda T: [t.cuda() for t in T]

In [19]:
model = DFSCodeSeq2SeqFC(n_atoms=118, n_bonds=4, emb_dim=50, nhead=5, nlayers=6, max_nodes=100, max_edges=400,
                         atom_encoder=nn.Embedding(118, 50), bond_encoder=nn.Linear(4, 50))

In [20]:
model.load_state_dict(torch.load(config.model_dir+'checkpoint.pt'))

<All keys matched successfully>

In [21]:
optim = optimizers.Adam(model.parameters(), lr=config.lr)

lr_scheduler = optimizers.lr_scheduler.ReduceLROnPlateau(optim, mode='min', verbose=True, patience=config.patience, factor=config.factor)
early_stopping = EarlyStopping(patience=config.valid_patience, delta=config.valid_minimal_improvement,
                              path=config.model_dir+'checkpoint.pt')
bce = torch.nn.BCEWithLogitsLoss()
ce = torch.nn.CrossEntropyLoss(ignore_index=-1)
softmax = nn.Softmax(dim=2)

In [22]:
model.to(device)

DFSCodeSeq2SeqFC(
  (encoder): DFSCodeEncoder(
    (emb_dfs): PositionalEncoding(
      (dropout): Dropout(p=0, inplace=False)
    )
    (emb_seq): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (emb_atom): Embedding(118, 50)
    (emb_bond): Linear(in_features=4, out_features=50, bias=True)
    (enc): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=250, out_features=250, bias=True)
          )
          (linear1): Linear(in_features=250, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=250, bias=True)
          (norm1): LayerNorm((250,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((250,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dro

In [23]:
data = next(iter(loader))

In [87]:
rndc, x, z, eattr, minc = data
rndc = to_cuda(rndc)
z = to_cuda(z)
eattr = to_cuda(eattr)
minc = to_cuda(minc)
#prepare labels
minc = [torch.cat((c, (-1)*torch.ones((1, 8), dtype=torch.long, device=device)), dim=0) for c in minc]
minc_seq = nn.utils.rnn.pad_sequence(minc, padding_value=-1)

rndc = [torch.cat((c, (-1)*torch.ones((1, 8), dtype=torch.long, device=device)), dim=0) for c in rndc]
rndc_seq = nn.utils.rnn.pad_sequence(rndc, padding_value=-1)
#prediction
dfs1, dfs2, atm1, atm2, bnd, eos = model(rndc, z, eattr)
eos_label = (minc_seq[:,:,0] == (-1))

In [32]:
dfsidx1 = torch.argmax(softmax(dfs1), dim=2)
dfsidx2 = torch.argmax(softmax(dfs2), dim=2)
atmnr1 = torch.argmax(softmax(atm1), dim=2)
atmnr2 = torch.argmax(softmax(atm2), dim=2)
bndnr = torch.argmax(softmax(bnd), dim=2)

In [159]:
j = 123

In [160]:
mask = minc_seq[:, :, 0][:, j] != -1

In [161]:
rndc_seq[:, :, 0][:, j][mask]

tensor([ 0,  1,  2,  3,  3,  0,  6,  0,  7,  8,  8,  8, 11, 12, 11],
       device='cuda:0')

In [162]:
dfsidx1[:, j][mask]

tensor([ 0,  1,  2,  3,  3,  1,  0,  7,  8,  9, 10, 11, 11, 13, 14],
       device='cuda:0')

In [163]:
minc_seq[:, :, 0][:, j][mask]

tensor([ 0,  1,  2,  3,  2,  1,  1,  7,  8,  9, 10, 11, 11, 13, 14],
       device='cuda:0')

In [164]:
rndc_seq[:, :, 1][:, j][mask]

tensor([ 1,  2,  3,  4,  5,  6,  5,  7,  8,  9, 10, 11, 12, 13, 14],
       device='cuda:0')

In [165]:
dfsidx2[:, j][mask]

tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,  8],
       device='cuda:0')

In [166]:
minc_seq[:, :, 1][:, j][mask]

tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,  8],
       device='cuda:0')

In [167]:
rndc_seq[:, :, 2][:, j][mask]

tensor([6, 6, 6, 6, 6, 6, 6, 6, 8, 6, 6, 6, 6, 7, 6], device='cuda:0')

In [168]:
atmnr1[:, j][mask]

tensor([6, 6, 6, 6, 7, 6, 6, 8, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')

In [169]:
minc_seq[:, :, 2][:, j][mask]

tensor([6, 6, 6, 7, 6, 6, 6, 8, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')

In [170]:
rndc_seq[:, :, 4][:, j][mask]

tensor([ 6,  6,  6, 17,  6,  6,  6,  8,  6,  6,  6,  6,  7,  7,  8],
       device='cuda:0')

In [171]:
atmnr2[:, j][mask]

tensor([ 6,  6,  6,  7,  8,  8,  8,  6,  6,  6,  6, 17,  6,  6,  6],
       device='cuda:0')

In [172]:
minc_seq[:, :, 4][:, j][mask]

tensor([ 6,  6,  7,  7,  8,  6,  8,  6,  6,  6,  6, 17,  6,  6,  6],
       device='cuda:0')

In [173]:
rndc_seq[:, :, 3][:, j][mask]

tensor([2, 2, 2, 0, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 1], device='cuda:0')

In [174]:
bndnr[:, j][mask]

tensor([0, 0, 0, 0, 1, 0, 0, 0, 2, 2, 2, 0, 2, 2, 2], device='cuda:0')

In [175]:
minc_seq[:, :, 3][:, j][mask]

tensor([0, 0, 0, 0, 1, 0, 0, 0, 2, 2, 2, 0, 2, 2, 2], device='cuda:0')

In [176]:
(eos_label[:, j]==False) == mask

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True], device='cuda:0')

In [None]:
try:
    for epoch in range(config.n_epochs):  
        epoch_loss = 0
        pbar = tqdm.tqdm(loader)
        for i, data in enumerate(pbar):
            rndc, x, z, eattr, minc = data
            rndc = to_cuda(rndc)
            z = to_cuda(z)
            eattr = to_cuda(eattr)
            minc = to_cuda(minc)
            #prepare labels
            minc = [torch.cat((c, (-1)*torch.ones((1, 8), dtype=torch.long, device=device)), dim=0) for c in minc]
            minc_seq = nn.utils.rnn.pad_sequence(minc, padding_value=-1)
            
            #prediction
            dfs1, dfs2, atm1, atm2, bnd, eos = model(rndc, z, eattr)
            eos_label = (minc_seq[:,:,0] == (-1))
            #minc_seq[eos_label] = 0
            #print(np.unique(minc_seq[:,:,0].cpu().numpy()))
            #print(np.unique(minc_seq[:,:,1].cpu().numpy()))
            #print(np.unique(minc_seq[:,:,2].cpu().numpy()))
            #print(np.unique(minc_seq[:,:,3].cpu().numpy()))
            #print(np.unique(minc_seq[:,:,4].cpu().numpy()))
            #TODO: use ignore_index
            loss = ce(torch.reshape(dfs1, (-1, 100)), minc_seq[:, :, 0].view(-1)) 
            loss += ce(torch.reshape(dfs2, (-1, 100)), minc_seq[:, :, 1].view(-1))
            loss += ce(torch.reshape(atm1, (-1, 118)), minc_seq[:, :, 2].view(-1))
            loss += ce(torch.reshape(bnd, (-1, 4)), minc_seq[:, :, 3].view(-1))
            loss += ce(torch.reshape(atm2, (-1, 118)), minc_seq[:, :, 4].view(-1))
            loss += bce(eos, torch.unsqueeze(eos_label.float(), -1))
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optim.step()
            epoch_loss = (epoch_loss*i + loss.item())/(i+1)
            
            pbar.set_description('Epoch %d: CE %2.6f'%(epoch+1, epoch_loss))


        lr_scheduler.step(epoch_loss)
        early_stopping(epoch_loss, model)
        curr_lr = list(optim.param_groups)[0]['lr']

        if early_stopping.early_stop:
            break

        if curr_lr < config.minimal_lr:
            break

except KeyboardInterrupt:
    print('keyboard interrupt caught')


Epoch 1: MAE/CA 4.238406: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7454/7454 [1:11:54<00:00,  1.73it/s]
Epoch 2: MAE/CA 3.486167: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7454/7454 [1:11:04<00:00,  1.75it/s]
Epoch 3: MAE/CA 3.134876: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7454/7454 [1:11:06<00:00,  1.75it/s]
Epoch 4: MAE/CA 2.903316: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7454/7454 [1:11:04<00:00,  1.75it/s]
Epoch 5: MAE/CA 2.726842: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████