In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


# Data

In [None]:
! pip install rdkit

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting rdkit
  Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)
[K     |████████████████████████████████| 36.8 MB 1.2 MB/s 
Installing collected packages: rdkit
Successfully installed rdkit-2022.3.5


In [None]:
import pickle
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
import os
import sys
import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from tqdm import tqdm
from sklearn.model_selection import KFold, train_test_split
from torch.utils.data import DataLoader, SubsetRandomSampler
import random
import argparse
import pdb
from torch import nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter


###############################################
#                                             #
#              Dataset Base Class             #
#                                             #
###############################################


def onek_encoding_unk(x, allowable_set):
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def onek_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception('input {0} not in allowable set{1}:'.format(x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))

def featurization(x):

    return x

def check_exists(path):
    return True if os.path.isfile(path) and os.path.getsize(path) > 0 else False

def add_index(input_array, ebd_size):
    add_idx, temp_arrays = 0, []
    for i in range(input_array.shape[0]):
        temp_array = input_array[i,:,:]
        masking_indices = temp_array.sum(1).nonzero()
        temp_array += add_idx
        temp_arrays.append(temp_array)
        add_idx = masking_indices[0].max()+1
    new_array = np.concatenate(temp_arrays, 0)

    return new_array.reshape(-1)


class DtiDatasetBase(Dataset):
    def __init__(self, args):
        self.args = args
        self.data_instances, self.meta_instances = [], []

        self.analysis_mode = False

        # Gathering All Meta-data from DTI Datasets
        self.data_path = os.path.join(args.root_path, f'dataset_{args.dataset_version}/')
        complex_dataframe, protein_dataframe, ligand_dataframe = [], [], []

        for dataset in args.dataset_subsets.split('+'):
            complex_path = f'{self.data_path}complex_metadata_{dataset}.csv'
            protein_path = f'{self.data_path}protein_metadata_{dataset}.csv'
            ligand_path = f'{self.data_path}ligand_metadata_{dataset}.csv'
            complex_dataframe.append(pd.read_csv(complex_path, index_col='complex_id'))
            protein_dataframe.append(pd.read_csv(protein_path, index_col='protein_id'))
            ligand_dataframe.append(pd.read_csv(ligand_path, index_col='ligand_id'))

        self.complex_dataframe = pd.concat(complex_dataframe)
        self.protein_dataframe = pd.concat(protein_dataframe)
        self.ligand_dataframe = pd.concat(ligand_dataframe)

        self.complex_dataframe = self.complex_dataframe[self.complex_dataframe['ba_measure']==args.ba_measure]
        if not args.inference_mode:
            self.complex_dataframe.dropna(subset=['ba_value'], axis=0, inplace=True)

        self.complex_indices = self.complex_dataframe.index
        if args.debug_mode or args.toy_test:
            self.complex_indices = self.complex_dataframe.index[:args.debug_index]

        self.kfold_splits = []

        # Which Features to Include?
        self.protein_features = args.protein_features # 'esm+blosum+onehot'
        self.ligand_features = args.ligand_features

    def check_ligand(self, ligand_idx):
        return

    def check_protein(self, protein_idx):
        if self.pdf.loc[protein_idx, 'fasta_length'] >= 1000:
            raise FastaLengthException(self.pdf.loc[protein_idx, 'fasta_length'])

    def check_complex(self, complex_idx):
        return

    def __len__(self):
        return len(self.data_instances)

    def __getitem__(self, idx):
        if self.analysis_mode:
        # if self.args.analysis_model:
            return self.data_instances[idx], self.meta_instances[idx]
        else:
            return self.data_instances[idx]

    def make_random_splits(self):
        print("Making Random Splits")
        kf = KFold(n_splits=5, shuffle=True)
        for a, b in kf.split(self.indices):
            train_indices, test_indices = a, b
            train_indices, valid_indices = train_test_split(train_indices, test_size=0.05)
            self.kfold_splits.append((train_indices, valid_indices, test_indices))

class FastaLengthException(Exception):
    def __init__(self, fasta_length, message="fasta length should not exceed 1000"):
        self.fasta_length = fasta_length
        self.message = message
        super().__init__(self.message)

    def __str__(self):
        return f'{self.fasta_length} -> {self.message}'

class NoProteinGraphException(Exception):
    def __init__(self, protein_idx, message="protein graph structure file not available"):
        self.protein_idx = protein_idx
        self.message = message
        super().__init__(self.message)

    def __str__(self):
        return f'{self.protein_idx} -> {self.message}'

class NoProteinFeaturesException(Exception):
    def __init__(self, protein_idx, message="protein advanced features file not available"):
        self.protein_idx = protein_idx
        self.message = message
        super().__init__(self.message)

    def __str__(self):
        return f'{self.protein_idx} -> {self.message}'

class NoComplexGraphException(Exception):
    def __init__(self, complex_idx, message="complex advanced features file not available"):
        self.complex_idx = complex_idx
        self.message = message
        super().__init__(self.message)

    def __str__(self):
        return f'{self.complex_idx} -> {self.message}'


###############################################
#                                             #
#              Collate Functions              #
#                                             #
###############################################

def stack_and_pad(arr_list, max_length=None):
    M = max([x.shape[0] for x in arr_list]) if not max_length else max_length
    N = max([x.shape[1] for x in arr_list])
    T = np.zeros((len(arr_list), M, N))
    t = np.zeros((len(arr_list), M))
    s = np.zeros((len(arr_list), M, N))

    for i, arr in enumerate(arr_list):
        # sum of 16 interaction type, one is enough
        if len(arr.shape) > 2:
            arr = (arr.sum(axis=2) > 0.0).astype(float)
        T[i, 0:arr.shape[0], 0:arr.shape[1]] = arr
        t[i, 0:arr.shape[0]] = 1 if arr.sum() != 0.0 else 0
        s[i, 0:arr.shape[0], 0:arr.shape[1]] = 1 if arr.sum() != 0.0 else 0
    return T, t, s

def stack_and_pad_2d(arr_list, block='lower_left', max_length=None):
    max0 = max([a.shape[0] for a in arr_list]) if not max_length else max_length
    max1 = max([a.shape[1] for a in arr_list])
    list_shapes = [a.shape for a in arr_list]

    final_result = np.zeros((len(arr_list), max0, max1))
    final_masks_2d = np.zeros((len(arr_list), max0))
    final_masks_3d = np.zeros((len(arr_list), max0, max1))

    if block == 'upper_left':
        for i, shape in enumerate(list_shapes):
            # sum of 16 interaction type, one is enough
            if len(arr_list[i].shape) > 2:
                arr_list[i] = (arr_list[i].sum(axis=2) == True).astype(float)
            final_result[i, :shape[0], :shape[1]] = arr_list[i]
            final_masks_2d[i, :shape[0]] = 1
            final_masks_3d[i, :shape[0], :shape[1]] = 1
    elif block == 'lower_right':
        for i, shape in enumerate(list_shapes):
            final_result[i, max0-shape[0]:, max1-shape[1]:] = arr_list[i]
            final_masks_2d[i, max0-shape[0]:] = 1
            final_masks_3d[i, max0-shape[0]:, max1-shape[1]:] = 1
    elif block == 'lower_left':
        for i, shape in enumerate(list_shapes):
            final_result[i, max0-shape[0]:, :shape[1]] = arr_list[i]
            final_masks_2d[i, max0-shape[0]:] = 1
            final_masks_3d[i, max0-shape[0]:, :shape[1]] = 1
    elif block == 'upper_right':
        for i, shape in enumerate(list_shapes):
            final_result[i, :shape[0], max1-shape[1]:] = arr_list[i]
            final_masks_2d[i, :shape[0]] = 1
            final_masks_3d[i, :shape[0], max1-shape[1]:] = 1
    else:
        raise

    return final_result, final_masks_2d, final_masks_3d

def stack_and_pad_3d(arr_list, block='lower_left'):
    max0 = max([a.shape[0] for a in arr_list])
    max1 = max([a.shape[1] for a in arr_list])
    max2 = max([a.shape[2] for a in arr_list])
    list_shapes = [a.shape for a in arr_list]

    final_result = np.zeros((len(arr_list), max0, max1, max2))
    final_masks_2d = np.zeros((len(arr_list), max0))
    final_masks_3d = np.zeros((len(arr_list), max0, max1))
    final_masks_4d = np.zeros((len(arr_list), max0, max1, max2))

    if block == 'upper_left':
        for i, shape in enumerate(list_shapes):
            final_result[i, :shape[0], :shape[1], :shape[2]] = arr_list[i]
            final_masks_2d[i, :shape[0]] = 1
            final_masks_3d[i, :shape[0], :shape[1]] = 1
            final_masks_4d[i, :shape[0], :shape[1], :] = 1
    elif block == 'lower_right':
        for i, shape in enumerate(list_shapes):
            final_result[i, max0-shape[0]:, max1-shape[1]:] = arr_list[i]
            final_masks_2d[i, max0-shape[0]:] = 1
            final_masks_3d[i, max0-shape[0]:, max1-shape[1]:] = 1
            final_masks_4d[i, max0-shape[0]:, max1-shape[1]:, :] = 1
    elif block == 'lower_left':
        for i, shape in enumerate(list_shapes):
            final_result[i, max0-shape[0]:, :shape[1]] = arr_list[i]
            final_masks_2d[i, max0-shape[0]:] = 1
            final_masks_3d[i, max0-shape[0]:, :shape[1]] = 1
            final_masks_4d[i, max0-shape[0]:, :shape[1], :] = 1
    elif block == 'upper_right':
        for i, shape in enumerate(list_shapes):
            final_result[i, :shape[0], max1-shape[1]:] = arr_list[i]
            final_masks_2d[i, :shape[0]] = 1
            final_masks_3d[i, :shape[0], max1-shape[1]:] = 1
            final_masks_4d[i, :shape[0], max1-shape[1]:, :] = 1
    else:
        raise

    return final_result, final_masks_2d, final_masks_3d, final_masks_4d

def ds_normalize(input_array):
    # Doubly Stochastic Normalization of Edges from CVPR 2019 Paper
    assert len(input_array.shape) == 3
    input_array = input_array / np.expand_dims(input_array.sum(1)+1e-8, axis=1)
    output_array = np.einsum('ijb,jkb->ikb', input_array,
                             input_array.transpose(1, 0, 2))
    output_array = output_array / (output_array.sum(0)+1e-8)

    return output_array

In [None]:
def add_index(input_array, ebd_size):
    add_idx, temp_arrays = 0, []
    for i in range(input_array.shape[0]): #batch size
        temp_array = input_array[i,:,:]
        masking_indices = temp_array.sum(1).nonzero()
        #print(masking_indices)
        temp_array += add_idx
        temp_arrays.append(temp_array)
        add_idx = masking_indices[0].max()+1
    new_array = np.concatenate(temp_arrays, 0)

    return new_array.reshape(-1)

In [None]:
BLOSUM_DICT = {
	'A': [4,0,-2,-1,-2,0,-2,-1,-1,-1,-1,-2,-1,-1,-1,1,0,0,-3,-2],
	'C': [0,9,-3,-4,-2,-3,-3,-1,-3,-1,-1,-3,-3,-3,-3,-1,-1,-1,-2,-2],
	'D': [-2,-3,6,2,-3,-1,-1,-3,-1,-4,-3,1,-1,0,-2,0,-1,-3,-4,-3],
	'E': [-1,-4,2,5,-3,-2,0,-3,1,-3,-2,0,-1,2,0,0,-1,-2,-3,-2],
	'F': [-2,-2,-3,-3,6,-3,-1,0,-3,0,0,-3,-4,-3,-3,-2,-2,-1,1,3],
	'G': [0,-3,-1,-2,-3,6,-2,-4,-2,-4,-3,0,-2,-2,-2,0,-2,-3,-2,-3],
	'H': [-2,-3,-1,0,-1,-2,8,-3,-1,-3,-2,1,-2,0,0,-1,-2,-3,-2,2],
	'I': [-1,-1,-3,-3,0,-4,-3,4,-3,2,1,-3,-3,-3,-3,-2,-1,3,-3,-1],
	'K': [-1,-3,-1,1,-3,-2,-1,-3,5,-2,-1,0,-1,1,2,0,-1,-2,-3,-2],
	'L': [-1,-1,-4,-3,0,-4,-3,2,-2,4,2,-3,-3,-2,-2,-2,-1,1,-2,-1],
	'M': [-1,-1,-3,-2,0,-3,-2,1,-1,2,5,-2,-2,0,-1,-1,-1,1,-1,-1],
	'N': [-2,-3,1,0,-3,0,1,-3,0,-3,-2,6,-2,0,0,1,0,-3,-4,-2],
	'P': [-1,-3,-1,-1,-4,-2,-2,-3,-1,-3,-2,-2,7,-1,-2,-1,-1,-2,-4,-3],
	'Q': [-1,-3,0,2,-3,-2,0,-3,1,-2,0,0,-1,5,1,0,-1,-2,-2,-1],
	'R': [-1,-3,-2,0,-3,-2,0,-3,2,-2,-1,0,-2,1,5,-1,-1,-3,-3,-2],
	'S': [1,-1,0,0,-2,0,-1,-2,0,-2,-1,1,-1,0,-1,4,1,-2,-3,-2],
	'T': [0,-1,-1,-1,-2,-2,-2,-1,-1,-1,-1,0,-1,-1,-1,1,5,0,-2,-2],
	'V': [0,-1,-3,-2,-1,-3,-3,3,-2,1,1,-3,-2,-2,-3,-2,0,4,-3,-1],
	'W': [-3,-2,-4,-3,1,-2,-2,-3,-3,-2,-1,-4,-4,-2,-3,-3,-2,-3,11,2],
	'Y': [-2,-2,-3,-2,3,-3,2,-1,-2,-1,-1,-2,-3,-1,-2,-2,-2,-1,2,7],
    'X': [0 for _ in range(20)],
	'unk':[0 for _ in range(20)]}

ATOM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg',
             'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl',
             'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H',
             'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr',
             'Pt', 'Hg', 'Pb', 'W', 'Ru', 'Nb', 'Re', 'Te', 'Rh', 'Tc',
             'Ba', 'Bi', 'Hf', 'Mo', 'U', 'Sm', 'Os', 'Ir', 'Ce','Gd',
             'Ga','Cs', 'unk']


