In [1]:
import torch
import sklearn
import torch.nn as nn
import os

from torch.utils.data import DataLoader
from torch.optim import Adam, lr_scheduler

from dgl.data.utils import Subset
from dgllife.utils import WeaveAtomFeaturizer, CanonicalBondFeaturizer, smiles_to_bigraph, EarlyStopping

from models import LocalRetro

from utils import init_featurizer, mkdir_p, get_configure, load_model, load_dataloader, predict
import tqdm as notebook_tqdm

def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer):
    model.train()
    train_loss = 0
    train_acc = 0
    for batch_id, batch_data in enumerate(data_loader):
        smiles, bg, atom_labels, bond_labels = batch_data
        if len(smiles) == 1:
            continue
           
        atom_labels, bond_labels = atom_labels.to(args['device']), bond_labels.to(args['device'])
        atom_logits, bond_logits, _ = predict(args, model, bg)

        loss_a = loss_criterion(atom_logits, atom_labels)
        loss_b = loss_criterion(bond_logits, bond_labels)
        total_loss = torch.cat([loss_a, loss_b]).mean()
        train_loss += total_loss.item()
        
        optimizer.zero_grad()      
        total_loss.backward() 
        nn.utils.clip_grad_norm_(model.parameters(), args['max_clip'])
        optimizer.step()
                
        if batch_id % args['print_every'] == 0:
            print('\repoch %d/%d, batch %d/%d, loss %.4f' % (epoch + 1, args['num_epochs'], batch_id + 1, len(data_loader), total_loss), end='', flush=True)

    print('\nepoch %d/%d, training loss: %.4f' % (epoch + 1, args['num_epochs'], train_loss/batch_id))

def run_an_eval_epoch(args, model, data_loader, loss_criterion):
    model.eval()
    val_loss = 0
    val_acc = 0
    with torch.no_grad():
        for batch_id, batch_data in enumerate(data_loader):
            smiles, bg, atom_labels, bond_labels = batch_data
            atom_labels, bond_labels = atom_labels.to(args['device']), bond_labels.to(args['device'])
            atom_logits, bond_logits, _ = predict(args, model, bg)
            
            loss_a = loss_criterion(atom_logits, atom_labels)
            loss_b = loss_criterion(bond_logits, bond_labels)
            total_loss = torch.cat([loss_a, loss_b]).mean()
            val_loss += total_loss.item()
    return val_loss/batch_id


def main(args):
    #KKK
    model_path = '../modelsC'
    model_name = 'LocalRetro_%s.pth' % args['dataset']
    args['model_path'] = model_path +'/' + model_name
    args['config_path'] = '../data/configs/%s' % args['config']
    args['data_dir'] = '../data/%s' % args['dataset']
    mkdir_p(model_path)                          
    args = init_featurizer(args)
    model, loss_criterion, optimizer, scheduler, stopper = load_model(args)   
    train_loader, val_loader, test_loader = load_dataloader(args)
    for epoch in range(args['num_epochs']):
        print("Running",epoch)
        run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer)
        val_loss = run_an_eval_epoch(args, model, val_loader, loss_criterion)
        early_stop = stopper.step(val_loss, model) 
        scheduler.step()
        print('epoch %d/%d, validation loss: %.4f' %  (epoch + 1, args['num_epochs'], val_loss))
        print('epoch %d/%d, Best loss: %.4f' % (epoch + 1, args['num_epochs'], stopper.best_score))
        if early_stop:
            print ('Early stopped!!')
            break

    stopper.load_checkpoint(model)
    test_loss = run_an_eval_epoch(args, model, test_loader, loss_criterion)
    print('test loss: %.4f' % test_loss)

import tqdm as notebook_tqdm
args = {'gpu': 'cuda:0', 'dataset': 'USPTO_50K', 'config': 'default_config.json', 'batch_size': 1, 'num_epochs': 50, 'patience': 5, 'max_clip': 20, 'learning_rate': 0.0001, 'weight_decay': 1e-06, 'schedule_step': 10, 'num_workers': 0, 'print_every': 20, 'mode': 'train'}
args['device'] = torch.device(args['gpu']) if torch.cuda.is_available() else torch.device('cpu')    

