In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [2]:
import random, os

import numpy as np
import torch
from rdkit import RDLogger
from socket import gethostname

from grover.util.parsing import parse_args, get_newest_train_args
from grover.util.utils import create_logger
from task.cross_validate import cross_validate, randomsearch, gridsearch, make_confusion_matrix
from task.fingerprint import generate_fingerprints, generate_embvec
from task.predict import make_predictions, write_prediction
from task.pretrain import pretrain_model, subset_learning
from grover.data.torchvocab import MolVocab

from grover.topology.mol_tree import *

#add for gridsearch
from argparse import ArgumentParser, Namespace

import torch.distributed as dist

In [3]:
def setup(seed):
    # frozen random seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

In [4]:
# setup random seed
setup(seed=42)
# Avoid the pylint warning.
a = MolVocab
# supress rdkit logger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

# Initialize MolVocab
mol_vocab = MolVocab

## args

In [5]:
from grover.util.parsing import *

In [6]:
def parse_args() -> Namespace:
    """
    Parses arguments for training and testing (includes modifying/validating arguments).

    :return: A Namespace containing the parsed, modified, and validated args.
    """
    parser = ArgumentParser()
    subparser = parser.add_subparsers(title="subcommands",
                                      dest="parser_name",
                                      help="Subcommands for fintune, prediction, and fingerprint.")
    parser_pretrain = subparser.add_parser('pretrain', help="Pretrain with unlabelled SMILES.")
    add_pretrain_args(parser_pretrain)

    args = parser.parse_args(['pretrain','--data_path','data/zinc10M_0','--save_dir','model/zinc10M_0','--atom_vocab_path','data/zinc10M/zinc10M_atom_vocab.pkl','--bond_vocab_path','data/zinc10M/zinc10M_bond_vocab.pkl',
                          '--batch_size','100','--dropout','0.1','--depth','3','--num_attn_head','4','--hidden_size','1200','--epochs','20','--activation','PReLU','--backbone','gtrans','--embedding_output_type','both',
                          '--save_interval','5','--init_lr', '0.0002', '--max_lr', '0.0004', '--final_lr', '0.0001', '--weight_decay', '0.0000001', 
                          '--topology','--motif_vocab_path','data/zinc10M/clique.txt','--motif_hidden_size','1200','--motif_latent_size','56','--motif_order','dfs'])
    
    if args.parser_name == 'finetune' or args.parser_name == 'eval':
        modify_train_args(args)
    elif args.parser_name == "pretrain":
        modify_pretrain_args(args)
    elif args.parser_name == 'predict':
        modify_predict_args(args)
    elif args.parser_name == 'fingerprint':
        modify_fingerprint_args(args)

    return args

In [6]:
def parse_args() -> Namespace:
    """
    Parses arguments for training and testing (includes modifying/validating arguments).

    :return: A Namespace containing the parsed, modified, and validated args.
    """
    parser = ArgumentParser()
    subparser = parser.add_subparsers(title="subcommands",
                                      dest="parser_name",
                                      help="Subcommands for fintune, prediction, and fingerprint.")
    parser_finetune = subparser.add_parser('finetune', help="Fine tune the pre-trained model.")
    add_finetune_args(parser_finetune)
    parser_eval = subparser.add_parser('eval', help="Evaluate the results of the pre-trained model.")
    add_finetune_args(parser_eval)
    parser_predict = subparser.add_parser('predict', help="Predict results from fine tuned model.")
    add_predict_args(parser_predict)
    parser_fp = subparser.add_parser('fingerprint', help="Get the fingerprints of SMILES.")
    add_fingerprint_args(parser_fp)
    parser_pretrain = subparser.add_parser('pretrain', help="Pretrain with unlabelled SMILES.")
    add_pretrain_args(parser_pretrain)

    args = parser.parse_args(['finetune', '--data_path', 'data/tox21.csv', '--features_path', 'data/tox21.npz', '--save_dir', 'model/test/', '--checkpoint_path', 'grover_large.pt', '--no_features_scaling', '--split_type', 'scaffold_balanced', '--epochs', '2', '--ffn_hidden_size', '1300', '--num_folds', '2',
                              '--randomsearch', '--n_iters', '2', '--batch_size', '96','--confusionmatrix'])
    
    if args.parser_name == 'finetune' or args.parser_name == 'eval':
        modify_train_args(args)
    elif args.parser_name == "pretrain":
        modify_pretrain_args(args)
    elif args.parser_name == 'predict':
        modify_predict_args(args)
    elif args.parser_name == 'fingerprint':
        modify_fingerprint_args(args)

    return args

    args = parser.parse_args(['pretrain','--data_path','data/zinc10M_0','--save_dir','model/zinc10M_0','--atom_vocab_path','data/zinc10M/zinc10M_atom_vocab.pkl','--bond_vocab_path','data/zinc10M/zinc10M_bond_vocab.pkl',
                              '--batch_size','100','--dropout','0.1','--depth','3','--num_attn_head','4','--hidden_size','1200','--epochs','20','--activation','PReLU','--backbone','gtrans','--embedding_output_type','both',
                              '--save_interval','5','--init_lr', '0.0002', '--max_lr', '0.0004', '--final_lr', '0.0001', '--weight_decay', '0.0000001', 
                              '--topology','--motif_vocab_path','data/zinc10M/clique.txt','--motif_hidden_size','1200','--motif_latent_size','56','--motif_order','dfs',
                             '--wandb','--wandb_name', 'jupyter_zinc10M'])
                             


    args = parser.parse_args(['pretrain','--data_path','data/mgssl','--save_dir','model/mgssl','--atom_vocab_path','data/mgssl/mgssl_atom_vocab.pkl','--bond_vocab_path','data/mgssl/mgssl_bond_vocab.pkl',
                              '--batch_size','100','--dropout','0.1','--depth','3','--num_attn_head','4','--hidden_size','1200','--epochs','20','--activation','PReLU','--backbone','gtrans','--embedding_output_type','both',
                              '--save_interval','5','--init_lr', '0.0002', '--max_lr', '0.0004', '--final_lr', '0.0001', '--weight_decay', '0.0000001', 
                              '--topology','--motif_vocab_path','data/mgssl/clique.txt','--motif_hidden_size','1200','--motif_latent_size','56','--motif_order','dfs'])

In [7]:
args = parse_args()
args

Namespace(activation='ReLU', attn_hidden=128, attn_out=4, batch_size=96, bond_drop_rate=0, checkpoint_dir=None, checkpoint_path='grover_large.pt', checkpoint_paths=['grover_large.pt'], confusionmatrix=True, crossval_index_dir=None, crossval_index_file=None, cuda=True, data_path='data/tox21.csv', dataset_type='classification', dist_coff=0.1, distinct_init=False, dropout=0.0, early_stop_epoch=1000, embedding_output_type='atom', enbl_multi_gpu=False, ensemble_size=1, epochs=2, features_generator=None, features_only=False, features_path=['data/tox21.npz'], features_scaling=False, ffn_hidden_size=1300, ffn_last_size=None, ffn_mid_size=None, ffn_num_layers=2, final_lr=0.0001, fine_tune_coff=1, fingerprint=False, folds_file=None, gpu=0, gridsearch=False, init_lr=0.0001, max_data_size=None, max_lr=0.001, metric='auc', minimize_score=False, multi_class=False, multi_class_num=3, n_iters=2, no_cache=True, num_folds=2, num_lrs=1, parser_name='finetune', randomsearch=True, save_dir='model/test/', s

In [8]:
logger = create_logger(name='train', save_dir=args.save_dir, quiet=False)

In [9]:
# grovermotiftrainer

In [10]:
import os
import time
from argparse import Namespace
from logging import Logger
from typing import Tuple

import numpy as np

from grover.util.utils import get_task_names
from grover.util.utils import makedirs
from task.run_evaluation import run_evaluation, run_evaluation_cfm
from task.train import run_training

import random
import torch

In [11]:
info = logger.info if logger is not None else print

# Initialize relevant variables
init_seed = args.seed
save_dir = args.save_dir
task_names = get_task_names(args.data_path)

#randomize parameter list
max_lr_list = [0.0001, 0.0002, 0.0003, 0.0004, 0.0005, 0.0006, 0.0007]#, 0.0009, 0.001]
lr_rate=[2,3,4,5,6,7,8,9,10]
dropout_list = [0, 0.05, 0.1, 0.15, 0.2]
attn_hidden_list = 128
attn_out_list = [4, 8]
dist_coff_list = [0.05, 0.1, 0.15]
bond_drop_rate_list = [0, 0.2, 0.4, 0.6]
ffn_num_layers_list = [2, 3]
ffn_num_layers_list = [2, 3, 4, 5]
ffn_dense_list = [300, 500, 700, 900, 1100, 1300]
smote_rate_list = [0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]

# Run training with different random seeds for each fold
all_scores = []
params = []
time_start = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())

In [12]:
iter_num=0
np.random.seed()
random.seed()
args.init_lr = args.max_lr / 10
args.max_lr = np.random.choice(max_lr_list, 1)[0]
args.final_lr = args.max_lr / np.random.choice(lr_rate, 1)[0]
args.dropout = np.random.choice(dropout_list, 1)[0]
args.attn_out = np.random.choice(attn_out_list, 1)[0]
args.dist_coff = np.random.choice(dist_coff_list, 1)[0]
args.bond_drop_rate = np.random.choice(bond_drop_rate_list, 1)[0]
args.ffn_num_layers = np.random.choice(ffn_num_layers_list, 1)[0]
args.ffn_hidden_size = np.random.choice(ffn_dense_list, 1)[0]
if args.smote==True : 
    args.smote_rate = np.random.choice(smote_rate_list, 1)[0]
    params.append(f'\n{iter_num}th search parameter : init_lr is {args.init_lr} \n final_lr rate is {args.final_lr} \n dropout is {args.dropout} \n attn_out is {args.attn_out} \n dist_coff is {args.dist_coff} \n bond_drop_rate is {args.bond_drop_rate} \n ffn_num_layers is {args.ffn_num_layers} \n ffn_hidden_size is {args.ffn_hidden_size} \n batch_size is {args.batch_size} \n smote_rate is {args.smote_rate}')
else : 
    params.append(f'\n{iter_num}th search parameter : init_lr is {args.init_lr} \n final_lr rate is {args.final_lr} \n dropout is {args.dropout} \n attn_out is {args.attn_out} \n dist_coff is {args.dist_coff} \n bond_drop_rate is {args.bond_drop_rate} \n ffn_num_layers is {args.ffn_num_layers} \n ffn_hidden_size is {args.ffn_hidden_size} \n batch_size is {args.batch_size}')
info(params[iter_num])

args.seed = init_seed                        # if change this, result will be change
iter_dir = os.path.join(save_dir, f'iter_{iter_num}')
args.save_dir = iter_dir
makedirs(args.save_dir)

fold_scores = []
if args.confusionmatrix:
    scores_AUC = []
    scores_ACC = []
    scores_REC = []
    scores_PREC = []
    scores_SPEC = []
    scores_F1 = []
    scores_BA = []
    scores_TP = []
    scores_FP = []
    scores_TN = []
    scores_FN = []



0th search parameter : init_lr is 0.0001 
 final_lr rate is 3.3333333333333335e-05 
 dropout is 0.0 
 attn_out is 8 
 dist_coff is 0.05 
 bond_drop_rate is 0.6 
 ffn_num_layers is 5 
 ffn_hidden_size is 300 
 batch_size is 96


In [13]:
fold_num=0
info(f'Fold {fold_num}')
args.seed = init_seed + fold_num
args.save_dir = os.path.join(iter_dir, f'fold_{fold_num}')
makedirs(args.save_dir)

Fold 0


# train.py

In [14]:
import csv
import logging
import os
import pickle
import time
from argparse import Namespace
from logging import Logger
from typing import List

import numpy as np
import pandas as pd
import torch
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader

from grover.data import MolCollator
from grover.data import StandardScaler
from grover.util.metrics import get_metric_func
from grover.util.nn_utils import initialize_weights, param_count
from grover.util.scheduler import NoamLR
from grover.util.utils import build_optimizer, build_lr_scheduler, makedirs, load_checkpoint, get_loss_func, \
    save_checkpoint, build_model
