# main.py

In [1]:
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.autograd import Variable

import math, random, sys
import numpy as np
from optparse import OptionParser
from gnn_model import GNN, GNN_grover

sys.path.append('./util/')

from mol_tree import *
from nnutils import *
from datautils import *
from motif_generation import *

import rdkit

# add for grover
import os, time
import wandb
from grover.topology.mol_tree import *
from sklearn.model_selection import train_test_split

In [2]:
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
# 치명적 오류가 발생되면 로그기록해라

def group_node_rep(node_rep, batch_index, batch_size):
    group = []
    count = 0
    for i in range(batch_size):
        num = sum(batch_index == i)
        group.append(node_rep[count:count + num])		# count += num번째 node의 표현을 그룹에 더해라
        count += num
    return group						# 최종 그룹을 출력

## args.parser

In [3]:
parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
parser.add_argument('--device', type=int, default=0,
                    help='which gpu to use if any (default: 0)')
parser.add_argument('--batch_size', type=int, default=32,
                    help='input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default=100,
                    help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=0.001,
                    help='learning rate (default: 0.001)')
parser.add_argument('--decay', type=float, default=0,
                    help='weight decay (default: 0)')
parser.add_argument('--num_layer', type=int, default=5,
                    help='number of GNN message passing layers (default: 5).')
parser.add_argument('--emb_dim', type=int, default=300,
                    help='embedding dimensions (default: 300)')
parser.add_argument('--dropout_ratio', type=float, default=0.2,
                    help='dropout ratio (default: 0.2)')
parser.add_argument('--graph_pooling', type=str, default="mean",
                    help='graph level pooling (sum, mean, max, set2set, attention)')
parser.add_argument('--JK', type=str, default="last",
                    help='how the node features across layers are combined. last, sum, max or concat')
parser.add_argument('--dataset', type=str, default='./data/merge_0',
                    help='root directory of dataset. For now, only classification.')
parser.add_argument('--gnn_type', type=str, default="gin")
parser.add_argument('--input_model_file', type=str, default="", help='filename to read the model (if there is any)')
parser.add_argument('--output_path', type=str, default='./saved_model/grover',
                    help='filename to output the pre-trained model')
parser.add_argument('--num_workers', type=int, default=0, help='number of workers for dataset loading')   #원래는 8이었음 오류로 0으로 바꿈
parser.add_argument("--hidden_size", type=int, default=300, help='hidden size')
parser.add_argument("--vocab", type=str, default='./data/merge/clique.txt', help='vocab path')
parser.add_argument('--order', type=str, default="bfs",
                    help='motif tree generation order (bfs or dfs)')
#for wandb
parser.add_argument('--wandb', action='store_true', default=False, help='add wandb log')
parser.add_argument('--wandb_name', type=str, default = 'MGSSL_Grover', help='wandb name')
args = parser.parse_args([])

In [4]:
args

Namespace(JK='last', batch_size=32, dataset='./data/merge_0', decay=0, device=0, dropout_ratio=0.2, emb_dim=300, epochs=100, gnn_type='gin', graph_pooling='mean', hidden_size=300, input_model_file='', lr=0.001, num_layer=5, num_workers=0, order='bfs', output_path='./saved_model/grover', vocab='./data/merge/clique.txt', wandb=False, wandb_name='MGSSL_Grover')

# change features list
- 나와야될 형태는 atom_feature, edge_index, edge_feature

In [5]:
import torch
from torch.utils.data import Dataset
from mol_tree import MolTree
import numpy as np
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import AllChem
from torch_geometric.data import Batch
from torch_geometric.data import Data

In [6]:
from argparse import Namespace
from typing import List, Tuple, Union
import pickle

ATOM_FEATURES = {
    'atomic_num': list(range(120)),
    'degree': [0, 1, 2, 3, 4, 5],
    'formal_charge': [-1, -2, 1, 2, 0],
    'chiral_tag': [0, 1, 2, 3],
    'num_Hs': [0, 1, 2, 3, 4],
    'hybridization': [
        Chem.rdchem.HybridizationType.SP,
        Chem.rdchem.HybridizationType.SP2,
        Chem.rdchem.HybridizationType.SP3,
        Chem.rdchem.HybridizationType.SP3D,
        Chem.rdchem.HybridizationType.SP3D2
    ],
}
atom_fdim = 151
bond_fdim = 165

##  함수 및 사전설정

In [7]:
def onek_encoding_unk(value: int, choices: List[int]) -> List[int]:
    """
    Creates a one-hot encoding.

    :param value: The value for which the encoding should be one.
    :param choices: A list of possible values.
    :return: A one-hot encoding of the value in a list of length len(choices) + 1.
    If value is not in the list of choices, then the final element in the encoding is 1.
    """
    encoding = [0] * (len(choices) + 1)
    if min(choices) < 0:
        index = value
    else:
        index = choices.index(value) if value in choices else -1
    encoding[index] = 1

    return encoding

In [8]:
def atom_features(atom: Chem.rdchem.Atom, hydrogen_acceptor_match, hydrogen_donor_match, acidic_match, basic_match, ring_info) -> List[Union[bool, int, float]]:
    """
    Builds a feature vector for an atom.

    :param atom: An RDKit atom.
    :param functional_groups: A k-hot vector indicating the functional groups the atom belongs to.
    :return: A list containing the atom features.
    """
    features = onek_encoding_unk(atom.GetAtomicNum() - 1, ATOM_FEATURES['atomic_num']) + \
               onek_encoding_unk(atom.GetTotalDegree(), ATOM_FEATURES['degree']) + \
               onek_encoding_unk(atom.GetFormalCharge(), ATOM_FEATURES['formal_charge']) + \
               onek_encoding_unk(int(atom.GetChiralTag()), ATOM_FEATURES['chiral_tag']) + \
               onek_encoding_unk(int(atom.GetTotalNumHs()), ATOM_FEATURES['num_Hs']) + \
               onek_encoding_unk(int(atom.GetHybridization()), ATOM_FEATURES['hybridization']) + \
               [1 if atom.GetIsAromatic() else 0] + \
               [atom.GetMass() * 0.01]
    atom_idx = atom.GetIdx()
    features = features + \
               onek_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) + \
               [atom_idx in hydrogen_acceptor_match] + \
               [atom_idx in hydrogen_donor_match] + \
               [atom_idx in acidic_match] + \
               [atom_idx in basic_match] + \
               [ring_info.IsAtomInRingOfSize(atom_idx, 3),
                ring_info.IsAtomInRingOfSize(atom_idx, 4),
                ring_info.IsAtomInRingOfSize(atom_idx, 5),
                ring_info.IsAtomInRingOfSize(atom_idx, 6),
                ring_info.IsAtomInRingOfSize(atom_idx, 7),
                ring_info.IsAtomInRingOfSize(atom_idx, 8)]
    return features

