In [1]:
import time
import argparse
import random
import copy
import os
import os.path as osp
import math
from glob import glob
import re

import numpy as np
import sympy as sym
from tqdm import tqdm
import pandas as pd

from typing import Optional, Callable, List

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
from warmup_scheduler import GradualWarmupScheduler

from torch_scatter import scatter
from torch_sparse import SparseTensor
from torch_geometric.data import (InMemoryDataset, download_url, extract_zip, Data)
from torch_geometric.data import DataLoader
from torch_geometric.nn import global_mean_pool, global_add_pool, radius
from torch_geometric.utils import remove_self_loops, add_self_loops, sort_edge_index

from utils import BesselBasisLayer, SphericalBasisLayer, EMA, MLP
from layers import Global_MP, Local_MP


import rdkit
from rdkit import Chem
from rdkit.Chem.rdchem import HybridizationType
from rdkit.Chem.rdchem import BondType as BT
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(1337) # Seed 고정

In [3]:
test_df = pd.read_csv('data/test_set.csv')

In [4]:
class Samsung(InMemoryDataset):
    if rdkit is not None:
        types = {'H':0, 'B':1, 'N':2, 'O':3, 'F':4, 'C':5, 'Si':6, 'P':7, 'S':8, 'Cl':9, 'Br':10, 'I':11}
        symbols = {'H':1, 'B':5, 'N':7, 'O':8, 'F':9, 'C':12, 'Si':14, 'P':15, 'S':16, 'Cl':17, 'Br':35, 'I':53}
        bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}

    def __init__(self, root: str, transform: Optional[Callable] = None,
                 pre_transform: Optional[Callable] = None,
                 pre_filter: Optional[Callable] = None):
        super().__init__(root, transform, pre_transform, pre_filter)
        
        self.data, self.slices = torch.load(self.processed_paths[0])

#     def mean(self, target: int) -> float:
#         y = torch.cat([self.get(i).y for i in range(len(self))], dim=0)
#         return float(y[:, target].mean())

#     def std(self, target: int) -> float:
#         y = torch.cat([self.get(i).y for i in range(len(self))], dim=0)
#         return float(y[:, target].std())