class DtiDataset(DtiDatasetBase):
    def __init__(self, args):
        super().__init__(args)
        self.default_batch_size = 32
        max_nb = 10
        for complex_idx in tqdm(self.complex_indices):
            try:
                ligand_idx = self.complex_dataframe.loc[complex_idx, 'ligand_id']
                protein_idx = self.complex_dataframe.loc[complex_idx, 'protein_id']
                ba_value = self.complex_dataframe.loc[complex_idx, 'ba_value']
                self.check_protein(protein_idx)

                # Ligand / Atom / Features
                atomwise_features = []
                smiles = self.ligand_dataframe.loc[ligand_idx, 'smiles']
                mol = Chem.MolFromSmiles(smiles)
                for atom in mol.GetAtoms():
                    try: atomwise_features.append(self.atom_features(atom).reshape(1, -1))
                    except: atomwise_features.append(np.zeros(82).reshape(1, -1))
                atomwise_features = np.vstack(atomwise_features) # atom feature vector, length 82

                idxfunc = lambda x: x.GetIdx()
                n_atoms = mol.GetNumAtoms()
                assert mol.GetNumBonds() >= 0
                n_bonds = max(mol.GetNumBonds(), 1)
                atom_nb = np.zeros((n_atoms, max_nb), dtype=np.int32)
                bond_nb = np.zeros((n_atoms, max_nb), dtype=np.int32)
                num_nbs = np.zeros((n_atoms,), dtype=np.int32)
                num_nbs_mat = np.zeros((n_atoms, max_nb), dtype=np.int32)

                # Ligand / Bond / Features
                bondwise_features = ['null' for _ in range(n_bonds)]
                for bond in mol.GetBonds():
                    a1, a2 = idxfunc(bond.GetBeginAtom()), idxfunc(bond.GetEndAtom())
                    bondwise_features[bond.GetIdx()] = self.bond_features(bond).reshape(1, -1)
                    # IndexError: index 6 is out of bounds for axis 1 with size 6
                    # bond_nb[a1, num_nbs[a1]]=bond.GetIdx()
                    atom_nb[a1, num_nbs[a1]] = a2
                    atom_nb[a2, num_nbs[a2]] = a1
                    bond_nb[a1, num_nbs[a1]] = bond.GetIdx()
                    bond_nb[a2, num_nbs[a2]] = bond.GetIdx()
                    num_nbs[a1] += 1
                    num_nbs[a2] += 1
                bondwise_features = np.vstack(bondwise_features)
                for i in range(len(num_nbs)):
                    num_nbs_mat[i, :num_nbs[i]] = 1

                # Protein / Residue / Features
                resiwise_features = []
                fasta = self.protein_dataframe.loc[protein_idx, 'fasta']
                for resi in fasta:
                    resiwise_features.append(np.array(self.resi_features(resi)).reshape(1, -1))
                resiwise_features = np.vstack(resiwise_features)

                # Complex / Residue / 2D Graph
                atomatom_graph = Chem.rdmolops.GetAdjacencyMatrix(mol)
                plip_path = f'{self.data_path}complexes/{complex_idx}/{complex_idx}.plip.npy'
                if check_exists(plip_path):
                    atomresi_graph = np.load(plip_path)[:,:,:-1]
                    atomresi_label = np.ones((atomwise_features.shape[0], resiwise_features.shape[0], 1))
                else:
                    atomresi_graph = np.zeros((atomwise_features.shape[0], resiwise_features.shape[0], 1))
                    atomresi_label = np.zeros((atomwise_features.shape[0], resiwise_features.shape[0], 1))
                smiles = ''.join(list(filter(str.isalpha, smiles)))

                metadata = (complex_idx, ligand_idx, protein_idx, smiles, fasta, ba_value)
                pytrdata = (atomwise_features, bondwise_features, atom_nb, bond_nb, num_nbs_mat, resiwise_features, atomresi_graph, ba_value)

                self.data_instances.append(pytrdata)
                self.meta_instances.append(metadata)

            except Exception as e:
                pass

        print("Number of data samples for MONN: ", len(self.data_instances))
        self.indices = [i for i in range(len(self.data_instances))]

    def check_protein(self, protein_idx):
        if self.protein_dataframe.loc[protein_idx, 'fasta_length'] >= 1000:
            raise FastaLengthException(self.protein_dataframe.loc[protein_idx, 'fasta_length'])

    def check_complex(self, complex_idx):
        return
        # if not check_exists(f'{self.data_path}complexes/{complex_idx}/{complex_idx}.arpeggio.npy'):
        #     import pdb;
        #     pdb.set_trace()
        #     raise NoComplexGraphException(complex_idx)

    def atom_features(self, atom):
        return np.array(onek_encoding_unk(atom.GetSymbol(), ATOM_LIST)
                        + onek_encoding_unk(atom.GetDegree(), [0, 1, 2, 3, 4, 5])
                        + onek_encoding_unk(atom.GetExplicitValence(), [1, 2, 3, 4, 5, 6])
                        + onek_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5])
                        + [atom.GetIsAromatic()], dtype=np.float32)

    def bond_features(self, bond):
        bt = bond.GetBondType()
        return np.array([bt == Chem.rdchem.BondType.SINGLE,
                        bt == Chem.rdchem.BondType.DOUBLE,
                        bt == Chem.rdchem.BondType.TRIPLE,
                        bt == Chem.rdchem.BondType.AROMATIC,
                        bond.GetIsConjugated(),
                        bond.IsInRing()], dtype=np.float32)

    def resi_features(self, resi):
        return BLOSUM_DICT[resi]


