In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import numpy as np
from torch.utils.data import Dataset, DataLoader
import glob
import wandb
import os
import torch.optim as optimizers

In [3]:
import dfs_code
from torch_geometric.data import InMemoryDataset, Data
import pickle
import torch
import torch.nn as nn
import tqdm
import copy
import pandas as pd
#torch.multiprocessing.set_sharing_strategy('file_system') # this is important on local machine
#def set_worker_sharing_strategy(worker_id: int) -> None:
#    torch.multiprocessing.set_sharing_strategy('file_system')

In [4]:
import sys
sys.path = ['../../src'] + sys.path
from dfs_transformer import EarlyStopping, DFSCodeSeq2SeqFC

2021-09-01 14:41:40.010184: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2021-09-01 14:41:40.010213: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [5]:
wandb.init(project='chembl', entity='chrisxx')

config = wandb.config
config.n_atoms = 118
config.n_bonds = 4
config.emb_dim = 120
config.nhead = 12
config.nlayers = 6
config.max_nodes = 400
config.max_edges = 600
config.dim_feedforward = 2048
config.lr = 0.000011111
config.n_epochs = 10000
config.lr_adjustment_period = 500
config.patience = 5
config.factor = 0.96
config.minimal_lr = 6e-8
config.batch_size = 100
config.accumulate_grads = 2
config.valid_patience = 10000
config.valid_minimal_improvement=0.00
config.model_dir = '../../models/chemblH/better_transformer/medium/'
config.path = "/mnt/project/chembl64/leq200_timeout10_sanity"
config.graph_pattern = "/home/wendlerc/ChEMBL/preprocessedPlusHs_split%d.pt"
config.csv_file = "/home/wendlerc/ChEMBL/small_molecules.csv"
config.pretrained_dir = "../../models/chembl/better_transformer/medium/"
config.num_workers = 0
config.prefetch_factor = 2
config.persistent_workers = False
config.load_last = True