#     def atomref(self, target) -> Optional[torch.Tensor]:
#         if target in atomrefs:
#             out = torch.zeros(100)
#             out[torch.tensor([1, 6, 7, 8, 9])] = torch.tensor(atomrefs[target])
#             return out.view(-1, 1)
#         return None

    @property
    def raw_file_names(self) -> List[str]:
        if rdkit is None:
            return ['samsung_test.pt']
        else:
            return ['samsung_test.sdf', 'samsung_test.sdf.csv', 'uncharacterized.txt']

    @property
    def processed_file_names(self) -> str:
        return 'samsung_test.pt'

    def process(self):
        if rdkit is None:
            print('Using a pre-processed version of the dataset. Please '
                  'install `rdkit` to alternatively process the raw data.')

            self.data, self.slices = torch.load(self.raw_paths[0])
            data_list = [self.get(i) for i in range(len(self))]

            if self.pre_filter is not None:
                data_list = [d for d in data_list if self.pre_filter(d)]

            if self.pre_transform is not None:
                data_list = [self.pre_transform(d) for d in data_list]

            torch.save(self.collate(data_list), self.processed_paths[0])
            return

        data_list = []
        max_dist = 0
        for idx, i in enumerate(tqdm(test_df[test_df.columns[0]])):
            g_mol = Chem.MolFromMolFile(f'data/mol_files/test_set/{i}' + '_g.mol', removeHs=False, sanitize=False)
            ex_mol = Chem.MolFromMolFile(f'data/mol_files/test_set/{i}' + '_ex.mol', removeHs=False, sanitize=False)

            N = g_mol.GetNumAtoms()

            tmp_g = pd.read_csv(f"data/test_g_file/{i}" + '_g.csv')
            tmp_ex = pd.read_csv(f"data/test_ex_file/{i}" + '_ex.csv')

            pos_g = np.array(tmp_g[['0', '1', '2']])
            pos_ex = np.array(tmp_ex[['0', '1', '2']])

            pos_g = torch.tensor(pos_g, dtype=torch.float)
            pos_ex = torch.tensor(pos_ex, dtype=torch.float)

            type_idx = []
            atomic_number = []
            aromatic = []
            sp = []
            sp2 = []
            sp3 = []
            num_hs = []
            
            for atom in g_mol.GetAtoms():
                type_idx.append(self.types[atom.GetSymbol()])
                atomic_number.append(atom.GetAtomicNum())
                aromatic.append(1 if atom.GetIsAromatic() else 0)
                hybridization = atom.GetHybridization()
                sp.append(1 if hybridization == HybridizationType.SP else 0)
                sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
                sp3.append(1 if hybridization == HybridizationType.SP3 else 0)

            z = torch.tensor(atomic_number, dtype=torch.long)

            row, col, edge_type = [], [], []
            for bond in g_mol.GetBonds():
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                row += [start, end]
                col += [end, start]
                edge_type += 2 * [self.bonds[bond.GetBondType()]]

            edge_index = torch.tensor([row, col], dtype=torch.long)
            edge_type = torch.tensor(edge_type, dtype=torch.long)
            edge_attr = F.one_hot(edge_type,
                                  num_classes=len(self.bonds)).to(torch.float)

            perm = (edge_index[0] * N + edge_index[1]).argsort()
            edge_index = edge_index[:, perm]
            edge_type = edge_type[perm]
            edge_attr = edge_attr[perm]

            row, col = edge_index
            hs = (z == 1).to(torch.float)
            num_hs = scatter(hs[row], col, dim_size=N).tolist()

            x = torch.tensor(type_idx).to(torch.float)

            data = Data(x=x, z=z, pos_g=pos_g, pos_ex=pos_ex, edge_index=edge_index,
                        edge_attr=edge_attr, name=i, idx=idx)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue
            if self.pre_transform is not None:
                data = self.pre_transform(data)

            data_list.append(data)
        torch.save(self.collate(data_list), self.processed_paths[0])

In [5]:
path = 'data/samsung'
test_dataset = Samsung(path)
print('# of graphs:', len(test_dataset))
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

device = torch.device('cuda:3') if torch.cuda.is_available() else torch.device('cpu')

n_layer = 6
dim = 128
cutoff = 5.0

# of graphs: 457




In [6]:
def test(loader, model):
    error = 0
    ema.assign(model)
    with torch.no_grad():
        pred = []
        for data in loader:
            data = data.to(device)
            output = model(data)
            pred.append(output.cpu())
    ema.resume(model)
    return torch.cat(pred).numpy()

In [7]:
class Config(object):
    def __init__(self, dim, n_layer, cutoff):
        self.dim = dim
        self.n_layer = n_layer
        self.cutoff = cutoff