from grover.util.utils import get_class_sizes, get_data, split_data, get_task_names
from task.predict import predict, evaluate, evaluate_predictions, evaluate_predictions_cfm

import wandb

In [15]:
def load_data(args, debug, logger):
    """
    load the training data.
    :param args:
    :param debug:
    :param logger:
    :return:
    """
    # Get data
    debug('Loading data')
    args.task_names = get_task_names(args.data_path)
    data = get_data(path=args.data_path, args=args, logger=logger)
    if data.data[0].features is not None:
        args.features_dim = len(data.data[0].features)
    else:
        args.features_dim = 0
    shared_dict = {}
    args.num_tasks = data.num_tasks()
    args.features_size = data.features_size()
    debug(f'Number of tasks = {args.num_tasks}')
    # Split data
    debug(f'Splitting data with seed {args.seed}')
    if args.separate_test_path:
        test_data = get_data(path=args.separate_test_path, args=args,
                             features_path=args.separate_test_features_path, logger=logger)
    if args.separate_val_path:
        val_data = get_data(path=args.separate_val_path, args=args,
                            features_path=args.separate_val_features_path, logger=logger)
    if args.separate_val_path and args.separate_test_path:
        train_data = data
    elif args.separate_val_path:
        train_data, _, test_data = split_data(data=data, split_type=args.split_type,
                                              sizes=(0.8, 0.2, 0.0), seed=args.seed, args=args, logger=logger)
    elif args.separate_test_path:
        train_data, val_data, _ = split_data(data=data, split_type=args.split_type,
                                             sizes=(0.8, 0.2, 0.0), seed=args.seed, args=args, logger=logger)
    else:
        train_data, val_data, test_data = split_data(data=data, split_type=args.split_type,
                                                     sizes=args.split_sizes, seed=args.seed, args=args, logger=logger)
    if args.features_scaling:
        features_scaler = train_data.normalize_features(replace_nan_token=0)
        val_data.normalize_features(features_scaler)
        test_data.normalize_features(features_scaler)
    else:
        features_scaler = None

    if args.smote == True:
        if args.dataset_type == 'classification':
            class_sizes = get_class_sizes(data)
            debug('Origin Class sizes')
            for i, task_class_sizes in enumerate(class_sizes):
                debug(f'{args.task_names[i]} '
                      f'{", ".join(f"{cls}: {int(size*args.train_data_size)}({size * 100:.2f}%)" for cls, size in enumerate(task_class_sizes))}')
        debug(f'Total size = {len(data):,} | '
              f'train size = {args.train_data_size:,} | val size = {len(val_data):,} | test size = {len(test_data):,}')
        args.train_data_size = len(train_data)
    
        debug('Smoted Class sizes')
        smoted_class_sizes = get_class_sizes(train_data)
        for i, task_class_sizes in enumerate(smoted_class_sizes):
            debug(f'{args.task_names[i]} '
                  f'{", ".join(f"{cls}: {int(size*args.train_data_size)}({size * 100:.2f}%)" for cls, size in enumerate(task_class_sizes))}')
        #note : there is some error of number because class_count is class_rate*data_length
        debug(f'Total size = {len(test_data)+len(train_data)+len(val_data):,} | '
              f'train size = {args.train_data_size:,} | val size = {len(val_data):,} | test size = {len(test_data):,}')
    else:
        if args.dataset_type == 'classification':
            class_sizes = get_class_sizes(data, args)
            debug('Class sizes')
            if not args.multi_class:
                for i, task_class_sizes in enumerate(class_sizes):
                    debug(f'{args.task_names[i]} '
                          f'{", ".join(f"{cls}: {size * 100:.2f}%" for cls, size in enumerate(task_class_sizes))}')
            elif args.multi_class:
                for i in range(args.multi_class_num):
                    print(f'{i} : {class_sizes[i][0]:.2f}')
                    
        args.train_data_size = len(train_data)
        debug(f'Total size = {len(data):,} | '
              f'train size = {len(train_data):,} | val size = {len(val_data):,} | test size = {len(test_data):,}')

    # Initialize scaler and scale training targets by subtracting mean and dividing standard deviation (regression only)
    if args.dataset_type == 'regression':
        debug('Fitting scaler')
        _, train_targets = train_data.smiles(), train_data.targets()
        scaler = StandardScaler().fit(train_targets)
        scaled_targets = scaler.transform(train_targets).tolist()
        train_data.set_targets(scaled_targets)

        val_targets = val_data.targets()
        scaled_val_targets = scaler.transform(val_targets).tolist()
        val_data.set_targets(scaled_val_targets)
    else:
        scaler = None
    return features_scaler, scaler, shared_dict, test_data, train_data, val_data


def save_splits(args, test_data, train_data, val_data):
    """
    Save the splits.
    :param args:
    :param test_data:
    :param train_data:
    :param val_data:
    :return:
    """
    with open(args.data_path, 'r') as f:
        reader = csv.reader(f)
        header = next(reader)

        lines_by_smiles = {}
        indices_by_smiles = {}
        for i, line in enumerate(reader):
            smiles = line[0]
            lines_by_smiles[smiles] = line
            indices_by_smiles[smiles] = i

    all_split_indices = []
    for dataset, name in [(train_data, 'train'), (val_data, 'val'), (test_data, 'test')]:
        with open(os.path.join(args.save_dir, name + '_smiles.csv'), 'w') as f:
            writer = csv.writer(f)
            writer.writerow(['smiles'])
            for smiles in dataset.smiles():
                writer.writerow([smiles])
        with open(os.path.join(args.save_dir, name + '_full.csv'), 'w') as f:
            writer = csv.writer(f)
            writer.writerow(header)
            for smiles in dataset.smiles():
                writer.writerow(lines_by_smiles[smiles])
        split_indices = []
        for smiles in dataset.smiles():
            split_indices.append(indices_by_smiles[smiles])
            split_indices = sorted(split_indices)
        all_split_indices.append(split_indices)
    with open(os.path.join(args.save_dir, 'split_indices.pckl'), 'wb') as f:
        pickle.dump(all_split_indices, f)
    return writer
def train(epoch, model, data, loss_func, optimizer, scheduler,
          shared_dict, args: Namespace, n_iter: int = 0,
          logger: logging.Logger = None):
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    # debug = logger.debug if logger is not None else print

    model.train()

    # data.shuffle()

    loss_sum, iter_count = 0, 0
    cum_loss_sum, cum_iter_count = 0, 0


    mol_collator = MolCollator(shared_dict=shared_dict, args=args)

    num_workers = 4
    if type(data) == DataLoader:
        mol_loader = data
    else:
        mol_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True,
                            num_workers=num_workers, collate_fn=mol_collator)

    for _, item in enumerate(mol_loader):
        _, batch, features_batch, mask, targets = item
        if next(model.parameters()).is_cuda:
            mask, targets = mask.cuda(), targets.cuda()
        class_weights = torch.ones(targets.shape)

        if args.cuda:
            class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(batch, features_batch)
        loss = loss_func(preds, targets) * class_weights * mask
        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += args.batch_size

        cum_loss_sum += loss.item()
        cum_iter_count += 1

        loss.backward()
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += args.batch_size

        #if (n_iter // args.batch_size) % args.log_frequency == 0:
        #    lrs = scheduler.get_lr()
        #    loss_avg = loss_sum / iter_count
        #    loss_sum, iter_count = 0, 0
        #    lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs))

    return n_iter, cum_loss_sum / cum_iter_count

In [16]:
if logger is not None:
    debug, info = logger.debug, logger.info
else:
    debug = info = print


# pin GPU to local rank.
idx = args.gpu
if args.gpu is not None:
    torch.cuda.set_device(idx)

features_scaler, scaler, shared_dict, test_data, train_data, val_data = load_data(args, debug, logger)
metric_func = get_metric_func(metric=args.metric)

# Set up test set evaluation
test_smiles, test_targets = test_data.smiles(), test_data.targets()
if args.multi_class:
    sum_test_preds = np.zeros((len(test_smiles), args.multi_class_num))
else:
    sum_test_preds = np.zeros((len(test_smiles), args.num_tasks))

model_idx=0
save_dir = os.path.join(args.save_dir, f'model_{model_idx}')
makedirs(save_dir)

# Load/build model
if args.checkpoint_paths is not None:
    if len(args.checkpoint_paths) == 1:
        cur_model = 0
    else:
        cur_model = model_idx
    debug(f'Loading model {cur_model} from {args.checkpoint_paths[cur_model]}')
    model = load_checkpoint(args.checkpoint_paths[cur_model], current_args=args, logger=logger)
else:
    debug(f'Building model {model_idx}')
    model = build_model(model_idx=model_idx, args=args)

if args.fine_tune_coff != 1 and args.checkpoint_paths is not None:
    debug("Fine tune fc layer with different lr")
    initialize_weights(model_idx=model_idx, model=model.ffn, distinct_init=args.distinct_init)

# Get loss and metric functions
loss_func = get_loss_func(args, model)

optimizer = build_optimizer(model, args)

debug(model)
debug(f'Number of parameters = {param_count(model):,}')
if args.cuda:
    debug('Moving model to cuda')
    model = model.cuda()

# set up wandb
if args.wandb : 
    wandb.init(project=args.wandb_name)
    wandb.config = args
    wandb.watch(model)


# Ensure that model is saved in correct location for evaluation if 0 epochs
save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler, features_scaler, args)

# Learning rate schedulers
scheduler = build_lr_scheduler(optimizer, args)

# Bulid data_loader
shuffle = True
mol_collator = MolCollator(shared_dict={}, args=args)
train_data = DataLoader(train_data,
                        batch_size=args.batch_size,
                        shuffle=shuffle,
                        num_workers=10,
                        collate_fn=mol_collator)

# Run training
best_score = float('inf') if args.minimize_score else -float('inf')
best_epoch, n_iter = 0, 0
min_val_loss = float('inf')

Loading data
Number of tasks = 12
Splitting data with seed 0
100%|##########| 7831/7831 [00:01<00:00, 5433.96it/s]
Total scaffolds = 2,326 | train scaffolds = 1,535 | val scaffolds = 367 | test scaffolds = 424
  target_avgs.append(np.nanmean(targets, axis=0))
