In [1]:
import random, os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
import numpy as np
import torch
from rdkit import RDLogger

from grover.util.parsing import parse_args, get_newest_train_args
from grover.util.utils import create_logger
from task.cross_validate import cross_validate, randomsearch, gridsearch, make_confusion_matrix
from task.fingerprint import generate_fingerprints
from task.predict import make_predictions, write_prediction
from task.pretrain import pretrain_model, subset_learning
from grover.data.torchvocab import MolVocab

from grover.topology.mol_tree import *
from grover.topology.dfs import *

#add for gridsearch
from argparse import ArgumentParser, Namespace

def setup(seed):
    # frozen random seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

using Horovod for multi-GPU training


In [3]:
# setup random seed
setup(seed=42)
# Avoid the pylint warning.
a = MolVocab
# supress rdkit logger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

# Initialize MolVocab
mol_vocab = MolVocab

# parse_args()

In [4]:
from grover.util.parsing import *

In [5]:
def parse_args() -> Namespace:
    """
    Parses arguments for training and testing (includes modifying/validating arguments).

    :return: A Namespace containing the parsed, modified, and validated args.
    """
    parser = ArgumentParser()
    subparser = parser.add_subparsers(title="subcommands",
                                      dest="parser_name",
                                      help="Subcommands for fintune, prediction, and fingerprint.")
    parser_pretrain = subparser.add_parser('pretrain', help="Pretrain with unlabelled SMILES.")
    add_pretrain_args(parser_pretrain)

    args = parser.parse_args(['pretrain','--data_path','data/merge','--save_dir','model/merge_test','--atom_vocab_path','data/merge/merge_atom_vocab.pkl','--bond_vocab_path','data/merge/merge_bond_vocab.pkl',
                              '--batch_size','100','--dropout','0.1','--depth','3','--num_attn_head','4','--hidden_size','1200','--epochs','20','--activation','PReLU','--backbone','gtrans','--embedding_output_type','both',
                              '--save_interval','5','--init_lr', '0.0002', '--max_lr', '0.0004', '--final_lr', '0.0001', '--weight_decay', '0.0000001', 
                              '--topology','--motif_vocab_path','data/merge/clique.txt','--motif_hidden_size','1200','--motif_latent_size','56','--motif_order','dfs'])
    
    if args.parser_name == 'finetune' or args.parser_name == 'eval':
        modify_train_args(args)
    elif args.parser_name == "pretrain":
        modify_pretrain_args(args)
    elif args.parser_name == 'predict':
        modify_predict_args(args)
    elif args.parser_name == 'fingerprint':
        modify_fingerprint_args(args)

    return args

In [6]:
args = parse_args()
args

Namespace(activation='PReLU', atom_vocab_path='data/merge/merge_atom_vocab.pkl', backbone='gtrans', batch_size=100, bias=False, bond_drop_rate=0, bond_vocab_path='data/merge/merge_bond_vocab.pkl', cuda=True, data_path='data/merge', dense=False, depth=3, dist_coff=0.1, dropout=0.1, each_epochs=5, embedding_output_type='both', enable_multi_gpu=False, epochs=20, fg_label_path=None, final_lr=0.0001, fine_tune_coff=1, hidden_size=1200, init_lr=0.0002, max_lr=0.0004, motif_hidden_size=1200, motif_latent_size=56, motif_order='dfs', motif_vocab_path='data/merge/clique.txt', no_cache=True, num_attn_head=4, num_mt_block=1, parser_name='pretrain', save_dir='model/merge_test', save_interval=5, subset_learning=False, topology=True, undirected=False, wandb=False, wandb_name='pretrain', warmup_epochs=2.0, weight_decay=1e-07)

In [7]:
logger = create_logger(name='pretrain', save_dir=args.save_dir)

# pretrain

In [8]:
import os
import time
from argparse import Namespace
from logging import Logger

import numpy as np

import torch
from torch.utils.data import DataLoader
import wandb

from grover.data.dist_sampler import DistributedSampler
from grover.data.groverdataset import get_data, split_data, GroverCollator, BatchMolDataset, get_motif_data, split_data_motif, GroverMotifCollator, BatchMolDataset_motif
from grover.data.torchvocab import MolVocab
from grover.model.models import GROVEREmbedding
from grover.util.multi_gpu_wrapper import MultiGpuWrapper as mgw
from grover.util.nn_utils import param_count
from grover.util.utils import build_optimizer, build_lr_scheduler
from task.grovertrainer import GROVERTrainer, GROVERMotifTrainer

from grover.topology.mol_tree import Motif_Vocab
from grover.topology.motif_generation import Motif_Generation

import os
import time
from logging import Logger
from typing import List, Tuple
from collections.abc import Callable
import torch
from torch.nn import Module
from torch.utils.data import DataLoader

from grover.model.models import GroverTask, GroverMotifTask
from grover.util.multi_gpu_wrapper import MultiGpuWrapper as mgw
from grover.util.utils import load_moltree

#add for Topology predict
from grover.topology.chemutils import group_node_rep


In [9]:
def pre_load_data(dataset: BatchMolDataset, rank: int, num_replicas: int, sample_per_file: int = None, epoch: int = 0):
    """
    Pre-load data at the beginning of each epoch.
    :param dataset: the training dataset.
    :param rank: the rank of the current worker.
    :param num_replicas: the replicas.
    :param sample_per_file: the number of the data points in each file. When sample_per_file is None, all data will be
    loaded. It implies the testing phase. (TODO: bad design here.)
    :param epoch: the epoch number.
    :return:
    """
    mock_sampler = DistributedSampler(dataset, num_replicas=num_replicas, rank=rank, shuffle=False,
                                      sample_per_file=sample_per_file)
    mock_sampler.set_epoch(epoch)
    pre_indices = mock_sampler.get_indices()
    for i in pre_indices:
        dataset.load_data(i)


In [10]:
if logger is not None:
    debug, info = logger.debug, logger.info
else:
    debug = print

# initialize the horovod library
if args.enable_multi_gpu:
    mgw.init()

# binding training to GPUs.
master_worker = (mgw.rank() == 0) if args.enable_multi_gpu else True
# pin GPU to local rank. By default, we use gpu:0 for training.
local_gpu_idx = mgw.local_rank() if args.enable_multi_gpu else 0
with_cuda = args.cuda
if with_cuda:
    torch.cuda.set_device(local_gpu_idx)

# get rank an  number of workers
rank = mgw.rank() if args.enable_multi_gpu else 0
num_replicas = mgw.size() if args.enable_multi_gpu else 1
# print("Rank: %d Rep: %d" % (rank, num_replicas))

# load file paths of the data.
if master_worker:
    info(args)
    if args.enable_multi_gpu:
        debug("Total workers: %d" % (mgw.size()))
    debug('Loading data')
    print(f'data path is {args.data_path}')
data, sample_per_file = get_motif_data(data_path=args.data_path)

# data splitting
if master_worker:
    debug(f'Splitting data with seed 0.')
train_data, test_data, _ = split_data_motif(data=data, sizes=(0.5, 0.5, 0.0), seed=0, logger=logger)

# Here the true train data size is the train_data divided by #GPUs
if args.enable_multi_gpu:
    args.train_data_size = len(train_data) // mgw.size()
else:
    args.train_data_size = len(train_data)
if master_worker:
    info(f'Total size = {len(data):,} | '
          f'train size = {len(train_data):,} | val size = {len(test_data):,}')

# load atom and bond vocabulary and the semantic motif labels.
atom_vocab = MolVocab.load_vocab(args.atom_vocab_path)
bond_vocab = MolVocab.load_vocab(args.bond_vocab_path)
atom_vocab_size, bond_vocab_size = len(atom_vocab), len(bond_vocab)

# Load motif vocabulary for pretrain
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
motif_vocab = [x.strip("\r\n ") for x in open(args.motif_vocab_path)]
motif_vocab = Motif_Vocab(motif_vocab)

# Hard coding here, since we haven't load any data yet!
fg_size = 85
shared_dict = {}
motif_collator = GroverMotifCollator(shared_dict=shared_dict, atom_vocab=atom_vocab, bond_vocab=bond_vocab, args=args)
if master_worker:
    debug("atom vocab size: %d, bond vocab size: %d, Number of FG tasks: %d" % (atom_vocab_size,
                                                                                bond_vocab_size, fg_size))

# Define the distributed sampler. If using the single card, the sampler will be None.
train_sampler = None
test_sampler = None
shuffle = True
if args.enable_multi_gpu:
    # If not shuffle, the performance may decayed.
    train_sampler = DistributedSampler(
        train_data, num_replicas=mgw.size(), rank=mgw.rank(), shuffle=True, sample_per_file=sample_per_file)
    # Here sample_per_file in test_sampler is None, indicating the test sampler would not divide the test samples by
    # rank. (TODO: bad design here.)
    test_sampler = DistributedSampler(
        test_data, num_replicas=mgw.size(), rank=mgw.rank(), shuffle=False)
    train_sampler.set_epoch(args.epochs)
    test_sampler.set_epoch(1)
    # if we enables multi_gpu training. shuffle should be disabled.
    shuffle = False

Namespace(activation='PReLU', atom_vocab_path='data/merge/merge_atom_vocab.pkl', backbone='gtrans', batch_size=100, bias=False, bond_drop_rate=0, bond_vocab_path='data/merge/merge_bond_vocab.pkl', cuda=True, data_path='data/merge', dense=False, depth=3, dist_coff=0.1, dropout=0.1, each_epochs=5, embedding_output_type='both', enable_multi_gpu=False, epochs=20, fg_label_path=None, final_lr=0.0001, fine_tune_coff=1, hidden_size=1200, init_lr=0.0002, max_lr=0.0004, motif_hidden_size=1200, motif_latent_size=56, motif_order='dfs', motif_vocab_path='data/merge/clique.txt', no_cache=True, num_attn_head=4, num_mt_block=1, parser_name='pretrain', save_dir='model/merge_test', save_interval=5, subset_learning=False, topology=True, undirected=False, wandb=False, wandb_name='pretrain', warmup_epochs=2.0, weight_decay=1e-07)
Loading data
Splitting data with seed 0.
Total size = 20 | train size = 10 | val size = 10


data path is data/merge
Loading data:
Number of files: 3
Number of samples: 20
Samples/file: 10


atom vocab size: 1533, bond vocab size: 1988, Number of FG tasks: 85


model = GNN(5, 300, JK='last', drop_ratio=0.2, gnn_type='gin').to(device)

In [14]:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, softmax
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
from torch_scatter import scatter_add
from torch_geometric.nn.inits import glorot, zeros

#이건 grover랑 다른데?
num_atom_type = 120 #including the extra mask tokens
num_chirality_tag = 3

num_bond_type = 6 #including aromatic and self-loop edge, and extra masked tokens
num_bond_direction = 3 

In [25]:
add_self_loops?