def bond_features(bond: Chem.rdchem.Bond
                  ) -> List[Union[bool, int, float]]:
    """
    Builds a feature vector for a bond.

    :param bond: A RDKit bond.
    :return: A list containing the bond features.
    """

    if bond is None:
        fbond = [1] + [0] * (BOND_FDIM - 1)
    else:
        bt = bond.GetBondType()
        fbond = [
            0,  # bond is not None
            bt == Chem.rdchem.BondType.SINGLE,
            bt == Chem.rdchem.BondType.DOUBLE,
            bt == Chem.rdchem.BondType.TRIPLE,
            bt == Chem.rdchem.BondType.AROMATIC,
            (bond.GetIsConjugated() if bt is not None else 0),
            (bond.IsInRing() if bt is not None else 0)
        ]
        fbond += onek_encoding_unk(int(bond.GetStereo()), list(range(6)))
    return fbond

In [9]:
ATOM_FEATURES = {
    'atomic_num': list(range(MAX_ATOMIC_NUM)),
    'degree': [0, 1, 2, 3, 4, 5],
    'formal_charge': [-1, -2, 1, 2, 0],
    'chiral_tag': [0, 1, 2, 3],
    'num_Hs': [0, 1, 2, 3, 4],
    'hybridization': [
        Chem.rdchem.HybridizationType.SP,
        Chem.rdchem.HybridizationType.SP2,
        Chem.rdchem.HybridizationType.SP3,
        Chem.rdchem.HybridizationType.SP3D,
        Chem.rdchem.HybridizationType.SP3D2
    ],
}

NameError: name 'MAX_ATOMIC_NUM' is not defined

## mol_to_graph_data_obj_grover

In [None]:
def mol_to_graph_data_obj_grover(mol):
    mol = Chem.MolFromSmiles(mol)
    hydrogen_donor = Chem.MolFromSmarts("[$([N;!H0;v3,v4&+1]),$([O,S;H1;+0]),n&H1&+0]")
    hydrogen_acceptor = Chem.MolFromSmarts(
        "[$([O,S;H1;v2;!$(*-*=[O,N,P,S])]),$([O,S;H0;v2]),$([O,S;-]),$([N;v3;!$(N-*=[O,N,P,S])]),"
        "n&H0&+0,$([o,s;+0;!$([o,s]:n);!$([o,s]:c:n)])]")
    acidic = Chem.MolFromSmarts("[$([C,S](=[O,S,P])-[O;H1,-1])]")
    basic = Chem.MolFromSmarts(
        "[#7;+,$([N;H2&+0][$([C,a]);!$([C,a](=O))]),$([N;H1&+0]([$([C,a]);!$([C,a](=O))])[$([C,a]);"
        "!$([C,a](=O))]),$([N;H0&+0]([C;!$(C(=O))])([C;!$(C(=O))])[C;!$(C(=O))])]")

    hydrogen_donor_match = sum(mol.GetSubstructMatches(hydrogen_donor), ())
    hydrogen_acceptor_match = sum(mol.GetSubstructMatches(hydrogen_acceptor), ())
    acidic_match = sum(mol.GetSubstructMatches(acidic), ())
    basic_match = sum(mol.GetSubstructMatches(basic), ())
    ring_info = mol.GetRingInfo()

    n_atoms = mol.GetNumAtoms()
    
    f_atoms = []
    for _, atom in enumerate(mol.GetAtoms()):
        f_atoms.append(atom_features(atom, hydrogen_donor_match, hydrogen_acceptor_match, acidic_match, basic_match, ring_info))
    f_atoms = [f_atoms[i] for i in range(n_atoms)]
    
    f_bonds = []
    bond_list = []
    for a1 in range(n_atoms):
        for a2 in range(a1 + 1, n_atoms):
            bond = mol.GetBondBetweenAtoms(a1, a2)

            if bond is None:
                continue

            f_bond = bond_features(bond)

            # Always treat the bond as directed.
            f_bonds.append(f_atoms[a1] + f_bond)
            bond_list.append([a1, a2])
            f_bonds.append(f_atoms[a2] + f_bond)
            bond_list.append([a2, a1])
    
    data = [f_atoms, bond_list, f_bonds]
    return data

In [10]:
data = mol_to_graph_data_obj_grover('C=CCOc1cccnc1C(=O)NC[C@H]1CC[C@@H](C(=O)N(C)C)O1')

### 세부실행 예시

In [11]:
mol = Chem.MolFromSmiles('C=CCOc1cccnc1C(=O)NC[C@H]1CC[C@@H](C(=O)N(C)C)O1')
# 이건 mol단계에 인식되게 넣어야한다.
hydrogen_donor = Chem.MolFromSmarts("[$([N;!H0;v3,v4&+1]),$([O,S;H1;+0]),n&H1&+0]")
hydrogen_acceptor = Chem.MolFromSmarts(
    "[$([O,S;H1;v2;!$(*-*=[O,N,P,S])]),$([O,S;H0;v2]),$([O,S;-]),$([N;v3;!$(N-*=[O,N,P,S])]),"
    "n&H0&+0,$([o,s;+0;!$([o,s]:n);!$([o,s]:c:n)])]")
acidic = Chem.MolFromSmarts("[$([C,S](=[O,S,P])-[O;H1,-1])]")
basic = Chem.MolFromSmarts(
    "[#7;+,$([N;H2&+0][$([C,a]);!$([C,a](=O))]),$([N;H1&+0]([$([C,a]);!$([C,a](=O))])[$([C,a]);"
    "!$([C,a](=O))]),$([N;H0&+0]([C;!$(C(=O))])([C;!$(C(=O))])[C;!$(C(=O))])]")

hydrogen_donor_match = sum(mol.GetSubstructMatches(hydrogen_donor), ())
hydrogen_acceptor_match = sum(mol.GetSubstructMatches(hydrogen_acceptor), ())
acidic_match = sum(mol.GetSubstructMatches(acidic), ())
basic_match = sum(mol.GetSubstructMatches(basic), ())
ring_info = mol.GetRingInfo()

n_atoms = mol.GetNumAtoms()

### atom feature 생성

In [12]:
f_atoms = []
for _, atom in enumerate(mol.GetAtoms()):
    f_atoms.append(atom_features(atom))
f_atoms = [f_atoms[i] for i in range(n_atoms)]
# f_atoms는 atom개수 x 151

TypeError: atom_features() missing 5 required positional arguments: 'hydrogen_acceptor_match', 'hydrogen_donor_match', 'acidic_match', 'basic_match', and 'ring_info'

### bond feature 생성으로 atom과 이어서 있어야함

In [37]:
f_bonds = []
bond_list = []
for a1 in range(n_atoms):
    for a2 in range(a1 + 1, n_atoms):
        bond = mol.GetBondBetweenAtoms(a1, a2)

        if bond is None:
            continue

        f_bond = bond_features(bond)

        # Always treat the bond as directed.
        f_bonds.append(f_atoms[a1] + f_bond)
        bond_list.append([a1, a2])
        f_bonds.append(f_atoms[a2] + f_bond)
        bond_list.append([a2, a1])