def collate_fn(batch):
    tensor_list = []
    list_atomwise_features = [x[0] for x in batch]
    list_bondwise_features = [x[1] for x in batch]
    list_atom_neighbors = [x[2] for x in batch]
    list_bond_neighbors = [x[3] for x in batch]
    list_neighbor_matrices = [x[4] for x in batch]
    list_resiwise_features = [x[5] for x in batch]
    list_atomresi_graphs = [(x[6] > 0.).astype(np.int_) for x in batch]
    list_ba_values = [x[7] for x in batch]

    x, y, _ = stack_and_pad(list_atomwise_features)
    tensor_list.append(torch.cuda.FloatTensor(x))
    tensor_list.append(torch.cuda.FloatTensor(y))
    x, _, _ = stack_and_pad(list_bondwise_features)
    tensor_list.append(torch.cuda.FloatTensor(x))
    x, _, _ = stack_and_pad(list_atom_neighbors)
    tensor_list.append(torch.cuda.LongTensor(add_index(x, x.shape[1])))
    x, _, _ = stack_and_pad(list_bond_neighbors)
    tensor_list.append(torch.cuda.LongTensor(add_index(x, x.shape[1])))
    x, _, _ = stack_and_pad(list_neighbor_matrices)
    tensor_list.append(torch.cuda.FloatTensor(x))
    x, y, _ = stack_and_pad(list_resiwise_features)
    tensor_list.append(torch.cuda.FloatTensor(x))
    tensor_list.append(torch.cuda.FloatTensor(y))
    tensor_list.append(torch.cuda.FloatTensor(list_ba_values).view(-1, 1))
    x, _, z = stack_and_pad(list_atomresi_graphs)
    tensor_list.append(torch.cuda.FloatTensor(x))
    tensor_list.append(torch.cuda.FloatTensor(z))

    return tensor_list

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

parser = argparse.ArgumentParser()
parser.add_argument('--root_path', default = '/content/drive/MyDrive/Colab Notebooks/monn', type =str)
parser.add_argument('--dataset_version', default=220722, type=int)
parser.add_argument('--dataset_subsets', default='pdb_2020_general', type=str)
parser.add_argument('--ba_measure', default='KIKD', type=str)
parser.add_argument('--inference_mode', default = False)
parser.add_argument('--debug_mode', default=False)
parser.add_argument('--toy_test', default=False)
parser.add_argument('--debug_index', default = False)
parser.add_argument('--protein_features', type = str)
parser.add_argument('--ligand_features', type = str)


parser.add_argument('--fold', default = 5, type=int)
parser.add_argument('--batch_size', default = 8, type=int)
#'KIKD' 'IC50'
args, unknown = parser.parse_known_args()


random.seed(42)

dataset = DtiDataset(args)

100%|██████████| 9136/9136 [00:30<00:00, 296.68it/s]

Number of data samples for MONN:  9085





- Protein(residue) => CNN
BLOSUM62 matrix => 각 residue의 특성을 encoding한 matrix, 20x20  
각 residue의 max개수를 이용하여 padding  

- Compound(atom) => Graph convolution  
Node feature : 길이가 82인 feature vector (원자 종류 63 + Degree 6 + Explicit valence 6 + implicit valence 6 + aromatic 1 = 82)  
Edge feature(bond) : 길이가 6인 feature vector (결합 종류 4 + conjugate,ring 2 = 6)  

tensor list의 길이는 11  
1,2 : atomwise feature 각 배치별,
3 : bondwise feature  
4 : atom neighbors  
5 : bond neighbors  
6 : neighbor matrix  
7,8 : resiwise feature  
9 : ba_value  
10,11 : atmoresi graph

In [None]:
path = '/content/drive/MyDrive/Colab Notebooks/monn'

In [None]:
dataset.make_random_splits()
indices = dataset.kfold_splits