In [25]:
from dgllife.model import MPNNGNN
from model_utils import pair_atom_feats, unbatch_mask, unbatch_feats, Global_Reactivity_Attention
def pair_atom_feats(g, node_feats):
    sg = g.remove_self_loop() # in case g includes self-loop
    atom_pair_list = torch.transpose(sg.adjacency_matrix().coalesce().indices(), 0, 1)
    atom_pair_idx1 = atom_pair_list[:,0]
    atom_pair_idx2 = atom_pair_list[:,1]
    atom_pair_feats = torch.cat((node_feats[atom_pair_idx1], node_feats[atom_pair_idx2]), dim = 1)
    print("pair_atom_feats",atom_pair_feats)
    return atom_pair_feats

class LocalRetroC(nn.Module):
    def __init__(self,
                 node_in_feats,
                 edge_in_feats,
                 node_out_feats,
                 edge_hidden_feats,
                 num_step_message_passing,
                 attention_heads,
                 attention_layers,
                 AtomTemplate_n, 
                 BondTemplate_n):
        super(LocalRetroC, self).__init__()
        print("node_out_feats",node_out_feats)                
        self.mpnn = MPNNGNN(node_in_feats=node_in_feats,
                           node_out_feats=node_out_feats,
                           edge_in_feats=edge_in_feats,
                           edge_hidden_feats=edge_hidden_feats,
                           num_step_message_passing=num_step_message_passing)
        
        self.linearB = nn.Linear(node_out_feats*2, node_out_feats)

        self.att = Global_Reactivity_Attention(node_out_feats, attention_heads, attention_layers)
        
        self.atom_linear =  nn.Sequential(
                            nn.Linear(node_out_feats, node_out_feats), 
                            nn.ReLU(), 
                            nn.Dropout(0.2),
                            nn.Linear(node_out_feats, AtomTemplate_n+1))
        self.bond_linear =  nn.Sequential(
                            nn.Linear(node_out_feats, node_out_feats), 
                            nn.ReLU(), 
                            nn.Dropout(0.2),
                            nn.Linear(node_out_feats, BondTemplate_n+1))

    def forward(self, g, node_feats, edge_feats):
        node_feats = self.mpnn(g, node_feats, edge_feats)
        atom_feats = node_feats
        bond_feats = self.linearB(pair_atom_feats(g, node_feats))
        edit_feats, mask = unbatch_mask(g, atom_feats, bond_feats)
        attention_score, edit_feats = self.att(edit_feats, mask)
           
        atom_feats, bond_feats = unbatch_feats(g, edit_feats)
        atom_outs = self.atom_linear(atom_feats) 
        bond_outs = self.bond_linear(bond_feats) 

        return atom_outs, bond_outs, attention_score


def load_model(args):
    exp_config = get_configure(args)
    model = LocalRetroC(
        node_in_feats=exp_config['in_node_feats'],
        edge_in_feats=exp_config['in_edge_feats'],
        node_out_feats=exp_config['node_out_feats'],
        edge_hidden_feats=exp_config['edge_hidden_feats'],
        num_step_message_passing=exp_config['num_step_message_passing'],
        attention_heads = exp_config['attention_heads'],
        attention_layers = exp_config['attention_layers'],
        AtomTemplate_n = exp_config['AtomTemplate_n'],
        BondTemplate_n = exp_config['BondTemplate_n'])
    model = model.to(args['device'])
    print ('Parameters of loaded LocalRetro:')
    print (exp_config)
    # print("KKKK")
    # exit(0)    
    if args['mode'] == 'train':
        loss_criterion = nn.CrossEntropyLoss(reduction = 'none')
        optimizer = Adam(model.parameters(), lr=args['learning_rate'], weight_decay=args['weight_decay'])
        scheduler = lr_scheduler.StepLR(optimizer, step_size=args['schedule_step'])
        
        if os.path.exists(args['model_path']):
            user_answer = input('%s exists, want to (a) overlap (b) continue from checkpoint (c) make a new model?' % args['model_path'])
            if user_answer == 'a':
                stopper = EarlyStopping(mode = 'lower', patience=args['patience'], filename=args['model_path'])
                print ('Overlap exsited model and training a new model...')
            elif user_answer == 'b':
                stopper = EarlyStopping(mode = 'lower', patience=args['patience'], filename=args['model_path'])
                stopper.load_checkpoint(model)
                print ('Train from exsited model checkpoint...')
            elif user_answer == 'c':
                model_name = input('Enter new model name: ')
                args['model_path'] = args['model_path'].replace('%s.pth' % args['dataset'], '%s.pth' % model_name)
                stopper = EarlyStopping(mode = 'lower', patience=args['patience'], filename=args['model_path'])
                print ('Training a new model %s.pth' % model_name)
        else:
            stopper = EarlyStopping(mode = 'lower', patience=args['patience'], filename=args['model_path'])
        return model, loss_criterion, optimizer, scheduler, stopper
    
    else:
        model.load_state_dict(torch.load(args['model_path'])['model_state_dict'])
        return model

