In [None]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
import pickle
import argparse
from collections import defaultdict
import pandas as pd
import numpy as np
import gzip
import shutil
import spacy 
import pandas as pd 
from tqdm import tqdm
from collections import Counter
import kora.install.rdkit
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Draw, MolFromSmiles, MolToSmiles
from sklearn.manifold import TSNE
import copy
import torch.nn as nn
import torch.nn.functional as F
import argparse
import math
import os
from torch import optim
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

In [None]:
class SandwichTransformer(nn.Module):

    def __init__(self, d_model = 512, nhead = 8, num_encoder_layers = 6,
                 num_decoder_layers = 6, dim_feedforward = 2048, dropout = 0.1,
                 sandwich_k = 2, sandwich_encoder = False, sandwich_decoder = False,
                 activation = F.relu, layer_norm_eps = 1e-5):
        super(SandwichTransformer, self).__init__()

        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
                                                activation, layer_norm_eps)
        encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm, sandwich_k if sandwich_encoder else 0)

        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
                                                activation, layer_norm_eps)
        decoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, sandwich_k if sandwich_decoder else 0)

        self._reset_parameters()

        self.d_model = d_model
        self.nhead = nhead

        self.sandwich_k = sandwich_k
        self.sandwich_encoder = sandwich_encoder
        self.sandwich_decoder = sandwich_decoder

    def forward(self, src, tgt, src_mask = None, tgt_mask = None, memory_mask = None,
                src_key_padding_mask = None, tgt_key_padding_mask = None,
                memory_key_padding_mask = None):

        is_batched = src.dim() == 3

        memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask, )
        output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                              tgt_key_padding_mask=tgt_key_padding_mask,
                              memory_key_padding_mask=memory_key_padding_mask)
        return output


    def generate_square_subsequent_mask(sz):
        return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

class TransformerEncoder(nn.Module):
    _constants_ = ['norm']
    def __init__(self, encoder_layer, num_layers, norm=None, sandwich_k=0):
        super(TransformerEncoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers+sandwich_k)
        self.num_layers = num_layers+sandwich_k
        self.norm = norm
        self.sandwich_k = sandwich_k

    def forward(self, src, mask = None, src_key_padding_mask = None):
        output = src
        for i, mod in enumerate(self.layers):
            if i < self.sandwich_k:
                output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, layer_type=1)
            elif i >= self.num_layers - self.sandwich_k:
                output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, layer_type=2)
            else:
                output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, layer_type=0)

        if self.norm is not None:
            output = self.norm(output)

        return output

class TransformerDecoder(nn.Module):
    _constants_ = ['norm']

    def __init__(self, decoder_layer, num_layers, norm=None, sandwich_k=0):
        super(TransformerDecoder, self).__init__()
        self.layers = _get_clones(decoder_layer, num_layers+sandwich_k)
        self.num_layers = num_layers+sandwich_k
        self.norm = norm
        self.sandwich_k = sandwich_k

    def forward(self, tgt, memory, tgt_mask = None,
                memory_mask = None, tgt_key_padding_mask = None,
                memory_key_padding_mask = None):
        output = tgt

        for i, mod in enumerate(self.layers):
            if i < self.sandwich_k:
                output = mod(output, memory, tgt_mask=tgt_mask,
                      memory_mask=memory_mask,
                      tgt_key_padding_mask=tgt_key_padding_mask,
                      memory_key_padding_mask=memory_key_padding_mask, layer_type=1)
            elif i >= self.num_layers - self.sandwich_k:
                output = mod(output, memory, tgt_mask=tgt_mask,
                      memory_mask=memory_mask,
                      tgt_key_padding_mask=tgt_key_padding_mask,
                      memory_key_padding_mask=memory_key_padding_mask, layer_type=2)
            else:
                output = mod(output, memory, tgt_mask=tgt_mask,
                      memory_mask=memory_mask,
                      tgt_key_padding_mask=tgt_key_padding_mask,
                      memory_key_padding_mask=memory_key_padding_mask, layer_type=0)

        if self.norm is not None:
            output = self.norm(output)

        return output

