In [1]:
#python pretrain_multi.py --emb_dim 300 --hidden_size 300 --epochs 100 --dropout_ratio 0.1 --dataset data/merge_0 --vocab data/merge_0/clique.txt --output_path saved_model/grover --batch_size 40 --order dfs --grover_dataset --multi
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.autograd import Variable

import math, random, sys
import numpy as np
from optparse import OptionParser
from gnn_model import GNN, GNN_grover

sys.path.append('./util')
sys.path.append('./grover')

from util.mol_tree import *
from util.nnutils import *
from util.datautils import *
from util.motif_generation import *

import rdkit

# add for grover
import os, time
import wandb
from grover.topology.mol_tree import *
from grover.topology.grover_datasets import *
from sklearn.model_selection import train_test_split

#for torch ddp
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os


In [2]:
def parse_args():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--decay', type=float, default=0,
                        help='weight decay (default: 0)')
    parser.add_argument('--num_layer', type=int, default=5,
                        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim', type=int, default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio', type=float, default=0.1,
                        help='dropout ratio (default: 0.1)')
    parser.add_argument('--graph_pooling', type=str, default="sum",
                        help='graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument('--JK', type=str, default="last",
                        help='how the node features across layers are combined. last, sum, max or concat')
    parser.add_argument('--dataset', type=str, default='./data/zinc/all.txt',
                        help='root directory of dataset. For now, only classification.')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--input_model_file', type=str, default="", help='filename to read the model (if there is any)')
    parser.add_argument('--output_path', type=str, default='./saved_model/grover',
                        help='filename to output the pre-trained model')
    parser.add_argument('--num_workers', type=int, default=0, help='number of workers for dataset loading')  
    parser.add_argument("--hidden_size", type=int, default=300, help='hidden size')
    parser.add_argument("--vocab", type=str, default='./data/zinc/clique.txt', help='vocab path')
    parser.add_argument('--order', type=str, default="dfs",
                        help='motif tree generation order (bfs or dfs)')
    parser.add_argument('--seed', type=int, default=0,
                        help='setting seed number')
    #for wandb
    parser.add_argument('--wandb', action='store_true', default=False, help='add wandb log')
    parser.add_argument('--wandb_name', type=str, default = 'MGSSL_Grover', help='wandb name')
    #for grovermode
    parser.add_argument('--grover_dataset', action='store_true', default=False, help='grover dataset mode')
    parser.add_argument('--multi', action='store_true', default=False, help='use multiprocess mode')
    parser.add_argument('--rank', type=int, default=0)
    parser.add_argument('--master_worker', action='store_true', default=True)
    

    args = parser.parse_args(['--dataset','data/zinc15_250K','--vocab','data/zinc15_250K/clique.txt','--grover_dataset'])
    return args


In [3]:
def group_node_rep(node_rep, batch_index, batch_size):
    group = []
    count = 0
    for i in range(batch_size):
        num = sum(batch_index == i)
        group.append(node_rep[count:count + num])
        count += num
    return group


In [4]:
args = parse_args()
args.rank = int(os.environ["LOCAL_RANK"]) if args.multi else 0
args.master_worker = (args.rank == 0) if args.multi else True
if args.master_worker : 
    if not os.path.exists(args.output_path):
        os.mkdir(args.output_path)
logger = create_logger('pretrain', args)
debug = logger.debug
info = logger.info

#for distributed
world_size = int(os.environ["WORLD_SIZE"]) if args.multi else 1

if args.master_worker : 
    info(f'emb_dim : {args.emb_dim}, lr : {args.lr}, dropout : {args.dropout_ratio}, batch_size : {args.batch_size}')
info(f'rank : {args.rank}')


torch.manual_seed(args.seed)
np.random.seed(args.seed)
#device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)