class MXMNetG(nn.Module):
    def __init__(self, config: Config, num_spherical=7, num_radial=6, envelope_exponent=5):
        super(MXMNetG, self).__init__()

        self.dim = config.dim
        self.n_layer = config.n_layer
        self.cutoff = config.cutoff

        self.embeddings = nn.Parameter(torch.ones((12, self.dim)))

        self.rbf_l = BesselBasisLayer(16, 5, envelope_exponent)
        self.rbf_g = BesselBasisLayer(16, self.cutoff, envelope_exponent)
        self.sbf = SphericalBasisLayer(num_spherical, num_radial, 5, envelope_exponent)

        self.rbf_g_mlp = MLP([16, self.dim])
        self.rbf_l_mlp = MLP([16, self.dim])

        self.sbf_1_mlp = MLP([num_spherical * num_radial, self.dim])
        self.sbf_2_mlp = MLP([num_spherical * num_radial, self.dim])

        self.global_layers = torch.nn.ModuleList()
        for layer in range(config.n_layer):
            self.global_layers.append(Global_MP(config))

        self.local_layers = torch.nn.ModuleList()
        for layer in range(config.n_layer):
            self.local_layers.append(Local_MP(config))
        
        self.init()

    def init(self):
        stdv = math.sqrt(3)
        self.embeddings.data.uniform_(-stdv, stdv)

    def indices(self, edge_index, num_nodes):
        row, col = edge_index

        value = torch.arange(row.size(0), device=row.device)
        adj_t = SparseTensor(row=col, col=row, value=value,
                             sparse_sizes=(num_nodes, num_nodes))
        
        #Compute the node indices for two-hop angles
        adj_t_row = adj_t[row]
        num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)

        idx_i = col.repeat_interleave(num_triplets)
        idx_j = row.repeat_interleave(num_triplets)
        idx_k = adj_t_row.storage.col()
        mask = idx_i != idx_k
        idx_i_1, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]

        idx_kj = adj_t_row.storage.value()[mask]
        idx_ji_1 = adj_t_row.storage.row()[mask]

        #Compute the node indices for one-hop angles
        adj_t_col = adj_t[col]

        num_pairs = adj_t_col.set_value(None).sum(dim=1).to(torch.long)
        idx_i_2 = row.repeat_interleave(num_pairs)
        idx_j1 = col.repeat_interleave(num_pairs)
        idx_j2 = adj_t_col.storage.col()

        idx_ji_2 = adj_t_col.storage.row()
        idx_jj = adj_t_col.storage.value()

        return idx_i_1, idx_j, idx_k, idx_kj, idx_ji_1, idx_i_2, idx_j1, idx_j2, idx_jj, idx_ji_2


    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        pos = data.pos_g
        batch = data.batch
        # Initialize node embeddings
        h = torch.index_select(self.embeddings, 0, x.long())

        # Get the edges and pairwise distances in the local layer
        edge_index_l, _ = remove_self_loops(edge_index)
        j_l, i_l = edge_index_l
        dist_l = (pos[i_l] - pos[j_l]).pow(2).sum(dim=-1).sqrt()
        
        # Get the edges pairwise distances in the global layer
        row, col = radius(pos, pos, self.cutoff, batch, batch, max_num_neighbors=500)
        edge_index_g = torch.stack([row, col], dim=0)
        edge_index_g, _ = remove_self_loops(edge_index_g)
        j_g, i_g = edge_index_g
        dist_g = (pos[i_g] - pos[j_g]).pow(2).sum(dim=-1).sqrt()
        
        # Compute the node indices for defining the angles
        idx_i_1, idx_j, idx_k, idx_kj, idx_ji, idx_i_2, idx_j1, idx_j2, idx_jj, idx_ji_2 = self.indices(edge_index_l, num_nodes=h.size(0))

        # Compute the two-hop angles
        pos_ji_1, pos_kj = pos[idx_j] - pos[idx_i_1], pos[idx_k] - pos[idx_j]
        a = (pos_ji_1 * pos_kj).sum(dim=-1)
        b = torch.cross(pos_ji_1, pos_kj).norm(dim=-1)
        angle_1 = torch.atan2(b, a)

        # Compute the one-hop angles
        pos_ji_2, pos_jj = pos[idx_j1] - pos[idx_i_2], pos[idx_j2] - pos[idx_j1]
        a = (pos_ji_2 * pos_jj).sum(dim=-1)
        b = torch.cross(pos_ji_2, pos_jj).norm(dim=-1)
        angle_2 = torch.atan2(b, a)

        # Get the RBF and SBF embeddings
        rbf_g = self.rbf_g(dist_g)
        rbf_l = self.rbf_l(dist_l)
        sbf_1 = self.sbf(dist_l, angle_1, idx_kj)
        sbf_2 = self.sbf(dist_l, angle_2, idx_jj)
        
        rbf_g = self.rbf_g_mlp(rbf_g)
        rbf_l = self.rbf_l_mlp(rbf_l)
        sbf_1 = self.sbf_1_mlp(sbf_1)
        sbf_2 = self.sbf_2_mlp(sbf_2)
        
        # Perform the message passing schemes
        node_sum = 0

        for layer in range(self.n_layer):
            h = self.global_layers[layer](h, rbf_g, edge_index_g)
            h, t = self.local_layers[layer](h, rbf_l, sbf_1, sbf_2, idx_kj, idx_ji, idx_jj, idx_ji_2, edge_index_l)
            node_sum += t
        
        # Readout
        output = global_add_pool(node_sum, batch)
        return output.view(-1)
    