fold_indices = indices[0]
train_idx, valid_idx, test_idx = fold_indices[0], fold_indices[1], fold_indices[2]
train = SubsetRandomSampler(train_idx)
valid = SubsetRandomSampler(valid_idx)
test = SubsetRandomSampler(test_idx)

train_loader = DataLoader(dataset, batch_size = 32,  collate_fn = collate_fn, sampler = train)
valid_loader = DataLoader(dataset, batch_size =16, collate_fn = collate_fn, sampler = valid)
test_loader = DataLoader(dataset, batch_size = 8, collate_fn = collate_fn, sampler = test)


data = next(iter(test_loader))
#data2 = next(iter(valid_loader))
with open(path + '/data/data_sample.pkl', 'wb') as writer:
    pickle.dump(data, writer)

Making Random Splits


In [None]:
for i , train_data in enumerate(train_loader):
  a = train_data[0]

In [None]:
data[0].shape

torch.Size([8, 60, 82])

* 364는 batch size  
* Batch의 compound 안에 있는 atom의 최대 개수  
* atom feature vector

In [None]:
data[1].shape

torch.Size([8, 33])

In [None]:
data[1] #masking 정보

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1

In [None]:
v_mask = data[1].reshape(data[1].size(0),-1,1)
p_mask = data[7].reshape(data[7].size(0),-1,1)

In [None]:
a = data[0]*v_mask
a = torch.sum(a,dim=1)/torch.sum(v_mask, dim = 1)

b = data[6]*p_mask
b = torch.sum(b,dim=1)/torch.sum(p_mask, dim=1)

In [None]:
a.shape, b.shape

(torch.Size([8, 82]), torch.Size([8, 20]))

In [None]:
def mask_softmax(a, mask, dim=-1):
		a_max = torch.max(a,dim,keepdim=True)[0]
		a_exp = torch.exp(a-a_max)
		a_exp = a_exp*mask
		a_softmax = a_exp/(torch.sum(a_exp,dim,keepdim=True)+1e-6)
		return a_softmax

In [None]:
mc0 = nn.Linear(82, 82).cuda()
a = mc0(data[0])

In [None]:
a.shape

torch.Size([8, 50, 82])

In [None]:
vs = nn.Linear(82, 1).cuda()
a = vs(a)
a.shape

torch.Size([8, 50, 1])

In [None]:
att = mask_softmax(a.view(a.size(0),-1),data[1].view(data[1].size(0),-1))

In [None]:
gru = nn.GRUCell(50, 50).cuda()
gru(data[1],data[1]).shape

torch.Size([8, 50])

In [None]:
data[2].shape

torch.Size([8, 41, 6])

* 364는 batch size  
* Batch의 compound 안에 있는 edge의 최대 개수  
* edge feature vector

In [None]:
data[3].shape

torch.Size([3040])

- atom neighbors

In [None]:
data[3]

tensor([ 1,  0,  0,  ..., 14, 14, 14], device='cuda:0')

In [None]:
atom_features = data[0]
bond_features = data[2]
atom_neighbor = data[3]
bond_neighbor = data[4]

batch_size = atom_features.size(0)
neighbors = torch.zeros(size=(atom_features.size(0), atom_features.size(1), 10, atom_features.size(2)+6)).cpu()

atom_neighbor = atom_neighbor.reshape(batch_size, -1, 10)
bond_neighbor = bond_neighbor.reshape(batch_size, -1, 10)

atom_id = torch.tensor([atom_neighbor[mol, -1, -1].item() for mol in range(batch_size)]) # atom neighbor에 있는 추가된 index의 값 저장
atom_id = atom_id.reshape(-1, 1).unsqueeze(2).expand(batch_size, -1, 10)
atom_id = atom_id.cuda()

atom_neighbor = (atom_neighbor - atom_id).type(torch.int64) # 추가된 index를 빼서 원래 index로 만들어주기
bond_neighbor = (bond_neighbor - atom_id).type(torch.int64) # 추가된 index를 빼서 원래 index로 만들어주기

atom_ids = data[1]
neighbor_matrix = data[5]
for mol_idx, atom_idx in torch.nonzero(atom_ids):

    neighbor_idx = torch.nonzero(neighbor_matrix[mol_idx][atom_idx])[-1].item()
    atom_indices = atom_neighbor[mol_idx, atom_idx, :neighbor_idx+1]
    bond_indices = bond_neighbor[mol_idx, atom_idx, :neighbor_idx+1]

    neighbor_atom_tensor = atom_features[mol_idx, atom_indices]
    neighbor_bond_tensor = bond_features[mol_idx, bond_indices]

    concat_atom_info = torch.cat((neighbor_atom_tensor, neighbor_bond_tensor), dim = 1)
    self
    neighbors[mol_idx, atom_idx, :atom_indices.size(0), : ] = concat_atom_info


In [None]:
import torch.nn.functional as F
def initialize_(batch_size, input_dim, hidden_dim): # batch_size x input_dim x hidden_dim의 weight
    weight = torch.nn.Parameter(torch.FloatTensor(input_dim, hidden_dim))
    torch.nn.init.kaiming_uniform_(weight)
    weight = weight.reshape(1, input_dim, hidden_dim)
    weight = weight.expand(batch_size,input_dim, hidden_dim)
    return weight
def embedding_(input, weight): # batch x 원자 개수 x hidden_dim의 representation
    return F.leaky_relu(torch.bmm(input, weight), 0.1)
weight1 = initialize_(n.size(0),88,82)

In [None]:
a = embedding_(n,weight1).sum(dim = 1).view(8, -1 , 82)
updated_local_info = torch.cat((a, a), dim =2)
print(updated_local_info.size())

weight2 = initialize_(8,164,82)
b = embedding_(updated_local_info,weight2)
print(b.size())

torch.Size([8, 38, 164])
torch.Size([8, 38, 82])


masking 과정  
atom - atom neighbor에서 0인데 atom 0과 연결되어 있는 경우  
bond - bond neighbor에서 0인데 edge 0인 경우

- bond neighbors

In [None]:
data[5].shape

torch.Size([8, 27, 10])

- neighbor matrix  
- 10은 max number(의미?)

In [None]:
class CNN(nn.Module):
    def __init__(self, num_layer, hidden_size, kernel_size):
        super(CNN, self).__init__()
        self.num_layer = num_layer
        self.hidden_size = hidden_size
        self.kernel_size = kernel_size

        #self.embed_seq = nn.Embedding(len(self.init_word_features), 20, padding_idx=0)
        #self.embed_seq.weight = nn.Parameter(self.init_word_features)
        #self.embed_seq.weight.requires_grad = False

        self.conv_first = nn.Conv1d(20, self.hidden_size , kernel_size=self.kernel_size,
                                    padding=int((self.kernel_size - 1) / 2))
        self.conv_last = nn.Conv1d(self.hidden_size, self.hidden_size, kernel_size=self.kernel_size,
                                   padding=int((self.kernel_size - 1) / 2))

        self.plain_CNN = nn.ModuleList([])
        for i in range(self.num_layer):
            self.plain_CNN.append(nn.Conv1d(self.hidden_size, self.hidden_size, kernel_size=self.kernel_size,
                                            padding=int((self.kernel_size - 1) / 2)))
    def forward(self, x):
        sequence = x.transpose(1,2) # batch x h1 x residue
        #embedding = self.embed_seq(sequence)
        x = F.leaky_relu(self.conv_first(sequence),0.1)
        for num in range(self.num_layer):
            x = self.plain_CNN[num](x)
            x = F.leaky_relu(x,0.1)
        x = F.leaky_relu(self.conv_last(x),0.1)
        x = x.transpose(1,2)

        return x

In [None]:
cnn.cuda()