if args.grover_dataset:
    grover_data, sample_per_file = get_motif_data(data_path = args.dataset, logger=logger)
    train_dataset, val_dataset = split_data_grover(grover_data, sizes=(0.9,0.1,0), seed=args.seed, logger=logger)
    shared_dict = {}
    GMC = GroverMotifCollator(shared_dict=shared_dict, args=args)

    pre_load_data(dataset=train_dataset, rank=args.rank, num_replicas=world_size, sample_per_file=sample_per_file, logger=logger)
    pre_load_data(dataset=val_dataset, rank=args.rank, num_replicas=world_size, sample_per_file=sample_per_file, logger=logger)

    train_sampler = DistributedSampler(dataset=train_dataset, num_replicas=world_size, rank=args.rank, shuffle=True, sample_per_file=sample_per_file)
    val_sampler = DistributedSampler(dataset=val_dataset, num_replicas=world_size, rank=args.rank, shuffle=False, sample_per_file=sample_per_file)
    train_sampler.set_epoch(args.epochs)
    val_sampler.set_epoch(1)
    idxs = val_sampler.get_indices()
    for local_rank in idxs:
        val_dataset.load_data(local_rank)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=GMC, sampler=train_sampler)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=GMC, sampler=val_sampler)

else : 
    dataset = MoleculeDataset_grover(args.dataset)
    train_dataset, val_dataset = train_test_split(dataset, test_size=0.1, random_state=args.seed)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=lambda x:x)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=lambda x:x)

emb_dim : 300, lr : 0.001, dropout : 0.1, batch_size : 32
rank : 0
Loading data:
Number of files: 250
Number of samples: 249624
Samples/file: 1000
train size : 225, val size : 25
total 225000 data pre-loading
total 24624 data pre-loading


In [5]:
model = GNN_grover(5, args.emb_dim, JK='last', drop_ratio=args.dropout_ratio, gnn_type='gin').to(args.rank)        
#model = DDP(model, device_ids = [args.rank])

vocab = [x.strip("\r\n ") for x in open(args.vocab)]
vocab = Vocab(vocab)
motif_model = Motif_Generation_Grover(vocab, args.hidden_size, args.rank, args.order).to(args.rank)
#motif_model = DDP(motif_model, device_ids = [args.rank])

optimizer_model = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay)
optimizer_motif = optim.Adam(motif_model.parameters(), lr=args.lr, weight_decay=args.decay)

cp_path = f'{args.output_path}/temp.pth'
if os.path.exists(cp_path):
    resume_epoch, resume_batch, best_val_loss = load_cp(model, motif_model, optimizer_model, optimizer_motif, cp_path)
    if args.master_worker:
        info(f'load checkpoint : {resume_epoch}epoch, batch : {resume_batch}')
    else : 
        debug(f'rank : {args.rank} load checkpoint : {resume_epoch}epoch, batch : {resume_batch}')
else : 
    resume_epoch = 0
    resume_batch = 0
    best_val_loss = 1e+10

model_list = [model, motif_model]
optimizer_list = [optimizer_model, optimizer_motif]



In [6]:
model, motif_model = model_list
optimizer_model, optimizer_motif = optimizer_list

model.train()
motif_model.train()
word_acc, topo_acc = 0, 0
starting = False
for step, batch in enumerate(train_loader):
    if step==resume_batch:
        starting=True
    if starting : 
        batch_size = len(batch)

        graph_batch = moltree_to_grover_data(batch)
        batch_index = graph_batch.batch.numpy()
        graph_batch = graph_batch.to(args.rank)
        node_rep = model(graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr)
        node_rep = group_node_rep(node_rep, batch_index, batch_size)
        loss, word_loss, topo_loss, word_acc, topo_acc = motif_model(batch, node_rep)

        optimizer_model.zero_grad()
        optimizer_motif.zero_grad()

        loss.backward()

        optimizer_model.step()
        optimizer_motif.step()
    if step==0:break



TypeError: full() received an invalid combination of arguments - got (tuple, int, device=NoneType), but expected one of:
 * (tuple of ints size, Number fill_value, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, Number fill_value, *, tuple of names names, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)