Label averages per scaffold, in decreasing order of scaffold frequency,capped at 10 scaffolds and 20 labels: [(array([0.01736614, 0.0152439 , 0.12600321, 0.02879581, 0.11440329,
       0.03917221, 0.01794072, 0.11601307, 0.0257732 , 0.04442771,
       0.16479724, 0.03924528]), array([1382, 1312, 1246, 1146, 1215, 1353, 1282, 1224, 1358, 1328, 1159,
       1325])), (array([0.01056958, 0.01393939, 0.01355514, 0.01047806, 0.06081946,
       0.02207637, 0.02308627, 0.09354414, 0.02058824, 0.04649721,
       0.03804348, 0.02112251]), array([1703, 1650, 1623, 1527, 1562, 1676, 1646, 1518, 1700, 1613, 1472,
       1657])), (array([nan, nan,  0., nan, nan,  0.,  1., nan, nan,  0.,  1.,  1.]), array([0, 0, 4, 0, 0, 3, 1, 0, 0, 1, 4, 3])), 

In [17]:
for epoch in range(args.epochs):
    s_time = time.time()
    n_iter, train_loss = train(
        epoch=epoch,
        model=model,
        data=train_data,
        loss_func=loss_func,
        optimizer=optimizer,
        scheduler=scheduler,
        args=args,
        n_iter=n_iter,
        shared_dict=shared_dict,
        logger=logger
    )
    t_time = time.time() - s_time
    s_time = time.time()
    val_scores, val_loss = evaluate(
        model=model,
        data=val_data,
        loss_func=loss_func,
        num_tasks=args.num_tasks,
        metric_func=metric_func,
        batch_size=args.batch_size,
        dataset_type=args.dataset_type,
        scaler=scaler,
        shared_dict=shared_dict,
        logger=logger,
        args=args
    )
    v_time = time.time() - s_time


    # Average validation score
    avg_val_score = np.nanmean(val_scores)


    # Logged after lr step
    if isinstance(scheduler, ExponentialLR):
        scheduler.step()

    if args.show_individual_scores:
        # Individual validation scores
        for task_name, val_score in zip(args.task_names, val_scores):
            debug(f'Validation {task_name} {args.metric} = {val_score:.6f}')
    print('Epoch: {:04d}'.format(epoch),
          'loss_train: {:.6f}'.format(train_loss),
          'loss_val: {:.6f}'.format(val_loss),
          f'{args.metric}_val: {avg_val_score:.4f}',
          # 'auc_val: {:.4f}'.format(avg_val_score),
          'cur_lr: {:.5f}'.format(scheduler.get_lr()[-1]),
          't_time: {:.4f}s'.format(t_time),
          'v_time: {:.4f}s'.format(v_time))

    if args.tensorboard:
        writer.add_scalar('loss/train', train_loss, epoch)
        writer.add_scalar('loss/val', val_loss, epoch)
        writer.add_scalar(f'{args.metric}_val', avg_val_score, epoch)

    if args.wandb :         
        wandb.log({"val_loss" : val_loss, "val_metrics" : val_scores})


    # Save model checkpoint if improved validation score
    if args.select_by_loss:
        if val_loss < min_val_loss:
            min_val_loss, best_epoch = val_loss, epoch
            save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler, features_scaler, args)
    else:
        if args.minimize_score and avg_val_score < best_score or \
                not args.minimize_score and avg_val_score > best_score:
            best_score, best_epoch = avg_val_score, epoch
            save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler, features_scaler, args)

    if epoch - best_epoch > args.early_stop_epoch:
        break

Epoch: 0000 loss_train: 0.527613 loss_val: 0.715554 auc_val: 0.7568 cur_lr: 0.00015 t_time: 25.6708s v_time: 2.0701s
Epoch: 0001 loss_train: 0.421411 loss_val: 0.715770 auc_val: 0.8045 cur_lr: 0.00003 t_time: 24.2523s v_time: 2.0245s


In [18]:

ensemble_scores = 0.0

# Evaluate on test set using model with best validation score
if args.select_by_loss:
    info(f'Model {model_idx} best val loss = {min_val_loss:.6f} on epoch {best_epoch}')
else:
    info(f'Model {model_idx} best validation {args.metric} = {best_score:.6f} on epoch {best_epoch}')
model = load_checkpoint(os.path.join(save_dir, 'model.pt'), cuda=args.cuda, logger=logger)

test_preds, _ = predict(
    model=model,
    data=test_data,
    loss_func=loss_func,
    batch_size=args.batch_size,
    logger=logger,
    shared_dict=shared_dict,
    scaler=scaler,
    args=args
)

Model 0 best val loss = 0.715554 on epoch 0
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.0.mpn_q.act_func.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.0.mpn_q.W_h.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.0.mpn_k.act_func.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.0.mpn_k.W_h.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.0.mpn_v.act_func.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.0.mpn_v.W_h.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.1.mpn_q.act_func.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.1.mpn_q.W_h.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.1.mpn_k.act_func.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.1.mpn_k.W_h.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.h

In [19]:
test_scores = evaluate_predictions(
            preds=test_preds,
            targets=test_targets,
            num_tasks=args.num_tasks,
            metric_func=metric_func,
            dataset_type=args.dataset_type,
            arg=args,
            logger=logger
        )

In [20]:
test_scores

[0.723381314356569,
 0.6980876628871214,
 0.7870563674321502,
 0.68391994478951,
 0.7033121916842847,
 0.7877631868534958,
 0.6172839506172839,
 0.6417086481947943,
 0.7322822105430801,
 0.6724259974259974,
 0.7397490386561425,
 0.7045931008195159]

In [25]:
if len(test_preds) != 0:
    sum_test_preds += np.array(test_preds, dtype=float)

# Average test score
avg_test_score = np.nanmean(test_scores)
info(f'Model {model_idx} test {args.metric} = {avg_test_score:.6f}')

if args.show_individual_scores:
    # Individual test scores
    for task_name, test_score in zip(args.task_names, test_scores):
        info(f'Model {model_idx} test {task_name} {args.metric} = {test_score:.6f}')

# Evaluate ensemble on test set
avg_test_preds = (sum_test_preds / args.ensemble_size).tolist()

ensemble_scores = test_scores

ind = [['preds'] * args.num_tasks + ['targets'] * args.num_tasks, args.task_names * 2]
ind = pd.MultiIndex.from_tuples(list(zip(*ind)))
if args.multi_class:
    data = np.concatenate([np.array([np.argmax(x) for x in avg_test_preds]).reshape(-1,1), np.array(test_targets)], 1)
else:
    data = np.concatenate([np.array(avg_test_preds), np.array(test_targets)], 1)
test_result = pd.DataFrame(data, index=test_smiles, columns=ind)
test_result.to_csv(os.path.join(args.save_dir, 'test_result.csv'))

# Average ensemble score
avg_ensemble_test_score = np.nanmean(ensemble_scores)
info(f'Ensemble test {args.metric} = {avg_ensemble_test_score:.6f}')

# Individual ensemble scores
if args.show_individual_scores:
    for task_name, ensemble_score in zip(args.task_names, ensemble_scores):
        info(f'Ensemble test {task_name} {args.metric} = {ensemble_score:.6f}')

Model 0 test auc = 0.707630
Ensemble test auc = 0.707630


In [24]:
    else:
        for fold_num in range(args.num_folds):
            info(f'Fold {fold_num}')
            args.seed = init_seed + fold_num
            args.save_dir = os.path.join(save_dir, f'fold_{fold_num}')
            makedirs(args.save_dir)
            
            AUC, ACC, REC, PREC, SPEC, F1, BA, TP, FP, TN, FN = run_evaluation_cfm(args, logger)
            scores_AUC.append(AUC)
            scores_ACC.append(ACC)
            scores_REC.append(REC)
            scores_PREC.append(PREC)
            scores_SPEC.append(SPEC)
            scores_F1.append(F1)
            scores_BA.append(BA)
            scores_TP.append(TP)
            scores_FP.append(FP)
            scores_TN.append(TN)
            scores_FN.append(FN)
        scores_AUC = np.array(scores_AUC)
        scores_ACC = np.array(scores_ACC)
        scores_REC = np.array(scores_REC)
        scores_PREC = np.array(scores_PREC)
        scores_SPEC = np.array(scores_SPEC)
        scores_F1 = np.array(scores_F1)
        scores_BA = np.array(scores_BA)
        scores_TN = np.array(scores_TN)
        scores_FN = np.array(scores_FN)
        scores_TP = np.array(scores_TP)
        scores_FP = np.array(scores_FP)

        # Report scores for each fold
        info(f'{args.num_folds}-fold cross validation')

        # Report scores across models
        avg_scores_AUC = np.nanmean(scores_AUC, axis=1)  # average score for each model across tasks
        mean_score_AUC, std_score_AUC = np.nanmean(avg_scores_AUC), np.nanstd(avg_scores_AUC)
        info(f'overall_{args.split_type}_test_AUC={mean_score_AUC:.6f}')
        info(f'std={std_score_AUC:.6f}')

        avg_scores_ACC = np.nanmean(scores_ACC, axis=1)  # average score for each model across tasks
        mean_score_ACC, std_score_ACC = np.nanmean(avg_scores_ACC), np.nanstd(avg_scores_ACC)
        info(f'overall_{args.split_type}_test_Accuracy={mean_score_ACC:.6f}')
        info(f'std={std_score_ACC:.6f}')

        avg_scores_REC = np.nanmean(scores_REC, axis=1)  # average score for each model across tasks
        mean_score_REC, std_score_REC = np.nanmean(avg_scores_REC), np.nanstd(avg_scores_REC)
        info(f'overall_{args.split_type}_test_Recall={mean_score_REC:.6f}')
        info(f'std={std_score_REC:.6f}')

        avg_scores_PREC = np.nanmean(scores_PREC, axis=1)  # average score for each model across tasks
        mean_score_PREC, std_score_PREC = np.nanmean(avg_scores_PREC), np.nanstd(avg_scores_PREC)
        info(f'overall_{args.split_type}_test_Precision={mean_score_PREC:.6f}')
        info(f'std={std_score_PREC:.6f}')

        avg_scores_SPEC = np.nanmean(scores_SPEC, axis=1)  # average score for each model across tasks
        mean_score_SPEC, std_score_SPEC = np.nanmean(avg_scores_SPEC), np.nanstd(avg_scores_SPEC)
        info(f'overall_{args.split_type}_test_Specificity={mean_score_SPEC:.6f}')
        info(f'std={std_score_SPEC:.6f}')

        avg_scores_F1 = np.nanmean(scores_F1, axis=1)  # average score for each model across tasks
        mean_score_F1, std_score_F1 = np.nanmean(avg_scores_F1), np.nanstd(avg_scores_F1)
        info(f'overall_{args.split_type}_test_F1={mean_score_F1:.6f}')
        info(f'std={std_score_F1:.6f}')

        avg_scores_BA = np.nanmean(scores_BA, axis=1)  # average score for each model across tasks
        mean_score_BA, std_score_BA = np.nanmean(avg_scores_BA), np.nanstd(avg_scores_BA)
        info(f'overall_{args.split_type}_test_BA={mean_score_BA:.6f}')
        info(f'std={std_score_BA:.6f}')

        avg_scores_TP = np.nanmean(scores_TP)  # average score for each model across tasks
        mean_score_TP, std_score_TP = np.nanmean(avg_scores_TP), np.nanstd(avg_scores_TP)

        avg_scores_FP = np.nanmean(scores_FP)  # average score for each model across tasks
        mean_score_FP, std_score_FP = np.nanmean(avg_scores_FP), np.nanstd(avg_scores_FP)

        avg_scores_TN = np.nanmean(scores_TN)  # average score for each model across tasks
        mean_score_TN, std_score_TN = np.nanmean(avg_scores_TN), np.nanstd(avg_scores_TN)

        avg_scores_FN = np.nanmean(scores_FN)  # average score for each model across tasks
        mean_score_FN, std_score_FN = np.nanmean(avg_scores_FN), np.nanstd(avg_scores_FN)
        info(f'TP : {mean_score_TP:.6f}\tFP : {mean_score_FP:.6f}')
        info(f'FN : {mean_score_FN:.6f}\tTN : {mean_score_TN:.6f}')


        if args.show_individual_scores:
            for task_num, task_name in enumerate(task_names):
                info(f'Overall test {task_name} {args.metric} = '
                     f'{np.nanmean(all_scores[:, task_num]):.6f} +/- {np.nanstd(all_scores[:, task_num]):.6f}')

        return mean_score_AUC, std_score_AUC

[0.723381314356569,
 0.6980876628871214,
 0.7870563674321502,
 0.68391994478951,
 0.7033121916842847,
 0.7877631868534958,
 0.6172839506172839,
 0.6417086481947943,
 0.7322822105430801,
 0.6724259974259974,
 0.7397490386561425,
 0.7045931008195159]

