# Import

In [153]:
import pandas as pd
import numpy as np
import matplotlib.pylab as plt
import json
import seaborn as sns
import os
import random
from tqdm.notebook import tqdm
import plotly.express as px

import lightgbm as lgb
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, Dataset, DataLoader, random_split
from torch.nn import functional as F
from torch.optim import lr_scheduler

# Settings

In [154]:
class config:
    data_dir = '../OpenVaccine/'
    train_file = '../OpenVaccine/train.json'
    test_file = '../OpenVaccine/test.json'
    pretrain_dir = '../OpenVaccine/pretrains/'
    sample_submission = '../OpenVaccine/sample_submission.csv'
    learning_rate = 0.001
    batch_size = 64
    n_epoch = 100
    n_split = 5
    K = 1
    gcn_agg = 'mean'
    filter_noise = True
    pred_cols = ['reactivity', 'deg_Mg_pH10', 'deg_pH10', 'deg_Mg_50C', 'deg_50C']
    seed = 1234

# Utils

In [155]:
class AverageMeter:
    """
    Computes and stores the average and current value
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [156]:
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
seed_everything(config.seed)

# Model

In [157]:
class GCN(nn.Module):
    def __init__(self, input_dim, output_dim, aggregator='mean'):
        super(GCN, self).__init__()
        self.aggregator = aggregator
        
        if aggregator == 'mean':
            linear_input_dim = input_dim * 2
        elif aggregator == 'conv':
            linear_input_dim = input_dim
        elif aggregator == 'pooling':
            linear_input_dim = input_dim * 2
            self.linear_pooling = nn.Linear(input_dim, input_dim)
        elif aggregator == 'lstm':
            self.lstm_hidden = 128
            linear_input_dim = input_dim + self.lstm_hidden
            self.lstm_agg = nn.LSTM(input_dim, self.lstm_hidden, num_layers=1, batch_first=True)
        
        self.linear_gcn = nn.Linear(in_features=linear_input_dim, out_features=output_dim)
        
    def forward(self, input_, adj_matrix):
        if self.aggregator == 'conv':
            # set elements in diagonal of adj matrix to 1 with conv aggregator
            idx = torch.arange(0, adj_matrix.shape[-1], out=torch.LongTensor())
            adj_matrix[:, idx, idx] = 1
            
        adj_matrix = adj_matrix.type(torch.float32)
        sum_adj = torch.sum(adj_matrix, axis=2)
        sum_adj[sum_adj==0] = 1
        
        if self.aggregator == 'mean' or self.aggregator == 'conv':
            feature_agg = torch.bmm(adj_matrix, input_)
            feature_agg = feature_agg / sum_adj.unsqueeze(dim=2)
            
        elif self.aggregator == 'pooling':
            feature_pooling = self.linear_pooling(input_)
            feature_agg = torch.sigmoid(feature_pooling)
            feature_agg = torch.bmm(adj_matrix, feature_agg)
            feature_agg = feature_agg / sum_adj.unsqueeze(dim=2)

        elif self.aggregator == 'lstm':
            feature_agg = torch.zeros(input_.shape[0], input_.shape[1], self.lstm_hidden).cuda()
            for i in range(adj_matrix.shape[1]):
                neighbors = adj_matrix[:, i, :].unsqueeze(2) * input_
                _, hn = self.lstm_agg(neighbors)
                feature_agg[:, i, :] = torch.squeeze(hn[0], 0)
                
        if self.aggregator != 'conv':
            feature_cat = torch.cat((input_, feature_agg), axis=2)
        else:
            feature_cat = feature_agg
                
        feature = torch.sigmoid(self.linear_gcn(feature_cat))
        feature = feature / torch.norm(feature, p=2, dim=2).unsqueeze(dim=2)
        
        return feature
        
    
class Net(nn.Module):
    def __init__(self, num_embedding=14, seq_len=107, pred_len=68, dropout=0.5, 
                 embed_dim=100, hidden_dim=128, K=1, aggregator='mean'):
        super(Net, self).__init__()
        
        self.pred_len = pred_len
        self.embedding_layer = nn.Embedding(num_embeddings=num_embedding, 
                                      embedding_dim=embed_dim)
        
        self.gcn = nn.ModuleList([GCN(3 * embed_dim, 3 * embed_dim, aggregator=aggregator) for i in range(K)])
        
        self.gru_layer = nn.GRU(input_size=3 * embed_dim, 
                          hidden_size=hidden_dim, 
                          num_layers=3, 
                          batch_first=True, 
                          dropout=dropout, 
                          bidirectional=True)
        
        self.linear_layer = nn.Linear(in_features=2 * hidden_dim, 
                                out_features=len(config.pred_cols))
        
    def forward(self, input_, adj_matrix):
        #embedding
        embedding = self.embedding_layer(input_)
        embedding = torch.reshape(embedding, (-1, embedding.shape[1], embedding.shape[2] * embedding.shape[3]))
        
        #gcn
        gcn_feature = embedding
        for gcn_layer in self.gcn:
            gcn_feature = gcn_layer(gcn_feature, adj_matrix)
        
        #gru
        gru_output, gru_hidden = self.gru_layer(gcn_feature)
        truncated = gru_output[:, :self.pred_len]
        
        output = self.linear_layer(truncated)
        
        return output

# Load Data

In [158]:
token2int = {x:i for i, x in enumerate('().ACGUBEHIMSX')}

def get_couples(structure):
    """
    For each closing parenthesis, I find the matching opening one and store their index in the couples list.
    The assigned list is used to keep track of the assigned opening parenthesis
    """
    opened = [idx for idx, i in enumerate(structure) if i == '(']
    closed = [idx for idx, i in enumerate(structure) if i == ')']

    assert len(opened) == len(closed)
    assigned = []
    couples = []

    for close_idx in closed:
        for open_idx in opened:
            if open_idx < close_idx:
                if open_idx not in assigned:
                    candidate = open_idx
            else:
                break
        assigned.append(candidate)
        couples.append([candidate, close_idx])
        
    assert len(couples) == len(opened)
    
    return couples

def build_matrix(couples, size):
    mat = np.zeros((size, size))
    
    for i in range(size):  # neigbouring bases are linked as well
        if i < size - 1:
            mat[i, i + 1] = 1
        if i > 0:
            mat[i, i - 1] = 1
    
    for i, j in couples:
        mat[i, j] = 1
        mat[j, i] = 1
        
    return mat

def convert_to_adj(structure):
    couples = get_couples(structure)
    mat = build_matrix(couples, len(structure))
    return mat

def bpp_to_adj(id):
    bpp_file = f'{config.data_dir}/bpps/{id}.npy'
    bpp = np.load(bpp_file)
    adj = bpp
#     adj = np.where(bpp > 0, 1, 0)
    return adj

def preprocess_inputs(df, cols=['sequence', 'structure', 'predicted_loop_type']):
    inputs = np.transpose(
        np.array(
            df[cols]
            .applymap(lambda seq: [token2int[x] for x in seq])
            .values
            .tolist()
        ),
        (0, 2, 1)
    )
    
    adj_matrix = np.array(df['structure'].apply(convert_to_adj).values.tolist())
#     adj_matrix = np.array(df['id'].apply(bpp_to_adj).values.tolist())
    
    return inputs, adj_matrix

In [159]:
train = pd.read_json(config.train_file, lines=True)

if config.filter_noise:
    train = train[train.SN_filter == 1]
    
test = pd.read_json(config.test_file, lines=True)
sample_df = pd.read_csv(config.sample_submission)

In [160]:
train_inputs, train_adj = preprocess_inputs(train)
train_labels = np.array(train[config.pred_cols].values.tolist()).transpose((0, 2, 1))

train_inputs = torch.tensor(train_inputs, dtype=torch.long)
train_adj = torch.tensor(train_adj, dtype=torch.float32)
train_labels = torch.tensor(train_labels, dtype=torch.float32)

# Train

In [161]:
def train_fn(epoch, model, train_loader, criterion, optimizer):
    model.train()
    model.zero_grad()
    train_loss = AverageMeter()
    
    for index, (input_, adj, label) in enumerate(train_loader):
        input_ = input_.cuda()
        adj = adj.cuda()
        label = label.cuda()
        preds = model(input_, adj)
        
        loss = criterion(preds, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss.update(loss.item())
    
    print(f"Train loss {train_loss.avg}")
    return train_loss.avg
    
def eval_fn(epoch, model, valid_loader, criterion):
    model.eval()
    eval_loss = AverageMeter()
    oof_label = None
    oof_pred = None
    
    for index, (input_, adj, label) in enumerate(valid_loader):
        input_ = input_.cuda()
        adj = adj.cuda()
        label = label.cuda()
        preds = model(input_, adj)
        
        loss = criterion(preds, label)
        eval_loss.update(loss.item())
        if oof_label is None:
            oof_label = label.cpu().detach()
            oof_pred = preds.cpu().detach()
        else:
            oof_label = torch.cat([oof_label, label.cpu().detach()], 0)
            oof_pred = torch.cat([oof_pred, preds.cpu().detach()], 0)
    
    print(f"Valid loss {eval_loss.avg}")
    return eval_loss.avg, oof_label, oof_pred

In [162]:
def MCRMSE(y_true, y_pred):
    colwise_mse = torch.mean(torch.square(y_true - y_pred), axis=1)
    return torch.mean(torch.sqrt(colwise_mse))

class RMSELoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.mse = torch.nn.MSELoss()
        self.eps = eps

    def forward(self, yhat, y):
        loss = torch.sqrt(self.mse(yhat, y) + self.eps)
        return loss
    
class MCRMSELoss(nn.Module):
    def __init__(self, num_scored=3):
        super().__init__()
        self.rmse = RMSELoss()
        self.num_scored = num_scored

    def forward(self, yhat, y):
        score = 0
        for i in range(self.num_scored):
            score += self.rmse(yhat[:, :, i], y[:, :, i]) / self.num_scored

        return score
    
def run(fold, train_loader, valid_loader):
    model = Net(K=config.K, aggregator=config.gcn_agg)
    model.cuda()
    criterion = MCRMSE # MCRMSELoss(num_scored=5) #torch.nn.MSELoss()
    optimizer = torch.optim.Adam(params=model.parameters(), lr=config.learning_rate, weight_decay=0.0)
    
    train_losses = []
    eval_losses = []
    for epoch in range(config.n_epoch):
        print('#################')
        print('###Epoch:', epoch)
        
        train_loss = train_fn(epoch, model, train_loader, criterion, optimizer)
        eval_loss, oof_label, oof_pred = eval_fn(epoch, model, valid_loader, criterion)
        train_losses.append(train_loss)
        eval_losses.append(eval_loss)
        
    torch.save(model.state_dict(), f'{config.pretrain_dir}/gru_{fold}.pt')
    return train_losses, eval_losses, oof_label, oof_pred

In [163]:
splits = KFold(n_splits=config.n_split, shuffle=True, random_state=config.seed).split(train_inputs)
# splits = StratifiedKFold(n_splits=config.n_split, shuffle=True, random_state=config.seed).split(train['id'], train['signal_to_noise'].apply(lambda x: int(x)))

overall_5_train_loss = AverageMeter()
overall_5_valid_loss = AverageMeter()

overall_10_train_loss = AverageMeter()
overall_10_valid_loss = AverageMeter()
oof_label_all = None
oof_pred_all = None

for fold, (train_idx, val_idx) in enumerate(splits):
    
    train_dataset = TensorDataset(train_inputs[train_idx], train_adj[train_idx], train_labels[train_idx])
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8)
    
    valid_dataset = TensorDataset(train_inputs[val_idx], train_adj[val_idx], train_labels[val_idx])
    valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=8)
    
    train_losses, eval_losses, oof_label, oof_pred = run(fold, train_loader, valid_loader)
    
    if oof_label_all is None:
        oof_label_all = oof_label
        oof_pred_all = oof_pred
    else:
        oof_label_all = torch.cat([oof_label_all, oof_label], 0)
        oof_pred_all = torch.cat([oof_pred_all, oof_pred], 0)
    
    last_5_train_loss = sum(train_losses[-5:])/len(train_losses[-5:])
    last_5_valid_loss = sum(eval_losses[-5:])/len(eval_losses[-5:])
    overall_5_train_loss.update(last_5_train_loss)
    overall_5_valid_loss.update(last_5_valid_loss)
    
    last_10_train_loss = sum(train_losses[-10:])/len(train_losses[-10:])
    last_10_valid_loss = sum(eval_losses[-10:])/len(eval_losses[-10:])
    overall_10_train_loss.update(last_10_train_loss)
    overall_10_valid_loss.update(last_10_valid_loss)
    
    print(f'Last 5 epoch train: {last_5_train_loss} - valid: {last_5_valid_loss}')
    print(f'Last 10 epoch train: {last_10_train_loss} - valid: {last_10_valid_loss}')
    
print(f'Overall 5 last epoch train: {overall_5_train_loss.avg} - valid: {overall_5_valid_loss.avg}')
print(f'Overall 10 last epoch train: {overall_10_train_loss.avg} - valid: {overall_10_valid_loss.avg}')
print(f'OOF CV: {MCRMSE(oof_label_all, oof_pred_all)}')

#################
###Epoch: 0
Train loss 0.4736518368124962
Valid loss 0.41888638734817507
#################
###Epoch: 1
Train loss 0.40256469547748563
Valid loss 0.37292989492416384
#################
###Epoch: 2
Train loss 0.3741307511925697
Valid loss 0.35383965969085696
#################
###Epoch: 3
Train loss 0.3538067176938057
Valid loss 0.33535735607147216
#################
###Epoch: 4
Train loss 0.3359106555581093
Valid loss 0.33099058270454407
#################
###Epoch: 5
Train loss 0.32673446387052535
Valid loss 0.31456226110458374
#################
###Epoch: 6
Train loss 0.31805810928344724
Valid loss 0.307984322309494
#################
###Epoch: 7
Train loss 0.31302081793546677
Valid loss 0.3049154162406921
#################
###Epoch: 8
Train loss 0.3086333736777306
Valid loss 0.3012210190296173
#################
###Epoch: 9
Train loss 0.304643914103508
Valid loss 0.2953181266784668
#################
###Epoch: 10
Train loss 0.30137674808502196
Valid loss 0.29262409210205076

Train loss 0.19137678891420365
Valid loss 0.21653702855110168
#################
###Epoch: 90
Train loss 0.19023684561252593
Valid loss 0.2167837530374527
#################
###Epoch: 91
Train loss 0.1896523840725422
Valid loss 0.21967997252941132
#################
###Epoch: 92
Train loss 0.18934932202100754
Valid loss 0.21901734471321105
#################
###Epoch: 93
Train loss 0.18840970620512962
Valid loss 0.21635543406009675
#################
###Epoch: 94
Train loss 0.18871036544442177
Valid loss 0.2182433396577835
#################
###Epoch: 95
Train loss 0.1875872418284416
Valid loss 0.21841321885585785
#################
###Epoch: 96
Train loss 0.1866515427827835
Valid loss 0.215856671333313
#################
###Epoch: 97
Train loss 0.18727300241589545
Valid loss 0.21898153722286223
#################
###Epoch: 98
Train loss 0.18649930730462075
Valid loss 0.21650962829589843
#################
###Epoch: 99
Train loss 0.18656896129250528
Valid loss 0.21737924516201018
Last 5 epoch tr

Train loss 0.19731983095407485
Valid loss 0.22326373755931855
#################
###Epoch: 78
Train loss 0.19700249582529067
Valid loss 0.22562285363674164
#################
###Epoch: 79
Train loss 0.1967194303870201
Valid loss 0.22315486371517182
#################
###Epoch: 80
Train loss 0.19572677239775657
Valid loss 0.22158203423023223
#################
###Epoch: 81
Train loss 0.1951096959412098
Valid loss 0.22384800612926484
#################
###Epoch: 82
Train loss 0.194117022305727
Valid loss 0.22162909507751466
#################
###Epoch: 83
Train loss 0.1937007375061512
Valid loss 0.22186661660671234
#################
###Epoch: 84
Train loss 0.19414836615324021
Valid loss 0.22328489124774933
#################
###Epoch: 85
Train loss 0.19338174313306808
Valid loss 0.22353942096233367
#################
###Epoch: 86
Train loss 0.19236019775271415
Valid loss 0.2221648007631302
#################
###Epoch: 87
Train loss 0.1918291375041008
Valid loss 0.22426466047763824
###############

Train loss 0.2053619660437107
Valid loss 0.21968188285827636
#################
###Epoch: 66
Train loss 0.2042568065226078
Valid loss 0.22228715419769288
#################
###Epoch: 67
Train loss 0.2048453189432621
Valid loss 0.2203601211309433
#################
###Epoch: 68
Train loss 0.20389125049114226
Valid loss 0.2202998638153076
#################
###Epoch: 69
Train loss 0.2027270309627056
Valid loss 0.21995292901992797
#################
###Epoch: 70
Train loss 0.2023948110640049
Valid loss 0.21831955909729003
#################
###Epoch: 71
Train loss 0.20151490420103074
Valid loss 0.2194102495908737
#################
###Epoch: 72
Train loss 0.2011263146996498
Valid loss 0.21895726323127745
#################
###Epoch: 73
Train loss 0.20000938847661018
Valid loss 0.22084712088108063
#################
###Epoch: 74
Train loss 0.19889691472053528
Valid loss 0.21895231902599335
#################
###Epoch: 75
Train loss 0.1991618812084198
Valid loss 0.21850109100341797
#################


Train loss 0.21353540495038031
Valid loss 0.2347470462322235
#################
###Epoch: 54
Train loss 0.21419822573661804
Valid loss 0.23242753148078918
#################
###Epoch: 55
Train loss 0.21261354833841323
Valid loss 0.2311989575624466
#################
###Epoch: 56
Train loss 0.21099662482738496
Valid loss 0.23275499641895295
#################
###Epoch: 57
Train loss 0.2107037477195263
Valid loss 0.23236227929592132
#################
###Epoch: 58
Train loss 0.2102730877697468
Valid loss 0.23277421295642853
#################
###Epoch: 59
Train loss 0.20978924185037612
Valid loss 0.2306637465953827
#################
###Epoch: 60
Train loss 0.2085300363600254
Valid loss 0.23120792508125304
#################
###Epoch: 61
Train loss 0.20801911279559135
Valid loss 0.23092640936374664
#################
###Epoch: 62
Train loss 0.20676039382815362
Valid loss 0.23051493763923644
#################
###Epoch: 63
Train loss 0.20555393770337105
Valid loss 0.23307294249534607
##############

Train loss 0.23018686175346376
Valid loss 0.2268580377101898
#################
###Epoch: 42
Train loss 0.2292154885828495
Valid loss 0.22333547472953796
#################
###Epoch: 43
Train loss 0.22767195329070092
Valid loss 0.2219638019800186
#################
###Epoch: 44
Train loss 0.22575237303972245
Valid loss 0.22072623670101166
#################
###Epoch: 45
Train loss 0.2243036575615406
Valid loss 0.22195352911949157
#################
###Epoch: 46
Train loss 0.22720347940921784
Valid loss 0.22165523171424867
#################
###Epoch: 47
Train loss 0.22238143905997276
Valid loss 0.22189440429210663
#################
###Epoch: 48
Train loss 0.22111476138234137
Valid loss 0.2182176321744919
#################
###Epoch: 49
Train loss 0.22171741500496864
Valid loss 0.2184634953737259
#################
###Epoch: 50
Train loss 0.22038271352648736
Valid loss 0.22101665437221527
#################
###Epoch: 51
Train loss 0.21946285888552666
Valid loss 0.21797596514225007
##############

In [None]:
Overall 5 epoch train: 0.0451215458744102 - valid: 0.06322710109608513
Overall 10 epoch train: 0.04563371909713303 - valid: 0.06327088166560445
        
Use BPP as weight adj matrix
GRU
Overall 5 epoch train: 0.0430864584280385 - valid: 0.06221413508057594
Overall 10 epoch train: 0.04368967310146049 - valid: 0.06239632720393794
LSTM
Overall 5 last epoch train: 0.041662138899167374 - valid: 0.06205343489136015
Overall 10 last epoch train: 0.042378014334374006 - valid: 0.06232069560459682
GRU -> GCN
Overall 5 last epoch train: 0.04252960059377882 - valid: 0.06481357118913104
Overall 10 last epoch train: 0.04310291423841759 - valid: 0.0650169599056244
SN_filter
Last 5 epoch train: 0.03829054716974497 - valid: 0.054680862873792646
Last 10 epoch train: 0.038869887720793486 - valid: 0.05474848434329033
Overall 5 last epoch train: 0.03866128552705049 - valid: 0.055820192784070966
Overall 10 last epoch train: 0.03917157341167331 - valid: 0.055871706023812286
3 labels
Overall 5 last epoch train: 0.033922446276992556 - valid: 0.05792523550987243
Overall 10 last epoch train: 0.03448757581226528 - valid: 0.05792512205243111
5 labels MCRMSE loss
Overall 5 last epoch train: 0.18149467387795448 - valid: 0.21504289948940278
Overall 10 last epoch train: 0.18257398997247218 - valid: 0.21543894702196117
Overall 5 last epoch train: 0.18749359598755838 - valid: 0.21788148260116577
Overall 10 last epoch train: 0.18864331699907777 - valid: 0.21844578289985658
MCRMSE loss class
Overall 5 last epoch train: 0.1900108553469181 - valid: 0.23267986249923706
Overall 10 last epoch train: 0.1914477916508913 - valid: 0.23293828153610235
Overall 5 last epoch train: 0.19669136756658553 - valid: 0.23536729907989504
Overall 10 last epoch train: 0.19817317792773242 - valid: 0.23596047580242155

GPU + adj + bpp
Overall 5 last epoch train: 0.044193561888403364 - valid: 0.06220439549003328
Overall 10 last epoch train: 0.0448748238619279 - valid: 0.06226083738463265
        
K = 2
Use BPP as weight adj matrix
Overall 5 last epoch train: 0.05076819153295623 - valid: 0.06468318649700709
Overall 10 last epoch train: 0.051428089461944715 - valid: 0.06473811021872929
        
Convert BPP to adj matrix (1 and 0)
Overall 5 last epoch train: 0.04308395972406422 - valid: 0.06330352274434907
Overall 10 last epoch train: 0.0436869018221343 - valid: 0.06328786111303739
        
Conv aggregator
Overall 5 last epoch train: 0.046746739074587824 - valid: 0.06578907604728426
Overall 10 last epoch train: 0.047424182744213825 - valid: 0.06575814057673728
        
Pooling aggregator
Overall 5 last epoch train: 0.04330202187928889 - valid: 0.06245726023401532
Overall 10 last epoch train: 0.04384024161155577 - valid: 0.06253742285072802
        
learning_rate = 0.002
Use BPP as weight adj matrix
Mean aggregator
Overall 5 last epoch train: 0.035587881477894615 - valid: 0.06147534329976354
Overall 10 last epoch train: 0.03606973063890581 - valid: 0.06147345149091312

learning_rate = 0.003
Overall 5 last epoch train: 0.03372417225605912 - valid: 0.06138800899897303
Overall 10 last epoch train: 0.034225094819234476 - valid: 0.0614313035777637
        
learning_rate = 0.005
Overall 5 last epoch train: 0.03593841718026885 - valid: 0.06251461637871605
Overall 10 last epoch train: 0.036377053264942436 - valid: 0.06247474372386932
        
learning_rate = 0.003
No GCN
Overall 5 last epoch train: 0.042106328460353386 - valid: 0.06521862102406364
Overall 10 last epoch train: 0.042440810569182585 - valid: 0.06519380110715116
        
Dropout 0.2, with GCN
Overall 5 last epoch train: 0.024202211432986794 - valid: 0.06166929091726031
Overall 10 last epoch train: 0.024575392231345174 - valid: 0.06173114275293692

Epoch 0
Last 5 epoch train: 0.0329512511552484 - valid: 0.06581846997141838
Last 10 epoch train: 0.03340453362023389 - valid: 0.06565298672233309

# Visualize

In [125]:
fig = px.line(
    pd.DataFrame([train_losses, eval_losses], index=['loss', 'val_loss']).T, 
    y=['loss', 'val_loss'], 
    labels={'index': 'epoch', 'value': 'Mean Squared Error'}, 
    title='Training History')
fig.show()

# Test

In [126]:
public_df = test.query("seq_length == 107").copy()
private_df = test.query("seq_length == 130").copy()

public_inputs, public_adj = preprocess_inputs(public_df)
private_inputs, private_adj = preprocess_inputs(private_df)

public_inputs = torch.tensor(public_inputs, dtype=torch.long)
private_inputs = torch.tensor(private_inputs, dtype=torch.long)
public_adj = torch.tensor(public_adj, dtype=torch.long)
private_adj = torch.tensor(private_adj, dtype=torch.long)

In [127]:
model_short = Net(seq_len=107, pred_len=107, K=config.K, aggregator=config.gcn_agg)
model_long = Net(seq_len=130, pred_len=130, K=config.K, aggregator=config.gcn_agg)

list_public_preds = []
list_private_preds = []
for fold in range(config.n_split):
    model_short.load_state_dict(torch.load(f'{config.pretrain_dir}/gru_{fold}.pt'))
    model_long.load_state_dict(torch.load(f'{config.pretrain_dir}/gru_{fold}.pt'))
    model_short.cuda()
    model_long.cuda()
    model_short.eval()
    model_long.eval()

    public_preds = model_short(public_inputs.cuda(), public_adj.cuda())
    private_preds = model_long(private_inputs.cuda(), private_adj.cuda())
    public_preds = public_preds.cpu().detach().numpy()
    private_preds = private_preds.cpu().detach().numpy()
    
    list_public_preds.append(public_preds)
    list_private_preds.append(private_preds)

In [128]:
public_preds = np.mean(list_public_preds, axis=0)
private_preds = np.mean(list_private_preds, axis=0)

In [129]:
preds_ls = []

for df, preds in [(public_df, public_preds), (private_df, private_preds)]:
    for i, uid in enumerate(df.id):
        single_pred = preds[i]

        single_df = pd.DataFrame(single_pred, columns=config.pred_cols)
        single_df['id_seqpos'] = [f'{uid}_{x}' for x in range(single_df.shape[0])]

        preds_ls.append(single_df)

preds_df = pd.concat(preds_ls)

In [130]:
submission = sample_df[['id_seqpos']].merge(preds_df, on=['id_seqpos'])
submission.to_csv(f'{config.data_dir}/submission.csv', index=False)

In [None]:
# 0.6 - 0.28785
# 0.4 - 0.28630