In [1]:
import os

import numpy as np
from scipy import stats
from scipy.special import softmax
import math
import random
import pickle

import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.utils.data
from torch.nn.parameter import Parameter
from torch.nn import init

from tqdm import tqdm

In [13]:
def process(data):
    with open("./encoding.pkl", 'rb') as fin:
        onehot = pickle.load(fin)
        
    inputs = []
    labels = []
    ranges = []
    for aa, v in data:
        extra = 40-len(aa)
        r = math.floor(extra/2)
        l = math.ceil(extra/2)
        aa = np.array([onehot[c] for c in ('J'*r) + aa + ('J'*l)]).T
        rng = (40-l-20, r)
        
        inputs.append(aa)
        labels.append(float(v))
        ranges.append(rng)
    return inputs, labels, ranges
    
def toTensorDataset(datalist):
    torchds = torch.utils.data.TensorDataset(torch.tensor(np.concatenate([data[0] for data in datalist]), dtype = torch.float32),
                                             torch.tensor(np.concatenate([data[1] for data in datalist]), dtype = torch.float32),
                                             torch.tensor(np.concatenate([data[2] for data in datalist]), dtype = torch.int64))
    return torchds

def getData(fname):
    with open(fname, 'rb') as fin:
        train, test = pickle.load(fin)
        
    train = [process(z) for z in train]
    test = process(test)
    test = torch.utils.data.DataLoader(toTensorDataset([test]), batch_size=1000, shuffle=False, num_workers=2, drop_last=False)
    return train, test

In [14]:
class ResBlock(nn.Module):
    def __init__(self, channels, dilation):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv1d(channels, channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation)
        self.conv1_bn = nn.BatchNorm1d(channels)
        self.conv2 = nn.Conv1d(channels, channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation)
        self.conv2_bn = nn.BatchNorm1d(channels)
        
    def forward(self, x):
        x1 = F.relu( self.conv1_bn( self.conv1(x) ) )
        return self.conv2_bn( self.conv2(x1) ) + x

class EmbedderNet(nn.Module):
    def __init__(self, length, channels, outchannels):
        super(EmbedderNet, self).__init__()
        self.length = length
        
        self.pool = nn.MaxPool1d(2, 2, ceil_mode = True)
        self.conv = nn.Conv1d(40, channels, 1, 1, 0)
        self.conv_bn = nn.BatchNorm1d(channels)
        
        self.block1 = ResBlock(channels, 1)
        self.block2 = ResBlock(channels, 1)
        self.block3 = ResBlock(channels, 1)
        self.block4 = ResBlock(channels, 1)
        self.block5 = ResBlock(channels, 1)
        
        self.embed_1 = nn.Linear( int( channels*math.ceil(math.ceil(self.length/2)/2) ) , 128)
        self.embed_2 = nn.Linear(128, outchannels)
        
    def forward(self, x):
        batchsize = x.shape[0]
        x = F.relu( self.conv_bn( self.conv(x) ) )
        
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.pool( self.block4(x) )
        x = self.pool( self.block5(x) ).view(batchsize,-1)
        return self.embed_2( F.relu( self.embed_1(x) ) )
    
class Predictor_Dot(nn.Module):
    def __init__(self):
        super(Predictor_Dot, self).__init__()
        self.embedPrefix = EmbedderNet(10, 64, 16)
        self.embedSuffix = EmbedderNet(10, 64, 16)
        
    def forward(self, x):
        prefix = x[:,:,:10]
        suffix = x[:,:,10:]
        return torch.sum(self.embedPrefix(prefix) * self.embedSuffix(suffix), dim = 1)
    
class Predictor_Plus(nn.Module):
    def __init__(self):
        super(Predictor_Plus, self).__init__()
        self.embedPrefix = EmbedderNet(10, 64, 1)
        self.embedSuffix = EmbedderNet(10, 64, 1)
        
    def forward(self, x):
        prefix = x[:,:,:10]
        suffix = x[:,:,10:]
        return self.embedPrefix(prefix) + self.embedSuffix(suffix)
    
class Predictor_Joint(nn.Module):
    def __init__(self):
        super(Predictor_Joint, self).__init__()
        self.predictor = EmbedderNet(20, 64, 1)
        
    def forward(self, x):
        return self.predictor(x)