In [21]:
test_scores_AUC, test_scores_ACC, test_scores_REC, test_scores_PREC, test_scores_SPEC, test_scores_F1, test_scores_BA, test_TP, test_FP, test_TN, test_FN = evaluate_predictions_cfm(
        preds=test_preds,
        targets=test_targets,
        num_tasks=args.num_tasks,
        metric_func=metric_func,
        dataset_type=args.dataset_type,
        args=args, logger=logger
    )

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
if args.confusionmatrix:
    test_scores_AUC, test_scores_ACC, test_scores_REC, test_scores_PREC, test_scores_SPEC, test_scores_F1, test_scores_BA, test_TP, test_FP, test_TN, test_FN = evaluate_predictions_cfm(
        preds=test_preds,
        targets=test_targets,
        num_tasks=args.num_tasks,
        metric_func=metric_func,
        dataset_type=args.dataset_type,
        args=args, logger=logger
    )
    # Average test score
    avg_test_score_AUC = np.nanmean(test_scores_AUC)
    avg_test_score_ACC = np.nanmean(test_scores_ACC)
    avg_test_score_REC = np.nanmean(test_scores_REC)
    avg_test_score_PREC = np.nanmean(test_scores_PREC)
    avg_test_score_SPEC = np.nanmean(test_scores_SPEC)
    avg_test_score_F1 = np.nanmean(test_scores_F1)
    avg_test_score_BA = np.nanmean(test_scores_BA)
    avg_test_TP = np.nanmean(test_TP)
    avg_test_TN = np.nanmean(test_TN)
    avg_test_FP = np.nanmean(test_FP)
    avg_test_FN = np.nanmean(test_FN)
    info(f'Model test AUC = {avg_test_score_AUC:.6f}')
    info(f'Model test ACC = {avg_test_score_ACC:.6f}')
    info(f'Model test REC = {avg_test_score_REC:.6f}')
    info(f'Model test PREC = {avg_test_score_PREC:.6f}')
    info(f'Model test SPEC = {avg_test_score_SPEC:.6f}')
    info(f'Model test F1 = {avg_test_score_F1:.6f}')
    info(f'Model test BA = {avg_test_score_BA:.6f}')
    info(f'Confusion matrix\nTP : {avg_test_TP:.6f}\tFP : {avg_test_FP:.6f}')
    info(f'FN : {avg_test_FN:.6f}\tTN : {avg_test_TN:.6f}')

    if args.metric=='f1':
        test_score = test_scores_F1
        avg_test_score = avg_test_score_F1
    elif args.metric=='recall':
        test_score = test_scores_REC
        avg_test_score = avg_test_score_REC
    elif args.metric=='auc':
        test_score = test_scores_AUC
        avg_test_score = avg_test_score_AUC
    else:
        raise ValueError(f'in confusionmatrix, Metric "{metric}" not supported. add the metric in code')

else:
    test_scores = evaluate_predictions(
        preds=test_preds,
        targets=test_targets,
        num_tasks=args.num_tasks,
        metric_func=metric_func,
        dataset_type=args.dataset_type,
        arg=args,
        logger=logger
    )
    # Average test score
    avg_test_score = np.nanmean(test_scores)
    info(f'Model {model_idx} test {args.metric} = {avg_test_score:.6f}')

if len(test_preds) != 0:
    sum_test_preds += np.array(test_preds, dtype=float)

if args.show_individual_scores:
    # Individual test scores
    for task_name, test_score in zip(args.task_names, test_scores):
        info(f'Model {model_idx} test {task_name} {args.metric} = {test_score:.6f}')

# Evaluate ensemble on test set
avg_test_preds = (sum_test_preds / args.ensemble_size).tolist()

"""ensemble_scores = evaluate_predictions(
    preds=avg_test_preds,
    targets=test_targets,
    num_tasks=args.num_tasks,
    metric_func=metric_func,
    dataset_type=args.dataset_type,
    arg=args,
    logger=logger
)
"""

ind = [['preds'] * args.num_tasks + ['targets'] * args.num_tasks, args.task_names * 2]
ind = pd.MultiIndex.from_tuples(list(zip(*ind)))
if args.multi_class:
    data = np.concatenate([np.array([np.argmax(x) for x in avg_test_preds]).reshape(-1,1), np.array(test_targets)], 1)
else:
    data = np.concatenate([np.array(avg_test_preds), np.array(test_targets)], 1)
test_result = pd.DataFrame(data, index=test_smiles, columns=ind)
test_result.to_csv(os.path.join(args.save_dir, 'test_result.csv'))

In [40]:
test_score

[0.6245438013237361]

In [36]:
test_score

[0.6245438013237361]

# run_motif_training

In [12]:
import os
import time
from argparse import Namespace
from logging import Logger

import torch
from torch.utils.data import DataLoader
import wandb

from grover.data.dist_sampler import DistributedSampler
from grover.data.groverdataset import get_data, split_data, GroverCollator, BatchMolDataset, get_motif_data, split_data_motif, GroverMotifCollator, BatchMolDataset_motif
from grover.data.torchvocab import MolVocab
from grover.model.models import GROVEREmbedding
from grover.util.multi_gpu_wrapper import MultiGpuWrapper as mgw
from grover.util.nn_utils import param_count
from grover.util.utils import build_optimizer, build_lr_scheduler
from task.grovertrainer import GROVERTrainer, GROVERMotifTrainer

from grover.topology.mol_tree import Motif_Vocab
from grover.topology.motif_generation import Motif_Generation


In [13]:
def split_data_motif(data,
               split_type='random',
               sizes=(0.8, 0.1, 0.1),
               seed=0,
               logger=None):
    """
    Split data with given train/validation/test ratio.
    :param data:
    :param split_type:
    :param sizes:
    :param seed:
    :param logger:
    :return:
    """
    assert len(sizes) == 3 and sum(sizes) == 1

    if split_type == "random":
        data.shuffle(seed=seed)
        data = data.data

        train_size = int(sizes[0] * len(data))
        train_val_size = int((sizes[0] + sizes[1]) * len(data))

        train = data[:train_size]
        val = data[train_size:train_val_size]
        test = data[train_val_size:]

        return BatchMolDataset_motif(train), BatchMolDataset_motif(val), BatchMolDataset_motif(test)
    else:
        raise NotImplementedError("Do not support %s splits" % split_type)

In [14]:
def pre_load_data(dataset: BatchMolDataset, rank: int, num_replicas: int, sample_per_file: int = None, epoch: int = 0):

    mock_sampler = DistributedSampler(dataset, num_replicas=num_replicas, rank=rank, shuffle=False,
                                      sample_per_file=sample_per_file)
    mock_sampler.set_epoch(epoch)
    pre_indices = mock_sampler.get_indices()
    for i in pre_indices:
        dataset.load_data(i)


In [15]:
def run_motif_training(args, logger):
    """
    Run the pretrain task with topology predict.
    :param args:
    :param logger:
    :return:
    """
    
    # initalize the logger.
    if logger is not None:
        debug, _ = logger.debug, logger.info
    else:
        debug = print

    # initialize the horovod library
    if args.enable_multi_gpu:
        mgw.init()

    # binding training to GPUs.
    master_worker = (mgw.rank() == 0) if args.enable_multi_gpu else True
    # pin GPU to local rank. By default, we use gpu:0 for training.
    local_gpu_idx = mgw.local_rank() if args.enable_multi_gpu else 0
    with_cuda = args.cuda
    if with_cuda:
        torch.cuda.set_device(local_gpu_idx)

    # get rank an  number of workers
    rank = mgw.rank() if args.enable_multi_gpu else 0
    num_replicas = mgw.size() if args.enable_multi_gpu else 1
    # print("Rank: %d Rep: %d" % (rank, num_replicas))

    # load file paths of the data.
    if master_worker:
        print(args)
        if args.enable_multi_gpu:
            debug("Total workers: %d" % (mgw.size()))
        debug('Loading data')
    data, sample_per_file = get_motif_data(data_path=args.data_path)

    # data splitting
    if master_worker:
        debug(f'Splitting data with seed 0.')
    train_data, test_data, _ = split_data_motif(data=data, sizes=(0.9, 0.1, 0.0), seed=0, logger=logger)

    # Here the true train data size is the train_data divided by #GPUs
    if args.enable_multi_gpu:
        args.train_data_size = len(train_data) // mgw.size()
    else:
        args.train_data_size = len(train_data)
    if master_worker:
        debug(f'Total size = {len(data):,} | '
              f'train size = {len(train_data):,} | val size = {len(test_data):,}')

    # load atom and bond vocabulary and the semantic motif labels.
    atom_vocab = MolVocab.load_vocab(args.atom_vocab_path)
    bond_vocab = MolVocab.load_vocab(args.bond_vocab_path)
    atom_vocab_size, bond_vocab_size = len(atom_vocab), len(bond_vocab)

    # Load motif vocabulary for pretrain
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    motif_vocab = [x.strip("\r\n ") for x in open(args.motif_vocab_path)]
    motif_vocab = Motif_Vocab(motif_vocab)

    # Hard coding here, since we haven't load any data yet!
    fg_size = 85
    shared_dict = {}
    motif_collator = GroverMotifCollator(shared_dict=shared_dict, atom_vocab=atom_vocab, bond_vocab=bond_vocab, args=args)
    if master_worker:
        debug("atom vocab size: %d, bond vocab size: %d, Number of FG tasks: %d" % (atom_vocab_size,
                                                                                    bond_vocab_size, fg_size))

    # Define the distributed sampler. If using the single card, the sampler will be None.
    train_sampler = None
    test_sampler = None
    shuffle = True
    if args.enable_multi_gpu:
        # If not shuffle, the performance may decayed.
        train_sampler = DistributedSampler(
            train_data, num_replicas=mgw.size(), rank=mgw.rank(), shuffle=True, sample_per_file=sample_per_file)
        # Here sample_per_file in test_sampler is None, indicating the test sampler would not divide the test samples by
        # rank. (TODO: bad design here.)
        test_sampler = DistributedSampler(
            test_data, num_replicas=mgw.size(), rank=mgw.rank(), shuffle=False)
        train_sampler.set_epoch(args.epochs)
        test_sampler.set_epoch(1)
        # if we enables multi_gpu training. shuffle should be disabled.
        shuffle = False

    # Pre load data. (Maybe unnecessary. )
    pre_load_data(train_data, rank, num_replicas, sample_per_file)
    pre_load_data(test_data, rank, num_replicas)
    if master_worker:
        # print("Pre-loaded training data: %d" % train_data.count_loaded_datapoints())
        print("Pre-loaded test data: %d" % test_data.count_loaded_datapoints())

    # Build dataloader
    train_data_dl = DataLoader(train_data,
                               batch_size=args.batch_size,
                               shuffle=shuffle,
                               num_workers=12,
                               sampler=train_sampler,
                               collate_fn=motif_collator)
    test_data_dl = DataLoader(test_data,
                              batch_size=args.batch_size,
                              shuffle=shuffle,
                              num_workers=10,
                              sampler=test_sampler,
                              collate_fn=motif_collator)

    # Build the embedding model.
    grover_model = GROVEREmbedding(args)
    
    # build the topology predict model.
    motif_model = Motif_Generation(motif_vocab, args.motif_hidden_size, args.motif_latent_size, 3, device, args.motif_order)

    #  Build the trainer.
    trainer = GROVERMotifTrainer(args=args,
                            embedding_model=grover_model,
                            topology_model=motif_model,
                            atom_vocab_size=atom_vocab_size,
                            bond_vocab_size=bond_vocab_size,
                            fg_size=fg_size,
                            train_dataloader=train_data_dl,
                            test_dataloader=test_data_dl,
                            optimizer_builder=build_optimizer,
                            scheduler_builder=build_lr_scheduler,
                            logger=logger,
                            with_cuda=with_cuda,
                            enable_multi_gpu=args.enable_multi_gpu)

    # Restore the interrupted training.
    model_dir = os.path.join(args.save_dir, "model")
    resume_from_epoch = 0
    resume_scheduler_step = 0
    if master_worker:
        resume_from_epoch, resume_scheduler_step = trainer.restore(model_dir)
    if args.enable_multi_gpu:
        resume_from_epoch = mgw.broadcast(torch.tensor(resume_from_epoch), root_rank=0, name="resume_from_epoch").item()
        resume_scheduler_step = mgw.broadcast(torch.tensor(resume_scheduler_step),
                                              root_rank=0, name="resume_scheduler_step").item()
        trainer.scheduler.current_step = resume_scheduler_step
        print("Restored epoch: %d Restored scheduler step: %d" % (resume_from_epoch, trainer.scheduler.current_step))
    trainer.broadcast_parameters()

    # Print model details.
    if master_worker:
        # Change order here.
        print(grover_model)
        print("Total parameters: %d" % param_count(trainer.grover))

    #wandb
    if args.wandb :
        wandb.init(project=args.wandb_name)
        wandb.config = args
        wandb.watch(grover_model)
        
    # Perform training.
    best_val_loss = 0
    best_val_epoch = 0
    best_model_dir = os.path.join(args.save_dir, "model_best")
    for epoch in range(resume_from_epoch + 1, args.epochs):
        s_time = time.time()

        # Data pre-loading.
        if args.enable_multi_gpu:
            train_sampler.set_epoch(epoch)
            train_data.clean_cache()
            idxs = train_sampler.get_indices()
            for local_gpu_idx in idxs:
                train_data.load_data(local_gpu_idx)
        d_time = time.time() - s_time

        # perform training and validation.
        s_time = time.time()
        _, train_loss, _ = trainer.train(epoch)
        t_time = time.time() - s_time
        s_time = time.time()
        _, val_loss, detailed_loss_val = trainer.test(epoch)
        val_av_loss, val_bv_loss, val_fg_loss, _, _, _, val_topo_loss, val_node_loss, topo_acc, node_acc = detailed_loss_val
        v_time = time.time() - s_time
        
        if best_val_loss > val_loss:
            best_val_loss = val_loss
            best_val_epoch = epoch
            trainer.save(epoch, best_model_dir)

        if args.wandb :         
            wandb.log({"train_loss" : train_loss, "val_loss" : val_loss, "topo_loss" : val_topo_loss})
        
        # print information.
        if master_worker:
            print('Epoch: {:04d}'.format(epoch),
                  'loss_train: {:.6f}'.format(train_loss),
                  'loss_val: {:.6f}'.format(val_loss),
                  'loss_val_av: {:.6f}'.format(val_av_loss),
                  'loss_val_bv: {:.6f}'.format(val_bv_loss),
                  'loss_val_fg: {:.6f}'.format(val_fg_loss),
                  'loss_val_topo: {:.6f}'.format(val_topo_loss),
                  'loss_val_node: {:.6f}'.format(val_node_loss),
                  'acc_topo: {:.6f}'.format(topo_acc),
                  'acc_node: {:.6f}'.format(node_acc),
                  'cur_lr: {:.5f}'.format(trainer.scheduler.get_lr()[0]),
                  't_time: {:.4f}s'.format(t_time),
                  'v_time: {:.4f}s'.format(v_time),
                  'd_time: {:.4f}s'.format(d_time), flush=True)
            
        
            if epoch % args.save_interval == 0:
                trainer.save(epoch, model_dir)


            trainer.save_tmp(epoch, model_dir, rank)

    # Only save final version.
    if master_worker:
        trainer.save(args.epochs, model_dir, "")