### 아래는 참조용 MGSSL코드

In [56]:
edges_list = []
edge_features_list = []
for bond in mol.GetBonds():
    i = bond.GetBeginAtomIdx()
    j = bond.GetEndAtomIdx()
    bt = bond.GetBondType()
    edge_feature = [allowable_features['possible_bonds'].index(bt)] + \
                    [int(bond.GetIsConjugated() if bt is not None else 0)] + \
                    [int(bond.IsInRing() if bt is not None else 0)] + \
                    [GROVER_FEATURES['stereo'].index(bond.GetStereo())]
    edges_list.append((i, j))
    edge_features_list.append(edge_feature)
    edges_list.append((j, i))
    edge_features_list.append(edge_feature)

In [143]:
isatominringofsize(atom_idx)

3

In [145]:
[ring_info.IsAtomInRingOfSize(atom_idx, 3),
                ring_info.IsAtomInRingOfSize(atom_idx, 4),
                ring_info.IsAtomInRingOfSize(atom_idx, 5),
                ring_info.IsAtomInRingOfSize(atom_idx, 6),
                ring_info.IsAtomInRingOfSize(atom_idx, 7),
                ring_info.IsAtomInRingOfSize(atom_idx, 8)]

[False, False, True, False, False, False]

In [152]:
atom.GetAtomicNum()

8

In [26]:
num_atom_type = 100 #including the extra mask tokens
num_chirality_tag = 3
emb_dim = 300
x=torch.tensor(f_atoms)
x_embedding1 = torch.nn.Embedding(num_atom_type, emb_dim)
x_embedding2 = torch.nn.Embedding(2, emb_dim)
x_embedding3 = torch.nn.Embedding(2, emb_dim)

# Dataset 변경

In [35]:
import torch
from torch.utils.data import Dataset
from mol_tree import MolTree
import numpy as np
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import AllChem
from torch_geometric.data import Batch
from torch_geometric.data import Data

import math
import os
import csv
from typing import Union, List
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from rdkit import Chem

from grover.topology.mol_tree import MolTree, MolTree_break

In [36]:
import os, pickle
from grover.topology.mol_tree import *
from sklearn.model_selection import train_test_split

In [8]:
#원본

In [18]:
class MoleculeDataset_grover(Dataset):

    def __init__(self, data):
        with open(data, 'rb') as f:
            self.data = pickle.load(f)
        self.n_samples = len(data)

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        mol_tree = self.data[idx]
        return mol_tree

In [10]:
dataset = MoleculeDataset_grover('data/merge/total.p')

In [11]:
dataset[3]

<grover.topology.mol_tree.MolTree_break at 0x7fb87216b350>

In [12]:
train_dataset, val_dataset = train_test_split(dataset, test_size=0.1, random_state=42)

In [13]:
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=args.num_workers, collate_fn=lambda x:x, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=True, num_workers=args.num_workers, collate_fn=lambda x:x, drop_last=True)

In [14]:
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
model = GNN(5, args.emb_dim, JK='last', drop_ratio=args.dropout_ratio, gnn_type='gin').to(device)
if os.path.exists(args.input_model_file):
    model.load_state_dict(torch.load(args.input_model_file))

vocab = [x.strip("\r\n ") for x in open(args.vocab)]
vocab = Vocab(vocab)
motif_model = Motif_Generation_Grover(vocab, args.hidden_size, device, args.order).to(device)

model_list = [model, motif_model]
optimizer_model = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay)
optimizer_motif = optim.Adam(motif_model.parameters(), lr=1e-3, weight_decay=args.decay)

optimizer_list = [optimizer_model, optimizer_motif]



In [15]:
for step, batch in enumerate(train_loader):	# 데이터로더에서 순회 진행바 표시형태로 순회해서 step과 batch대로 반복하자
    print(batch)
    batch_size = len(batch)

    graph_batch = moltree_to_graph_data(batch)		# 분자식을 파이토치 지오메트릭 패키지에서 요구되는 그래프 데이터 형태로 변경해서 배치단위로 저장   /datautils에 있음
    #store graph object data in the process stage	
    batch_index = graph_batch.batch.numpy()			# 배치내의 배치텐서를 넘파이로 인덱스에 넘겨라
    graph_batch = graph_batch.to(device)			# 그래프배치는 GPU로 되게

