In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import numpy as np
from torch.utils.data import Dataset, DataLoader
import glob
import wandb
import os
import torch.optim as optimizers

In [3]:
import dfs_code
from torch_geometric.data import InMemoryDataset, Data
import pickle
import torch
import torch.nn as nn
import tqdm
import copy
import pandas as pd

In [4]:
import sys
sys.path = ['../../src'] + sys.path
from dfs_transformer import EarlyStopping, DFSCodeSeq2SeqFC, smiles2graph, BERTize



In [5]:
from dfs_transformer import DFSCodeSeq2SeqFCFeatures, Trainer, PubChem, get_n_files
from dfs_transformer.training.utils import seq_loss, seq_acc, collate_BERT, collate_rnd2min
import argparse
import yaml
import functools
from ml_collections import ConfigDict
fname = '../../config/selfattn/bert10K.yaml'

In [6]:
with open(fname) as file:
    config = ConfigDict(yaml.load(file, Loader=yaml.FullLoader))

In [7]:
run = wandb.init(mode="offline", project="pubchem-experimental", entity="chrisxx", 
                 name="bert-10K", config=config.to_dict())

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
2021-09-15 19:42:23.968033: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/opt/cuda/extras/CUPTI/lib64/:/opt/intel/lib:/opt/intel/mkl/lib/intel64:/opt/intel:/opt/ibm/ILOG/CPLEX_Studio1210/cplex/bin/x86-64_linux:/opt/ibm/ILOG/CPLEX_Studio1210/cplex/python/3.7/x86-64_linux:/opt/intel/clck_latest/lib:/opt/intel/daal/lib:/opt/intel/intelpython3/lib:/opt/intel/ipp/lib:/opt/intel/itac_2019/lib:/opt/intel/itac_latest/lib:/opt/intel/mkl/lib:/opt/intel/mkl_/lib:/opt/intel/mpirt/lib:/opt/intel/tbb/lib:/opt/intel/clck/2019.0/lib:/opt/intel/compilers_and_libraries_2019/linux/lib:/opt/intel/compilers_and_libraries/linux/lib:/opt/intel/itac/2019.0.018/lib:/opt/intel/ita

In [8]:
m = config.model
t = config.training
d = config.data
device = torch.device('cuda:%d'%config.training.gpu_id if torch.cuda.is_available() else 'cpu')

ce = nn.CrossEntropyLoss(ignore_index=-1)
bce = nn.BCEWithLogitsLoss()    

fields = ['acc-dfs1', 'acc-dfs2', 'acc-atm1', 'acc-atm2', 'acc-bnd']
metrics = {field:functools.partial(seq_acc, idx=idx) for idx, field in enumerate(fields)}

In [9]:
def collate_fn(dlist, fraction_missing=0.1):
    node_batch = [] 
    edge_batch = []
    min_code_batch = []
    for d in dlist:
        node_batch += [d.node_features]
        edge_batch += [d.edge_features]
        atom1 = d.node_features[d.min_dfs_code[:, -3]]
        atom2 = d.node_features[d.min_dfs_code[:, -1]]
        bond = d.edge_features[d.min_dfs_code[:, -2]]
        min_code_batch += [torch.cat((d.min_dfs_code, atom1, atom2, bond), dim=1)]

    inputs, outputs = BERTize(min_code_batch, fraction_missing=fraction_missing)
    inputs = [inp[:, :8].long() for inp in inputs]
    targets = nn.utils.rnn.pad_sequence(outputs, padding_value=-1)
    return inputs, node_batch, edge_batch, targets 