# dataset 관련

In [16]:
import math
import os
import csv
import random
from argparse import Namespace
from typing import Callable, List, Union

import numpy as np
from rdkit import Chem
import torch
from torch.utils.data.dataset import Dataset

from grover.data.molfeaturegenerator import get_features_generator
from grover.data.scaler import StandardScaler

import grover.util.utils as feautils
from grover.data import mol2graph
from grover.data.moldataset import MoleculeDatapoint
from grover.data.task_labels import atom_to_vocab, bond_to_vocab

from grover.topology.mol_tree import MolTree, MolTree_break

In [17]:
# 이건 pretrain.py로
def pre_load_data_motif(dataset: BatchMolDataset, rank: int, num_replicas: int, sample_per_file: int = None, epoch: int = 0):
    """
    Pre-load data at the beginning of each epoch.
    :param dataset: the training dataset.
    :param rank: the rank of the current worker.
    :param num_replicas: the replicas.
    :param sample_per_file: the number of the data points in each file. When sample_per_file is None, all data will be
    loaded. It implies the testing phase. (TODO: bad design here.)
    :param epoch: the epoch number.
    :return:
    """
    mock_sampler = DistributedSampler(dataset, num_replicas=num_replicas, rank=rank, shuffle=False,
                                      sample_per_file=sample_per_file)
    mock_sampler.set_epoch(epoch)
    pre_indices = mock_sampler.get_indices()
    for i in pre_indices:
        dataset.load_data(i)

In [18]:
class MoleculeDatapoint_motif:
    """A MoleculeDatapoint contains a single molecule and its associated features and targets."""

    def __init__(self,
                 line: List[str],
                 args: Namespace = None,
                 features: np.ndarray = None,
                 moltrees: object = None,
                 use_compound_names: bool = False):
        """
        Initializes a MoleculeDatapoint, which contains a single molecule.

        :param line: A list of strings generated by separating a line in a data CSV file by comma.
        :param args: Arguments.
        :param features: A numpy array containing additional features (ex. Morgan fingerprint).
        :param use_compound_names: Whether the data CSV includes the compound name on each line.
        """
        self.features_generator = None
        self.args = None
        if args is not None:
            if hasattr(args, "features_generator"):
                self.features_generator = args.features_generator
            self.args = args

        if features is not None and self.features_generator is not None:
            raise ValueError('Currently cannot provide both loaded features and a features generator.')

        self.features = features
        self.moltree_path = moltrees[0]
        self.moltree_index = moltrees[1]
        self.moltrees = moltrees

        if use_compound_names:
            self.compound_name = line[0]  # str
            line = line[1:]
        else:
            self.compound_name = None

        self.smiles = line[0]  # str


        # Generate additional features if given a generator
        if self.features_generator is not None:
            self.features = []
            mol = Chem.MolFromSmiles(self.smiles)
            for fg in self.features_generator:
                features_generator = get_features_generator(fg)
                if mol is not None and mol.GetNumHeavyAtoms() > 0:
                    if fg in ['morgan', 'morgan_count']:
                        self.features.extend(features_generator(mol, num_bits=args.num_bits))
                    else:
                        self.features.extend(features_generator(mol))

            self.features = np.array(self.features)

        # Fix nans in features
        if self.features is not None:
            replace_token = 0
            self.features = np.where(np.isnan(self.features), replace_token, self.features)

        # Create targets
        self.targets = [float(x) if x != '' else None for x in line[1:]]

    def set_features(self, features: np.ndarray):
        """
        Sets the features of the molecule.

        :param features: A 1-D numpy array of features for the molecule.
        """
        self.features = features
        
    def set_moltrees(self, moltrees: list):
        """
        Sets the moltree of the molecule.

        :param moltree: moltree object
        """
        self.moltrees = moltrees
        
    def load_moltree(self):
        """
        load moltree of the molecule.
        """
        with open(self.moltree_path, 'rb') as f:
            moltreefile = pickle.load(f)
            self.moltrees = moltreefile[self.moltree_index]
        f.close()
        
    def clean_moltree(self):
        """
        clean moltree for memory
        """
        self.moltrees = None

    def num_tasks(self) -> int:
        """
        Returns the number of prediction tasks.

        :return: The number of tasks.
        """
        return len(self.targets)

    def set_targets(self, targets: List[float]):
        """
        Sets the targets of a molecule.

        :param targets: A list of floats containing the targets.
        """
        self.targets = targets
        
    def __getitem__(self):
        

In [19]:
class BatchDatapoint_motif:
    def __init__(self,
                 smiles_file,
                 feature_file,
                 moltree_file,
                 n_samples,
                 ):
        self.smiles_file = smiles_file
        self.feature_file = feature_file
        self.moltree_file = moltree_file
        # deal with the last batch graph numbers.
        self.n_samples = n_samples
        self.datapoints = None

    def load_datapoints(self):
        features = self.load_feature()
        #moltrees = self.load_moltree()
        moltrees = self.moltree_file
        self.datapoints = []

        with open(self.smiles_file) as f:
            reader = csv.reader(f)
            next(reader)
            for i, line in enumerate(reader):
                # line = line[0]
#                d = MoleculeDatapoint_motif(line=line,
#                                      features=features[i],
#                                      moltrees=moltrees[i])
                d = MoleculeDatapoint_motif(line=line,
                                      features=features[i],
                                      moltrees=[str(self.moltree_file),i])
                self.datapoints.append(d)
        f.close()

        assert len(self.datapoints) == self.n_samples

    def load_feature(self):
        return feautils.load_features(self.feature_file)
    
    def load_moltree(self):
        return feautils.load_moltrees(self.moltree_file)

    def shuffle(self):
        pass

    def clean_cache(self):
        del self.datapoints
        self.datapoints = None

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        assert self.datapoints is not None
        return self.datapoints[idx]

    def is_loaded(self):
        return self.datapoints is not None

In [20]:
class BatchMolDataset_motif(Dataset):
    def __init__(self, data: List[BatchDatapoint_motif],
                 graph_per_file=None):
        self.data = data

        self.len = 0
        for d in self.data:
            self.len += len(d)
        if graph_per_file is not None:
            self.sample_per_file = graph_per_file
        else:
            self.sample_per_file = len(self.data[0]) if len(self.data) != 0 else None

    def shuffle(self, seed: int = None):
        pass

    def clean_cache(self):
        for d in self.data:
            d.clean_cache()

    def __len__(self) -> int:
        return self.len

    def __getitem__(self, idx) -> Union[MoleculeDatapoint_motif, List[MoleculeDatapoint_motif]]:
        # print(idx)
        dp_idx = int(idx / self.sample_per_file)
        real_idx = idx % self.sample_per_file
        return self.data[dp_idx][real_idx]

    def load_data(self, idx):
        dp_idx = int(idx / self.sample_per_file)
        if not self.data[dp_idx].is_loaded():
            self.data[dp_idx].load_datapoints()

    def count_loaded_datapoints(self):
        res = 0
        for d in self.data:
            if d.is_loaded():
                res += 1
        return res

In [21]:
class GroverMotifCollator(object):
    def __init__(self, shared_dict, atom_vocab, bond_vocab, args):
        self.args = args
        self.shared_dict = shared_dict
        self.atom_vocab = atom_vocab
        self.bond_vocab = bond_vocab

    def atom_random_mask(self, smiles_batch):
        """
        Perform the random mask operation on atoms.
        :param smiles_batch:
        :return: The corresponding atom labels.
        """
        # There is a zero padding.
        vocab_label = [0]
        percent = 0.15
        for smi in smiles_batch:
            mol = Chem.MolFromSmiles(smi)
            mlabel = [0] * mol.GetNumAtoms()
            n_mask = math.ceil(mol.GetNumAtoms() * percent)
            perm = np.random.permutation(mol.GetNumAtoms())[:n_mask]
            for p in perm:
                atom = mol.GetAtomWithIdx(int(p))
                mlabel[p] = self.atom_vocab.stoi.get(atom_to_vocab(mol, atom), self.atom_vocab.other_index)

            vocab_label.extend(mlabel)
        return vocab_label

    def bond_random_mask(self, smiles_batch):
        """
        Perform the random mask operaiion on bonds.
        :param smiles_batch:
        :return: The corresponding bond labels.
        """
        # There is a zero padding.
        vocab_label = [0]
        percent = 0.15
        for smi in smiles_batch:
            mol = Chem.MolFromSmiles(smi)
            nm_atoms = mol.GetNumAtoms()
            nm_bonds = mol.GetNumBonds()
            mlabel = []
            n_mask = math.ceil(nm_bonds * percent)
            perm = np.random.permutation(nm_bonds)[:n_mask]
            virtual_bond_id = 0
            for a1 in range(nm_atoms):
                for a2 in range(a1 + 1, nm_atoms):
                    bond = mol.GetBondBetweenAtoms(a1, a2)

                    if bond is None:
                        continue
                    if virtual_bond_id in perm:
                        label = self.bond_vocab.stoi.get(bond_to_vocab(mol, bond), self.bond_vocab.other_index)
                        mlabel.extend([label])
                    else:
                        mlabel.extend([0])

                    virtual_bond_id += 1
            # todo: might need to consider bond_drop_rate
            # todo: double check reverse bond
            vocab_label.extend(mlabel)
        return vocab_label

    def __call__(self, batch):
        smiles_batch = [d.smiles for d in batch] # 여기서 말하는 batch는 batchmoldataset_motif다 그리고 d는 batchdatapoint_motif고
        batchgraph = mol2graph(smiles_batch, self.shared_dict, self.args).get_components()

        atom_vocab_label = torch.Tensor(self.atom_random_mask(smiles_batch)).long()
        bond_vocab_label = torch.Tensor(self.bond_random_mask(smiles_batch)).long()
        fgroup_label = torch.Tensor(np.array([d.features for d in batch])).float()
        moltree_batch = [d.moltrees for d in batch]
        
        # may be some mask here
        res = {"graph_input": batchgraph,
               "targets": {"av_task": atom_vocab_label,
                           "bv_task": bond_vocab_label,
                           "fg_task": fgroup_label},
               "moltree" : moltree_batch
               }
        return res