[<grover.topology.mol_tree.MolTree_break object at 0x7fb87217cd90>, <grover.topology.mol_tree.MolTree object at 0x7fb8720d3fd0>]
[<grover.topology.mol_tree.MolTree object at 0x7fb8720e46d0>, <grover.topology.mol_tree.MolTree_break object at 0x7fb87216b350>]
[<grover.topology.mol_tree.MolTree object at 0x7fb8720ea890>, <grover.topology.mol_tree.MolTree object at 0x7fb87214a690>]
[<grover.topology.mol_tree.MolTree object at 0x7fb872150790>, <grover.topology.mol_tree.MolTree_break object at 0x7fb8721d3e10>]
[<grover.topology.mol_tree.MolTree_break object at 0x7fb87218f310>, <grover.topology.mol_tree.MolTree_break object at 0x7fb87212f3d0>]
[<grover.topology.mol_tree.MolTree object at 0x7fb8721464d0>, <grover.topology.mol_tree.MolTree object at 0x7fb8720dd310>]
[<grover.topology.mol_tree.MolTree object at 0x7fb872142310>, <grover.topology.mol_tree.MolTree_break object at 0x7fb87215fd10>]
[<grover.topology.mol_tree.MolTree_break object at 0x7fb872138f10>, <grover.topology.mol_tree.MolTree_b

In [16]:
graph_batch

DataDataBatch(x=[43, 2], edge_index=[2, 92], edge_attr=[92, 2], batch=[43], ptr=[3])

In [17]:
#grover꺼 추가

In [7]:
def split_data_grover(data,
               split_type='random',
               sizes=(0.8, 0.1, 0.1),
               seed=0,
               logger=None):
    """
    Split data with given train/validation/test ratio.
    :param data:
    :param split_type:
    :param sizes:
    :param seed:
    :param logger:
    :return:
    """
    assert len(sizes) == 3 and sum(sizes) == 1

    if split_type == "random":
        data.shuffle(seed=seed)
        data = data.data

        train_size = int(sizes[0] * len(data))
        train_val_size = int((sizes[0] + sizes[1]) * len(data))

        train = data[:train_size]
        val = data[train_size:train_val_size]
        test = data[train_val_size:]

        return BatchMolDataset_motif(train), BatchMolDataset_motif(val), BatchMolDataset_motif(test)
    else:
        raise NotImplementedError("Do not support %s splits" % split_type)

In [8]:
import math
import time
import torch
from torch.utils.data.sampler import Sampler
import torch.distributed as dist

In [9]:
class DistributedSampler(Sampler):
    """Sampler that restricts data loading to a subset of the dataset.

    It is especially useful in conjunction with
    :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
    process can pass a DistributedSampler instance as a DataLoader sampler,
    and load a subset of the original dataset that is exclusive to it.

    .. note::
        Dataset is assumed to be of constant size.

    Arguments:
        dataset: Dataset used for sampling.
        num_replicas (optional): Number of processes participating in
            distributed training.
        rank (optional): Rank of the current process within num_replicas.
    """

    def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, sample_per_file=None):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.sample_per_file = sample_per_file
        self.shuffle = shuffle

    def get_indices(self):

        indices = list(range(len(self.dataset)))

        if self.sample_per_file is not None:
            indices = self.sub_indices_of_rank(indices)
        else:
            # add extra samples to make it evenly divisible
            indices += indices[:(self.total_size - len(indices))]
            assert len(indices) == self.total_size
            # subsample
            s = self.rank * self.num_samples
            e = min((self.rank + 1) * self.num_samples, len(indices))

            # indices = indices[self.rank:self.total_size:self.num_replicas]
            indices = indices[s:e]

        if self.shuffle:
            g = torch.Generator()
            # the seed need to be considered.
            g.manual_seed((self.epoch + 1) * (self.rank + 1) * time.time())
            idx = torch.randperm(len(indices), generator=g).tolist()
            indices = [indices[i] for i in idx]

        # disable this since sub_indices_of_rank.
        # assert len(indices) == self.num_samples

        return indices

    def sub_indices_of_rank(self, indices):

        # fix generator for each epoch
        g = torch.Generator()
        # All data should be loaded in each epoch.
        g.manual_seed((self.epoch + 1) * 2 + 3)

        # the fake file indices to cache
        f_indices = list(range(int(math.ceil(len(indices) * 1.0 / self.sample_per_file))))
        idx = torch.randperm(len(f_indices), generator=g).tolist()
        f_indices = [f_indices[i] for i in idx]

        file_per_rank = int(math.ceil(len(f_indices) * 1.0 / self.num_replicas))
        # add extra fake file to make it evenly divisible
        f_indices += f_indices[:(file_per_rank * self.num_replicas - len(f_indices))]

        # divide index by rank
        rank_s = self.rank * file_per_rank
        rank_e = min((self.rank + 1) * file_per_rank, len(f_indices))

        # get file index for this rank
        f_indices = f_indices[rank_s:rank_e]
        # print("f_indices")
        # print(f_indices)
        res_indices = []
        for fi in f_indices:
            # get real indices for this rank
            si = fi * self.sample_per_file
            ei = min((fi + 1) * self.sample_per_file, len(indices))
            cur_idx = [indices[i] for i in range(si, ei)]
            res_indices += cur_idx

        self.num_samples = len(res_indices)
        return res_indices

    def __iter__(self):
        return iter(self.get_indices())

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch

In [10]:
class MoleculeDatapoint_motif:
    """A MoleculeDatapoint contains a single molecule and its associated features and targets."""

    def __init__(self,
                 line: List[str],
                 args: Namespace = None,
                 moltrees: object = None,
                 use_compound_names: bool = False):
        """
        Initializes a MoleculeDatapoint, which contains a single molecule.

        :param line: A list of strings generated by separating a line in a data CSV file by comma.
        :param args: Arguments.
        :param features: A numpy array containing additional features (ex. Morgan fingerprint).
        :param use_compound_names: Whether the data CSV includes the compound name on each line.
        """
        self.args = None
        if args is not None:
            self.args = args

        self.moltrees = moltrees

        if use_compound_names:
            self.compound_name = line[0]  # str
            line = line[1:]
        else:
            self.compound_name = None

        self.smiles = line[0]  # str

        # Create targets
        self.targets = [float(x) if x != '' else None for x in line[1:]]
        
    def set_moltrees(self, moltrees: list):
        """
        Sets the moltree of the molecule.

        :param moltree: moltree object
        """
        self.moltrees = moltrees
        
    def clean_moltree(self):
        """
        clean moltree for memory
        """
        self.moltrees = None

    def num_tasks(self) -> int:
        """
        Returns the number of prediction tasks.

        :return: The number of tasks.
        """
        return len(self.targets)

    def set_targets(self, targets: List[float]):
        """
        Sets the targets of a molecule.

        :param targets: A list of floats containing the targets.
        """
        self.targets = targets

In [11]:
class BatchDatapoint_motif:
    def __init__(self,
                 smiles_file,
                 moltree_file,
                 n_samples,
                 ):
        self.smiles_file = smiles_file
        self.moltree_file = moltree_file
        # deal with the last batch graph numbers.
        self.n_samples = n_samples
        self.datapoints = None

    def load_datapoints(self):
        moltrees = self.load_moltree()
        self.datapoints = []

        with open(self.smiles_file) as f:
            reader = csv.reader(f)
            next(reader)
            for i, line in enumerate(reader):
                # line = line[0]
                d = MoleculeDatapoint_motif(line=line,
                                      moltrees=moltrees[i])
                self.datapoints.append(d)
        f.close()

        assert len(self.datapoints) == self.n_samples
    
    def load_moltree(self):
        with open(self.moltree_file, 'rb') as f:
            moltrees = pickle.load(f)            
        return moltrees

    def shuffle(self):
        pass

    def clean_cache(self):
        del self.datapoints
        self.datapoints = None

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        assert self.datapoints is not None
        return self.datapoints[idx]

    def is_loaded(self):
        return self.datapoints is not None

In [12]:
class BatchMolDataset_motif(Dataset):
    def __init__(self, data: List[BatchDatapoint_motif],
                 graph_per_file=None):
        self.data = data

        self.len = 0
        for d in self.data:
            self.len += len(d)
        if graph_per_file is not None:
            self.sample_per_file = graph_per_file
        else:
            self.sample_per_file = len(self.data[0]) if len(self.data) != 0 else None

    def shuffle(self, seed: int = None):
        pass

    def clean_cache(self):
        for d in self.data:
            d.clean_cache()

    def __len__(self) -> int:
        return self.len

    def __getitem__(self, idx) -> Union[MoleculeDatapoint_motif, List[MoleculeDatapoint_motif]]:
        # print(idx)
        dp_idx = int(idx / self.sample_per_file)
        real_idx = idx % self.sample_per_file
        return self.data[dp_idx][real_idx]

    def load_data(self, idx):
        dp_idx = int(idx / self.sample_per_file)
        if not self.data[dp_idx].is_loaded():
            self.data[dp_idx].load_datapoints()

    def count_loaded_datapoints(self):
        res = 0
        for d in self.data:
            if d.is_loaded():
                res += 1
        return res