class MXMNetEX(nn.Module):
    def __init__(self, config: Config, num_spherical=7, num_radial=6, envelope_exponent=5):
        super(MXMNetEX, self).__init__()

        self.dim = config.dim
        self.n_layer = config.n_layer
        self.cutoff = config.cutoff

        self.embeddings = nn.Parameter(torch.ones((12, self.dim)))

        self.rbf_l = BesselBasisLayer(16, 5, envelope_exponent)
        self.rbf_g = BesselBasisLayer(16, self.cutoff, envelope_exponent)
        self.sbf = SphericalBasisLayer(num_spherical, num_radial, 5, envelope_exponent)

        self.rbf_g_mlp = MLP([16, self.dim])
        self.rbf_l_mlp = MLP([16, self.dim])

        self.sbf_1_mlp = MLP([num_spherical * num_radial, self.dim])
        self.sbf_2_mlp = MLP([num_spherical * num_radial, self.dim])

        self.global_layers = torch.nn.ModuleList()
        for layer in range(config.n_layer):
            self.global_layers.append(Global_MP(config))

        self.local_layers = torch.nn.ModuleList()
        for layer in range(config.n_layer):
            self.local_layers.append(Local_MP(config))
        
        self.init()

    def init(self):
        stdv = math.sqrt(3)
        self.embeddings.data.uniform_(-stdv, stdv)

    def indices(self, edge_index, num_nodes):
        row, col = edge_index

        value = torch.arange(row.size(0), device=row.device)
        adj_t = SparseTensor(row=col, col=row, value=value,
                             sparse_sizes=(num_nodes, num_nodes))
        
        #Compute the node indices for two-hop angles
        adj_t_row = adj_t[row]
        num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)

        idx_i = col.repeat_interleave(num_triplets)
        idx_j = row.repeat_interleave(num_triplets)
        idx_k = adj_t_row.storage.col()
        mask = idx_i != idx_k
        idx_i_1, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]

        idx_kj = adj_t_row.storage.value()[mask]
        idx_ji_1 = adj_t_row.storage.row()[mask]

        #Compute the node indices for one-hop angles
        adj_t_col = adj_t[col]

        num_pairs = adj_t_col.set_value(None).sum(dim=1).to(torch.long)
        idx_i_2 = row.repeat_interleave(num_pairs)
        idx_j1 = col.repeat_interleave(num_pairs)
        idx_j2 = adj_t_col.storage.col()

        idx_ji_2 = adj_t_col.storage.row()
        idx_jj = adj_t_col.storage.value()

        return idx_i_1, idx_j, idx_k, idx_kj, idx_ji_1, idx_i_2, idx_j1, idx_j2, idx_jj, idx_ji_2


    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        pos = data.pos_ex
        batch = data.batch
        # Initialize node embeddings
        h = torch.index_select(self.embeddings, 0, x.long())

        # Get the edges and pairwise distances in the local layer
        edge_index_l, _ = remove_self_loops(edge_index)
        j_l, i_l = edge_index_l
        dist_l = (pos[i_l] - pos[j_l]).pow(2).sum(dim=-1).sqrt()
        
        # Get the edges pairwise distances in the global layer
        row, col = radius(pos, pos, self.cutoff, batch, batch, max_num_neighbors=500)
        edge_index_g = torch.stack([row, col], dim=0)
        edge_index_g, _ = remove_self_loops(edge_index_g)
        j_g, i_g = edge_index_g
        dist_g = (pos[i_g] - pos[j_g]).pow(2).sum(dim=-1).sqrt()
        
        # Compute the node indices for defining the angles
        idx_i_1, idx_j, idx_k, idx_kj, idx_ji, idx_i_2, idx_j1, idx_j2, idx_jj, idx_ji_2 = self.indices(edge_index_l, num_nodes=h.size(0))

        # Compute the two-hop angles
        pos_ji_1, pos_kj = pos[idx_j] - pos[idx_i_1], pos[idx_k] - pos[idx_j]
        a = (pos_ji_1 * pos_kj).sum(dim=-1)
        b = torch.cross(pos_ji_1, pos_kj).norm(dim=-1)
        angle_1 = torch.atan2(b, a)

        # Compute the one-hop angles
        pos_ji_2, pos_jj = pos[idx_j1] - pos[idx_i_2], pos[idx_j2] - pos[idx_j1]
        a = (pos_ji_2 * pos_jj).sum(dim=-1)
        b = torch.cross(pos_ji_2, pos_jj).norm(dim=-1)
        angle_2 = torch.atan2(b, a)

        # Get the RBF and SBF embeddings
        rbf_g = self.rbf_g(dist_g)
        rbf_l = self.rbf_l(dist_l)
        sbf_1 = self.sbf(dist_l, angle_1, idx_kj)
        sbf_2 = self.sbf(dist_l, angle_2, idx_jj)
        
        rbf_g = self.rbf_g_mlp(rbf_g)
        rbf_l = self.rbf_l_mlp(rbf_l)
        sbf_1 = self.sbf_1_mlp(sbf_1)
        sbf_2 = self.sbf_2_mlp(sbf_2)
        
        # Perform the message passing schemes
        node_sum = 0

        for layer in range(self.n_layer):
            h = self.global_layers[layer](h, rbf_g, edge_index_g)
            h, t = self.local_layers[layer](h, rbf_l, sbf_1, sbf_2, idx_kj, idx_ji, idx_jj, idx_ji_2, edge_index_l)
            node_sum += t
        
        # Readout
        output = global_add_pool(node_sum, batch)
        return output.view(-1)