## 혹시 이거 불러들여서?

In [22]:
import csv

In [23]:
def get_motif_data(data_path, logger=None):
    """
    Load data from the data_path.
    :param data_path: the data_path.
    :param logger: the logger.
    :return:
    """
    debug = logger.debug if logger is not None else print
    summary_path = os.path.join(data_path, "summary.txt")
    smiles_path = os.path.join(data_path, "graph")
    feature_path = os.path.join(data_path, "feature")
    moltree_path = os.path.join(data_path, "moltrees")

    fin = open(summary_path)
    n_files = int(fin.readline().strip().split(":")[-1])
    n_samples = int(fin.readline().strip().split(":")[-1])
    sample_per_file = int(fin.readline().strip().split(":")[-1])
    debug("Loading data:")
    debug("Number of files: %d" % n_files)
    debug("Number of samples: %d" % n_samples)
    debug("Samples/file: %d" % sample_per_file)

    datapoints = []
    for i in range(n_files):
        smiles_path_i = os.path.join(smiles_path, str(i) + ".csv")
        feature_path_i = os.path.join(feature_path, str(i) + ".npz")
        moltree_path_i = os.path.join(moltree_path, str(i) + ".p")
        n_samples_i = sample_per_file if i != (n_files - 1) else n_samples % sample_per_file
        datapoints.append(BatchDatapoint_motif(smiles_path_i, feature_path_i, moltree_path_i, n_samples_i))
    return BatchMolDataset_motif(datapoints), sample_per_file

In [24]:
def split_data_motif(data,
               split_type='random',
               sizes=(0.8, 0.1, 0.1),
               seed=0,
               logger=None):
    """
    Split data with given train/validation/test ratio.
    :param data:
    :param split_type:
    :param sizes:
    :param seed:
    :param logger:
    :return:
    """
    assert len(sizes) == 3 and sum(sizes) == 1

    if split_type == "random":
        data.shuffle(seed=seed)
        data = data.data

        train_size = int(sizes[0] * len(data))
        train_val_size = int((sizes[0] + sizes[1]) * len(data))

        train = data[:train_size]
        val = data[train_size:train_val_size]
        test = data[train_val_size:]

        return BatchMolDataset_motif(train), BatchMolDataset_motif(val), BatchMolDataset_motif(test)
    else:
        raise NotImplementedError("Do not support %s splits" % split_type)

In [25]:
class GROVERMotifTrainer:
    def __init__(self,
                 args,
                 embedding_model: Module,
                 topology_model: Module,
                 atom_vocab_size: int,  # atom vocab size
                 bond_vocab_size: int,
                 fg_size: int,
                 train_dataloader: DataLoader,
                 test_dataloader: DataLoader,
                 optimizer_builder: Callable,
                 scheduler_builder: Callable,
                 logger: Logger = None,
                 with_cuda: bool = False,
                 enable_multi_gpu: bool = False):
        """
        The init function of GROVERTrainer
        :param args: the input arguments.
        :param embedding_model: the model to generate atom/bond embeddings.
        :param topology_model : the model to predict topology of molecule from embeddings
        :param atom_vocab_size: the vocabulary size of atoms.
        :param bond_vocab_size: the vocabulary size of bonds.
        :param fg_size: the size of semantic motifs (functional groups)
        :param train_dataloader: the data loader of train data.
        :param test_dataloader: the data loader of validation data.
        :param optimizer_builder: the function of building the optimizer.
        :param scheduler_builder: the function of building the scheduler.
        :param logger: the logger
        :param with_cuda: enable gpu training.
        :param enable_multi_gpu: enable multi_gpu traning.
        """

        self.args = args
        self.with_cuda = with_cuda
        self.grover = embedding_model
        self.model = GroverMotifTask(args, embedding_model, atom_vocab_size, bond_vocab_size, fg_size)
        self.motif_model = topology_model
        self.loss_func = self.model.get_loss_func(args)
        self.enable_multi_gpu = enable_multi_gpu

        self.atom_vocab_size = atom_vocab_size
        self.bond_vocab_size = bond_vocab_size
        self.debug = logger.debug if logger is not None else print

        if self.with_cuda:
            # print("Using %d GPUs for training." % (torch.cuda.device_count()))
            self.model = self.model.cuda()
            self.motif_model = self.motif_model.cuda()

        self.train_data = train_dataloader
        self.test_data = test_dataloader

        self.optimizer = optimizer_builder(self.model, self.args)
        self.motif_optimizer = torch.optim.Adam(self.motif_model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay)
        self.scheduler = scheduler_builder(self.optimizer, self.args)
        if self.enable_multi_gpu:
            self.optimizer = mgw.DistributedOptimizer(self.optimizer,
                                                      named_parameters=self.model.named_parameters())
        self.args = args
        self.n_iter = 0

    def broadcast_parameters(self) -> None:
        """
        Broadcast parameters before training.
        :return: no return.
        """
        if self.enable_multi_gpu:
            # broadcast parameters & optimizer state.
            mgw.broadcast_parameters(self.model.state_dict(), root_rank=0)
            mgw.broadcast_optimizer_state(self.optimizer, root_rank=0)

    def train(self, epoch: int) -> List:
        """
        The training iteration
        :param epoch: the current epoch number.
        :return: the loss terms of current epoch.
        """
        # return self.mock_iter(epoch, self.train_data, train=True)
        return self.iter(epoch, self.train_data, train=True)

    def test(self, epoch: int) -> List:
        """
        The test/validaiion iteration
        :param epoch: the current epoch number.
        :return:  the loss terms as a list
        """
        # return self.mock_iter(epoch, self.test_data, train=False)
        return self.iter(epoch, self.test_data, train=False)

    def mock_iter(self, epoch: int, data_loader: DataLoader, train: bool = True) -> List:
        """
        Perform a mock iteration. For test only.
        :param epoch: the current epoch number.
        :param data_loader: the data loader.
        :param train: True: train model, False: validation model.
        :return: the loss terms as a list
        """

        for _, _ in enumerate(data_loader):
            self.scheduler.step()
        cum_loss_sum = 0.0
        self.n_iter += self.args.batch_size
        return self.n_iter, cum_loss_sum, (0, 0, 0, 0, 0, 0)

    def iter(self, epoch, data_loader, train=True) -> List:
        """
        Perform a training / validation iteration.
        :param epoch: the current epoch number.
        :param data_loader: the data loader.
        :param train: True: train model, False: validation model.
        :return: the loss terms as a list
        """

        if train:
            self.model.train()
            self.motif_model.train()
        else:
            self.model.eval()
            self.motif_model.eval()

        loss_sum, iter_count = 0, 0
        cum_loss_sum, cum_iter_count = 0, 0
        av_loss_sum, bv_loss_sum, fg_loss_sum, av_dist_loss_sum, bv_dist_loss_sum, fg_dist_loss_sum, node_loss_sum, topo_loss_sum = 0, 0, 0, 0, 0, 0, 0, 0
        
        topo_acc_avg, node_acc_avg = 0, 0
        # loss_func = self.model.get_loss_func(self.args)

        for _, item in enumerate(data_loader):
            batch_graph = item["graph_input"]
            targets = item["targets"]
            
            # add this for motif generation
            moltree_paths = item["moltree"]
            
            moltree = list()
            for _, item in enumerate(moltree_paths):
                moltree.append(load_moltree(item[0],item[1]))

            if next(self.model.parameters()).is_cuda:
                targets["av_task"] = targets["av_task"].cuda()
                targets["bv_task"] = targets["bv_task"].cuda()
                targets["fg_task"] = targets["fg_task"].cuda()
            
            preds = self.model(batch_graph)
            emb_vector = preds['emb_vec']

            # add this for motif generation
            if self.args.embedding_output_type == 'atom':
                emb_afa_grouped = group_node_rep(moltree, emb_vector['atom_from_atom'],batch_graph)
                emb_afb_grouped = group_node_rep(moltree, emb_vector['atom_from_bond'],batch_graph)
                
                node_afa_loss, topo_afa_loss, node_afa_acc, topo_afa_acc = self.motif_model(moltree, emb_afa_grouped)
                node_afb_loss, topo_afb_loss, node_afb_acc, topo_afb_acc = self.motif_model(moltree, emb_afb_grouped)
                
                node_loss = node_afa_loss + node_afb_loss
                topo_loss = topo_afa_loss + topo_afb_loss
                node_acc = (node_afa_acc + node_afb_acc)/2
                topo_acc = (topo_afa_acc + topo_afb_acc)/2
                
            elif self.args.embedding_output_type == 'bond':
                emb_bfa_grouped = group_node_rep(moltree, emb_vector['bond_from_atom'],batch_graph)
                emb_bfb_grouped = group_node_rep(moltree, emb_vector['bond_from_bond'],batch_graph)
                
                node_bfa_loss, topo_bfa_loss, node_bfa_acc, topo_bfa_acc = self.motif_model(moltree, emb_bfa_grouped)
                node_bfb_loss, topo_bfb_loss, node_bfb_acc, topo_bfb_acc = self.motif_model(moltree, emb_bfb_grouped)
                
                node_loss = node_bfa_loss + node_bfb_loss
                topo_loss = topo_bfa_loss + topo_bfb_loss
                node_acc = (node_bfa_acc + node_bfb_acc)/2
                topo_acc = (topo_bfa_acc + topo_bfb_acc)/2
                
            elif self.args.embedding_output_type == "both":
                emb_afa_grouped = group_node_rep(moltree, emb_vector['atom_from_atom'],batch_graph)
                emb_afb_grouped = group_node_rep(moltree, emb_vector['atom_from_bond'],batch_graph)
                emb_bfa_grouped = group_node_rep(moltree, emb_vector['bond_from_atom'],batch_graph)
                emb_bfb_grouped = group_node_rep(moltree, emb_vector['bond_from_bond'],batch_graph)
                
                node_afa_loss, topo_afa_loss, node_afa_acc, topo_afa_acc = self.motif_model(moltree, emb_afa_grouped)
                node_afb_loss, topo_afb_loss, node_afb_acc, topo_afb_acc = self.motif_model(moltree, emb_afb_grouped)
                node_bfa_loss, topo_bfa_loss, node_bfa_acc, topo_bfa_acc = self.motif_model(moltree, emb_bfa_grouped)
                node_bfb_loss, topo_bfb_loss, node_bfb_acc, topo_bfb_acc = self.motif_model(moltree, emb_bfb_grouped)
                
                node_loss = node_afa_loss + node_afb_loss + node_bfa_loss + node_bfb_loss
                topo_loss = topo_afa_loss + topo_afb_loss + topo_bfa_loss + topo_bfb_loss
                node_acc = (node_afa_acc + node_afb_acc + node_bfa_acc + node_bfb_acc)/4
                topo_acc = (topo_afa_acc + topo_afb_acc + topo_bfa_acc + topo_bfb_acc)/4

            # # ad-hoc code, for visualizing a model, comment this block when it is not needed
            # import dglt.contrib.grover.vis_model as vis_model
            # for task in ['av_task', 'bv_task', 'fg_task']:
            #     vis_graph = vis_model.make_dot(self.model(batch_graph)[task],
            #                                    params=dict(self.model.named_parameters()))
            #     # vis_graph.view()
            #     vis_graph.render(f"{self.args.backbone}_model_{task}_vis.png", format="png")
            # exit()

            loss, av_loss, bv_loss, fg_loss, av_dist_loss, bv_dist_loss, fg_dist_loss = self.loss_func(preds, targets)

            loss_sum += loss.item()
            iter_count += self.args.batch_size
            
            # add for topology loss
            loss += topo_loss
            loss += node_loss
            topo_loss_sum += topo_loss.item()
            node_loss_sum += node_loss.item()

            if train:
                cum_loss_sum += loss.item()
                # Run model
                self.model.zero_grad()
                self.motif_model.zero_grad()
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.motif_optimizer.step()
                self.scheduler.step()
            else:
                # For eval model, only consider the loss of three task.
                cum_loss_sum += av_loss.item()
                cum_loss_sum += bv_loss.item()
                cum_loss_sum += fg_loss.item()

            av_loss_sum += av_loss.item()
            bv_loss_sum += bv_loss.item()
            fg_loss_sum += fg_loss.item()
            av_dist_loss_sum += av_dist_loss.item() if type(av_dist_loss) != float else av_dist_loss
            bv_dist_loss_sum += bv_dist_loss.item() if type(bv_dist_loss) != float else bv_dist_loss
            fg_dist_loss_sum += fg_dist_loss.item() if type(fg_dist_loss) != float else fg_dist_loss

            cum_iter_count += 1
            self.n_iter += self.args.batch_size

            # Debug only.
            # if i % 50 == 0:
            #     print(f"epoch: {epoch}, batch_id: {i}, av_loss: {av_loss}, bv_loss: {bv_loss}, "
            #           f"fg_loss: {fg_loss}, av_dist_loss: {av_dist_loss}, bv_dist_loss: {bv_dist_loss}, "
            #           f"fg_dist_loss: {fg_dist_loss}")

        cum_loss_sum /= cum_iter_count
        av_loss_sum /= cum_iter_count
        bv_loss_sum /= cum_iter_count
        fg_loss_sum /= cum_iter_count
        av_dist_loss_sum /= cum_iter_count
        bv_dist_loss_sum /= cum_iter_count
        fg_dist_loss_sum /= cum_iter_count
        
        topo_loss_sum /= cum_iter_count
        node_loss_sum /= cum_iter_count

        return self.n_iter, cum_loss_sum, (av_loss_sum, bv_loss_sum, fg_loss_sum, av_dist_loss_sum,
                                           bv_dist_loss_sum, fg_dist_loss_sum, topo_loss_sum, node_loss_sum, topo_acc, node_acc)

    def save(self, epoch, file_path, name=None) -> str:
        """
        Save the intermediate models during training.
        :param epoch: the epoch number.
        :param file_path: the file_path to save the model.
        :return: the output path.
        """
        # add specific time in model fine name, in order to distinguish different saved models
        now = time.localtime()
        if name is None:
            name = "_%04d_%02d_%02d_%02d_%02d_%02d" % (
                now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min, now.tm_sec)
        output_path = file_path + name + ".ep%d" % epoch
        scaler = None
        features_scaler = None
        state = {
            'args': self.args,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler_step': self.scheduler.current_step,
            "epoch": epoch,
            'data_scaler': {
                'means': scaler.means,
                'stds': scaler.stds
            } if scaler is not None else None,
            'features_scaler': {
                'means': features_scaler.means,
                'stds': features_scaler.stds
            } if features_scaler is not None else None
        }
        torch.save(state, output_path)

        # Is this necessary?
        # if self.with_cuda:
        #    self.model = self.model.cuda()
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path

    def save_tmp(self, epoch, file_path, rank=0):
        """
        Save the models for auto-restore during training.
        The model are stored in file_path/tmp folder and will replaced on each epoch.
        :param epoch: the epoch number.
        :param file_path: the file_path to store the model.
        :param rank: the current rank (decrypted).
        :return:
        """
        store_path = os.path.join(file_path, "tmp")
        if not os.path.exists(store_path):
            os.makedirs(store_path, exist_ok=True)
        store_path = os.path.join(store_path, "model.%d" % rank)
        state = {
            'args': self.args,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler_step': self.scheduler.current_step,
            "epoch": epoch
        }
        torch.save(state, store_path)

    def restore(self, file_path, rank=0) -> Tuple[int, int]:
        """
        Restore the training state saved by save_tmp.
        :param file_path: the file_path to store the model.
        :param rank: the current rank (decrypted).
        :return: the restored epoch number and the scheduler_step in scheduler.
        """
        cpt_path = os.path.join(file_path, "tmp", "model.%d" % rank)
        if not os.path.exists(cpt_path):
            print("No checkpoint found %d")
            return 0, 0
        cpt = torch.load(cpt_path)
        self.model.load_state_dict(cpt["state_dict"])
        self.optimizer.load_state_dict(cpt["optimizer"])
        epoch = cpt["epoch"]
        scheduler_step = cpt["scheduler_step"]
        self.scheduler.current_step = scheduler_step
        print("Restore checkpoint, current epoch: %d" % (epoch))
        return epoch, scheduler_step