In [13]:
def get_motif_data(data_path, logger=None):
    """
    Load data from the data_path.
    :param data_path: the data_path.
    :param logger: the logger.
    :return:
    """
    debug = logger.debug if logger is not None else print
    summary_path = os.path.join(data_path, "summary.txt")
    smiles_path = os.path.join(data_path, "graph")
    moltree_path = os.path.join(data_path, "moltrees")

    fin = open(summary_path)
    n_files = int(fin.readline().strip().split(":")[-1])
    n_samples = int(fin.readline().strip().split(":")[-1])
    sample_per_file = int(fin.readline().strip().split(":")[-1])
    debug("Loading data:")
    debug("Number of files: %d" % n_files)
    debug("Number of samples: %d" % n_samples)
    debug("Samples/file: %d" % sample_per_file)

    datapoints = []
    for i in range(n_files):
        smiles_path_i = os.path.join(smiles_path, str(i) + ".csv")
        moltree_path_i = os.path.join(moltree_path, str(i) + ".p")
        n_samples_i = sample_per_file if i != (n_files - 1) else n_samples % sample_per_file
        datapoints.append(BatchDatapoint_motif(smiles_path_i, moltree_path_i, n_samples_i))
    return BatchMolDataset_motif(datapoints), sample_per_file


In [14]:
class GroverMotifCollator(object):
    def __init__(self, shared_dict, args):
        self.args = args
        self.shared_dict = shared_dict

    def __call__(self, batch):
        smiles_batch = [d.smiles for d in batch] # 여기서 말하는 batch는 batchmoldataset_motif다 그리고 d는 batchdatapoint_motif고
        #batchgraph = mol2graph(smiles_batch, self.shared_dict, self.args).get_components()

        #fgroup_label = torch.Tensor(np.array([d.features for d in batch])).float()
        moltree_batch = [d.moltrees for d in batch]
        
        # may be some mask here

        return moltree_batch

In [15]:
def pre_load_data(dataset: BatchMolDataset_motif, rank: int, num_replicas: int, sample_per_file: int = None, epoch: int = 0):
    """
    Pre-load data at the beginning of each epoch.
    :param dataset: the training dataset.
    :param rank: the rank of the current worker.
    :param num_replicas: the replicas.
    :param sample_per_file: the number of the data points in each file. When sample_per_file is None, all data will be
    loaded. It implies the testing phase. (TODO: bad design here.)
    :param epoch: the epoch number.
    :return:
    """
    mock_sampler = DistributedSampler(dataset, num_replicas=num_replicas, rank=rank, shuffle=False,
                                      sample_per_file=sample_per_file)
    mock_sampler.set_epoch(epoch)
    pre_indices = mock_sampler.get_indices()
    for i in pre_indices:
        dataset.load_data(i)

In [16]:
grover_data, sample_per_file = get_motif_data(args.dataset)

Loading data:
Number of files: 20
Number of samples: 20000
Samples/file: 1000


In [17]:
train, test, _ = split_data_grover(grover_data, sizes=(0.5,0.5,0), seed=0)

In [18]:
shared_dict = {}
GMC = GroverMotifCollator(shared_dict=shared_dict, args=args)

In [19]:
train

<__main__.BatchMolDataset_motif at 0x7f097019af90>

In [20]:
import csv

In [21]:
pre_load_data(train, rank=0, num_replicas=1)
train_grover_loader = DataLoader(train, batch_size=2, shuffle=True, num_workers=args.num_workers, collate_fn=GMC)

In [22]:
for _, batch in enumerate(train_grover_loader):
    print(batch)
    batch_size = len(batch)

    graph_batch = moltree_to_graph_data(batch)		# 분자식을 파이토치 지오메트릭 패키지에서 요구되는 그래프 데이터 형태로 변경해서 배치단위로 저장   /datautils에 있음
    #store graph object data in the process stage	
    batch_index = graph_batch.batch.numpy()			# 배치내의 배치텐서를 넘파이로 인덱스에 넘겨라
    graph_batch = graph_batch.to(device)			# 그래프배치는 GPU로 되게

[<grover.topology.mol_tree.MolTree object at 0x7f089996bf50>, <grover.topology.mol_tree.MolTree object at 0x7f0899a951d0>]


NameError: name 'device' is not defined

In [33]:
graph_batch

DataDataBatch(x=[43, 2], edge_index=[2, 92], edge_attr=[92, 2], batch=[43], ptr=[3])

# model

In [10]:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, softmax
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
from torch_scatter import scatter_add
from torch_geometric.nn.inits import glorot, zeros

In [11]:
class GINConv_grover(MessagePassing):
    """
    Extension of GIN aggregation to incorporate edge information by concatenation.

    Args:
        emb_dim (int): dimensionality of embeddings for nodes and edges.
        embed_input (bool): whether to embed input or not. 
        

    See https://arxiv.org/abs/1810.00826
    """
    def __init__(self, emb_dim, aggr = "add"):
        super(GINConv, self).__init__()
        #multi-layer perceptron
        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
        self.edge_embedding = torch.nn.Linear(bond_fdim, emb_dim)

        self.aggr = aggr

    def forward(self, x, edge_index, edge_attr):
        #add self loops in the edge space
        edge_index = add_self_loops(edge_index, num_nodes = x.size(0))

        #add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), 2)
        self_loop_attr[:,0] = 4 #bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)

        edge_embeddings = self.edge_embedding(edge_attr)

        return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings)

    def message(self, x_j, edge_attr):
        return x_j + edge_attr

    def update(self, aggr_out):
        return self.mlp(aggr_out)