[0;31mSignature:[0m
[0madd_self_loops[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0medge_index[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0medge_attr[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfill_value[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mfloat[0m[0;34m,[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mstr[0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnum_nodes[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mint[0m[0;34m,[0m [0mTuple[0m[0;34m[[0m[0mint[0m[0;34m,[0m [0mint[0m[0;34m][0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0mTuple[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;

In [None]:
class GINConv(MessagePassing):
    """
    Extension of GIN aggregation to incorporate edge information by concatenation.

    Args:
        emb_dim (int): dimensionality of embeddings for nodes and edges.
        embed_input (bool): whether to embed input or not. 
        

    See https://arxiv.org/abs/1810.00826
    """
    def __init__(self, emb_dim, aggr = "add"):
        super(GINConv, self).__init__()
        #multi-layer perceptron
        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
        self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
        self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)

        torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
        torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
        self.aggr = aggr

    def forward(self, x, edge_index, edge_attr):
        #add self loops in the edge space
        edge_index = add_self_loops(edge_index, num_nodes = x.size(0))

        #add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), 2)
        self_loop_attr[:,0] = 4 #bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)

        edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])

        return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings)

    def message(self, x_j, edge_attr):
        return x_j + edge_attr

    def update(self, aggr_out):
        return self.mlp(aggr_out)

In [15]:
test = [1] + [2]
test

[1, 2]

In [24]:
emb(torch.tensor(1))+emb(torch.tensor(2))

tensor([-1.1506, -0.3253,  0.0729,  1.1212,  0.5103,  2.0589,  2.2528,  1.1751,
        -0.7290,  0.4813], grad_fn=<AddBackward0>)

In [13]:
class GNN_Grover(torch.nn.Module):
    """
    

    Args:
        num_layer (int): the number of GNN layers
        emb_dim (int): dimensionality of embeddings
        JK (str): last, concat, max or sum.
        max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation
        drop_ratio (float): dropout rate
        gnn_type: gin, gcn, graphsage, gat

    Output:
        node representations

    """
    def __init__(self, num_layer, emb_dim, JK = "last", drop_ratio = 0, gnn_type = "gin"):
        super(GNN, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")
        """
        # nn.Embedding은 좌측의 수만큼의 카테고리를 우측의 임베딩 공간으로 임베딩하는거다.
        # 그러나 grover의 경우엔 원핫 인코더등을 적용한 복잡한 구조이다.
        #self.x_embedding1 = torch.nn.Embedding(num_atom_type, emb_dim)
        #self.x_embedding2 = torch.nn.Embedding(num_chirality_tag, emb_dim)
        """
        
        self.embedding = nn.Linear(165, emb_dim, bias=bias)

        torch.nn.init.xavier_uniform_(self.x_embedding1.weight.data)
        torch.nn.init.xavier_uniform_(self.x_embedding2.weight.data)

        ###List of MLPs
        self.gnns = torch.nn.ModuleList()
        for layer in range(num_layer):
            if gnn_type == "gin":
                self.gnns.append(GINConv(emb_dim, aggr = "add"))
            elif gnn_type == "gcn":
                self.gnns.append(GCNConv(emb_dim))
            elif gnn_type == "gat":
                self.gnns.append(GATConv(emb_dim))
            elif gnn_type == "graphsage":
                self.gnns.append(GraphSAGEConv(emb_dim))

        ###List of batchnorms
        self.batch_norms = torch.nn.ModuleList()
        for layer in range(num_layer):
            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    #def forward(self, x, edge_index, edge_attr):
    def forward(self, *argv):
        """
        기존 코드
        if len(argv) == 3:
            # 이게 MGSSL의 기본 input인데, 흠,,,
            x, edge_index, edge_attr = argv[0], argv[1], argv[2]
        elif len(argv) == 1:
            data = argv[0]
            x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        else:
            raise ValueError("unmatched number of arguments.")
        """
        
        node_attr, edge_attr = argv
        
        # x shape : (num_batch, emb_dim)
        x = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1])

        h_list = [x]
        for layer in range(self.num_layer):
            h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            #h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            h_list.append(h)

        ### Different implementations of Jk-concat   #레이어간 노드 기능들을 어떻게 할건지, 기본은 Last다
        if self.JK == "concat":
            node_representation = torch.cat(h_list, dim = 1)
        elif self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "max":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.max(torch.cat(h_list, dim = 0), dim = 0)[0]
        elif self.JK == "sum":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.sum(torch.cat(h_list, dim = 0), dim = 0)[0]

        return node_representation

In [12]:
# Pre load data. (Maybe unnecessary. )
pre_load_data(train_data, rank, num_replicas, sample_per_file)
pre_load_data(test_data, rank, num_replicas)
if master_worker:
    # print("Pre-loaded training data: %d" % train_data.count_loaded_datapoints())
    info("Pre-loaded test data: %d" % test_data.count_loaded_datapoints())

# Build dataloader
train_data_dl = DataLoader(train_data,
                           batch_size=2,
                           shuffle=shuffle,
                           num_workers=0,
                           sampler=train_sampler,
                           collate_fn=motif_collator)
test_data_dl = DataLoader(test_data,
                          batch_size=2,
                          shuffle=shuffle,
                          num_workers=0,
                          sampler=test_sampler,
                          collate_fn=motif_collator)

# Build the embedding model.
grover_model = GROVEREmbedding(args)

# build the topology predict model.
motif_model = Motif_Generation(motif_vocab, args.motif_hidden_size, args.motif_latent_size, 3, device, args.motif_order)

#  Build the trainer.
trainer = GROVERMotifTrainer(args=args,
                        embedding_model=grover_model,
                        topology_model=motif_model,
                        atom_vocab_size=atom_vocab_size,
                        bond_vocab_size=bond_vocab_size,
                        fg_size=fg_size,
                        train_dataloader=train_data_dl,
                        test_dataloader=test_data_dl,
                        optimizer_builder=build_optimizer,
                        scheduler_builder=build_lr_scheduler,
                        logger=logger,
                        with_cuda=with_cuda,
                        enable_multi_gpu=args.enable_multi_gpu)

# Restore the interrupted training.
model_dir = os.path.join(args.save_dir, "model")
resume_from_epoch = 0
resume_scheduler_step = 0
if master_worker:
    resume_from_epoch, resume_scheduler_step = trainer.restore(model_dir)
if args.enable_multi_gpu:
    resume_from_epoch = mgw.broadcast(torch.tensor(resume_from_epoch), root_rank=0, name="resume_from_epoch").item()
    resume_scheduler_step = mgw.broadcast(torch.tensor(resume_scheduler_step),
                                          root_rank=0, name="resume_scheduler_step").item()
    trainer.scheduler.current_step = resume_scheduler_step
    info("Restored epoch: %d Restored scheduler step: %d" % (resume_from_epoch, trainer.scheduler.current_step))

trainer.broadcast_parameters()

Pre-loaded test data: 1


No checkpoint found %d


  self.linear_increment = (self.max_lr - self.init_lr) / self.warmup_steps
  self.exponential_gamma = (self.final_lr / self.max_lr) ** (1 / (self.total_steps - self.warmup_steps))


In [13]:
model = GroverMotifTask(args, grover_model, atom_vocab_size, bond_vocab_size, fg_size)
model.cuda()

GroverMotifTask(
  (grover): GROVEREmbedding(
    (encoders): GTransEncoder(
      (edge_blocks): ModuleList(
        (0): MTBlock(
          (heads): ModuleList(
            (0): Head(
              (mpn_q): MPNEncoder(
                (dropout_layer): Dropout(p=0.1, inplace=False)
                (act_func): PReLU(num_parameters=1)
                (W_h): Linear(in_features=1200, out_features=1200, bias=False)
              )
              (mpn_k): MPNEncoder(
                (dropout_layer): Dropout(p=0.1, inplace=False)
                (act_func): PReLU(num_parameters=1)
                (W_h): Linear(in_features=1200, out_features=1200, bias=False)
              )
              (mpn_v): MPNEncoder(
                (dropout_layer): Dropout(p=0.1, inplace=False)
                (act_func): PReLU(num_parameters=1)
                (W_h): Linear(in_features=1200, out_features=1200, bias=False)
              )
            )
            (1): Head(
              (mpn_q): MPNEncoder(
       

In [14]:
train_data_dl = DataLoader(train_data,
                           batch_size=2,
                           shuffle=shuffle,
                           num_workers=0,
                           sampler=train_sampler,
                           collate_fn=motif_collator)
test_data_dl = DataLoader(test_data,
                          batch_size=2,
                          shuffle=shuffle,
                          num_workers=0,
                          sampler=test_sampler,
                          collate_fn=motif_collator)

In [436]:
model.train()
motif_model.train()

for i, item in enumerate(train_data_dl):
    batch_graph = item["graph_input"]
    targets = item["targets"]

    # add this for motif generation
    moltree = item["moltree"]

    if next(model.parameters()).is_cuda:
        targets["av_task"] = targets["av_task"].cuda()
        targets["bv_task"] = targets["bv_task"].cuda()
        targets["fg_task"] = targets["fg_task"].cuda()

    preds = model(batch_graph)
    emb_vector = preds['emb_vec']
    
    emb_afa_grouped = group_node_rep(moltree, emb_vector['atom_from_atom'], batch_graph)
    if i==0:break

In [437]:
for i in range(len(moltree)):
    print(moltree[i].smiles)

Oc1cc(Nc2ccnc3cc(Cl)ccc23)cc(-c2ccc(C(F)(F)F)cc2)c1
Cl.O=C1c2cc(CO)cc(O)c2C(=O)c2c1ccc(C1(C3O[C@H](CO)[C@@H](O)[C@H](O)[C@H]3O)c3cccc(O)c3C(=O)c3c(O)cc(CO)cc31)c2O


In [438]:
node_afa_loss, topo_afa_loss, node_afa_acc, topo_afa_acc = motif_model(moltree, emb_afa_grouped)

mol_tree smiles is Oc1cc(Nc2ccnc3cc(Cl)ccc23)cc(-c2ccc(C(F)(F)F)cc2)c1
mol_tree smiles is Cl.O=C1c2cc(CO)cc(O)c2C(=O)c2c1ccc(C1(C3O[C@H](CO)[C@@H](O)[C@H](O)[C@H]3O)c3cccc(O)c3C(=O)c3c(O)cc(CO)cc31)c2O
error smiles is Cl.O=C1c2cc(CO)cc(O)c2C(=O)c2c1ccc(C1(C3O[C@H](CO)[C@@H](O)[C@H](O)[C@H]3O)c3cccc(O)c3C(=O)c3c(O)cc(CO)cc31)c2O
[tensor([ 1.6271, -0.1634, -0.1997,  ...,  0.7849, -0.9937, -1.2304],
       device='cuda:0', grad_fn=<SelectBackward0>), tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), tensor([ 0.6828,  0.1067, -0.1724,  ...,  0.1375, -0.4109,  0.0683],
       device='cuda:0', grad_fn=<SelectBackward0>), tensor([ 0.4688,  0.1302,  0

AttributeError: 'list' object has no attribute 'shape'

# dfs 과정 추적

In [440]:
from grover.topology.dfs import *
vocab_size = motif_vocab.size()
hidden_size=1200
pred_loss = nn.CrossEntropyLoss(reduction="mean")
stop_loss = nn.BCEWithLogitsLoss(reduction="mean")
W_z = nn.Linear(2 * hidden_size, hidden_size).cuda()
U_r = nn.Linear(hidden_size, hidden_size, bias=False).cuda()
W_r = nn.Linear(hidden_size, hidden_size).cuda()
W_h = nn.Linear(2 * hidden_size, hidden_size).cuda()
W = nn.Linear(hidden_size, hidden_size).cuda()
U = nn.Linear(2 * hidden_size, hidden_size).cuda()
W_o = nn.Linear(hidden_size, vocab_size).cuda()
U_s = nn.Linear(hidden_size, 1).cuda()

In [441]:
def set_batch_nodeID(mol_batch, vocab):
    tot = 0
    for mol_tree in mol_batch:
        for node in mol_tree.nodes:
            node.idx = tot
            node.wid = vocab.get_index(node.smiles)
            tot += 1

In [442]:
W_z.cuda()

Linear(in_features=2400, out_features=1200, bias=True)

In [443]:
mol_batch = moltree.copy()
node_rep = emb_afa_grouped.copy()

In [444]:
mol_batch

[<grover.topology.mol_tree.MolTree_break at 0x7f9a49a97cd0>,
 <grover.topology.mol_tree.MolTree_break at 0x7f9a49a70410>]

In [445]:
node_rep[0].shape

torch.Size([29, 1200])

In [446]:
node_rep1[0].shape

torch.Size([29, 1200])

In [447]:
set_batch_nodeID(mol_batch, motif_vocab)

In [448]:
super_root = MolTreeNode("")
super_root.idx = -1

# Initialize
pred_hiddens, pred_targets = [], []
stop_hiddens, stop_targets = [], []
traces = []
for mol_tree in mol_batch:
    s = []
    try : 
        dfs(s, mol_tree.nodes[0], super_root)
    except : 
        print(f'smiles is {mol_tree.smiles}')
        print(f'moltree is {mol_tree}')
        print(f'moltreenodes is {mol_tree.nodes}')
        print(f'superroot is {super_root}')
    traces.append(s)
    for node in mol_tree.nodes:
        node.neighbors = []
'''
# Predict Root
pred_hiddens.append(create_var(torch.zeros(len(mol_batch), self.hidden_size)))
pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch])
pred_mol_vecs.append(mol_vec)
'''

max_iter = max([len(tr) for tr in traces])
padding = create_var(torch.zeros(hidden_size), False)
h = {}

for t in range(max_iter):
    prop_list = []
    batch_list = []
    for i, plist in enumerate(traces):
        if t < len(plist):
            prop_list.append(plist[t])
            batch_list.append(i)
        else:
            prop_list.append(None)

    em_list = []
    cur_h_nei, cur_o_nei = [], []

    for mol_index, prop in enumerate(prop_list):
        if prop is None:
            continue
        node_x, real_y, _ = prop
        # Neighbors for message passing (target not included)
        cur_nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx]
        pad_len = MAX_NB - len(cur_nei)
        if pad_len>= 0:
            cur_h_nei.extend(cur_nei)
            cur_h_nei.extend([padding] * pad_len)
        else:
            cur_h_nei.extend(cur_nei[:MAX_NB])

        # Neighbors for stop prediction (all neighbors)
        cur_nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors]
        pad_len = MAX_NB - len(cur_nei)
        if pad_len >= 0:
            cur_o_nei.extend(cur_nei)
            cur_o_nei.extend([padding] * pad_len)
        else:
            cur_o_nei.extend(cur_nei[:MAX_NB])


        # Current clique embedding
        em_list.append(torch.sum(node_rep[mol_index].index_select(0, torch.tensor(node_x.clique).to(device)), dim=0))

    # Clique embedding
    cur_x = torch.stack(em_list, dim=0)  #gpu code

    # Message passing
    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1, MAX_NB, hidden_size)
    new_h = GRU(cur_x, cur_h_nei, W_z, W_r, U_r, W_h)

    # Node Aggregate
    cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, hidden_size)
    cur_o = cur_o_nei.sum(dim=1)

    # Gather targets
    pred_target, pred_list = [], []
    stop_target = []
    prop_list = [x for x in prop_list if x is not None]
    for i, m in enumerate(prop_list):
        node_x, node_y, direction = m
        x, y = node_x.idx, node_y.idx
        h[(x, y)] = new_h[i]
        node_y.neighbors.append(node_x)
        if direction == 1:
            pred_target.append(node_y.wid)
            pred_list.append(i)
        stop_target.append(direction)

    # Hidden states for stop prediction
    stop_hidden = torch.cat([cur_x, cur_o], dim=1) 
    stop_hiddens.append(stop_hidden)
    stop_targets.extend(stop_target)

    # Hidden states for clique prediction
    if len(pred_list) > 0:
        #batch_list = [batch_list[i] for i in pred_list]
        #cur_batch = create_var(torch.LongTensor(batch_list))
        #pred_mol_vecs.append(mol_vec.index_select(0, cur_batch))

        cur_pred = create_var(torch.LongTensor(pred_list))
        pred_hiddens.append(new_h.index_select(0, cur_pred))
        pred_targets.extend(pred_target)

# Last stop at root
em_list, cur_o_nei = [], []
for mol_index, mol_tree in enumerate(mol_batch):
#            try : 
    node_x = mol_tree.nodes[0]
#            except : 
#                print(f'error smiles is {mol_tree.smiles}')

    em_list.append(torch.sum(node_rep[mol_index].index_select(0, torch.tensor(node_x.clique).to(device)), dim=0))
    cur_nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors]
    pad_len = MAX_NB - len(cur_nei)
    cur_o_nei.extend(cur_nei)
    cur_o_nei.extend([padding] * pad_len)

cur_x = torch.stack(em_list, dim=0)
try : 
    cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, hidden_size)
    cur_o = cur_o_nei.sum(dim=1)
except : 
    for mol_index, mol_tree in enumerate(mol_batch):
        print(f'mol_tree smiles is {mol_tree.smiles}')
    print(f'error smiles is {mol_tree.smiles}')
    print(cur_o_nei)
    print(cur_o_nei.shape)
    print(cur_x)
    print(cur_x.shape)
    print(torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, hidden_size))
    print(torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, hidden_size).sum(dim=1))


stop_hidden = torch.cat([cur_x, cur_o], dim=1)
stop_hiddens.append(stop_hidden)
stop_targets.extend([0] * len(mol_batch))

# Predict next clique
# pred hiddens와 pred_vecs는 앞의 node들에 대한 hiddenstate로 node x hidden size다
# 마지막에 pred_scores로 전체 vocab에 대한 완전연결층으로 예측
pred_hiddens = torch.cat(pred_hiddens, dim=0)
#pred_mol_vecs = torch.cat(pred_mol_vecs, dim=0)
#pred_vecs = torch.cat([pred_hiddens, pred_mol_vecs], dim=1)
pred_vecs = pred_hiddens
pred_vecs = nn.ReLU()(W(pred_vecs))
pred_scores = W_o(pred_vecs)
pred_targets = create_var(torch.LongTensor(pred_targets))

pred_loss = pred_loss(pred_scores, pred_targets) #/ len(mol_batch)
_, preds = torch.max(pred_scores, dim=1)
pred_acc = torch.eq(preds.to(device), pred_targets.to(device)).float()
pred_acc = torch.sum(pred_acc) / pred_targets.nelement()

# Predict stop
stop_hiddens = torch.cat(stop_hiddens, dim=0)
stop_vecs = nn.ReLU()(U(stop_hiddens))
stop_scores = U_s(stop_vecs).squeeze()
stop_targets = create_var(torch.Tensor(stop_targets))

stop_loss = stop_loss(stop_scores, stop_targets) #/ len(mol_batch)
stops = torch.ge(stop_scores, 0).float()
stop_acc = torch.eq(stops, stop_targets).float()
stop_acc = torch.sum(stop_acc) / stop_targets.nelement()

mol_tree smiles is Oc1cc(Nc2ccnc3cc(Cl)ccc23)cc(-c2ccc(C(F)(F)F)cc2)c1
mol_tree smiles is Cl.O=C1c2cc(CO)cc(O)c2C(=O)c2c1ccc(C1(C3O[C@H](CO)[C@@H](O)[C@H](O)[C@H]3O)c3cccc(O)c3C(=O)c3c(O)cc(CO)cc31)c2O
error smiles is Cl.O=C1c2cc(CO)cc(O)c2C(=O)c2c1ccc(C1(C3O[C@H](CO)[C@@H](O)[C@H](O)[C@H]3O)c3cccc(O)c3C(=O)c3c(O)cc(CO)cc31)c2O
[tensor([-1.0053, -0.6278,  0.9503,  ..., -0.9165, -0.9564,  0.9179],
       device='cuda:0', grad_fn=<SelectBackward0>), tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'), tensor([-0.1486, -0.1987, -0.2134,  ...,  0.1359, -0.1658, -0.1233],
       device='cuda:0', grad_fn=<SelectBackward0>), tensor([-0.2673, -0.1789, -0

AttributeError: 'list' object has no attribute 'shape'

In [449]:
pred_targets

[57097,
 22255,
 3695,
 3695,
 25255,
 23858,
 57097,
 22255,
 5604,
 25255,
 25255,
 25255,
 31864,
 57097,
 25255,
 31864,
 71304,
 25255,
 17777,
 25255,
 17777,
 25255,
 17777,
 25255,
 22255,
 25255]

## trace 생성

In [450]:
super_root = MolTreeNode("")
super_root.idx = -1

# Initialize
pred_hiddens, pred_targets = [], []
stop_hiddens, stop_targets = [], []
traces = []
for mol_tree in mol_batch:
    s = []
    dfs(s, mol_tree.nodes[0], super_root)
    traces.append(s)
    for node in mol_tree.nodes:
        node.neighbors = []
'''
# Predict Root
pred_hiddens.append(create_var(torch.zeros(len(mol_batch), self.hidden_size)))
pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch])
pred_mol_vecs.append(mol_vec)
'''

max_iter = max([len(tr) for tr in traces])
padding = create_var(torch.zeros(1200), False)
h = {}

In [451]:
# dfs로 돌아가는 순서가 적힌 것
print(mol_batch[0].smiles)
for i in range(len(traces[0])):
    print(traces[0][i][0].smiles)
    print(traces[0][i][0].clique)
    for j in range(len(traces[0][i][0].neighbors)):
        print(traces[0][i][0].neighbors[j].smiles)
    
    print('\n')

Oc1cc(Nc2ccnc3cc(Cl)ccc23)cc(-c2ccc(C(F)(F)F)cc2)c1
CO
[0, 1]


C1=CC=CC=C1
[1, 28, 17, 16, 3, 2]


CN
[3, 4]


CN
[4, 5]


C1=CC=NC=C1
[6, 7, 8, 9, 15, 5]


C1=CC=CC=C1
[10, 11, 13, 14, 15, 9]


CCl
[11, 12]


C1=CC=CC=C1
[10, 11, 13, 14, 15, 9]


C1=CC=NC=C1
[6, 7, 8, 9, 15, 5]


CN
[4, 5]


CN
[3, 4]


C1=CC=CC=C1
[1, 28, 17, 16, 3, 2]


CC
[17, 18]


C1=CC=CC=C1
[19, 20, 21, 26, 27, 18]


CC
[21, 22]


C
[22]


CF
[22, 23]


C
[22]


CF
[22, 24]


C
[22]


CF
[22, 25]


C
[22]


CC
[21, 22]


C1=CC=CC=C1
[19, 20, 21, 26, 27, 18]


CC
[17, 18]


C1=CC=CC=C1
[1, 28, 17, 16, 3, 2]




In [452]:
len(traces[0])

26

In [453]:
len(mol_batch[0].nodes)

14

In [454]:
# dfs로 돌아가는 순서가 적힌 것
print(mol_batch[1].smiles)
for i in range(len(traces[1])):
    print(traces[1][i][0].smiles)
    print(traces[1][i][0].clique)
    for j in range(len(traces[1][i][0].neighbors)):
        print(traces[1][i][0].neighbors[j].smiles)
    
    print('\n')

Cl.O=C1c2cc(CO)cc(O)c2C(=O)c2c1ccc(C1(C3O[C@H](CO)[C@@H](O)[C@H](O)[C@H]3O)c3cccc(O)c3C(=O)c3c(O)cc(CO)cc31)c2O
C1=CC=C2CC3=C(C=CC(C4(C5CCCCO5)C5=CC=CC=C5CC5=CC=CC=C54)=C3)CC2=C1
[2, 3, 4, 5, 8, 9, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 25, 27, 29, 31, 32, 33, 34, 35, 37, 38, 40, 41, 43, 44, 47, 48, 49]


CO
[24, 23]


C1=CC=C2CC3=C(C=CC(C4(C5CCCCO5)C5=CC=CC=C5CC5=CC=CC=C54)=C3)CC2=C1
[2, 3, 4, 5, 8, 9, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 25, 27, 29, 31, 32, 33, 34, 35, 37, 38, 40, 41, 43, 44, 47, 48, 49]


O
[1]


C1=CC=C2CC3=C(C=CC(C4(C5CCCCO5)C5=CC=CC=C5CC5=CC=CC=C54)=C3)CC2=C1
[2, 3, 4, 5, 8, 9, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 25, 27, 29, 31, 32, 33, 34, 35, 37, 38, 40, 41, 43, 44, 47, 48, 49]


CO
[6, 7]


C1=CC=C2CC3=C(C=CC(C4(C5CCCCO5)C5=CC=CC=C5CC5=CC=CC=C54)=C3)CC2=C1
[2, 3, 4, 5, 8, 9, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 25, 27, 29, 31, 32, 33, 34, 35, 37, 38, 40, 41, 43, 44, 47, 48, 49]


O
[10]


C1=CC=C2CC3=C(C=CC(C4(C5CCCCO5)C5=CC=CC=C5CC5

## dfs 메인코드 분석
- 먼저 각 노드들을 1열로 세운다(prop_list)
- 그 다음 각 노드의 conjugate graph를 생성
- 그리고 cur_h_nei는 타겟을 제외한 인접 노드들의 message를 종합
- cur_o_nei는 타겟을 포함한 모든 인접 레이어의 message를 종합하고
- em_list는 해당 노드의 clique들의 임베딩 벡터를 가져온다. 

In [455]:
# 이건 각 노드들을 순환하며 motif 예측하기 위한 hidden state들을 종합하고, target들을 리스트화 시키는 작업
# 여기서 최대 인접 수를 8개로 제한한다. 적으면 zero padding, 많으면 중간에 자름.
for t in range(max_iter):
    prop_list = []
    batch_list = []
    # traces는 결합관계를 1번째 노드에서 시작해서 끝까지 포함하는 것
    # 아래의 for문은 해당 결합관계를 순서대로 prop_list로 추가하는 것이다.
    # 그리고 이 분자안의 결합관계를 배치 개념으로 봐서 배치의 index로 추가한다.
    # prop_list에는 각 분자의 t번째 node의 결합 정보가 포함이 된다.
    for i, plist in enumerate(traces):
        if t < len(plist):
            prop_list.append(plist[t])
            batch_list.append(i)
        else:
            prop_list.append(None)

    em_list = []
    cur_h_nei, cur_o_nei = [], []

    for mol_index, prop in enumerate(prop_list):
        if prop is None:
            continue
        node_x, real_y, _ = prop
        # Neighbors for message passing (target not included)
        cur_nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx]
        pad_len = MAX_NB - len(cur_nei)
        if pad_len>= 0:
            cur_h_nei.extend(cur_nei)
            cur_h_nei.extend([padding] * pad_len)
        else:
            cur_h_nei.extend(cur_nei[:MAX_NB])

        # Neighbors for stop prediction (all neighbors)
        cur_nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors]
        pad_len = MAX_NB - len(cur_nei)
        if pad_len >= 0:
            cur_o_nei.extend(cur_nei)
            cur_o_nei.extend([padding] * pad_len)
        else:
            cur_o_nei.extend(cur_nei[:MAX_NB])


        # Current clique embedding
        em_list.append(torch.sum(node_rep[mol_index].index_select(0, torch.tensor(node_x.clique).to(device)), dim=0))  #gpu code

    # Clique embedding
    cur_x = torch.stack(em_list, dim=0)    #gpu code

    # Message passing
    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1, MAX_NB, 1200)
    new_h = GRU(cur_x, cur_h_nei, W_z, W_r, U_r, W_h)    #gpu code

    # Node Aggregate
    cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, 1200)
    cur_o = cur_o_nei.sum(dim=1)

    # Gather targets
    pred_target, pred_list = [], []
    stop_target = []
    prop_list = [x for x in prop_list if x is not None]
    for i, m in enumerate(prop_list):
        node_x, node_y, direction = m
        x, y = node_x.idx, node_y.idx
        h[(x, y)] = new_h[i]
        node_y.neighbors.append(node_x)
        if direction == 1:
            pred_target.append(node_y.wid)
            pred_list.append(i)
        stop_target.append(direction)

    # Hidden states for stop prediction
    stop_hidden = torch.cat([cur_x, cur_o], dim=1)
    stop_hiddens.append(stop_hidden)
    stop_targets.extend(stop_target)

    # Hidden states for clique prediction
    if len(pred_list) > 0:
        #batch_list = [batch_list[i] for i in pred_list]
        #cur_batch = create_var(torch.LongTensor(batch_list))
        #pred_mol_vecs.append(mol_vec.index_select(0, cur_batch))

        cur_pred = create_var(torch.LongTensor(pred_list))
        pred_hiddens.append(new_h.index_select(0, cur_pred))
        pred_targets.extend(pred_target)

In [503]:
em_list, cur_o_nei = [], []
for mol_index, mol_tree in enumerate(mol_batch):
    node_x = mol_tree.nodes[0]
    em_list.append(torch.sum(node_rep[mol_index].index_select(0, torch.tensor(node_x.clique).to(device)), dim=0))
    cur_nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors]
    pad_len = MAX_NB - len(cur_nei)
    cur_o_nei.extend(cur_nei)
    cur_o_nei.extend([padding] * pad_len)

cur_x = torch.stack(em_list, dim=0)

In [504]:
cur_o_nei

[tensor([-1.0053, -0.6278,  0.9503,  ..., -0.9165, -0.9564,  0.9179],
        device='cuda:0', grad_fn=<SelectBackward0>),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([-0.1486, -0.1987, -0.2134,  ...,  0.1359, -0.1658, -0.1233],
        device='cuda:0', grad_fn=<SelectBackward0>),
 tensor([-0.2673, -0.1789, -0.1111,  ...,  0.1779,  0.0260, -0.0916],
        device='cuda:0', grad_fn=<SelectBackward0>),
 tensor([-0.3187,  0.2006,  0.3408,  ..., -0.1273, -0.1980, -0.0009],
        device='cuda:0', grad_fn=<SelectBackward0>),
 tensor([ 0.0025, -0.0889,  0.0007,  ...,  0.1633, -0.0735, -0.0748],
        device='cuda:0', grad_fn

In [496]:
cur_o_nei = torch.stack(cur_o_nei[:MAX_NB], dim=0).view(-1, MAX_NB, hidden_size)
cur_o_nei

tensor([[[-1.0053, -0.6278,  0.9503,  ..., -0.9165, -0.9564,  0.9179],
         [-0.1486, -0.1987, -0.2134,  ...,  0.1359, -0.1658, -0.1233],
         [-0.2673, -0.1789, -0.1111,  ...,  0.1779,  0.0260, -0.0916],
         ...,
         [-0.3237,  0.2433, -0.1976,  ..., -0.0109, -0.1962, -0.0697],
         [ 0.1379, -0.2395, -0.1053,  ...,  0.2151, -0.2555, -0.1679],
         [-0.0308, -0.2504, -0.0852,  ...,  0.2045, -0.0680, -0.1314]]],
       device='cuda:0', grad_fn=<ViewBackward0>)

In [489]:
cur_nei

[tensor([-0.1486, -0.1987, -0.2134,  ...,  0.1359, -0.1658, -0.1233],
        device='cuda:0', grad_fn=<SelectBackward0>),
 tensor([-0.2673, -0.1789, -0.1111,  ...,  0.1779,  0.0260, -0.0916],
        device='cuda:0', grad_fn=<SelectBackward0>),
 tensor([-0.3187,  0.2006,  0.3408,  ..., -0.1273, -0.1980, -0.0009],
        device='cuda:0', grad_fn=<SelectBackward0>),
 tensor([ 0.0025, -0.0889,  0.0007,  ...,  0.1633, -0.0735, -0.0748],
        device='cuda:0', grad_fn=<SelectBackward0>),
 tensor([-0.3237,  0.2433, -0.1976,  ..., -0.0109, -0.1962, -0.0697],
        device='cuda:0', grad_fn=<SelectBackward0>),
 tensor([ 0.1379, -0.2395, -0.1053,  ...,  0.2151, -0.2555, -0.1679],
        device='cuda:0', grad_fn=<SelectBackward0>),
 tensor([-0.0308, -0.2504, -0.0852,  ...,  0.2045, -0.0680, -0.1314],
        device='cuda:0', grad_fn=<SelectBackward0>),
 tensor([-0.0794,  0.1330, -0.0175,  ...,  0.0113, -0.1469, -0.0161],
        device='cuda:0', grad_fn=<SelectBackward0>),
 tensor([-0.2083

In [488]:
cur_o_nei

tensor([[[-1.0053e+00, -6.2779e-01,  9.5025e-01,  ..., -9.1653e-01,
          -9.5636e-01,  9.1786e-01],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...,
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[-1.4856e-01, -1.9872e-01, -2.1339e-01,  ...,  1.3593e-01,
          -1.6576e-01, -1.2327e-01],
         [-2.6727e-01, -1.7890e-01, -1.1111e-01,  ...,  1.7788e-01,
           2.5965e-02, -9.1610e-02],
         [-3.1872e-01,  2.0057e-01,  3.4085e-01,  ..., -1.2734e-01,
          -1.9799e-01, -8.6986e-04],
         ...,
         [ 1.3793e-01, -2

In [487]:
try : 
    cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, hidden_size)
    cur_o = cur_o_nei.sum(dim=1)
except : 
    for mol_index, mol_tree in enumerate(mol_batch):
        print(f'mol_tree smiles is {mol_tree.smiles}')
    print(f'error smiles is {mol_tree.smiles}')
    print(cur_o_nei)
    print(cur_o_nei.shape)
    print(cur_x)
    print(cur_x.shape)
    print(torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, hidden_size))
    print(torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, hidden_size).sum(dim=1))

### prop_list

In [38]:
prop_list = []
batch_list = []
# traces는 결합관계를 1번째 노드에서 시작해서 끝까지 포함하는 것
# 아래의 for문은 해당 결합관계를 순서대로 prop_list로 추가하는 것이다.
# 그리고 이 분자안의 결합관계를 배치 개념으로 봐서 배치의 index로 추가한다.
for i, plist in enumerate(traces):
    if 0 < len(plist):
        prop_list.append(plist[0])
        batch_list.append(i)
    else:
        prop_list.append(None)

In [39]:
prop_list[1]==traces[1][0]

True

In [40]:
prop_list

[(<grover.topology.mol_tree.MolTreeNode_break at 0x7f89d957a490>,
  <grover.topology.mol_tree.MolTreeNode_break at 0x7f89d957a590>,
  1),
 (<grover.topology.mol_tree.MolTreeNode_break at 0x7f89d9546fd0>,
  <grover.topology.mol_tree.MolTreeNode_break at 0x7f89d954b110>,
  1)]

## em_list, cur_nei, cur_h_nei, cur_o_nei

In [41]:
em_list = []
cur_h_nei, cur_o_nei = [], []

for mol_index, prop in enumerate(prop_list):
    if prop is None:
        continue
    #node_x : 현재 노드, real_y : 다음에 올 정답 y
    node_x, real_y, _ = prop
    # Neighbors for message passing (target not included)
    # 메시지 패싱 네트워크를 위하여 인접 노드들의 리스트를 만들되 이때 정답의 값은 미포함한다.
    cur_nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx]
    pad_len = MAX_NB - len(cur_nei)
    # cur_h_nei : 현재 노드의 정답을 제외한 인접 노드들의 인덱스 리스트
    if pad_len>= 0:
        cur_h_nei.extend(cur_nei)
        cur_h_nei.extend([padding] * pad_len)
    else:
        cur_h_nei.extend(cur_nei[:MAX_NB])

    # Neighbors for stop prediction (all neighbors)
    # cur_o_nei는 멈출지에 대해 판단하기 위해 모든 인접 노드에 대한 리스트 생성
    cur_nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors]
    pad_len = MAX_NB - len(cur_nei)
    if pad_len >= 0:
        cur_o_nei.extend(cur_nei)
        cur_o_nei.extend([padding] * pad_len)
    else:
        cur_o_nei.extend(cur_nei[:MAX_NB])


    # Current clique embedding (size = hidden_size)
    # 현재 motif에서의 임베딩. 근데 여기서의 node_x의 clique가 뭐지?
    # chatgpt와의 상담결과 clique가 1개인 것은 이 node가 clique 1개에 포함되고, 인접 clique가 없다는 것이고,
    # clique가 여러개임은 이 node가 포함되는 clique가 있고, 인접에 또 다른 clique가 존재함을 의미
    em_list.append(torch.sum(node_rep[mol_index].index_select(0, torch.tensor(node_x.clique).to(device)), dim=0))

In [42]:
node_x = prop_list[0][0]
real_y = prop_list[0][1]

In [244]:
node_rep[mol_index].index_select(0, torch.tensor([0,2,10]).to(device))

tensor([[-0.1079, -0.9163,  0.2216,  ...,  0.4369, -1.8563, -0.0000],
        [-0.1206, -0.3257,  1.0068,  ..., -0.2150, -0.3437,  0.0495],
        [ 1.4415,  0.7626,  1.8714,  ..., -0.3281, -1.3577, -0.7717]],
       device='cuda:0', grad_fn=<IndexSelectBackward0>)

In [96]:
em_list

[tensor([ 2.4738,  1.0263,  1.4388,  ...,  0.0704, -1.6111, -3.1228],
        device='cuda:0', grad_fn=<SumBackward1>),
 tensor([ 1.7689, -0.9816,  1.9727,  ...,  0.7898, -1.8478, -0.6169],
        device='cuda:0', grad_fn=<SumBackward1>)]

## message passing

In [91]:
# 앞의 현 원자의 임베딩 값을 쌓는다. 사실상 따로노는 리스트상의 벡터를 매트릭스로 변환하는 작업
cur_x = torch.stack(em_list, dim=0)

# Message passing
# 현재 노드와 타겟을 제외한 노드들의 정보를 종합, message passing을 한다.
cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1, MAX_NB, 1200)
new_h = GRU(cur_x, cur_h_nei, W_z, W_r, U_r, W_h)

In [93]:
cur_x.shape

torch.Size([2, 1200])

In [100]:
cur_h_nei.shape

torch.Size([2, 8, 1200])

In [102]:
new_h

tensor([[ 0.0725, -0.0012,  0.0523,  ...,  0.1012, -0.0299,  0.3923],
        [ 0.2292,  0.0733,  0.0864,  ...,  0.1459, -0.0062,  0.4880]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [103]:
# Node Aggregate
cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, 1200)
cur_o = cur_o_nei.sum(dim=1)

In [105]:
cur_o_nei.shape

torch.Size([2, 8, 1200])

In [107]:
cur_o.shape

torch.Size([2, 1200])

In [108]:
# Gather target
# pred_target : y의 정답을 clique 사전에서 찾기
# pred_list : 현 분자에서의 예측에 쓸 분자 index
# stop_target : 0이면 돌아가는 것을 의미함. 1이면 들어왔음을 의미하고.

pred_target, pred_list = [], []
stop_target = []
prop_list = [x for x in prop_list if x is not None]
for i, m in enumerate(prop_list):
    node_x, node_y, direction = m
    x, y = node_x.idx, node_y.idx
    h[(x, y)] = new_h[i]
    node_y.neighbors.append(node_x)
    if direction == 1:
        pred_target.append(node_y.wid)
        pred_list.append(i)
    stop_target.append(direction)

In [114]:
pred_target, pred_list, stop_target, prop_list

([71304, 57097],
 [0, 1],
 [1, 1],
 [(<grover.topology.mol_tree.MolTreeNode_break at 0x7f89d957a490>,
   <grover.topology.mol_tree.MolTreeNode_break at 0x7f89d957a590>,
   1),
  (<grover.topology.mol_tree.MolTreeNode_break at 0x7f89d9546fd0>,
   <grover.topology.mol_tree.MolTreeNode_break at 0x7f89d954b110>,
   1)])

In [115]:
# Hidden states for stop prediction
# stop_hidden : 현재의 노드의 타겟을 제외한 인접 메시지 종합본에 우측?에 모든 인접 메시지를 종합한걸 이어 붙임
# stop_hiddens는 앞의 것들을 계속 확장
# stop_targets는 여기서 종료 여부를 모으는 작업
# 여기는 사실상 임베딩 벡터를 종합하고, 종료 여부를 종합하는 작업하는 곳
stop_hidden = torch.cat([cur_x, cur_o], dim=1)
stop_hiddens.append(stop_hidden)
stop_targets.extend(stop_target)

In [124]:
# Hidden states for clique prediction
if len(pred_list) > 0:
    #batch_list = [batch_list[i] for i in pred_list]
    #cur_batch = create_var(torch.LongTensor(batch_list))
    #pred_mol_vecs.append(mol_vec.index_select(0, cur_batch))
    
    # cur_pred는 현재에서 예측할 대상들
    # pred_hiddens는 예측할 대상들의 임베딩 벡터
    # 예측할 타겟
    cur_pred = create_var(torch.LongTensor(pred_list))
    pred_hiddens.append(new_h.index_select(0, cur_pred))
    pred_targets.extend(pred_target)

## 마지막 예측(여기부터 다시 돌리기)

In [152]:
# Last stop at root
# 마지막에 멈추는 것을 잘 예측할 것인지 마지막 node에 대한 정보를 종합
em_list, cur_o_nei = [], []
for mol_index, mol_tree in enumerate(mol_batch):
    print(mol_tree.smiles)
    node_x = mol_tree.nodes[0]
    em_list.append(torch.sum(node_rep[mol_index].index_select(0, torch.tensor(node_x.clique).to(device)), dim=0))
    cur_nei = [h[(node_y.idx, node_x.idx)] for node_y in node_x.neighbors]
    print(cur_nei)
    pad_len = MAX_NB - len(cur_nei)
    cur_o_nei.extend(cur_nei)
    cur_o_nei.extend([padding] * pad_len)

cur_x = torch.stack(em_list, dim=0)
cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1, MAX_NB, 1200)
cur_o = cur_o_nei.sum(dim=1)


FC(F)(F)c1ccc(-c2[nH]c(N3CCN(c4ncccc4C(F)(F)F)CC3)nc2-c2ccccc2)cc1
[tensor([ 0.4527,  0.3099,  0.0178,  ...,  0.6078, -0.2452,  0.7329],
       device='cuda:0', grad_fn=<SelectBackward0>)]
Oc1cc(Nc2ccnc3cc(Cl)ccc23)cc(-c2ccc(C(F)(F)F)cc2)c1
[tensor([ 0.6466,  0.5866,  0.5675,  ...,  1.4837, -0.1425,  0.8750],
       device='cuda:0', grad_fn=<SelectBackward0>)]


In [153]:
cur_x.shape

torch.Size([2, 1200])

In [154]:
stop_hidden = torch.cat([cur_x, cur_o], dim=1)
stop_hiddens.append(stop_hidden)
stop_targets.extend([0] * len(mol_batch))

In [156]:
stop_hidden.shape

torch.Size([2, 2400])

In [159]:
len(stop_hiddens)

39

In [160]:
stop_targets

[1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 1,
 1,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0]

In [162]:
# Predict stop
stop_hiddens = torch.cat(stop_hiddens, dim=0)
stop_vecs = nn.ReLU()(U(stop_hiddens))
stop_scores = U_s(stop_vecs).squeeze()
stop_targets = create_var(torch.Tensor(stop_targets))

stop_loss = stop_loss(stop_scores, stop_targets) #/ len(mol_batch)
stops = torch.ge(stop_scores, 0).float()
stop_acc = torch.eq(stops, stop_targets).float()
stop_acc = torch.sum(stop_acc) / stop_targets.nelement()

TypeError: cat() received an invalid combination of arguments - got (Tensor, dim=int), but expected one of:
 * (tuple of Tensors tensors, int dim, *, Tensor out)
 * (tuple of Tensors tensors, name dim, *, Tensor out)


In [155]:
# Predict next clique
pred_hiddens = torch.cat(pred_hiddens, dim=0)
#pred_mol_vecs = torch.cat(pred_mol_vecs, dim=0)
#pred_vecs = torch.cat([pred_hiddens, pred_mol_vecs], dim=1)
pred_vecs = pred_hiddens
pred_vecs = nn.ReLU()(W(pred_vecs))
pred_scores = W_o(pred_vecs)
pred_targets = create_var(torch.LongTensor(pred_targets))

TypeError: cat() received an invalid combination of arguments - got (Tensor, dim=int), but expected one of:
 * (tuple of Tensors tensors, int dim, *, Tensor out)
 * (tuple of Tensors tensors, name dim, *, Tensor out)


## 또다른 분자식

In [15]:
from grover.topology.dfs import *
vocab_size = motif_vocab.size()
hidden_size=1200
pred_loss = nn.CrossEntropyLoss(reduction="mean")
stop_loss = nn.BCEWithLogitsLoss(reduction="mean")
W_z = nn.Linear(2 * hidden_size, hidden_size).cuda()
U_r = nn.Linear(hidden_size, hidden_size, bias=False).cuda()
W_r = nn.Linear(hidden_size, hidden_size).cuda()
W_h = nn.Linear(2 * hidden_size, hidden_size).cuda()
W = nn.Linear(hidden_size, hidden_size).cuda()
U = nn.Linear(2 * hidden_size, hidden_size).cuda()
W_o = nn.Linear(hidden_size, vocab_size).cuda()
U_s = nn.Linear(hidden_size, 1).cuda()

In [17]:
MAX_NB=8

In [18]:
def set_batch_nodeID(mol_batch, vocab):
    tot = 0
    for mol_tree in mol_batch:
        for node in mol_tree.nodes:
            node.idx = tot
            node.wid = vocab.get_index(node.smiles)
            tot += 1

In [19]:
train_data_dl = DataLoader(train_data,
                           batch_size=2,
                           shuffle=shuffle,
                           num_workers=0,
                           sampler=train_sampler,
                           collate_fn=motif_collator)
test_data_dl = DataLoader(test_data,
                          batch_size=2,
                          shuffle=shuffle,
                          num_workers=0,
                          sampler=test_sampler,
                          collate_fn=motif_collator)

In [20]:
model.train()
motif_model.train()

for i, item1 in enumerate(train_data_dl):
    batch_graph1 = item1["graph_input"]
    targets1 = item1["targets"]

    # add this for motif generation
    moltree1 = item1["moltree"]

    if next(model.parameters()).is_cuda:
        targets1["av_task"] = targets1["av_task"]
        targets1["bv_task"] = targets1["bv_task"]
        targets1["fg_task"] = targets1["fg_task"]

    preds1 = model(batch_graph1)
    emb_vector1 = preds1['emb_vec']
    
    emb_afa_grouped1 = group_node_rep(moltree1, emb_vector1['atom_from_atom'], batch_graph1)
    
    #node_afa_loss, topo_afa_loss, node_afa_acc, topo_afa_acc = motif_model(moltree1, emb_afa_grouped1)
    if i==0:break

In [21]:
for i in range(len(moltree1)):
    print(moltree1[i].smiles)

Cl.O=C1c2cc(CO)cc(O)c2C(=O)c2c1ccc(C1(C3O[C@H](CO)[C@@H](O)[C@H](O)[C@H]3O)c3cccc(O)c3C(=O)c3c(O)cc(CO)cc31)c2O
Cc1cccc(C)c1NC(=O)C[N+](C)(C)CC(O)COc1cccc2ccccc12.[Cl-]


In [22]:
mol_batch1 = moltree1.copy()
node_rep1 = emb_afa_grouped1.copy()

In [23]:
node_rep1

[tensor([[ 0.9657, -0.0000, -0.5911,  ...,  0.7817, -1.2535, -0.0000],
         [ 0.9558, -0.2084, -0.4887,  ...,  1.1971, -1.3918, -0.0000],
         [ 1.7787,  1.2965,  0.3958,  ...,  0.0126, -0.9428, -1.7812],
         ...,
         [ 0.0140,  0.0612, -0.1296,  ...,  0.9470, -0.7918, -0.0593],
         [ 0.0878,  0.7641,  0.0867,  ..., -0.6685, -1.9724, -0.4697],
         [ 1.7776, -0.9270, -0.0000,  ...,  1.6140, -1.2967, -1.2059]],
        device='cuda:0', grad_fn=<SliceBackward0>),
 tensor([[ 0.6236, -0.4460, -0.4338,  ...,  1.2252, -0.0000, -0.6305],
         [ 0.9030, -0.3929,  0.6735,  ...,  1.0632, -1.2189, -0.0000],
         [ 0.8281,  0.0021, -0.6424,  ..., -0.3869, -1.8624, -1.0977],
         ...,
         [ 0.6406,  0.0101,  0.0466,  ...,  0.2641, -0.4310, -0.7308],
         [ 0.9161,  0.4658, -0.0000,  ..., -0.1305, -0.0000, -0.9166],
         [ 1.3393, -0.6038, -0.2350,  ...,  0.6411, -1.4211, -0.0000]],
        device='cuda:0', grad_fn=<SliceBackward0>)]

In [24]:
mol_batch1

[<grover.topology.mol_tree.MolTree_break at 0x7f60526a6110>,
 <grover.topology.mol_tree.MolTree_break at 0x7f605267c4d0>]

In [25]:
set_batch_nodeID(mol_batch1, motif_vocab)

In [26]:
super_root1 = MolTreeNode("")
super_root1.idx = -1

# Initialize
pred_hiddens1, pred_targets1 = [], []
stop_hiddens1, stop_targets1 = [], []
traces1 = []
for mol_tree1 in mol_batch1:
    print(mol_tree1.smiles)
    s1 = []
    dfs(s1, mol_tree1.nodes[0], super_root1)
    traces1.append(s1)
    for node1 in mol_tree1.nodes:
        node1.neighbors = []
'''
# Predict Root
pred_hiddens.append(create_var(torch.zeros(len(mol_batch), self.hidden_size)))
pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch])
pred_mol_vecs.append(mol_vec)
'''

max_iter1 = max([len(tr) for tr in traces1])
padding1 = create_var(torch.zeros(1200), False)
h1 = {}

Cl.O=C1c2cc(CO)cc(O)c2C(=O)c2c1ccc(C1(C3O[C@H](CO)[C@@H](O)[C@H](O)[C@H]3O)c3cccc(O)c3C(=O)c3c(O)cc(CO)cc31)c2O
Cc1cccc(C)c1NC(=O)C[N+](C)(C)CC(O)COc1cccc2ccccc12.[Cl-]


In [30]:
# dfs로 돌아가는 순서가 적힌 것
print(mol_batch1[0].smiles)
for i in range(len(traces1[0])):
    print(f'{i}th node smiles : {traces1[0][i][0].smiles}')
    print(f'{i}th node clique : {traces1[0][i][0].clique}')
    for j in range(len(traces1[0][i][0].neighbors)):
        print(f'{j}th neighbor : {traces1[0][i][0].neighbors[j].smiles}')
    print('\n')

Cl.O=C1c2cc(CO)cc(O)c2C(=O)c2c1ccc(C1(C3O[C@H](CO)[C@@H](O)[C@H](O)[C@H]3O)c3cccc(O)c3C(=O)c3c(O)cc(CO)cc31)c2O
0th node smiles : C1=CC=C2CC3=C(C=CC(C4(C5CCCCO5)C5=CC=CC=C5CC5=CC=CC=C54)=C3)CC2=C1
0th node clique : [2, 3, 4, 5, 8, 9, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 25, 27, 29, 31, 32, 33, 34, 35, 37, 38, 40, 41, 43, 44, 47, 48, 49]
0th neighbor : CO
1th neighbor : O
2th neighbor : CO
3th neighbor : O
4th neighbor : O
5th neighbor : O
6th neighbor : O
7th neighbor : O
8th neighbor : O
9th neighbor : O
10th neighbor : O
11th neighbor : CO
12th neighbor : O


1th node smiles : CO
1th node clique : [24, 23]
0th neighbor : C1=CC=C2CC3=C(C=CC(C4(C5CCCCO5)C5=CC=CC=C5CC5=CC=CC=C54)=C3)CC2=C1


2th node smiles : C1=CC=C2CC3=C(C=CC(C4(C5CCCCO5)C5=CC=CC=C5CC5=CC=CC=C54)=C3)CC2=C1
2th node clique : [2, 3, 4, 5, 8, 9, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 25, 27, 29, 31, 32, 33, 34, 35, 37, 38, 40, 41, 43, 44, 47, 48, 49]
0th neighbor : CO
1th neighbor : O
2th neighbor : CO
3th ne

In [31]:
# dfs로 돌아가는 순서가 적힌 것
print(mol_batch1[1].smiles)
for i in range(len(traces1[1])):
    print(f'{i}th node smiles : {traces1[1][i][0].smiles}')
    print(f'{i}th node clique : {traces1[1][i][0].clique}')
    for j in range(len(traces1[1][i][0].neighbors)):
        print(f'{j}th neighbor : {traces1[1][i][0].neighbors[j].smiles}')
    print('\n')

Cc1cccc(C)c1NC(=O)C[N+](C)(C)CC(O)COc1cccc2ccccc12.[Cl-]
0th node smiles : C
0th node clique : [0]
0th neighbor : C1=CC=CC=C1


1th node smiles : C1=CC=CC=C1
1th node clique : [1, 2, 3, 4, 5, 7]
0th neighbor : C
1th neighbor : C
2th neighbor : N


2th node smiles : C
2th node clique : [6]
0th neighbor : C1=CC=CC=C1


3th node smiles : C1=CC=CC=C1
3th node clique : [1, 2, 3, 4, 5, 7]
0th neighbor : C
1th neighbor : C
2th neighbor : N


4th node smiles : N
4th node clique : [8]
0th neighbor : C1=CC=CC=C1
1th neighbor : C


5th node smiles : C
5th node clique : [9]
0th neighbor : N
1th neighbor : O
2th neighbor : C


6th node smiles : O
6th node clique : [10]
0th neighbor : C


7th node smiles : C
7th node clique : [9]
0th neighbor : N
1th neighbor : O
2th neighbor : C


8th node smiles : C
8th node clique : [11]
0th neighbor : C
1th neighbor : [NH4+]


9th node smiles : [NH4+]
9th node clique : [12]
0th neighbor : C
1th neighbor : C
2th neighbor : C
3th neighbor : C


10th node smiles : 

### 아래꺼에서 문제 발생

In [29]:
for t in range(max_iter1):
    prop_list1 = []
    batch_list1 = []
    for i, plist in enumerate(traces1):
        if t < len(plist):
            prop_list1.append(plist[t])
            batch_list1.append(i)
        else:
            prop_list1.append(None)

    em_list1 = []
    cur_h_nei1, cur_o_nei1 = [], []

    for mol_index, prop in enumerate(prop_list1):
        if prop is None:
            continue
        node_x1, real_y1, _ = prop
        # Neighbors for message passing (target not included)
        cur_nei1 = [h1[(node_y1.idx, node_x1.idx)] for node_y1 in node_x1.neighbors if node_y1.idx != real_y1.idx]
        pad_len1 = MAX_NB - len(cur_nei1)
        if pad_len1 >= 0:
            cur_h_nei1.extend(cur_nei1)
            cur_h_nei1.extend([padding1] * pad_len1)
        else:
            cur_h_nei1.extend(cur_nei1[:MAX_NB])

        # Neighbors for stop prediction (all neighbors)
        cur_nei1 = [h1[(node_y1.idx, node_x1.idx)] for node_y1 in node_x1.neighbors]
        pad_len1 = MAX_NB - len(cur_nei1)
        if pad_len1 >= 0:
            cur_o_nei1.extend(cur_nei1)
            cur_o_nei1.extend([padding1] * pad_len1)
        else:
            cur_o_nei1.extend(cur_nei1[:MAX_NB])


        # Current clique embedding
        em_list1.append(torch.sum(node_rep1[mol_index].index_select(0, torch.tensor(node_x1.clique).to(device)), dim=0))
        

    # Clique embedding
    cur_x1 = torch.stack(em_list1, dim=0)

    # Message passing
    cur_h_nei1 = torch.stack(cur_h_nei1, dim=0).view(-1, MAX_NB, 1200)
    new_h1 = GRU(cur_x1, cur_h_nei1, W_z, W_r, U_r, W_h)

    # Node Aggregate
    cur_o_nei1 = torch.stack(cur_o_nei1, dim=0).view(-1, MAX_NB, 1200)
    cur_o1 = cur_o_nei1.sum(dim=1)

    # Gather targets
    pred_target1, pred_list1 = [], []
    stop_target1 = []
    prop_list1 = [x for x in prop_list1 if x is not None]
    for i, m in enumerate(prop_list1):
        node_x1, node_y1, direction1 = m
        x1, y1 = node_x1.idx, node_y1.idx
        h1[(x1, y1)] = new_h1[i]
        node_y1.neighbors.append(node_x1)
        if direction1 == 1:
            pred_target1.append(node_y1.wid)
            pred_list1.append(i)
        stop_target1.append(direction1)

    # Hidden states for stop prediction
    stop_hidden1 = torch.cat([cur_x1, cur_o1], dim=1)
    stop_hiddens1.append(stop_hidden1)
    stop_targets1.extend(stop_target1)

    # Hidden states for clique prediction
    if len(pred_list1) > 0:
        #batch_list = [batch_list[i] for i in pred_list]
        #cur_batch = create_var(torch.LongTensor(batch_list))
        #pred_mol_vecs.append(mol_vec.index_select(0, cur_batch))

        cur_pred1 = create_var(torch.LongTensor(pred_list1))
        pred_hiddens1.append(new_h1.index_select(0, cur_pred1))
        pred_targets1.extend(pred_target1)

In [406]:
em_list1, cur_o_nei1 = [], []
for mol_index1, mol_tree1 in enumerate(mol_batch1):
    node_x1 = mol_tree1.nodes[0]
    em_list1.append(torch.sum(node_rep1[mol_index1].index_select(0, torch.tensor(node_x1.clique).to(device)), dim=0))
    cur_nei1 = [h1[(node_y1.idx, node_x1.idx)] for node_y1 in node_x1.neighbors]
    pad_len1 = MAX_NB - len(cur_nei1)
    cur_o_nei1.extend(cur_nei1)
    cur_o_nei1.extend([padding1] * pad_len1)

cur_x1 = torch.stack(em_list1, dim=0)
try : 
    cur_o_nei1 = torch.stack(cur_o_nei1, dim=0).view(-1, MAX_NB, hidden_size)
    cur_o1 = cur_o_nei1.sum(dim=1)
except : 
    for mol_index1, mol_tree1 in enumerate(mol_batch1):
        print(f'mol_tree smiles is {mol_tree1.smiles}')
    print(f'error smiles is {mol_tree1.smiles}')
    print(cur_o_nei1)
    print(cur_o_nei1.shape)
    print(cur_x1)
    print(cur_x1.shape)
    print(torch.stack(cur_o_nei1, dim=0).view(-1, MAX_NB, hidden_size))
    print(torch.stack(cur_o_nei1, dim=0).view(-1, MAX_NB, hidden_size).sum(dim=1))

In [407]:
cur_o_nei1.shape

torch.Size([2, 15, 1200])

In [408]:
t

25

#### trace에서 1번째 node

In [92]:
#위에서 for문 생략하고 돌리기
t=0

In [93]:
prop_list1 = []
batch_list1 = []
for i, plist in enumerate(traces1):
    if t < len(plist):
        prop_list1.append(plist[t])
        batch_list1.append(i)
    else:
        prop_list1.append(None)

em_list1 = []
cur_h_nei1, cur_o_nei1 = [], []

In [94]:
prop_list1[0][0].smiles

'O'

In [95]:
#for mol_index, prop in enumerate(prop_list1):
mol_index = 0
prop = prop_list1[0]
node_x1, real_y1, _ = prop

In [96]:
node_x1, real_y1

(<grover.topology.mol_tree.MolTreeNode_break at 0x7f9a493f7b10>,
 <grover.topology.mol_tree.MolTreeNode_break at 0x7f9a493f7bd0>)

In [97]:
# Neighbors for message passing (target not included)
cur_nei1 = [h1[(node_y1.idx, node_x1.idx)] for node_y1 in node_x1.neighbors if node_y1.idx != real_y1.idx]
pad_len1 = MAX_NB - len(cur_nei1)
if pad_len1 >= 0:
    cur_h_nei1.extend(cur_nei1)
    cur_h_nei1.extend([padding1] * pad_len1)
else:
    cur_h_nei1.extend(cur_nei1[:MAX_NB])

In [98]:
cur_nei1

[]

In [99]:
cur_h_nei1

[tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')]

In [100]:
# Neighbors for stop prediction (all neighbors)
cur_nei1 = [h1[(node_y1.idx, node_x1.idx)] for node_y1 in node_x1.neighbors]
pad_len1 = MAX_NB - len(cur_nei1)
if pad_len1 >= 0:
    cur_o_nei1.extend(cur_nei1)
    cur_o_nei1.extend([padding1] * pad_len1)
else:
    cur_o_nei1.extend(cur_nei1[:MAX_NB])

In [101]:
cur_nei1

[]

In [102]:
cur_o_nei1

[tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')]

In [103]:
# Current clique embedding  <- 여기서 첫 에러 뜬다!!!
#em_list1.append(torch.sum(node_rep1[mol_index].index_select(0, torch.tensor(node_x1.clique).to(device)), dim=0))

In [104]:
node_x1.clique

[0]

In [105]:
print(node_rep1[mol_index], node_rep1[mol_index].shape)

tensor([[ 0.0000, -0.2322, -0.0604,  ...,  1.2353, -1.4820, -0.2532],
        [ 0.9791,  0.0000,  0.2418,  ..., -0.7371, -0.9990, -1.7210],
        [ 0.5931,  0.0000,  0.5181,  ..., -0.3096,  0.2467, -0.7316],
        ...,
        [ 0.0000, -0.1858,  0.7433,  ...,  1.0620, -1.2294, -0.9201],
        [ 0.6799,  0.0000,  0.4741,  ..., -0.6495, -1.2709, -0.1777],
        [ 0.1575,  0.4379,  0.4260,  ...,  0.8630, -0.0000,  0.7927]],
       device='cuda:0', grad_fn=<SliceBackward0>) torch.Size([24, 1200])


In [106]:
em_list1.append(torch.sum(node_rep1[mol_index].index_select(0, torch.tensor(node_x1.clique).to(device)), dim=0))
em_list1

[tensor([ 0.0000, -0.2322, -0.0604,  ...,  1.2353, -1.4820, -0.2532],
        device='cuda:0', grad_fn=<SumBackward1>)]

In [107]:
# Clique embedding
cur_x1 = torch.stack(em_list1, dim=0)
cur_x1

tensor([[ 0.0000, -0.2322, -0.0604,  ...,  1.2353, -1.4820, -0.2532]],
       device='cuda:0', grad_fn=<StackBackward0>)

In [108]:
# Message passing
cur_h_nei1 = torch.stack(cur_h_nei1, dim=0).view(-1, MAX_NB, 1200)
new_h1 = GRU(cur_x1, cur_h_nei1, W_z, W_r, U_r, W_h)
cur_h_nei1

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0')

In [109]:
cur_h_nei1.shape

torch.Size([1, 8, 1200])

In [110]:
new_h1

tensor([[ 0.3083,  0.3083, -0.1806,  ..., -0.1225, -0.0994,  0.0763]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [111]:
new_h1.shape

torch.Size([1, 1200])

In [112]:
# Node Aggregate
cur_o_nei1 = torch.stack(cur_o_nei1, dim=0).view(-1, MAX_NB, 1200)
cur_o1 = cur_o_nei1.sum(dim=1)
cur_o_nei1

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0')

In [113]:
cur_o_nei1.shape

torch.Size([1, 8, 1200])

In [114]:
cur_o1

tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')

In [115]:
cur_o1.shape

torch.Size([1, 1200])

In [116]:
# Gather targets
pred_target1, pred_list1 = [], []
stop_target1 = []
prop_list1 = [x for x in prop_list1 if x is not None]
prop_list1

[(<grover.topology.mol_tree.MolTreeNode_break at 0x7f9a493f7b10>,
  <grover.topology.mol_tree.MolTreeNode_break at 0x7f9a493f7bd0>,
  1),
 (<grover.topology.mol_tree.MolTreeNode_break at 0x7f9a4944ced0>,
  <grover.topology.mol_tree.MolTreeNode_break at 0x7f9a4944cfd0>,
  1)]

In [117]:
for i, m in enumerate(prop_list1):
    node_x1, node_y1, direction1 = m
    x1, y1 = node_x1.idx, node_y1.idx
    h1[(x1, y1)] = new_h1[i]
    node_y1.neighbors.append(node_x1)
    if direction1 == 1:
        pred_target1.append(node_y1.wid)
        pred_list1.append(i)
    stop_target1.append(direction1)

IndexError: index 1 is out of bounds for dimension 0 with size 1

In [None]:
# Gather targets
pred_target1, pred_list1 = [], []
stop_target1 = []
prop_list1 = [x for x in prop_list1 if x is not None]
for i, m in enumerate(prop_list1):
    node_x1, node_y1, direction1 = m
    x1, y1 = node_x1.idx, node_y1.idx
    h1[(x1, y1)] = new_h1[i]
    node_y1.neighbors.append(node_x1)
    if direction1 == 1:
        pred_target1.append(node_y1.wid)
        pred_list1.append(i)
    stop_target1.append(direction1)

# Hidden states for stop prediction
stop_hidden1 = torch.cat([cur_x1, cur_o1], dim=1)
stop_hiddens1.append(stop_hidden1)
stop_targets1.extend(stop_target1)

# Hidden states for clique prediction
if len(pred_list1) > 0:
    #batch_list = [batch_list[i] for i in pred_list]
    #cur_batch = create_var(torch.LongTensor(batch_list))
    #pred_mol_vecs.append(mol_vec.index_select(0, cur_batch))

    cur_pred1 = create_var(torch.LongTensor(pred_list1))
    pred_hiddens1.append(new_h1.index_select(0, cur_pred1))
    pred_targets1.extend(pred_target1)

In [106]:
t=0

In [107]:
prop_list1 = []
batch_list1 = []
for i, plist in enumerate(traces1):
    if t < len(plist):
        prop_list1.append(plist[t])
        batch_list1.append(i)
    else:
        prop_list1.append(None)

em_list1 = []
cur_h_nei1, cur_o_nei1 = [], []

In [108]:
mol_index = 0
prop = prop_list1[0]
node_x1, real_y1, _ = prop

In [109]:
node_x1, real_y1

(<grover.topology.mol_tree.MolTreeNode_break at 0x7fc16c5d4b50>,
 <grover.topology.mol_tree.MolTreeNode_break at 0x7fc16c5d4c10>)

In [110]:
cur_nei1 = [h1[(node_y1.idx, node_x1.idx)] for node_y1 in node_x1.neighbors if node_y1.idx != real_y1.idx]
cur_nei1

[]

In [111]:
pad_len1 = MAX_NB - len(cur_nei1)
if pad_len1 >= 0:
    cur_h_nei1.extend(cur_nei1)
    cur_h_nei1.extend([padding1] * pad_len1)
else:
    cur_h_nei1.extend(cur_nei1[:MAX_NB])
cur_h_nei1

[tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')]

In [112]:
cur_nei1 = [h1[(node_y1.idx, node_x1.idx)] for node_y1 in node_x1.neighbors]
cur_nei1

[]

In [113]:
pad_len1 = MAX_NB - len(cur_nei1)
if pad_len1 >= 0:
    cur_o_nei1.extend(cur_nei1)
    cur_o_nei1.extend([padding1] * pad_len1)
else:
    cur_o_nei1.extend(cur_nei1[:MAX_NB])
cur_o_nei1

[tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')]

In [114]:
node_x1.clique

[2,
 3,
 4,
 5,
 8,
 9,
 11,
 12,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 25,
 27,
 29,
 31,
 32,
 33,
 34,
 35,
 37,
 38,
 40,
 41,
 43,
 44,
 47,
 48,
 49]

In [115]:
# 일단 노드의 표현은 정상
node_rep1[mol_index]

tensor([[ 0.3272,  0.1187, -0.0000,  ...,  0.9684, -0.6193, -0.9083],
        [ 0.0805,  0.3285,  0.2019,  ...,  0.2417, -0.6316,  0.2726],
        [ 1.2339,  1.5807,  0.9370,  ...,  0.0000, -0.5533, -0.3055],
        ...,
        [-0.3315,  0.7579,  0.8269,  ..., -0.6230, -0.0919, -0.0000],
        [ 0.2876,  1.0661,  2.0430,  ..., -0.6365, -1.2341, -0.0990],
        [ 0.1914,  0.1172,  0.0374,  ...,  0.4083, -1.1632, -0.6278]],
       device='cuda:0', grad_fn=<SliceBackward0>)

In [116]:
node_rep1[mol_index].shape

torch.Size([51, 1200])

In [117]:
torch.sum(node_rep1[mol_index].index_select(0, torch.tensor(node_x1.clique).to(device)), dim=0)

tensor([ 10.2837,  16.5959,  25.5474,  ..., -16.6817, -23.7419,  -4.6524],
       device='cuda:0', grad_fn=<SumBackward1>)

In [118]:
em_list1.append(torch.sum(node_rep1[mol_index].index_select(0, torch.tensor(node_x1.clique).to(device)), dim=0))

In [119]:
em_list1

[tensor([ 10.2837,  16.5959,  25.5474,  ..., -16.6817, -23.7419,  -4.6524],
        device='cuda:0', grad_fn=<SumBackward1>)]

In [120]:
# Clique embedding
cur_x1 = torch.stack(em_list1, dim=0)
cur_x1

tensor([[ 10.2837,  16.5959,  25.5474,  ..., -16.6817, -23.7419,  -4.6524]],
       device='cuda:0', grad_fn=<StackBackward0>)

In [124]:
# Message passing
cur_h_nei1 = torch.stack(cur_h_nei1, dim=0).view(-1, MAX_NB, 1200)
new_h1 = GRU(cur_x1, cur_h_nei1, W_z, W_r, U_r, W_h)

new_h1

TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor

In [None]:
# Clique embedding
cur_x1 = torch.stack(em_list1, dim=0)

# Message passing
cur_h_nei1 = torch.stack(cur_h_nei1, dim=0).view(-1, MAX_NB, 1200)
new_h1 = GRU(cur_x1, cur_h_nei1, W_z, W_r, U_r, W_h)

# Node Aggregate
cur_o_nei1 = torch.stack(cur_o_nei1, dim=0).view(-1, MAX_NB, 1200)
cur_o1 = cur_o_nei1.sum(dim=1)

# Gather targets
pred_target1, pred_list1 = [], []
stop_target1 = []
prop_list1 = [x for x in prop_list1 if x is not None]
for i, m in enumerate(prop_list1):
    node_x1, node_y1, direction1 = m
    x1, y1 = node_x1.idx, node_y1.idx
    h1[(x1, y1)] = new_h1[i]
    node_y1.neighbors.append(node_x1)
    if direction1 == 1:
        pred_target1.append(node_y1.wid)
        pred_list1.append(i)
    stop_target1.append(direction1)

# Hidden states for stop prediction
stop_hidden1 = torch.cat([cur_x1, cur_o1], dim=1)
stop_hiddens1.append(stop_hidden1)
stop_targets1.extend(stop_target1)

# Hidden states for clique prediction
if len(pred_list1) > 0:
    #batch_list = [batch_list[i] for i in pred_list]
    #cur_batch = create_var(torch.LongTensor(batch_list))
    #pred_mol_vecs.append(mol_vec.index_select(0, cur_batch))

    cur_pred1 = create_var(torch.LongTensor(pred_list1))
    pred_hiddens1.append(new_h1.index_select(0, cur_pred1))
    pred_targets1.extend(pred_target1)

In [314]:
for t in range(max_iter1):
    prop_list1 = []
    batch_list1 = []
    for i, plist in enumerate(traces1):
        if t < len(plist):
            prop_list1.append(plist[t])
            batch_list1.append(i)
        else:
            prop_list1.append(None)

    em_list1 = []
    cur_h_nei1, cur_o_nei1 = [], []

In [312]:
for i, plist in enumerate(traces1):
    if t < len(plist):
        
        prop_list1.append(plist[t])
        
        batch_list1.append(i)
    else:
        prop_list1.append(None)

In [315]:
prop_list1

[None,
 (<grover.topology.mol_tree.MolTreeNode_break at 0x7f89d95f1e50>,
  <grover.topology.mol_tree.MolTreeNode_break at 0x7f89d95f1d90>,
  0)]

In [252]:
for mol_index, prop in enumerate(prop_list1):
    if prop is None:
        continue
    node_x1, real_y1, _ = prop
    # Neighbors for message passing (target not included)
    cur_nei1 = [h1[(node_y1.idx, node_x1.idx)] for node_y1 in node_x1.neighbors if node_y1.idx != real_y1.idx]
    pad_len1 = MAX_NB - len(cur_nei1)
    if pad_len1 >= 0:
        cur_h_nei1.extend(cur_nei1)
        cur_h_nei1.extend([padding1] * pad_len1)
    else:
        cur_h_nei1.extend(cur_nei1[:MAX_NB])

    # Neighbors for stop prediction (all neighbors)
    cur_nei1 = [h1[(node_y1.idx, node_x1.idx)] for node_y1 in node_x1.neighbors]
    pad_len1 = MAX_NB - len(cur_nei1)
    if pad_len1 >= 0:
        cur_o_nei1.extend(cur_nei1)
        cur_o_nei1.extend([padding1] * pad_len1)
    else:
        cur_o_nei1.extend(cur_nei1[:MAX_NB])


    # Current clique embedding
    em_list1.append(torch.sum(node_rep1[mol_index].index_select(0, torch.tensor(node_x1.clique).to(device)), dim=0))

In [256]:
em_list1

[tensor([-0.1315,  1.8345,  4.1327,  ...,  1.7675, -7.3927, -5.3204],
        device='cuda:0', grad_fn=<SumBackward1>)]

In [None]:
    







    # Hidden states for stop prediction
    stop_hidden1 = torch.cat([cur_x1, cur_o1], dim=1)
    stop_hiddens1.append(stop_hidden1)
    stop_targets1.extend(stop_target1)

    # Hidden states for clique prediction
    if len(pred_list1) > 0:
        #batch_list = [batch_list[i] for i in pred_list]
        #cur_batch = create_var(torch.LongTensor(batch_list))
        #pred_mol_vecs.append(mol_vec.index_select(0, cur_batch))

        cur_pred1 = create_var(torch.LongTensor(pred_list1))
        pred_hiddens1.append(new_h1.index_select(0, cur_pred1))
        pred_targets1.extend(pred_target1)

In [257]:
# Clique embedding
cur_x1 = torch.stack(em_list1, dim=0)
cur_x1

tensor([[-0.1315,  1.8345,  4.1327,  ...,  1.7675, -7.3927, -5.3204]],
       device='cuda:0', grad_fn=<StackBackward0>)

In [258]:
# Message passing
cur_h_nei1 = torch.stack(cur_h_nei1, dim=0).view(-1, MAX_NB, 1200)
new_h1 = GRU(cur_x1, cur_h_nei1, W_z, W_r, U_r, W_h)
new_h1

tensor([[-1.0000e+00,  1.0000e+00, -1.0000e+00,  ..., -2.9606e+03,
          2.0296e+02,  3.3766e+00]], device='cuda:0', grad_fn=<AddBackward0>)

In [259]:
# Node Aggregate
cur_o_nei1 = torch.stack(cur_o_nei1, dim=0).view(-1, MAX_NB, 1200)
cur_o1 = cur_o_nei1.sum(dim=1)
cur_o1

tensor([[-1.4324e+00, -2.0613e+03, -3.9814e-01,  ..., -1.9745e+03,
          1.3582e+02,  2.9715e+00]], device='cuda:0', grad_fn=<SumBackward1>)

In [270]:
# Gather targets
pred_target1, pred_list1 = [], []
stop_target1 = []
prop_list1 = [x for x in prop_list1 if x is not None]
for i, m in enumerate(prop_list1):
    node_x1, node_y1, direction1 = m
    x1, y1 = node_x1.idx, node_y1.idx
    h1[(x1, y1)] = new_h1[i]
    node_y1.neighbors.append(node_x1)
    if direction1 == 1:
        pred_target1.append(node_y1.wid)
        pred_list1.append(i)
    stop_target1.append(direction1)

In [283]:
prop_list1

[(<grover.topology.mol_tree.MolTreeNode_break at 0x7f89d95f1e50>,
  <grover.topology.mol_tree.MolTreeNode_break at 0x7f89d95f1d90>,
  0)]

In [381]:
# Last stop at root
em_list1, cur_o_nei1 = [], []
for mol_index, mol_tree in enumerate(mol_batch1):
    print(mol_tree.smiles)
    node_x1 = mol_tree.nodes[0]
    em_list1.append(torch.sum(node_rep1[mol_index].index_select(0, torch.tensor(node_x1.clique).to(device)), dim=0))
    cur_nei1 = [h1[(node_y1.idx, node_x1.idx)] for node_y1 in node_x1.neighbors]
    print(cur_nei1)
    pad_len1 = MAX_NB - len(cur_nei1)
    cur_o_nei1.extend(cur_nei1)
    cur_o_nei1.extend([padding1] * pad_len1)

#cur_x1 = torch.stack(em_list1, dim=0)
#cur_o_nei1 = torch.stack(cur_o_nei1, dim=0).view(-1, MAX_NB, 1200)
#cur_o1 = cur_o_nei1.sum(dim=1)


O=C1COc2ccc(C(=O)COc3ccc([N+](=O)[O-])cc3)cc2N1
[tensor([ 0.9287, -0.7753,  0.2930,  ...,  0.6038, -0.3346, -0.9487],
       device='cuda:0', grad_fn=<SelectBackward0>)]
Cl.O=C1c2cc(CO)cc(O)c2C(=O)c2c1ccc(C1(C3O[C@H](CO)[C@@H](O)[C@H](O)[C@H]3O)c3cccc(O)c3C(=O)c3c(O)cc(CO)cc31)c2O
[]


# 분자 3D로 그리기

In [382]:
import py3Dmol
from rdkit import Chem
from rdkit.Chem import AllChem

smiles = "Cl.O=C1c2cc(CO)cc(O)c2C(=O)c2c1ccc(C1(C3O[C@H](CO)[C@@H](O)[C@H](O)[C@H]3O)c3cccc(O)c3C(=O)c3c(O)cc(CO)cc31)c2O"
mol = Chem.MolFromSmiles(smiles)
mol = Chem.AddHs(mol)
AllChem.EmbedMolecule(mol, useExpTorsionAnglePrefs=True, useBasicKnowledge=True)
AllChem.MMFFOptimizeMolecule(mol)
pdb = Chem.MolToPDBBlock(mol)

In [383]:
import py3Dmol

def MolTo3DView(mol, size=(600, 600), style="stick", surface=False, opacity=0.5):
    """Draw molecule in 3D
    
    Args:
    ----
        mol: rdMol, molecule to show
        size: tuple(int, int), canvas size
        style: str, type of drawing molecule
               style can be 'line', 'stick', 'sphere', 'carton'
        surface, bool, display SAS
        opacity, float, opacity of surface, range 0.0-1.0
    Return:
    ----
        viewer: py3Dmol.view, a class for constructing embedded 3Dmol.js views in ipython notebooks.
    """
    assert style in ('line', 'stick', 'sphere', 'carton')
    mblock = Chem.MolToMolBlock(mol)
    viewer = py3Dmol.view(width=size[0], height=size[1])
    viewer.addModel(mblock, 'mol')
    viewer.setStyle({style:{}})
    if surface:
        viewer.addSurface(py3Dmol.SAS, {'opacity': opacity})
    viewer.zoomTo()
    return viewer

In [384]:
from rdkit import Chem
from rdkit.Chem import AllChem

def smi2conf(smiles):
    '''Convert SMILES to rdkit.Mol with 3D coordinates'''
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        mol = Chem.AddHs(mol)
        AllChem.EmbedMolecule(mol)
        AllChem.MMFFOptimizeMolecule(mol, maxIters=200)
        return mol
    else:
        return None

smi = 'Cl.O=C1c2cc(CO)cc(O)c2C(=O)c2c1ccc(C1(C3O[C@H](CO)[C@@H](O)[C@H](O)[C@H]3O)c3cccc(O)c3C(=O)c3c(O)cc(CO)cc31)c2O'
conf = smi2conf(smi)
viewer = MolTo3DView(conf, size=(600, 300), style='sphere')
viewer.show()

In [224]:
from ipywidgets import interact,fixed,IntSlider
import ipywidgets


def conf_viewer():
    mol = conf
    return MolTo3DView(mol).show()

interact(conf_viewer, idx=ipywidgets.IntSlider(step=1))

interactive(children=(Output(),), _dom_classes=('widget-interact',))

<function __main__.conf_viewer()>

In [225]:
from ipywidgets import interact,fixed,IntSlider
import ipywidgets

smis = [ 'COc3nc(OCc2ccc(C#N)c(c1ccc(C(=O)O)cc1)c2P(=O)(O)O)ccc3C[NH2+]CC(I)NC(=O)C(F)(Cl)Br',
    'CC(NCCNCC1=CC=C(OCC2=C(C)C(C3=CC=CC=C3)=CC=C2)N=C1OC)=O',
    'Cc1c(COc2cc(OCc3cccc(c3)C#N)c(CN3C[C@H](O)C[C@H]3C(O)=O)cc2Cl)cccc1-c1ccc2OCCOc2c1',
    'CCCCC(=O)NCCCCC(=O)NCCCCCC(=O)[O-]',
    "CC(NCCNCC1=CC=C(OCC2=C(C)C(C3=CC=CC=C3)=CC=C2)N=C1OC)=O"]

def style_selector(s):
    return MolTo3DView(conf, style=s).show()

interact(style_selector, 
         idx=ipywidgets.IntSlider(step=1),
         s=ipywidgets.Dropdown(
            options=['line', 'stick', 'sphere'],
            value='line',
            description='Style:'))

interactive(children=(Dropdown(description='Style:', options=('line', 'stick', 'sphere'), value='line'), Outpu…

<function __main__.style_selector(s)>

# allinone으로 moltree label 생성과정 확인

In [1]:
'''
history
1. this was made for just preprocess features all in one
2. make temp file and can continue process
3. if some smiles use too many time to process, then skip
3.1 add logger
'''

import pickle
import os
import csv
import shutil
import pandas as pd
import time
#import torch
from collections import Counter
from typing import Callable, Union

from argparse import ArgumentParser, Namespace
import numpy as np
from multiprocessing import Pool
from typing import List, Tuple
from tqdm import tqdm
from rdkit import RDLogger

import sys
#sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

import grover.util.utils as fea_utils
from grover.util.utils import get_data, makedirs, load_features, save_features, create_logger
from grover.data.molfeaturegenerator import get_available_features_generators, \
    get_features_generator
from grover.data.task_labels import rdkit_functional_group_label_features_generator
from grover.topology.mol_tree import *
from grover.topology.chemutils import *

from rdkit import Chem
from descriptastorus.descriptors import rdDescriptors

from grover.data.molfeaturegenerator import register_features_generator

def load_smiles(data_path):
    with open(data_path) as f:
        reader = csv.reader(f)
        header = next(reader)
        res = []
        for line in reader:
            res.append(line)
    return res, header

def save_smiles(data_path, index, data, header='smiles'):
    fn = os.path.join(data_path, str(index) + ".csv")
    with open(fn, "w") as f:
        fw = csv.writer(f)
        fw.writerow([header])
        for d in data:
            fw.writerow([d])
    f.close()
    
def save_features(data_path, index, features):
    fn = os.path.join(data_path, str(index) + ".npz")
    np.savez_compressed(fn, features=features)
    
def save_moltrees(data_path, index, moltrees):
    fn = os.path.join(data_path, str(index) + ".p")
    with open(fn, 'wb') as file: 
        pickle.dump(moltrees, file)
    file.close()

def save_cliques(data_path, index, cliques):
    clique_path = data_path+f'/clique{index}.txt'
    with open(clique_path, 'w') as file:
        for c in cliques:
            file.write(c)
            file.write('\n')
    file.close()
    
def load_checkpoint(process_path):
    with open(process_path, 'r') as file:
        line1 = file.readline()
        line2 = file.readline()
        temp_i = int(line1.split(' \n')[0])
        temp_num = int(line2)
    file.close()
    return temp_i, temp_num

def load_cliques(data_path, index):
    cliques = set()
    clique_path = args.output_path+f'/cliques/clique{index}.txt'
    with open(clique_path, 'r') as file:
        for line in file:
            cliques.add(line)
    file.close()
    return cliques

def enum_assemble(node, neighbors, prev_nodes=[], prev_amap=[]):
    all_attach_confs = []
    singletons = [nei_node.nid for nei_node in neighbors + prev_nodes if nei_node.mol.GetNumAtoms() == 1]
    e_atime = time.time()
    
    def search(cur_amap, depth):
        if len(all_attach_confs) > MAX_NCAND:
            return
        if depth == len(neighbors):
            all_attach_confs.append(cur_amap)
            return

        nei_node = neighbors[depth]
        cand_amap = enum_attach(node.mol, nei_node, cur_amap, singletons)
        cand_smiles = set()
        candidates = []
        
        for amap in cand_amap:
            cand_mol = local_attach(node.mol, neighbors[:depth + 1], prev_nodes, amap)
            cand_mol = sanitize(cand_mol)
            if cand_mol is None:
                continue
            smiles = get_smiles(cand_mol)
            if smiles in cand_smiles:
                continue
            cand_smiles.add(smiles)
            candidates.append(amap)
            cdtime = time.time() - e_atime
            if cdtime > 50:
                return [3]
            

        if len(candidates) == 0:
            return

        for new_amap in candidates:
            time1 = search(new_amap, depth + 1)
            if time1==[3]:
                return time1

    time2 = search(prev_amap, 0)
    if time2 == [3]:
        return [3]
    cand_smiles = set()
    candidates = []
    for amap in all_attach_confs:
        cand_mol = local_attach(node.mol, neighbors, prev_nodes, amap)
        cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol))
        smiles = Chem.MolToSmiles(cand_mol)
        if smiles in cand_smiles:
            continue
        cand_smiles.add(smiles)
        Chem.Kekulize(cand_mol)
        candidates.append((smiles, cand_mol, amap))

    return candidates

class MolTreeNode(object):

    def __init__(self, smiles, clique=[]):
        self.smiles = smiles
        self.mol = get_mol(self.smiles)
        #self.mol = cmol

        self.clique = [x for x in clique] #copy
        self.neighbors = []
        
    def add_neighbor(self, nei_node):
        self.neighbors.append(nei_node)

    def recover(self, original_mol):
        clique = []
        clique.extend(self.clique)
        if not self.is_leaf:
            for cidx in self.clique:
                original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(self.nid)

        for nei_node in self.neighbors:
            clique.extend(nei_node.clique)
            if nei_node.is_leaf: #Leaf node, no need to mark 
                continue
            for cidx in nei_node.clique:
                #allow singleton node override the atom mapping
                if cidx not in self.clique or len(nei_node.clique) == 1:
                    atom = original_mol.GetAtomWithIdx(cidx)
                    atom.SetAtomMapNum(nei_node.nid)

        clique = list(set(clique))
        label_mol = get_clique_mol(original_mol, clique)
        self.label = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
        self.label_mol = get_mol(self.label)

        for cidx in clique:
            original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)

        return self.label
    
    def assemble(self):
        neighbors = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() > 1]
        neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
        singletons = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() == 1]
        neighbors = singletons + neighbors

        cands = enum_assemble(self, neighbors)
        if cands == [3]:
            return [3]
        if len(cands) > 0:
            self.cands, self.cand_mols, _ = zip(*cands)
            self.cands = list(self.cands)
            self.cand_mols = list(self.cand_mols)
        else:
            self.cands = []
            self.cand_mols = []
            
class MolTree(object):

    def __init__(self, smiles):
        self.smiles = smiles
        self.mol = get_mol(smiles)

        '''
        #Stereo Generation
        mol = Chem.MolFromSmiles(smiles)
        self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
        self.smiles2D = Chem.MolToSmiles(mol)
        self.stereo_cands = decode_stereo(self.smiles2D)
        '''

        cliques, edges = brics_decomp(self.mol)
        if len(edges) <= 1:
            cliques, edges = tree_decomp(self.mol)
        self.nodes = []
        root = 0
        for i,c in enumerate(cliques):
            cmol = get_clique_mol(self.mol, c)
            node = MolTreeNode(get_smiles(cmol), c)
            self.nodes.append(node)
            if min(c) == 0:
                root = i

        for x,y in edges:
            self.nodes[x].add_neighbor(self.nodes[y])
            self.nodes[y].add_neighbor(self.nodes[x])
        
        if root > 0:
            self.nodes[0],self.nodes[root] = self.nodes[root],self.nodes[0]

        for i,node in enumerate(self.nodes):
            node.nid = i + 1
            if len(node.neighbors) > 1: #Leaf node mol is not marked
                set_atommap(node.mol, node.nid)
            node.is_leaf = (len(node.neighbors) == 1)

    def size(self):
        return len(self.nodes)

    def recover(self):
        for node in self.nodes:
            node.recover(self.mol)

    def assemble(self):
        for node in self.nodes:
            assem_time = node.assemble()
            if assem_time == [3] : 
                return assem_time

Molecule = Union[str, Chem.Mol]
FeaturesGenerator = Callable[[Molecule], np.ndarray]

# The functional group descriptors in RDkit.
RDKIT_PROPS = ['fr_Al_COO', 'fr_Al_OH', 'fr_Al_OH_noTert', 'fr_ArN',
               'fr_Ar_COO', 'fr_Ar_N', 'fr_Ar_NH', 'fr_Ar_OH', 'fr_COO', 'fr_COO2',
               'fr_C_O', 'fr_C_O_noCOO', 'fr_C_S', 'fr_HOCCN', 'fr_Imine', 'fr_NH0',
               'fr_NH1', 'fr_NH2', 'fr_N_O', 'fr_Ndealkylation1', 'fr_Ndealkylation2',
               'fr_Nhpyrrole', 'fr_SH', 'fr_aldehyde', 'fr_alkyl_carbamate', 'fr_alkyl_halide',
               'fr_allylic_oxid', 'fr_amide', 'fr_amidine', 'fr_aniline', 'fr_aryl_methyl',
               'fr_azide', 'fr_azo', 'fr_barbitur', 'fr_benzene', 'fr_benzodiazepine',
               'fr_bicyclic', 'fr_diazo', 'fr_dihydropyridine', 'fr_epoxide', 'fr_ester',
               'fr_ether', 'fr_furan', 'fr_guanido', 'fr_halogen', 'fr_hdrzine', 'fr_hdrzone',
               'fr_imidazole', 'fr_imide', 'fr_isocyan', 'fr_isothiocyan', 'fr_ketone',
               'fr_ketone_Topliss', 'fr_lactam', 'fr_lactone', 'fr_methoxy', 'fr_morpholine',
               'fr_nitrile', 'fr_nitro', 'fr_nitro_arom', 'fr_nitro_arom_nonortho',
               'fr_nitroso', 'fr_oxazole', 'fr_oxime', 'fr_para_hydroxylation', 'fr_phenol',
               'fr_phenol_noOrthoHbond', 'fr_phos_acid', 'fr_phos_ester', 'fr_piperdine',
               'fr_piperzine', 'fr_priamide', 'fr_prisulfonamd', 'fr_pyridine', 'fr_quatN',
               'fr_sulfide', 'fr_sulfonamd', 'fr_sulfone', 'fr_term_acetylene', 'fr_tetrazole',
               'fr_thiazole', 'fr_thiocyan', 'fr_thiophene', 'fr_unbrch_alkane', 'fr_urea']

BOND_FEATURES = ['BondType', 'Stereo', 'BondDir']

@register_features_generator('allinone')
def make_fgfeatures_moltree_clique(mol):
    """
    Generates functional group label for a molecule using RDKit.

    :param mol: A molecule (i.e. either a SMILES string or an RDKit molecule).
    :return: A 1D numpy array containing the RDKit 2D features.
    """
    try : 
        smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol
        #test GetNumHeavyAtoms()
        mol = Chem.MolFromSmiles(smiles)
        mol.GetNumHeavyAtoms()

        #make fg features
        generator = rdDescriptors.RDKit2D(RDKIT_PROPS)
        features = generator.process(smiles)[1:]
        features = np.array(features)
        features[features != 0] = 1

        #make cliuqe
        mol_tree = MolTree(smiles)
        cset=[]
        for node in mol_tree.nodes:
            cset.append(node.smiles)

        #make moltree and if smiles is need too much time, break
        mol_tree.recover()
        t=mol_tree.assemble()
        if t == [3]:
            info(f'smiles {smiles} is take too much time so break')
            return None, None, None, None
        
        return smiles, features, cset, mol_tree
    
    except : 
        return None, None, None, None
    
#execute

In [8]:
res = ['Cl.O=C1c2cc(CO)cc(O)c2C(=O)c2c1ccc(C1(C3O[C@H](CO)[C@@H](O)[C@H](O)[C@H]3O)c3cccc(O)c3C(=O)c3c(O)cc(CO)cc31)c2O', 'CCCOc1ccccc1-c1nnc(NN)s1.Cl', 'CC[N+](C)(CC)CCC(=O)Nc1ccc2c(c1)C(=O)c1cc(NC(=O)CC[N+](C)(CC)CC)ccc1-2.[I-].[I-]', 'Cl.O=C(COc1ccccc1)N1CCN(CC(O)COc2ccccc2)CC1']
data = pd.DataFrame(res, columns=['smiles'])
features_generator = get_features_generator('allinone')

In [9]:
mapping = map(features_generator, data.smiles)

In [10]:
smiles_list = []
smiles_full_list = []
features_list = []
moltree_list = []
cliques = []
for i, output in tqdm(enumerate(mapping), total=1):
    print(f'{output[0]}\n{output[2]}\n\n')
    smiles_list.append(output[0])
    smiles_full_list.append(output[0])
    features_list.append(output[1])
    cliques.append(output[2])
    moltree_list.append(output[3])

4it [00:00, 29.00it/s]                       

Cl.O=C1c2cc(CO)cc(O)c2C(=O)c2c1ccc(C1(C3O[C@H](CO)[C@@H](O)[C@H](O)[C@H]3O)c3cccc(O)c3C(=O)c3c(O)cc(CO)cc31)c2O
['C1=CC=C2CC3=C(C=CC(C4(C5CCCCO5)C5=CC=CC=C5CC5=CC=CC=C54)=C3)CC2=C1', 'CO', 'CO', 'CO', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


CCCOc1ccccc1-c1nnc(NN)s1.Cl
['CCC', 'C1=CC=C(C2=NN=CS2)C=C1', 'NN', 'O']


CC[N+](C)(CC)CCC(=O)Nc1ccc2c(c1)C(=O)c1cc(NC(=O)CC[N+](C)(CC)CC)ccc1-2.[I-].[I-]
['CC', 'CC', 'CC', 'C1=CC=C2C(=C1)CC1=CC=CC=C12', 'CC', 'CC', 'CC', 'C', 'N', 'C', 'N', '[NH4+]', '[NH4+]', 'O', 'C', 'O', 'O', 'C']


Cl.O=C(COc1ccccc1)N1CCN(CC(O)COc2ccccc2)CC1
['C1=CC=CC=C1', 'C1CNCCN1', 'C1=CC=CC=C1', 'C', 'O', 'C', 'O', 'C', 'C', 'O', 'C', 'O']







In [51]:
for i in range(len(moltree_list)):
    print(f'molecular : {moltree_list[i].smiles}')
    for j in range(len(moltree_list[i].nodes)):
        print(f'molecular : {moltree_list[i].smiles}, {j}node : {moltree_list[i].nodes[j].smiles}')
        for k in range(len(moltree_list[i].nodes[j].neighbors)):
            print(f'{j}node, {k}neighbor : {moltree_list[i].nodes[j].neighbors[k].smiles}')
    print('')    

molecular : Cl.O=C1c2cc(CO)cc(O)c2C(=O)c2c1ccc(C1(C3O[C@H](CO)[C@@H](O)[C@H](O)[C@H]3O)c3cccc(O)c3C(=O)c3c(O)cc(CO)cc31)c2O
molecular : Cl.O=C1c2cc(CO)cc(O)c2C(=O)c2c1ccc(C1(C3O[C@H](CO)[C@@H](O)[C@H](O)[C@H]3O)c3cccc(O)c3C(=O)c3c(O)cc(CO)cc31)c2O, 0node : C1=CC=C2CC3=C(C=CC(C4(C5CCCCO5)C5=CC=CC=C5CC5=CC=CC=C54)=C3)CC2=C1
0node, 0neighbor : CO
0node, 1neighbor : O
0node, 2neighbor : CO
0node, 3neighbor : O
0node, 4neighbor : O
0node, 5neighbor : O
0node, 6neighbor : O
0node, 7neighbor : O
0node, 8neighbor : O
0node, 9neighbor : O
0node, 10neighbor : O
0node, 11neighbor : CO
0node, 12neighbor : O
molecular : Cl.O=C1c2cc(CO)cc(O)c2C(=O)c2c1ccc(C1(C3O[C@H](CO)[C@@H](O)[C@H](O)[C@H]3O)c3cccc(O)c3C(=O)c3c(O)cc(CO)cc31)c2O, 1node : CO
1node, 0neighbor : C1=CC=C2CC3=C(C=CC(C4(C5CCCCO5)C5=CC=CC=C5CC5=CC=CC=C54)=C3)CC2=C1
molecular : Cl.O=C1c2cc(CO)cc(O)c2C(=O)c2c1ccc(C1(C3O[C@H](CO)[C@@H](O)[C@H](O)[C@H]3O)c3cccc(O)c3C(=O)c3c(O)cc(CO)cc31)c2O, 2node : CO
2node, 0neighbor : C1=CC=C2CC3=C(C=CC(C