In [4]:
import random

import numpy as np
import torch
from rdkit import RDLogger

from grover.util.parsing import parse_args, get_newest_train_args
from grover.util.utils import create_logger
from task.cross_validate import cross_validate, randomsearch, gridsearch, make_confusion_matrix
from task.fingerprint import generate_fingerprints
from task.predict import make_predictions, write_prediction
from task.pretrain import pretrain_model
from grover.data.torchvocab import MolVocab

#add for gridsearch
from argparse import ArgumentParser, Namespace

In [5]:
def setup(seed):
    # frozen random seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

In [6]:
# setup random seed
setup(seed=42)
# Avoid the pylint warning.
a = MolVocab
# supress rdkit logger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

# Initialize MolVocab
mol_vocab = MolVocab

# 0. save_features.py

In [7]:
import os
import shutil
import sys
from argparse import ArgumentParser, Namespace
from multiprocessing import Pool
from typing import List, Tuple

from tqdm import tqdm

from grover.util.utils import get_data, makedirs, load_features, save_features
from grover.data.molfeaturegenerator import get_available_features_generators, \
    get_features_generator
from grover.data.task_labels import rdkit_functional_group_label_features_generator

def load_temp(temp_dir: str) -> Tuple[List[List[float]], int]:
    """
    Loads all features saved as .npz files in load_dir.

    Assumes temporary files are named in order 0.npz, 1.npz, ...

    :param temp_dir: Directory in which temporary .npz files containing features are stored.
    :return: A tuple with a list of molecule features, where each molecule's features is a list of floats,
    and the number of temporary files.
    """
    features = []
    temp_num = 0
    temp_path = os.path.join(temp_dir, f'{temp_num}.npz')

    while os.path.exists(temp_path):
        features.extend(load_features(temp_path))
        temp_num += 1
        temp_path = os.path.join(temp_dir, f'{temp_num}.npz')

    return features, temp_num

In [2]:
parser = ArgumentParser()
parser.add_argument('--data_path', type=str, required=True,
                    help='Path to data CSV')
parser.add_argument('--features_generator', type=str, required=True,
                    choices=get_available_features_generators(),
                    help='Type of features to generate')
parser.add_argument('--save_path', type=str, default=None,
                    help='Path to .npz file where features will be saved as a compressed numpy archive')
parser.add_argument('--save_frequency', type=int, default=10000,
                    help='Frequency with which to save the features')
parser.add_argument('--restart', action='store_true', default=False,
                    help='Whether to not load partially complete featurization and instead start from scratch')
parser.add_argument('--max_data_size', type=int,
                    help='Maximum number of data points to load')
parser.add_argument('--sequential', action='store_true', default=False,
                    help='Whether to task sequentially rather than in parallel')
args = parser.parse_args(['--data_path','data/testfiles/mgssl.csv','--save_path', 'data/pretrain/mgssl/CO2.npz', '--features_generator','fgtasklabel','--restart'])
args

Namespace(data_path='data/CO2.csv', features_generator='fgtasklabel', max_data_size=None, restart=True, save_frequency=10000, save_path='data/CO2.npz', sequential=False)

## 0-1. generate_and_save_features()

## 0-1-1. fgtasklabel
- RDKIT을 통해 smiles에서 찾을 수 있는 FG(모티프)들을 feature로써 출력한다

In [6]:
from collections import Counter
from typing import Callable, Union

import numpy as np
from rdkit import Chem
from descriptastorus.descriptors import rdDescriptors

from grover.data.molfeaturegenerator import register_features_generator

Molecule = Union[str, Chem.Mol]
FeaturesGenerator = Callable[[Molecule], np.ndarray]

In [7]:
# The functional group descriptors in RDkit.
RDKIT_PROPS = ['fr_Al_COO', 'fr_Al_OH', 'fr_Al_OH_noTert', 'fr_ArN',
               'fr_Ar_COO', 'fr_Ar_N', 'fr_Ar_NH', 'fr_Ar_OH', 'fr_COO', 'fr_COO2',
               'fr_C_O', 'fr_C_O_noCOO', 'fr_C_S', 'fr_HOCCN', 'fr_Imine', 'fr_NH0',
               'fr_NH1', 'fr_NH2', 'fr_N_O', 'fr_Ndealkylation1', 'fr_Ndealkylation2',
               'fr_Nhpyrrole', 'fr_SH', 'fr_aldehyde', 'fr_alkyl_carbamate', 'fr_alkyl_halide',
               'fr_allylic_oxid', 'fr_amide', 'fr_amidine', 'fr_aniline', 'fr_aryl_methyl',
               'fr_azide', 'fr_azo', 'fr_barbitur', 'fr_benzene', 'fr_benzodiazepine',
               'fr_bicyclic', 'fr_diazo', 'fr_dihydropyridine', 'fr_epoxide', 'fr_ester',
               'fr_ether', 'fr_furan', 'fr_guanido', 'fr_halogen', 'fr_hdrzine', 'fr_hdrzone',
               'fr_imidazole', 'fr_imide', 'fr_isocyan', 'fr_isothiocyan', 'fr_ketone',
               'fr_ketone_Topliss', 'fr_lactam', 'fr_lactone', 'fr_methoxy', 'fr_morpholine',
               'fr_nitrile', 'fr_nitro', 'fr_nitro_arom', 'fr_nitro_arom_nonortho',
               'fr_nitroso', 'fr_oxazole', 'fr_oxime', 'fr_para_hydroxylation', 'fr_phenol',
               'fr_phenol_noOrthoHbond', 'fr_phos_acid', 'fr_phos_ester', 'fr_piperdine',
               'fr_piperzine', 'fr_priamide', 'fr_prisulfonamd', 'fr_pyridine', 'fr_quatN',
               'fr_sulfide', 'fr_sulfonamd', 'fr_sulfone', 'fr_term_acetylene', 'fr_tetrazole',
               'fr_thiazole', 'fr_thiocyan', 'fr_thiophene', 'fr_unbrch_alkane', 'fr_urea']

In [8]:
len(RDKIT_PROPS)

85

### 0-1-1-1 rdkit_fg_label_feature

In [9]:
@register_features_generator('fgtasklabel')
def rdkit_functional_group_label_features_generator(mol: Molecule) -> np.ndarray:
    """
    Generates functional group label for a molecule using RDKit.

    :param mol: A molecule (i.e. either a SMILES string or an RDKit molecule).
    :return: A 1D numpy array containing the RDKit 2D features.
    """
    smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol
    generator = rdDescriptors.RDKit2D(RDKIT_PROPS)
    features = generator.process(smiles)[1:]
    features = np.array(features)
    features[features != 0] = 1
    return features

### 0-1-1-2 rdkit_fg_label_feature 예시
- 아래의 글은 참고용이다. 그대로 쓰면 오류떠서 안된다.

In [11]:
smiles = 'C(O)O'
generator = rdDescriptors.RDKit2D(RDKIT_PROPS)
features = generator.process(smiles)[1:]
features2 = np.array(features)
features2[features2 != 0] = 1
features2