class TransformerEncoderLayer(nn.Module):
    _constants_ = ['batch_first', 'norm_first']

    def __init__(self, d_model, nhead, dim_feedforward = 2048, dropout = 0.1,
                 activation = F.relu, layer_norm_eps = 1e-5):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = activation

    def _setstate_(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerEncoderLayer, self)._setstate_(state)

    def forward(self, src, src_mask = None, src_key_padding_mask = None,
                layer_type=0):
        x = src

        if layer_type == 0:
            x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
            x = self.norm2(x + self._ff_block(x))
        elif layer_type == 1:
            x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
        else:
            x = self.norm2(x + self._ff_block(x))

        return x

    # self-attention block
    def _sa_block(self, x, attn_mask, key_padding_mask):
        x = self.self_attn(x, x, x,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           need_weights=False)[0]
        return self.dropout1(x)

    # feed forward block
    def _ff_block(self, x):
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)

class TransformerDecoderLayer(nn.Module):
    _constants_ = ['batch_first', 'norm_first']

    def __init__(self, d_model, nhead, dim_feedforward = 2048, dropout = 0.1,
                 activation = F.relu, layer_norm_eps = 1e-5):
      
        super(TransformerDecoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = activation

    def _setstate_(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerDecoderLayer, self)._setstate_(state)

    def forward(self, tgt, memory, tgt_mask = None, memory_mask= None,
                tgt_key_padding_mask = None, memory_key_padding_mask = None,
                layer_type = 0):
        x = tgt
        if layer_type == 0:
            x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
            x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask))
            x = self.norm3(x + self._ff_block(x))
        elif layer_type == 1:
            x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
            x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask))
        else:
            x = self.norm3(x + self._ff_block(x))

        return x

    # self-attention block
    def _sa_block(self, x, attn_mask , key_padding_mask):
        x = self.self_attn(x, x, x,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           need_weights=False)[0]
        return self.dropout1(x)

    # multihead attention block
    def _mha_block(self, x, mem, attn_mask, key_padding_mask):
        x = self.multihead_attn(x, mem, mem,
                                attn_mask=attn_mask,
                                key_padding_mask=key_padding_mask,
                                need_weights=False)[0]
        return self.dropout2(x)

    # feed forward block
    def _ff_block(self, x):
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout3(x)

def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout = 0.1, max_len = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [None]:
PAD = 0
UNK = 1
EOS = 2
SOS = 3
MASK = 4

class SandwichSmTrfm(nn.Module):
    def __init__(self, in_size, hidden_size, out_size, n_layers, sandwich_k, sandwich_encoder, sandwich_decoder):
        super(SandwichSmTrfm, self).__init__()
        self.in_size = in_size
        self.hidden_size = hidden_size
        self.embed = nn.Embedding(in_size, hidden_size)
        self.pe = PositionalEncoding(hidden_size, 0.1)
        self.sandwich_k = sandwich_k
        self.sandwich_encoder = sandwich_encoder
        self.sandwich_decoder = sandwich_decoder
        self.trfm = SandwichTransformer(d_model=hidden_size, nhead=4, 
        num_encoder_layers=n_layers, num_decoder_layers = n_layers, dim_feedforward=hidden_size,
        sandwich_k=sandwich_k, sandwich_encoder=sandwich_encoder, sandwich_decoder=sandwich_decoder,
        dropout=0.4)
        self.out = nn.Linear(hidden_size, out_size)

    def forward(self, src):
        embedded = self.embed(src)
        embedded += self.pe(embedded)
        hidden = self.trfm(embedded, embedded)
        out = self.out(hidden)
        out = F.log_softmax(out, dim=2)
        return out

    def _encode(self, src):
        embedded = self.embed(src)
        embedded = self.pe(embedded)
        output = embedded
        if self.sandwich_encoder:
            for i in range(self.trfm.encoder.num_layers - 1):
                if i < self.sandwich_k:
                    output = self.trfm.encoder.layers[i](output, None, layer_type=1)
                elif i >= self.trfm.encoder.num_layers - self.sandwich_k:
                    output = self.trfm.encoder.layers[i](output, None, layer_type=2)
                else:
                    output = self.trfm.encoder.layers[i](output, None, layer_type=0)
            penul = output.cpu().detach().numpy()
            output = self.trfm.encoder.layers[-1](output, None, layer_type=2)
        else:
            for i in range(self.trfm.encoder.num_layers - 1):
                output = self.trfm.encoder.layers[i](output, None, layer_type=0)
            penul = output.cpu().detach().numpy()
            output = self.trfm.encoder.layers[-1](output, None, layer_type=0)
        if self.trfm.encoder.norm:
            output = self.trfm.encoder.norm(output)
        output = output.cpu().detach().numpy()
        
        return np.hstack([np.mean(output, axis=0), np.max(output, axis=0), output[0,:,:], penul[0,:,:]])
    
    def encode(self, src):
        batch_size = src.shape[1]
        if batch_size <= 100:
            return self._encode(src)
        else: # Batch is too large to load
            i = 0
            while i < batch_size:
                if o == 0:
                    out = self._encode(src[:, i:i+100])
                else:
                    out = np.concatenate([out, self._encode(src[:, i:i+100])], axis=0)
                i += 100
            return out