In [12]:
class GNN_grover(torch.nn.Module):
    """
    

    Args:
        num_layer (int): the number of GNN layers
        emb_dim (int): dimensionality of embeddings
        JK (str): last, concat, max or sum.
        max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation
        drop_ratio (float): dropout rate
        gnn_type: gin, gcn, graphsage, gat

    Output:
        node representations

    """
    def __init__(self, num_layer, emb_dim, JK = "last", drop_ratio = 0, gnn_type = "gin"):
        super(GNN, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.x_embedding = torch.nn.Linear(atom_fdim, emb_dim, bias=True)

        ###List of MLPs
        self.gnns = torch.nn.ModuleList()
        for layer in range(num_layer):
            if gnn_type == "gin":
                self.gnns.append(GINConv(emb_dim, aggr = "add"))
            elif gnn_type == "gcn":
                self.gnns.append(GCNConv(emb_dim))
            elif gnn_type == "gat":
                self.gnns.append(GATConv(emb_dim))
            elif gnn_type == "graphsage":
                self.gnns.append(GraphSAGEConv(emb_dim))

        ###List of batchnorms
        self.batch_norms = torch.nn.ModuleList()
        for layer in range(num_layer):
            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    #def forward(self, x, edge_index, edge_attr):
    def forward(self, *argv):
        if len(argv) == 3:
            x, edge_index, edge_attr = argv[0], argv[1], argv[2]
        elif len(argv) == 1:
            data = argv[0]
            x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        else:
            raise ValueError("unmatched number of arguments.")

        x = self.x_embedding(x)

        h_list = [x]
        for layer in range(self.num_layer):
            h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            #h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            h_list.append(h)

        ### Different implementations of Jk-concat   #레이어간 노드 기능들을 어떻게 할건지, 기본은 Last다
        if self.JK == "concat":
            node_representation = torch.cat(h_list, dim = 1)
        elif self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "max":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.max(torch.cat(h_list, dim = 0), dim = 0)[0]
        elif self.JK == "sum":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.sum(torch.cat(h_list, dim = 0), dim = 0)[0]

        return node_representation

In [None]:
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
model = GNN(5, 1200, JK='last', drop_ratio=0.2, gnn_type='gin').to(device)

vocab = [x.strip("\r\n ") for x in open('../../grover/data/merge/clique.txt')]
vocab = Vocab(vocab)
motif_model = Motif_Generation(vocab, 1200, 56, 3, device, 'dfs').to(device)

In [None]:
model.load_

In [None]:
for step, batch in enumerate(tqdm(loader, desc="Iteration")):	# 데이터로더에서 순회 진행바 표시형태로 순회해서 step과 batch대로 반복하자

    batch_size = len(batch)

    graph_batch = moltree_to_graph_data(batch)		# 분자식을 파이토치 지오메트릭 패키지에서 요구되는 그래프 데이터 형태로 변경해서 배치단위로 저장   /datautils에 있음
    #store graph object data in the process stage	
    batch_index = graph_batch.batch.numpy()			# 배치내의 배치텐서를 넘파이로 인덱스에 넘겨라
    graph_batch = graph_batch.cuda()			# 그래프배치는 GPU로 되게
    node_rep = model(graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr)	# GNN모델에 그래프(x, 엣지인덱스, 엣지의 특성) 투입
    node_rep = group_node_rep(node_rep, batch_index, batch_size)			# rep는 representation의 줄임말로 노드 표현을 의미
    loss, wacc, tacc = motif_model(batch, node_rep)		# motif모델에서 손실, motif정확도, 위상 정확도 출력
    if step == 1 : break

In [81]:
node_rep

[tensor([[-0.2578,  0.1756, -0.4891,  ...,  0.2525,  0.0000, -1.0834],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.7313, -0.0132, -0.3913],
         [-0.0000,  0.8658, -0.4621,  ..., -0.1646,  0.2450, -0.5645],
         ...,
         [-0.5576, -0.3981,  0.0000,  ..., -0.4005,  0.0000,  0.2969],
         [-0.0000, -0.0094,  0.7077,  ..., -0.0000,  1.3705,  0.0000],
         [-0.1629, -0.0094,  0.7077,  ..., -0.4874,  1.3705,  0.1655]],
        device='cuda:0', grad_fn=<SliceBackward0>),
 tensor([[-6.1990e-01,  1.8907e-01, -0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -0.0000e+00],
         [-0.0000e+00,  0.0000e+00, -3.6910e-01,  ...,  1.4765e+00,
           0.0000e+00,  4.6965e-01],
         [-6.5697e-01,  1.9685e-01, -0.0000e+00,  ...,  0.0000e+00,
           1.1914e+00,  3.4645e-01],
         ...,
         [ 8.1637e-01,  4.8261e-01, -2.5783e+00,  ...,  4.8366e-01,
           7.1742e-01,  7.3027e-01],
         [ 0.0000e+00,  4.0174e-01, -0.0000e+00,  ..., -2.3614e-01,
     

# test

In [2]:
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.autograd import Variable

import math, random, sys
from tqdm import tqdm
import numpy as np
from optparse import OptionParser
from gnn_model import GNN, GNN_grover

sys.path.append('./util/')

from mol_tree import *
from nnutils import *
from datautils import *
from motif_generation import *

import rdkit

# add for grover
import wandb
from grover.topology.mol_tree import *

lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
# 치명적 오류가 발생되면 로그기록해라


In [None]:
parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
parser.add_argument('--device', type=int, default=0,
                    help='which gpu to use if any (default: 0)')
parser.add_argument('--batch_size', type=int, default=32,
                    help='input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default=100,
                    help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=0.001,
                    help='learning rate (default: 0.001)')
parser.add_argument('--decay', type=float, default=0,
                    help='weight decay (default: 0)')
parser.add_argument('--num_layer', type=int, default=5,
                    help='number of GNN message passing layers (default: 5).')
parser.add_argument('--emb_dim', type=int, default=300,
                    help='embedding dimensions (default: 300)')
parser.add_argument('--dropout_ratio', type=float, default=0.2,
                    help='dropout ratio (default: 0.2)')
parser.add_argument('--graph_pooling', type=str, default="mean",
                    help='graph level pooling (sum, mean, max, set2set, attention)')
parser.add_argument('--JK', type=str, default="last",
                    help='how the node features across layers are combined. last, sum, max or concat')
parser.add_argument('--dataset', type=str, default='./data/zinc/all.txt',
                    help='root directory of dataset. For now, only classification.')
parser.add_argument('--gnn_type', type=str, default="gin")
parser.add_argument('--input_model_file', type=str, default="", help='filename to read the model (if there is any)')
parser.add_argument('--output_model_file', type=str, default='./saved_model/motif_pretrain',
                    help='filename to output the pre-trained model')
parser.add_argument('--num_workers', type=int, default=0, help='number of workers for dataset loading')   #원래는 8이었음 오류로 0으로 바꿈
parser.add_argument("--hidden_size", type=int, default=300, help='hidden size')
parser.add_argument("--latent_size", type=int, default=56, help='latent size')
parser.add_argument("--vocab", type=str, default='./data/zinc/clique.txt', help='vocab path')
parser.add_argument('--order', type=str, default="bfs",
                    help='motif tree generation order (bfs or dfs)')
#for wandb
parser.add_argument('--wandb', action='store_true', default=False, help='add wandb log')
parser.add_argument('--wandb_name', type=str, default = 'MGSSL_Grover', help='wandb name')
args = parser.parse_args(['--emb_dim', '1200', '--hidden_size', '1200', '--dataset', '../../data/merge_16/total.p'])

torch.manual_seed(0)
np.random.seed(0)
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)

dataset = MoleculeDataset_grover(args.dataset)

loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=lambda x:x, drop_last=True)

model = GNN(5, args.emb_dim, JK='last', drop_ratio=0.2, gnn_type='gin').to(device)
if os.path.exists(args.input_model_file):
    model.load_state_dict(torch.load(args.input_model_file))

In [None]:
def validation(args, model_list, loader, optimizer_list, device):
    model, motif_model = model_list                             # 훈련간 사용 모델은 GNN모델과 motif모델이다.
    optimizer_model, optimizer_motif = optimizer_list        # 옵티마이저도 둘에 대해 각각 사용하라.

    model.eval()					#모델, 모티프 모델 훈련!
    motif_model.eval()
    word_acc, topo_acc = 0, 0			# 분자와 위상 정확도 변수 설정
    for step, batch in enumerate(tqdm(loader, desc="Iteration")):	# 데이터로더에서 순회 진행바 표시형태로 순회해서 step과 batch대로 반복하자

        batch_size = len(batch)

        graph_batch = moltree_to_graph_data(batch)		# 분자식을 파이토치 지오메트릭 패키지에서 요구되는 그래프 데이터 형태로 변경해서 배치단위로 저장   /datautils에 있음
        #store graph object data in the process stage	
        batch_index = graph_batch.batch.numpy()			# 배치내의 배치텐서를 넘파이로 인덱스에 넘겨라
        graph_batch = graph_batch.to(device)			# 그래프배치는 GPU로 되게
        node_rep = model(graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr)	# GNN모델에 그래프(x, 엣지인덱스, 엣지의 특성) 투입
        node_rep = group_node_rep(node_rep, batch_index, batch_size)			# rep는 representation의 줄임말로 노드 표현을 의미
        loss, wacc, tacc = motif_model(batch, node_rep)		# motif모델에서 손실, motif정확도, 위상 정확도 출력

        word_acc += wacc
        topo_acc += tacc					#위상 정확도
        
        if args.wandb :         
            wandb.log({"validation_loss" : loss})

        if (step+1) % 20 == 0:
            word_acc = word_acc / 20 * 100
            topo_acc = topo_acc / 20 * 100
            print("Loss: %.1f, Word: %.2f, Topo: %.2f" % (loss, word_acc, topo_acc))
            word_acc, topo_acc = 0, 0

In [7]:
data = 1234.556789
epoch = 1

In [6]:
print(f'{data:.4f}')

1234.5568


In [9]:
print(f'{epoch:04d}')

0001


In [12]:
1e-100>9999

False

In [15]:
import math, random, sys
sys.path.append('./util/')
from datautils import *
dataset = MoleculeDataset_grover('../../data/merge_0/moltrees/0.p')

In [16]:
from sklearn.model_selection import train_test_split
train, val = train_test_split(dataset, test_size=0.1, random_state=42)

In [22]:
len(dataset)

1000

In [19]:
len(train)

900

In [20]:
len(val)

100

# main.py

In [1]:
#python pretrain_grovermotif.py --dataset data/merge_0 --vocab data/merge_0/clique.txt --grover_dataset --output_path saved_model/grover
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.autograd import Variable

import math, random, sys
import numpy as np
from optparse import OptionParser
from gnn_model import GNN, GNN_grover

sys.path.append('./util/')

from mol_tree import *
from nnutils import *
from datautils import *
from motif_generation import *

import rdkit

# add for grover
import os, time
import wandb
from grover.topology.mol_tree import *
from grover.topology.grover_datasets import *
from sklearn.model_selection import train_test_split

lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
# 치명적 오류가 발생되면 로그기록해라

In [2]:
def group_node_rep(node_rep, batch_index, batch_size):
    group = []
    count = 0
    for i in range(batch_size):
        num = sum(batch_index == i)
        group.append(node_rep[count:count + num])		# count += num번째 node의 표현을 그룹에 더해라
        count += num
    return group						# 최종 그룹을 출력

def train(args, model_list, loader, optimizer_list, device):
    model, motif_model = model_list                             # 훈련간 사용 모델은 GNN모델과 motif모델이다.
    optimizer_model, optimizer_motif = optimizer_list        # 옵티마이저도 둘에 대해 각각 사용하라.

    model.train()					#모델, 모티프 모델 훈련!
    motif_model.train()
    word_acc, topo_acc = 0, 0			# 분자와 위상 정확도 변수 설정
    for step, batch in enumerate(loader):	# 데이터로더에서 순회 진행바 표시형태로 순회해서 step과 batch대로 반복하자

        batch_size = len(batch)

        graph_batch = moltree_to_graph_data(batch)		# 분자식을 파이토치 지오메트릭 패키지에서 요구되는 그래프 데이터 형태로 변경해서 배치단위로 저장   /datautils에 있음
        #store graph object data in the process stage	
        batch_index = graph_batch.batch.numpy()			# 배치내의 배치텐서를 넘파이로 인덱스에 넘겨라
        graph_batch = graph_batch.to(device)			# 그래프배치는 GPU로 되게
        node_rep = model(graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr)	# GNN모델에 그래프(x, 엣지인덱스, 엣지의 특성) 투입
        node_rep = group_node_rep(node_rep, batch_index, batch_size)			# rep는 representation의 줄임말로 노드 표현을 의미
        loss, word_loss, topo_loss, wacc, tacc = motif_model(batch, node_rep)		# motif모델에서 손실, motif정확도, 위상 정확도 출력

        optimizer_model.zero_grad()				#옵티마이저 0으로
        optimizer_motif.zero_grad()
        loss.backward()					#손실 역전파

        optimizer_model.step()				#옵티마이저 시행
        optimizer_motif.step()

        word_acc += wacc
        topo_acc += tacc					#위상 정확도
            
    return loss, word_loss, topo_loss, word_acc*100, topo_acc*100

def validation(args, model_list, loader, device):
    model, motif_model = model_list                             # 훈련간 사용 모델은 GNN모델과 motif모델이다.

    model.eval()					#모델, 모티프 모델 훈련!
    motif_model.eval()
    word_acc, topo_acc = 0, 0			# 분자와 위상 정확도 변수 설정
    for step, batch in enumerate(loader):	# 데이터로더에서 순회 진행바 표시형태로 순회해서 step과 batch대로 반복하자

        batch_size = len(batch)

        graph_batch = moltree_to_graph_data(batch)		# 분자식을 파이토치 지오메트릭 패키지에서 요구되는 그래프 데이터 형태로 변경해서 배치단위로 저장   /datautils에 있음
        #store graph object data in the process stage	
        batch_index = graph_batch.batch.numpy()			# 배치내의 배치텐서를 넘파이로 인덱스에 넘겨라
        graph_batch = graph_batch.to(device)			# 그래프배치는 GPU로 되게
        node_rep = model(graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr)	# GNN모델에 그래프(x, 엣지인덱스, 엣지의 특성) 투입
        node_rep = group_node_rep(node_rep, batch_index, batch_size)			# rep는 representation의 줄임말로 노드 표현을 의미
        loss, word_loss, topo_loss, wacc, tacc = motif_model(batch, node_rep)		# motif모델에서 손실, motif정확도, 위상 정확도 출력

        word_acc += wacc
        topo_acc += tacc					#위상 정확도

    return loss, word_loss, topo_loss, word_acc*100, topo_acc*100


In [6]:
parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
parser.add_argument('--device', type=int, default=0,
                    help='which gpu to use if any (default: 0)')
parser.add_argument('--batch_size', type=int, default=32,
                    help='input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default=100,
                    help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=0.001,
                    help='learning rate (default: 0.001)')
parser.add_argument('--decay', type=float, default=0,
                    help='weight decay (default: 0)')
parser.add_argument('--num_layer', type=int, default=5,
                    help='number of GNN message passing layers (default: 5).')
parser.add_argument('--emb_dim', type=int, default=300,
                    help='embedding dimensions (default: 300)')
parser.add_argument('--dropout_ratio', type=float, default=0.2,
                    help='dropout ratio (default: 0.2)')
parser.add_argument('--graph_pooling', type=str, default="mean",
                    help='graph level pooling (sum, mean, max, set2set, attention)')
parser.add_argument('--JK', type=str, default="last",
                    help='how the node features across layers are combined. last, sum, max or concat')
parser.add_argument('--dataset', type=str, default='./data/zinc/all.txt',
                    help='root directory of dataset. For now, only classification.')
parser.add_argument('--gnn_type', type=str, default="gin")
parser.add_argument('--input_model_file', type=str, default="", help='filename to read the model (if there is any)')
parser.add_argument('--output_path', type=str, default='./saved_model/grover',
                    help='filename to output the pre-trained model')
parser.add_argument('--num_workers', type=int, default=0, help='number of workers for dataset loading')   #원래는 8이었음 오류로 0으로 바꿈
parser.add_argument("--hidden_size", type=int, default=300, help='hidden size')
parser.add_argument("--latent_size", type=int, default=56, help='latent size')
parser.add_argument("--vocab", type=str, default='./data/zinc/clique.txt', help='vocab path')
parser.add_argument('--order', type=str, default="bfs",
                    help='motif tree generation order (bfs or dfs)')
parser.add_argument('--seed', type=int, default=0,
                    help='setting seed number')
#for wandb
parser.add_argument('--wandb', action='store_true', default=False, help='add wandb log')
parser.add_argument('--wandb_name', type=str, default = 'MGSSL_Grover', help='wandb name')
parser.add_argument('--grover_dataset', action='store_true', default=False, help='grover dataset mode')

_StoreTrueAction(option_strings=['--grover_dataset'], dest='grover_dataset', nargs=0, const=True, default=False, type=None, choices=None, help='grover dataset mode', metavar=None)

In [32]:
args = parser.parse_args(['--emb_dim', '1200', '--hidden_size', '1200', '--epochs', '100', '--batch_size', '40', '--grover_dataset',
                          '--dropout_ratio', '0.1', '--vocab', 'data/merge_0/clique.txt', '--order', 'dfs', '--dataset', 'data/merge_0', 
                          '--output_path', 'output/grover'])
args

Namespace(JK='last', batch_size=40, dataset='data/merge_0', decay=0, device=0, dropout_ratio=0.1, emb_dim=1200, epochs=100, gnn_type='gin', graph_pooling='mean', grover_dataset=True, hidden_size=1200, input_model_file='', latent_size=56, lr=0.001, num_layer=5, num_workers=0, order='dfs', output_path='output/grover', seed=0, vocab='data/merge_0/clique.txt', wandb=False, wandb_name='MGSSL_Grover')

In [33]:
torch.manual_seed(args.seed)
np.random.seed(args.seed)
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)

rank = 0
num_replicas = 1
if args.grover_dataset:
    grover_data, sample_per_file = get_motif_data(args.dataset)
    train_dataset, val_dataset, _ = split_data_grover(grover_data, sizes=(0.9,0.1,0), seed=0)
    shared_dict = {}
    GMC = GroverMotifCollator(shared_dict=shared_dict, args=args)
    pre_load_data(train_dataset, rank = rank, num_replicas = num_replicas)
    pre_load_data(val_dataset, rank = rank, num_replicas = num_replicas)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=GMC)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=GMC)