CNN(
  (conv_first): Conv1d(20, 128, kernel_size=(5,), stride=(1,), padding=(2,))
  (conv_last): Conv1d(128, 128, kernel_size=(5,), stride=(1,), padding=(2,))
  (plain_CNN): ModuleList(
    (0): Conv1d(128, 128, kernel_size=(5,), stride=(1,), padding=(2,))
    (1): Conv1d(128, 128, kernel_size=(5,), stride=(1,), padding=(2,))
    (2): Conv1d(128, 128, kernel_size=(5,), stride=(1,), padding=(2,))
    (3): Conv1d(128, 128, kernel_size=(5,), stride=(1,), padding=(2,))
  )
)

In [None]:
cnn = CNN(4,128,5)
output = cnn.cuda()(data[6])

In [None]:
output.shape

torch.Size([8, 366, 128])

In [None]:
embed = nn.Embedding(data[6].size(1), 20, padding_idx=0)
conv_first = nn.Conv1d(20, 128 , kernel_size=5, padding=2)



In [None]:
torch.tensor([2,31])

tensor([ 2, 31])

In [None]:
embed(torch.tensor([2,291]))

tensor([[-0.7318,  0.8266, -0.3092,  1.6290, -0.3277,  1.7404, -0.3083,  0.1312,
         -0.2667,  0.4664, -0.3234,  0.1287,  0.2820, -0.0963, -1.8444,  0.6796,
         -1.2653, -0.3600, -0.3921,  0.6482],
        [ 1.1570,  0.3718,  0.0839,  0.3193,  0.7952,  0.4620, -0.9109, -1.0247,
         -0.8341,  0.6086,  0.9505, -0.4411, -1.8334, -1.9562, -1.2634,  0.2680,
          0.1432, -1.3979, -0.2807,  1.4928]], grad_fn=<EmbeddingBackward0>)

In [None]:
torch.randn(1,2,3)

tensor([[[ 0.1544, -0.8639,  1.1462],
         [ 0.4451, -0.5022,  0.7280]]])

In [None]:
conv_first(data[6].transpose(1,2).cpu()).shape

torch.Size([8, 128, 366])

In [None]:
data[6].dtype

torch.float32

- Batch의 protein 안에 있는 residue의 최대 개수
- Blosum vector 길이 - 20  

In [None]:
data[7].shape

torch.Size([8, 846])

In [None]:
data[8].shape

torch.Size([8, 1])

- Target value

In [None]:
data[9].shape

torch.Size([8, 27, 846])

In [None]:
data[10].shape

torch.Size([8, 27, 846])

- 각 배치별 compound의 atom과 protein의 residue간의 정보

# Model

In [None]:
'''base.py'''
@torch.no_grad()
def store_representations(self, input, output):
    self.representations.append(output.detach().cpu().numpy())
    return

@torch.no_grad()
def store_interactions(self, input, output):
    self.interactions.append(output.detach().cpu().numpy())
    return

@torch.no_grad()
def store_atomwise_compound_representations(self, input, output):
	self.atomwise_representations.append(output[-2].detach().cpu().numpy())
	self.compound_representations.append(output[-1].detach().cpu().numpy())
	return

@torch.no_grad()
def store_atomwise_resiwise_compound_representations(self, input, output):
	self.atomwise_representations.append(output[-3].detach().cpu().numpy())
	self.resiwise_representations.append(output[-2].detach().cpu().numpy())
	self.compound_representations.append(output[-1].detach().cpu().numpy())
	return


def mask_softmax(a, mask, dim=-1):
    a_max = torch.max(a, dim, keepdim=True)[0]
    a_exp = torch.exp(a - a_max)
    a_exp = a_exp * mask
    a_softmax = a_exp / (torch.sum(a_exp, dim, keepdim=True) + 1e-6)
    return a_softmax


class GraphDenseSequential(nn.Sequential):
    def __init__(self, *args):
        super(GraphDenseSequential, self).__init__(*args)

    def forward(self, X, adj, mask):
        for module in self._modules.values():
            try:
                X = module(X, adj, mask)
            except BaseException:
                X = module(X)

        return X

In [None]:
import math
import numpy as np
from torch.nn.parameter import Parameter
import torch.nn.functional as F

atomf_len = 82
bondf_len = 6

'''
- Protein(residue) => CNN
BLOSUM62 matrix => 각 residue의 특성을 encoding한 matrix, 20x20
각 residue의 max개수를 이용하여 padding

- Compound(atom) => Graph convolution
Node feature : 길이가 82인 feature vector (원자 종류 63 + Degree 6 + Explicit valence 6 + implicit valence 6 + aromatic 1 = 82)
Edge feature(bond) : 길이가 6인 feature vector (결합 종류 4 + conjugate,ring 2 = 6)

tensor list의 길이는 11
1,2 : atomwise feature 각 배치별,
3 : bondwise feature
4 : atom neighbors
5 : bond neighbors
6 : neighbor matrix
7,8 : resiwise feature
9 : ba_value
10,11 : atmoresi graph
'''
def initialize_(batch_size, input_dim, hidden_dim): # batch_size x input_dim x hidden_dim의 weight
    weight = torch.nn.Parameter(torch.FloatTensor(input_dim, hidden_dim)).cuda()
    torch.nn.init.kaiming_uniform_(weight)
    weight = weight.reshape(1, input_dim, hidden_dim)
    weight = weight.expand(batch_size,input_dim, hidden_dim)
    return weight

def embedding_(input, weight): # batch x 원자 개수 x hidden_dim의 representation
    return F.leaky_relu(torch.bmm(input, weight), 0.1)

class Network(nn.Module):
    def __init__(self, GNN_depth=2, k_head=3, hidden_dim1=10, hidden_dim2=10, DAN_depth=2):
        super(Network, self).__init__()

        self.GNN_depth = GNN_depth
        self.k_head = k_head
        self.hidden_dim1 = hidden_dim1
        self.hidden_dim2 = hidden_dim2
        self.DAN_depth = DAN_depth

        self.mpu = MPU(128)
        self.warp_gru = Warp_GRU(self.k_head, self.hidden_dim1)

        self.cnn = CNN(4, 128, 5)
        self.pairwise = Pairwise_pred_module(self.hidden_dim1)
        self.affinity = Affinity_pred_module(self.hidden_dim1, self.hidden_dim2, self.DAN_depth)

    def forward(self, atom_feature, masking, bond_feature, atom_neighbor, bond_neighbor, neighbor_matrix, residue_feature, residue_mask):
        batch_size = atom_feature.size(0)
        vertex_initial = atom_feature  # batch x atom개수 x atomf_len

        weight_init = initialize_(batch_size, atomf_len, self.hidden_dim1)  # batch x atomf_len x h1
        super_weight_init = initialize_(batch_size, self.hidden_dim1, self.hidden_dim1)  # batch x h1 x h1

        vertex_feature = embedding_(vertex_initial, weight_init)  # batch x atom 개수 x h1
        super_node_init = torch.sum(vertex_feature, dim=1, keepdim=True)  # Summation of node features, batch x 1 x h1
        super_node_feature = F.tanh(embedding_(super_node_init, super_weight_init))  # super_node_features, batch x 1 x h1
        #print('first embedding')
        for GNN_iter in range(self.GNN_depth):

            u_i = self.mpu(vertex_feature, masking, bond_feature, atom_neighbor, bond_neighbor, neighbor_matrix)
            #print('mpu_{}_iter'.format(GNN_iter))
            vertex_feature, super_node_feature = self.warp_gru(vertex_feature, super_node_feature, u_i)
            #print('warp_gru_{}_iter'.format(GNN_iter))

        protein_feature = self.cnn(residue_feature)
        #print('cnn')

        pairwise_pred = self.pairwise(vertex_feature, protein_feature, masking, residue_mask)
        #print('pairwise prediction')
        affinity_pred = self.affinity(vertex_feature, super_node_feature, protein_feature, masking, residue_mask, pairwise_pred)

        return affinity_pred


