In [1]:
from torch.nn.functional import softmax, relu, selu, leaky_relu, elu, max_pool1d, batch_norm
import torch.nn.init as init
import time
import torch
import torch.nn as nn
import torch.optim as optim
import inspect
import random
import math

In [2]:
class NN_Encoder(nn.Module):
    def __init__(self, input_dim, out_dim, hid_dims, dropout):
        super().__init__()
        
        #Initializing 
        self.out_dim = out_dim

        self.dropout = nn.Dropout(dropout)
        
        self.relu = nn.ReLU()
        
        new_hid_dims = [input_dim]+hid_dims+[out_dim]
        
        self.hid_dims = new_hid_dims
        
        N_hid_dim = len(new_hid_dims)
        
        self.fcs = nn.ModuleList([
            nn.Sequential(
                nn.Linear(new_hid_dims[i],new_hid_dims[i+1]),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.BatchNorm1d(new_hid_dims[i+1])
            )
            for i in range(0,(N_hid_dim-1))
        ])
                    
    def forward(self, cell_src, batch_size):
        
        #src = [batch_size, cell_src_input_dim]
        
        x = cell_src

        for fc in self.fcs:
            x = fc(x)
            
        #Output will be [batch_size, output_dim]                   
        
        return x

In [3]:
class CNN_Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, out_dim, n_filters, filter_sizes, dropout):
        super().__init__()

        
        self.out_dim = out_dim
        
        self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx = 0)
        
        self.convs = nn.ModuleList([
                                    nn.Conv1d(in_channels = emb_dim, 
                                              out_channels = n_filters, 
                                              kernel_size = fs,
                                              padding_mode='zeros',
                                              padding=0)
                                    for fs in filter_sizes
                                    ])
        
        self.dropout = nn.Dropout(dropout)
       
        self.batch_norm_cnn = nn.BatchNorm1d(n_filters)
        
        #self.relu = leaky_relu
        self.relu = selu
        
        self.maxpool = max_pool1d
        
        self.fc = nn.Linear(len(filter_sizes) * n_filters, out_dim)

        self.batch_norm_out = nn.BatchNorm1d(out_dim)

    def forward(self, src, batch_size):
        
        #src = [src len, batch size]
        
        embedded = self.dropout(self.embedding(src))
        #embedded = [src len, batch size, emb dim]
        
        embedded = embedded.permute(1, 2, 0)
        #embedded = [batch size, emb dim, src len]
        
        conved = [self.batch_norm_cnn(self.relu(conv(embedded))) for conv in self.convs]
        #conved_n = [batch size, n_filters, src len - filter_sizes[n] + 1]
        
        pooled = [self.maxpool(conv, conv.shape[2]).squeeze(2) for conv in conved]
        #pooled_n = [batch size, n_filters]
        
        cat = self.dropout(torch.cat(pooled, dim = 1))
        output = self.dropout(self.relu(self.fc(cat)))
        

        return output

In [4]:
class Seq2Func(nn.Module):
    def __init__(self, cell_encoder, smiles_encoder, hid_dim, out_dim, dropout, device):
        super().__init__()
        
        self.cell_encoder = cell_encoder
        
        self.smiles_encoder = smiles_encoder
        
        self.device = device
        
        self.fc1 = nn.Linear(cell_encoder.out_dim+smiles_encoder.out_dim, hid_dim)
        
        self.fc2 = nn.Linear(hid_dim, out_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.relu = leaky_relu
        
    def forward(self, cell_src, smiles_src):
        
        #Get protein encoder output
        cell_output = self.cell_encoder(cell_src, cell_src.shape[1]) 
        #cell_output = [batch size, cell out_dim]
        
        #Get smiles encoder output
        smiles_output = self.smiles_encoder(smiles_src, smiles_src.shape[1])
        #smiles_output = [batch size, smiles out_dim]
        
        ls_output = torch.cat((cell_output,smiles_output),1)
        #ls_output = [batch size, cell out_dim + smiles out_dim]
        
        o1 = self.dropout(self.relu(self.fc1(ls_output)))
        #o1 = [batch size, hid_dim]
        
        final_output = self.relu(self.fc2(o1))
        #final_output = [batch_size, 1]
        
        return final_output


In [5]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.05, 0.05)

In [6]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [7]:
def evaluation(model, iterator, criterion, DEVICE):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in enumerate(iterator):

            cell_src = batch[1].to(DEVICE)
            smiles_src = batch[0].permute(1,0).to(DEVICE)
            trg = batch[2].to(DEVICE)

            output = model(cell_src, smiles_src).squeeze(1) 
            #output = [batch size]
            
            loss = criterion(output, trg)
            
            epoch_loss += loss.item()
            
            del cell_src
            del smiles_src
            torch.cuda.empty_cache()
        
    return epoch_loss / len(iterator)

In [8]:
def training(model, iterator, optimizer, criterion, clip, DEVICE):
    
    model.train()
    
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        
        cell_src = batch[1].to(DEVICE)
        smiles_src = batch[0].permute(1,0).to(DEVICE)
        trg = batch[2].to(DEVICE)
        
        optimizer.zero_grad()
        
        output = model(cell_src, smiles_src).squeeze(1)
        #output = [batch size]
        
        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        del cell_src
        del smiles_src
        torch.cuda.empty_cache()
        
    return epoch_loss / len(iterator)