[34m[1mwandb[0m: Currently logged in as: [33mchrisxx[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
2021-09-01 14:41:57.182046: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2021-09-01 14:41:57.182073: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [6]:
os.makedirs(config.model_dir, exist_ok=True)

In [7]:
torch.multiprocessing.get_all_sharing_strategies()

{'file_descriptor', 'file_system'}

In [8]:
class ChEMBL(InMemoryDataset):
    def __init__(self, fname):
        super().__init__()
        self.data, self.slices = torch.load(fname)

class ChEMBL100NoH(Dataset):
    """ChEMBL dataset of molecules and minimal DFS codes."""
    # create data structure that says which id is in which file...
    def __init__(self, path=config.path,
                       graph_pattern=config.graph_pattern, 
                       csv_file=config.csv_file, 
                       idx_start=0,
                       idx_max=64,
                       transform=None):
        self.path = path
        self.fnames = glob.glob(path+"CHEMBL*.pkl")
        self.data = []
        self.csv_file = csv_file
        self.path = path
        self.graph_pattern = graph_pattern 
        self.idx_start = idx_start
        self.idx_max = idx_max
        self.prepare()
        
        
    def prepare(self):
        codes_all = {}
        i2didx = {}
        for i in tqdm.tqdm(range(self.idx_start, self.idx_max)):
            dname = glob.glob(config.path+"/%d/min_dfs_codes_split*.json"%(i+1))[0]
            didx = int(dname.split("split")[-1][:-5])
            i2didx[i] = didx
            with open(dname, 'r') as f:
                codes = json.load(f)
                for key, val in codes.items():
                    codes_all[key] = val
        graph_pattern = self.graph_pattern
        df = pd.read_csv(self.csv_file, delimiter=';', low_memory=False)
        chembl2smiles = {cid:smiles for cid, smiles in zip(df['ChEMBL ID'], df['Smiles'])}
        for i in range(self.idx_start, self.idx_max):
            dataset = ChEMBL(graph_pattern%i2didx[i])
            print(i)
            for data in tqdm.tqdm(dataset):
                if data.name in codes_all:
                    code = codes_all[data.name]
                    d = {'x':data.x.numpy(),
                          'z':data.z.numpy(),
                          'edge_attr': data.edge_attr.numpy(),
                          'edge_index': data.edge_index.numpy(),
                          'name':data.name,
                          'min_dfs_code': np.asarray(code['min_dfs_code']),
                          'min_dfs_index': np.asarray(code['dfs_index']),
                          'smiles': chembl2smiles[data.name]}
                    data_ = Data(x=torch.tensor(d['x']),
                                z=torch.tensor(d['z']),
                                edge_attr=torch.tensor(d['edge_attr']),
                                edge_index=torch.tensor(d['edge_index'], dtype=torch.long),
                                name=d['name'],
                                min_dfs_code=torch.tensor(d['min_dfs_code']),
                                min_dfs_index=torch.tensor(d['min_dfs_index'], dtype=torch.long),
                                smiles=d['smiles'])
                    self.data += [data_]   
            del dataset
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [9]:
def collate_fn(dlist):
    x_batch = [] 
    z_batch = []
    edge_attr_batch = []
    rnd_code_batch = []
    min_code_batch = []
    for d in dlist:
        rnd_code, rnd_index = dfs_code.rnd_dfs_code_from_torch_geometric(d, 
                                                                         d.z.numpy().tolist(), 
                                                                         np.argmax(d.edge_attr.numpy(), axis=1))
        x_batch += [d.x]
        z_batch += [d.z]#[nn.functional.one_hot(d.z, 118)]#118 elements in periodic table
        edge_attr_batch += [d.edge_attr]
        rnd_code_batch += [torch.tensor(rnd_code)]
        min_code_batch += [d.min_dfs_code]
    return rnd_code_batch, x_batch, z_batch, edge_attr_batch, min_code_batch

In [10]:
ngpu=1
device = torch.device('cuda:1' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

In [11]:
dataset = ChEMBL100NoH(idx_start=16, idx_max=16+32)

100%|███████████████████████████████████████████| 32/32 [01:46<00:00,  3.34s/it]


16


100%|███████████████████████████████████| 29917/29917 [00:06<00:00, 4755.55it/s]


17


100%|███████████████████████████████████| 29917/29917 [00:06<00:00, 4828.56it/s]


18


100%|███████████████████████████████████| 29917/29917 [00:06<00:00, 4872.23it/s]


19


100%|███████████████████████████████████| 29917/29917 [00:06<00:00, 4852.55it/s]


20


100%|███████████████████████████████████| 29917/29917 [00:06<00:00, 4827.16it/s]


21


100%|███████████████████████████████████| 29917/29917 [00:06<00:00, 4899.36it/s]


22


100%|███████████████████████████████████| 29917/29917 [00:06<00:00, 4810.64it/s]


23


100%|███████████████████████████████████| 29917/29917 [00:06<00:00, 4728.33it/s]


24


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4767.87it/s]


25


100%|███████████████████████████████████| 29916/29916 [00:22<00:00, 1328.67it/s]


26


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4956.65it/s]


27


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4934.52it/s]


28


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4760.40it/s]


29


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4710.16it/s]


30


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4692.14it/s]


31


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4626.13it/s]


32


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4652.47it/s]


33


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4529.35it/s]


34


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4707.69it/s]


35


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4657.29it/s]


36


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4707.65it/s]


37


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4605.50it/s]


38


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4571.53it/s]


39


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4640.01it/s]


40


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4645.40it/s]


41


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4701.60it/s]


42


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4717.70it/s]


43


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4668.36it/s]


44


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4739.26it/s]


45


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4782.97it/s]


46


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4736.16it/s]


47


100%|███████████████████████████████████| 29916/29916 [00:06<00:00, 4725.15it/s]


In [12]:
loader = DataLoader(dataset, 
                    batch_size=config.batch_size, 
                    shuffle=True, 
                    pin_memory=False, 
                    collate_fn=collate_fn,
                    num_workers=config.num_workers, 
                    prefetch_factor=config.prefetch_factor, 
                    persistent_workers=config.persistent_workers)

In [13]:
to_cuda = lambda T: [t.to(device) for t in T]