class MPU(nn.Module):
    def __init__(self, hidden_dim):
        super(MPU, self).__init__()
        self.hidden_dim = hidden_dim

    def gather_neighbor_info(self, atom_features,masking, bond_features, atom_neighbor, bond_neighbor, neighbor_matrix):
        neighbors = torch.zeros(size=(atom_features.size(0), atom_features.size(1), 10,atom_features.size(2) + 6)).cuda()  # batch x atom x 10 x hidden +6
        batch_size = atom_features.size(0)

        atom_neighbor = atom_neighbor.reshape(batch_size, -1, 10)
        bond_neighbor = bond_neighbor.reshape(batch_size, -1, 10)

        atom_id = torch.tensor([atom_neighbor[mol, -1, -1].item() for mol in range(batch_size)])
        atom_id = atom_id.reshape(-1, 1).unsqueeze(2).expand(batch_size, -1, 10)
        atom_id = atom_id.cuda()

        atom_neighbor = (atom_neighbor - atom_id).type(torch.int64)  # add index 하기 전 atom index
        bond_neighbor = (bond_neighbor - atom_id).type(torch.int64)  # add index 하기 전 edge index

        for mol_idx, atom_idx in torch.nonzero(masking):  # masking된 정보를 이용해 molecule과 atom의 index로 iterate
            neighbor_idx = torch.nonzero(neighbor_matrix[mol_idx][atom_idx])[ -1].item()  # 해당 molecule, atom에서 연결 개수
            atom_indices = atom_neighbor[mol_idx, atom_idx, :neighbor_idx + 1]  # 연결된 atom의 인덱스
            bond_indices = bond_neighbor[mol_idx, atom_idx, :neighbor_idx + 1]  # 연결된 edge의 인덱스

            neighbor_atom_tensor = atom_features[mol_idx, atom_indices]
            neighbor_bond_tensor = bond_features[mol_idx, bond_indices]

            concat_atom_info = torch.cat((neighbor_atom_tensor, neighbor_bond_tensor), dim=1)  # 2 x (hidden + 6)
            neighbors[mol_idx, atom_idx, :atom_indices.size(0), :] = concat_atom_info
        return neighbors

    def forward(self, atom_features, masking, bond_features, atom_neighbor, bond_neighbor, neighbor_matrix):
        batch_size = atom_features.size(0)

        concatenated_neighbors = self.gather_neighbor_info(atom_features,masking, bond_features, atom_neighbor, bond_neighbor, neighbor_matrix)
        concatenated_neighbors = concatenated_neighbors.reshape(-1, 10, concatenated_neighbors.size(-1))  # batch * atom x 10 x hidden +6
        local_weight = initialize_(concatenated_neighbors.size(0), self.hidden_dim + 6, self.hidden_dim)
        local_embedding = embedding_(concatenated_neighbors, local_weight).sum(dim=1)  # batch * atom x hidden
        local = local_embedding.reshape(batch_size, -1, self.hidden_dim)  # batch x atom x hidden

        update_local = torch.cat((local, atom_features), dim=2)  # batch x atom x hidden*2
        update_local_weight = initialize_(batch_size, self.hidden_dim * 2, self.hidden_dim)
        updated_i = embedding_(update_local, update_local_weight)  # batch x atom x hidden
        return updated_i

class Attention(nn.Module):
    def __init__(self, batch_size, k_head, hidden_dim1):
        super(Attention, self).__init__()
        self.k_head = k_head
        self.hidden_dim1 = hidden_dim1
        self.batch_size = batch_size

        self.att_weight = initialize_(self.batch_size, self.hidden_dim1, 1)
        self.super_att_weight = initialize_(self.batch_size, self.hidden_dim1, self.hidden_dim1)
        self.vertex_att_weight = initialize_(self.batch_size, self.hidden_dim1, self.hidden_dim1)
        self.v_s_weight = initialize_(self.batch_size, self.k_head * self.hidden_dim1, self.hidden_dim1)

    def forward(self, vertex, super_node):
        v_att = F.tanh(torch.bmm(vertex, self.vertex_att_weight))  # batch x atom 개수 x h1
        s_att = F.tanh(torch.bmm(super_node, self.super_att_weight))  # batch x 1 x h1
        b_ = torch.mul(v_att, s_att)  # batch x atom 개수 x h1, elementwise multiplication
        k_head = []
        for i in range(self.k_head):
            alpha = F.softmax(torch.bmm(b_, self.att_weight))  # batch x atom 개수 x 1
            alpha = alpha.expand(alpha.size(0), alpha.size(1), self.hidden_dim1)  # batch x atom 개수 x h1
            alpha_vertex = alpha * vertex  # batch x atom 개수 x h1
            k_head.append(alpha_vertex)

        concat = torch.stack(k_head, dim=2)  # batch x atom 개수 x k_head x h1
        concat = concat.reshape(concat.size(0), concat.size(1),
                                concat.size(-1) * concat.size(-2))  # batch x atom 개수 x k_head*h1
        output = torch.bmm(concat, self.v_s_weight)  # batch x atom 개수 x h1
        output = F.tanh(output)

        return output


class Warp_GRU(nn.Module):
    def __init__(self, k_head, hidden_dim1):
        super(Warp_GRU, self).__init__()
        self.hidden_dim1 = hidden_dim1
        self.k_head = k_head

        self.v_gru = nn.GRUCell(self.hidden_dim1, self.hidden_dim1)
        self.s_gru = nn.GRUCell(self.hidden_dim1, self.hidden_dim1)

    def forward(self, vertex, super_node, u_i):
        batch_size = vertex.size(0)

        s_weight = initialize_(batch_size, self.hidden_dim1, self.hidden_dim1)
        s_v_weight = initialize_(batch_size, self.hidden_dim1, self.hidden_dim1)

        weight_11 = initialize_(batch_size, self.hidden_dim1, self.hidden_dim1)
        weight_12 = initialize_(batch_size, self.hidden_dim1, self.hidden_dim1)
        weight_21 = initialize_(batch_size, self.hidden_dim1, self.hidden_dim1)
        weight_22 = initialize_(batch_size, self.hidden_dim1, self.hidden_dim1)

        k_head_attn = Attention(batch_size, self.k_head, self.hidden_dim1)

        u_s = F.tanh(torch.bmm(super_node, s_weight))  # batch x 1 x h1
        u_s_v = F.tanh(torch.bmm(super_node, s_v_weight))  # batch x 1 x h1, gathering info from super node

        u_v_s = k_head_attn(vertex, super_node)  # batch x atom 개수 x h1
        g_v_s = torch.sigmoid(torch.bmm(u_v_s, weight_11) + torch.bmm(u_s, weight_12))  # batch x atom 개수 x h1
        t_v_s = torch.mul(1 - g_v_s, u_v_s) + torch.mul(g_v_s, u_s)  # batch x atom 개수 x h1

        g_s_i = torch.sigmoid(torch.bmm(u_i, weight_21) + torch.bmm(u_s_v, weight_22))  # batch x atom 개수 x h1
        t_s_i = torch.mul(1 - g_s_i, u_i) + torch.mul(g_s_i, u_s_v)  # batch x atom 개수 x h1

        atom_node_updated = self.v_gru(vertex.view(-1, self.hidden_dim1),t_s_i.view(-1, self.hidden_dim1))  # batch x atom 개수 x h1
        atom_node_updated = atom_node_updated.view(batch_size, -1, self.hidden_dim1)
        super_node_updated = self.s_gru(super_node.view(-1, self.hidden_dim1),torch.sum(t_v_s, dim=1).view(-1, self.hidden_dim1))  # batch x 1 x h1
        super_node_updated = super_node_updated.view(batch_size, -1, self.hidden_dim1)

        return atom_node_updated, super_node_updated



