In [None]:
from torch import nn

import tqdm
import torch.nn.functional as F
from torch.autograd import Variable

device = 'cuda'

batch_size = 32

input_size = 84

output_size = 1

hidden_size = input_size*3

layers = 3

import pickle
import numpy as np
import torch
import torch.utils.data as thdata

class RNN(nn.Module):
    def __init__(self) -> None:
        super(RNN, self).__init__()
        self.layer1 = nn.RNN(input_size, hidden_size, layers).to(device='cuda')
        self.fc = nn.Linear(hidden_size, output_size).to(device='cuda')
        # self.h0 = torch.zeros(1, batch_size, hidden_size).to(device='cuda')
    
    def forward(self, x):
        h0 = torch.zeros(layers, batch_size, hidden_size).to(device='cuda')
        output, hn = self.layer1(x, h0)
        y = self.fc(output[-1])
        y = torch.sigmoid(y)
        return y
    
class GRU(nn.Module):
    def __init__(self) -> None:
        super(GRU, self).__init__()
        self.layer1 = nn.GRU(input_size, hidden_size, layers).to(device='cuda')
        self.fc = nn.Linear(hidden_size, output_size).to(device='cuda')
        # self.h0 = torch.zeros(1, batch_size, hidden_size).to(device='cuda')
    
    def forward(self, x):
        h0 = torch.zeros(layers, batch_size, hidden_size).to(device='cuda')
        output, hn = self.layer1(x, h0)
        y = self.fc(output[-1])
        y = torch.sigmoid(y)
        return y


class LSTM(nn.Module):
    def __init__(self) -> None:
        super(LSTM, self).__init__()
        self.layer1 = nn.LSTM(input_size, hidden_size, layers).to(device='cuda')
        self.fc = nn.Linear(hidden_size, output_size).to(device='cuda')
        # self.h0 = torch.zeros(1, batch_size, hidden_size).to(device='cuda')
    
    def forward(self, x):
        h0 = torch.zeros(layers, batch_size, hidden_size).to(device='cuda')
        c0 = torch.zeros(layers, batch_size, hidden_size).to(device='cuda')
        output, _ = self.layer1(x, (h0, c0))
        y = self.fc(output[-1])
        y = torch.sigmoid(y)
        return y


In [None]:
'''LEM'''

import math

class LEMCell(nn.Module):
    def __init__(self, ninp, nhid, dt):
        super(LEMCell, self).__init__()
        self.ninp = ninp
        self.nhid = nhid
        self.dt = dt
        self.inp2hid = nn.Linear(ninp, 4 * nhid)
        self.hid2hid = nn.Linear(nhid, 3 * nhid)
        self.transform_z = nn.Linear(nhid, nhid)
        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.nhid)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def forward(self, x, y, z):
        transformed_inp = self.inp2hid(x)
        transformed_hid = self.hid2hid(y)
        i_dt1, i_dt2, i_z, i_y = transformed_inp.chunk(4, 1)
        h_dt1, h_dt2, h_y = transformed_hid.chunk(3, 1)

        ms_dt_bar = self.dt * torch.sigmoid(i_dt1 + h_dt1)
        ms_dt = self.dt * torch.sigmoid(i_dt2 + h_dt2)

        z = (1.-ms_dt) * z + ms_dt * torch.tanh(i_y + h_y)
        y = (1.-ms_dt_bar)* y + ms_dt_bar * torch.tanh(self.transform_z(z)+i_z)

        return y, z

class LEM(nn.Module):
    def __init__(self, dt=1.):
        super(LEM, self).__init__()
        self.nhid = hidden_size
        self.cell = LEMCell(input_size,hidden_size,dt)
        self.classifier = nn.Linear(hidden_size, output_size)
        self.init_weights()

    def init_weights(self):
        for name, param in self.named_parameters():
            if 'classifier' in name and 'weight' in name:
                nn.init.kaiming_normal_(param.data)

    def forward(self, input):
        ## initialize hidden states
        y = input.data.new(input.size(1), self.nhid).zero_()
        z = input.data.new(input.size(1), self.nhid).zero_()
        for x in input:
            y, z = self.cell(x,y,z)
        out = self.classifier(y)
        out = nn.Sigmoid(out)
        return out