In [14]:
model = DFSCodeSeq2SeqFC(n_atoms=config.n_atoms,
                         n_bonds=config.n_bonds, 
                         emb_dim=config.emb_dim, 
                         nhead=config.nhead, 
                         nlayers=config.nlayers, 
                         max_nodes=config.max_nodes, 
                         max_edges=config.max_edges,
                         atom_encoder=nn.Embedding(config.n_atoms, config.emb_dim), 
                         bond_encoder=nn.Linear(config.n_bonds, config.emb_dim))

In [15]:
if config.pretrained_dir is not None:
    model.load_state_dict(torch.load(config.pretrained_dir+'checkpoint.pt', map_location=device))
if config.load_last:
    model.load_state_dict(torch.load(config.model_dir+'checkpoint.pt', map_location=device))

In [16]:
optim = optimizers.Adam(model.parameters(), lr=config.lr)

#lr_scheduler = optimizers.lr_scheduler.ReduceLROnPlateau(optim, mode='min', verbose=True, patience=config.patience, factor=config.factor)
lr_scheduler = optimizers.lr_scheduler.ExponentialLR(optim, gamma=config.factor)

early_stopping = EarlyStopping(patience=config.valid_patience, delta=config.valid_minimal_improvement,
                              path=config.model_dir+'checkpoint.pt')
bce = torch.nn.BCEWithLogitsLoss()
ce = torch.nn.CrossEntropyLoss(ignore_index=-1)
softmax = nn.Softmax(dim=2)

In [17]:
model.to(device)