def evaluate(model, test_loader, vocab):
    model.eval()
    total_loss = 0
    for b, sm in enumerate(test_loader):
        sm = torch.t(sm.cuda())
        with torch.no_grad():
            output = model(sm)
        loss = F.nll_loss(output.view(-1, len(vocab)),
                               sm.contiguous().view(-1),
                               ignore_index=PAD)
        total_loss += loss.item()
    return total_loss / len(test_loader)


In [None]:
model = SandwichSmTrfm(45, 256, 45, 4, 1, False, True)

In [None]:
checkpoint = torch.load('/content/drive/MyDrive/project_best.pth')
model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
def eval_mlp(X, y, rate, n_repeats):
    auc = np.empty(n_repeats)
    for i in range(n_repeats):
        clf = MLPClassifier(max_iter=1000)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1-rate, stratify=y)
        clf.fit(X_train, y_train)
        y_score = clf.predict_proba(X_test)
        auc[i] = roc_auc_score(y_test, y_score[:,1])
    res = {}
    res['auc mean'] = np.mean(auc)
    res['auc std'] = np.mean(np.std(auc, axis=0))
    return res

def get_inputs(sm):
    seq_len = 220
    sm = sm.split()
    if len(sm) > 218:
        sm = sm[:109]+sm[-109:]
    ids = [vocab.word_dict.get(token, unk_index) for token in sm]
    ids = [sos_index] + ids + [eos_index]
    seg = [1] * len(ids)
    padding = [pad_index] * (seq_len - len(ids))
    ids.extend(padding)
    seg.extend(padding)
    return ids, seg

def get_array(smiles):
    x_id, x_seg = [], []
    for sm in smiles:
        a,b = get_inputs(sm)
        x_id.append(a)
        x_seg.append(b)
    return torch.tensor(x_id), torch.tensor(x_seg)

In [None]:
fs = [('/content/drive/MyDrive/BBBP.csv', 'p_np', 'smiles'), ('/content/drive/MyDrive/HIV.csv', 'HIV_active', 'smiles'), ('/content/drive/MyDrive/bace.csv', 'Class', 'mol')]
for f in fs:
    model.eval()
    pad_index = 0
    unk_index = 1
    eos_index = 2
    sos_index = 3
    mask_index = 4
    df = pd.read_csv(f[0])
    rates = 2**np.arange(7)/80
    x_split = [split(sm) for sm in df[f[2]].values]
    xid, _ = get_array(x_split)
    X = model.encode(torch.t(xid).cuda())
    mean_score = np.zeros(len(rates))
    print(X.shape)
    for i, rate in enumerate(rates):
        score_dic = eval_mlp(X, df[f[1]].values, rate, 20)
        mean_score[i] = score_dic['auc mean']
        print(rate, score_dic)
    
    print(np.mean(mean_score))

In [None]:
model.eval()
X = model.encode(torch.t(xid).cuda())

In [None]:
X_reduced = TSNE(n_components=2, random_state=0).fit_transform(X)

In [None]:
fig = plt.figure(figsize=(8,7))
plt.rcParams['font.size'] = 14
plt.rcParams['font.size'] = 12
# p_np for bbbp, Class for bace, 'HIV_active' for hiv and use df1 not df
plt.scatter(X_reduced[df['HIV_active']==0][:,0], X_reduced[df['HIV_active']==0][:,1], label='negative', marker='o',alpha=0.5)
plt.scatter(X_reduced[df['HIV_active']==1][:,0], X_reduced[df['HIV_active']==1][:,1], label='positive', marker='o', alpha=0.8)

plt.axis('off')
plt.legend(loc='upper left')
plt.savefig('/content/11785FinalProject/LatentSpaceImages/HIV/HIV_baseline.png')