else : 
    dataset = MoleculeDataset_grover(args.dataset)
    train_dataset, val_dataset = train_test_split(dataset, test_size=0.5, random_state=42)
    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=args.num_workers, collate_fn=lambda x:x, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=2, shuffle=True, num_workers=args.num_workers, collate_fn=lambda x:x, drop_last=True)

Loading data:
Number of files: 20
Number of samples: 20000
Samples/file: 1000


In [30]:
model = GNN(5, args.emb_dim, JK='last', drop_ratio=args.dropout_ratio, gnn_type='gin').to(device)
if os.path.exists(args.input_model_file):
    model.load_state_dict(torch.load(args.input_model_file))

vocab = [x.strip("\r\n ") for x in open(args.vocab)]
vocab = Vocab(vocab)
motif_model = Motif_Generation_Grover(vocab, args.hidden_size, device, args.order).to(device)

model_list = [model, motif_model]
optimizer_model = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay)
optimizer_motif = optim.Adam(motif_model.parameters(), lr=1e-3, weight_decay=args.decay)

optimizer_list = [optimizer_model, optimizer_motif]

if args.wandb :
    wandb.init(project=args.wandb_name)
    wandb.config = args
    #wandb.watch(model)

In [34]:
#train start
best_val_loss = 1e+10
for epoch in range(1, args.epochs + 1):
    print("====epoch " + str(epoch))

    #training
    train_start = time.time()
    train_loss, train_node_loss, train_topo_loss, train_node_acc, train_topo_acc = train(args, model_list, train_loader, optimizer_list, device)
    train_end = time.time() - train_start
    print(f'epoch : {epoch:04d} train_loss : {train_loss:.4f} train_node_loss : {train_node_loss:.4f} train_topo_loss : {train_topo_loss:.4f} train_node_acc : {train_node_acc:.2f} train_topo_acc : {train_topo_acc:.2f} train_time : {train_end:.2f}s')

    #validation
    val_start = time.time()
    val_loss, val_node_loss, val_topo_loss, val_node_acc, val_topo_acc = validation(args, model_list, val_loader, device)
    val_end = time.time() - val_start
    print(f'epoch : {epoch:04d} val_loss : {val_loss:.4f} val_node_loss : {val_node_loss:.4f} val_topo_loss : {val_topo_loss:.4f} val_node_acc : {val_node_acc:.2f} val_topo_acc : {val_topo_acc:.2f} val_tim : {val_end:.2f}s')

    if not os.path.exists(args.output_path):
        os.mkdir(args.output_path)

    if args.wandb :         
        wandb.log({"train_loss" : train_loss, "train_node_loss" : train_node_loss, "train_topo_loss" : train_topo_loss, 
                   "val_loss" : val_loss, "val_node_loss" : val_node_loss, "val_topo_loss" : val_topo_loss})

    torch.save(model.state_dict(), os.path.join(args.output_path, 'temp.pth'))
    if best_val_loss > val_loss:
        torch.save(model.state_dict(), os.path.join(args.output_path, f'best.pth'))
    if epoch % 5 == 0:
        torch.save(model.state_dict(), os.path.join(args.output_path, f'{epoch}.pth'))

print('all train clear')

====epoch 1
epoch : 0001 train_loss : 20.1973 train_node_loss : 17.1413 train_topo_loss : 3.0560 train_node_acc : 26300.13 train_topo_acc : 41731.20 train_time : 120.79s
epoch : 0001 val_loss : 19.6792 val_node_loss : 15.5086 val_topo_loss : 4.1706 val_node_acc : 1612.44 val_topo_acc : 2380.62 val_tim : 4.58s
====epoch 2


KeyboardInterrupt: 