class CNN(nn.Module):
    def __init__(self, num_layer, hidden_size, kernel_size):
        super(CNN, self).__init__()
        self.num_layer = num_layer
        self.hidden_size = hidden_size
        self.kernel_size = kernel_size

        # self.embed_seq = nn.Embedding(len(self.init_word_features), 20, padding_idx=0)
        # self.embed_seq.weight = nn.Parameter(self.init_word_features)
        # self.embed_seq.weight.requires_grad = False

        self.conv_first = nn.Conv1d(20, self.hidden_size, kernel_size=self.kernel_size, padding=int((self.kernel_size - 1) / 2))
        self.conv_last = nn.Conv1d(self.hidden_size, self.hidden_size, kernel_size=self.kernel_size,padding=int((self.kernel_size - 1) / 2))

        self.plain_CNN = nn.ModuleList([])
        for i in range(self.num_layer):
            self.plain_CNN.append(nn.Conv1d(self.hidden_size, self.hidden_size, kernel_size=self.kernel_size,padding=int((self.kernel_size - 1) / 2)))

    def forward(self, x):
        sequence = x.transpose(1, 2)  # batch x h1 x residue
        # embedding = self.embed_seq(sequence)
        x = F.leaky_relu(self.conv_first(sequence), 0.1)
        #print('cnn first layer')
        for num in range(self.num_layer):
            x = self.plain_CNN[num](x)
            x = F.leaky_relu(x, 0.1)
        x = F.leaky_relu(self.conv_last(x), 0.1)
        x = x.transpose(1, 2)  # batch x residue x h1
        return x


class Pairwise_pred_module(nn.Module):
    def __init__(self, hidden_size):
        super(Pairwise_pred_module, self).__init__()
        self.hidden_dim = hidden_size

    def forward(self,comp_feature, prot_feature, vertex_mask, seq_mask):
        batch_size = comp_feature.size(0)
        weight_atom = initialize_(batch_size, self.hidden_dim, self.hidden_dim)
        weight_residue = initialize_(batch_size, self.hidden_dim, self.hidden_dim)

        pairwise_c_feature = embedding_(comp_feature, weight_atom)  # batch x atom x hidden
        pairwise_p_feature = embedding_(prot_feature, weight_residue)  # batch x residue x hidden
        pairwise_pred = torch.sigmoid(torch.matmul(pairwise_c_feature, pairwise_p_feature.transpose(1, 2)))  # batch x atom x residue
        pairwise_mask = torch.matmul(vertex_mask.view(batch_size, -1, 1),seq_mask.view(batch_size, 1, -1))  # batch x atom x residue
        pairwise_pred = pairwise_pred * pairwise_mask  # batch x atom x residue

        return pairwise_pred


class Affinity_pred_module(nn.Module):
    def __init__(self, hidden_size, hidden_size2, DAN_depth):
        super(Affinity_pred_module, self).__init__()
        self.hidden_dim1 = hidden_size
        self.hidden_dim2 = hidden_size2
        self.DAN_depth = DAN_depth

        self.mc = nn.ModuleList([nn.Linear(self.hidden_dim2, self.hidden_dim2) for i in range(self.DAN_depth)]).cuda()
        self.mp = nn.ModuleList([nn.Linear(self.hidden_dim2, self.hidden_dim2) for i in range(self.DAN_depth)]).cuda()

        self.vc = nn.ModuleList([nn.Linear(self.hidden_dim2, self.hidden_dim2) for i in range(self.DAN_depth)]).cuda()
        self.rp = nn.ModuleList([nn.Linear(self.hidden_dim2, self.hidden_dim2) for i in range(self.DAN_depth)]).cuda()
        self.vs = nn.ModuleList([nn.Linear(self.hidden_dim2, 1) for i in range(self.DAN_depth)]).cuda()
        self.rs = nn.ModuleList([nn.Linear(self.hidden_dim2, 1) for i in range(self.DAN_depth)]).cuda()

        self.v_to_r_transform = nn.ModuleList(
            [nn.Linear(self.hidden_dim2, self.hidden_dim2) for i in range(self.DAN_depth)]).cuda()
        self.r_to_v_transform = nn.ModuleList(
            [nn.Linear(self.hidden_dim2, self.hidden_dim2) for i in range(self.DAN_depth)]).cuda()

        self.GRU_dan = nn.GRUCell(self.hidden_dim2, self.hidden_dim2).cuda()
        self.W_out = nn.Linear(self.hidden_dim2 * self.hidden_dim2 * 2, 1).cuda()

    def mask_softmax(self, a, mask, dim=-1):
        a_max = torch.max(a, dim, keepdim=True)[0]
        a_exp = torch.exp(a - a_max)
        a_exp = a_exp * mask
        a_softmax = a_exp / (torch.sum(a_exp, dim, keepdim=True) + 1e-6)
        return a_softmax

    def dan_gru(self, batch_size, comp_feature, prot_feature, vertex_mask, seq_mask, pairwise_pred):
        vertex_mask = vertex_mask.view(batch_size, -1, 1)
        seq_mask = seq_mask.view(batch_size, -1, 1)

        c0 = torch.sum(comp_feature * vertex_mask, dim=1) / torch.sum(vertex_mask, dim=1)  # batch x hidden2
        p0 = torch.sum(prot_feature * seq_mask, dim=1) / torch.sum(seq_mask, dim=1)  # batch x hidden2

        m = c0 * p0  # batch x hidden2
        for DAN_iter in range(self.DAN_depth):
            r_to_v = torch.matmul(pairwise_pred, F.tanh(self.r_to_v_transform[DAN_iter](prot_feature)))  # batch x atom 개수 x hidden2, equ23
            v_to_r = torch.matmul(pairwise_pred.transpose(1, 2), F.tanh(self.v_to_r_transform[DAN_iter](comp_feature)))  # batch x residue 개수 x hidden2, equ24

            v_tmp = F.tanh(self.vc[DAN_iter](comp_feature)) * F.tanh(self.mc[DAN_iter](m)).view(batch_size, 1,-1) * r_to_v  # batch x atom 개수 x hidden2, equ25
            r_tmp = F.tanh(self.rp[DAN_iter](prot_feature)) * F.tanh(self.mp[DAN_iter](m)).view(batch_size, 1,-1) * v_to_r  # batch x residue 개수 x hidden2, equ26

            v_att = self.mask_softmax(self.vs[DAN_iter](v_tmp).view(batch_size, -1),vertex_mask.view(batch_size, -1))  # batch x atom개수, equ27
            r_att = self.mask_softmax(self.rs[DAN_iter](r_tmp).view(batch_size, -1),seq_mask.view(batch_size, -1))  # batch x residue 개수, equ28

            compound_fixed = torch.sum(comp_feature * v_att.view(batch_size, -1, 1), dim=1)  # batch x hidden2, equ29
            protein_fixed = torch.sum(prot_feature * r_att.view(batch_size, -1, 1), dim=1)  # batch x hidden2, equ30

            m = self.GRU_dan(m, compound_fixed * protein_fixed)  # batch x hidden2, equ31

        return compound_fixed, protein_fixed, m

    def forward(self, comp_feature, super_feature, prot_feature, vertex_mask, seq_mask, pairwise_pred):
        batch_size = comp_feature.size(0)

        weight_atom = initialize_(batch_size, self.hidden_dim1, self.hidden_dim2)
        weight_super = initialize_(batch_size, self.hidden_dim1, self.hidden_dim2)
        weight_protein = initialize_(batch_size, self.hidden_dim1, self.hidden_dim2)

        comp_feature = embedding_(comp_feature, weight_atom)  # batch x atom 개수 x hidden2
        super_feature = embedding_(super_feature, weight_super)  # batch x 1 x hidden2
        prot_feature = embedding_(prot_feature, weight_protein)  # batch x residue 개수 x hidden2

        cf, pf, m = self.dan_gru(batch_size, comp_feature, prot_feature, vertex_mask, seq_mask,pairwise_pred)
        combined_representation = torch.cat([cf.view(batch_size, -1), super_feature.view(batch_size, -1)],dim=1)  # batch x 2*hidden2
        flatten = F.leaky_relu(torch.matmul(combined_representation.view(batch_size, -1, 1), pf.view(batch_size, 1, -1)).view(batch_size, -1), 0.1)  # batch x 2 * hidden2^2
        affinity_pred = self.W_out(flatten)  # batch x 1

        return affinity_pred, pairwise_pred


