# Import

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pylab as plt
import json
import seaborn as sns
import os
import random
from tqdm.notebook import tqdm
import plotly.express as px

import lightgbm as lgb
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, Dataset, DataLoader, random_split
from torch.nn import functional as F
from torch.optim import lr_scheduler

# Settings

In [2]:
class config:
    data_dir = '../OpenVaccine/'
    train_file = '../OpenVaccine/train.json'
    test_file = '../OpenVaccine/test.json'
    pretrain_dir = '../OpenVaccine/pretrains/'
    sample_submission = '../OpenVaccine/sample_submission.csv'
    learning_rate = 0.001
    batch_size = 64
    n_epoch = 50
    n_split = 5
    seed = 1234

# Utils

In [3]:
class AverageMeter:
    """
    Computes and stores the average and current value
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [4]:
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
seed_everything(config.seed)

# Model

In [5]:
class GCN(nn.Module):
    def __init__(self, in_features, out_features):
        super(GCN, self).__init__()
        self.linear_gcn = nn.Linear(in_features=in_features, out_features=out_features)
        
    def forward(self, input_, adj_matrix):
        sum_adj = torch.sum(adj_matrix, axis=2)
        sum_adj[sum_adj==0] = 1
        
        adj_matrix = adj_matrix.type(torch.float32)
        feature_agg = torch.bmm(adj_matrix, input_)
        feature_agg = feature_agg / sum_adj.unsqueeze(dim=2)
        feature_cat = torch.cat((input_, feature_agg), axis=2)
        
        feature = torch.sigmoid(self.linear_gcn(feature_cat))
        feature = feature / torch.norm(feature, p=2, dim=2).unsqueeze(dim=2)
        
        return feature
        
    
class Net(nn.Module):
    def __init__(self, num_embedding=14, seq_len=107, pred_len=68, dropout=0.5, embed_dim=100, hidden_dim=128, K=1):
        super(Net, self).__init__()
        
        self.pred_len = pred_len
        self.embedding_layer = nn.Embedding(num_embeddings=num_embedding, 
                                      embedding_dim=embed_dim)
        
        self.gcn = nn.ModuleList([GCN(in_features=3 * embed_dim * 2, out_features=3 * embed_dim) for i in range(K)])
        
        self.gru_layer = nn.GRU(input_size=3 * embed_dim, 
                          hidden_size=hidden_dim, 
                          num_layers=3, 
                          batch_first=True, 
                          dropout=dropout, 
                          bidirectional=True)
        
        self.linear_layer = nn.Linear(in_features=2 * hidden_dim, 
                                out_features=5)
        
    def forward(self, input_, adj_matrix):
        #embedding
        embedding = self.embedding_layer(input_)
        embedding = torch.reshape(embedding, (-1, embedding.shape[1], embedding.shape[2] * embedding.shape[3]))
        
        #gcn
        gcn_feature = embedding
        for gcn_layer in self.gcn:
            gcn_feature = gcn_layer(gcn_feature, adj_matrix)
        
        #gru
        gru_output, gru_hidden = self.gru_layer(gcn_feature)
        truncated = gru_output[:, :self.pred_len]
        
        output = self.linear_layer(truncated)
        
        return output

# Load Data

In [6]:
pred_cols = ['reactivity', 'deg_Mg_pH10', 'deg_pH10', 'deg_Mg_50C', 'deg_50C']

In [7]:
token2int = {x:i for i, x in enumerate('().ACGUBEHIMSX')}

def get_couples(structure):
    """
    For each closing parenthesis, I find the matching opening one and store their index in the couples list.
    The assigned list is used to keep track of the assigned opening parenthesis
    """
    opened = [idx for idx, i in enumerate(structure) if i == '(']
    closed = [idx for idx, i in enumerate(structure) if i == ')']

    assert len(opened) == len(closed)
    assigned = []
    couples = []

    for close_idx in closed:
        for open_idx in opened:
            if open_idx < close_idx:
                if open_idx not in assigned:
                    candidate = open_idx
            else:
                break
        assigned.append(candidate)
        couples.append([candidate, close_idx])
        
    assert len(couples) == len(opened)
    
    return couples

def build_matrix(couples, size):
    mat = np.zeros((size, size))
    
    for i in range(size):  # neigbouring bases are linked as well
        if i < size - 1:
            mat[i, i + 1] = 1
        if i > 0:
            mat[i, i - 1] = 1
    
    for i, j in couples:
        mat[i, j] = 1
        mat[j, i] = 1
        
    return mat

def convert_to_adj(structure):
    couples = get_couples(structure)
    mat = build_matrix(couples, len(structure))
    return mat

def preprocess_inputs(df, cols=['sequence', 'structure', 'predicted_loop_type']):
    inputs = np.transpose(
        np.array(
            df[cols]
            .applymap(lambda seq: [token2int[x] for x in seq])
            .values
            .tolist()
        ),
        (0, 2, 1)
    )
    
    adj_matrix = np.array(df['structure'].apply(convert_to_adj).values.tolist())
    
    return inputs, adj_matrix

In [8]:
train = pd.read_json(config.train_file, lines=True)
test = pd.read_json(config.test_file, lines=True)
sample_df = pd.read_csv(config.sample_submission)

In [9]:
train_inputs, train_adj = preprocess_inputs(train)
train_labels = np.array(train[pred_cols].values.tolist()).transpose((0, 2, 1))

train_inputs = torch.tensor(train_inputs, dtype=torch.long)
train_adj = torch.tensor(train_adj, dtype=torch.long)
train_labels = torch.tensor(train_labels, dtype=torch.float32)

# Train

In [10]:
def train_fn(epoch, model, train_loader, criterion, optimizer):
    model.train()
    model.zero_grad()
    train_loss = AverageMeter()
    
    for index, (input_, adj, label) in tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Train Epoch {epoch}'):
        input_ = input_.cuda()
        adj = adj.cuda()
        label = label.cuda()
        preds = model(input_, adj)
        
        loss = criterion(preds, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss.update(loss.item())
    
    print(f"Train loss {train_loss.avg}")
    return train_loss.avg
    
def eval_fn(epoch, model, valid_loader, criterion):
    model.eval()
    eval_loss = AverageMeter()
    
    for index, (input_, adj, label) in tqdm(enumerate(valid_loader), total=len(valid_loader), desc=f'Valid Epoch {epoch}'):
        input_ = input_.cuda()
        adj = adj.cuda()
        label = label.cuda()
        preds = model(input_, adj)
        
        loss = criterion(preds, label)
        eval_loss.update(loss.item())
    
    print(f"Valid loss {eval_loss.avg}")
    return eval_loss.avg

In [11]:
def run(fold, train_loader, valid_loader):
    model = Net()
    model.cuda()
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(params=model.parameters(), lr=config.learning_rate, weight_decay=0.0)
    
    train_losses = []
    eval_losses = []
    for epoch in range(config.n_epoch):
        print('#################')
        print('###Epoch:', epoch)
        print('#################')
        
        train_loss = train_fn(epoch, model, train_loader, criterion, optimizer)
        eval_loss = eval_fn(epoch, model, valid_loader, criterion)
        train_losses.append(train_loss)
        eval_losses.append(eval_loss)
        
    torch.save(model.state_dict(), f'{config.pretrain_dir}/gru_{fold}.pt')
    return train_losses, eval_losses

In [12]:
splits = KFold(n_splits=config.n_split, shuffle=True, random_state=config.seed).split(train_inputs)

for fold, (train_idx, val_idx) in enumerate(splits):
    train_dataset = TensorDataset(train_inputs[train_idx], train_adj[train_idx], train_labels[train_idx])
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8)
    
    valid_dataset = TensorDataset(train_inputs[val_idx], train_adj[val_idx], train_labels[val_idx])
    valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=8)
    
    train_losses, eval_losses = run(fold, train_loader, valid_loader)
    break

#################
###Epoch: 0
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 0', max=30.0, style=ProgressStyle(description…


Train loss 0.9379975060621898


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 0', max=8.0, style=ProgressStyle(description_…


Valid loss 0.8019687235355377
#################
###Epoch: 1
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 1', max=30.0, style=ProgressStyle(description…


Train loss 0.8756959021091462


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 1', max=8.0, style=ProgressStyle(description_…


Valid loss 0.7724631242454052
#################
###Epoch: 2
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 2', max=30.0, style=ProgressStyle(description…


Train loss 0.8564922218521436


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 2', max=8.0, style=ProgressStyle(description_…


Valid loss 0.7631312273442745
#################
###Epoch: 3
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 3', max=30.0, style=ProgressStyle(description…


Train loss 0.8408145402868589


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 3', max=8.0, style=ProgressStyle(description_…


Valid loss 0.7410204913467169
#################
###Epoch: 4
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 4', max=30.0, style=ProgressStyle(description…


Train loss 0.8276837214827537


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 4', max=8.0, style=ProgressStyle(description_…


Valid loss 0.7370243836194277
#################
###Epoch: 5
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 5', max=30.0, style=ProgressStyle(description…


Train loss 0.8251523693402608


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 5', max=8.0, style=ProgressStyle(description_…


Valid loss 0.7331584952771664
#################
###Epoch: 6
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 6', max=30.0, style=ProgressStyle(description…


Train loss 0.8168214033047358


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 6', max=8.0, style=ProgressStyle(description_…


Valid loss 0.7306485418230295
#################
###Epoch: 7
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 7', max=30.0, style=ProgressStyle(description…


Train loss 0.815260182817777


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 7', max=8.0, style=ProgressStyle(description_…


Valid loss 0.7254208829253912
#################
###Epoch: 8
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 8', max=30.0, style=ProgressStyle(description…


Train loss 0.8114437376459439


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 8', max=8.0, style=ProgressStyle(description_…


Valid loss 0.7260830253362656
#################
###Epoch: 9
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 9', max=30.0, style=ProgressStyle(description…


Train loss 0.8122871972620487


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 9', max=8.0, style=ProgressStyle(description_…


Valid loss 0.7220852542668581
#################
###Epoch: 10
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 10', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8072882657249768


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 10', max=8.0, style=ProgressStyle(description…


Valid loss 0.7198872454464436
#################
###Epoch: 11
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 11', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8030313144127528


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 11', max=8.0, style=ProgressStyle(description…


Valid loss 0.7184105515480042
#################
###Epoch: 12
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 12', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7980181828141213


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 12', max=8.0, style=ProgressStyle(description…


Valid loss 0.7148968391120434
#################
###Epoch: 13
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 13', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7947530840833982


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 13', max=8.0, style=ProgressStyle(description…


Valid loss 0.7119524721056223
#################
###Epoch: 14
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 14', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7902694846192996


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 14', max=8.0, style=ProgressStyle(description…


Valid loss 0.7117275018244982
#################
###Epoch: 15
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 15', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7869565506776174


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 15', max=8.0, style=ProgressStyle(description…


Valid loss 0.7129603307694197
#################
###Epoch: 16
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 16', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7873871177434921


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 16', max=8.0, style=ProgressStyle(description…


Valid loss 0.7156950924545527
#################
###Epoch: 17
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 17', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7837731381257375


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 17', max=8.0, style=ProgressStyle(description…


Valid loss 0.7104544378817081
#################
###Epoch: 18
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 18', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7790212641159694


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 18', max=8.0, style=ProgressStyle(description…


Valid loss 0.7036212515085936
#################
###Epoch: 19
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 19', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7767200658718745


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 19', max=8.0, style=ProgressStyle(description…


Valid loss 0.7036113571375608
#################
###Epoch: 20
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 20', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7755615388353666


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 20', max=8.0, style=ProgressStyle(description…


Valid loss 0.7049744818359613
#################
###Epoch: 21
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 21', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7725229198733966


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 21', max=8.0, style=ProgressStyle(description…


Valid loss 0.7031883075833321
#################
###Epoch: 22
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 22', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.769252173602581


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 22', max=8.0, style=ProgressStyle(description…


Valid loss 0.7001206800341606
#################
###Epoch: 23
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 23', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7661248818039894


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 23', max=8.0, style=ProgressStyle(description…


Valid loss 0.6994511317461729
#################
###Epoch: 24
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 24', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7689721216758092


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 24', max=8.0, style=ProgressStyle(description…


Valid loss 0.6980802100151777
#################
###Epoch: 25
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 25', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7618824611107509


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 25', max=8.0, style=ProgressStyle(description…


Valid loss 0.7018421739339828
#################
###Epoch: 26
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 26', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.763642476995786


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 26', max=8.0, style=ProgressStyle(description…


Valid loss 0.6990994978696108
#################
###Epoch: 27
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 27', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7580372214317321


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 27', max=8.0, style=ProgressStyle(description…


Valid loss 0.70032705552876
#################
###Epoch: 28
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 28', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7606729820370675


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 28', max=8.0, style=ProgressStyle(description…


Valid loss 0.6985184624791145
#################
###Epoch: 29
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 29', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7532033674418926


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 29', max=8.0, style=ProgressStyle(description…


Valid loss 0.6977801229804754
#################
###Epoch: 30
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 30', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7507945095499357


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 30', max=8.0, style=ProgressStyle(description…


Valid loss 0.6956973411142826
#################
###Epoch: 31
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 31', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7473209102948507


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 31', max=8.0, style=ProgressStyle(description…


Valid loss 0.7020544148981571
#################
###Epoch: 32
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 32', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7484260886907578


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 32', max=8.0, style=ProgressStyle(description…


Valid loss 0.7003972511738539
#################
###Epoch: 33
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 33', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7463013355930647


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 33', max=8.0, style=ProgressStyle(description…


Valid loss 0.6960629615932703
#################
###Epoch: 34
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 34', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.74413104703029


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 34', max=8.0, style=ProgressStyle(description…


Valid loss 0.6957117896527052
#################
###Epoch: 35
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 35', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7431227500240009


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 35', max=8.0, style=ProgressStyle(description…


Valid loss 0.70334143191576
#################
###Epoch: 36
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 36', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.741237630446752


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 36', max=8.0, style=ProgressStyle(description…


Valid loss 0.7015927508473396
#################
###Epoch: 37
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 37', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7372721912960212


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 37', max=8.0, style=ProgressStyle(description…


Valid loss 0.7002626452594995
#################
###Epoch: 38
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 38', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.728891118367513


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 38', max=8.0, style=ProgressStyle(description…


Valid loss 0.7031639143824577
#################
###Epoch: 39
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 39', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7257798391083876


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 39', max=8.0, style=ProgressStyle(description…


Valid loss 0.7006662357598543
#################
###Epoch: 40
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 40', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7287229532996814


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 40', max=8.0, style=ProgressStyle(description…


Valid loss 0.7018567025661469
#################
###Epoch: 41
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 41', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7379623517394066


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 41', max=8.0, style=ProgressStyle(description…


Valid loss 0.7127504721283913
#################
###Epoch: 42
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 42', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7239407161871593


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 42', max=8.0, style=ProgressStyle(description…


Valid loss 0.7124345880001783
#################
###Epoch: 43
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 43', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7217509138087431


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 43', max=8.0, style=ProgressStyle(description…


Valid loss 0.7112282048910856
#################
###Epoch: 44
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 44', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7178290575742722


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 44', max=8.0, style=ProgressStyle(description…


Valid loss 0.7094032075256109
#################
###Epoch: 45
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 45', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7102376858393351


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 45', max=8.0, style=ProgressStyle(description…


Valid loss 0.7184518799185753
#################
###Epoch: 46
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 46', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7063660750786463


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 46', max=8.0, style=ProgressStyle(description…


Valid loss 0.7112066633999348
#################
###Epoch: 47
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 47', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.7055866504708926


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 47', max=8.0, style=ProgressStyle(description…


Valid loss 0.7093294374644756
#################
###Epoch: 48
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 48', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.6950497336685657


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 48', max=8.0, style=ProgressStyle(description…


Valid loss 0.7172360941767693
#################
###Epoch: 49
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 49', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.694547921915849


HBox(children=(FloatProgress(value=0.0, description='Valid Epoch 49', max=8.0, style=ProgressStyle(description…


Valid loss 0.726085564121604


# Visualize

In [85]:
fig = px.line(
    pd.DataFrame([train_losses, eval_losses], index=['loss', 'val_loss']).T, 
    y=['loss', 'val_loss'], 
    labels={'index': 'epoch', 'value': 'Mean Squared Error'}, 
    title='Training History')
fig.show()

# Test

In [13]:
public_df = test.query("seq_length == 107").copy()
private_df = test.query("seq_length == 130").copy()

public_inputs, public_adj = preprocess_inputs(public_df)
private_inputs, private_adj = preprocess_inputs(private_df)

public_inputs = torch.tensor(public_inputs, dtype=torch.long)
private_inputs = torch.tensor(private_inputs, dtype=torch.long)
public_adj = torch.tensor(public_adj, dtype=torch.long)
private_adj = torch.tensor(private_adj, dtype=torch.long)

In [14]:
model_short = Net(seq_len=107, pred_len=107)
model_long = Net(seq_len=130, pred_len=130)

list_public_preds = []
list_private_preds = []
for fold in range(config.n_split):
    model_short.load_state_dict(torch.load(f'{config.pretrain_dir}/gru_{fold}.pt'))
    model_long.load_state_dict(torch.load(f'{config.pretrain_dir}/gru_{fold}.pt'))
    model_short.cuda()
    model_long.cuda()
    model_short.eval()
    model_long.eval()

    public_preds = model_short(public_inputs.cuda(), public_adj.cuda())
    private_preds = model_long(private_inputs.cuda(), private_adj.cuda())
    public_preds = public_preds.cpu().detach().numpy()
    private_preds = private_preds.cpu().detach().numpy()
    
    list_public_preds.append(public_preds)
    list_private_preds.append(private_preds)

In [15]:
public_preds = np.mean(list_public_preds, axis=0)
private_preds = np.mean(list_private_preds, axis=0)

In [16]:
preds_ls = []

for df, preds in [(public_df, public_preds), (private_df, private_preds)]:
    for i, uid in enumerate(df.id):
        single_pred = preds[i]

        single_df = pd.DataFrame(single_pred, columns=pred_cols)
        single_df['id_seqpos'] = [f'{uid}_{x}' for x in range(single_df.shape[0])]

        preds_ls.append(single_df)

preds_df = pd.concat(preds_ls)

In [17]:
submission = sample_df[['id_seqpos']].merge(preds_df, on=['id_seqpos'])
submission.to_csv(f'{config.data_dir}/submission.csv', index=False)

In [18]:
# 0.6 - 0.28785
# 0.4 - 0.28630