array([0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [12]:
features2.shape

(85,)

## 0-1. generate_and_save_features 함수

In [13]:
def generate_and_save_features(args: Namespace):
    """
    Computes and saves features for a dataset of molecules as a 2D array in a .npz file.

    :param args: Arguments.
    """
    # Create directory for save_path
    makedirs(args.save_path, isfile=True)

    # Get data and features function
    data = get_data(path=args.data_path, max_data_size=None)
    features_generator = get_features_generator(args.features_generator)
    temp_save_dir = args.save_path + '_temp'

    # Load partially complete data
    if args.restart:
        if os.path.exists(args.save_path):
            os.remove(args.save_path)
        if os.path.exists(temp_save_dir):
            shutil.rmtree(temp_save_dir)
    else:
        if os.path.exists(args.save_path):
            raise ValueError(f'"{args.save_path}" already exists and args.restart is False.')

        if os.path.exists(temp_save_dir):
            features, temp_num = load_temp(temp_save_dir)

    if not os.path.exists(temp_save_dir):
        makedirs(temp_save_dir)
        features, temp_num = [], 0

    # Build features map function
    data = data[len(features):]  # restrict to data for which features have not been computed yet
    mols = (d.smiles for d in data)

    if args.sequential:
        features_map = map(features_generator, mols)
    else:
        features_map = Pool(30).imap(features_generator, mols)
        
        

    # Get features
    temp_features = []
    for i, feats in tqdm(enumerate(features_map), total=len(data)):
        temp_features.append(feats)

        # Save temporary features every save_frequency
        if (i > 0 and (i + 1) % args.save_frequency == 0) or i == len(data) - 1:
            save_features(os.path.join(temp_save_dir, f'{temp_num}.npz'), temp_features)
            features.extend(temp_features)
            temp_features = []
            temp_num += 1

    try:
        # Save all features
        save_features(args.save_path, features)

        # Remove temporary features
        shutil.rmtree(temp_save_dir)
    except OverflowError:
        print('Features array is too large to save as a single file. Instead keeping features as a directory of files.')

In [13]:
if args.save_path is None:
    args.save_path = args.data_path.split('csv')[0] + 'npz'
generate_and_save_features(args)

100%|█████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 6821.12it/s]
Process ForkPoolWorker-8:
Process ForkPoolWorker-7:
Process ForkPoolWorker-2:
Process ForkPoolWorker-3:
Process ForkPoolWorker-10:
Process ForkPoolWorker-9:
Process ForkPoolWorker-6:
Process ForkPoolWorker-4:
Process ForkPoolWorker-18:
Traceback (most recent call last):
Process ForkPoolWorker-11:
  File "/home/rt/anaconda3/envs/tox/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Traceback (most recent call last):
Process ForkPoolWorker-17:
Process ForkPoolWorker-19:
Process ForkPoolWorker-5:
Traceback (most recent call last):
  File "/home/rt/anaconda3/envs/tox/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/rt/anaconda3/envs/tox/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Process ForkPoolWorker-22:
Traceback (most

# 1. build_vocab.py

In [4]:
import os

import sys

#from grover.data.torchvocab import MolVocab

## 1-1.MolVocab

In [15]:
"""
The contextual property.
"""
import pickle
from collections import Counter
from multiprocessing import Pool

import tqdm
from rdkit import Chem

#from grover.data.task_labels import atom_to_vocab
#from grover.data.task_labels import bond_to_vocab

In [16]:
class TorchVocab(object):
    """
    Defines the vocabulary for atoms/bonds in molecular.
    """

    def __init__(self, counter, max_size=None, min_freq=1, specials=('<pad>', '<other>'), vocab_type='atom'):
        """

        :param counter:
        :param max_size:
        :param min_freq:
        :param specials:
        :param vocab_type: 'atom': atom atom_vocab; 'bond': bond atom_vocab.
        """
        self.freqs = counter
        counter = counter.copy()
        min_freq = max(min_freq, 1)
        if vocab_type in ('atom', 'bond'):
            self.vocab_type = vocab_type
        else:
            raise ValueError('Wrong input for vocab_type!')
        self.itos = list(specials)

        max_size = None if max_size is None else max_size + len(self.itos)
        # sort by frequency, then alphabetically
        words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
        words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)

        for word, freq in words_and_frequencies:
            if freq < min_freq or len(self.itos) == max_size:
                break
            self.itos.append(word)
        # stoi is simply a reverse dict for itos
        self.stoi = {tok: i for i, tok in enumerate(self.itos)}
        self.other_index = 1
        self.pad_index = 0

    def __eq__(self, other):
        if self.freqs != other.freqs:
            return False
        if self.stoi != other.stoi:
            return False
        if self.itos != other.itos:
            return False
        # if self.vectors != other.vectors:
        #    return False
        return True

    def __len__(self):
        return len(self.itos)

    def vocab_rerank(self):
        self.stoi = {word: i for i, word in enumerate(self.itos)}

    def extend(self, v, sort=False):
        words = sorted(v.itos) if sort else v.itos
        for w in words:
            if w not in self.stoi:
                self.itos.append(w)
                self.stoi[w] = len(self.itos) - 1
                self.freqs[w] = 0
            self.freqs[w] += v.freqs[w]

    def mol_to_seq(self, mol, with_len=False):
        mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
        if self.vocab_type == 'atom':
            seq = [self.stoi.get(atom_to_vocab(mol, atom), self.other_index) for i, atom in enumerate(mol.GetAtoms())]
        else:
            seq = [self.stoi.get(bond_to_vocab(mol, bond), self.other_index) for i, bond in enumerate(mol.GetBonds())]
        return (seq, len(seq)) if with_len else seq

    @staticmethod
    def load_vocab(vocab_path: str) -> 'Vocab':
        with open(vocab_path, "rb") as f:
            return pickle.load(f)

    def save_vocab(self, vocab_path):
        with open(vocab_path, "wb") as f:
            pickle.dump(self, f)

### 1-1-1. atom/bond_to_vocab

In [17]:
from collections import Counter
from typing import Callable, Union

import numpy as np
from rdkit import Chem
from descriptastorus.descriptors import rdDescriptors

#from grover.data.molfeaturegenerator import register_features_generator

BOND_FEATURES = ['BondType', 'Stereo', 'BondDir']

In [18]:
def atom_to_vocab(mol, atom):
    """
    Convert atom to vocabulary. The convention is based on atom type and bond type.
    :param mol: the molecular.
    :param atom: the target atom.
    :return: the generated atom vocabulary with its contexts.
    """
    nei = Counter()
    for a in atom.GetNeighbors():
        bond = mol.GetBondBetweenAtoms(atom.GetIdx(), a.GetIdx())
        nei[str(a.GetSymbol()) + "-" + str(bond.GetBondType())] += 1
    keys = nei.keys()
    keys = list(keys)
    keys.sort()
    output = atom.GetSymbol()
    for k in keys:
        output = "%s_%s%d" % (output, k, nei[k])

    # The generated atom_vocab is too long?
    return output


def bond_to_vocab(mol, bond):
    """
    Convert bond to vocabulary. The convention is based on atom type and bond type.
    Considering one-hop neighbor atoms
    :param mol: the molecular.
    :param atom: the target atom.
    :return: the generated bond vocabulary with its contexts.
    """
    nei = Counter()
    two_neighbors = (bond.GetBeginAtom(), bond.GetEndAtom())
    two_indices = [a.GetIdx() for a in two_neighbors]
    for nei_atom in two_neighbors:
        for a in nei_atom.GetNeighbors():
            a_idx = a.GetIdx()
            if a_idx in two_indices:
                continue
            tmp_bond = mol.GetBondBetweenAtoms(nei_atom.GetIdx(), a_idx)
            nei[str(nei_atom.GetSymbol()) + '-' + get_bond_feature_name(tmp_bond)] += 1
    keys = list(nei.keys())
    keys.sort()
    output = get_bond_feature_name(bond)
    for k in keys:
        output = "%s_%s%d" % (output, k, nei[k])
    return output

def get_bond_feature_name(bond):
    """
    Return the string format of bond features.
    Bond features are surrounded with ()

    """
    ret = []
    for bond_feature in BOND_FEATURES:
        fea = eval(f"bond.Get{bond_feature}")()
        ret.append(str(fea))

    return '(' + '-'.join(ret) + ')'

In [19]:
mol = Chem.MolFromSmiles('C(O)O')
print(atom_to_vocab(mol, mol.GetAtoms()[0]), atom_to_vocab(mol, mol.GetAtoms()[1]), atom_to_vocab(mol, mol.GetAtoms()[2]))

C_O-SINGLE2 O_C-SINGLE1 O_C-SINGLE1


In [20]:
print(bond_to_vocab(mol, mol.GetBonds()[0]), bond_to_vocab(mol, mol.GetBonds()[1]))

(SINGLE-STEREONONE-NONE)_C-(SINGLE-STEREONONE-NONE)1 (SINGLE-STEREONONE-NONE)_C-(SINGLE-STEREONONE-NONE)1


## 1-1. MolVocab클래스

In [21]:
class MolVocab(TorchVocab):
    def __init__(self, smiles, max_size=None, min_freq=1, vocab_type='atom'):
        if vocab_type in ('atom', 'bond'):
            self.vocab_type = vocab_type
        else:
            raise ValueError('Wrong input for vocab_type!')

        print("Building %s vocab from smiles: %d" % (self.vocab_type, len(smiles)))
        counter = Counter()

        for smi in tqdm.tqdm(smiles):
            mol = Chem.MolFromSmiles(smi)
            if self.vocab_type == 'atom':
                for _, atom in enumerate(mol.GetAtoms()):
                    v = atom_to_vocab(mol, atom)
                    counter[v] += 1
            else:
                for _, bond in enumerate(mol.GetBonds()):
                    v = bond_to_vocab(mol, bond)
                    counter[v] += 1
        super().__init__(counter, max_size=max_size, min_freq=min_freq, vocab_type=vocab_type)

    def __init__(self, file_path, max_size=None, min_freq=1, num_workers=1, total_lines=None, vocab_type='atom'):
        if vocab_type in ('atom', 'bond'):
            self.vocab_type = vocab_type
        else:
            raise ValueError('Wrong input for vocab_type!')
        print("Building %s vocab from file: %s" % (self.vocab_type, file_path))

        from rdkit import RDLogger
        lg = RDLogger.logger()
        lg.setLevel(RDLogger.CRITICAL)

        if total_lines is None:
            def file_len(fname):
                f_len = 0
                with open(fname) as f:
                    for f_len, _ in enumerate(f):
                        pass
                return f_len + 1

            total_lines = file_len(file_path)

        counter = Counter()
        pbar = tqdm.tqdm(total=total_lines)
        pool = Pool(num_workers)
        res = []
        batch = 50000
        callback = lambda a: pbar.update(batch)
        for i in range(int(total_lines / batch + 1)):
            start = int(batch * i)
            end = min(total_lines, batch * (i + 1))
            # print("Start: %d, End: %d"%(start, end))
            res.append(pool.apply_async(MolVocab.read_smiles_from_file,
                                        args=(file_path, start, end, vocab_type,),
                                        callback=callback))
            # read_smiles_from_file(lock, file_path, start, end)
        pool.close()
        pool.join()
        for r in res:
            sub_counter = r.get()
            for k in sub_counter:
                if k not in counter:
                    counter[k] = 0
                counter[k] += sub_counter[k]
        # print(counter)
        super().__init__(counter, max_size=max_size, min_freq=min_freq, vocab_type=vocab_type)

    @staticmethod
    def read_smiles_from_file(file_path, start, end, vocab_type):
        # print("start")
        smiles = open(file_path, "r")
        smiles.readline()
        sub_counter = Counter()
        for i, smi in enumerate(smiles):
            if i < start:
                continue
            if i >= end:
                break
            mol = Chem.MolFromSmiles(smi)
            if vocab_type == 'atom':
                for atom in mol.GetAtoms():
                    v = atom_to_vocab(mol, atom)
                    sub_counter[v] += 1
            else:
                for bond in mol.GetBonds():
                    v = bond_to_vocab(mol, bond)
                    sub_counter[v] += 1
        # print("end")
        return sub_counter

    @staticmethod
    def load_vocab(vocab_path: str) -> 'MolVocab':
        with open(vocab_path, "rb") as f:
            return pickle.load(f)


## 실행코드

In [24]:
parser = ArgumentParser()
parser.add_argument('--data_path', default="../../dataset/grover_new_dataset/druglike_merged_refine2.csv", type=str)
parser.add_argument('--vocab_save_folder', default="../../dataset/grover_new_dataset", type=str)
parser.add_argument('--dataset_name', type=str, default=None,
                    help="Will be the first part of the vocab file name. If it is None,"
                         "the vocab files will be: atom_vocab.pkl and bond_vocab.pkl")
parser.add_argument('--vocab_max_size', type=int, default=None)
parser.add_argument('--vocab_min_freq', type=int, default=1)
args = parser.parse_args(['--data_path','data/CO2.csv','--vocab_save_folder', 'data/CO2', '--dataset_name','CO2'])

In [25]:
for vocab_type in ['atom', 'bond']:
    vocab_file = f"{vocab_type}_vocab.pkl"
    if args.dataset_name is not None:
        vocab_file = args.dataset_name + '_' + vocab_file
    vocab_save_path = os.path.join(args.vocab_save_folder, vocab_file)

    os.makedirs(os.path.dirname(vocab_save_path), exist_ok=True)
    vocab = MolVocab(file_path=args.data_path,
                     max_size=args.vocab_max_size,
                     min_freq=args.vocab_min_freq,
                     num_workers=100,
                     vocab_type=vocab_type)
    print(f"{vocab_type} vocab size", len(vocab))
    vocab.save_vocab(vocab_save_path)

Building atom vocab from file: data/CO2.csv



  0%|                                                                                            | 0/11 [00:00<?, ?it/s][A
50000it [00:00, 77882.49it/s]                                                                                           [A


atom vocab size 4
Building bond vocab from file: data/CO2.csv



  0%|                                                                                            | 0/11 [00:00<?, ?it/s][A
50000it [00:00, 84081.55it/s]                                                                                           [A

bond vocab size 3





# 2. pretrain_model()
- pretrain.py파일

In [8]:
import os
import time
from argparse import Namespace
from logging import Logger

import torch
from torch.utils.data import DataLoader

from grover.data.dist_sampler import DistributedSampler
from grover.data.groverdataset import get_data, split_data, GroverCollator, BatchMolDataset
from grover.data.torchvocab import MolVocab
from grover.model.models import GROVEREmbedding
from grover.util.multi_gpu_wrapper import MultiGpuWrapper as mgw
from grover.util.nn_utils import param_count
from grover.util.utils import build_optimizer, build_lr_scheduler
from task.grovertrainer import GROVERTrainer

## 2-0. pre_load_data()

In [9]:
import math
import time
import torch
from torch.utils.data.sampler import Sampler
import torch.distributed as dist

### 2-0-1. DistributedSampler
- 부분집합 만드는거

In [10]:
class DistributedSampler(Sampler):
    """Sampler that restricts data loading to a subset of the dataset.

    It is especially useful in conjunction with
    :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
    process can pass a DistributedSampler instance as a DataLoader sampler,
    and load a subset of the original dataset that is exclusive to it.

    .. note::
        Dataset is assumed to be of constant size.

    Arguments:
        dataset: Dataset used for sampling.
        num_replicas (optional): Number of processes participating in
            distributed training.
        rank (optional): Rank of the current process within num_replicas.
    """

    def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, sample_per_file=None):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.sample_per_file = sample_per_file
        self.shuffle = shuffle

    def get_indices(self):

        indices = list(range(len(self.dataset)))

        if self.sample_per_file is not None:
            indices = self.sub_indices_of_rank(indices)
        else:
            # add extra samples to make it evenly divisible
            indices += indices[:(self.total_size - len(indices))]
            assert len(indices) == self.total_size
            # subsample
            s = self.rank * self.num_samples
            e = min((self.rank + 1) * self.num_samples, len(indices))

            # indices = indices[self.rank:self.total_size:self.num_replicas]
            indices = indices[s:e]

        if self.shuffle:
            g = torch.Generator()
            # the seed need to be considered.
            g.manual_seed((self.epoch + 1) * (self.rank + 1) * time.time())
            idx = torch.randperm(len(indices), generator=g).tolist()
            indices = [indices[i] for i in idx]

        # disable this since sub_indices_of_rank.
        # assert len(indices) == self.num_samples

        return indices

    def sub_indices_of_rank(self, indices):

        # fix generator for each epoch
        g = torch.Generator()
        # All data should be loaded in each epoch.
        g.manual_seed((self.epoch + 1) * 2 + 3)

        # the fake file indices to cache
        f_indices = list(range(int(math.ceil(len(indices) * 1.0 / self.sample_per_file))))
        idx = torch.randperm(len(f_indices), generator=g).tolist()
        f_indices = [f_indices[i] for i in idx]

        file_per_rank = int(math.ceil(len(f_indices) * 1.0 / self.num_replicas))
        # add extra fake file to make it evenly divisible
        f_indices += f_indices[:(file_per_rank * self.num_replicas - len(f_indices))]

        # divide index by rank
        rank_s = self.rank * file_per_rank
        rank_e = min((self.rank + 1) * file_per_rank, len(f_indices))

        # get file index for this rank
        f_indices = f_indices[rank_s:rank_e]
        # print("f_indices")
        # print(f_indices)
        res_indices = []
        for fi in f_indices:
            # get real indices for this rank
            si = fi * self.sample_per_file
            ei = min((fi + 1) * self.sample_per_file, len(indices))
            cur_idx = [indices[i] for i in range(si, ei)]
            res_indices += cur_idx

        self.num_samples = len(res_indices)
        return res_indices

    def __iter__(self):
        return iter(self.get_indices())

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch

In [11]:
def pre_load_data(dataset: BatchMolDataset, rank: int, num_replicas: int, sample_per_file: int = None, epoch: int = 0):
    """
    Pre-load data at the beginning of each epoch.
    :param dataset: the training dataset.
    :param rank: the rank of the current worker.
    :param num_replicas: the replicas.
    :param sample_per_file: the number of the data points in each file. When sample_per_file is None, all data will be
    loaded. It implies the testing phase. (TODO: bad design here.)
    :param epoch: the epoch number.
    :return:
    """
    mock_sampler = DistributedSampler(dataset, num_replicas=num_replicas, rank=rank, shuffle=False,
                                      sample_per_file=sample_per_file)
    mock_sampler.set_epoch(epoch)
    pre_indices = mock_sampler.get_indices()
    for i in pre_indices:
        dataset.load_data(i)

## 2-1. run_training()

In [12]:
import os
import time
from argparse import Namespace
from logging import Logger

import torch
from torch.utils.data import DataLoader

from grover.data.dist_sampler import DistributedSampler
from grover.data.groverdataset import get_data, split_data, GroverCollator, BatchMolDataset
from grover.data.torchvocab import MolVocab
from grover.model.models import GROVEREmbedding
from grover.util.multi_gpu_wrapper import MultiGpuWrapper as mgw
from grover.util.nn_utils import param_count
from grover.util.utils import build_optimizer, build_lr_scheduler
from task.grovertrainer import GROVERTrainer

from grover.topology.mol_tree import Motif_Vocab
from grover.topology.motif_generation import Motif_Generation

### 2-1-1. get_data()
- summary 예 : n_files:60, n_samples:5970, sample_per_file:100
- graph : smiles를 잘게 나눈 csv들이 있는 곳
- feature : 분자 안에 어떤 motif가 포함되어 있는가를 feature라고 표현하네

In [13]:
import math
import os
import csv
from typing import Union, List
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from rdkit import Chem

import grover.util.utils as feautils
from grover.data import mol2graph
from grover.data.moldataset import MoleculeDatapoint
from grover.data.task_labels import atom_to_vocab, bond_to_vocab

In [14]:
class BatchDatapoint:
    def __init__(self,
                 smiles_file,
                 feature_file,
                 n_samples,
                 ):
        self.smiles_file = smiles_file
        self.feature_file = feature_file
        # deal with the last batch graph numbers.
        self.n_samples = n_samples
        self.datapoints = None

    def load_datapoints(self):
        features = self.load_feature()
        self.datapoints = []

        with open(self.smiles_file) as f:
            reader = csv.reader(f)
            next(reader)
            for i, line in enumerate(reader):
                # line = line[0]
                d = MoleculeDatapoint(line=line,
                                      features=features[i])
                self.datapoints.append(d)

        assert len(self.datapoints) == self.n_samples

    def load_feature(self):
        return feautils.load_features(self.feature_file)

    def shuffle(self):
        pass

    def clean_cache(self):
        del self.datapoints
        self.datapoints = None

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        assert self.datapoints is not None
        return self.datapoints[idx]

    def is_loaded(self):
        return self.datapoints is not None

In [15]:
class BatchMolDataset(Dataset):
    def __init__(self, data: List[BatchDatapoint],
                 graph_per_file=None):
        self.data = data

        self.len = 0
        for d in self.data:
            self.len += len(d)
        if graph_per_file is not None:
            self.sample_per_file = graph_per_file
        else:
            self.sample_per_file = len(self.data[0]) if len(self.data) != 0 else None

    def shuffle(self, seed: int = None):
        pass

    def clean_cache(self):
        for d in self.data:
            d.clean_cache()

    def __len__(self) -> int:
        return self.len

    def __getitem__(self, idx) -> Union[MoleculeDatapoint, List[MoleculeDatapoint]]:
        # print(idx)
        dp_idx = int(idx / self.sample_per_file)
        real_idx = idx % self.sample_per_file
        return self.data[dp_idx][real_idx]

    def load_data(self, idx):
        dp_idx = int(idx / self.sample_per_file)
        if not self.data[dp_idx].is_loaded():
            self.data[dp_idx].load_datapoints()

    def count_loaded_datapoints(self):
        res = 0
        for d in self.data:
            if d.is_loaded():
                res += 1
        return res

In [16]:
def get_data(data_path, logger=None):
    """
    Load data from the data_path.
    :param data_path: the data_path.
    :param logger: the logger.
    :return:
    """
    debug = logger.debug if logger is not None else print
    summary_path = os.path.join(data_path, "summary.txt")
    smiles_path = os.path.join(data_path, "graph")
    feature_path = os.path.join(data_path, "feature")

    fin = open(summary_path)
    n_files = int(fin.readline().strip().split(":")[-1])
    n_samples = int(fin.readline().strip().split(":")[-1])
    sample_per_file = int(fin.readline().strip().split(":")[-1])
    debug("Loading data:")
    debug("Number of files: %d" % n_files)
    debug("Number of samples: %d" % n_samples)
    debug("Samples/file: %d" % sample_per_file)

    datapoints = []
    for i in range(n_files):
        smiles_path_i = os.path.join(smiles_path, str(i) + ".csv")
        feature_path_i = os.path.join(feature_path, str(i) + ".npz")
        n_samples_i = sample_per_file if i != (n_files - 1) else n_samples % sample_per_file
        datapoints.append(BatchDatapoint(smiles_path_i, feature_path_i, n_samples_i))
    return BatchMolDataset(datapoints), sample_per_file

### 2-1-2. GroverCollator

#### 2-1-2-1. mol2graph(finetune과 같음)

In [17]:
from typing import List, Tuple, Union

MAX_ATOMIC_NUM = 100


ATOM_FEATURES = {
    'atomic_num': list(range(MAX_ATOMIC_NUM)),
    'degree': [0, 1, 2, 3, 4, 5],
    'formal_charge': [-1, -2, 1, 2, 0],
    'chiral_tag': [0, 1, 2, 3],
    'num_Hs': [0, 1, 2, 3, 4],
    'hybridization': [
        Chem.rdchem.HybridizationType.SP,
        Chem.rdchem.HybridizationType.SP2,
        Chem.rdchem.HybridizationType.SP3,
        Chem.rdchem.HybridizationType.SP3D,
        Chem.rdchem.HybridizationType.SP3D2
    ],
}

# len(choices) + 1 to include room for uncommon values; + 2 at end for IsAromatic and mass
ATOM_FDIM = sum(len(choices) + 1 for choices in ATOM_FEATURES.values()) + 2    # 참고로 이거 133이다... 왜지?
BOND_FDIM = 14
BOND_FDIM_3D = 15


def get_atom_fdim() -> int:
    """
    Gets the dimensionality of atom features.

    :param: Arguments.
    """
    return ATOM_FDIM + 18


def get_bond_fdim() -> int:
    """
    Gets the dimensionality of bond features.

    :param: Arguments.
    """
    return BOND_FDIM


def onek_encoding_unk(value: int, choices: List[int]) -> List[int]:
    """
    Creates a one-hot encoding.

    :param value: The value for which the encoding should be one.
    :param choices: A list of possible values.
    :return: A one-hot encoding of the value in a list of length len(choices) + 1.
    If value is not in the list of choices, then the final element in the encoding is 1.
    """
    encoding = [0] * (len(choices) + 1)
    if min(choices) < 0:
        index = value
    else:
        index = choices.index(value) if value in choices else -1
    encoding[index] = 1

    return encoding

In [18]:
class MolGraph:
    """
    A MolGraph represents the graph structure and featurization of a single molecule.

    A MolGraph computes the following attributes:
    - smiles: Smiles string.
    - n_atoms: The number of atoms in the molecule.
    - n_bonds: The number of bonds in the molecule.
    - f_atoms: A mapping from an atom index to a list atom features.
    - f_bonds: A mapping from a bond index to a list of bond features.
    - a2b: A mapping from an atom index to a list of incoming bond indices.
    - b2a: A mapping from a bond index to the index of the atom the bond originates from.
    - b2revb: A mapping from a bond index to the index of the reverse bond.
    """

    def __init__(self, smiles: str,  args: Namespace):
        """
        Computes the graph structure and featurization of a molecule.

        :param smiles: A smiles string.
        :param args: Arguments.
        """
        self.smiles = smiles
        self.args = args
        self.n_atoms = 0  # number of atoms
        self.n_bonds = 0  # number of bonds
        self.f_atoms = []  # mapping from atom index to atom features
        self.f_bonds = []  # mapping from bond index to concat(in_atom, bond) features
        self.a2b = []  # mapping from atom index to incoming bond indices
        self.b2a = []  # mapping from bond index to the index of the atom the bond is coming from
        self.b2revb = []  # mapping from bond index to the index of the reverse bond

        # Convert smiles to molecule
        mol = Chem.MolFromSmiles(smiles)

        self.hydrogen_donor = Chem.MolFromSmarts("[$([N;!H0;v3,v4&+1]),$([O,S;H1;+0]),n&H1&+0]")
        self.hydrogen_acceptor = Chem.MolFromSmarts(
            "[$([O,S;H1;v2;!$(*-*=[O,N,P,S])]),$([O,S;H0;v2]),$([O,S;-]),$([N;v3;!$(N-*=[O,N,P,S])]),"
            "n&H0&+0,$([o,s;+0;!$([o,s]:n);!$([o,s]:c:n)])]")
        self.acidic = Chem.MolFromSmarts("[$([C,S](=[O,S,P])-[O;H1,-1])]")
        self.basic = Chem.MolFromSmarts(
            "[#7;+,$([N;H2&+0][$([C,a]);!$([C,a](=O))]),$([N;H1&+0]([$([C,a]);!$([C,a](=O))])[$([C,a]);"
            "!$([C,a](=O))]),$([N;H0&+0]([C;!$(C(=O))])([C;!$(C(=O))])[C;!$(C(=O))])]")

        self.hydrogen_donor_match = sum(mol.GetSubstructMatches(self.hydrogen_donor), ())
        self.hydrogen_acceptor_match = sum(mol.GetSubstructMatches(self.hydrogen_acceptor), ())
        self.acidic_match = sum(mol.GetSubstructMatches(self.acidic), ())
        self.basic_match = sum(mol.GetSubstructMatches(self.basic), ())
        self.ring_info = mol.GetRingInfo()


        # fake the number of "atoms" if we are collapsing substructures
        self.n_atoms = mol.GetNumAtoms()

        # Get atom features
        for _, atom in enumerate(mol.GetAtoms()):
            self.f_atoms.append(self.atom_features(atom))
        self.f_atoms = [self.f_atoms[i] for i in range(self.n_atoms)]

        for _ in range(self.n_atoms):
            self.a2b.append([])

        # Get bond features
        for a1 in range(self.n_atoms):
            for a2 in range(a1 + 1, self.n_atoms):
                bond = mol.GetBondBetweenAtoms(a1, a2)

                if bond is None:
                    continue

                if args.bond_drop_rate > 0:
                    if np.random.binomial(1, args.bond_drop_rate):
                        continue

                f_bond = self.bond_features(bond)

                # Always treat the bond as directed.
                self.f_bonds.append(self.f_atoms[a1] + f_bond)
                self.f_bonds.append(self.f_atoms[a2] + f_bond)

                # Update index mappings
                b1 = self.n_bonds
                b2 = b1 + 1
                self.a2b[a2].append(b1)  # b1 = a1 --> a2
                self.b2a.append(a1)
                self.a2b[a1].append(b2)  # b2 = a2 --> a1
                self.b2a.append(a2)
                self.b2revb.append(b2)
                self.b2revb.append(b1)
                self.n_bonds += 2

    def atom_features(self, atom: Chem.rdchem.Atom) -> List[Union[bool, int, float]]:
        """
        Builds a feature vector for an atom.

        :param atom: An RDKit atom.
        :param functional_groups: A k-hot vector indicating the functional groups the atom belongs to.
        :return: A list containing the atom features.
        """
        features = onek_encoding_unk(atom.GetAtomicNum() - 1, ATOM_FEATURES['atomic_num']) + \
                   onek_encoding_unk(atom.GetTotalDegree(), ATOM_FEATURES['degree']) + \
                   onek_encoding_unk(atom.GetFormalCharge(), ATOM_FEATURES['formal_charge']) + \
                   onek_encoding_unk(int(atom.GetChiralTag()), ATOM_FEATURES['chiral_tag']) + \
                   onek_encoding_unk(int(atom.GetTotalNumHs()), ATOM_FEATURES['num_Hs']) + \
                   onek_encoding_unk(int(atom.GetHybridization()), ATOM_FEATURES['hybridization']) + \
                   [1 if atom.GetIsAromatic() else 0] + \
                   [atom.GetMass() * 0.01]
        atom_idx = atom.GetIdx()
        features = features + \
                   onek_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) + \
                   [atom_idx in self.hydrogen_acceptor_match] + \
                   [atom_idx in self.hydrogen_donor_match] + \
                   [atom_idx in self.acidic_match] + \
                   [atom_idx in self.basic_match] + \
                   [self.ring_info.IsAtomInRingOfSize(atom_idx, 3),
                    self.ring_info.IsAtomInRingOfSize(atom_idx, 4),
                    self.ring_info.IsAtomInRingOfSize(atom_idx, 5),
                    self.ring_info.IsAtomInRingOfSize(atom_idx, 6),
                    self.ring_info.IsAtomInRingOfSize(atom_idx, 7),
                    self.ring_info.IsAtomInRingOfSize(atom_idx, 8)]
        return features

    def bond_features(self, bond: Chem.rdchem.Bond
                      ) -> List[Union[bool, int, float]]:
        """
        Builds a feature vector for a bond.

        :param bond: A RDKit bond.
        :return: A list containing the bond features.
        """

        if bond is None:
            fbond = [1] + [0] * (BOND_FDIM - 1)
        else:
            bt = bond.GetBondType()
            fbond = [
                0,  # bond is not None
                bt == Chem.rdchem.BondType.SINGLE,
                bt == Chem.rdchem.BondType.DOUBLE,
                bt == Chem.rdchem.BondType.TRIPLE,
                bt == Chem.rdchem.BondType.AROMATIC,
                (bond.GetIsConjugated() if bt is not None else 0),
                (bond.IsInRing() if bt is not None else 0)
            ]
            fbond += onek_encoding_unk(int(bond.GetStereo()), list(range(6)))
        return fbond

In [19]:
class BatchMolGraph:
    """
    A BatchMolGraph represents the graph structure and featurization of a batch of molecules.

    A BatchMolGraph contains the attributes of a MolGraph plus:
    - smiles_batch: A list of smiles strings.
    - n_mols: The number of molecules in the batch.
    - atom_fdim: The dimensionality of the atom features.
    - bond_fdim: The dimensionality of the bond features (technically the combined atom/bond features).
    - a_scope: A list of tuples indicating the start and end atom indices for each molecule.
    - b_scope: A list of tuples indicating the start and end bond indices for each molecule.
    - max_num_bonds: The maximum number of bonds neighboring an atom in this batch.
    - b2b: (Optional) A mapping from a bond index to incoming bond indices.
    - a2a: (Optional): A mapping from an atom index to neighboring atom indices.
    """

    def __init__(self, mol_graphs: List[MolGraph], args: Namespace):
        self.smiles_batch = [mol_graph.smiles for mol_graph in mol_graphs]
        self.n_mols = len(self.smiles_batch)

        self.atom_fdim = get_atom_fdim()
        self.bond_fdim = get_bond_fdim() + self.atom_fdim

        # Start n_atoms and n_bonds at 1 b/c zero padding
        self.n_atoms = 1  # number of atoms (start at 1 b/c need index 0 as padding)
        self.n_bonds = 1  # number of bonds (start at 1 b/c need index 0 as padding)
        self.a_scope = []  # list of tuples indicating (start_atom_index, num_atoms) for each molecule
        self.b_scope = []  # list of tuples indicating (start_bond_index, num_bonds) for each molecule

        # All start with zero padding so that indexing with zero padding returns zeros
        f_atoms = [[0] * self.atom_fdim]  # atom features
        f_bonds = [[0] * self.bond_fdim]  # combined atom/bond features
        a2b = [[]]  # mapping from atom index to incoming bond indices
        b2a = [0]  # mapping from bond index to the index of the atom the bond is coming from
        b2revb = [0]  # mapping from bond index to the index of the reverse bond

        for mol_graph in mol_graphs:
            f_atoms.extend(mol_graph.f_atoms)
            f_bonds.extend(mol_graph.f_bonds)

            for a in range(mol_graph.n_atoms):
                a2b.append([b + self.n_bonds for b in mol_graph.a2b[a]])

            for b in range(mol_graph.n_bonds):
                b2a.append(self.n_atoms + mol_graph.b2a[b])
                b2revb.append(self.n_bonds + mol_graph.b2revb[b])

            self.a_scope.append((self.n_atoms, mol_graph.n_atoms))
            self.b_scope.append((self.n_bonds, mol_graph.n_bonds))
            self.n_atoms += mol_graph.n_atoms
            self.n_bonds += mol_graph.n_bonds

        # max with 1 to fix a crash in rare case of all single-heavy-atom mols
        self.max_num_bonds = max(1, max(len(in_bonds) for in_bonds in a2b))

        self.f_atoms = torch.FloatTensor(f_atoms)
        self.f_bonds = torch.FloatTensor(f_bonds)
        self.a2b = torch.LongTensor([a2b[a] + [0] * (self.max_num_bonds - len(a2b[a])) for a in range(self.n_atoms)])
        self.b2a = torch.LongTensor(b2a)
        self.b2revb = torch.LongTensor(b2revb)
        self.b2b = None  # try to avoid computing b2b b/c O(n_atoms^3)
        self.a2a = self.b2a[self.a2b]  # only needed if using atom messages
        self.a_scope = torch.LongTensor(self.a_scope)
        self.b_scope = torch.LongTensor(self.b_scope)

    def set_new_atom_feature(self, f_atoms):
        """
        Set the new atom feature. Do not update bond feature.
        :param f_atoms:
        """
        self.f_atoms = f_atoms

    def get_components(self) -> Tuple[torch.FloatTensor, torch.FloatTensor,
                                      torch.LongTensor, torch.LongTensor, torch.LongTensor,
                                      List[Tuple[int, int]], List[Tuple[int, int]]]:
        """
        Returns the components of the BatchMolGraph.

        :return: A tuple containing PyTorch tensors with the atom features, bond features, and graph structure
        and two lists indicating the scope of the atoms and bonds (i.e. which molecules they belong to).
        """
        return self.f_atoms, self.f_bonds, self.a2b, self.b2a, self.b2revb, self.a_scope, self.b_scope, self.a2a

    def get_b2b(self) -> torch.LongTensor:
        """
        Computes (if necessary) and returns a mapping from each bond index to all the incoming bond indices.

        :return: A PyTorch tensor containing the mapping from each bond index to all the incoming bond indices.
        """

        if self.b2b is None:
            b2b = self.a2b[self.b2a]  # num_bonds x max_num_bonds
            # b2b includes reverse edge for each bond so need to mask out
            revmask = (b2b != self.b2revb.unsqueeze(1).repeat(1, b2b.size(1))).long()  # num_bonds x max_num_bonds
            self.b2b = b2b * revmask

        return self.b2b

    def get_a2a(self) -> torch.LongTensor:
        """
        Computes (if necessary) and returns a mapping from each atom index to all neighboring atom indices.

        :return: A PyTorch tensor containing the mapping from each bond index to all the incodming bond indices.
        """
        if self.a2a is None:
            # b = a1 --> a2
            # a2b maps a2 to all incoming bonds b
            # b2a maps each bond b to the atom it comes from a1
            # thus b2a[a2b] maps atom a2 to neighboring atoms a1
            self.a2a = self.b2a[self.a2b]  # num_atoms x max_num_bonds

        return self.a2a

In [20]:
def mol2graph(smiles_batch: List[str], shared_dict,
              args: Namespace) -> BatchMolGraph:
    """
    Converts a list of SMILES strings to a BatchMolGraph containing the batch of molecular graphs.

    :param smiles_batch: A list of SMILES strings.
    :param args: Arguments.
    :return: A BatchMolGraph containing the combined molecular graph for the molecules
    """
    mol_graphs = []
    for smiles in smiles_batch:
        if smiles in shared_dict:
            mol_graph = shared_dict[smiles]
        else:
            mol_graph = MolGraph(smiles, args)
            if not args.no_cache:
                shared_dict[smiles] = mol_graph
        mol_graphs.append(mol_graph)

    return BatchMolGraph(mol_graphs, args)

#### 2-1-2-2. Collator()함수
- 여기의 percent를 수정하면 몇퍼센트를 알아맞출지를 결정한다.
- 15%를 가린다는 의미인데, 실제로는 15%만 타겟으로 선정하여 맞추는 형식,,,

In [21]:
class GroverCollator(object):
    def __init__(self, shared_dict, atom_vocab, bond_vocab, args):
        self.args = args
        self.shared_dict = shared_dict
        self.atom_vocab = atom_vocab
        self.bond_vocab = bond_vocab

    def atom_random_mask(self, smiles_batch):
        """
        Perform the random mask operation on atoms.
        :param smiles_batch:
        :return: The corresponding atom labels.
        """
        # There is a zero padding.
        vocab_label = [0]
        percent = 0.15
        for smi in smiles_batch:
            mol = Chem.MolFromSmiles(smi)
            mlabel = [0] * mol.GetNumAtoms()
            n_mask = math.ceil(mol.GetNumAtoms() * percent)
            perm = np.random.permutation(mol.GetNumAtoms())[:n_mask]
            for p in perm:
                atom = mol.GetAtomWithIdx(int(p))
                mlabel[p] = self.atom_vocab.stoi.get(atom_to_vocab(mol, atom), self.atom_vocab.other_index)

            vocab_label.extend(mlabel)
        return vocab_label

    def bond_random_mask(self, smiles_batch):
        """
        Perform the random mask operaiion on bonds.
        :param smiles_batch:
        :return: The corresponding bond labels.
        """
        # There is a zero padding.
        vocab_label = [0]
        percent = 0.15
        for smi in smiles_batch:
            mol = Chem.MolFromSmiles(smi)
            nm_atoms = mol.GetNumAtoms()
            nm_bonds = mol.GetNumBonds()
            mlabel = []
            n_mask = math.ceil(nm_bonds * percent)
            perm = np.random.permutation(nm_bonds)[:n_mask]
            virtual_bond_id = 0
            for a1 in range(nm_atoms):
                for a2 in range(a1 + 1, nm_atoms):
                    bond = mol.GetBondBetweenAtoms(a1, a2)

                    if bond is None:
                        continue
                    if virtual_bond_id in perm:
                        label = self.bond_vocab.stoi.get(bond_to_vocab(mol, bond), self.bond_vocab.other_index)
                        mlabel.extend([label])
                    else:
                        mlabel.extend([0])

                    virtual_bond_id += 1
            # todo: might need to consider bond_drop_rate
            # todo: double check reverse bond
            vocab_label.extend(mlabel)
        return vocab_label

    def __call__(self, batch):
        smiles_batch = [d.smiles for d in batch]
        batchgraph = mol2graph(smiles_batch, self.shared_dict, self.args).get_components()

        atom_vocab_label = torch.Tensor(self.atom_random_mask(smiles_batch)).long()
        bond_vocab_label = torch.Tensor(self.bond_random_mask(smiles_batch)).long()
        fgroup_label = torch.Tensor([d.features for d in batch]).float()
        # may be some mask here
        res = {"graph_input": batchgraph,
               "targets": {"av_task": atom_vocab_label,
                           "bv_task": bond_vocab_label,
                           "fg_task": fgroup_label}
               }
        return res

### 2-1-3. GROVEREmbedding()

#### 2-1-3-1. MPNEncoder
- 먼저 input feature(해당 블락의 feature 뿐만 아니라 - a2a, a2b, b2a, b2arevb 등 다양한 정보를 활용한다.)를 여기서 사용할 feature크기로 Linear(Dense)를 통과시키고, 활성화 함수를 적용시킨다.
- 자신을 제외한 node, edge들의 메시지를 임의의 hop만큼 종합하여 Linear(Dense) 1개를 통과시켜서 Message를 구한다.
- dyMPN은 truncated normal로 3을 기준으로 +-3에서 하는게 기본이나 0~6사이의 uniform분포에서 hop수를 지정할 수도 있다.

In [22]:
import scipy.stats as stats
from torch import nn as nn

In [23]:
class MPNEncoder(nn.Module):
    """A message passing neural network for encoding a molecule."""

    def __init__(self, args: Namespace,
                 atom_messages: bool,
                 init_message_dim: int,
                 attached_fea_fdim: int,
                 hidden_size: int,
                 bias: bool,
                 depth: int,
                 dropout: float,
                 undirected: bool,
                 dense: bool,
                 aggregate_to_atom: bool,
                 attach_fea: bool,
                 input_layer="fc",
                 dynamic_depth='none'
                 ):
        """
        Initializes the MPNEncoder.
        :param args: the arguments.
        :param atom_messages: enables atom_messages or not.
        :param init_message_dim:  the initial input message dimension.
        :param attached_fea_fdim:  the attached feature dimension.
        :param hidden_size: the output message dimension during message passing.
        :param bias: the bias in the message passing.
        :param depth: the message passing depth.
        :param dropout: the dropout rate.
        :param undirected: the message passing is undirected or not.
        :param dense: enables the dense connections.
        :param attach_fea: enables the feature attachment during the message passing process.
        :param dynamic_depth: enables the dynamic depth. Possible choices: "none", "uniform" and "truncnorm"
        """
        super(MPNEncoder, self).__init__()
        self.init_message_dim = init_message_dim
        self.attached_fea_fdim = attached_fea_fdim
        self.hidden_size = hidden_size
        self.bias = bias
        self.depth = depth
        self.dropout = dropout
        self.input_layer = input_layer
        self.layers_per_message = 1
        self.undirected = undirected
        self.atom_messages = atom_messages
        self.dense = dense
        self.aggreate_to_atom = aggregate_to_atom
        self.attached_fea = attach_fea
        self.dynamic_depth = dynamic_depth

        # Dropout
        self.dropout_layer = nn.Dropout(p=self.dropout)

        # Activation
        self.act_func = get_activation_function(args.activation)

        # Input
        if self.input_layer == "fc":
            input_dim = self.init_message_dim
            self.W_i = nn.Linear(input_dim, self.hidden_size, bias=self.bias)

        if self.attached_fea:
            w_h_input_size = self.hidden_size + self.attached_fea_fdim
        else:
            w_h_input_size = self.hidden_size

        # Shared weight matrix across depths (default)
        self.W_h = nn.Linear(w_h_input_size, self.hidden_size, bias=self.bias)

    def forward(self,
                init_messages,
                init_attached_features,
                a2nei,
                a2attached,
                b2a=None,
                b2revb=None,
                adjs=None
                ) -> torch.FloatTensor:
        """
        The forward function.
        :param init_messages:  initial massages, can be atom features or bond features.
        :param init_attached_features: initial attached_features.
        :param a2nei: the relation of item to its neighbors. For the atom message passing, a2nei = a2a. For bond
        messages a2nei = a2b
        :param a2attached: the relation of item to the attached features during message passing. For the atom message
        passing, a2attached = a2b. For the bond message passing a2attached = a2a
        :param b2a: remove the reversed bond in bond message passing
        :param b2revb: remove the revered atom in bond message passing
        :return: if aggreate_to_atom or self.atom_messages, return num_atoms x hidden.
        Otherwise, return num_bonds x hidden
        """

        # Input
        if self.input_layer == 'fc':
            input = self.W_i(init_messages)  # num_bonds x hidden_size # f_bond
            message = self.act_func(input)  # num_bonds x hidden_size
        elif self.input_layer == 'none':
            input = init_messages
            message = input

        attached_fea = init_attached_features  # f_atom / f_bond

        # dynamic depth
        # uniform sampling from depth - 1 to depth + 1
        # only works in training.
        if self.training and self.dynamic_depth != "none":
            if self.dynamic_depth == "uniform":
                # uniform sampling
                ndepth = numpy.random.randint(self.depth - 3, self.depth + 3)
            else:
                # truncnorm
                mu = self.depth
                sigma = 1
                lower = mu - 3 * sigma
                upper = mu + 3 * sigma
                X = stats.truncnorm((lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
                ndepth = int(X.rvs(1))
        else:
            ndepth = self.depth

        # Message passing
        for _ in range(ndepth - 1):
            if self.undirected:
                # two directions should be the same
                message = (message + message[b2revb]) / 2

            nei_message = select_neighbor_and_aggregate(message, a2nei)
            a_message = nei_message
            if self.attached_fea:
                attached_nei_fea = select_neighbor_and_aggregate(attached_fea, a2attached)
                a_message = torch.cat((nei_message, attached_nei_fea), dim=1)

            if not self.atom_messages:
                rev_message = message[b2revb]
                if self.attached_fea:
                    atom_rev_message = attached_fea[b2a[b2revb]]
                    rev_message = torch.cat((rev_message, atom_rev_message), dim=1)
                # Except reverse bond its-self(w) ! \sum_{k\in N(u) \ w}
                message = a_message[b2a] - rev_message  # num_bonds x hidden
            else:
                message = a_message

            message = self.W_h(message)

            # BUG here, by default MPNEncoder use the dense connection in the message passing step.
            # The correct form should if not self.dense
            if self.dense:
                message = self.act_func(message)  # num_bonds x hidden_size
            else:
                message = self.act_func(input + message)
            message = self.dropout_layer(message)  # num_bonds x hidden

        output = message
        print(output)

        return output  # num_atoms x hidden


#### 2-1-3-2. 멀티헤드어텐션
- 위의 MPN을 Q,K,V로 3개에 대해 Head수만큼 만든다(4 또는 8)
- 그리고 각각의 Head수만큼 Self-Attention을 점곱하여 계산해낸다.

In [24]:
class Attention(nn.Module):
    """
    Compute 'Scaled Dot Product SelfAttention
    """

    def forward(self, query, key, value, mask=None, dropout=None):
        """
        :param query:
        :param key:
        :param value:
        :param mask:
        :param dropout:
        :return:
        """
        scores = torch.matmul(query, key.transpose(-2, -1)) \
                 / math.sqrt(query.size(-1))

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        p_attn = F.softmax(scores, dim=-1)

        if dropout is not None:
            p_attn = dropout(p_attn)

        return torch.matmul(p_attn, value), p_attn

In [25]:
class MultiHeadedAttention(nn.Module):
    """
    The multi-head attention module. Take in model size and number of heads.
    """

    def __init__(self, h, d_model, dropout=0.1, bias=False):
        """

        :param h:
        :param d_model:
        :param dropout:
        :param bias:
        """
        super().__init__()
        print(f'd_model is {d_model} and h is {h}')
        assert d_model % h == 0

        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h  # number of heads

        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])  # why 3: query, key, value
        self.output_linear = nn.Linear(d_model, d_model, bias)
        self.attention = Attention()

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        """

        :param query:
        :param key:
        :param value:
        :param mask:
        :return:
        """
        batch_size = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
                             for l, x in zip(self.linear_layers, (query, key, value))]

        # 2) Apply attention on all the projected vectors in batch.
        x, _ = self.attention(query, key, value, mask=mask, dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)

        return self.output_linear(x)

In [26]:
class Head(nn.Module):
    """
    One head for multi-headed attention.
    :return: (query, key, value)
    """

    def __init__(self, args, hidden_size, atom_messages=False):
        """
        Initialization.
        :param args: The argument.
        :param hidden_size: the dimension of hidden layer in Head.
        :param atom_messages: the MPNEncoder type.
        """
        super(Head, self).__init__()
        atom_fdim = hidden_size
        bond_fdim = hidden_size
        hidden_size = hidden_size
        self.atom_messages = atom_messages
        if self.atom_messages:
            init_message_dim = atom_fdim
            attached_fea_dim = bond_fdim
        else:
            init_message_dim = bond_fdim
            attached_fea_dim = atom_fdim

        # Here we use the message passing network as query, key and value.
        self.mpn_q = MPNEncoder(args=args,
                                atom_messages=atom_messages,
                                init_message_dim=init_message_dim,
                                attached_fea_fdim=attached_fea_dim,
                                hidden_size=hidden_size,
                                bias=args.bias,
                                depth=args.depth,
                                dropout=args.dropout,
                                undirected=args.undirected,
                                dense=args.dense,
                                aggregate_to_atom=False,
                                attach_fea=False,
                                input_layer="none",
                                dynamic_depth="truncnorm")
        self.mpn_k = MPNEncoder(args=args,
                                atom_messages=atom_messages,
                                init_message_dim=init_message_dim,
                                attached_fea_fdim=attached_fea_dim,
                                hidden_size=hidden_size,
                                bias=args.bias,
                                depth=args.depth,
                                dropout=args.dropout,
                                undirected=args.undirected,
                                dense=args.dense,
                                aggregate_to_atom=False,
                                attach_fea=False,
                                input_layer="none",
                                dynamic_depth="truncnorm")
        self.mpn_v = MPNEncoder(args=args,
                                atom_messages=atom_messages,
                                init_message_dim=init_message_dim,
                                attached_fea_fdim=attached_fea_dim,
                                hidden_size=hidden_size,
                                bias=args.bias,
                                depth=args.depth,
                                dropout=args.dropout,
                                undirected=args.undirected,
                                dense=args.dense,
                                aggregate_to_atom=False,
                                attach_fea=False,
                                input_layer="none",
                                dynamic_depth="truncnorm")

    def forward(self, f_atoms, f_bonds, a2b, a2a, b2a, b2revb):
        """
        The forward function.
        :param f_atoms: the atom features, num_atoms * atom_dim
        :param f_bonds: the bond features, num_bonds * bond_dim
        :param a2b: mapping from atom index to incoming bond indices.
        :param a2a: mapping from atom index to its neighbors. num_atoms * max_num_bonds
        :param b2a: mapping from bond index to the index of the atom the bond is coming from.
        :param b2revb: mapping from bond index to the index of the reverse bond.
        :return:
        """
        if self.atom_messages:
            init_messages = f_atoms
            init_attached_features = f_bonds
            a2nei = a2a
            a2attached = a2b
            b2a = b2a
            b2revb = b2revb
        else:
            init_messages = f_bonds
            init_attached_features = f_atoms
            a2nei = a2b
            a2attached = a2a
            b2a = b2a
            b2revb = b2revb

        q = self.mpn_q(init_messages=init_messages,
                       init_attached_features=init_attached_features,
                       a2nei=a2nei,
                       a2attached=a2attached,
                       b2a=b2a,
                       b2revb=b2revb)
        k = self.mpn_k(init_messages=init_messages,
                       init_attached_features=init_attached_features,
                       a2nei=a2nei,
                       a2attached=a2attached,
                       b2a=b2a,
                       b2revb=b2revb)
        v = self.mpn_v(init_messages=init_messages,
                       init_attached_features=init_attached_features,
                       a2nei=a2nei,
                       a2attached=a2attached,
                       b2a=b2a,
                       b2revb=b2revb)
        return q, k, v

#### 2-1-3-3 MTBLOCK

In [27]:
from grover.util.nn_utils import get_activation_function, select_neighbor_and_aggregate
from torch.nn import LayerNorm, functional as F

In [28]:
class MTBlock(nn.Module):
    """
    The Multi-headed attention block.
    """

    def __init__(self,
                 args,
                 num_attn_head,
                 input_dim,
                 hidden_size,
                 activation="ReLU",
                 dropout=0.0,
                 bias=True,
                 atom_messages=False,
                 cuda=True,
                 res_connection=False):
        """

        :param args: the arguments.
        :param num_attn_head: the number of attention head.
        :param input_dim: the input dimension.
        :param hidden_size: the hidden size of the model.
        :param activation: the activation function.
        :param dropout: the dropout ratio
        :param bias: if true: all linear layer contains bias term.
        :param atom_messages: the MPNEncoder type
        :param cuda: if true, the model run with GPU.
        :param res_connection: enables the skip-connection in MTBlock.
        """
        super(MTBlock, self).__init__()
        # self.args = args
        self.atom_messages = atom_messages
        self.hidden_size = hidden_size
        self.heads = nn.ModuleList()
        self.input_dim = input_dim
        self.cuda = cuda
        self.res_connection = res_connection
        self.act_func = get_activation_function(activation)
        self.dropout_layer = nn.Dropout(p=dropout)
        # Note: elementwise_affine has to be consistent with the pre-training phase
        self.layernorm = nn.LayerNorm(self.hidden_size, elementwise_affine=True)

        self.W_i = nn.Linear(self.input_dim, self.hidden_size, bias=bias)
        self.attn = MultiHeadedAttention(h=num_attn_head,
                                         d_model=self.hidden_size,
                                         bias=bias,
                                         dropout=dropout)
        self.W_o = nn.Linear(self.hidden_size * num_attn_head, self.hidden_size, bias=bias)
        self.sublayer = SublayerConnection(self.hidden_size, dropout)
        for _ in range(num_attn_head):
            self.heads.append(Head(args, hidden_size=hidden_size, atom_messages=atom_messages))

    def forward(self, batch, features_batch=None):
        """

        :param batch: the graph batch generated by GroverCollator.
        :param features_batch: the additional features of molecules. (deprecated)
        :return:
        """
        f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a = batch

        if self.atom_messages:
            # Only add linear transformation in the input feature.
            if f_atoms.shape[1] != self.hidden_size:
                f_atoms = self.W_i(f_atoms)
                f_atoms = self.dropout_layer(self.layernorm(self.act_func(f_atoms)))

        else:  # bond messages
            if f_bonds.shape[1] != self.hidden_size:
                f_bonds = self.W_i(f_bonds)
                f_bonds = self.dropout_layer(self.layernorm(self.act_func(f_bonds)))

        queries = []
        keys = []
        values = []
        for head in self.heads:
            q, k, v = head(f_atoms, f_bonds, a2b, a2a, b2a, b2revb)
            queries.append(q.unsqueeze(1))
            keys.append(k.unsqueeze(1))
            values.append(v.unsqueeze(1))
        queries = torch.cat(queries, dim=1)
        keys = torch.cat(keys, dim=1)
        values = torch.cat(values, dim=1)

        x_out = self.attn(queries, keys, values)  # multi-headed attention
        x_out = x_out.view(x_out.shape[0], -1)
        x_out = self.W_o(x_out)

        x_in = None
        # support no residual connection in MTBlock.
        if self.res_connection:
            if self.atom_messages:
                x_in = f_atoms
            else:
                x_in = f_bonds

        if self.atom_messages:
            f_atoms = self.sublayer(x_in, x_out)
        else:
            f_bonds = self.sublayer(x_in, x_out)

        batch = f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a
        features_batch = features_batch
        return batch, features_batch

In [29]:
class PositionwiseFeedForward(nn.Module):
    """Implements FFN equation."""

    def __init__(self, d_model, d_ff, activation="PReLU", dropout=0.1, d_out=None):
        """Initialization.

        :param d_model: the input dimension.
        :param d_ff: the hidden dimension.
        :param activation: the activation function.
        :param dropout: the dropout rate.
        :param d_out: the output dimension, the default value is equal to d_model.
        """
        super(PositionwiseFeedForward, self).__init__()
        if d_out is None:
            d_out = d_model
        # By default, bias is on.
        self.W_1 = nn.Linear(d_model, d_ff)
        self.W_2 = nn.Linear(d_ff, d_out)
        self.dropout = nn.Dropout(dropout)
        self.act_func = get_activation_function(activation)

    def forward(self, x):
        """
        The forward function
        :param x: input tensor.
        :return:
        """
        return self.W_2(self.dropout(self.act_func(self.W_1(x))))

In [30]:
class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """

    def __init__(self, size, dropout):
        """Initialization.

        :param size: the input dimension.
        :param dropout: the dropout ratio.
        """
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size, elementwise_affine=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, inputs, outputs):
        """Apply residual connection to any sublayer with the same size."""
        # return x + self.dropout(self.norm(x))
        if inputs is None:
            return self.dropout(self.norm(outputs))
        return inputs + self.dropout(self.norm(outputs))

In [31]:
def index_select_nd(source: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
    """
    Selects the message features from source corresponding to the atom or bond indices in index.

    :param source: A tensor of shape (num_bonds, hidden_size) containing message features.
    :param index: A tensor of shape (num_atoms/num_bonds, max_num_bonds) containing the atom or bond
    indices to select from source.
    :return: A tensor of shape (num_atoms/num_bonds, max_num_bonds, hidden_size) containing the message
    features corresponding to the atoms/bonds specified in index.
    """
    index_size = index.size()  # (num_atoms/num_bonds, max_num_bonds)
    suffix_dim = source.size()[1:]  # (hidden_size,)
    final_size = index_size + suffix_dim  # (num_atoms/num_bonds, max_num_bonds, hidden_size)

    target = source.index_select(dim=0, index=index.view(-1))  # (num_atoms/num_bonds * max_num_bonds, hidden_size)
    target = target.view(final_size)  # (num_atoms/num_bonds, max_num_bonds, hidden_size)

    return target


#### 2-1-4-4. GTransEncoder and Embedding

In [32]:
from typing import List, Dict, Callable

In [33]:
class GTransEncoder(nn.Module):
    def __init__(self,
                 args,
                 hidden_size,
                 edge_fdim,
                 node_fdim,
                 dropout=0.0,
                 activation="ReLU",
                 num_mt_block=1,
                 num_attn_head=4,
                 atom_emb_output: Union[bool, str] = False,  # options: True, False, None, "atom", "bond", "both"
                 bias=False,
                 cuda=True,
                 res_connection=False):
        """

        :param args: the arguments.
        :param hidden_size: the hidden size of the model.
        :param edge_fdim: the dimension of additional feature for edge/bond.
        :param node_fdim: the dimension of additional feature for node/atom.
        :param dropout: the dropout ratio
        :param activation: the activation function
        :param num_mt_block: the number of mt block.
        :param num_attn_head: the number of attention head.
        :param atom_emb_output:  enable the output aggregation after message passing.
                                              atom_messages:      True                      False
        -False: no aggregating to atom. output size:     (num_atoms, hidden_size)    (num_bonds, hidden_size)
        -True:  aggregating to atom.    output size:     (num_atoms, hidden_size)    (num_atoms, hidden_size)
        -None:                         same as False
        -"atom":                       same as True
        -"bond": aggragating to bond.   output size:     (num_bonds, hidden_size)    (num_bonds, hidden_size)
        -"both": aggregating to atom&bond. output size:  (num_atoms, hidden_size)    (num_bonds, hidden_size)
                                                         (num_bonds, hidden_size)    (num_atoms, hidden_size)
        :param bias: enable bias term in all linear layers.
        :param cuda: run with cuda.
        :param res_connection: enables the skip-connection in MTBlock.
        """
        super(GTransEncoder, self).__init__()

        # For the compatibility issue.
        if atom_emb_output is False:
            atom_emb_output = None
        if atom_emb_output is True:
            atom_emb_output = 'atom'

        self.hidden_size = hidden_size
        self.dropout = dropout
        self.activation = activation
        self.cuda = cuda
        self.bias = bias
        self.res_connection = res_connection
        self.edge_blocks = nn.ModuleList()
        self.node_blocks = nn.ModuleList()

        edge_input_dim = edge_fdim
        node_input_dim = node_fdim
        edge_input_dim_i = edge_input_dim
        node_input_dim_i = node_input_dim

        for i in range(num_mt_block):
            if i != 0:
                edge_input_dim_i = self.hidden_size
                node_input_dim_i = self.hidden_size
            self.edge_blocks.append(MTBlock(args=args,
                                            num_attn_head=num_attn_head,
                                            input_dim=edge_input_dim_i,
                                            hidden_size=self.hidden_size,
                                            activation=activation,
                                            dropout=dropout,
                                            bias=self.bias,
                                            atom_messages=False,
                                            cuda=cuda))
            self.node_blocks.append(MTBlock(args=args,
                                            num_attn_head=num_attn_head,
                                            input_dim=node_input_dim_i,
                                            hidden_size=self.hidden_size,
                                            activation=activation,
                                            dropout=dropout,
                                            bias=self.bias,
                                            atom_messages=True,
                                            cuda=cuda))

        self.atom_emb_output = atom_emb_output

        self.ffn_atom_from_atom = PositionwiseFeedForward(self.hidden_size + node_fdim,
                                                          self.hidden_size * 4,
                                                          activation=self.activation,
                                                          dropout=self.dropout,
                                                          d_out=self.hidden_size)

        self.ffn_atom_from_bond = PositionwiseFeedForward(self.hidden_size + node_fdim,
                                                          self.hidden_size * 4,
                                                          activation=self.activation,
                                                          dropout=self.dropout,
                                                          d_out=self.hidden_size)

        self.ffn_bond_from_atom = PositionwiseFeedForward(self.hidden_size + edge_fdim,
                                                          self.hidden_size * 4,
                                                          activation=self.activation,
                                                          dropout=self.dropout,
                                                          d_out=self.hidden_size)

        self.ffn_bond_from_bond = PositionwiseFeedForward(self.hidden_size + edge_fdim,
                                                          self.hidden_size * 4,
                                                          activation=self.activation,
                                                          dropout=self.dropout,
                                                          d_out=self.hidden_size)

        self.atom_from_atom_sublayer = SublayerConnection(size=self.hidden_size, dropout=self.dropout)
        self.atom_from_bond_sublayer = SublayerConnection(size=self.hidden_size, dropout=self.dropout)
        self.bond_from_atom_sublayer = SublayerConnection(size=self.hidden_size, dropout=self.dropout)
        self.bond_from_bond_sublayer = SublayerConnection(size=self.hidden_size, dropout=self.dropout)

        self.act_func_node = get_activation_function(self.activation)
        self.act_func_edge = get_activation_function(self.activation)

        self.dropout_layer = nn.Dropout(p=args.dropout)

    def pointwise_feed_forward_to_atom_embedding(self, emb_output, atom_fea, index, ffn_layer):
        """
        The point-wise feed forward and long-range residual connection for atom view.
        aggregate to atom.
        :param emb_output: the output embedding from the previous multi-head attentions.
        :param atom_fea: the atom/node feature embedding.
        :param index: the index of neighborhood relations.
        :param ffn_layer: the feed forward layer
        :return:
        """
        aggr_output = select_neighbor_and_aggregate(emb_output, index)
        aggr_outputx = torch.cat([atom_fea, aggr_output], dim=1)
        return ffn_layer(aggr_outputx), aggr_output

    def pointwise_feed_forward_to_bond_embedding(self, emb_output, bond_fea, a2nei, b2revb, ffn_layer):
        """
        The point-wise feed forward and long-range residual connection for bond view.
        aggregate to bond.
        :param emb_output: the output embedding from the previous multi-head attentions.
        :param bond_fea: the bond/edge feature embedding.
        :param index: the index of neighborhood relations.
        :param ffn_layer: the feed forward layer
        :return:
        """
        aggr_output = select_neighbor_and_aggregate(emb_output, a2nei)
        # remove rev bond / atom --- need for bond view
        aggr_output = self.remove_rev_bond_message(emb_output, aggr_output, b2revb)
        aggr_outputx = torch.cat([bond_fea, aggr_output], dim=1)
        return ffn_layer(aggr_outputx), aggr_output

    @staticmethod
    def remove_rev_bond_message(orginal_message, aggr_message, b2revb):
        """

        :param orginal_message:
        :param aggr_message:
        :param b2revb:
        :return:
        """
        rev_message = orginal_message[b2revb]
        return aggr_message - rev_message

    def atom_bond_transform(self,
                            to_atom=True,  # False: to bond
                            atomwise_input=None,
                            bondwise_input=None,
                            original_f_atoms=None,
                            original_f_bonds=None,
                            a2a=None,
                            a2b=None,
                            b2a=None,
                            b2revb=None
                            ):
        """
        Transfer the output of atom/bond multi-head attention to the final atom/bond output.
        :param to_atom: if true, the output is atom emebedding, otherwise, the output is bond embedding.
        :param atomwise_input: the input embedding of atom/node.
        :param bondwise_input: the input embedding of bond/edge.
        :param original_f_atoms: the initial atom features.
        :param original_f_bonds: the initial bond features.
        :param a2a: mapping from atom index to its neighbors. num_atoms * max_num_bonds
        :param a2b: mapping from atom index to incoming bond indices.
        :param b2a: mapping from bond index to the index of the atom the bond is coming from.
        :param b2revb: mapping from bond index to the index of the reverse bond.
        :return:
        """

        if to_atom:
            # atom input to atom output
            atomwise_input, _ = self.pointwise_feed_forward_to_atom_embedding(atomwise_input, original_f_atoms, a2a,
                                                                              self.ffn_atom_from_atom)
            atom_in_atom_out = self.atom_from_atom_sublayer(None, atomwise_input)
            # bond to atom
            bondwise_input, _ = self.pointwise_feed_forward_to_atom_embedding(bondwise_input, original_f_atoms, a2b,
                                                                              self.ffn_atom_from_bond)
            bond_in_atom_out = self.atom_from_bond_sublayer(None, bondwise_input)
            return atom_in_atom_out, bond_in_atom_out
        else:  # to bond embeddings

            # atom input to bond output
            atom_list_for_bond = torch.cat([b2a.unsqueeze(dim=1), a2a[b2a]], dim=1)
            atomwise_input, _ = self.pointwise_feed_forward_to_bond_embedding(atomwise_input, original_f_bonds,
                                                                              atom_list_for_bond,
                                                                              b2a[b2revb], self.ffn_bond_from_atom)
            atom_in_bond_out = self.bond_from_atom_sublayer(None, atomwise_input)
            # bond input to bond output
            bond_list_for_bond = a2b[b2a]
            bondwise_input, _ = self.pointwise_feed_forward_to_bond_embedding(bondwise_input, original_f_bonds,
                                                                              bond_list_for_bond,
                                                                              b2revb, self.ffn_bond_from_bond)
            bond_in_bond_out = self.bond_from_bond_sublayer(None, bondwise_input)
            return atom_in_bond_out, bond_in_bond_out

    def forward(self, batch, features_batch = None):
        f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a = batch
        if self.cuda or next(self.parameters()).is_cuda:
            f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.cuda(), f_bonds.cuda(), a2b.cuda(), b2a.cuda(), b2revb.cuda()
            a2a = a2a.cuda()

        node_batch = f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a
        edge_batch = f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a

        # opt pointwise_feed_forward
        original_f_atoms, original_f_bonds = f_atoms, f_bonds

        # Note: features_batch is not used here.
        for nb in self.node_blocks:  # atom messages. Multi-headed attention
            node_batch, features_batch = nb(node_batch, features_batch)
        for eb in self.edge_blocks:  # bond messages. Multi-headed attention
            edge_batch, features_batch = eb(edge_batch, features_batch)

        atom_output, _, _, _, _, _, _, _ = node_batch  # atom hidden states
        _, bond_output, _, _, _, _, _, _ = edge_batch  # bond hidden states

        if self.atom_emb_output is None:
            # output the embedding from multi-head attention directly.
            return atom_output, bond_output

        if self.atom_emb_output == 'atom':
            return self.atom_bond_transform(to_atom=True,  # False: to bond
                                            atomwise_input=atom_output,
                                            bondwise_input=bond_output,
                                            original_f_atoms=original_f_atoms,
                                            original_f_bonds=original_f_bonds,
                                            a2a=a2a,
                                            a2b=a2b,
                                            b2a=b2a,
                                            b2revb=b2revb)
        elif self.atom_emb_output == 'bond':
            return self.atom_bond_transform(to_atom=False,  # False: to bond
                                            atomwise_input=atom_output,
                                            bondwise_input=bond_output,
                                            original_f_atoms=original_f_atoms,
                                            original_f_bonds=original_f_bonds,
                                            a2a=a2a,
                                            a2b=a2b,
                                            b2a=b2a,
                                            b2revb=b2revb)
        else:  # 'both'
            atom_embeddings = self.atom_bond_transform(to_atom=True,  # False: to bond
                                                       atomwise_input=atom_output,
                                                       bondwise_input=bond_output,
                                                       original_f_atoms=original_f_atoms,
                                                       original_f_bonds=original_f_bonds,
                                                       a2a=a2a,
                                                       a2b=a2b,
                                                       b2a=b2a,
                                                       b2revb=b2revb)

            bond_embeddings = self.atom_bond_transform(to_atom=False,  # False: to bond
                                                       atomwise_input=atom_output,
                                                       bondwise_input=bond_output,
                                                       original_f_atoms=original_f_atoms,
                                                       original_f_bonds=original_f_bonds,
                                                       a2a=a2a,
                                                       a2b=a2b,
                                                       b2a=b2a,
                                                       b2revb=b2revb)
            # Notice: need to be consistent with output format of DualMPNN encoder
            return ((atom_embeddings[0], bond_embeddings[0]),
                    (atom_embeddings[1], bond_embeddings[1]))


In [34]:
class GROVEREmbedding(nn.Module):
    """
    The GROVER Embedding class. It contains the GTransEncoder.
    This GTransEncoder can be replaced by any validate encoders.
    """

    def __init__(self, args: Namespace):
        """
        Initialize the GROVEREmbedding class.
        :param args:
        """
        super(GROVEREmbedding, self).__init__()
        self.embedding_output_type = args.embedding_output_type
        edge_dim = get_bond_fdim() + get_atom_fdim()  # fdim에 대한건 4-3-4-1. mol2graph()참조
        node_dim = get_atom_fdim()
        if not hasattr(args, "backbone"):
            print("No backbone specified in args, use gtrans backbone.")
            args.backbone = "gtrans"
        if args.backbone == "gtrans" or args.backbone == "dualtrans":
            # dualtrans is the old name.
            self.encoders = GTransEncoder(args,
                                          hidden_size=args.hidden_size,
                                          edge_fdim=edge_dim,
                                          node_fdim=node_dim,
                                          dropout=args.dropout,
                                          activation=args.activation,
                                          num_mt_block=args.num_mt_block,
                                          num_attn_head=args.num_attn_head,
                                          atom_emb_output=self.embedding_output_type,
                                          bias=args.bias,
                                          cuda=args.cuda)

    def forward(self, graph_batch: List) -> Dict:
        """
        The forward function takes graph_batch as input and output a dict. The content of the dict is decided by
        self.embedding_output_type.

        :param graph_batch: the input graph batch generated by MolCollator.
        :return: a dict containing the embedding results.
        """
        output = self.encoders(graph_batch)
        if self.embedding_output_type == 'atom':
            return {"atom_from_atom": output[0], "atom_from_bond": output[1],
                    "bond_from_atom": None, "bond_from_bond": None}  # atom_from_atom, atom_from_bond
        elif self.embedding_output_type == 'bond':
            return {"atom_from_atom": None, "atom_from_bond": None,
                    "bond_from_atom": output[0], "bond_from_bond": output[1]}  # bond_from_atom, bond_from_bond
        elif self.embedding_output_type == "both":
            return {"atom_from_atom": output[0][0], "bond_from_atom": output[0][1],
                    "atom_from_bond": output[1][0], "bond_from_bond": output[1][1]}

### 2-1-4. GROVERTrainer()

In [35]:
import os
import time
from logging import Logger
from typing import List, Tuple
from collections.abc import Callable
import torch
from torch.nn import Module
from torch.utils.data import DataLoader

from grover.model.models import GroverTask
from grover.util.multi_gpu_wrapper import MultiGpuWrapper as mgw

#### 2-2-1-1. GroverTask() *중요! Loss함수코드
- 순서 : batch의 graph를 embedding_model=grover_model=GROVEREmbedding(args)를 통과시킨다. 이건 finetune의 인코더와 똑같다.

- 이때 embedding은 both로 5개의 결과가 나온다. atom/bond from bond/atom과 fg_task_all

- 손실함수
  - av_loss, bv_loss : NLLLoss(pred, target)
  - fg_atom, bond : BinaryCrossEntropyWithLogitsLoss(pred, target)
  - dist_loss : MSELoss(atom or bond from atom/bond)

In [36]:
from argparse import Namespace
from typing import List, Dict, Callable

import numpy as np
import torch
from torch import nn as nn

from grover.data import get_atom_fdim, get_bond_fdim
from grover.model.layers import Readout, GTransEncoder
from grover.util.nn_utils import get_activation_function

##### 2-2-1-1-1. Atom, Bond, FG predict
- 바로 예측을 하네.
- 그리고 예측값에 logsoftmax

In [37]:
class AtomVocabPrediction(nn.Module):
    """
    The atom-wise vocabulary prediction task. The atom vocabulary is constructed by the context.
    """
    def __init__(self, args, vocab_size, hidden_size=None):
        """
        :param args: the argument.
        :param vocab_size: the size of atom vocabulary.
        """
        super(AtomVocabPrediction, self).__init__()
        if not hidden_size:
            hidden_size = args.hidden_size
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, embeddings):
        """
        If embeddings is None: do not go through forward pass.
        :param embeddings: the atom embeddings, num_atom X fea_dim.
        :return: the prediction for each atom, num_atom X vocab_size.
        """
        if embeddings is None:
            return None
        return self.logsoftmax(self.linear(embeddings))
    
class BondVocabPrediction(nn.Module):
    """
    The bond-wise vocabulary prediction task. The bond vocabulary is constructed by the context.
    """
    def __init__(self, args, vocab_size, hidden_size=None):
        """
        Might need to use different architecture for bond vocab prediction.
        :param args:
        :param vocab_size: size of bond vocab.
        :param hidden_size: hidden size
        """
        super(BondVocabPrediction, self).__init__()
        if not hidden_size:
            hidden_size = args.hidden_size
        self.linear = nn.Linear(hidden_size, vocab_size)

        # ad-hoc here
        # If TWO_FC_4_BOND_VOCAB, we will use two distinct fc layer to deal with the bond and rev bond.
        self.TWO_FC_4_BOND_VOCAB = True
        if self.TWO_FC_4_BOND_VOCAB:
            self.linear_rev = nn.Linear(hidden_size, vocab_size)
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, embeddings):
        """
        If embeddings is None: do not go through forward pass.
        :param embeddings: the atom embeddings, num_bond X fea_dim.
        :return: the prediction for each atom, num_bond X vocab_size.
        """
        if embeddings is None:
            return None
        nm_bonds = embeddings.shape[0]  # must be an odd number
        # The bond and rev bond have odd and even ids respectively. See definition in molgraph.
        ids1 = [0] + list(range(1, nm_bonds, 2))
        ids2 = list(range(0, nm_bonds, 2))
        if self.TWO_FC_4_BOND_VOCAB:
            logits = self.linear(embeddings[ids1]) + self.linear_rev(embeddings[ids2])
        else:
            logits = self.linear(embeddings[ids1] + embeddings[ids2])

        return self.logsoftmax(logits)


class FunctionalGroupPrediction(nn.Module):
    """
    The functional group (semantic motifs) prediction task. This is a graph-level task.
    """
    def __init__(self, args, fg_size):
        """
        :param args: The arguments.
        :param fg_size: The size of semantic motifs.
        """
        super(FunctionalGroupPrediction, self).__init__()
        first_linear_dim = args.hidden_size
        hidden_size = args.hidden_size

        # In order to retain maximal information in the encoder, we use a simple readout function here.
        self.readout = Readout(rtype="mean", hidden_size=hidden_size)
        # We have four branches here. But the input with less than four branch is OK.
        # Since we use BCEWithLogitsLoss as the loss function, we only need to output logits here.
        self.linear_atom_from_atom = nn.Linear(first_linear_dim, fg_size)
        self.linear_atom_from_bond = nn.Linear(first_linear_dim, fg_size)
        self.linear_bond_from_atom = nn.Linear(first_linear_dim, fg_size)
        self.linear_bond_from_bond = nn.Linear(first_linear_dim, fg_size)

    def forward(self, embeddings: Dict, ascope: List, bscope: List) -> Dict:
        """
        The forward function of semantic motif prediction. It takes the node/bond embeddings, and the corresponding
        atom/bond scope as input and produce the prediction logits for different branches.
        :param embeddings: The input embeddings are organized as dict. The output of GROVEREmbedding.
        :param ascope: The scope for bonds. Please refer BatchMolGraph for more details.
        :param bscope: The scope for aotms. Please refer BatchMolGraph for more details.
        :return: a dict contains the predicted logits.
        """

        preds_atom_from_atom, preds_atom_from_bond, preds_bond_from_atom, preds_bond_from_bond = \
            None, None, None, None

        if embeddings["bond_from_atom"] is not None:
            preds_bond_from_atom = self.linear_bond_from_atom(self.readout(embeddings["bond_from_atom"], bscope))
        if embeddings["bond_from_bond"] is not None:
            preds_bond_from_bond = self.linear_bond_from_bond(self.readout(embeddings["bond_from_bond"], bscope))

        if embeddings["atom_from_atom"] is not None:
            preds_atom_from_atom = self.linear_atom_from_atom(self.readout(embeddings["atom_from_atom"], ascope))
        if embeddings["atom_from_bond"] is not None:
            preds_atom_from_bond = self.linear_atom_from_bond(self.readout(embeddings["atom_from_bond"], ascope))

        return {"atom_from_atom": preds_atom_from_atom, "atom_from_bond": preds_atom_from_bond,
                "bond_from_atom": preds_bond_from_atom, "bond_from_bond": preds_bond_from_bond}


In [38]:
class GroverTask(nn.Module):
    """
    The pretrain module.
    """
    def __init__(self, args, grover, atom_vocab_size, bond_vocab_size, fg_size):
        super(GroverTask, self).__init__()
        self.grover = grover
        self.av_task_atom = AtomVocabPrediction(args, atom_vocab_size)
        self.av_task_bond = AtomVocabPrediction(args, atom_vocab_size)
        self.bv_task_atom = BondVocabPrediction(args, bond_vocab_size)
        self.bv_task_bond = BondVocabPrediction(args, bond_vocab_size)

        self.fg_task_all = FunctionalGroupPrediction(args, fg_size)

        self.embedding_output_type = args.embedding_output_type

    @staticmethod
    def get_loss_func(args: Namespace) -> Callable:
        """
        The loss function generator.
        :param args: the arguments.
        :return: the loss fucntion for GroverTask.
        """
        def loss_func(preds, targets, dist_coff=args.dist_coff):
            """
            The loss function for GroverTask.
            :param preds: the predictions.
            :param targets: the targets.
            :param dist_coff: the default disagreement coefficient for the distances between different branches.
            :return:
            """
            av_task_loss = nn.NLLLoss(ignore_index=0, reduction="mean")  # same for av and bv

            fg_task_loss = nn.BCEWithLogitsLoss(reduction="mean")
            # av_task_dist_loss = nn.KLDivLoss(reduction="mean")
            av_task_dist_loss = nn.MSELoss(reduction="mean")
            fg_task_dist_loss = nn.MSELoss(reduction="mean")
            sigmoid = nn.Sigmoid()

            av_atom_loss, av_bond_loss, av_dist_loss = 0.0, 0.0, 0.0
            fg_atom_from_atom_loss, fg_atom_from_bond_loss, fg_atom_dist_loss = 0.0, 0.0, 0.0
            bv_atom_loss, bv_bond_loss, bv_dist_loss = 0.0, 0.0, 0.0
            fg_bond_from_atom_loss, fg_bond_from_bond_loss, fg_bond_dist_loss = 0.0, 0.0, 0.0

            if preds["av_task"][0] is not None:
                av_atom_loss = av_task_loss(preds['av_task'][0], targets["av_task"])
                fg_atom_from_atom_loss = fg_task_loss(preds["fg_task"]["atom_from_atom"], targets["fg_task"])

            if preds["av_task"][1] is not None:
                av_bond_loss = av_task_loss(preds['av_task'][1], targets["av_task"])
                fg_atom_from_bond_loss = fg_task_loss(preds["fg_task"]["atom_from_bond"], targets["fg_task"])

            if preds["bv_task"][0] is not None:
                bv_atom_loss = av_task_loss(preds['bv_task'][0], targets["bv_task"])
                fg_bond_from_atom_loss = fg_task_loss(preds["fg_task"]["bond_from_atom"], targets["fg_task"])

            if preds["bv_task"][1] is not None:
                bv_bond_loss = av_task_loss(preds['bv_task'][1], targets["bv_task"])
                fg_bond_from_bond_loss = fg_task_loss(preds["fg_task"]["bond_from_bond"], targets["fg_task"])

            if preds["av_task"][0] is not None and preds["av_task"][1] is not None:
                av_dist_loss = av_task_dist_loss(preds['av_task'][0], preds['av_task'][1])
                fg_atom_dist_loss = fg_task_dist_loss(sigmoid(preds["fg_task"]["atom_from_atom"]),
                                                      sigmoid(preds["fg_task"]["atom_from_bond"]))

            if preds["bv_task"][0] is not None and preds["bv_task"][1] is not None:
                bv_dist_loss = av_task_dist_loss(preds['bv_task'][0], preds['bv_task'][1])
                fg_bond_dist_loss = fg_task_dist_loss(sigmoid(preds["fg_task"]["bond_from_atom"]),
                                                      sigmoid(preds["fg_task"]["bond_from_bond"]))

            av_loss = av_atom_loss + av_bond_loss
            bv_loss = bv_atom_loss + bv_bond_loss
            fg_atom_loss = fg_atom_from_atom_loss + fg_atom_from_bond_loss
            fg_bond_loss = fg_bond_from_atom_loss + fg_bond_from_bond_loss

            fg_loss = fg_atom_loss + fg_bond_loss
            fg_dist_loss = fg_atom_dist_loss + fg_bond_dist_loss

            # dist_loss = av_dist_loss + bv_dist_loss + fg_dist_loss
            # print("%.4f %.4f %.4f %.4f %.4f %.4f"%(av_atom_loss,
            #                                       av_bond_loss,
            #                                       fg_atom_loss,
            #                                       fg_bond_loss,
            #                                       av_dist_loss,
            #                                       fg_dist_loss))
            # return av_loss + fg_loss + dist_coff * dist_loss
            overall_loss = av_loss + bv_loss + fg_loss + dist_coff * av_dist_loss + \
                           dist_coff * bv_dist_loss + fg_dist_loss

            return overall_loss, av_loss, bv_loss, fg_loss, av_dist_loss, bv_dist_loss, fg_dist_loss

        return loss_func

    def forward(self, graph_batch: List):
        """
        The forward function.
        :param graph_batch:
        :return:
        """
        _, _, _, _, _, a_scope, b_scope, _ = graph_batch
        a_scope = a_scope.data.cpu().numpy().tolist()

        embeddings = self.grover(graph_batch)

        av_task_pred_atom = self.av_task_atom(
            embeddings["atom_from_atom"])  # if None: means not go through this fowward
        av_task_pred_bond = self.av_task_bond(embeddings["atom_from_bond"])

        bv_task_pred_atom = self.bv_task_atom(embeddings["bond_from_atom"])
        bv_task_pred_bond = self.bv_task_bond(embeddings["bond_from_bond"])

        fg_task_pred_all = self.fg_task_all(embeddings, a_scope, b_scope)

        return {"av_task": (av_task_pred_atom, av_task_pred_bond),
                "bv_task": (bv_task_pred_atom, bv_task_pred_bond),
                "fg_task": fg_task_pred_all}

#### 2-2-1 GroverTrainer 코드

In [39]:
class GROVERTrainer:
    def __init__(self,
                 args,
                 embedding_model: Module,
                 atom_vocab_size: int,  # atom vocab size
                 bond_vocab_size: int,
                 fg_szie: int,
                 train_dataloader: DataLoader,
                 test_dataloader: DataLoader,
                 optimizer_builder: Callable,
                 scheduler_builder: Callable,
                 logger: Logger = None,
                 with_cuda: bool = False,
                 enable_multi_gpu: bool = False):
        """
        The init function of GROVERTrainer
        :param args: the input arguments.
        :param embedding_model: the model to generate atom/bond embeddings.
        :param atom_vocab_size: the vocabulary size of atoms.
        :param bond_vocab_size: the vocabulary size of bonds.
        :param fg_szie: the size of semantic motifs (functional groups)
        :param train_dataloader: the data loader of train data.
        :param test_dataloader: the data loader of validation data.
        :param optimizer_builder: the function of building the optimizer.
        :param scheduler_builder: the function of building the scheduler.
        :param logger: the logger
        :param with_cuda: enable gpu training.
        :param enable_multi_gpu: enable multi_gpu traning.
        """

        self.args = args
        self.with_cuda = with_cuda
        self.grover = embedding_model
        self.model = GroverTask(args, embedding_model, atom_vocab_size, bond_vocab_size, fg_szie)
        self.loss_func = self.model.get_loss_func(args)
        self.enable_multi_gpu = enable_multi_gpu

        self.atom_vocab_size = atom_vocab_size
        self.bond_vocab_size = bond_vocab_size
        self.debug = logger.debug if logger is not None else print

        if self.with_cuda:
            # print("Using %d GPUs for training." % (torch.cuda.device_count()))
            self.model = self.model.cuda()

        self.train_data = train_dataloader
        self.test_data = test_dataloader

        self.optimizer = optimizer_builder(self.model, self.args)
        self.scheduler = scheduler_builder(self.optimizer, self.args)
        if self.enable_multi_gpu:
            self.optimizer = mgw.DistributedOptimizer(self.optimizer,
                                                      named_parameters=self.model.named_parameters())
        self.args = args
        self.n_iter = 0

    def broadcast_parameters(self) -> None:
        """
        Broadcast parameters before training.
        :return: no return.
        """
        if self.enable_multi_gpu:
            # broadcast parameters & optimizer state.
            mgw.broadcast_parameters(self.model.state_dict(), root_rank=0)
            mgw.broadcast_optimizer_state(self.optimizer, root_rank=0)

    def train(self, epoch: int) -> List:
        """
        The training iteration
        :param epoch: the current epoch number.
        :return: the loss terms of current epoch.
        """
        # return self.mock_iter(epoch, self.train_data, train=True)
        return self.iter(epoch, self.train_data, train=True)

    def test(self, epoch: int) -> List:
        """
        The test/validaiion iteration
        :param epoch: the current epoch number.
        :return:  the loss terms as a list
        """
        # return self.mock_iter(epoch, self.test_data, train=False)
        return self.iter(epoch, self.test_data, train=False)

    def mock_iter(self, epoch: int, data_loader: DataLoader, train: bool = True) -> List:
        """
        Perform a mock iteration. For test only.
        :param epoch: the current epoch number.
        :param data_loader: the data loader.
        :param train: True: train model, False: validation model.
        :return: the loss terms as a list
        """

        for _, _ in enumerate(data_loader):
            self.scheduler.step()
        cum_loss_sum = 0.0
        self.n_iter += self.args.batch_size
        return self.n_iter, cum_loss_sum, (0, 0, 0, 0, 0, 0)

    def iter(self, epoch, data_loader, train=True) -> List:
        """
        Perform a training / validation iteration.
        :param epoch: the current epoch number.
        :param data_loader: the data loader.
        :param train: True: train model, False: validation model.
        :return: the loss terms as a list
        """

        if train:
            self.model.train()
        else:
            self.model.eval()

        loss_sum, iter_count = 0, 0
        cum_loss_sum, cum_iter_count = 0, 0
        av_loss_sum, bv_loss_sum, fg_loss_sum, av_dist_loss_sum, bv_dist_loss_sum, fg_dist_loss_sum = 0, 0, 0, 0, 0, 0
        # loss_func = self.model.get_loss_func(self.args)

        for _, item in enumerate(data_loader):
            batch_graph = item["graph_input"]
            targets = item["targets"]

            if next(self.model.parameters()).is_cuda:
                targets["av_task"] = targets["av_task"].cuda()
                targets["bv_task"] = targets["bv_task"].cuda()
                targets["fg_task"] = targets["fg_task"].cuda()

            preds = self.model(batch_graph)

            # # ad-hoc code, for visualizing a model, comment this block when it is not needed
            # import dglt.contrib.grover.vis_model as vis_model
            # for task in ['av_task', 'bv_task', 'fg_task']:
            #     vis_graph = vis_model.make_dot(self.model(batch_graph)[task],
            #                                    params=dict(self.model.named_parameters()))
            #     # vis_graph.view()
            #     vis_graph.render(f"{self.args.backbone}_model_{task}_vis.png", format="png")
            # exit()

            loss, av_loss, bv_loss, fg_loss, av_dist_loss, bv_dist_loss, fg_dist_loss = self.loss_func(preds, targets)

            loss_sum += loss.item()
            iter_count += self.args.batch_size

            if train:
                cum_loss_sum += loss.item()
                # Run model
                self.model.zero_grad()
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()
            else:
                # For eval model, only consider the loss of three task.
                cum_loss_sum += av_loss.item()
                cum_loss_sum += bv_loss.item()
                cum_loss_sum += fg_loss.item()

            av_loss_sum += av_loss.item()
            bv_loss_sum += bv_loss.item()
            fg_loss_sum += fg_loss.item()
            av_dist_loss_sum += av_dist_loss.item() if type(av_dist_loss) != float else av_dist_loss
            bv_dist_loss_sum += bv_dist_loss.item() if type(bv_dist_loss) != float else bv_dist_loss
            fg_dist_loss_sum += fg_dist_loss.item() if type(fg_dist_loss) != float else fg_dist_loss

            cum_iter_count += 1
            self.n_iter += self.args.batch_size

            # Debug only.
            # if i % 50 == 0:
            #     print(f"epoch: {epoch}, batch_id: {i}, av_loss: {av_loss}, bv_loss: {bv_loss}, "
            #           f"fg_loss: {fg_loss}, av_dist_loss: {av_dist_loss}, bv_dist_loss: {bv_dist_loss}, "
            #           f"fg_dist_loss: {fg_dist_loss}")

        cum_loss_sum /= cum_iter_count
        av_loss_sum /= cum_iter_count
        bv_loss_sum /= cum_iter_count
        fg_loss_sum /= cum_iter_count
        av_dist_loss_sum /= cum_iter_count
        bv_dist_loss_sum /= cum_iter_count
        fg_dist_loss_sum /= cum_iter_count

        return self.n_iter, cum_loss_sum, (av_loss_sum, bv_loss_sum, fg_loss_sum, av_dist_loss_sum,
                                           bv_dist_loss_sum, fg_dist_loss_sum)

    def save(self, epoch, file_path, name=None) -> str:
        """
        Save the intermediate models during training.
        :param epoch: the epoch number.
        :param file_path: the file_path to save the model.
        :return: the output path.
        """
        # add specific time in model fine name, in order to distinguish different saved models
        now = time.localtime()
        if name is None:
            name = "_%04d_%02d_%02d_%02d_%02d_%02d" % (
                now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min, now.tm_sec)
        output_path = file_path + name + ".ep%d" % epoch
        scaler = None
        features_scaler = None
        state = {
            'args': self.args,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler_step': self.scheduler.current_step,
            "epoch": epoch,
            'data_scaler': {
                'means': scaler.means,
                'stds': scaler.stds
            } if scaler is not None else None,
            'features_scaler': {
                'means': features_scaler.means,
                'stds': features_scaler.stds
            } if features_scaler is not None else None
        }
        torch.save(state, output_path)

        # Is this necessary?
        # if self.with_cuda:
        #    self.model = self.model.cuda()
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path

    def save_tmp(self, epoch, file_path, rank=0):
        """
        Save the models for auto-restore during training.
        The model are stored in file_path/tmp folder and will replaced on each epoch.
        :param epoch: the epoch number.
        :param file_path: the file_path to store the model.
        :param rank: the current rank (decrypted).
        :return:
        """
        store_path = os.path.join(file_path, "tmp")
        if not os.path.exists(store_path):
            os.makedirs(store_path, exist_ok=True)
        store_path = os.path.join(store_path, "model.%d" % rank)
        state = {
            'args': self.args,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler_step': self.scheduler.current_step,
            "epoch": epoch
        }
        torch.save(state, store_path)

    def restore(self, file_path, rank=0) -> Tuple[int, int]:
        """
        Restore the training state saved by save_tmp.
        :param file_path: the file_path to store the model.
        :param rank: the current rank (decrypted).
        :return: the restored epoch number and the scheduler_step in scheduler.
        """
        cpt_path = os.path.join(file_path, "tmp", "model.%d" % rank)
        if not os.path.exists(cpt_path):
            print("No checkpoint found %d")
            return 0, 0
        cpt = torch.load(cpt_path)
        self.model.load_state_dict(cpt["state_dict"])
        self.optimizer.load_state_dict(cpt["optimizer"])
        epoch = cpt["epoch"]
        scheduler_step = cpt["scheduler_step"]
        self.scheduler.current_step = scheduler_step
        print("Restore checkpoint, current epoch: %d" % (epoch))
        return epoch, scheduler_step


### 2-1-5. run_training 함수

In [40]:
def run_training(args, logger):
    """
    Run the pretrain task.
    :param args:
    :param logger:
    :return:
    """

    # initalize the logger.
    if logger is not None:
        debug, _ = logger.debug, logger.info
    else:
        debug = print

    # initialize the horovod library
    if args.enable_multi_gpu:
        mgw.init()

    # binding training to GPUs.
    master_worker = (mgw.rank() == 0) if args.enable_multi_gpu else True
    # pin GPU to local rank. By default, we use gpu:0 for training.
    local_gpu_idx = mgw.local_rank() if args.enable_multi_gpu else 0
    with_cuda = args.cuda
    if with_cuda:
        torch.cuda.set_device(local_gpu_idx)

    # get rank an  number of workers
    rank = mgw.rank() if args.enable_multi_gpu else 0
    num_replicas = mgw.size() if args.enable_multi_gpu else 1
    # print("Rank: %d Rep: %d" % (rank, num_replicas))

    # load file paths of the data.
    if master_worker:
        print(args)
        if args.enable_multi_gpu:
            debug("Total workers: %d" % (mgw.size()))
        debug('Loading data')
    data, sample_per_file = get_data(data_path=args.data_path)

    # data splitting
    if master_worker:
        debug(f'Splitting data with seed 0.')
    train_data, test_data, _ = split_data(data=data, sizes=(0.9, 0.1, 0.0), seed=0, logger=logger)

    # Here the true train data size is the train_data divided by #GPUs
    if args.enable_multi_gpu:
        args.train_data_size = len(train_data) // mgw.size()
    else:
        args.train_data_size = len(train_data)
    if master_worker:
        debug(f'Total size = {len(data):,} | '
              f'train size = {len(train_data):,} | val size = {len(test_data):,}')

    # load atom and bond vocabulary and the semantic motif labels.
    atom_vocab = MolVocab.load_vocab(args.atom_vocab_path)
    bond_vocab = MolVocab.load_vocab(args.bond_vocab_path)
    atom_vocab_size, bond_vocab_size = len(atom_vocab), len(bond_vocab)

    # Load motif vocabulary for pretrain
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if args.parser_name == 'pretrain':
        motif_vocab = [x.strip("\r\n ") for x in open(args.motif_vocab_path)]
        motif_vocab = Motif_Vocab(motif_vocab)
        #see below motif_model = Motif_Generation(motif_vocab, args.motif_hidden_size, args.motif_latent_size, 3, device, args.motif_order).to(device)
        
        
    # Hard coding here, since we haven't load any data yet!
    fg_size = 85
    shared_dict = {}
    mol_collator = GroverCollator(shared_dict=shared_dict, atom_vocab=atom_vocab, bond_vocab=bond_vocab, args=args)
    if master_worker:
        debug("atom vocab size: %d, bond vocab size: %d, Number of FG tasks: %d" % (atom_vocab_size,
                                                                                    bond_vocab_size, fg_size))

    # Define the distributed sampler. If using the single card, the sampler will be None.
    train_sampler = None
    test_sampler = None
    shuffle = True
    if args.enable_multi_gpu:
        # If not shuffle, the performance may decayed.
        train_sampler = DistributedSampler(
            train_data, num_replicas=mgw.size(), rank=mgw.rank(), shuffle=True, sample_per_file=sample_per_file)
        # Here sample_per_file in test_sampler is None, indicating the test sampler would not divide the test samples by
        # rank. (TODO: bad design here.)
        test_sampler = DistributedSampler(
            test_data, num_replicas=mgw.size(), rank=mgw.rank(), shuffle=False)
        train_sampler.set_epoch(args.epochs)
        test_sampler.set_epoch(1)
        # if we enables multi_gpu training. shuffle should be disabled.
        shuffle = False

    # Pre load data. (Maybe unnecessary. )
    pre_load_data(train_data, rank, num_replicas, sample_per_file)
    pre_load_data(test_data, rank, num_replicas)
    if master_worker:
        # print("Pre-loaded training data: %d" % train_data.count_loaded_datapoints())
        print("Pre-loaded test data: %d" % test_data.count_loaded_datapoints())

    # Build dataloader
    train_data_dl = DataLoader(train_data,
                               batch_size=args.batch_size,
                               shuffle=shuffle,
                               num_workers=12,
                               sampler=train_sampler,
                               collate_fn=mol_collator)
    test_data_dl = DataLoader(test_data,
                              batch_size=args.batch_size,
                              shuffle=shuffle,
                              num_workers=10,
                              sampler=test_sampler,
                              collate_fn=mol_collator)

    # Build the embedding model.
    grover_model = GROVEREmbedding(args)

    #  Build the trainer.
    trainer = GROVERTrainer(args=args,
                            embedding_model=grover_model,
                            atom_vocab_size=atom_vocab_size,
                            bond_vocab_size=bond_vocab_size,
                            fg_szie=fg_size,
                            train_dataloader=train_data_dl,
                            test_dataloader=test_data_dl,
                            optimizer_builder=build_optimizer,
                            scheduler_builder=build_lr_scheduler,
                            logger=logger,
                            with_cuda=with_cuda,
                            enable_multi_gpu=args.enable_multi_gpu)

    # Restore the interrupted training.
    model_dir = os.path.join(args.save_dir, "model")
    resume_from_epoch = 0
    resume_scheduler_step = 0
    if master_worker:
        resume_from_epoch, resume_scheduler_step = trainer.restore(model_dir)
    if args.enable_multi_gpu:
        resume_from_epoch = mgw.broadcast(torch.tensor(resume_from_epoch), root_rank=0, name="resume_from_epoch").item()
        resume_scheduler_step = mgw.broadcast(torch.tensor(resume_scheduler_step),
                                              root_rank=0, name="resume_scheduler_step").item()
        trainer.scheduler.current_step = resume_scheduler_step
        print("Restored epoch: %d Restored scheduler step: %d" % (resume_from_epoch, trainer.scheduler.current_step))
    trainer.broadcast_parameters()

    # Print model details.
    if master_worker:
        # Change order here.
        print(grover_model)
        print("Total parameters: %d" % param_count(trainer.grover))

    # Perform training.
    for epoch in range(resume_from_epoch + 1, args.epochs):
        s_time = time.time()

        # Data pre-loading.
        if args.enable_multi_gpu:
            train_sampler.set_epoch(epoch)
            train_data.clean_cache()
            idxs = train_sampler.get_indices()
            for local_gpu_idx in idxs:
                train_data.load_data(local_gpu_idx)
        d_time = time.time() - s_time

        # perform training and validation.
        s_time = time.time()
        _, train_loss, _ = trainer.train(epoch)
        t_time = time.time() - s_time
        s_time = time.time()
        _, val_loss, detailed_loss_val = trainer.test(epoch)
        val_av_loss, val_bv_loss, val_fg_loss, _, _, _ = detailed_loss_val
        v_time = time.time() - s_time

        # print information.
        if master_worker:
            print('Epoch: {:04d}'.format(epoch),
                  'loss_train: {:.6f}'.format(train_loss),
                  'loss_val: {:.6f}'.format(val_loss),
                  'loss_val_av: {:.6f}'.format(val_av_loss),
                  'loss_val_bv: {:.6f}'.format(val_bv_loss),
                  'loss_val_fg: {:.6f}'.format(val_fg_loss),
                  'cur_lr: {:.5f}'.format(trainer.scheduler.get_lr()[0]),
                  't_time: {:.4f}s'.format(t_time),
                  'v_time: {:.4f}s'.format(v_time),
                  'd_time: {:.4f}s'.format(d_time), flush=True)

            if epoch % args.save_interval == 0:
                trainer.save(epoch, model_dir)


            trainer.save_tmp(epoch, model_dir, rank)

    # Only save final version.
    if master_worker:
        trainer.save(args.epochs, model_dir, "")

## 3. save_moltrees

In [41]:
"""
Computes and saves molecular features for a dataset.
"""
import os
import shutil
import sys
from argparse import ArgumentParser, Namespace
from multiprocessing import Pool
from typing import List, Tuple

from tqdm import tqdm

#sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

from grover.util.utils import get_data, makedirs, load_features, save_features
from grover.data.molfeaturegenerator import get_available_features_generators, \
    get_features_generator
from grover.data.task_labels import rdkit_functional_group_label_features_generator

import pickle
from grover.topology.mol_tree import *

In [42]:
def load_temp(temp_dir: str) -> Tuple[List[List[float]], int]:
    """
    Loads all features saved as .npz files in load_dir.

    Assumes temporary files are named in order 0.npz, 1.npz, ...

    :param temp_dir: Directory in which temporary .npz files containing features are stored.
    :return: A tuple with a list of molecule features, where each molecule's features is a list of floats,
    and the number of temporary files.
    """
    features = []
    temp_num = 0
    temp_path = os.path.join(temp_dir, f'{temp_num}.npz')

    while os.path.exists(temp_path):
        features.extend(load_features(temp_path))
        temp_num += 1
        temp_path = os.path.join(temp_dir, f'{temp_num}.npz')

    return features, temp_num

In [43]:
def make_moltree(smiles):
    mol_tree = MolTree(smiles)
    mol_tree.recover()
    mol_tree.assemble()
    return mol_tree

In [44]:
def generate_and_save_moltrees(args: Namespace):
    """
    Computes and saves features for a dataset of molecules as a 2D array in a .npz file.

    :param args: Arguments.
    """
    # Create directory for save_path
    makedirs(args.save_path, isfile=True)

    # Get data and features function
    data = get_data(path=args.data_path, max_data_size=None)
    temp_save_dir = args.save_path + '_temp'

    # Load partially complete data
    if args.restart:
        if os.path.exists(args.save_path):
            os.remove(args.save_path)
        if os.path.exists(temp_save_dir):
            shutil.rmtree(temp_save_dir)
    else:
        if os.path.exists(args.save_path):
            raise ValueError(f'"{args.save_path}" already exists and args.restart is False.')

        if os.path.exists(temp_save_dir):
            moltrees, temp_num = load_temp(temp_save_dir)

    if not os.path.exists(temp_save_dir):
        makedirs(temp_save_dir)
        moltrees, temp_num = [], 0

    # Build features map function
    data = data[len(moltrees):]  # restrict to data for which features have not been computed yet
    mols = (d.smiles for d in data)
    
    if args.sequential:
        moltrees_map = map(make_moltree, mols)
    else:
        moltrees_map = Pool(30).imap(make_moltree, mols)
        
    # Get features
    temp_moltrees = []
    for i, moltree in tqdm(enumerate(moltrees_map), total=len(data)):
        temp_moltrees.append(moltree)

        # Save temporary features every save_frequency
        if (i > 0 and (i + 1) % args.save_frequency == 0) or i == len(data) - 1:
            #save_features(os.path.join(temp_save_dir, f'{temp_num}.npz'), temp_moltrees)
            moltrees.extend(temp_moltrees)
            temp_moltrees = []
            temp_num += 1

    try:
        # Save all features
        with open('mgssl_moltree.p', 'wb') as file: 
            pickle.dump(moltrees, file)

        # Remove temporary features
        shutil.rmtree(temp_save_dir)
    except OverflowError:
        print('moltree object is too large to save as a single file. Instead keeping features as a directory of files.')

In [41]:
parser = ArgumentParser()
parser.add_argument('--data_path', type=str, required=True,
                    help='Path to data CSV')
parser.add_argument('--features_generator', type=str, required=True,
                    choices=get_available_features_generators(),
                    help='Type of features to generate')
parser.add_argument('--save_path', type=str, default=None,
                    help='Path to .npz file where features will be saved as a compressed numpy archive')
parser.add_argument('--save_frequency', type=int, default=10000,
                    help='Frequency with which to save the features')
parser.add_argument('--restart', action='store_true', default=False,
                    help='Whether to not load partially complete featurization and instead start from scratch')
parser.add_argument('--max_data_size', type=int,
                    help='Maximum number of data points to load')
parser.add_argument('--sequential', action='store_true', default=False,
                    help='Whether to task sequentially rather than in parallel')
args = parser.parse_args(['--data_path', 'data/mgssl.csv', '--features_generator','fgtasklabel', '--save_path', 'data/test/mgssl2'])
if args.save_path is None:
    args.save_path = args.data_path.split('csv')[0] + 'npz'
#generate_and_save_features(args)

In [42]:
args

Namespace(data_path='data/mgssl.csv', features_generator='fgtasklabel', max_data_size=None, restart=False, save_frequency=10000, save_path='data/test/mgssl2', sequential=False)

In [43]:
makedirs(args.save_path, isfile=True)

# Get data and features function
data = get_data(path=args.data_path, max_data_size=None)
temp_save_dir = args.save_path + '_temp'

# Load partially complete data
if args.restart:
    if os.path.exists(args.save_path):
        os.remove(args.save_path)
    if os.path.exists(temp_save_dir):
        shutil.rmtree(temp_save_dir)
else:
    if os.path.exists(args.save_path):
        raise ValueError(f'"{args.save_path}" already exists and args.restart is False.')

    if os.path.exists(temp_save_dir):
        moltrees, temp_num = load_temp(temp_save_dir)

if not os.path.exists(temp_save_dir):
    makedirs(temp_save_dir)
    moltrees, temp_num = [], 0

# Build features map function
data = data[len(moltrees):]  # restrict to data for which features have not been computed yet
mols = (d.smiles for d in data)

In [44]:
if args.sequential:
    moltrees_map = map(make_moltree, mols)
else:
    moltrees_map = Pool(30).imap(make_moltree, mols)

In [45]:
# Get features
temp_moltrees = []
for i, moltree in tqdm(enumerate(moltrees_map), total=len(data)):
    temp_moltrees.append(moltree)

    # Save temporary features every save_frequency
    if (i > 0 and (i + 1) % args.save_frequency == 0) or i == len(data) - 1:
        #save_features(os.path.join(temp_save_dir, f'{temp_num}.npz'), temp_moltrees)
        moltrees.extend(temp_moltrees)
        temp_moltrees = []
        temp_num += 1

try:
    # Save all features
    #save_features(args.save_path, moltrees)

    # Remove temporary features
    shutil.rmtree(temp_save_dir)
except OverflowError:
    print('moltree object is too large to save as a single file. Instead keeping features as a directory of files.')

100%|█████████████████████████████████████████| 293/293 [00:07<00:00, 40.94it/s]


In [21]:
moltrees

[<grover.topology.mol_tree.MolTree at 0x7f70ee7f0550>,
 <grover.topology.mol_tree.MolTree at 0x7f70ee6b6ed0>,
 <grover.topology.mol_tree.MolTree at 0x7f70f02ea310>,
 <grover.topology.mol_tree.MolTree at 0x7f70ee610dd0>,
 <grover.topology.mol_tree.MolTree at 0x7f70ee6a1650>,
 <grover.topology.mol_tree.MolTree at 0x7f70ad091e50>,
 <grover.topology.mol_tree.MolTree at 0x7f70ee502f10>,
 <grover.topology.mol_tree.MolTree at 0x7f70ee6e2450>,
 <grover.topology.mol_tree.MolTree at 0x7f70ee4f2510>,
 <grover.topology.mol_tree.MolTree at 0x7f70ee505b50>,
 <grover.topology.mol_tree.MolTree at 0x7f70ee68b690>,
 <grover.topology.mol_tree.MolTree at 0x7f70ee4fe850>,
 <grover.topology.mol_tree.MolTree at 0x7f70ee5eb490>,
 <grover.topology.mol_tree.MolTree at 0x7f70ee6d1f90>,
 <grover.topology.mol_tree.MolTree at 0x7f70ee67a450>,
 <grover.topology.mol_tree.MolTree at 0x7f70ee5eb410>,
 <grover.topology.mol_tree.MolTree at 0x7f70ee874bd0>,
 <grover.topology.mol_tree.MolTree at 0x7f70ee658510>,
 <grover.t

In [22]:
import pickle
#mol_tree = MolTree(smiles)
#mol_tree.recover()
#mol_tree.assemble()
with open('mgssl_moltree.p', 'wb') as file: 
    pickle.dump(moltrees, file)

In [23]:
with open("mgssl_moltree.p", 'rb') as f:
    moltreedata = pickle.load(f)

## 4. Grover_motiftrain코드 나눠서 실행

### 3-0. parse_args

In [45]:
from grover.util.parsing import *

In [47]:
def parse_args() -> Namespace:
    """
    Parses arguments for training and testing (includes modifying/validating arguments).

    :return: A Namespace containing the parsed, modified, and validated args.
    """
    parser = ArgumentParser()
    subparser = parser.add_subparsers(title="subcommands",
                                      dest="parser_name",
                                      help="Subcommands for fintune, prediction, and fingerprint.")
    parser_finetune = subparser.add_parser('finetune', help="Fine tune the pre-trained model.")
    add_finetune_args(parser_finetune)
    parser_eval = subparser.add_parser('eval', help="Evaluate the results of the pre-trained model.")
    add_finetune_args(parser_eval)
    parser_predict = subparser.add_parser('predict', help="Predict results from fine tuned model.")
    add_predict_args(parser_predict)
    parser_fp = subparser.add_parser('fingerprint', help="Get the fingerprints of SMILES.")
    add_fingerprint_args(parser_fp)
    parser_pretrain = subparser.add_parser('pretrain', help="Pretrain with unlabelled SMILES.")
    add_pretrain_args(parser_pretrain)

    args = parser.parse_args(['pretrain','--data_path','data/zinc10M','--save_dir','model/zinc10M','--atom_vocab_path','data/zinc10M/zinc10M_atom_vocab.pkl','--bond_vocab_path','data/zinc10M/zinc10M_bond_vocab.pkl',
                              '--batch_size','100','--dropout','0.1','--depth','3','--num_attn_head','4','--hidden_size','1200','--epochs','20','--activation','PReLU','--backbone','gtrans','--embedding_output_type','both',
                              '--save_interval','5','--init_lr', '0.0002', '--max_lr', '0.0004', '--final_lr', '0.0001', '--weight_decay', '0.0000001', 
                              '--topology','--motif_vocab_path','data/zinc10M/clique.txt','--motif_hidden_size','1200','--motif_latent_size','56','--motif_order','dfs'])
    
    if args.parser_name == 'finetune' or args.parser_name == 'eval':
        modify_train_args(args)
    elif args.parser_name == "pretrain":
        modify_pretrain_args(args)
    elif args.parser_name == 'predict':
        modify_predict_args(args)
    elif args.parser_name == 'fingerprint':
        modify_fingerprint_args(args)

    return args

In [48]:
args = parse_args()
args
logger = create_logger(name='pretrain', save_dir=args.save_dir)

In [49]:
args

Namespace(activation='PReLU', atom_vocab_path='data/zinc10M/zinc10M_atom_vocab.pkl', backbone='gtrans', batch_size=100, bias=False, bond_drop_rate=0, bond_vocab_path='data/zinc10M/zinc10M_bond_vocab.pkl', cuda=True, data_path='data/zinc10M', dense=False, depth=3, dist_coff=0.1, dropout=0.1, each_epochs=5, embedding_output_type='both', enable_multi_gpu=False, epochs=20, fg_label_path=None, final_lr=0.0001, fine_tune_coff=1, hidden_size=1200, init_lr=0.0002, max_lr=0.0004, motif_hidden_size=1200, motif_latent_size=56, motif_order='dfs', motif_vocab_path='data/zinc10M/clique.txt', no_cache=True, num_attn_head=4, num_mt_block=1, parser_name='pretrain', save_dir='model/zinc10M', save_interval=5, subset_learning=False, topology=True, undirected=False, wandb=False, wandb_name='pretrain', warmup_epochs=2.0, weight_decay=1e-07)

### 3-1. MolTree

In [45]:
from grover.topology.chemutils import *
from grover.topology.mol_tree import *
from grover.topology.motif_generation import *
from grover.topology.dfs import *
from grover.topology.bfs import *

In [46]:
import rdkit
import rdkit.Chem as Chem
import numpy as np
import copy

from grover.topology.chemutils import get_clique_mol, tree_decomp, brics_decomp, get_mol, get_smiles, set_atommap, enum_assemble, decode_stereo

In [47]:
def get_slots(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()]

class Motif_Vocab(object):

    def __init__(self, smiles_list):
        self.vocab = smiles_list
        self.vmap = {x:i for i,x in enumerate(self.vocab)}
        self.slots = [get_slots(smiles) for smiles in self.vocab]
        
    def get_index(self, smiles):
        return self.vmap[smiles]
    

    def get_smiles(self, idx):
        return self.vocab[idx]

    def get_slots(self, idx):
        return copy.deepcopy(self.slots[idx])

    def size(self):
        return len(self.vocab)
    
    def add_motif(self, smiles):
        self.vocab.append(smiles)

In [48]:
class MolTreeNode(object):

    def __init__(self, smiles, clique=[]):
        self.smiles = smiles
        self.mol = get_mol(self.smiles)
        #self.mol = cmol

        self.clique = [x for x in clique] #copy
        self.neighbors = []
        
    def add_neighbor(self, nei_node):
        self.neighbors.append(nei_node)

    def recover(self, original_mol):
        clique = []
        clique.extend(self.clique)
        if not self.is_leaf:
            for cidx in self.clique:
                original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(self.nid)

        for nei_node in self.neighbors:
            clique.extend(nei_node.clique)
            if nei_node.is_leaf: #Leaf node, no need to mark 
                continue
            for cidx in nei_node.clique:
                #allow singleton node override the atom mapping
                if cidx not in self.clique or len(nei_node.clique) == 1:
                    atom = original_mol.GetAtomWithIdx(cidx)
                    atom.SetAtomMapNum(nei_node.nid)

        clique = list(set(clique))
        label_mol = get_clique_mol(original_mol, clique)
        self.label = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
        self.label_mol = get_mol(self.label)

        for cidx in clique:
            original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)

        return self.label
    
    def assemble(self):
        neighbors = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() > 1]
        neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
        singletons = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() == 1]
        neighbors = singletons + neighbors
        
        cands = enum_assemble(self, neighbors)
        
        if len(cands) > 0:
            self.cands, self.cand_mols, _ = zip(*cands)
            self.cands = list(self.cands)
            self.cand_mols = list(self.cand_mols)
        else:
            self.cands = []
            self.cand_mols = []


In [49]:
class MolTree(object):

    def __init__(self, smiles):
        self.smiles = smiles
        self.mol = get_mol(smiles)

        '''
        #Stereo Generation
        mol = Chem.MolFromSmiles(smiles)
        self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
        self.smiles2D = Chem.MolToSmiles(mol)
        self.stereo_cands = decode_stereo(self.smiles2D)
        '''

        cliques, edges = brics_decomp(self.mol)
        if len(edges) <= 1:
            cliques, edges = tree_decomp(self.mol)
        self.nodes = []
        root = 0
        for i,c in enumerate(cliques):
            cmol = get_clique_mol(self.mol, c)
            node = MolTreeNode(get_smiles(cmol), c)
            self.nodes.append(node)
            if min(c) == 0:
                root = i

        for x,y in edges:
            self.nodes[x].add_neighbor(self.nodes[y])
            self.nodes[y].add_neighbor(self.nodes[x])
        
        if root > 0:
            self.nodes[0],self.nodes[root] = self.nodes[root],self.nodes[0]

        for i,node in enumerate(self.nodes):
            node.nid = i + 1
            if len(node.neighbors) > 1: #Leaf node mol is not marked
                set_atommap(node.mol, node.nid)
            node.is_leaf = (len(node.neighbors) == 1)

    def size(self):
        return len(self.nodes)

    def recover(self):
        for node in self.nodes:
            node.recover(self.mol)

    def assemble(self):
        for node in self.nodes:
            node.assemble()

### 3-2. get_motif_data():

In [50]:
from torch.utils.data.dataset import Dataset
from typing import Union, List

In [51]:
def load_moltrees(path: str) -> np.ndarray:
    """
    Loads features saved in a variety of formats.

    Supported formats:
    - .npz compressed (assumes features are saved with name "features")

    All formats assume that the SMILES strings loaded elsewhere in the code are in the same
    order as the features loaded here.

    :param path: Path to a file containing features.
    :return: A 2D numpy array of size (num_molecules, features_size) containing the features.
    """
    extension = os.path.splitext(path)[1]

    if extension == '.p':
        with open(path, 'rb') as f:
            moltrees = pickle.load(f)
    else:
        raise ValueError(f'Features path extension {extension} not supported.')

    return moltrees

#### 3-2-1 MoleculeDataPoint_motif

In [52]:
class MoleculeDatapoint_motif:
    """A MoleculeDatapoint contains a single molecule and its associated features and targets."""

    def __init__(self,
                 line: List[str],
                 args: Namespace = None,
                 features: np.ndarray = None,
                 moltrees: object = None,
                 use_compound_names: bool = False):
        """
        Initializes a MoleculeDatapoint, which contains a single molecule.

        :param line: A list of strings generated by separating a line in a data CSV file by comma.
        :param args: Arguments.
        :param features: A numpy array containing additional features (ex. Morgan fingerprint).
        :param use_compound_names: Whether the data CSV includes the compound name on each line.
        """
        self.features_generator = None
        self.args = None
        if args is not None:
            if hasattr(args, "features_generator"):
                self.features_generator = args.features_generator
            self.args = args

        if features is not None and self.features_generator is not None:
            raise ValueError('Currently cannot provide both loaded features and a features generator.')

        self.features = features
        self.moltrees = moltrees

        if use_compound_names:
            self.compound_name = line[0]  # str
            line = line[1:]
        else:
            self.compound_name = None

        self.smiles = line[0]  # str


        # Generate additional features if given a generator
        if self.features_generator is not None:
            self.features = []
            mol = Chem.MolFromSmiles(self.smiles)
            for fg in self.features_generator:
                features_generator = get_features_generator(fg)
                if mol is not None and mol.GetNumHeavyAtoms() > 0:
                    if fg in ['morgan', 'morgan_count']:
                        self.features.extend(features_generator(mol, num_bits=args.num_bits))
                    else:
                        self.features.extend(features_generator(mol))

            self.features = np.array(self.features)

        # Fix nans in features
        if self.features is not None:
            replace_token = 0
            self.features = np.where(np.isnan(self.features), replace_token, self.features)

        # Create targets
        self.targets = [float(x) if x != '' else None for x in line[1:]]

    def set_features(self, features: np.ndarray):
        """
        Sets the features of the molecule.

        :param features: A 1-D numpy array of features for the molecule.
        """
        self.features = features
        
    def set_moltrees(self, moltrees: list):
        """
        Sets the features of the molecule.

        :param features: A 1-D numpy array of features for the molecule.
        """
        self.moltrees = moltrees

    def num_tasks(self) -> int:
        """
        Returns the number of prediction tasks.

        :return: The number of tasks.
        """
        return len(self.targets)

    def set_targets(self, targets: List[float]):
        """
        Sets the targets of a molecule.

        :param targets: A list of floats containing the targets.
        """
        self.targets = targets

In [53]:
class BatchDatapoint_motif:
    def __init__(self,
                 smiles_file,
                 feature_file,
                 moltree_file,
                 n_samples,
                 ):
        self.smiles_file = smiles_file
        self.feature_file = feature_file
        self.moltree_file = moltree_file
        # deal with the last batch graph numbers.
        self.n_samples = n_samples
        self.datapoints = None

    def load_datapoints(self):
        features = self.load_feature()
        moltrees = self.load_moltree()
        self.datapoints = []

        with open(self.smiles_file) as f:
            reader = csv.reader(f)
            next(reader)
            for i, line in enumerate(reader):
                # line = line[0]
                d = MoleculeDatapoint_motif(line=line,
                                      features=features[i],
                                      moltrees=moltrees[i])
                self.datapoints.append(d)

        assert len(self.datapoints) == self.n_samples

    def load_feature(self):
        return feautils.load_features(self.feature_file)
    
    def load_moltree(self):
        return feautils.load_moltrees(self.moltree_file)

    def shuffle(self):
        pass

    def clean_cache(self):
        del self.datapoints
        self.datapoints = None

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        assert self.datapoints is not None
        return self.datapoints[idx]

    def is_loaded(self):
        return self.datapoints is not None

In [56]:
class BatchMolDataset_motif(Dataset):
    def __init__(self, data: List[BatchDatapoint_motif],
                 graph_per_file=None):
        self.data = data

        self.len = 0
        for d in self.data:
            self.len += len(d)
        if graph_per_file is not None:
            self.sample_per_file = graph_per_file
        else:
            self.sample_per_file = len(self.data[0]) if len(self.data) != 0 else None

    def shuffle(self, seed: int = None):
        pass

    def clean_cache(self):
        for d in self.data:
            d.clean_cache()

    def __len__(self) -> int:
        return self.len

    def __getitem__(self, idx) -> Union[MoleculeDatapoint_motif, List[MoleculeDatapoint_motif]]:
        # print(idx)
        dp_idx = int(idx / self.sample_per_file)
        real_idx = idx % self.sample_per_file
        return self.data[dp_idx][real_idx]

    def load_data(self, idx):
        dp_idx = int(idx / self.sample_per_file)
        if not self.data[dp_idx].is_loaded():
            self.data[dp_idx].load_datapoints()

    def count_loaded_datapoints(self):
        res = 0
        for d in self.data:
            if d.is_loaded():
                res += 1
        return res

In [57]:
def get_motif_data(data_path, logger=None):
    """
    Load data from the data_path.
    :param data_path: the data_path.
    :param logger: the logger.
    :return:
    """
    debug = logger.debug if logger is not None else print
    summary_path = os.path.join(data_path, "summary.txt")
    smiles_path = os.path.join(data_path, "graph")
    feature_path = os.path.join(data_path, "feature")
    moltree_path = os.path.join(data_path, "moltrees")

    fin = open(summary_path)
    n_files = int(fin.readline().strip().split(":")[-1])
    n_samples = int(fin.readline().strip().split(":")[-1])
    sample_per_file = int(fin.readline().strip().split(":")[-1])
    debug("Loading data:")
    debug("Number of files: %d" % n_files)
    debug("Number of samples: %d" % n_samples)
    debug("Samples/file: %d" % sample_per_file)

    datapoints = []
    for i in range(n_files):
        smiles_path_i = os.path.join(smiles_path, str(i) + ".csv")
        feature_path_i = os.path.join(feature_path, str(i) + ".npz")
        moltree_path_i = os.path.join(moltree_path, str(i) + ".p")
        n_samples_i = sample_per_file if i != (n_files - 1) else n_samples % sample_per_file
        datapoints.append(BatchDatapoint_motif(smiles_path_i, feature_path_i, moltree_path_i, n_samples_i))
    return BatchMolDataset_motif(datapoints), sample_per_file

### 3-2. Grover_MotifGeneration

#### 3-2-1. GroverMotifcollator

In [65]:
class GroverMotifCollator(object):
    def __init__(self, shared_dict, atom_vocab, bond_vocab, args):
        self.args = args
        self.shared_dict = shared_dict
        self.atom_vocab = atom_vocab
        self.bond_vocab = bond_vocab

    def atom_random_mask(self, smiles_batch):
        """
        Perform the random mask operation on atoms.
        :param smiles_batch:
        :return: The corresponding atom labels.
        """
        # There is a zero padding.
        vocab_label = [0]
        percent = 0.15
        for smi in smiles_batch:
            mol = Chem.MolFromSmiles(smi)
            mlabel = [0] * mol.GetNumAtoms()
            n_mask = math.ceil(mol.GetNumAtoms() * percent)
            perm = np.random.permutation(mol.GetNumAtoms())[:n_mask]
            for p in perm:
                atom = mol.GetAtomWithIdx(int(p))
                mlabel[p] = self.atom_vocab.stoi.get(atom_to_vocab(mol, atom), self.atom_vocab.other_index)

            vocab_label.extend(mlabel)
        return vocab_label

    def bond_random_mask(self, smiles_batch):
        """
        Perform the random mask operaiion on bonds.
        :param smiles_batch:
        :return: The corresponding bond labels.
        """
        # There is a zero padding.
        vocab_label = [0]
        percent = 0.15
        for smi in smiles_batch:
            mol = Chem.MolFromSmiles(smi)
            nm_atoms = mol.GetNumAtoms()
            nm_bonds = mol.GetNumBonds()
            mlabel = []
            n_mask = math.ceil(nm_bonds * percent)
            perm = np.random.permutation(nm_bonds)[:n_mask]
            virtual_bond_id = 0
            for a1 in range(nm_atoms):
                for a2 in range(a1 + 1, nm_atoms):
                    bond = mol.GetBondBetweenAtoms(a1, a2)

                    if bond is None:
                        continue
                    if virtual_bond_id in perm:
                        label = self.bond_vocab.stoi.get(bond_to_vocab(mol, bond), self.bond_vocab.other_index)
                        mlabel.extend([label])
                    else:
                        mlabel.extend([0])

                    virtual_bond_id += 1
            # todo: might need to consider bond_drop_rate
            # todo: double check reverse bond
            vocab_label.extend(mlabel)
        return vocab_label

    def __call__(self, batch):
        smiles_batch = [d.smiles for d in batch]
        batchgraph = mol2graph(smiles_batch, self.shared_dict, self.args).get_components()

        atom_vocab_label = torch.Tensor(self.atom_random_mask(smiles_batch)).long()
        bond_vocab_label = torch.Tensor(self.bond_random_mask(smiles_batch)).long()
        fgroup_label = torch.Tensor([d.features for d in batch]).float()
        moltree_batch = [d.moltrees for d in batch]
        
        # may be some mask here
        res = {"graph_input": batchgraph,
               "targets": {"av_task": atom_vocab_label,
                           "bv_task": bond_vocab_label,
                           "fg_task": fgroup_label},
               "moltree" : moltree_batch
               }
        return res

#### 3-2-2. dfs

In [66]:
import torch
import torch.nn as nn
from grover.topology.mol_tree import Motif_Vocab, MolTree, MolTreeNode
from grover.topology.chemutils import enum_assemble
# add this directly below (from nnutils import create_var, GRU)
import copy

from torch.autograd import Variable

MAX_NB = 8
MAX_DECODE_LEN = 100

def create_var(tensor, requires_grad=None):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'   #"cuda:" + "1" 다중 처리 때문에 이렇게 했나봐 ㅡㅡ 원본 : torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
    if requires_grad is None:
        return Variable(tensor).to(device)
    else:
        return Variable(tensor, requires_grad=requires_grad).to(device)

def index_select_ND(source, dim, index):
    index_size = index.size()
    suffix_dim = source.size()[1:]
    final_size = index_size + suffix_dim
    target = source.index_select(dim, index.view(-1))
    return target.view(final_size)

def GRU(x, h_nei, W_z, W_r, U_r, W_h):
    hidden_size = x.size()[-1]
    sum_h = h_nei.sum(dim=1)
    z_input = torch.cat([x,sum_h], dim=1)
    z = nn.Sigmoid()(W_z(z_input))

    r_1 = W_r(x).view(-1,1,hidden_size)
    r_2 = U_r(h_nei)
    r = nn.Sigmoid()(r_1 + r_2)
    
    gated_h = r * h_nei
    sum_gated_h = gated_h.sum(dim=1)
    h_input = torch.cat([x,sum_gated_h], dim=1)
    pre_h = nn.Tanh()(W_h(h_input))
    new_h = (1.0 - z) * sum_h + z * pre_h
    return new_h

def dfs(stack, x, fa):
    for y in x.neighbors:
        if y.idx == fa.idx:
            continue
        stack.append((x, y, 1))
        dfs(stack, y, x)
        stack.append((y, x, 0))


def have_slots(fa_slots, ch_slots):
    if len(fa_slots) > 2 and len(ch_slots) > 2:
        return True
    matches = []
    for i, s1 in enumerate(fa_slots):
        a1, c1, h1 = s1
        for j, s2 in enumerate(ch_slots):
            a2, c2, h2 = s2
            if a1 == a2 and c1 == c2 and (a1 != "C" or h1 + h2 >= 4):
                matches.append((i, j))

    if len(matches) == 0: return False

    fa_match, ch_match = zip(*matches)
    if len(set(fa_match)) == 1 and 1 < len(fa_slots) <= 2:  # never remove atom from ring
        fa_slots.pop(fa_match[0])
    if len(set(ch_match)) == 1 and 1 < len(ch_slots) <= 2:  # never remove atom from ring
        ch_slots.pop(ch_match[0])

    return True


def can_assemble(node_x, node_y):
    neis = node_x.neighbors + [node_y]
    for i, nei in enumerate(neis):
        nei.nid = i

    neighbors = [nei for nei in neis if nei.mol.GetNumAtoms() > 1]
    neighbors = sorted(neighbors, key=lambda x: x.mol.GetNumAtoms(), reverse=True)
    singletons = [nei for nei in neis if nei.mol.GetNumAtoms() == 1]
    neighbors = singletons + neighbors
    cands = enum_assemble(node_x, neighbors)
    return len(cands) > 0

In [67]:
class Motif_Generation_dfs(nn.Module):

    def __init__(self, vocab, hidden_size, device):       #각종 초기값들 설정해두기
        super(Motif_Generation_dfs, self).__init__()
        self.hidden_size = hidden_size
        self.vocab_size = vocab.size()
        self.vocab = vocab
        self.device = device

        # GRU Weights
        self.W_z = nn.Linear(2 * hidden_size, hidden_size)
        self.U_r = nn.Linear(hidden_size, hidden_size, bias=False)
        self.W_r = nn.Linear(hidden_size, hidden_size)
        self.W_h = nn.Linear(2 * hidden_size, hidden_size)

        # Feature Aggregate Weights
        self.W = nn.Linear(hidden_size, hidden_size)
        self.U = nn.Linear(2 * hidden_size, hidden_size)

        # Output Weights
        self.W_o = nn.Linear(hidden_size, self.vocab_size)
        self.U_s = nn.Linear(hidden_size, 1)

        # Loss Functions
        self.pred_loss = nn.CrossEntropyLoss(size_average=False)
        self.stop_loss = nn.BCEWithLogitsLoss(size_average=False)

    def get_trace(self, node):		#trace란것을 얻을건데, 
        super_root = MolTreeNode("")
        super_root.idx = -1
        trace = []
        dfs(trace, node, super_root)
        return [(x.smiles, y.smiles, z) for x, y, z in trace]

    def forward(self, mol_batch, node_rep):
        super_root = MolTreeNode("")
        super_root.idx = -1

        # Initialize
        pred_hiddens, pred_targets = [], []
        stop_hiddens, stop_targets = [], []
        traces = []
        for mol_tree in mol_batch:
            s = []
            dfs(s, mol_tree.nodes[0], super_root)
            traces.append(s)
            for node in mol_tree.nodes:
                node.neighbors = []
        '''
        # Predict Root
        pred_hiddens.append(create_var(torch.zeros(len(mol_batch), self.hidden_size)))
        pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch])
        pred_mol_vecs.append(mol_vec)
        '''

        max_iter = max([len(tr) for tr in traces])
        padding = create_var(torch.zeros(self.hidden_size), False)
        h = {}

        for t in range(max_iter):
            prop_list = []
            batch_list = []
            for i, plist in enumerate(traces):
                if t < len(plist):
                    prop_list.append(plist[t])
                    batch_list.append(i)
                else:
                    prop_list.append(None)

            em_list = []
            cur_h_nei, cur_o_nei = [], []

            for mol_index, prop in enumerate(prop_list):
                if prop is None:
                    continue
                node_x, real_y, _ = prop
                # Neighbors for message passing (target not included)
                cur_nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx]
                pad_len = MAX_NB - len(cur_nei)
                if pad_len>= 0:
                    cur_h_nei.extend(cur_nei)
                    cur_h_nei.extend([padding] * pad_len)
                else:
                    cur_h_nei.extend(cur_nei[:MAX_NB])

                # Neighbors for stop prediction (all neighbors)
                cur_nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors]
                pad_len = MAX_NB - len(cur_nei)
                if pad_len >= 0:
                    cur_o_nei.extend(cur_nei)
                    cur_o_nei.extend([padding] * pad_len)
                else:
                    cur_o_nei.extend(cur_nei[:MAX_NB])


                # Current clique embedding
                em_list.append(torch.sum(node_rep[mol_index].index_select(0, torch.tensor(node_x.clique).to(self.device)), dim=0))

            # Clique embedding
            cur_x = torch.stack(em_list, dim=0)

            # Message passing
            cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1, MAX_NB, self.hidden_size)
            new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)

            # Node Aggregate
            cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, self.hidden_size)
            cur_o = cur_o_nei.sum(dim=1)

            # Gather targets
            pred_target, pred_list = [], []
            stop_target = []
            prop_list = [x for x in prop_list if x is not None]
            for i, m in enumerate(prop_list):
                node_x, node_y, direction = m
                x, y = node_x.idx, node_y.idx
                h[(x, y)] = new_h[i]
                node_y.neighbors.append(node_x)
                if direction == 1:
                    pred_target.append(node_y.wid)
                    pred_list.append(i)
                stop_target.append(direction)

            # Hidden states for stop prediction
            stop_hidden = torch.cat([cur_x, cur_o], dim=1)
            stop_hiddens.append(stop_hidden)
            stop_targets.extend(stop_target)

            # Hidden states for clique prediction
            if len(pred_list) > 0:
                #batch_list = [batch_list[i] for i in pred_list]
                #cur_batch = create_var(torch.LongTensor(batch_list))
                #pred_mol_vecs.append(mol_vec.index_select(0, cur_batch))

                cur_pred = create_var(torch.LongTensor(pred_list))
                pred_hiddens.append(new_h.index_select(0, cur_pred))
                pred_targets.extend(pred_target)

        # Last stop at root
        em_list, cur_o_nei = [], []
        for mol_index, mol_tree in enumerate(mol_batch):
            node_x = mol_tree.nodes[0]
            em_list.append(torch.sum(node_rep[mol_index].index_select(0, torch.tensor(node_x.clique).to(self.device)), dim=0))
            cur_nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors]
            pad_len = MAX_NB - len(cur_nei)
            cur_o_nei.extend(cur_nei)
            cur_o_nei.extend([padding] * pad_len)

        cur_x = torch.stack(em_list, dim=0)
        cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, self.hidden_size)
        cur_o = cur_o_nei.sum(dim=1)

        stop_hidden = torch.cat([cur_x, cur_o], dim=1)
        stop_hiddens.append(stop_hidden)
        stop_targets.extend([0] * len(mol_batch))

        # Predict next clique
        pred_hiddens = torch.cat(pred_hiddens, dim=0)
        #pred_mol_vecs = torch.cat(pred_mol_vecs, dim=0)
        #pred_vecs = torch.cat([pred_hiddens, pred_mol_vecs], dim=1)
        pred_vecs = pred_hiddens
        pred_vecs = nn.ReLU()(self.W(pred_vecs))
        pred_scores = self.W_o(pred_vecs)
        pred_targets = create_var(torch.LongTensor(pred_targets))

        pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch)
        _, preds = torch.max(pred_scores, dim=1)
        pred_acc = torch.eq(preds, pred_targets).float()
        pred_acc = torch.sum(pred_acc) / pred_targets.nelement()

        # Predict stop
        stop_hiddens = torch.cat(stop_hiddens, dim=0)
        stop_vecs = nn.ReLU()(self.U(stop_hiddens))
        stop_scores = self.U_s(stop_vecs).squeeze()
        stop_targets = create_var(torch.Tensor(stop_targets))

        stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch)
        stops = torch.ge(stop_scores, 0).float()
        stop_acc = torch.eq(stops, stop_targets).float()
        stop_acc = torch.sum(stop_acc) / stop_targets.nelement()

        return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()

#### 3-2-3. MOotif_Generation

In [68]:
class Motif_Generation(nn.Module):

    def __init__(self, vocab, hidden_size, latent_size, depth, device, order):
        super(Motif_Generation, self).__init__()
        self.vocab = vocab
        self.hidden_size = hidden_size
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.depth = depth
        self.device = device
        if order == 'dfs':
            self.decoder = Motif_Generation_dfs(vocab, hidden_size, self.device)
        elif order == 'bfs':
            self.decoder = Motif_Generation_bfs(vocab, hidden_size, self.device)

    def forward(self, mol_batch, node_rep):
        set_batch_nodeID(mol_batch, self.vocab)

        word_loss, topo_loss, word_acc, topo_acc = self.decoder(mol_batch, node_rep)

        return word_loss, topo_loss, word_acc, topo_acc

#### 3-2-4. Grover_MotifTask

In [69]:
class GroverMotifTask(nn.Module):
    """
    The pretrain module.
    """
    def __init__(self, args, grover, atom_vocab_size, bond_vocab_size, fg_size):
        super(GroverMotifTask, self).__init__()
        self.grover = grover
        self.av_task_atom = AtomVocabPrediction(args, atom_vocab_size)
        self.av_task_bond = AtomVocabPrediction(args, atom_vocab_size)
        self.bv_task_atom = BondVocabPrediction(args, bond_vocab_size)
        self.bv_task_bond = BondVocabPrediction(args, bond_vocab_size)

        self.fg_task_all = FunctionalGroupPrediction(args, fg_size)

        self.embedding_output_type = args.embedding_output_type

    @staticmethod
    def get_loss_func(args: Namespace) -> Callable:
        """
        The loss function generator.
        :param args: the arguments.
        :return: the loss fucntion for GroverTask.
        """
        def loss_func(preds, targets, dist_coff=args.dist_coff):
            """
            The loss function for GroverTask.
            :param preds: the predictions.
            :param targets: the targets.
            :param dist_coff: the default disagreement coefficient for the distances between different branches.
            :return:
            """
            av_task_loss = nn.NLLLoss(ignore_index=0, reduction="mean")  # same for av and bv

            fg_task_loss = nn.BCEWithLogitsLoss(reduction="mean")
            # av_task_dist_loss = nn.KLDivLoss(reduction="mean")
            av_task_dist_loss = nn.MSELoss(reduction="mean")
            fg_task_dist_loss = nn.MSELoss(reduction="mean")
            sigmoid = nn.Sigmoid()

            av_atom_loss, av_bond_loss, av_dist_loss = 0.0, 0.0, 0.0
            fg_atom_from_atom_loss, fg_atom_from_bond_loss, fg_atom_dist_loss = 0.0, 0.0, 0.0
            bv_atom_loss, bv_bond_loss, bv_dist_loss = 0.0, 0.0, 0.0
            fg_bond_from_atom_loss, fg_bond_from_bond_loss, fg_bond_dist_loss = 0.0, 0.0, 0.0

            if preds["av_task"][0] is not None:
                av_atom_loss = av_task_loss(preds['av_task'][0], targets["av_task"])
                fg_atom_from_atom_loss = fg_task_loss(preds["fg_task"]["atom_from_atom"], targets["fg_task"])

            if preds["av_task"][1] is not None:
                av_bond_loss = av_task_loss(preds['av_task'][1], targets["av_task"])
                fg_atom_from_bond_loss = fg_task_loss(preds["fg_task"]["atom_from_bond"], targets["fg_task"])

            if preds["bv_task"][0] is not None:
                bv_atom_loss = av_task_loss(preds['bv_task'][0], targets["bv_task"])
                fg_bond_from_atom_loss = fg_task_loss(preds["fg_task"]["bond_from_atom"], targets["fg_task"])

            if preds["bv_task"][1] is not None:
                bv_bond_loss = av_task_loss(preds['bv_task'][1], targets["bv_task"])
                fg_bond_from_bond_loss = fg_task_loss(preds["fg_task"]["bond_from_bond"], targets["fg_task"])

            if preds["av_task"][0] is not None and preds["av_task"][1] is not None:
                av_dist_loss = av_task_dist_loss(preds['av_task'][0], preds['av_task'][1])
                fg_atom_dist_loss = fg_task_dist_loss(sigmoid(preds["fg_task"]["atom_from_atom"]),
                                                      sigmoid(preds["fg_task"]["atom_from_bond"]))

            if preds["bv_task"][0] is not None and preds["bv_task"][1] is not None:
                bv_dist_loss = av_task_dist_loss(preds['bv_task'][0], preds['bv_task'][1])
                fg_bond_dist_loss = fg_task_dist_loss(sigmoid(preds["fg_task"]["bond_from_atom"]),
                                                      sigmoid(preds["fg_task"]["bond_from_bond"]))
                
            #if 

            av_loss = av_atom_loss + av_bond_loss
            bv_loss = bv_atom_loss + bv_bond_loss
            fg_atom_loss = fg_atom_from_atom_loss + fg_atom_from_bond_loss
            fg_bond_loss = fg_bond_from_atom_loss + fg_bond_from_bond_loss

            fg_loss = fg_atom_loss + fg_bond_loss
            fg_dist_loss = fg_atom_dist_loss + fg_bond_dist_loss

            # dist_loss = av_dist_loss + bv_dist_loss + fg_dist_loss
            # print("%.4f %.4f %.4f %.4f %.4f %.4f"%(av_atom_loss,
            #                                       av_bond_loss,
            #                                       fg_atom_loss,
            #                                       fg_bond_loss,
            #                                       av_dist_loss,
            #                                       fg_dist_loss))
            # return av_loss + fg_loss + dist_coff * dist_loss
            overall_loss = av_loss + bv_loss + fg_loss + dist_coff * av_dist_loss + \
                           dist_coff * bv_dist_loss + fg_dist_loss

            return overall_loss, av_loss, bv_loss, fg_loss, av_dist_loss, bv_dist_loss, fg_dist_loss

        return loss_func

    def forward(self, graph_batch: List):
        """
        The forward function.
        :param graph_batch:
        :return:
        """
        _, _, _, _, _, a_scope, b_scope, _ = graph_batch
        a_scope = a_scope.data.cpu().numpy().tolist()

        embeddings = self.grover(graph_batch)

        av_task_pred_atom = self.av_task_atom(
            embeddings["atom_from_atom"])  # if None: means not go through this fowward
        av_task_pred_bond = self.av_task_bond(embeddings["atom_from_bond"])

        bv_task_pred_atom = self.bv_task_atom(embeddings["bond_from_atom"])
        bv_task_pred_bond = self.bv_task_bond(embeddings["bond_from_bond"])

        fg_task_pred_all = self.fg_task_all(embeddings, a_scope, b_scope)

        return {"av_task": (av_task_pred_atom, av_task_pred_bond),
                "bv_task": (bv_task_pred_atom, bv_task_pred_bond),
                "fg_task": fg_task_pred_all,
                "emb_vec": embeddings}

#### group_node_rep

In [70]:
def group_node_rep(moltree, node_rep, batch_graph):
    group = []
    count = 1
    for i in range(len(moltree)):
        num=batch_graph[5][i][1]
        group.append(node_rep[count:count + num])		# count += num번째 node의 표현을 그룹에 더해라
        count += num
    return group						# 최종 그룹을 출력

In [71]:
def group_edge_rep(moltree, edge_rep, batch_graph):
    group = []
    count = 1
    for i in range(len(moltree)):
        num=batch_graph[6][i][1]
        group.append(edge_rep[count:count + num])		# count += num번째 node의 표현을 그룹에 더해라
        count += num
    return group						# 최종 그룹을 출력

#### 3-2-5. Grover_motiftrainer

In [72]:
class GROVERMotifTrainer:
    def __init__(self,
                 args,
                 embedding_model: Module,
                 topology_model: Module,
                 atom_vocab_size: int,  # atom vocab size
                 bond_vocab_size: int,
                 fg_size: int,
                 train_dataloader: DataLoader,
                 test_dataloader: DataLoader,
                 optimizer_builder: Callable,
                 scheduler_builder: Callable,
                 logger: Logger = None,
                 with_cuda: bool = False,
                 enable_multi_gpu: bool = False):
        """
        The init function of GROVERTrainer
        :param args: the input arguments.
        :param embedding_model: the model to generate atom/bond embeddings.
        :param topology_model : the model to predict topology of molecule from embeddings
        :param atom_vocab_size: the vocabulary size of atoms.
        :param bond_vocab_size: the vocabulary size of bonds.
        :param fg_size: the size of semantic motifs (functional groups)
        :param train_dataloader: the data loader of train data.
        :param test_dataloader: the data loader of validation data.
        :param optimizer_builder: the function of building the optimizer.
        :param scheduler_builder: the function of building the scheduler.
        :param logger: the logger
        :param with_cuda: enable gpu training.
        :param enable_multi_gpu: enable multi_gpu traning.
        """

        self.args = args
        self.with_cuda = with_cuda
        self.grover = embedding_model
        self.model = GroverMotifTask(args, embedding_model, atom_vocab_size, bond_vocab_size, fg_size)
        self.motif_model = topology_model
        self.loss_func = self.model.get_loss_func(args)
        self.enable_multi_gpu = enable_multi_gpu

        self.atom_vocab_size = atom_vocab_size
        self.bond_vocab_size = bond_vocab_size
        self.debug = logger.debug if logger is not None else print

        if self.with_cuda:
            # print("Using %d GPUs for training." % (torch.cuda.device_count()))
            self.model = self.model.cuda()
            self.motif_model = self.motif_model.cuda()

        self.train_data = train_dataloader
        self.test_data = test_dataloader

        self.optimizer = optimizer_builder(self.model, self.args)
        self.motif_optimizer = torch.optim.Adam(self.motif_model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay)
        self.scheduler = scheduler_builder(self.optimizer, self.args)
        if self.enable_multi_gpu:
            self.optimizer = mgw.DistributedOptimizer(self.optimizer,
                                                      named_parameters=self.model.named_parameters())
        self.args = args
        self.n_iter = 0

    def broadcast_parameters(self) -> None:
        """
        Broadcast parameters before training.
        :return: no return.
        """
        if self.enable_multi_gpu:
            # broadcast parameters & optimizer state.
            mgw.broadcast_parameters(self.model.state_dict(), root_rank=0)
            mgw.broadcast_optimizer_state(self.optimizer, root_rank=0)

    def train(self, epoch: int) -> List:
        """
        The training iteration
        :param epoch: the current epoch number.
        :return: the loss terms of current epoch.
        """
        # return self.mock_iter(epoch, self.train_data, train=True)
        return self.iter(epoch, self.train_data, train=True)

    def test(self, epoch: int) -> List:
        """
        The test/validaiion iteration
        :param epoch: the current epoch number.
        :return:  the loss terms as a list
        """
        # return self.mock_iter(epoch, self.test_data, train=False)
        return self.iter(epoch, self.test_data, train=False)

    def mock_iter(self, epoch: int, data_loader: DataLoader, train: bool = True) -> List:
        """
        Perform a mock iteration. For test only.
        :param epoch: the current epoch number.
        :param data_loader: the data loader.
        :param train: True: train model, False: validation model.
        :return: the loss terms as a list
        """

        for _, _ in enumerate(data_loader):
            self.scheduler.step()
        cum_loss_sum = 0.0
        self.n_iter += self.args.batch_size
        return self.n_iter, cum_loss_sum, (0, 0, 0, 0, 0, 0)

    def iter(self, epoch, data_loader, train=True) -> List:
        """
        Perform a training / validation iteration.
        :param epoch: the current epoch number.
        :param data_loader: the data loader.
        :param train: True: train model, False: validation model.
        :return: the loss terms as a list
        """

        if train:
            self.model.train()
            self.motif_model.train()
        else:
            self.model.eval()
            self.motif_model.eval()
            
        time1 = time.time()
        print(f'iter start')
        loss_sum, iter_count = 0, 0
        cum_loss_sum, cum_iter_count = 0, 0
        av_loss_sum, bv_loss_sum, fg_loss_sum, av_dist_loss_sum, bv_dist_loss_sum, fg_dist_loss_sum, topo_loss_sum, node_loss_sum = 0, 0, 0, 0, 0, 0, 0, 0
        # loss_func = self.model.get_loss_func(self.args)

        for _, item in enumerate(data_loader):
            batch_graph = item["graph_input"]
            targets = item["targets"]
            
            # add this for motif generation
            moltree = item["moltree"]
            #list_graph = list(batch_graph)
            #new_graph=list_graph[0][1:],list_graph[1][1:], list_graph[2][1:], list_graph[3][1:], list_graph[4][1:], list_graph[5][:], list_graph[6][:], list_graph[7][1:]
            
            time2 = time.time()
            print(f'dataloader time is {time2-time1}')
            
            if next(self.model.parameters()).is_cuda:
                targets["av_task"] = targets["av_task"].cuda()
                targets["bv_task"] = targets["bv_task"].cuda()
                targets["fg_task"] = targets["fg_task"].cuda()
            
            preds = self.model(batch_graph)
            emb_vector = preds['emb_vec']
            
            time3 = time.time()
            print(f'pred time is {time3-time2}')

            # add this for motif generation
            if self.args.embedding_output_type == 'atom':
                emb_afa_grouped = group_node_rep(moltree, emb_vector['atom_from_atom'],batch_graph)
                emb_afb_grouped = group_node_rep(moltree, emb_vector['atom_from_bond'],batch_graph)
                
                node_afa_loss, topo_afa_loss, node_afa_acc, topo_afa_acc = self.motif_model(moltree, emb_afa_grouped)
                node_afb_loss, topo_afb_loss, node_afb_acc, topo_afb_acc = self.motif_model(moltree, emb_afb_grouped)
                
                node_loss = node_afa_loss + node_afb_loss
                topo_loss = topo_afa_loss + topo_afb_loss
                node_acc = (node_afa_acc + node_afb_acc)/2
                topo_acc = (topo_afa_acc + topo_afb_acc)/2
                
            elif self.args.embedding_output_type == 'bond':
                emb_bfa_grouped = group_node_rep(moltree, emb_vector['bond_from_atom'],batch_graph)
                emb_bfb_grouped = group_node_rep(moltree, emb_vector['bond_from_bond'],batch_graph)
                
                node_bfa_loss, topo_bfa_loss, node_bfa_acc, topo_bfa_acc = self.motif_model(moltree, emb_bfa_grouped)
                node_bfb_loss, topo_bfb_loss, node_bfb_acc, topo_bfb_acc = self.motif_model(moltree, emb_bfb_grouped)
                
                node_loss = node_bfa_loss + node_bfb_loss
                topo_loss = topo_bfa_loss + topo_bfb_loss
                node_acc = (node_bfa_acc + node_bfb_acc)/2
                topo_acc = (topo_bfa_acc + topo_bfb_acc)/2
                
            elif self.args.embedding_output_type == "both":
                emb_afa_grouped = group_node_rep(moltree, emb_vector['atom_from_atom'],batch_graph)
                emb_afb_grouped = group_node_rep(moltree, emb_vector['atom_from_bond'],batch_graph)
                emb_bfa_grouped = group_node_rep(moltree, emb_vector['bond_from_atom'],batch_graph)
                emb_bfb_grouped = group_node_rep(moltree, emb_vector['bond_from_bond'],batch_graph)
                
                node_afa_loss, topo_afa_loss, node_afa_acc, topo_afa_acc = self.motif_model(moltree, emb_afa_grouped)
                node_afb_loss, topo_afb_loss, node_afb_acc, topo_afb_acc = self.motif_model(moltree, emb_afb_grouped)
                node_bfa_loss, topo_bfa_loss, node_bfa_acc, topo_bfa_acc = self.motif_model(moltree, emb_bfa_grouped)
                node_bfb_loss, topo_bfb_loss, node_bfb_acc, topo_bfb_acc = self.motif_model(moltree, emb_bfb_grouped)
                
                node_loss = node_afa_loss + node_afb_loss + node_bfa_loss + node_bfb_loss
                topo_loss = topo_afa_loss + topo_afb_loss + topo_bfa_loss + topo_bfb_loss
                node_acc = (node_afa_acc + node_afb_acc + node_bfa_acc + node_bfb_acc)/4
                topo_acc = (topo_afa_acc + topo_afb_acc + topo_bfa_acc + topo_bfb_acc)/4

            # # ad-hoc code, for visualizing a model, comment this block when it is not needed
            # import dglt.contrib.grover.vis_model as vis_model
            # for task in ['av_task', 'bv_task', 'fg_task']:
            #     vis_graph = vis_model.make_dot(self.model(batch_graph)[task],
            #                                    params=dict(self.model.named_parameters()))
            #     # vis_graph.view()
            #     vis_graph.render(f"{self.args.backbone}_model_{task}_vis.png", format="png")
            # exit()
            
            time4 = time.time()
            print(f'motif_model time is {time4-time3}')

            loss, av_loss, bv_loss, fg_loss, av_dist_loss, bv_dist_loss, fg_dist_loss = self.loss_func(preds, targets)

            loss_sum += loss.item()
            iter_count += self.args.batch_size
            
            # add for topology loss
            loss += topo_loss.item()
            loss += node_loss.item()
            topo_loss_sum += topo_loss.item()
            node_loss_sum += node_loss.item()

            if train:
                cum_loss_sum += loss.item()
                # Run model
                self.model.zero_grad()
                self.motif_model.zero_grad()
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()
            else:
                # For eval model, only consider the loss of three task.
                cum_loss_sum += av_loss.item()
                cum_loss_sum += bv_loss.item()
                cum_loss_sum += fg_loss.item()

            av_loss_sum += av_loss.item()
            bv_loss_sum += bv_loss.item()
            fg_loss_sum += fg_loss.item()
            av_dist_loss_sum += av_dist_loss.item() if type(av_dist_loss) != float else av_dist_loss
            bv_dist_loss_sum += bv_dist_loss.item() if type(bv_dist_loss) != float else bv_dist_loss
            fg_dist_loss_sum += fg_dist_loss.item() if type(fg_dist_loss) != float else fg_dist_loss

            cum_iter_count += 1
            self.n_iter += self.args.batch_size

            # Debug only.
            # if i % 50 == 0:
            #     print(f"epoch: {epoch}, batch_id: {i}, av_loss: {av_loss}, bv_loss: {bv_loss}, "
            #           f"fg_loss: {fg_loss}, av_dist_loss: {av_dist_loss}, bv_dist_loss: {bv_dist_loss}, "
            #           f"fg_dist_loss: {fg_dist_loss}")
            
            time1=time.time()
            print(f'loss time is {time1-time4}')

        cum_loss_sum /= cum_iter_count
        av_loss_sum /= cum_iter_count
        bv_loss_sum /= cum_iter_count
        fg_loss_sum /= cum_iter_count
        av_dist_loss_sum /= cum_iter_count
        bv_dist_loss_sum /= cum_iter_count
        fg_dist_loss_sum /= cum_iter_count
        
        topo_loss_sum /= cum_iter_count
        node_loss_sum /= cum_iter_count

        return self.n_iter, cum_loss_sum, (av_loss_sum, bv_loss_sum, fg_loss_sum, av_dist_loss_sum,
                                           bv_dist_loss_sum, fg_dist_loss_sum, topo_loss_sum, node_loss_sum)

    def save(self, epoch, file_path, name=None) -> str:
        """
        Save the intermediate models during training.
        :param epoch: the epoch number.
        :param file_path: the file_path to save the model.
        :return: the output path.
        """
        # add specific time in model fine name, in order to distinguish different saved models
        now = time.localtime()
        if name is None:
            name = "_%04d_%02d_%02d_%02d_%02d_%02d" % (
                now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min, now.tm_sec)
        output_path = file_path + name + ".ep%d" % epoch
        scaler = None
        features_scaler = None
        state = {
            'args': self.args,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler_step': self.scheduler.current_step,
            "epoch": epoch,
            'data_scaler': {
                'means': scaler.means,
                'stds': scaler.stds
            } if scaler is not None else None,
            'features_scaler': {
                'means': features_scaler.means,
                'stds': features_scaler.stds
            } if features_scaler is not None else None
        }
        torch.save(state, output_path)

        # Is this necessary?
        # if self.with_cuda:
        #    self.model = self.model.cuda()
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path

    def save_tmp(self, epoch, file_path, rank=0):
        """
        Save the models for auto-restore during training.
        The model are stored in file_path/tmp folder and will replaced on each epoch.
        :param epoch: the epoch number.
        :param file_path: the file_path to store the model.
        :param rank: the current rank (decrypted).
        :return:
        """
        store_path = os.path.join(file_path, "tmp")
        if not os.path.exists(store_path):
            os.makedirs(store_path, exist_ok=True)
        store_path = os.path.join(store_path, "model.%d" % rank)
        state = {
            'args': self.args,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler_step': self.scheduler.current_step,
            "epoch": epoch
        }
        torch.save(state, store_path)

    def restore(self, file_path, rank=0) -> Tuple[int, int]:
        """
        Restore the training state saved by save_tmp.
        :param file_path: the file_path to store the model.
        :param rank: the current rank (decrypted).
        :return: the restored epoch number and the scheduler_step in scheduler.
        """
        cpt_path = os.path.join(file_path, "tmp", "model.%d" % rank)
        if not os.path.exists(cpt_path):
            print("No checkpoint found %d")
            return 0, 0
        cpt = torch.load(cpt_path)
        self.model.load_state_dict(cpt["state_dict"])
        self.optimizer.load_state_dict(cpt["optimizer"])
        epoch = cpt["epoch"]
        scheduler_step = cpt["scheduler_step"]
        self.scheduler.current_step = scheduler_step
        print("Restore checkpoint, current epoch: %d" % (epoch))
        return epoch, scheduler_step


##### 3-2-3-1 trainer 세부 실행

In [62]:
def get_slots(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()]
        
class Motif_Vocab(object):

    def __init__(self, smiles_list):
        self.vocab = smiles_list
        self.vmap = {x:i for i,x in enumerate(self.vocab)}
        self.slots = [get_slots(smiles) for smiles in self.vocab]
        
    def get_index(self, smiles):
        try : return self.vmap[smiles]
        except : return [0]
    
    def get_smiles(self, idx):
        return self.vocab[idx]

    def get_slots(self, idx):
        return copy.deepcopy(self.slots[idx])

    def size(self):
        return len(self.vocab)
    
    def add_motif(self, smiles):
        self.vocab.append(smiles)

In [63]:
if args.parser_name == 'pretrain':
    motif_vocab = [x.strip("\r\n ") for x in open(args.motif_vocab_path)]
    motif_vocab = Motif_Vocab(motif_vocab)
    #see below motif_model = Motif_Generation(motif_vocab, args.motif_hidden_size, args.motif_latent_size, 3, device, args.motif_order).to(device)

In [64]:
embedding_model = GROVEREmbedding(args)
embedding_model.cuda()
embedding_model_test = GROVEREmbedding(args)
embedding_model_test.cuda()
embedding_model_test.eval()
embed_model = GroverMotifTask(args, embedding_model, atom_vocab_size, bond_vocab_size, fg_size)
embed_model.cuda()
embed_model_test = GroverMotifTask(args, embedding_model, atom_vocab_size, bond_vocab_size, fg_size)
embed_model_test.cuda()
embed_model_test.eval()
motif_model = Motif_Generation(motif_vocab, args.motif_hidden_size, args.motif_latent_size, 3, device, args.motif_order).to(device)

NameError: name 'GroverMotifTask' is not defined

In [None]:
train_data.__getitem__

In [None]:
motif_collator = GroverMotifCollator(shared_dict=shared_dict, atom_vocab=atom_vocab, bond_vocab=bond_vocab, args=args)
train_data_motif_dl = DataLoader(train_data,
                           batch_size=args.batch_size,
                           shuffle=shuffle,
                           num_workers=12,
                           sampler=train_sampler,
                           collate_fn=motif_collator)
test_data_motif_dl = DataLoader(test_data,
                          batch_size=args.batch_size,
                          shuffle=shuffle,
                          num_workers=10,
                          sampler=test_sampler,
                          collate_fn=motif_collator)

In [113]:
for i, item in enumerate(train_data_motif_dl):
    batch_graph = item["graph_input"]
    targets = item["targets"]
    moltree = item["moltree"]
    
    if next(embed_model.parameters()).is_cuda:
        targets["av_task"] = targets["av_task"].cuda()
        targets["bv_task"] = targets["bv_task"].cuda()
        targets["fg_task"] = targets["fg_task"].cuda()
    #preds = embed_model(batch_graph)
    #emb = preds['emb_vec']
    #_, motif_loss, _ = motif_model(emb)
    if i == 0 : break

AssertionError: Caught AssertionError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/root/grover/grover/data/groverdataset.py", line 158, in __getitem__
    return self.data[dp_idx][real_idx]
  File "/tmp/ipykernel_200/3388227146.py", line 49, in __getitem__
    assert self.datapoints is not None
AssertionError


In [135]:
moltree[0].nodes[0].smiles

'CCC'

In [128]:
for i, item in enumerate(train_data_dl):
    batch_graph_test = item["graph_input"]
    targets_test = item["targets"]
    moltree_test = item["moltree"]

    if next(embed_model.parameters()).is_cuda:
        targets_test["av_task"] = targets_test["av_task"].cuda()
        targets_test["bv_task"] = targets_test["bv_task"].cuda()
        targets_test["fg_task"] = targets_test["fg_task"].cuda()
    preds_test = embed_model_test(batch_graph)
    emb_test = preds_test['emb_vec']
    #_, motif_loss, _ = motif_model(emb)
    if i == 0 : break



In [136]:
emb

{'atom_from_atom': tensor([[ 1.1859,  0.1740, -0.9452,  ...,  1.4035,  0.3065,  1.4557],
         [ 1.3639,  0.3965, -1.2151,  ...,  1.1539,  0.4090,  1.6906],
         [ 1.1003,  1.0086, -1.0413,  ...,  0.5360,  0.8957,  1.4942],
         ...,
         [ 1.4797,  0.4890, -1.2851,  ...,  0.3559,  0.5014,  1.6542],
         [ 1.4237,  0.5486, -1.3769,  ...,  0.3725,  0.4631,  1.7000],
         [ 1.0127,  0.6267, -1.4408,  ...,  0.7031,  0.3877,  1.7992]],
        device='cuda:0', grad_fn=<NativeLayerNormBackward0>),
 'bond_from_atom': tensor([[ 1.9872,  0.0787,  3.0122,  ..., -0.4602,  1.9317, -0.2936],
         [ 1.7526, -0.0993,  2.8577,  ..., -0.3381,  1.6139, -0.0265],
         [ 0.8892,  0.1347,  2.2546,  ...,  0.2686,  1.3640,  1.0550],
         ...,
         [ 1.0481, -0.2610,  2.6476,  ...,  0.0519,  1.5092,  0.8602],
         [ 1.0814, -0.2726,  2.7533,  ...,  0.0844,  1.5442,  0.8107],
         [ 1.3965, -0.5333,  2.4764,  ...,  0.0368,  1.5384,  0.8662]],
        device='cuda

In [137]:
emb_test

{'atom_from_atom': tensor([[ 1.1859,  0.1740, -0.9452,  ...,  1.4035,  0.3065,  1.4557],
         [ 1.2336,  0.4660, -1.1345,  ...,  1.1749,  0.5299,  1.5460],
         [ 0.7674,  0.3911, -1.2583,  ...,  0.8974,  0.8591,  1.3976],
         ...,
         [ 0.3979,  0.6365, -2.1722,  ...,  0.5994,  0.5248,  1.4194],
         [ 0.6119,  1.2031, -1.9464,  ...,  0.2268,  0.8142,  1.4497],
         [ 0.3979,  0.6365, -2.1722,  ...,  0.5994,  0.5248,  1.4194]],
        device='cuda:0', grad_fn=<NativeLayerNormBackward0>),
 'bond_from_atom': tensor([[ 1.9872,  0.0787,  3.0122,  ..., -0.4602,  1.9317, -0.2936],
         [ 1.7658, -0.0789,  2.8364,  ..., -0.3204,  1.5859, -0.0390],
         [ 1.0875,  0.2264,  2.3448,  ...,  0.2457,  1.4572,  0.8970],
         ...,
         [ 0.5334,  0.7396,  1.5271,  ...,  0.1219,  1.0316,  1.1022],
         [ 0.5334,  0.7396,  1.5271,  ...,  0.1219,  1.0316,  1.1022],
         [ 0.7041,  0.5886,  1.4270,  ...,  0.0203,  1.0522,  1.4338]],
        device='cuda

In [138]:
batch_graph

(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.]]),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 tensor([[   0,    0,    0,    0],
         [   2,    0,    0,    0],
         [   1,    4,    0,    0],
         ...,
         [5241, 5244,    0,    0],
         [5243, 5246,    0,    0],
         [5237, 5245,    0,    0]]),
 tensor([   0,    1,    2,  ..., 2424, 2424, 2425]),
 tensor([   0,    2,    1,  ..., 5243, 5246, 5245]),
 tensor([[   1,   20],
         [  21,   27],
         [  48,   32],
         [  80,   21],
         [ 101,   21],
         [ 122,   27],
      

In [139]:
batch_graph_test

(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.]]),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 tensor([[   0,    0,    0,    0],
         [   2,    0,    0,    0],
         [   1,    4,    0,    0],
         ...,
         [5021, 5031, 5039,    0],
         [5005, 5042,    0,    0],
         [4997, 5041,    0,    0]]),
 tensor([   0,    1,    2,  ..., 2339, 2340, 2341]),
 tensor([   0,    2,    1,  ..., 5039, 5042, 5041]),
 tensor([[   1,   27],
         [  28,   26],
         [  54,   21],
         [  75,   24],
         [  99,   29],
         [ 128,   18],
      

In [140]:
moltree[0].mol

<rdkit.Chem.rdchem.Mol at 0x7f5096622f30>

In [141]:
moltree_test[0].mol

<rdkit.Chem.rdchem.Mol at 0x7f5095c904b0>

In [211]:
for i in range(10):
    print(moltree[i].smiles)
    for j in range(len(moltree[i].nodes)):
        print(moltree[i].nodes[j].smiles)

CC(=O)Nc1nc2ccc(NC(=O)NCc3ccccc3)cc2s1
C
C1=CC=CC=C1
C
N
C
N
N
C
C1=CC=C2SC=NC2=C1
O
O
Cc1cccc(C)c1NC(=O)C[NH+]1CCC(OCc2ccc(F)cc2)CC1
C
C1CC[NH2+]CC1
C1=CC=CC=C1
C
N
O
C
C
C1=CC=CC=C1
C
F
O
Cc1occc1C(=O)/C(C#N)=C\c1cccc(C(F)(F)F)c1
C
C#N
C1=CC=CC=C1
C
C
C1=COC=C1
C
C
O
F
F
F
Cc1cccn2c(=O)c(C(=O)NC[C@H]3CCO[C@@H]3C(C)C)cnc12
C
C1CCOC1
C
N
C
C
C1=CC2=NC=CCN2C=C1
O
O
C
C
COc1ccc(OC)c(/C=C2\Oc3cc(OC(=O)c4ccncc4)cc(C)c3C2=O)c1
CO
C1=CC=CC=C1
CO
C1=CC=C2OCCC2=C1
C1=CC=NC=C1
C
O
C
C
O
O
Cc1ccc(-c2cc(NC(=O)C(C)C)c(=O)n(CC(=O)Nc3cccc(C)c3)n2)cc1
C
C1=CC=CC=C1
C
N
C
N
C1=CC(C2=CC=CC=C2)=NNC1
O
C
C
O
C
C
C
O
CC(=O)NCCC(=O)N1CCC[C@@H](C)C1
C
C1CCNCC1
C
N
C
C
CC
O
O
CCC[NH2+]C1CCC(O)(Cc2nc(C)cs2)CC1
CCC
C1CCCCC1
CC1=NC=CS1
[NH4+]
O
C
CCn1cc(C(=O)N[C@H]2CC(=O)N(C)C2)c(C(C)C)n1
CC
C1=CNN=C1
C1CCNC1
C
N
O
C
C
O
C
C
Cc1c(F)cc(N)cc1S(=O)(=O)N[C@@H](C)C1CC1
C
C1CC1
C
N
S
C1=CC=CC=C1
F
N
O
O
C


In [214]:
set_batch_nodeID(moltree, motif_vocab)

In [296]:
len(moltree[1].mol.GetBonds())

22

In [328]:
batch_graph[5][0][1]

tensor(21)

In [344]:
emb['atom_from_atom'].shape

torch.Size([233, 100])

In [345]:
emb['atom_from_bond'].shape

torch.Size([233, 100])

In [337]:
batch_graph[6]

tensor([[  1,  46],
        [ 47,  68],
        [115,  54],
        [169,  44],
        [213,  38],
        [251,  76],
        [327,  44],
        [371,  38],
        [409,  60],
        [469,  32]])

In [142]:
emb_edge_grouped = group_edge_rep(moltree, emb['bond_from_atom'],batch_graph)
emb_edge_grouped

[tensor([[ 1.7526, -0.0993,  2.8577,  ..., -0.3381,  1.6139, -0.0265],
         [ 0.8892,  0.1347,  2.2546,  ...,  0.2686,  1.3640,  1.0550],
         [ 0.9989,  0.0948,  2.0923,  ...,  0.0882,  1.2366,  0.7802],
         ...,
         [ 0.1021,  0.1924,  1.4996,  ...,  0.6037,  0.2489,  1.1135],
         [ 0.7596,  0.0699,  1.5142,  ...,  0.7303,  0.1724,  1.0258],
         [ 1.6360, -0.0166,  2.9026,  ..., -0.2519,  1.5676,  0.0445]],
        device='cuda:0', grad_fn=<SliceBackward0>),
 tensor([[ 1.6737e+00, -2.5817e-03,  2.9433e+00,  ..., -2.7101e-01,
           1.5551e+00,  8.2253e-02],
         [ 6.8865e-01,  2.6351e-01,  1.5407e+00,  ...,  2.8127e-01,
           7.3498e-01,  7.7785e-01],
         [-1.1575e-01,  5.9401e-01,  1.4826e+00,  ...,  1.5004e-01,
           8.0959e-01,  8.8538e-01],
         ...,
         [ 6.6971e-01,  4.9209e-01,  2.2563e+00,  ..., -4.9288e-01,
           1.0688e+00,  2.5702e-01],
         [ 3.4176e-01,  6.9647e-01,  1.9208e+00,  ...,  8.3852e-01,
     

In [335]:
emb['bond_from_atom'].shape

torch.Size([501, 100])

In [325]:
emb_edgegrouped[6].shape

torch.Size([21, 100])

In [313]:
batch_graph[6]

tensor([[  1,  46],
        [ 47,  68],
        [115,  54],
        [169,  44],
        [213,  38],
        [251,  76],
        [327,  44],
        [371,  38],
        [409,  60],
        [469,  32]])

In [316]:
emb_edge_grouped = group_edge_rep(moltree, emb['bond_from_atom'],batch_graph)
emb_edge_grouped

[tensor([[-8.5820e-01,  1.6547e+00,  9.4280e-01,  8.4202e-01, -1.9089e-01,
          -5.1609e-01,  8.8267e-02,  2.9910e-02,  2.4479e-01, -9.5501e-02,
           1.0021e+00,  8.2473e-01, -1.1580e+00,  7.7930e-01, -9.4704e-01,
           1.4190e+00,  9.8389e-01, -1.0539e+00,  2.8490e-01,  1.2800e+00,
          -1.5545e+00,  7.8484e-01,  7.0774e-01,  1.9000e+00, -9.9584e-01,
           1.5108e+00, -9.1161e-01,  1.5985e+00, -5.7780e-01,  1.0407e+00,
          -5.9209e-01,  1.9629e+00, -3.4021e-01,  9.2519e-01, -3.4118e-01,
          -2.5694e-01,  5.4684e-01, -1.2189e+00, -1.3635e+00,  3.2851e-02,
          -7.1505e-01,  1.6838e-01,  2.8715e-02,  7.5729e-01, -1.1052e+00,
          -6.0632e-01,  8.8837e-02,  6.8328e-01, -2.6477e+00, -5.6554e-03,
           6.6163e-01, -9.7289e-01, -1.0368e+00,  8.4469e-01, -7.1414e-01,
           5.7456e-01, -7.3513e-01, -1.2843e-01, -2.7095e-01,  4.1770e-01,
          -9.9255e-01,  7.8305e-01, -2.0609e+00,  7.2598e-01,  4.0311e-01,
          -2.7848e-01,  1

In [317]:
emb['bond_from_atom'].shape

torch.Size([501, 100])

In [292]:
emb['atom_from_atom'].shape

torch.Size([231, 100])

In [293]:
emb['bond_from_atom'].shape

torch.Size([491, 100])

In [299]:
count_test=0
for i in range(10):
    count_test+=len(emb_edge_grouped[i])
    
print(count_test)

225


In [219]:
motif_model(moltree, emb_grouped)

(tensor(53.9152, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(15.8666, device='cuda:0', grad_fn=<DivBackward0>),
 0.010101010091602802,
 0.44711539149284363)

In [259]:
len(emb['atom_from_bond'])

228

In [233]:
for mol_index, mol_tree in enumerate(mol_batch):
    print(mol_index)
    print(mol_tree)

0
<grover.pretrain_motif.mol_tree.MolTree object at 0x7f0cbd5b5a90>
1
<grover.pretrain_motif.mol_tree.MolTree object at 0x7f0cbd598990>
2
<grover.pretrain_motif.mol_tree.MolTree object at 0x7f0cbd5cb890>
3
<grover.pretrain_motif.mol_tree.MolTree object at 0x7f0cbd5fb750>
4
<grover.pretrain_motif.mol_tree.MolTree object at 0x7f0cbd607090>
5
<grover.pretrain_motif.mol_tree.MolTree object at 0x7f0cbd4b2150>
6
<grover.pretrain_motif.mol_tree.MolTree object at 0x7f0cbd4b2b90>
7
<grover.pretrain_motif.mol_tree.MolTree object at 0x7f0cbd4b5b50>
8
<grover.pretrain_motif.mol_tree.MolTree object at 0x7f0cbd498350>
9
<grover.pretrain_motif.mol_tree.MolTree object at 0x7f0cbd49d210>


[<grover.pretrain_motif.mol_tree.MolTree at 0x7f0cbd5b5a90>,
 <grover.pretrain_motif.mol_tree.MolTree at 0x7f0cbd598990>,
 <grover.pretrain_motif.mol_tree.MolTree at 0x7f0cbd5cb890>,
 <grover.pretrain_motif.mol_tree.MolTree at 0x7f0cbd5fb750>,
 <grover.pretrain_motif.mol_tree.MolTree at 0x7f0cbd607090>,
 <grover.pretrain_motif.mol_tree.MolTree at 0x7f0cbd4b2150>,
 <grover.pretrain_motif.mol_tree.MolTree at 0x7f0cbd4b2b90>,
 <grover.pretrain_motif.mol_tree.MolTree at 0x7f0cbd4b5b50>,
 <grover.pretrain_motif.mol_tree.MolTree at 0x7f0cbd498350>,
 <grover.pretrain_motif.mol_tree.MolTree at 0x7f0cbd49d210>]

##### 데이터 구조 및 training

In [108]:
emb['atom_from_atom'].shape

torch.Size([496, 100])

In [109]:
preds["av_task"][0].shape

torch.Size([496, 324])

In [110]:
preds["bv_task"][0].shape

torch.Size([531, 353])

In [71]:
len(batch_graph[0])

514

In [84]:
len(train_data_dl)

169

In [88]:
# Perform training.
print(f'resume_from_epoch is {resume_from_epoch}')
for epoch in range(resume_from_epoch + 1, args.epochs):
    s_time = time.time()

    # Data pre-loading.
    if args.enable_multi_gpu:
        train_sampler.set_epoch(epoch)
        train_data.clean_cache()
        idxs = train_sampler.get_indices()
        for local_gpu_idx in idxs:
            train_data.load_data(local_gpu_idx)
    d_time = time.time() - s_time

    # perform training and validation.
    s_time = time.time()
    _, train_loss, _ = trainer.train(epoch)
    t_time = time.time() - s_time
    s_time = time.time()
    _, val_loss, detailed_loss_val = trainer.test(epoch)
    val_av_loss, val_bv_loss, val_fg_loss, _, _, _ = detailed_loss_val
    v_time = time.time() - s_time

    '''    # print information.
    if master_worker:
        print('Epoch: {:04d}'.format(epoch),
              'loss_train: {:.6f}'.format(train_loss),
              'loss_val: {:.6f}'.format(val_loss),
              'loss_val_av: {:.6f}'.format(val_av_loss),
              'loss_val_bv: {:.6f}'.format(val_bv_loss),
              'loss_val_fg: {:.6f}'.format(val_fg_loss),
              'cur_lr: {:.5f}'.format(trainer.scheduler.get_lr()[0]),
              't_time: {:.4f}s'.format(t_time),
              'v_time: {:.4f}s'.format(v_time),
              'd_time: {:.4f}s'.format(d_time), flush=True)

        if epoch % args.save_interval == 0:
            trainer.save(epoch, model_dir)


        trainer.save_tmp(epoch, model_dir, rank)'''

# Only save final version.
if master_worker:
    trainer.save(args.epochs, model_dir, "")

resume_from_epoch is 0




tensor([[ 6.6085e-03, -3.0400e-03,  0.0000e+00,  ..., -5.4301e-03,
          2.4351e-03,  3.8916e-02],
        [ 7.1513e-01, -5.5678e-01,  2.3509e+00,  ..., -7.1710e-01,
         -4.2909e-01,  4.4382e+00],
        [ 8.0203e-01, -1.1722e+00, -2.2792e-01,  ..., -1.0905e+00,
         -1.1718e+00,  0.0000e+00],
        ...,
        [ 9.9632e-02,  1.5563e+00,  8.1791e-02,  ..., -2.7661e-01,
         -9.3305e-01, -5.8124e-01],
        [-5.3853e-01,  6.1812e-01, -2.7733e-01,  ..., -9.3885e-01,
         -9.6680e-01, -7.6313e-02],
        [ 3.4529e-01,  2.7850e+00, -0.0000e+00,  ..., -2.7454e-01,
         -4.0088e-01,  3.0151e+00]], device='cuda:0',
       grad_fn=<FusedDropoutBackward0>)
tensor([[ 0.1022, -0.0084,  0.0221,  ...,  0.0205, -0.0240, -0.0036],
        [ 1.0024, -0.1991,  1.9776,  ..., -0.0144, -0.5643,  2.4008],
        [ 1.5665,  0.5782,  0.2745,  ..., -0.8584, -1.9646, -0.4699],
        ...,
        [-0.0081, -0.3479, -0.0000,  ..., -0.4042, -0.9947,  0.4765],
        [ 1.6408, 

KeyboardInterrupt: 

In [59]:
for i, item in enumerate(train_data_dl):
    if i==0:
        items = item
        batch_graph = item["graph_input"]
        targets = item["targets"]
    else : break



#### 2-1-4-1 배치 내부 구조
- graph_input(graph상태를 의미)
  - f_atoms
  - f_bonds
  - a2b
  - b2a
  - b2revb
  - a_scope
  - b_scope
  - a2a
- targets : Label값
  - av_task : atom 맞추기
  - bv_task : bond 맞추기
  - fg_task : motif 맞추기(정확한걸 맞추는게 아니라 이 분자에 포함된 motif가 이거이거다 라는 식)

In [39]:
items['targets']

NameError: name 'items' is not defined

In [40]:
mol2graph(['C(O)O'], shared_dict, args).get_components()

(tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000

In [41]:
mol2graph(['C(O)O'], shared_dict, args).a2a

tensor([[0, 0],
        [2, 3],
        [1, 0],
        [1, 0]])

In [63]:
mol2graph(['C(O)O'], shared_dict, args).b2b

In [65]:
batch_graph

(tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000

In [66]:
len(batch_graph)

8

In [67]:
batch_graph[0][1]

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000,
        1.0000, 0.0000, 0.0000, 0.0000, 

In [68]:
batch_graph[1][1]

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000,
        1.0000, 0.0000, 0.0000, 0.0000, 

In [69]:
batch_graph[2][1]

tensor([2, 4])

In [70]:
batch_graph[3][1]

tensor(1)

In [71]:
batch_graph[4][1]

tensor(2)

In [72]:
batch_graph[5][0]

tensor([1, 3])

In [73]:
batch_graph[6][0]

tensor([1, 4])

In [74]:
batch_graph[7][1]

tensor([2, 3])

In [75]:
targets['av_task'].shape

torch.Size([4])

### 3.3. run_motif_training

In [71]:
def run_motif_training(args, logger):
    """
    Run the pretrain task.
    :param args:
    :param logger:
    :return:
    """

    # initalize the logger.
    if logger is not None:
        debug, _ = logger.debug, logger.info
    else:
        debug = print

    # initialize the horovod library
    if args.enable_multi_gpu:
        mgw.init()

    # binding training to GPUs.
    master_worker = (mgw.rank() == 0) if args.enable_multi_gpu else True
    # pin GPU to local rank. By default, we use gpu:0 for training.
    local_gpu_idx = mgw.local_rank() if args.enable_multi_gpu else 0
    with_cuda = args.cuda
    if with_cuda:
        torch.cuda.set_device(local_gpu_idx)

    # get rank an  number of workers
    rank = mgw.rank() if args.enable_multi_gpu else 0
    num_replicas = mgw.size() if args.enable_multi_gpu else 1
    # print("Rank: %d Rep: %d" % (rank, num_replicas))

    # load file paths of the data.
    if master_worker:
        print(args)
        if args.enable_multi_gpu:
            debug("Total workers: %d" % (mgw.size()))
        debug('Loading data')
    data, sample_per_file = get_data(data_path=args.data_path)

    # data splitting
    if master_worker:
        debug(f'Splitting data with seed 0.')
    train_data, test_data, _ = split_data(data=data, sizes=(0.9, 0.1, 0.0), seed=0, logger=logger)

    # Here the true train data size is the train_data divided by #GPUs
    if args.enable_multi_gpu:
        args.train_data_size = len(train_data) // mgw.size()
    else:
        args.train_data_size = len(train_data)
    if master_worker:
        debug(f'Total size = {len(data):,} | '
              f'train size = {len(train_data):,} | val size = {len(test_data):,}')

    # load atom and bond vocabulary and the semantic motif labels.
    atom_vocab = MolVocab.load_vocab(args.atom_vocab_path)
    bond_vocab = MolVocab.load_vocab(args.bond_vocab_path)
    atom_vocab_size, bond_vocab_size = len(atom_vocab), len(bond_vocab)

    # Load motif vocabulary for pretrain
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if args.parser_name == 'pretrain':
        motif_vocab = [x.strip("\r\n ") for x in open(args.motif_vocab_path)]
        motif_vocab = Motif_Vocab(motif_vocab)
        #see below motif_model = Motif_Generation(motif_vocab, args.motif_hidden_size, args.motif_latent_size, 3, device, args.motif_order).to(device)
        
        
    # Hard coding here, since we haven't load any data yet!
    fg_size = 85
    shared_dict = {}
    mol_collator = GroverCollator(shared_dict=shared_dict, atom_vocab=atom_vocab, bond_vocab=bond_vocab, args=args)
    if master_worker:
        debug("atom vocab size: %d, bond vocab size: %d, Number of FG tasks: %d" % (atom_vocab_size,
                                                                                    bond_vocab_size, fg_size))

    # Define the distributed sampler. If using the single card, the sampler will be None.
    train_sampler = None
    test_sampler = None
    shuffle = True
    if args.enable_multi_gpu:
        # If not shuffle, the performance may decayed.
        train_sampler = DistributedSampler(
            train_data, num_replicas=mgw.size(), rank=mgw.rank(), shuffle=True, sample_per_file=sample_per_file)
        # Here sample_per_file in test_sampler is None, indicating the test sampler would not divide the test samples by
        # rank. (TODO: bad design here.)
        test_sampler = DistributedSampler(
            test_data, num_replicas=mgw.size(), rank=mgw.rank(), shuffle=False)
        train_sampler.set_epoch(args.epochs)
        test_sampler.set_epoch(1)
        # if we enables multi_gpu training. shuffle should be disabled.
        shuffle = False

    # Pre load data. (Maybe unnecessary. )
    pre_load_data(train_data, rank, num_replicas, sample_per_file)
    pre_load_data(test_data, rank, num_replicas)
    if master_worker:
        # print("Pre-loaded training data: %d" % train_data.count_loaded_datapoints())
        print("Pre-loaded test data: %d" % test_data.count_loaded_datapoints())

    # Build dataloader
    train_data_dl = DataLoader(train_data,
                               batch_size=args.batch_size,
                               shuffle=shuffle,
                               num_workers=12,
                               sampler=train_sampler,
                               collate_fn=mol_collator)
    test_data_dl = DataLoader(test_data,
                              batch_size=args.batch_size,
                              shuffle=shuffle,
                              num_workers=10,
                              sampler=test_sampler,
                              collate_fn=mol_collator)

    # Build the embedding model.
    grover_model = GROVEREmbedding(args)

    #  Build the trainer.
    trainer = GROVERMotifTrainer(args=args,
                            embedding_model=grover_model,
                            atom_vocab_size=atom_vocab_size,
                            bond_vocab_size=bond_vocab_size,
                            fg_szie=fg_size,
                            train_dataloader=train_data_dl,
                            test_dataloader=test_data_dl,
                            optimizer_builder=build_optimizer,
                            scheduler_builder=build_lr_scheduler,
                            logger=logger,
                            with_cuda=with_cuda,
                            enable_multi_gpu=args.enable_multi_gpu)

    # Restore the interrupted training.
    model_dir = os.path.join(args.save_dir, "model")
    resume_from_epoch = 0
    resume_scheduler_step = 0
    if master_worker:
        resume_from_epoch, resume_scheduler_step = trainer.restore(model_dir)
    if args.enable_multi_gpu:
        resume_from_epoch = mgw.broadcast(torch.tensor(resume_from_epoch), root_rank=0, name="resume_from_epoch").item()
        resume_scheduler_step = mgw.broadcast(torch.tensor(resume_scheduler_step),
                                              root_rank=0, name="resume_scheduler_step").item()
        trainer.scheduler.current_step = resume_scheduler_step
        print("Restored epoch: %d Restored scheduler step: %d" % (resume_from_epoch, trainer.scheduler.current_step))
    trainer.broadcast_parameters()

    # Print model details.
    if master_worker:
        # Change order here.
        print(grover_model)
        print("Total parameters: %d" % param_count(trainer.grover))

    # Perform training.
    for epoch in range(resume_from_epoch + 1, args.epochs):
        s_time = time.time()

        # Data pre-loading.
        if args.enable_multi_gpu:
            train_sampler.set_epoch(epoch)
            train_data.clean_cache()
            idxs = train_sampler.get_indices()
            for local_gpu_idx in idxs:
                train_data.load_data(local_gpu_idx)
        d_time = time.time() - s_time

        # perform training and validation.
        s_time = time.time()
        _, train_loss, _ = trainer.train(epoch)
        t_time = time.time() - s_time
        s_time = time.time()
        _, val_loss, detailed_loss_val = trainer.test(epoch)
        val_av_loss, val_bv_loss, val_fg_loss, _, _, val_topo_loss, val_node_loss = detailed_loss_val
        v_time = time.time() - s_time

        # print information.
        if master_worker:
            print('Epoch: {:04d}'.format(epoch),
                  'loss_train: {:.6f}'.format(train_loss),
                  'loss_val: {:.6f}'.format(val_loss),
                  'loss_val_av: {:.6f}'.format(val_av_loss),
                  'loss_val_bv: {:.6f}'.format(val_bv_loss),
                  'loss_val_fg: {:.6f}'.format(val_fg_loss),
                  'loss_val_topo: {:.6f}'.format(val_topo_loss),
                  'loss_val_node: {:.6f}'.format(val_node_loss),
                  'cur_lr: {:.5f}'.format(trainer.scheduler.get_lr()[0]),
                  't_time: {:.4f}s'.format(t_time),
                  'v_time: {:.4f}s'.format(v_time),
                  'd_time: {:.4f}s'.format(d_time), flush=True)

            if epoch % args.save_interval == 0:
                trainer.save(epoch, model_dir)


            trainer.save_tmp(epoch, model_dir, rank)

    # Only save final version.
    if master_worker:
        trainer.save(args.epochs, model_dir, "")

In [77]:
data

<__main__.BatchMolDataset_motif at 0x7fe258304f10>

In [59]:
logger = create_logger(name='pretrain', save_dir=args.save_dir)
if logger is not None:
    debug, _ = logger.debug, logger.info
else:
    debug = print

# initialize the horovod library
if args.enable_multi_gpu:
    mgw.init()

# binding training to GPUs.
master_worker = (mgw.rank() == 0) if args.enable_multi_gpu else True
# pin GPU to local rank. By default, we use gpu:0 for training.
local_gpu_idx = mgw.local_rank() if args.enable_multi_gpu else 0
with_cuda = args.cuda
if with_cuda:
    torch.cuda.set_device(local_gpu_idx)

# get rank an  number of workers
rank = mgw.rank() if args.enable_multi_gpu else 0
num_replicas = mgw.size() if args.enable_multi_gpu else 1
# print("Rank: %d Rep: %d" % (rank, num_replicas))

# load file paths of the data.
if master_worker:
    print(args)
    if args.enable_multi_gpu:
        debug("Total workers: %d" % (mgw.size()))
    debug('Loading data')
#data, sample_per_file = get_data(data_path=args.data_path)
data, sample_per_file = get_motif_data(data_path=args.data_path)

# data splitting
if master_worker:
    debug(f'Splitting data with seed 0.')
train_data, test_data, _ = split_data(data=data, sizes=(0.9, 0.1, 0.0), seed=0, logger=logger)

# Here the true train data size is the train_data divided by #GPUs
if args.enable_multi_gpu:
    args.train_data_size = len(train_data) // mgw.size()
else:
    args.train_data_size = len(train_data)
if master_worker:
    debug(f'Total size = {len(data):,} | '
          f'train size = {len(train_data):,} | val size = {len(test_data):,}')

# load atom and bond vocabulary and the semantic motif labels.
atom_vocab = MolVocab.load_vocab(args.atom_vocab_path)
bond_vocab = MolVocab.load_vocab(args.bond_vocab_path)
atom_vocab_size, bond_vocab_size = len(atom_vocab), len(bond_vocab)

# Load motif vocabulary for pretrain
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if args.parser_name == 'pretrain':
    motif_vocab = [x.strip("\r\n ") for x in open(args.motif_vocab_path)]
    motif_vocab = Motif_Vocab(motif_vocab)
    #see below motif_model = Motif_Generation(motif_vocab, args.motif_hidden_size, args.motif_latent_size, 3, device, args.motif_order).to(device)


# Hard coding here, since we haven't load any data yet!
fg_size = 85
shared_dict = {}
motif_collator = GroverMotifCollator(shared_dict=shared_dict, atom_vocab=atom_vocab, bond_vocab=bond_vocab, args=args)

if master_worker:
    debug("atom vocab size: %d, bond vocab size: %d, Number of FG tasks: %d" % (atom_vocab_size,
                                                                                bond_vocab_size, fg_size))

# Define the distributed sampler. If using the single card, the sampler will be None.
train_sampler = None
test_sampler = None
shuffle = True
if args.enable_multi_gpu:
    # If not shuffle, the performance may decayed.
    train_sampler = DistributedSampler(
        train_data, num_replicas=mgw.size(), rank=mgw.rank(), shuffle=True, sample_per_file=sample_per_file)
    # Here sample_per_file in test_sampler is None, indicating the test sampler would not divide the test samples by
    # rank. (TODO: bad design here.)
    test_sampler = DistributedSampler(
        test_data, num_replicas=mgw.size(), rank=mgw.rank(), shuffle=False)
    train_sampler.set_epoch(args.epochs)
    test_sampler.set_epoch(1)
    # if we enables multi_gpu training. shuffle should be disabled.
    shuffle = False

# Pre load data. (Maybe unnecessary. )
#pre_load_data(train_data, rank, num_replicas, sample_per_file)
#pre_load_data(test_data, rank, num_replicas)
#if master_worker:
    # print("Pre-loaded training data: %d" % train_data.count_loaded_datapoints())
#    print("Pre-loaded test data: %d" % test_data.count_loaded_datapoints())

Loading data
Loading data
Splitting data with seed 0.
Splitting data with seed 0.
Total size = 500,000 | train size = 450,000 | val size = 50,000
Total size = 500,000 | train size = 450,000 | val size = 50,000


Namespace(activation='PReLU', atom_vocab_path='data/zinc10M/zinc10M_atom_vocab.pkl', backbone='gtrans', batch_size=100, bias=False, bond_drop_rate=0, bond_vocab_path='data/zinc10M/zinc10M_bond_vocab.pkl', cuda=True, data_path='data/zinc10M_0', dense=False, depth=3, dist_coff=0.1, dropout=0.1, embedding_output_type='both', enable_multi_gpu=False, epochs=20, fg_label_path=None, final_lr=0.0001, fine_tune_coff=1, hidden_size=1200, init_lr=0.0002, max_lr=0.0004, motif_hidden_size=1200, motif_latent_size=56, motif_order='dfs', motif_vocab_path='data/zinc10M/clique.txt', no_cache=True, num_attn_head=4, num_mt_block=1, parser_name='pretrain', save_dir='model/ChEMBL', save_interval=5, topology=True, undirected=False, wandb=False, wandb_name='pretrain', warmup_epochs=2.0, weight_decay=1e-07)
Loading data:
Number of files: 501
Number of samples: 500000
Samples/file: 1000


atom vocab size: 521, bond vocab size: 942, Number of FG tasks: 85
atom vocab size: 521, bond vocab size: 942, Number of FG tasks: 85


In [75]:
# Build dataloader
train_data_dl = DataLoader(train_data,
                           batch_size=args.batch_size,
                           shuffle=shuffle,
                           num_workers=0,
                           sampler=train_sampler,
                           collate_fn=motif_collator)
test_data_dl = DataLoader(test_data,
                          batch_size=args.batch_size,
                          shuffle=shuffle,
                          num_workers=0,
                          sampler=test_sampler,
                          collate_fn=motif_collator)
args.train_data_size=len(train_data)
# Build the embedding model.
grover_model = GROVEREmbedding(args)

# build the topology predict model.
motif_model = Motif_Generation(motif_vocab, args.motif_hidden_size, args.motif_latent_size, 3, device, args.motif_order).to(device)

#  Build the trainer.
trainer = GROVERMotifTrainer(args=args,
                        embedding_model=grover_model,
                        topology_model = motif_model,
                        atom_vocab_size=atom_vocab_size,
                        bond_vocab_size=bond_vocab_size,
                        fg_size=fg_size,
                        train_dataloader=train_data_dl,
                        test_dataloader=test_data_dl,
                        optimizer_builder=build_optimizer,
                        scheduler_builder=build_lr_scheduler,
                        logger=logger,
                        with_cuda=with_cuda,
                        enable_multi_gpu=args.enable_multi_gpu)

# Restore the interrupted training.
model_dir = os.path.join(args.save_dir, "model")
resume_from_epoch = 0
resume_scheduler_step = 0
if master_worker:
    resume_from_epoch, resume_scheduler_step = trainer.restore(model_dir)
if args.enable_multi_gpu:
    resume_from_epoch = mgw.broadcast(torch.tensor(resume_from_epoch), root_rank=0, name="resume_from_epoch").item()
    resume_scheduler_step = mgw.broadcast(torch.tensor(resume_scheduler_step),
                                          root_rank=0, name="resume_scheduler_step").item()
    trainer.scheduler.current_step = resume_scheduler_step
    print("Restored epoch: %d Restored scheduler step: %d" % (resume_from_epoch, trainer.scheduler.current_step))
trainer.broadcast_parameters()

# Print model details.
if master_worker:
    # Change order here.
    print(grover_model)
    print("Total parameters: %d" % param_count(trainer.grover))

No checkpoint found %d
GROVEREmbedding(
  (encoders): GTransEncoder(
    (edge_blocks): ModuleList(
      (0): MTBlock(
        (heads): ModuleList(
          (0): Head(
            (mpn_q): MPNEncoder(
              (dropout_layer): Dropout(p=0.1, inplace=False)
              (act_func): PReLU(num_parameters=1)
              (W_h): Linear(in_features=1200, out_features=1200, bias=False)
            )
            (mpn_k): MPNEncoder(
              (dropout_layer): Dropout(p=0.1, inplace=False)
              (act_func): PReLU(num_parameters=1)
              (W_h): Linear(in_features=1200, out_features=1200, bias=False)
            )
            (mpn_v): MPNEncoder(
              (dropout_layer): Dropout(p=0.1, inplace=False)
              (act_func): PReLU(num_parameters=1)
              (W_h): Linear(in_features=1200, out_features=1200, bias=False)
            )
          )
          (1): Head(
            (mpn_q): MPNEncoder(
              (dropout_layer): Dropout(p=0.1, inplace=False

In [76]:
stime = time.time()
for i, item in enumerate(train_data_dl):
    batch_graph_test = item["graph_input"]
    targets_test = item["targets"]
    moltree_test = item["moltree"]

    if next(embed_model.parameters()).is_cuda:
        targets_test["av_task"] = targets_test["av_task"].cuda()
        targets_test["bv_task"] = targets_test["bv_task"].cuda()
        targets_test["fg_task"] = targets_test["fg_task"].cuda()
    preds_test = embed_model_test(batch_graph)
    emb_test = preds_test['emb_vec']
    #_, motif_loss, _ = motif_model(emb)
    print(time.time()-stime)
    if i == 0 : break

AssertionError: 

In [69]:
# Perform training.
best_val_loss = 0
best_val_epoch = 0
best_model_dir = os.path.join(args.save_dir, "model_best")
for epoch in range(args.epochs):
    s_time = time.time()

    # Data pre-loading.
    if args.enable_multi_gpu:
        train_sampler.set_epoch(epoch)
        train_data.clean_cache()
        idxs = train_sampler.get_indices()
        for local_gpu_idx in idxs:
            train_data.load_data(local_gpu_idx)
    d_time = time.time() - s_time

    # perform training and validation.
    s_time = time.time()
    _, train_loss, _ = trainer.train(epoch)
    t_time = time.time() - s_time
    s_time = time.time()
    _, val_loss, detailed_loss_val = trainer.test(epoch)
    val_av_loss, val_bv_loss, val_fg_loss, _, _, _, val_topo_loss, val_node_loss = detailed_loss_val
    v_time = time.time() - s_time
    
    if best_val_loss > val_loss:
        best_val_loss = val_loss
        best_val_epoch = epoch
        trainer.save(epoch, best_model_dir)
    
    wandb.log({"train_loss" : train_loss, "val_loss" : val_loss, "topo_loss" : val_topo_loss, "epochs" : epoch})
    
    # print information.
    if master_worker:
        print('Epoch: {:04d}\n'.format(epoch),
              'loss_train: {:.6f}'.format(train_loss),
              'loss_val: {:.6f}'.format(val_loss),
              'loss_val_av: {:.6f}'.format(val_av_loss),
              'loss_val_bv: {:.6f}'.format(val_bv_loss),
              'loss_val_fg: {:.6f}'.format(val_fg_loss),
              'loss_val_topo: {:.6f}'.format(val_topo_loss),
              'loss_val_node: {:.6f}'.format(val_node_loss),
              'cur_lr: {:.5f}'.format(trainer.scheduler.get_lr()[0]),
              't_time: {:.4f}s'.format(t_time),
              'v_time: {:.4f}s'.format(v_time),
              'd_time: {:.4f}s'.format(d_time), flush=True)

        if epoch % args.save_interval == 0:
            trainer.save(epoch, model_dir)


        trainer.save_tmp(epoch, model_dir, rank)

# Only save final version.
if master_worker:
    trainer.save(args.epochs, model_dir, "")

iter start


AssertionError: Caught AssertionError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/root/grover/grover/data/groverdataset.py", line 158, in __getitem__
    return self.data[dp_idx][real_idx]
  File "/tmp/ipykernel_1191/3388227146.py", line 49, in __getitem__
    assert self.datapoints is not None
AssertionError


In [65]:
writer.close()

In [269]:
detailed_loss_val

(5.672407817840576,
 5.451661777496338,
 1.284084439277649,
 0.4274427592754364,
 0.47347733080387117,
 0.00922257686033845,
 10.510714721679687)

In [126]:
!rm -r model/tryout

## 2-2. pretrain_model() 코드

In [None]:
def pretrain_model(args: Namespace, logger: Logger = None):
    """
    The entrey of pretrain.
    :param args: the argument.
    :param logger: the logger.
    :return:
    """

    # avoid auto optimized import by pycharm.
    a = MolVocab
    s_time = time.time()
    run_training(args=args, logger=logger)
    e_time = time.time()
    print("Total Time: %.3f" % (e_time - s_time))

# main.py 실행코드

In [94]:
pretrain_model(args, logger)

Loading data
Loading data
Splitting data with seed 0.
Splitting data with seed 0.
Total size = 5,970 | train size = 5,400 | val size = 570
Total size = 5,970 | train size = 5,400 | val size = 570


Namespace(activation='PReLU', atom_vocab_path='data/pretrain/tryout_atom_vocab.pkl', backbone='gtrans', batch_size=32, bias=False, bond_drop_rate=0, bond_vocab_path='data/pretrain/tryout_bond_vocab.pkl', cuda=True, data_path='data/pretrain/tryout', dense=False, depth=5, dist_coff=0.1, dropout=0.1, embedding_output_type='both', enable_multi_gpu=False, epochs=3, fg_label_path=None, final_lr=0.0001, fine_tune_coff=1, hidden_size=100, init_lr=0.0001, max_lr=0.001, motif_hidden_size=300, motif_latent_size=56, motif_order='bfs', motif_vocab_path='data/pretrain/clique.txt', no_cache=True, num_attn_head=1, num_mt_block=1, parser_name='pretrain', save_dir='model/tryout', save_interval=9999999999, train_data_size=5400, undirected=False, warmup_epochs=2.0, weight_decay=0.0)
Loading data:
Number of files: 60
Number of samples: 5970
Samples/file: 100


atom vocab size: 324, bond vocab size: 353, Number of FG tasks: 85
atom vocab size: 324, bond vocab size: 353, Number of FG tasks: 85


Pre-loaded test data: 6
Restore checkpoint, current epoch: 1
GROVEREmbedding(
  (encoders): GTransEncoder(
    (edge_blocks): ModuleList(
      (0): MTBlock(
        (heads): ModuleList(
          (0): Head(
            (mpn_q): MPNEncoder(
              (dropout_layer): Dropout(p=0.1, inplace=False)
              (act_func): PReLU(num_parameters=1)
              (W_h): Linear(in_features=100, out_features=100, bias=False)
            )
            (mpn_k): MPNEncoder(
              (dropout_layer): Dropout(p=0.1, inplace=False)
              (act_func): PReLU(num_parameters=1)
              (W_h): Linear(in_features=100, out_features=100, bias=False)
            )
            (mpn_v): MPNEncoder(
              (dropout_layer): Dropout(p=0.1, inplace=False)
              (act_func): PReLU(num_parameters=1)
              (W_h): Linear(in_features=100, out_features=100, bias=False)
            )
          )
        )
        (act_func): PReLU(num_parameters=1)
        (dropout_layer): Dr

  fgroup_label = torch.Tensor([d.features for d in batch]).float()
  fgroup_label = torch.Tensor([d.features for d in batch]).float()
  fgroup_label = torch.Tensor([d.features for d in batch]).float()
  fgroup_label = torch.Tensor([d.features for d in batch]).float()
  fgroup_label = torch.Tensor([d.features for d in batch]).float()
  fgroup_label = torch.Tensor([d.features for d in batch]).float()
  fgroup_label = torch.Tensor([d.features for d in batch]).float()
  fgroup_label = torch.Tensor([d.features for d in batch]).float()
  fgroup_label = torch.Tensor([d.features for d in batch]).float()
  fgroup_label = torch.Tensor([d.features for d in batch]).float()
  fgroup_label = torch.Tensor([d.features for d in batch]).float()
  fgroup_label = torch.Tensor([d.features for d in batch]).float()
  fgroup_label = torch.Tensor([d.features for d in batch]).float()
  fgroup_label = torch.Tensor([d.features for d in batch]).float()
  fgroup_label = torch.Tensor([d.features for d in batch]).flo

Epoch: 0002 loss_train: 3.373643 loss_val: 2.182687 loss_val_av: 0.507859 loss_val_bv: 0.846167 loss_val_fg: 0.828661 cur_lr: 0.00017 t_time: 10.2503s v_time: 0.8160s d_time: 0.0000s
EP:3 Model Saved on: model/tryout/model.ep3
Total Time: 87.221


In [79]:
args

Namespace(activation='PReLU', atom_vocab_path='data/pretrain/CO2_atom_vocab.pkl', backbone='gtrans', batch_size=1, bias=False, bond_drop_rate=0, bond_vocab_path='data/pretrain/CO2_bond_vocab.pkl', cuda=True, data_path='data/pretrain/CO2', dense=False, depth=5, dist_coff=0.1, dropout=0.0, embedding_output_type='both', enable_multi_gpu=False, epochs=3, fg_label_path=None, final_lr=0.0001, fine_tune_coff=1, hidden_size=3, init_lr=0.0001, max_lr=0.001, motif_hidden_size=300, motif_latent_size=56, motif_order='bfs', motif_vocab_path='data/pretrain/clique.txt', no_cache=True, num_attn_head=4, num_mt_block=1, parser_name='pretrain', save_dir='model/tryout', save_interval=9999999999, train_data_size=10, undirected=False, warmup_epochs=2.0, weight_decay=0.0)

In [None]:
npzfile = np.load('data/pretrain/tryout/feature/0.npz')
npzfile

In [7]:
mol = Chem.MolFromSmiles('C1CCC23C4=C5C6=C2C2=C7C8=C9C%10=C(C%11=C%12C%13=C%10C%10=C%14C%15=C%16C%17=C%18C%19=C%15C(=C%149)C7=C7C2=C2C6=C6C9=C%14C%15=C%20C(=C%12C(=C%114)C%15=C59)C4=C%13C%10C%16C5=C%17C9=C(C%20=C45)C%14=C4C9=C%18C(=C7%19)C2=C64)C83C1')
mol2 = Chem.MolFromSmiles('CC(C)(C)C1=CC(O)=CC=C1O')

In [8]:
smiles='CC(C)(C)C1=CC(O)=CC=C1O'

In [11]:
mol_tree = MolTree(smiles)
mol_tree.recover()
mol_tree.assemble()

AttributeError: 'MolTree' object has no attribute 'savez_compressed'

In [13]:
import pickle
mol_tree = MolTree(smiles)
mol_tree.recover()
mol_tree.assemble()
with open('moltree.p', 'wb') as file: 
    pickle.dump(mol_tree, file)

In [18]:
mol_tree

<grover.topology.mol_tree.MolTree at 0x7f04defd7550>

In [15]:
with open('moltree.p', 'wb') as file: 
    pickle.dump(mol_tree, file)

In [16]:
with open('moltree.p', 'rb') as file:
    mol_tree2 = pickle.load(file)

In [17]:
mol_tree2

<grover.topology.mol_tree.MolTree at 0x7f04df042590>

In [None]:
import argparse
import time
import sys
import csv
import pandas as pd
from rdkit import Chem

sys.path.append('./')
from grover.topology.chemutils import *
from grover.topology.mol_tree import *

parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
parser.add_argument('--datapath', type=str, default='./data/zinc/all.txt',
                        help='root directory of dataset. For now, only classification.')
parser.add_argument('--output_clique', type=str, default='./clique.txt',
                        help='filename to output the pre-trained model')
parser.add_argument('--output_data', type=str, default='./data/zinc/all_edit.txt',
                        help='filename to output deleted data')
args = parser.parse_args(['--datapath','data/pretraindata])

lg = rdkit.RDLogger.logger() 
lg.setLevel(rdkit.RDLogger.CRITICAL)

data = pd.read_csv(args.datapath)
data_len = len(data)
print(data_len)

num=0
cset = set()
counts = {}

print("start")
s_time = time.time()

for i in range(data_len):
    if num%10000==0:print(f'process : {num} / {data_len}')
    smiles = data.smiles[num]
    try : 
        mol = Chem.MolFromSmiles(smiles)
        mol.GetNumHeavyAtoms()
        moltree = MolTree(smiles)
        for node in moltree.nodes:
            cset.add(node.smiles)
            if node.smiles not in counts:
                counts[node.smiles] = 1
            else:
                counts[node.smiles] += 1
    except : 
        print(f'error smiles is {smiles}')
        data=data.drop(num)
    num += 1


print("Preprocessing Completed!")
t_time = time.time() - s_time
print(f'total time is {t_time:.4f}s, data length is {data_len} -> {len(data)}')

clique_list = list(cset)

data.to_csv(args.output_data, index=False)

with open(args.output_clique, 'w') as file:
    for c in clique_list:
        file.write(c)
        file.write('\n')

In [1]:
import numpy as np

In [6]:
data = load_features('data/zinc10M/feature/1.npz')

In [7]:
len(data)

999

In [17]:
import pandas as pd
csv = pd.read_csv('data/mgssl_test/graph/0.csv')

In [18]:
csv

Unnamed: 0,smiles
0,CC(C)(C)c1ccc2occ(CC(=O)Nc3ccccc3F)c2c1
1,C[C@@H]1CC(Nc2cncc(-c3nncn3C)c2)C[C@@H](C)C1
2,N#Cc1ccc(-c2ccc(O[C@@H](C(=O)N3CCCC3)c3ccccc3)...
3,CCOC(=O)[C@@H]1CCCN(C(=O)c2nc(-c3ccc(C)cc3)n3c...
4,N#CC1=C(SCC(=O)Nc2cccc(Cl)c2)N=C([O-])[C@H](C#...
...,...
95,Cc1cc(Cl)ccc1OCC(=O)N/N=C/c1ccccn1
96,O=C1NC(=S)NC(=O)C1=CNc1ccc([N+](=O)[O-])cc1O
97,Cc1c(C(=O)N2CCOCC2)oc2c1-c1nn(CC(=O)NCc3ccco3)...
98,CCc1ccc(CNC(=O)c2ccc(-c3nccnc3N3CCCCC3)cc2)cc1


In [19]:
len(csv)

100

# test subset

In [61]:
class process_tracker():
    def __init__(self, args):
        self.args = args
        self.num_subset = 0
        self.now_subset = 0
        self.now_iter = 0
        self.origin_data_path = args.data_path
        
        
    def save_process(self):
        path = os.path.join(self.args.save_dir, "process.txt")
        txt = open(path, 'w')
        txt.write("num_subset:%d\n" % (self.num_subset))
        txt.write("now_subset:%d\n" % (self.now_subset))
        txt.write("now_iter:%d\n" % (self.now_iter))
        txt.close()
        print('process saved')
        
    def load_process(self):
        '''
        if you don't have saved data, you must make txt file like below
        
        num_subset:0
        now_subset:0
        now_iter:0
        '''
        path = os.path.join(self.args.save_dir, "process.txt")
        f = open(path, 'r')
        lines = f.readlines()
        self.num_subset = np.int(lines[0].split(':')[1].split('\n')[0])
        self.now_subset = np.int(lines[1].split(':')[1].split('\n')[0])
        self.now_iter = np.int(lines[2].split(':')[1].split('\n')[0])
        f.close()


In [70]:
pt = process_tracker(args)
pt.load_process()


In [73]:
pt.now_subset

1

In [69]:
num_subset

20

In [76]:
(args.epochs % args.each_epochs) == 0

True

In [78]:
def subset_learning(args: Namespace, logger: Logger = None, process = None):
    a = MolVocab
    assert (args.epochs % args.each_epochs) == 0, 'you must make args.epochs % args.each_epochs = 0'
    
    left_iters = args.epochs/args.each_epochs - process.now_iter
    for iters in range(left_iters):
        left_subsets = process.num_subset - process.now_subset
        
        # run_motif_training until all subset
        for num_subset in range(left_subsets):
            args.data_path = process.origin_data_path+f'_{process.now_subset}'
            run_motif_training(args=args, logger=logger)
            process.now_subset += 1
            process.save_process()
            
        process.now_subset = 0
        process.now_iter += 1
        process.save_process()
        
    print('all process is end')


In [86]:
a = MolVocab
assert (args.epochs % args.each_epochs) == 0, 'you must make args.epochs % args.each_epochs = 0'

left_iters = np.int(args.epochs/args.each_epochs - pt.now_iter)
for iters in range(left_iters):
    left_subsets = pt.num_subset - pt.now_subset

    # run_motif_training until all subset
    for num_subset in range(left_subsets):
        args.data_path = pt.origin_data_path+f'_{pt.now_subset}'
        #run_motif_training(args=args, logger=logger)
        pt.now_subset += 1
        pt.save_process()

    pt.now_subset = 0
    pt.now_iter += 1
    pt.save_process()

process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
process saved
proces

In [79]:
logger = create_logger(name='pretrain', save_dir=args.save_dir)
pt = process_tracker(args)
pt.load_process()
subset_learning(args, logger, pt)


TypeError: 'float' object cannot be interpreted as an integer