In [None]:
'''MLP'''

class MLP(nn.Module):
    def __init__(self) -> None:
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size*5, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.layers(x)

In [None]:
def metrics(pred, label):
    TP, FP, TN, FN = 0, 0, 0, 0
    for i in range(batch_size):
        if label[i] == 1:
            if F.mse_loss(pred[i], torch.ones_like(pred[i])) < F.mse_loss(pred[i], torch.zeros_like(pred[i])):
                TP+=1
            else:
                FP+=1
        else:
            if F.mse_loss(pred[i], torch.ones_like(pred[i])) > F.mse_loss(pred[i], torch.zeros_like(pred[i])):
                TN+=1
            else:
                FN+=1
    return TP, FP, TN, FN
    
    
def test(dataloader, model, model_name=None):
    model.eval()
    with torch.no_grad():
        TP, FP, TN, FN = 0, 0, 0, 0
        for data, label in dataloader:
            if model_name=='mlp':
                pred = model(Variable(data.reshape(batch_size, input_size*5).to(device='cuda')))
            else:
                pred = model(Variable(data.permute(1, 0, 2).to(device='cuda')))
            tp, fp, tn, fn = metrics(pred, label)
            TP += tp
            FP += fp
            TN += tn
            FN += fn
    return TP, FP, TN, FN

In [None]:
import os

files = os.listdir("path")
models = ['rnn', 'mlp', 'gru', 'lstm']


for file in files:
    dataset = pickle.load(open('path/{}'.format(file), 'rb'))

    data = [item[0] for item in dataset]
    label = torch.asarray([item[1] for item in dataset], dtype=torch.float)
    data = np.reshape(data, (len(data), 5*84))

    norm = torch.nn.BatchNorm1d(5*84, affine=True)
    input = torch.asarray(data, dtype=torch.float)
    output = norm(input).reshape(len(data), 5, 84)

    dataset = thdata.TensorDataset(output, label)

    train_dataset, test_dataset = thdata.random_split(dataset, [0.8, 0.2], torch.Generator().manual_seed(42))


    train_dataset_loader = thdata.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    test_dataset_loader = thdata.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=True)


    for model_name in models:
        result_file = open('result/{}_{}_{}_new'.format(file[:-4], model_name, batch_size), 'w+')
        if model_name == 'rnn':
            model = RNN().to(device='cuda')
        if model_name == 'gru':
            model = GRU().to(device='cuda')
        if model_name == 'lstm':
            model = LSTM().to(device='cuda')
        if model_name == 'lem':
            model = LEM().to(device='cuda')
        if model_name == 'mlp':
            model = MLP().to(device='cuda')
            
        opt = torch.optim.Adam(model.parameters(), lr=0.001)
        epoch = 50
        
        tqdm_epoch = tqdm.tqdm([i for i in range(epoch)], desc='{}-{}'.format(file[:-4], model_name))
        for e in tqdm_epoch:
            loss_sum = 0
            for data, label in train_dataset_loader:
                label = Variable(label.to(device='cuda'))
                if model_name=='mlp':
                    pred = model(Variable(data.reshape(batch_size, input_size*5).to(device='cuda')))
                else:
                    pred = model(Variable(data.permute(1, 0, 2).to(device='cuda')))
                loss = F.mse_loss(pred, label)
                
                opt.zero_grad()
                loss.backward()
                opt.step()
                
                loss_sum += loss.item()
                
            
            model.eval()
            TP, FP, TN, FN = test(test_dataset_loader, model, model_name)
            P = TP / (TP + FP) if (TP + FP) else 0
            R = TP / (TP + FN) if (TP + FN) else 0
            F1 = 2 * P * R / (P + R) if (P + R) else 0
            tqdm_epoch.set_postfix(loss='{}'.format(loss_sum / len(train_dataset_loader)), TPR=TPR, FPR=FPR, ACC=ACC, refresh=False)
            model.train()
                    
            result_file.write('{}, {}, {}, {}, {}, {}, {}\n'.format(e, loss_sum, TP, FP, TN, FN, ACC))
            result_file.flush()
        
        result_file.close()
            