def mainX(args):
    #KKK
    model_path = '../modelsC'
    model_name = 'LocalRetro_%s.pth' % args['dataset']
    args['model_path'] = model_path +'/' + model_name
    args['config_path'] = '../data/configs/%s' % args['config']
    args['data_dir'] = '../data/%s' % args['dataset']
    mkdir_p(model_path)                          
    args = init_featurizer(args)
    model, loss_criterion, optimizer, scheduler, stopper = load_model(args)   
    
    return model,loss_criterion#,train_loader

    # for epoch in range(args['num_epochs']):
    #     print("Running",epoch)
    #     run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer)
    #     val_loss = run_an_eval_epoch(args, model, val_loader, loss_criterion)
    #     early_stop = stopper.step(val_loss, model) 
    #     scheduler.step()
    #     print('epoch %d/%d, validation loss: %.4f' %  (epoch + 1, args['num_epochs'], val_loss))
    #     print('epoch %d/%d, Best loss: %.4f' % (epoch + 1, args['num_epochs'], stopper.best_score))
    #     if early_stop:
    #         print ('Early stopped!!')
    #         break

    # stopper.load_checkpoint(model)
    # test_loss = run_an_eval_epoch(args, model, test_loader, loss_criterion)
    # print('test loss: %.4f' % test_loss)
#    return (train_loader, val_loader, test_loader)

#train_loader, val_loader, test_loader = mainX(args)    
model,loss_criterion =  mainX(args)

Directory ../modelsC already exists.
node_out_feats 320
Parameters of loaded LocalRetro:
{'attention_heads': 8, 'attention_layers': 1, 'batch_size': 16, 'edge_hidden_feats': 64, 'node_out_feats': 320, 'num_step_message_passing': 6, 'AtomTemplate_n': 124, 'BondTemplate_n': 548, 'in_node_feats': 80, 'in_edge_feats': 13}


In [3]:
train_loader, val_loader, test_loader = load_dataloader(args)

Loading previously saved dgl graphs...../data/saved_graphs/USPTO_50K_dglgraph.bin


In [36]:
# print(model_path)
for batch_id, batch_data in enumerate(train_loader):
    smiles, bg, atom_labels, bond_labels = batch_data
    # print(smiles,len(smiles))
    print(bg)
    # print(atom_labels,len(atom_labels))
    # print(bond_labels,len(bond_labels))

    if len(smiles) == 1:
        continue
           
    atom_labels, bond_labels = atom_labels.to(args['device']), bond_labels.to(args['device'])
    atom_logits, bond_logits, _ = predict(args, model, bg)    
    print(atom_logits)
    print(bond_logits)
    break

Graph(num_nodes=230, num_edges=736,
      ndata_schemes={'h': Scheme(shape=(80,), dtype=torch.float32)}
      edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float32)})


RuntimeError: shape '[8, -1, 8, 1]' is invalid for input of size 11280

In [9]:
import dgl
import torch

g1 = dgl.graph(([0, 1], [1, 0]))
g1.ndata['h'] = torch.tensor([1., 2.])
g2 = dgl.graph(([0, 1], [1, 2]))
g2.ndata['h'] = torch.tensor([1., 2., 3.])