# 실험장

In [68]:
if logger is not None:
    debug, _ = logger.debug, logger.info
else:
    debug = print

# initialize the horovod library
if args.enable_multi_gpu:
    mgw.init()

# binding training to GPUs.
master_worker = (mgw.rank() == 0) if args.enable_multi_gpu else True
# pin GPU to local rank. By default, we use gpu:0 for training.
local_gpu_idx = mgw.local_rank() if args.enable_multi_gpu else 0
with_cuda = args.cuda
if with_cuda:
    torch.cuda.set_device(local_gpu_idx)

# get rank an  number of workers
rank = mgw.rank() if args.enable_multi_gpu else 0
num_replicas = mgw.size() if args.enable_multi_gpu else 1
# print("Rank: %d Rep: %d" % (rank, num_replicas))

# load file paths of the data.
if master_worker:
    print(args)
    if args.enable_multi_gpu:
        debug("Total workers: %d" % (mgw.size()))
    debug('Loading data')
data, sample_per_file = get_motif_data(data_path=args.data_path)

# data splitting
if master_worker:
    debug(f'Splitting data with seed 0.')
train_data, test_data, _ = split_data_motif(data=data, sizes=(0.9, 0.1, 0.0), seed=0, logger=logger)

# Here the true train data size is the train_data divided by #GPUs
if args.enable_multi_gpu:
    args.train_data_size = len(train_data) // mgw.size()
else:
    args.train_data_size = len(train_data)
if master_worker:
    debug(f'Total size = {len(data):,} | '
          f'train size = {len(train_data):,} | val size = {len(test_data):,}')

# load atom and bond vocabulary and the semantic motif labels.
atom_vocab = MolVocab.load_vocab(args.atom_vocab_path)
bond_vocab = MolVocab.load_vocab(args.bond_vocab_path)
atom_vocab_size, bond_vocab_size = len(atom_vocab), len(bond_vocab)

# Load motif vocabulary for pretrain
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
motif_vocab = [x.strip("\r\n ") for x in open(args.motif_vocab_path)]
motif_vocab = Motif_Vocab(motif_vocab)

# Hard coding here, since we haven't load any data yet!
fg_size = 85
shared_dict = {}
motif_collator = GroverMotifCollator(shared_dict=shared_dict, atom_vocab=atom_vocab, bond_vocab=bond_vocab, args=args)
if master_worker:
    debug("atom vocab size: %d, bond vocab size: %d, Number of FG tasks: %d" % (atom_vocab_size,
                                                                                bond_vocab_size, fg_size))

# Define the distributed sampler. If using the single card, the sampler will be None.
train_sampler = None
test_sampler = None
shuffle = True
if args.enable_multi_gpu:
    # If not shuffle, the performance may decayed.
    train_sampler = DistributedSampler(
        train_data, num_replicas=mgw.size(), rank=mgw.rank(), shuffle=True, sample_per_file=sample_per_file)
    # Here sample_per_file in test_sampler is None, indicating the test sampler would not divide the test samples by
    # rank. (TODO: bad design here.)
    test_sampler = DistributedSampler(
        test_data, num_replicas=mgw.size(), rank=mgw.rank(), shuffle=False)
    train_sampler.set_epoch(args.epochs)
    test_sampler.set_epoch(1)
    # if we enables multi_gpu training. shuffle should be disabled.
    shuffle = False
    
    # Pre load data. (Maybe unnecessary. )
pre_load_data(train_data, rank, num_replicas, sample_per_file)
pre_load_data(test_data, rank, num_replicas)

Loading data
Loading data
Loading data
Loading data
Splitting data with seed 0.
Splitting data with seed 0.
Splitting data with seed 0.
Splitting data with seed 0.
Total size = 500,000 | train size = 450,000 | val size = 50,000
Total size = 500,000 | train size = 450,000 | val size = 50,000
Total size = 500,000 | train size = 450,000 | val size = 50,000
Total size = 500,000 | train size = 450,000 | val size = 50,000


Namespace(activation='PReLU', atom_vocab_path='data/zinc10M/zinc10M_atom_vocab.pkl', backbone='gtrans', batch_size=100, bias=False, bond_drop_rate=0, bond_vocab_path='data/zinc10M/zinc10M_bond_vocab.pkl', cuda=True, data_path='data/zinc10M_0', dense=False, depth=3, dist_coff=0.1, dropout=0.1, embedding_output_type='both', enable_multi_gpu=False, epochs=20, fg_label_path=None, final_lr=0.0001, fine_tune_coff=1, hidden_size=1200, init_lr=0.0002, max_lr=0.0004, motif_hidden_size=1200, motif_latent_size=56, motif_order='dfs', motif_vocab_path='data/zinc10M/clique.txt', no_cache=True, num_attn_head=4, num_mt_block=1, parser_name='pretrain', save_dir='model/zinc10M_0', save_interval=5, topology=True, undirected=False, wandb=False, wandb_name='pretrain', warmup_epochs=2.0, weight_decay=1e-07)
Loading data:
Number of files: 501
Number of samples: 500000
Samples/file: 1000


atom vocab size: 521, bond vocab size: 942, Number of FG tasks: 85
atom vocab size: 521, bond vocab size: 942, Number of FG tasks: 85
atom vocab size: 521, bond vocab size: 942, Number of FG tasks: 85
atom vocab size: 521, bond vocab size: 942, Number of FG tasks: 85


In [69]:
# Build dataloader
train_data_dl = DataLoader(train_data,
                           batch_size=150,
                           shuffle=shuffle,
                           num_workers=10,
                           sampler=train_sampler,
                           collate_fn=motif_collator)
test_data_dl = DataLoader(test_data,
                          batch_size=150,
                          shuffle=shuffle,
                          num_workers=10,
                          sampler=test_sampler,
                          collate_fn=motif_collator)

# Build the embedding model.
grover_model = GROVEREmbedding(args)

# build the topology predict model.
motif_model = Motif_Generation(motif_vocab, args.motif_hidden_size, args.motif_latent_size, 3, device, args.motif_order)

#  Build the trainer.
trainer = GROVERMotifTrainer(args=args,
                        embedding_model=grover_model,
                        topology_model=motif_model,
                        atom_vocab_size=atom_vocab_size,
                        bond_vocab_size=bond_vocab_size,
                        fg_size=fg_size,
                        train_dataloader=train_data_dl,
                        test_dataloader=test_data_dl,
                        optimizer_builder=build_optimizer,
                        scheduler_builder=build_lr_scheduler,
                        logger=logger,
                        with_cuda=with_cuda,
                        enable_multi_gpu=args.enable_multi_gpu)

# Restore the interrupted training.
model_dir = os.path.join(args.save_dir, "model")
resume_from_epoch = 0
resume_scheduler_step = 0
if master_worker:
    resume_from_epoch, resume_scheduler_step = trainer.restore(model_dir)
if args.enable_multi_gpu:
    resume_from_epoch = mgw.broadcast(torch.tensor(resume_from_epoch), root_rank=0, name="resume_from_epoch").item()
    resume_scheduler_step = mgw.broadcast(torch.tensor(resume_scheduler_step),
                                          root_rank=0, name="resume_scheduler_step").item()
    trainer.scheduler.current_step = resume_scheduler_step
    print("Restored epoch: %d Restored scheduler step: %d" % (resume_from_epoch, trainer.scheduler.current_step))
trainer.broadcast_parameters()