DFSCodeSeq2SeqFC(
  (encoder): DFSCodeEncoder(
    (emb_dfs): PositionalEncoding(
      (dropout): Dropout(p=0, inplace=False)
    )
    (emb_seq): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (emb_atom): Embedding(118, 120)
    (emb_bond): Linear(in_features=4, out_features=120, bias=True)
    (mixer): Linear(in_features=600, out_features=600, bias=True)
    (enc): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=600, out_features=600, bias=True)
          )
          (linear1): Linear(in_features=600, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=600, bias=True)
          (norm1): LayerNorm((600,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((600,), eps=1e-05, elementwise_affine=True)
         

In [None]:
try:
    for epoch in range(config.n_epochs):  
        epoch_loss = 0
        pbar = tqdm.tqdm(loader)
        for i, data in enumerate(pbar):
            if i % config.accumulate_grads == 0: #bei 0 wollen wir das
                optim.zero_grad()
            rndc, x, z, eattr, minc = data
            rndc = to_cuda(rndc)
            z = to_cuda(z)
            eattr = to_cuda(eattr)
            minc = to_cuda(minc)
            #prepare labels
            minc = [torch.cat((c, (-1)*torch.ones((1, 8), dtype=torch.long, device=device)), dim=0) for c in minc]
            minc_seq = nn.utils.rnn.pad_sequence(minc, padding_value=-1)
            
            #prediction
            dfs1, dfs2, atm1, atm2, bnd, eos = model(rndc, z, eattr)
            eos_label = (minc_seq[:,:,0] == (-1))
            loss = ce(torch.reshape(dfs1, (-1, config.max_nodes)), minc_seq[:, :, 0].view(-1)) 
            loss += ce(torch.reshape(dfs2, (-1, config.max_nodes)), minc_seq[:, :, 1].view(-1))
            loss += ce(torch.reshape(atm1, (-1, config.n_atoms)), minc_seq[:, :, 2].view(-1))
            loss += ce(torch.reshape(bnd, (-1, config.n_bonds)), minc_seq[:, :, 3].view(-1))
            loss += ce(torch.reshape(atm2, (-1, config.n_atoms)), minc_seq[:, :, 4].view(-1))
            loss += bce(eos, torch.unsqueeze(eos_label.float(), -1))
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            if (i+1) % config.accumulate_grads == 0:
                optim.step() # bei 2 wollen wir das
            epoch_loss = (epoch_loss*i + loss.item())/(i+1)
            
            curr_lr = list(optim.param_groups)[0]['lr']
            wandb.log({'loss':epoch_loss, 
                   'learning rate':curr_lr})
            pbar.set_description('Epoch %d: CE %2.6f'%(epoch+1, epoch_loss))
            
            if i % config.lr_adjustment_period == 0:
                early_stopping(epoch_loss, model)
                
        lr_scheduler.step()#(epoch_loss)
        
        if early_stopping.early_stop:
            break

        if curr_lr < config.minimal_lr:
            break

except KeyboardInterrupt:
    torch.save(model.state_dict(), config.model_dir+'_keyboardinterrupt.pt')
    print('keyboard interrupt caught')


Epoch 1: CE 0.726162:   6%|█                 | 501/8611 [03:22<55:49,  2.42it/s]

EarlyStopping counter: 1 out of 10000


Epoch 1: CE 0.727141:  12%|█▉               | 1001/8611 [06:39<47:49,  2.65it/s]

EarlyStopping counter: 2 out of 10000


Epoch 1: CE 0.724648:  17%|██▉              | 1501/8611 [09:55<48:29,  2.44it/s]

EarlyStopping counter: 3 out of 10000


Epoch 1: CE 0.723486:  23%|███▉             | 2001/8611 [13:10<40:48,  2.70it/s]

EarlyStopping counter: 4 out of 10000


Epoch 1: CE 0.722601:  29%|████▉            | 2501/8611 [16:27<40:24,  2.52it/s]

EarlyStopping counter: 5 out of 10000


Epoch 1: CE 0.722343:  35%|█████▉           | 3001/8611 [19:42<36:50,  2.54it/s]

EarlyStopping counter: 6 out of 10000


Epoch 1: CE 0.721426:  41%|██████▉          | 3501/8611 [22:59<37:21,  2.28it/s]

EarlyStopping counter: 7 out of 10000


Epoch 1: CE 0.720557:  46%|███████▉         | 4001/8611 [26:15<31:03,  2.47it/s]

EarlyStopping counter: 8 out of 10000


Epoch 1: CE 0.719677:  52%|████████▉        | 4501/8611 [29:31<27:14,  2.51it/s]

EarlyStopping counter: 9 out of 10000


Epoch 1: CE 0.719547:  58%|█████████▊       | 5001/8611 [32:48<22:53,  2.63it/s]

EarlyStopping counter: 10 out of 10000


Epoch 1: CE 0.718849:  64%|██████████▊      | 5501/8611 [36:05<21:54,  2.37it/s]

EarlyStopping counter: 11 out of 10000


Epoch 1: CE 0.718569:  70%|███████████▊     | 6001/8611 [39:21<17:26,  2.49it/s]

EarlyStopping counter: 12 out of 10000


Epoch 1: CE 0.718231:  75%|████████████▊    | 6501/8611 [42:38<13:31,  2.60it/s]

EarlyStopping counter: 13 out of 10000


Epoch 1: CE 0.718152:  81%|█████████████▊   | 7001/8611 [45:55<10:25,  2.57it/s]

EarlyStopping counter: 14 out of 10000


Epoch 1: CE 0.718951:  87%|██████████████▊  | 7501/8611 [49:12<07:06,  2.60it/s]

EarlyStopping counter: 15 out of 10000


Epoch 1: CE 0.718562:  93%|███████████████▊ | 8001/8611 [52:29<04:02,  2.52it/s]

EarlyStopping counter: 16 out of 10000


Epoch 1: CE 0.719005:  99%|████████████████▊| 8501/8611 [55:45<00:43,  2.54it/s]

EarlyStopping counter: 17 out of 10000


Epoch 1: CE 0.719210: 100%|█████████████████| 8611/8611 [56:28<00:00,  2.54it/s]
Epoch 2: CE 0.642413:   0%|                  | 1/8611 [00:00<1:01:00,  2.35it/s]

EarlyStopping counter: 18 out of 10000


Epoch 2: CE 0.722655:   6%|█                 | 501/8611 [03:15<52:00,  2.60it/s]

EarlyStopping counter: 19 out of 10000


Epoch 2: CE 0.719823:  12%|█▉               | 1001/8611 [06:31<52:50,  2.40it/s]

EarlyStopping counter: 20 out of 10000


Epoch 2: CE 0.717396:  17%|██▉              | 1501/8611 [09:48<44:20,  2.67it/s]

EarlyStopping counter: 21 out of 10000


Epoch 2: CE 0.717074:  19%|███▏             | 1628/8611 [10:38<45:40,  2.55it/s]