In [3]:
import torch
import pandas
import numpy
from torch import nn
from torch.utils.tensorboard import SummaryWriter
import sys

In [16]:
class logger():
    '''simple class for handling tensorboard logging'''
    
    #init new stepper for each name seen do wit dictionary 
    def __init__(self, write_path):
        self.writer = SummaryWriter(write_path)
        self.step = 0 
    def log(self,name, value):
        self.writer.add_scalar(name, value, self.step)
        self.step += 1

In [11]:
class BN_block(nn.Module):
    def __init__(self, in_dim):
        super(BN_block,self).__init__()
        self.bn = nn.BatchNorm1d(in_dim)
        self.weight = nn.Linear(in_dim,in_dim)
        self.relu = nn.Sigmoid()
        self.cuda()
    def forward(self, inputs):
        x = self.weight(inputs)
        x = self.bn(x)
        x = self.relu(x)
        return x
        


In [19]:
class Model(nn.Module):
    def __init__(self,input_dim,hidden_dim,num_hidden_dim,output_dim,drop_out_percent):
        super(Model,self).__init__()
        self.init_embedding = nn.Linear(input_dim,hidden_dim)
        self.dense_1 = BN_block(hidden_dim)
        self.dense_2 = BN_block(hidden_dim)
        self.dense_3 = BN_block(hidden_dim)
        self.final = nn.Linear(hidden_dim,output_dim)
        
        
    def forward(self, inputs):
        
        x = self.init_embedding(inputs)
        x = self.dense_1(x)
        x = self.dense_2(x)
        x = self.dense_3(x)
        x = self.final(x)
        
        return torch.softmax(x,dim=1)#(x,dim=1)
        
        

In [52]:
class TwinNetwork(nn.Module):
    def __init__(self,input_dim,hidden_dim,num_hidden_dim,output_dim):
        super(TwinNetwork,self).__init__()
        self.init_embedding_X = nn.Linear(input_dim,hidden_dim)
        self.init_embedding_context = nn.Linear(input_dim,hidden_dim)
        
        self.dense_X = BN_block(hidden_dim)
        self.dense_context = BN_block(hidden_dim)
        
        self.final_X = nn.Linear(hidden_dim,output_dim)
        self.final_context = nn.Linear(hidden_dim,output_dim)
        
    def forward(self, inputs):
        
        context, X = inputs
        
        
        context = self.init_embedding_context(context)
        context = self.dense_context(context)
        context = self.final_context(context)
        #(batch_num,hidden_dim)
        
        X = self.init_embedding_context(X)
        X = self.dense_context(X)
        X = self.final_context(X)
        #(batch_num, hidden_dim)
        
        
        x = X.unsqueeze(dim=1)
        context = context.unsqueeze(dim=2)
        out = (torch.bmm(x,context))
        return out.squeeze(dim=1)
        
        

In [55]:
class ContrastiveTrainer():
    '''
    class for handling training of a model
    Model Trained using some sort of contrastive loss to avoid 
    softmax computation grossness
    
    
    write inference 
    '''
    def __init__(self,lr,input_dim,hidden_dim,num_hidden_dim, output_dim,loss, tensorBoardPath, generator):
        self.logger = logger(tensorBoardPath)
        self.model = TwinNetwork(input_dim,hidden_dim,num_hidden_dim, output_dim)
        self.model.cuda()
        self.optim = torch.optim.Adam(self.model.parameters(), lr = lr)
        if loss == "contrastive":
            self.loss = self.contrastive_loss
        
        else:
            print("bad loss")
    
        self.generator = generator
        print("yes")
    
    def on_epoch_tasks(self):
        pass
    
    def contrastive_loss(self,out,labels):
        return torch.nn.BCEWithLogitsLoss()(out.float(),labels.float())
    
    
    def train(self,num_epochs,batches_per_epoch):
        for epoch in range(num_epochs):
            #print(epoch)
            for batch in range(batches_per_epoch):  
                context, X, labels= self.generator.__next__()
                self.train_on(context, X, labels)
                
            self.on_epoch_tasks()
            
    
    def train_on(self,context, X, labels):
        out = self.model((context,X))
        
        
        
        
        loss = self.loss(out,labels)
        self.optim.zero_grad()
        loss.backward()
        
        self.logger.log('loss', loss)
        return loss.detach()
        