In [25]:
def RandomPadding(inputs, bounds):
    bounds = bounds.permute((1,0))
    indexes = np.random.randint(bounds[0], bounds[1]+1).reshape(-1,1)
    gatherindex = torch.tensor(
        np.array([i for i in range(20)]).reshape(1,-1) + indexes,
        dtype = torch.int64 )
    gatherindex = gatherindex.view(-1,1,20).repeat(1,40,1)
    
    result = torch.gather(inputs, 2, gatherindex)
    return result

def FixedPadding(inputs, bounds):
    bounds = bounds.permute((1,0))
    indexes = np.ones(bounds[0].shape[0]).reshape(-1,1)*10
    gatherindex = torch.tensor(
        np.array([i for i in range(20)]).reshape(1,-1) + indexes,
        dtype = torch.int64 )
    gatherindex = gatherindex.view(-1,1,20).repeat(1,40,1)
    
    result = torch.gather(inputs, 2, gatherindex)
    return result

def validate(net, loader, paddingFunction):
    predictions = []
    targets = []
    if len(loader) == 0:
        return np.array([]), np.array([])
    with torch.no_grad():
        pbar = tqdm(loader, position=0, leave=True)
        for (data, target, bounds) in pbar:
            data = paddingFunction(data, bounds).cuda()
            predictions.append( net(data).cpu().numpy().reshape(-1) )
            targets.append( target.numpy().reshape(-1) )
        predictions = np.concatenate(predictions)
        targets = np.concatenate(targets)
        return predictions, targets

def getTrainingValidationSplit(data, i):
    t = toTensorDataset( data[:i] + data[i+1:] )
    v = toTensorDataset([data[i]])
    trainloader = torch.utils.data.DataLoader(t, batch_size=100, shuffle=True, num_workers=2, drop_last=True)
    valloader = torch.utils.data.DataLoader(v, batch_size=1000, shuffle=False, num_workers=2, drop_last=False)
    return trainloader, valloader

# epochs: Number of training epochs
# train_loader: Dataloader for training data
# validation_loader: Dataloader for validation data
# modelname: Descriptive name for model. Can be anything
# architecture: Predictor_Dot, Predictor_Joint, or Predictor_Plus
# paddingFunction: RandomPadding or FixedPadding
def trainNet(epochs, train_loader, validation_loader, modelname, architecture, paddingFunction):
    net = architecture().cuda()
    optimizer = torch.optim.Adam(net.parameters())
    lossfunc = nn.MSELoss()
    
    losses = []
    loss_trace = []
    validation_trace = []
    bestr = float('-inf')
    for i in range(epochs):
        losses = []
        net = net.train()
        pbar = tqdm(train_loader, position=0, leave=True)
        for (data, target, bounds) in pbar:
            optimizer.zero_grad()
            data = paddingFunction(data, bounds).cuda()
            output = net(data)
            
            loss = lossfunc(output.reshape(-1), target.cuda() )
            
            loss.backward()
            optimizer.step()
            
            current_loss = loss.item()
            losses.append(current_loss)
            
            pbar.set_description("Batch {}: loss_avg = {:.5f}, loss_now = {:.5f}".format(i, np.mean(losses), current_loss ))
        net = net.eval()
        valx, valy = validate(net, validation_loader, paddingFunction)
        r = stats.pearsonr(valx, valy)[0]
        validation_trace.append(r)
        print ("Epoch {}: r={}".format(i,r))
        
        if r > bestr:
            torch.save(net.state_dict(), "./weights/{}_best.pt".format(modelname))
            bestr = r
        loss_trace.append(losses)
        
    net = architecture()
    net.load_state_dict(torch.load( "./weights/{}_best.pt".format(modelname), map_location='cpu' ))
    net.eval()
    return net, loss_trace, validation_trace



In [None]:
# Load in data
training_splits, test_loader = getData("./data/ranibizumab_log_enrichment.pkl")


In [None]:
# Train the model
model, losses, performance_trace = trainNet(
    2, *getTrainingValidationSplit(training_splits, 0), "testrun", Predictor_Dot, RandomPadding)


In [None]:
# Validate the model
predictions, values = validate(model[0].cuda(), test_loader, FixedPadding)
stats.pearsonr(predictions, values)
