In [1]:
import numpy as np
from matplotlib import pyplot as plt
from scipy import stats

import random
import math
import pickle
from collections import OrderedDict
import time
import os

import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.utils.data
from torch.nn.parameter import Parameter
from torch.nn import init

from tqdm import tqdm

In [4]:
class ResBlock(nn.Module):
    def __init__(self, channels, dilation):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv1d(channels, channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation)
        self.conv1_bn = nn.BatchNorm1d(channels)
        self.conv2 = nn.Conv1d(channels, channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation)
        self.conv2_bn = nn.BatchNorm1d(channels)
        
    def forward(self, x):
        x1 = F.relu( self.conv1_bn( self.conv1(x) ) )
        return self.conv2_bn( self.conv2(x1) ) + x

class EmbedderNet(nn.Module):
    def __init__(self, length, channels, outchannels):
        super(EmbedderNet, self).__init__()
        self.length = length
        
        self.pool = nn.MaxPool1d(2, 2, ceil_mode = True)
        self.conv = nn.Conv1d(40, channels, 1, 1, 0)
        self.conv_bn = nn.BatchNorm1d(channels)
        
        self.block1 = ResBlock(channels, 1)
        self.block2 = ResBlock(channels, 1)
        self.block3 = ResBlock(channels, 1)
        self.block4 = ResBlock(channels, 1)
        self.block5 = ResBlock(channels, 1)
        
        self.embed_1 = nn.Linear( int( channels*math.ceil(math.ceil(self.length/2)/2) ) , 128)
        self.embed_2 = nn.Linear(128, outchannels)
        
    def forward(self, x):
        batchsize = x.shape[0]
        x = F.relu( self.conv_bn( self.conv(x) ) )
        
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.pool( self.block4(x) )
        x = self.pool( self.block5(x) ).view(batchsize,-1)
        return self.embed_2( F.relu( self.embed_1(x) ) )
    
class Predictor_Dot(nn.Module):
    def __init__(self):
        super(Predictor_Dot, self).__init__()
        self.embedPrefix = EmbedderNet(10, 64, 16)
        self.embedSuffix = EmbedderNet(10, 64, 16)
        
    def forward(self, x):
        prefix = x[:,:,:10]
        suffix = x[:,:,10:]
        return torch.sum(self.embedPrefix(prefix) * self.embedSuffix(suffix), dim = 1)

In [5]:
class NetEnsemble(nn.Module):
    def __init__(self, lst):
        super(NetEnsemble, self).__init__()
        self.nets = nn.ModuleList(lst)
        
    def forward(self, x):
        return torch.stack([net(x) for net in self.nets]).permute((1,0,2))

class Combine(nn.Module):
    def __init__(self):
        super(Combine, self).__init__()
        
    def forward(self, x1, x2):
        return torch.einsum("nij,nij->n",x1,x2)
        
    def getPrefix(self, x1, x2):
        return torch.einsum("nij,ij->n",x1,x2)
    
    def getSuffix(self, x1, x2):
        return torch.einsum("ij,nij->n",x1,x2)

In [None]:
#TODO: add optimizer with documentation