# Print model details.
#if master_worker:
    # Change order here.
    #print(grover_model)
    #print("Total parameters: %d" % param_count(trainer.grover))

No checkpoint found %d


In [70]:
    #wandb
wandb.init(project='load_after_loader')
wandb.config = args
wandb.watch(grover_model)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mppn0303[0m. Use [1m`wandb login --relogin`[0m to force relogin


[]

In [71]:
def load_moltree(path: str, index: int) -> np.ndarray:
    """
    Loads features saved in a variety of formats.

    Supported formats:
    - .npz compressed (assumes features are saved with name "features")

    All formats assume that the SMILES strings loaded elsewhere in the code are in the same
    order as the features loaded here.

    :param path: Path to a file containing features.
    :return: A 2D numpy array of size (num_molecules, features_size) containing the features.
    """
    extension = os.path.splitext(path)[1]

    with open(path, 'rb') as f:
        moltrees = pickle.load(f)[index]
    f.close()
    return moltrees

In [80]:
stime = time.time()
moltree = []
for _, item in enumerate(moltree_paths):
    moltree.append(load_moltrees(item[0],item[1]))
dtime = time.time()
print(dtime-stime)

55.592586278915405


In [79]:
stime = time.time()
moltree = list(0 for i in range(len(moltree_paths)))
for _, item in enumerate(moltree_paths):
    moltree.append(load_moltrees(item[0],item[1]))
dtime = time.time()
print(dtime-stime)

55.522074699401855


In [72]:
from memory_profiler import memory_usage

In [None]:
# Perform training.
best_val_loss = 0
best_val_epoch = 0
best_model_dir = os.path.join(args.save_dir, "model_best")
print(f"before train : {memory_usage()[0]:.2f} MiB")
for epoch in range(resume_from_epoch + 1, args.epochs):
    s_time = time.time()
    print(f"epochs start memory is {memory_usage()[0]:.2f} MiB")
    # Data pre-loading.
    if args.enable_multi_gpu:
        train_sampler.set_epoch(epoch)
        train_data.clean_cache()
        idxs = train_sampler.get_indices()
        for local_gpu_idx in idxs:
            train_data.load_data(local_gpu_idx)
    
    

    d_time = time.time() - s_time
    

    # perform training and validation.
    s_time = time.time()
    _, train_loss, _ = trainer.train(epoch)
    print(f"after train memory is {memory_usage()[0]:.2f} MiB")
    t_time = time.time() - s_time
    s_time = time.time()
    _, val_loss, detailed_loss_val = trainer.test(epoch)
    print(f"after validation memory is {memory_usage()[0]:.2f} MiB")
    val_av_loss, val_bv_loss, val_fg_loss, _, _, _, val_topo_loss, val_node_loss, topo_acc, node_acc = detailed_loss_val
    v_time = time.time() - s_time

    if best_val_loss > val_loss:
        best_val_loss = val_loss
        best_val_epoch = epoch
        trainer.save(epoch, best_model_dir)

    if args.wandb :         
        wandb.log({"train_loss" : train_loss, "val_loss" : val_loss, "topo_loss" : val_topo_loss})

    # print information.
    if master_worker:
        print('Epoch: {:04d}'.format(epoch),
              'loss_train: {:.6f}'.format(train_loss),
              'loss_val: {:.6f}'.format(val_loss),
              'loss_val_av: {:.6f}'.format(val_av_loss),
              'loss_val_bv: {:.6f}'.format(val_bv_loss),
              'loss_val_fg: {:.6f}'.format(val_fg_loss),
              'loss_val_topo: {:.6f}'.format(val_topo_loss),
              'loss_val_node: {:.6f}'.format(val_node_loss),
              'acc_topo: {:.6f}'.format(topo_acc),
              'acc_node: {:.6f}'.format(node_acc),
              'cur_lr: {:.5f}'.format(trainer.scheduler.get_lr()[0]),
              't_time: {:.4f}s'.format(t_time),
              'v_time: {:.4f}s'.format(v_time),
              'd_time: {:.4f}s'.format(d_time), flush=True)


        if epoch % args.save_interval == 0:
            trainer.save(epoch, model_dir)


        trainer.save_tmp(epoch, model_dir, rank)
        print(f"after save cp memory is {memory_usage()[0]:.2f} MiB")

# Only save final version.
if master_worker:
    trainer.save(args.epochs, model_dir, "")

before train : 5779.71 MiB
epochs start memory is 5779.71 MiB


In [36]:
load_moltrees(item[0],item[1])

<grover.topology.mol_tree.MolTree at 0x7f7701cd5490>

In [34]:
preds

{'atom_from_atom': tensor([[ 0.5286,  0.1908,  0.1559,  ...,  0.0000,  0.2753, -1.2593],
         [ 0.3252, -0.7777,  0.7843,  ..., -0.4200, -0.5239,  0.0000],
         [ 0.8270, -0.7058,  1.4171,  ...,  0.0000, -0.8003, -1.1189],
         ...,
         [-0.0897, -0.7822,  0.4744,  ...,  0.5085,  0.0036, -0.2264],
         [ 0.0317, -0.1056,  0.0000,  ...,  0.8393, -0.3412, -1.7130],
         [ 0.1424,  0.0848,  0.7505,  ..., -0.0120, -0.7686, -0.9019]],
        device='cuda:0', grad_fn=<FusedDropoutBackward>),
 'bond_from_atom': tensor([[-0.8544,  1.2712,  1.2164,  ...,  0.0085,  1.2820,  0.0000],
         [-0.3456,  0.9824,  0.7136,  ..., -0.1406,  1.1398,  0.0440],
         [-1.0875,  0.2222,  0.0810,  ...,  0.0000,  0.1001, -0.5310],
         ...,
         [-0.6528,  0.9574,  1.1351,  ...,  0.6902,  1.4957, -2.3826],
         [-1.0220,  0.9658,  0.4182,  ..., -0.7550,  0.5864, -1.0845],
         [-1.2348,  1.4471,  1.8378,  ..., -0.2258,  0.9261, -1.6320]],
        device='cuda:0',

In [35]:
loss_sum, iter_count = 0, 0
cum_loss_sum, cum_iter_count = 0, 0
av_loss_sum, bv_loss_sum, fg_loss_sum, av_dist_loss_sum, bv_dist_loss_sum, fg_dist_loss_sum, node_loss_sum, topo_loss_sum = 0, 0, 0, 0, 0, 0, 0, 0

topo_acc_avg, node_acc_avg = 0, 0
# loss_func = self.model.get_loss_func(self.args)

for _, item in enumerate(train_data_dl):
    batch_graph = item["graph_input"]
    targets = item["targets"]

    # add this for motif generation
    moltree_paths = item["moltree"]

    moltree = list()
    for _, item in enumerate(moltree_paths):
        moltree.append(load_moltrees(item[0],item[1]))


    targets["av_task"] = targets["av_task"].cuda()
    targets["bv_task"] = targets["bv_task"].cuda()
    targets["fg_task"] = targets["fg_task"].cuda()

    preds = grover_model(batch_graph)
    emb_vector = preds#['emb_vec']

    emb_afa_grouped = group_node_rep(moltree, emb_vector['atom_from_atom'],batch_graph)
    emb_afb_grouped = group_node_rep(moltree, emb_vector['atom_from_bond'],batch_graph)
    emb_bfa_grouped = group_node_rep(moltree, emb_vector['bond_from_atom'],batch_graph)
    emb_bfb_grouped = group_node_rep(moltree, emb_vector['bond_from_bond'],batch_graph)

    node_afa_loss, topo_afa_loss, node_afa_acc, topo_afa_acc = motif_model(moltree, emb_afa_grouped)
    node_afb_loss, topo_afb_loss, node_afb_acc, topo_afb_acc = motif_model(moltree, emb_afb_grouped)
    node_bfa_loss, topo_bfa_loss, node_bfa_acc, topo_bfa_acc = motif_model(moltree, emb_bfa_grouped)
    node_bfb_loss, topo_bfb_loss, node_bfb_acc, topo_bfb_acc = motif_model(moltree, emb_bfb_grouped)

    node_loss = node_afa_loss + node_afb_loss + node_bfa_loss + node_bfb_loss
    topo_loss = topo_afa_loss + topo_afb_loss + topo_bfa_loss + topo_bfb_loss
    node_acc = (node_afa_acc + node_afb_acc + node_bfa_acc + node_bfb_acc)/4
    topo_acc = (topo_afa_acc + topo_afb_acc + topo_bfa_acc + topo_bfb_acc)/4


NameError: name 'self' is not defined

In [None]:
def pretrain_model(args: Namespace, logger: Logger = None):
    """
    The entrey of pretrain.
    :param args: the argument.
    :param logger: the logger.
    :return:
    """

    # avoid auto optimized import by pycharm.
    a = MolVocab
    s_time = time.time()
    if args.topology : 
        run_motif_training(args=args, logger=logger)
    else : 
        run_training(args=args, logger=logger)
    e_time = time.time()
    print("Total Time: %.3f" % (e_time - s_time))

In [47]:
if args.parser_name == 'pretrain':
    logger = create_logger(name='pretrain', save_dir=args.save_dir)
    pretrain_model(args, logger)


Loading data
Loading data
Splitting data with seed 0.
Splitting data with seed 0.
Total size = 500,000 | train size = 450,000 | val size = 50,000
Total size = 500,000 | train size = 450,000 | val size = 50,000


Namespace(activation='PReLU', atom_vocab_path='data/zinc10M/zinc10M_atom_vocab.pkl', backbone='gtrans', batch_size=100, bias=False, bond_drop_rate=0, bond_vocab_path='data/zinc10M/zinc10M_bond_vocab.pkl', cuda=True, data_path='data/zinc10M_0', dense=False, depth=3, dist_coff=0.1, dropout=0.1, embedding_output_type='both', enable_multi_gpu=False, epochs=20, fg_label_path=None, final_lr=0.0001, fine_tune_coff=1, hidden_size=1200, init_lr=0.0002, max_lr=0.0004, motif_hidden_size=1200, motif_latent_size=56, motif_order='dfs', motif_vocab_path='data/zinc10M/clique.txt', no_cache=True, num_attn_head=4, num_mt_block=1, parser_name='pretrain', save_dir='model/ChEMBL', save_interval=5, topology=True, train_data_size=450000, undirected=False, wandb=False, wandb_name='pretrain', warmup_epochs=2.0, weight_decay=1e-07)
Loading data:
Number of files: 501
Number of samples: 500000
Samples/file: 1000


atom vocab size: 521, bond vocab size: 942, Number of FG tasks: 85
atom vocab size: 521, bond vocab size: 942, Number of FG tasks: 85
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f71212227a0>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/opt/conda/lib/python3.7/multiprocessing/process.py", line 140, in join
    res = self._popen.wait(timeout)
  File "/opt/conda/lib/python3.7/multiprocessing/popen_fork.py", line 45, in wait
    if not wait([self.sentinel], timeout):
  File "/opt/conda/lib/python3.7/multiprocessing/connection.py", line 921, in wait
    ready = selector.select(timeout)
  File "/opt/conda/lib/python3.7/selectors.py", line 415, in select
    fd_event_list

Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3457, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_4023/3879309671.py", line 3, in <module>
    pretrain_model(args, logger)
  File "/root/grover/task/pretrain.py", line 37, in pretrain_model
    run_motif_training(args=args, logger=logger)
  File "/root/grover/task/pretrain.py", line 339, in run_motif_training
    pre_load_data(train_data, rank, num_replicas, sample_per_file)
  File "/root/grover/task/pretrain.py", line 60, in pre_load_data
    dataset.load_data(i)
  File "/root/grover/grover/data/groverdataset.py", line 163, in load_data
    def __init__(self, data: List[BatchDatapoint],
  File "/root/grover/grover/data/groverdataset.py", line 299, in load_datapoints
    moltrees = self.load_moltree()
  File "/root/grover/grover/data/groverdataset.py", line 319, in load_moltree
    return feautils.load_moltrees(self.molt

TypeError: object of type 'NoneType' has no len()