In [10]:
def loss(pred, target):
    dfs1, dfs2, atm1, atm2, bnd, feat = pred
    
    pred_dfs1 = torch.reshape(dfs1, (-1, m.max_nodes))
    pred_dfs2 = torch.reshape(dfs2, (-1, m.max_nodes))
    pred_atm1 = torch.reshape(atm1, (-1, m.n_atoms))
    pred_atm2 = torch.reshape(atm2, (-1, m.n_atoms))
    pred_bnd = torch.reshape(bnd, (-1, m.n_bonds))
    pred_feat = torch.reshape(feat, (-1,  2*m.n_node_features + m.n_edge_features))
    
    tgt_dfs1 = target[:, :, 0].view(-1).long()
    tgt_dfs2 = target[:, :, 1].view(-1).long()
    tgt_atm1 = target[:, :, 2].view(-1).long()
    tgt_atm2 = target[:, :, 4].view(-1).long()
    tgt_bnd = target[:, :, 3].view(-1).long()
    tgt_feat = target[:, :, 8:].view(-1, 2*m.n_node_features + m.n_edge_features)
    
    loss = ce(pred_dfs1, tgt_dfs1) 
    loss += ce(pred_dfs2, tgt_dfs2)
    loss += ce(pred_atm1, tgt_atm1)
    loss += ce(pred_bnd, tgt_bnd)
    loss += ce(pred_atm2, tgt_atm2)
    
    mask = tgt_dfs1 != -1
    loss += bce(pred_feat[mask], tgt_feat[mask])
    
    return loss 

In [11]:
model = DFSCodeSeq2SeqFCFeatures(**m)
    
if t.load_last and t.es_path is not None:
    model.load_state_dict(torch.load(t.es_path, map_location=device))
elif t.pretrained_dir is not None:
    model.load_state_dict(torch.load(t.pretrained_dir, map_location=device))

In [12]:
validloader = None
if d.valid_path is not None:
    validset = PubChem('../.'+d.valid_path, max_nodes=m.max_nodes, max_edges=m.max_edges)
    validloader = DataLoader(validset, batch_size=d.batch_size, shuffle=True, 
                             pin_memory=False, collate_fn=collate_fn)
    exclude = validset.smiles

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 16.18it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9978/9978 [00:00<00:00, 14780.83it/s]


In [13]:
data = next(iter(validloader))

In [14]:
data[-1][1]

tensor([[-1., -1., -1.,  ..., -1., -1., -1.],
        [-1., -1., -1.,  ..., -1., -1., -1.],
        [-1., -1., -1.,  ..., -1., -1., -1.],
        ...,
        [-1., -1., -1.,  ..., -1., -1., -1.],
        [-1., -1., -1.,  ..., -1., -1., -1.],
        [-1., -1., -1.,  ..., -1., -1., -1.]])

In [15]:
trainer = Trainer(model, None, loss, validloader=validloader, metrics=metrics, 
                  wandb_run = run, **t)
trainer.n_epochs = d.n_iter_per_split

In [16]:
n_files = get_n_files('../.'+d.path)
if d.n_used is None:
    n_splits = 1
else:
    n_splits = n_files // d.n_used

In [None]:
for epoch in range(t.n_epochs):
    print('starting epoch %d'%(epoch+1))
    for split in range(n_splits):
        dataset = PubChem('../.'+d.path, n_used = d.n_used, max_nodes=m.max_nodes, 
                          max_edges=m.max_edges, exclude=exclude)
        loader = DataLoader(dataset, batch_size=d.batch_size, shuffle=True, 
                            pin_memory=False, collate_fn=collate_fn)
        trainer.loader = loader
        trainer.fit()
        if trainer.stop_training:
            break
    if trainer.stop_training:
        break

starting epoch 1


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 14.77it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9983/9983 [00:00<00:00, 14008.04it/s]
Epoch 1: loss 10.721580 0.0417 0.0556 0.7222 0.7222 0.5556: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:42<00:00,  4.66it/s]
Valid 1: loss 9.652095 0.0545 0.0909 0.6909 0.7636 0.5818: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:17<00:00, 11.45it/s]
Epoch 2: loss 9.462198 0.0349 0.0581 0.8605 0.6512 0.5233: 100%|████████████████████████████████████████████████████████████████████████████████████████

In [None]:
d = next(iter(loader))

In [None]:
np.unique(torch.cat(d[1]).numpy())

In [None]:
np.unique(torch.cat(d[2]).numpy())

In [None]:
d[2]

In [None]:
#store config and model
with open(trainer.es_path+'config.yaml', 'w') as f:
    yaml.dump(config.to_dict(), f, default_flow_style=False)
if args.name is not None and args.wandb_mode != "offline":
    trained_model_artifact = wandb.Artifact(args.name, type="model", description="trained selfattn model")
    trained_model_artifact.add_dir(trainer.es_path)
    run.log_artifact(trained_model_artifact)