In [81]:
def mol_to_graph_data_obj_grover(mol):
    #mol = Chem.MolFromSmiles(mol)
    hydrogen_donor = Chem.MolFromSmarts("[$([N;!H0;v3,v4&+1]),$([O,S;H1;+0]),n&H1&+0]")
    hydrogen_acceptor = Chem.MolFromSmarts(
        "[$([O,S;H1;v2;!$(*-*=[O,N,P,S])]),$([O,S;H0;v2]),$([O,S;-]),$([N;v3;!$(N-*=[O,N,P,S])]),"
        "n&H0&+0,$([o,s;+0;!$([o,s]:n);!$([o,s]:c:n)])]")
    acidic = Chem.MolFromSmarts("[$([C,S](=[O,S,P])-[O;H1,-1])]")
    basic = Chem.MolFromSmarts(
        "[#7;+,$([N;H2&+0][$([C,a]);!$([C,a](=O))]),$([N;H1&+0]([$([C,a]);!$([C,a](=O))])[$([C,a]);"
        "!$([C,a](=O))]),$([N;H0&+0]([C;!$(C(=O))])([C;!$(C(=O))])[C;!$(C(=O))])]")

    hydrogen_donor_match = sum(mol.GetSubstructMatches(hydrogen_donor), ())
    hydrogen_acceptor_match = sum(mol.GetSubstructMatches(hydrogen_acceptor), ())
    acidic_match = sum(mol.GetSubstructMatches(acidic), ())
    basic_match = sum(mol.GetSubstructMatches(basic), ())
    ring_info = mol.GetRingInfo()

    n_atoms = mol.GetNumAtoms()
    
    f_atoms = []
    for _, atom in enumerate(mol.GetAtoms()):
        f_atoms.append(atom_features(atom, hydrogen_donor_match, hydrogen_acceptor_match, acidic_match, basic_match, ring_info))
    f_atoms = [f_atoms[i] for i in range(n_atoms)]
    
    f_bonds = []
    bond_list = []
    for a1 in range(n_atoms):
        for a2 in range(a1 + 1, n_atoms):
            bond = mol.GetBondBetweenAtoms(a1, a2)

            if bond is None:
                continue

            f_bond = bond_features(bond)

            # Always treat the bond as directed.
            f_bonds.append(f_atoms[a1] + f_bond)
            bond_list.append([a1, a2])
            f_bonds.append(f_atoms[a2] + f_bond)
            bond_list.append([a2, a1])
    
#    data = [f_atoms, bond_list, f_bonds]
    data = Data(x=torch.tensor(f_atoms), edge_index=torch.tensor(bond_list).T, edge_attr=torch.tensor(f_bonds))
    return data


In [82]:
new_batch = Batch(graph_data_batch)

In [86]:
graph_data_batch[0].x[0]

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 

In [71]:
graph_data_batch2[0].edge_index

tensor([[ 0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  4,  6,  6,  7,  7,  8,  8,  9,
          9, 10, 10, 11,  1, 12, 12, 13, 13, 14, 14, 15, 14, 16, 16, 17, 17, 18,
         18, 19, 19, 20, 19, 21, 11,  7, 21, 16],
        [ 1,  0,  2,  1,  3,  2,  4,  3,  5,  4,  6,  4,  7,  6,  8,  7,  9,  8,
         10,  9, 11, 10, 12,  1, 13, 12, 14, 13, 15, 14, 16, 14, 17, 16, 18, 17,
         19, 18, 20, 19, 21, 19,  7, 11, 16, 21]])

In [84]:
graph_data_batch = []
for mol in batch:
    graph_data_batch.append(mol_to_graph_data_obj_grover(mol.mol))
new_batch = Batch().from_data_list(graph_data_batch)

In [87]:
new_batch

DataDataBatch(x=[692, 171], edge_index=[2, 1470], edge_attr=[1470, 185], batch=[692], ptr=[33])

In [85]:
batch_index = new_batch.batch.numpy()
graph_batch = new_batch.to(args.rank)
node_rep = model(graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr)
node_rep = group_node_rep(node_rep, batch_index, batch_size)
loss, word_loss, topo_loss, word_acc, topo_acc = motif_model(batch, node_rep)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (692x171 and 151x300)

In [40]:
graph_data_batch2 = []
for mol in batch:
    graph_data_batch2.append(mol_to_graph_data_obj_simple(mol.mol))
new_batch2 = Batch().from_data_list(graph_data_batch2)

In [67]:
new_batch2

DataDataBatch(x=[692, 2], edge_index=[2, 1470], edge_attr=[1470, 2], batch=[692], ptr=[33])