dgl.readout_nodes(g1, 'h')
# tensor([3.])  # 1 + 2

bg = dgl.batch([g1, g2])
dgl.readout_nodes(bg, 'h')
# tensor([3., 6.])  # [1 + 2, 1 + 2 + 3]

tensor([3., 6.])

In [4]:
enumTrainLoader = enumerate(train_loader)

In [31]:
batch_id, batch_data = next(enumTrainLoader)
smiles, bg, atom_labels, bond_labels = batch_data
atom_labels, bond_labels = atom_labels.to(args['device']), bond_labels.to(args['device'])

In [6]:
batch_data

(['[CH2:1]([CH2:2][CH2:3][c:4]1[cH:5][cH:6][cH:7][cH:8][cH:9]1)[NH:10][CH2:11][CH2:12][CH2:13][CH2:14][CH2:15][n:16]1[cH:17][cH:18][n:19][cH:20]1'],
 Graph(num_nodes=20, num_edges=62,
       ndata_schemes={'h': Scheme(shape=(80,), dtype=torch.float32)}
       edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float32)}),
 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 tensor([  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0, 545,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0]))

In [32]:
batch_id, batch_data = next(enumTrainLoader)
smiles, bg, atom_labels, bond_labels = batch_data
atom_labels, bond_labels = atom_labels.to(args['device']), bond_labels.to(args['device'])
bg = bg.to(args['device'])
node_feats = bg.ndata.pop('h').to(args['device'])
edge_feats = bg.edata.pop('e').to(args['device'])
    # return model(bg, node_feats, edge_feats)



In [33]:
print(node_feats.size(),edge_feats.size())
atom_logits, bond_logits, _= model(bg, node_feats, edge_feats)    

torch.Size([17, 80]) torch.Size([53, 13])
pair_atom_feats tensor([[ 0.7240, -0.8444,  0.5521,  ..., -0.2495, -0.9935,  0.1580],
        [ 0.7240, -0.8444,  0.5521,  ..., -0.1911, -0.9167,  0.6131],
        [-0.1346, -0.9412,  0.5618,  ..., -0.0635, -0.9048,  0.6645],
        ...,
        [-0.2897, -0.9593,  0.5600,  ..., -0.1647, -0.9951, -0.1582],
        [-0.2592, -0.9607,  0.5665,  ..., -0.2900, -0.9978,  0.3605],
        [-0.2592, -0.9607,  0.5665,  ..., -0.1593, -0.9939, -0.0664]],
       device='cuda:0', grad_fn=<CatBackward0>)


In [20]:
print(atom_logits.size(),bond_logits.size())

torch.Size([20, 125]) torch.Size([42, 549])


In [34]:
print(atom_logits.size(),bond_logits.size())
print(atom_labels.size(),bond_labels.size())
loss_a = loss_criterion(atom_logits, atom_labels)

torch.Size([17, 125]) torch.Size([36, 549])
torch.Size([17]) torch.Size([36])


In [41]:
_atom_logits = nn.Softmax(dim = 1)(atom_logits)
_bond_logits = nn.Softmax(dim = 1)(bond_logits) 
print(_atom_logits.size(),_bond_logits.size())

torch.Size([17, 125]) torch.Size([36, 549])


In [49]:
import dgl
def get_bg_partition(bg):
    sg = bg.remove_self_loop()
    gs = dgl.unbatch(sg)
    nodes_sep = [0]
    edges_sep = [0]
    for g in gs:
        nodes_sep.append(nodes_sep[-1] + g.num_nodes())
        edges_sep.append(edges_sep[-1] + g.num_edges())
    return gs, nodes_sep[1:], edges_sep[1:]

graphs, nodes_sep, edges_sep = get_bg_partition(bg)    
genum = enumerate(zip(graphs, nodes_sep, edges_sep))

In [50]:
print(nodes_sep,edges_sep)
print(graphs)
next(genum)


[17] [36]
[Graph(num_nodes=17, num_edges=36,
      ndata_schemes={}
      edata_schemes={})]