In [None]:
net = Network(GNN_depth=4, k_head=3, hidden_dim1=128, hidden_dim2=128, DAN_depth = 2).cuda()
net.parameters()

<generator object Module.parameters at 0x7fe181fbd650>

In [None]:
class Masked_BCELoss(nn.Module):
	def __init__(self):
		super(Masked_BCELoss, self).__init__()
		self.criterion = nn.BCELoss(reduce=False)
	def forward(self, pred, label, pairwise_mask, vertex_mask, seq_mask):
		batch_size = pred.size(0)
		loss_all = self.criterion(pred, label)
		loss_mask = torch.matmul(vertex_mask.view(batch_size,-1,1), seq_mask.view(batch_size,1,-1))*pairwise_mask
		loss = torch.sum(loss_all*loss_mask) / torch.sum(pairwise_mask).clamp(min=1e-10)
		return loss

In [None]:
data[1].view(8,-1,1).shape, data[7].view(8,1,-1).shape

(torch.Size([8, 37, 1]), torch.Size([8, 1, 844]))

In [None]:
data[8].sum()

tensor(51.3000, device='cuda:0')

In [None]:
loss_all.view(-1,1,1).shape

torch.Size([249824, 1, 1])

In [None]:
data[10].shape

torch.Size([8, 37, 844])

In [None]:
data[8].shape

torch.Size([8, 1])

In [None]:
bce = nn.BCELoss(reduce=False)
loss_all = bce(pairwise,data[9])



In [None]:
loss_all.shape

torch.Size([8, 37, 844])

In [None]:
criterion2 = Masked_BCELoss()
loss_pairwise = criterion2(pairwise, data[9], data[10], data[1], data[7])



In [None]:
loss_pairwise

tensor(0., device='cuda:0', grad_fn=<DivBackward0>)

In [None]:
import torch.optim as optim
from sklearn.metrics import roc_auc_score, mean_squared_error
def train(train_loader, valid_loader, test_loader, epochs, lambd):

  net = Network(GNN_depth=4, k_head=3, hidden_dim1=128, hidden_dim2=128, DAN_depth = 2).cuda()

  criterion1 = nn.MSELoss()
  criterion2 = Masked_BCELoss()

  optimizer = optim.Adam(net.parameters(), lr=0.0005, weight_decay=0, amsgrad=True)
  scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

  min_rmse=1000
  for epoch in range(epochs):
    print('Epoch {} started!'.format(epoch))

    train_output_list = []
    train_label_list = []
    total_loss = 0
    affinity_loss = 0
    pairwise_loss = 0
    for i, train_data in enumerate(train_loader):
      batch_size = train_data[0].size(0)

      optimizer.zero_grad()
      affinity_pred, pairwise_pred = net(train_data[0], train_data[1], train_data[2], train_data[3], train_data[4], train_data[5], train_data[6], train_data[7])

      loss_aff = criterion1(affinity_pred, train_data[8])
      loss_pairwise = criterion2(pairwise_pred,train_data[9], train_data[10], train_data[1], train_data[7])
      loss = loss_aff + lambd*loss_pairwise

      total_loss += float(loss.data*batch_size)
      affinity_loss += float(loss_aff.data*batch_size)
      pairwise_loss += float(loss_pairwise.data*batch_size)

      loss.backward()
      nn.utils.clip_grad_norm_(net.parameters(), 5)
      optimizer.step()
    scheduler.step()

    loss_list = [total_loss, affinity_loss, pairwise_loss]
    loss_name = ['total loss', 'affinity loss', 'pairwise loss']
    print_loss = [loss_name[i]+' '+str(round(loss_list[i]/float(len(train_data[0])), 6)) for i in range(len(loss_name))]
    print ('epoch:',epoch, ' '.join(print_loss))

    valid_rmse, valid_auc, valid_label, valid_output = test(net, valid_loader)
    print('Valid rmse : {}\n'.format(valid_rmse) + 'Valid average AUC : {}'.format(valid_auc))
    print('Valid rmse : {}\n'.format(valid_rmse))

    if valid_rmse < min_rmse:
    #if valid_auc > max_auc:
      min_rmse = valid_rmse
      #max_auc = valid_auc
      test_rmse, test_auc, test_label, test_output = test(net, test_loader)
      print('Test rmse : {}\n'.format(test_rmse) + 'Test average AUC : {}'.format(test_auc))
      print('Test rmse : {}\n'.format(test_rmse))

  print('Train finished!')
  return test_rmse, test_auc, test_label, test_output


def test(net, test_loader):
  output_list = []
  label_list = []
  pairwise_auc_list = []

  for i, test_data in enumerate(test_loader):
    batch_size = test_data[0].size(0)

    affinity_pred, pairwise_pred = net(test_data[0], test_data[1], test_data[2], test_data[3], test_data[4], test_data[5], test_data[6], test_data[7])

    # for j in range(len(test_data[10])):
    #   num_vertex = int(torch.sum(test_data[1][j,:]))
    #   num_residue = int(torch.sum(test_data[7][j,:]))
    #   pairwise_pred_i = pairwise_pred[j, :num_vertex, :num_residue].cpu().detach().numpy().reshape(-1)
    #   pairwise_label_i = test_data[10][j,:num_vertex,:num_residue].cpu().detach().numpy().reshape(-1)
    #   pairwise_auc_list.append(roc_auc_score(pairwise_label_i, pairwise_pred_i))
    output_list += affinity_pred.cpu().detach().numpy().reshape(-1).tolist()
    label_list += test_data[8].reshape(-1).tolist()
  output_list = np.array(output_list)
  label_list = np.array(label_list)
  rmse_value = np.sqrt(mean_squared_error(label_list,output_list))
  #average_pairwise_auc = np.mean(pairwise_auc_list)

  return rmse_value, 1, label_list, output_list



In [None]:
test_rmse, test_auc, test_label, test_output = train(train_loader, valid_loader, test_loader, 3,0.1)



Epoch 0 started!




epoch: 0 total loss 3442.2249 affinity loss 3442.2249 pairwise loss 0.0
Valid rmse : 2.5780447802050883
Valid average AUC : 1
Valid rmse : 2.5780447802050883





ValueError: ignored

In [None]:
data[7].shape

torch.Size([8, 844])

In [None]:
data[9].shape, data[10].shape

(torch.Size([8, 67, 346]), torch.Size([8, 67, 346]))

In [None]:
net = Network(GNN_depth=4, k_head=3, hidden_dim1=128, hidden_dim2=128, DAN_depth = 2).cuda()
affinity, pairwise = net(data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7])



In [None]:
affinity, pairwise

(tensor([[ 0.0694],
         [-0.1045],
         [-0.0691],
         [-0.0164],
         [ 0.0290],
         [ 0.0617],
         [-0.0709],
         [-0.0851]], device='cuda:0', grad_fn=<AddmmBackward0>),
 tensor([[[0.5361, 0.5373, 0.5494,  ..., 0.0000, 0.0000, 0.0000],
          [0.5584, 0.5549, 0.5721,  ..., 0.0000, 0.0000, 0.0000],
          [0.5487, 0.5471, 0.5608,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
 
         [[0.5591, 0.5638, 0.5643,  ..., 0.0000, 0.0000, 0.0000],
          [0.5770, 0.5822, 0.5860,  ..., 0.0000, 0.0000, 0.0000],
          [0.5554, 0.5603, 0.5650,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.5755, 0.5807, 0.5818,  ..., 0.0000, 0.0000, 0.0000],
          [0.5735, 0.5787, 0.5794,  ..., 0.0000, 0.0000, 0.0000],
          [0.5777, 0.5830, 0.5849,  

In [None]:
a,b,c = stack_and_pad(pairwise.detach().cpu().numpy())

In [None]:
 gru = nn.GRU(82,82).cuda()
output , h1 = gru(data[0])

In [None]:
data[6].shape

torch.Size([8, 844, 20])