# Predict G

In [8]:
class Double_MXMNET(nn.Module):
    def __init__(self, config):
        super(Double_MXMNET, self).__init__()

        self.ground = MXMNetG(config)
        
        self.excited = MXMNetEX(config)
    
    def forward(self, data):
        ground_out = self.ground(data)
        excited_out = self.excited(data)
        
        out = excited_out - ground_out
                
        return out

In [9]:
config = Config(dim=dim, n_layer=n_layer, cutoff=cutoff)
model = Double_MXMNET(config)

best_checkpoint = torch.load('models/Double_MXMNET_for_g.pth', map_location=device)
model.load_state_dict(best_checkpoint)
model.eval()
model.to(device)

print('Loaded the MXMNet.')

ema = EMA(model, decay=0.999)

g_pred = test(test_loader, model)

Loaded the MXMNet.


# Predict EX

In [10]:
class Double_MXMNET(nn.Module):
    def __init__(self, config):
        super(Double_MXMNET, self).__init__()

        self.ground = MXMNetG(config)
        
        self.excited = MXMNetEX(config)
    
    def forward(self, data):
        ground_out = self.ground(data)
        excited_out = self.excited(data)
        
        out = ground_out - excited_out
                
        return out

In [11]:
config = Config(dim=dim, n_layer=n_layer, cutoff=cutoff)
model = Double_MXMNET(config)

best_checkpoint = torch.load('models/Double_MXMNET_for_ex.pth', map_location=device)
model.load_state_dict(best_checkpoint)
model.eval()
model.to(device)

print('Loaded the MXMNet.')

ema = EMA(model, decay=0.999)

ex_pred = test(test_loader, model)

Loaded the MXMNet.


# Make submission file

In [12]:
submit = pd.read_csv('data/sample_submission.csv')
submit['Reorg_ex'] = ex_pred
submit['Reorg_g'] = g_pred

submit.to_csv(f'submission.csv', index=False)
print('Done.')

submit

Done.


Unnamed: 0,index,Reorg_g,Reorg_ex
0,test_0,0.387975,0.309783
1,test_1,0.231006,0.210665
2,test_2,0.285344,0.268960
3,test_3,0.393129,0.264282
4,test_4,0.167339,0.208210
...,...,...,...
452,test_452,0.220338,0.224084
453,test_453,0.171201,0.147907
454,test_454,0.211920,0.168242
455,test_455,0.198944,0.164223