(0,
 (Graph(num_nodes=17, num_edges=36,
        ndata_schemes={}
        edata_schemes={}),
  17,
  36))

In [60]:
import numpy as np
def combined_edit(graph, atom_out, bond_out, top_num):
    edit_id_a, edit_proba_a = output2edit(atom_out, top_num)
    edit_id_b, edit_proba_b = output2edit(bond_out, top_num)
    edit_id_c = edit_id_a + edit_id_b
    edit_type_c = ['a'] * top_num + ['b'] * top_num
    edit_proba_c = edit_proba_a + edit_proba_b
    edit_rank_c = np.flip(np.argsort(edit_proba_c))[:top_num]
    edit_type_c = [edit_type_c[r] for r in edit_rank_c]
    edit_id_c = [edit_id_c[r] for r in edit_rank_c]
    edit_proba_c = [edit_proba_c[r] for r in edit_rank_c]
    return edit_type_c, edit_id_c, edit_proba_c

def output2edit(out, top_num):
    class_n = out.size(-1)
    readout = out.cpu().detach().numpy()
    readout = readout.reshape(-1)
    output_rank = np.flip(np.argsort(readout))
    output_rank = [r for r in output_rank if get_id_template(r, class_n)[1] != 0][:top_num]
    selected_edit = [get_id_template(a, class_n) for a in output_rank]
    selected_proba = [readout[a] for a in output_rank]
    return selected_edit, selected_proba

def get_id_template(a, class_n):
    class_n = class_n # no template
    edit_idx = a//class_n
    template = a%class_n
    return (edit_idx, template)

start_node=0
start_edge=0
args['top_num'] = 5
for single_id, (graph, end_node, end_edge) in enumerate(zip(graphs, nodes_sep, edges_sep)):
    # smiles = smiles_list[single_id]
    test_id = (batch_id * args['batch_size']) + single_id
    pred_types, pred_sites, pred_scores = combined_edit(graph, _atom_logits[start_node:end_node], _bond_logits[start_edge:end_edge], args['top_num'])
    start_node = end_node
    start_edge = end_edge
    print(pred_types,pred_sites,pred_scores)

['a', 'a', 'a', 'a', 'a'] [(9, 13), (8, 13), (5, 13), (1, 13), (0, 13)] [0.017958768, 0.016977806, 0.016885, 0.016703364, 0.016456636]


In [73]:
_atom_logits[0]

tensor([0.0074, 0.0094, 0.0100, 0.0101, 0.0094, 0.0136, 0.0089, 0.0120, 0.0088,
        0.0061, 0.0109, 0.0092, 0.0076, 0.0165, 0.0072, 0.0064, 0.0069, 0.0070,
        0.0086, 0.0086, 0.0105, 0.0061, 0.0077, 0.0060, 0.0093, 0.0060, 0.0073,
        0.0086, 0.0053, 0.0092, 0.0107, 0.0083, 0.0074, 0.0056, 0.0069, 0.0066,
        0.0089, 0.0137, 0.0070, 0.0063, 0.0088, 0.0084, 0.0079, 0.0085, 0.0087,
        0.0096, 0.0100, 0.0055, 0.0093, 0.0060, 0.0081, 0.0065, 0.0106, 0.0054,
        0.0056, 0.0096, 0.0052, 0.0061, 0.0073, 0.0075, 0.0098, 0.0067, 0.0046,
        0.0093, 0.0054, 0.0036, 0.0105, 0.0080, 0.0059, 0.0062, 0.0067, 0.0058,
        0.0087, 0.0121, 0.0078, 0.0087, 0.0048, 0.0087, 0.0090, 0.0071, 0.0074,
        0.0067, 0.0063, 0.0076, 0.0083, 0.0058, 0.0069, 0.0053, 0.0040, 0.0060,
        0.0039, 0.0082, 0.0076, 0.0050, 0.0067, 0.0096, 0.0092, 0.0101, 0.0082,
        0.0092, 0.0082, 0.0057, 0.0103, 0.0109, 0.0100, 0.0070, 0.0091, 0.0113,
        0.0061, 0.0087, 0.0086, 0.0095, 