In [23]:
import pandas as pd
import numpy as np
import matplotlib.pylab as plt
import json
import seaborn as sns
import os
import random
from tqdm.notebook import tqdm
import plotly.express as px

import lightgbm as lgb
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, Dataset, DataLoader, random_split
from torch.nn import functional as F
from torch.optim import lr_scheduler

In [54]:
class config:
    train_file = './stanford-covid-vaccine/train.json'
    test_file = './stanford-covid-vaccine/test.json'
    pretrain_dir = './GCN_lstm_patience10'
    sample_submission = './stanford-covid-vaccine/sample_submission.csv'
    learning_rate = 0.001
    batch_size = 64
    n_epoch = 200
    n_split = 5
    K = 1 # number of aggregation loop (also means number of GCN layers)
    gcn_agg = 'mean' # aggregator function: mean, conv, lstm, pooling
    filter_noise = True
    patience= 10
    seed = 1234
    loss_weights = torch.tensor([0.3,0.3,0.05,0.3,0.05]).cuda()
    pretrain_epoch = 1

In [25]:
class AverageMeter:
    """
    Computes and stores the average and current value
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [26]:
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
seed_everything(config.seed)

## model 

In [27]:
class GCN(nn.Module):
    '''
    Implementation of one layer of GraphSAGE
    '''
    def __init__(self, input_dim, output_dim, aggregator=config.gcn_agg):
        super(GCN, self).__init__()
        self.aggregator = aggregator
        
        if aggregator == 'mean':
            linear_input_dim = input_dim * 2
        elif aggregator == 'conv':
            linear_input_dim = input_dim
#         elif aggregator == 'pooling':
#             linear_input_dim = input_dim * 2
#             self.linear_pooling = nn.Linear(input_dim, input_dim)
        elif aggregator == 'lstm':
            self.lstm_hidden = 64
            linear_input_dim = input_dim + self.lstm_hidden
            self.lstm_agg = nn.LSTM(input_dim, self.lstm_hidden, num_layers=1, batch_first=True)
        
        self.linear_gcn = nn.Linear(in_features=linear_input_dim, out_features=output_dim)
        
    def forward(self, input_, adj_matrix):
        if self.aggregator == 'conv':
            # set elements in diagonal of adj matrix to 1 with conv aggregator
            idx = torch.arange(0, adj_matrix.shape[-1], out=torch.LongTensor())
            adj_matrix[:, idx, idx] = 1
            
        adj_matrix = adj_matrix.type(torch.float32)
        sum_adj = torch.sum(adj_matrix, axis=2)
        sum_adj[sum_adj==0] = 1
        
        if self.aggregator == 'mean' or self.aggregator == 'conv':
            feature_agg = torch.bmm(adj_matrix, input_)
            feature_agg = feature_agg / sum_adj.unsqueeze(dim=2)
            
#         elif self.aggregator == 'pooling':
#             feature_pooling = self.linear_pooling(input_)
#             feature_agg = torch.sigmoid(feature_pooling)
#             feature_agg = torch.bmm(adj_matrix, feature_agg)
#             feature_agg = feature_agg / sum_adj.unsqueeze(dim=2)

        elif self.aggregator == 'lstm':
            feature_agg = torch.zeros(input_.shape[0], input_.shape[1], self.lstm_hidden).cuda()
            for i in range(adj_matrix.shape[1]):
                neighbors = adj_matrix[:, i, :].unsqueeze(2) * input_
                _, hn = self.lstm_agg(neighbors)
                feature_agg[:, i, :] = torch.squeeze(hn[0], 0)
                
        if self.aggregator != 'conv':
            feature_cat = torch.cat((input_, feature_agg), axis=2)
        else:
            feature_cat = feature_agg
                
        feature = torch.sigmoid(self.linear_gcn(feature_cat))
        feature = feature / torch.norm(feature, p=2, dim=2).unsqueeze(dim=2)
        
        return feature
        

In [44]:
class Net(nn.Module):
    def __init__(self, num_embedding=14, seq_len=107, pred_len=68, dropout=0.5, 
                 embed_dim=100, hidden_dim=128, K=1, aggregator='mean'):
        '''
        K: number of GCN layers
        aggregator: type of aggregator function
        '''
        super(Net, self).__init__()
        
        self.pred_len = pred_len
        self.embedding_layer = nn.Embedding(num_embeddings=num_embedding, 
                                      embedding_dim=embed_dim)
        
        self.gcn = nn.ModuleList([GCN(3 * embed_dim, 3 * embed_dim, aggregator=aggregator) for i in range(K)])
        
        self.gru_layer = nn.GRU(input_size=3 * embed_dim, 
                          hidden_size=hidden_dim, 
                          num_layers=3, 
                          batch_first=True, 
                          dropout=dropout, 
                          bidirectional=True)
        
        self.linear_layer = nn.Linear(in_features=2 * hidden_dim, 
                                out_features=5)
        
        self.dropout_layer = nn.Dropout(p=0.5)
    def forward(self, input_, adj_matrix):
        #embedding
        embedding = self.embedding_layer(input_)
        embedding = torch.reshape(embedding, (-1, embedding.shape[1], embedding.shape[2] * embedding.shape[3]))
        
        if self.training == True:
            embedding = self.dropout_layer(embedding.float())
        
        #gcn
        gcn_feature = embedding
        for gcn_layer in self.gcn:
            gcn_feature = gcn_layer(gcn_feature, adj_matrix)
        
        #gru
        gru_output, gru_hidden = self.gru_layer(gcn_feature)
        truncated = gru_output[:, :self.pred_len]
        
        output = self.linear_layer(truncated)
        
        return output

In [45]:
class autoEncoder(nn.Module):
    def __init__(self, num_embedding=14, seq_len=107, pred_len=68, dropout=0.5, 
                 embed_dim=100, hidden_dim=128, K=1, aggregator='mean'):
        '''
        K: number of GCN layers
        aggregator: type of aggregator function
        '''
        super(autoEncoder, self).__init__()
        
        self.pred_len = pred_len
        self.embedding_layer = nn.Embedding(num_embeddings=num_embedding, 
                                      embedding_dim=embed_dim)
        
        self.gcn = nn.ModuleList([GCN(3 * embed_dim, 3 * embed_dim, aggregator=aggregator) for i in range(K)])
        
        
        self.linear_layer1 = nn.Linear(in_features=3 * embed_dim, 
                                out_features=128)
        
        self.linear_layer2 = nn.Linear(in_features=128,
                                out_features=3)
        
        self.dropout_layer = nn.Dropout(p=0.5)
        
    def forward(self, input_, adj_matrix):
        #embedding
        
        embedding = self.embedding_layer(input_)
        if self.training == True:
            embedding = self.dropout_layer(embedding.float())
        embedding = torch.reshape(embedding, (-1, embedding.shape[1], embedding.shape[2] * embedding.shape[3]))
        
        #gcn
        gcn_feature = embedding
        for gcn_layer in self.gcn:
            gcn_feature = gcn_layer(gcn_feature, adj_matrix)
            
        output = self.linear_layer1(gcn_feature)
        output = self.linear_layer2(output)
        return output

## load data

In [46]:
pred_cols = ['reactivity', 'deg_Mg_pH10', 'deg_pH10', 'deg_Mg_50C', 'deg_50C']

In [47]:
token2int = {x:i for i, x in enumerate('().ACGUBEHIMSX')}

def get_couples(structure):
    """
    For each closing parenthesis, I find the matching opening one and store their index in the couples list.
    The assigned list is used to keep track of the assigned opening parenthesis
    """
    opened = [idx for idx, i in enumerate(structure) if i == '(']
    closed = [idx for idx, i in enumerate(structure) if i == ')']

    assert len(opened) == len(closed)
    assigned = []
    couples = []

    for close_idx in closed:
        for open_idx in opened:
            if open_idx < close_idx:
                if open_idx not in assigned:
                    candidate = open_idx
            else:
                break
        assigned.append(candidate)
        couples.append([candidate, close_idx])
        
    assert len(couples) == len(opened)
    
    return couples

def build_matrix(couples, size):
    mat = np.zeros((size, size))
    
    for i in range(size):  # neigbouring bases are linked as well
        if i < size - 1:
            mat[i, i + 1] = 1
        if i > 0:
            mat[i, i - 1] = 1
    
    for i, j in couples:
        mat[i, j] = 1
        mat[j, i] = 1
        
    return mat

def convert_to_adj(structure):
    couples = get_couples(structure)
    mat = build_matrix(couples, len(structure))
    return mat

def preprocess_inputs(df, cols=['sequence', 'structure', 'predicted_loop_type']):
    inputs = np.transpose(
        np.array(
            df[cols]
            .applymap(lambda seq: [token2int[x] for x in seq])
            .values
            .tolist()
        ),
        (0, 2, 1)
    )
    
    adj_matrix = np.array(df['structure'].apply(convert_to_adj).values.tolist())
    
    return inputs, adj_matrix

In [48]:
train = pd.read_json(config.train_file, lines=True)

if config.filter_noise:
    train = train[train.signal_to_noise > 1]
    
test = pd.read_json(config.test_file, lines=True)
sample_df = pd.read_csv(config.sample_submission)

In [49]:
train_inputs, train_adj = preprocess_inputs(train)
train_labels = np.array(train[pred_cols].values.tolist()).transpose((0, 2, 1))

train_inputs = torch.tensor(train_inputs, dtype=torch.long)
train_adj = torch.tensor(train_adj, dtype=torch.long)
train_labels = torch.tensor(train_labels, dtype=torch.float32)

### Train KFold

In [50]:
def train_fn(epoch, model, train_loader, criterion, optimizer):
    model.train()
    model.zero_grad()
    train_loss = AverageMeter()
    
    for index, (input_, adj, label) in enumerate(train_loader):
        input_ = input_.cuda()
        adj = adj.cuda()
        label = label.cuda()
        preds = model(input_, adj)
        if criterion == weighted_mse_loss:
            loss = criterion(preds, label, config.loss_weights)
        else:
            loss = criterion(preds.float(), label.float())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss.update(loss.item())
    
    print(f"Train loss {train_loss.avg}")
    return train_loss.avg
    
def eval_fn(epoch, model, valid_loader, criterion):
    model.eval()
    eval_loss = AverageMeter()
    
    for index, (input_, adj, label) in enumerate(valid_loader):
        input_ = input_.cuda()
        adj = adj.cuda()
        label = label.cuda()
        preds = model(input_, adj)
        
        loss = criterion(preds, label, config.loss_weights)
        eval_loss.update(loss.item())
    
    print(f"Valid loss {eval_loss.avg}")
    return eval_loss.avg

In [51]:
def weighted_mse_loss(pred_, label, weight):
    num = pred_.shape[0] * pred_.shape[1]
    return torch.sum(weight * (pred_ - label) ** 2)/num

### pretrain model

In [37]:
def pretrain(train_loader):
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(params=model.parameters(), lr=config.learning_rate, weight_decay=0.0)
    
    train_losses = []
    for epoch in range(config.pretrain_epoch):
        train_loss = train_fn(epoch, model, train_loader, criterion, optimizer)
        train_losses.append(train_loss)
    return train_losses

In [38]:
public_df = test.query("seq_length == 107").copy()
private_df = test.query("seq_length == 130").copy()

public_inputs, public_adj = preprocess_inputs(public_df)
private_inputs, private_adj = preprocess_inputs(private_df)

public_inputs = torch.tensor(public_inputs, dtype=torch.long)
private_inputs = torch.tensor(private_inputs, dtype=torch.long)
public_adj = torch.tensor(public_adj, dtype=torch.long)
private_adj = torch.tensor(private_adj, dtype=torch.long)

In [39]:
# train different aggregators
aggregator = "mean"
config.gcn_agg = aggregator
config.pretrain_dir = "GCN_pretrain"

short_inputs = torch.cat((train_inputs, public_inputs), 0)
short_adj = torch.cat((train_adj, public_adj), 0)
short_dataset = TensorDataset(short_inputs, short_adj, short_inputs)
short_loader = DataLoader(short_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8)

private_dataset = TensorDataset(private_inputs, private_adj, private_inputs)
private_loader = DataLoader(private_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8)

In [40]:
model = autoEncoder(K=config.K, aggregator=config.gcn_agg)
model.cuda()

for train_round in range(50):
    print('#################')
    print('###Epoch:', train_round)
    train_losses = pretrain(short_loader)
    train_losses = pretrain(private_loader)

#################
###Epoch: 0
Train loss 23.498131995977356
Train loss 1.1987750149787741
#################
###Epoch: 1
Train loss 0.21746733628733214
Train loss 0.03037050631927683
#################
###Epoch: 2
Train loss 0.014607587720938894
Train loss 0.009905497741667515
#################
###Epoch: 3
Train loss 0.005833896039443654
Train loss 0.0045348779506426545
#################
###Epoch: 4
Train loss 0.0033644093173491054
Train loss 0.003533088411402671
#################
###Epoch: 5
Train loss 0.003968001344663546
Train loss 0.0034940153468112915
#################
###Epoch: 6
Train loss 0.0035899883822462153
Train loss 0.0031105340365568134
#################
###Epoch: 7
Train loss 0.0032509775022802833
Train loss 0.003210766126396769
#################
###Epoch: 8
Train loss 0.003256473138990682
Train loss 0.0028946933514894324
#################
###Epoch: 9
Train loss 0.00304492479386512
Train loss 0.0030721011396653533
#################
###Epoch: 10
Train loss 0.003150231977193

In [41]:
torch.save(model.state_dict(), "./pretrain_model/pretrain_model.bin")

## train model

In [52]:
def run(fold, train_loader, valid_loader,fine_tune):
    model = Net(K=config.K, aggregator=config.gcn_agg)
    model.cuda()
    # load pretrain model
    if fine_tune == False:
        print("load pretrain model")
        model.load_state_dict(torch.load(config.pretrain_dir), strict=False)

#         for para in model.embedding_layer.parameters():
#             para.requires_grad=False
#         for para in model.gcn.parameters():
#             para.requires_grad=False
    else:
        print("load pretrain model to fine tune")
        model.load_state_dict(torch.load(f'{config.pretrain_dir}/gcn_gru_{fold}.pt'), strict=False)
    
    criterion = weighted_mse_loss
#     if fine_tune == True:
#         optimizer = torch.optim.Adam(params=model.parameters(), lr=config.finetune_learning_rate, weight_decay=0.0)
#     else:
#         optimizer = torch.optim.Adam(params=model.parameters(), lr=config.learning_rate, weight_decay=0.0)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=config.learning_rate, weight_decay=0.0)
    
    eval_loss_increase_step = 0
    train_losses = []
    eval_losses = []
    for epoch in range(config.n_epoch):
        print('#################')
        print('###Epoch:', epoch)
        
        train_loss = train_fn(epoch, model, train_loader, criterion, optimizer)
        eval_loss = eval_fn(epoch, model, valid_loader, criterion)
        train_losses.append(train_loss)
        eval_losses.append(eval_loss)
        
        # check if should early stop
        if len(eval_losses) == 2:
            previous_eval_loss = eval_losses[0]
        if len(eval_losses) >= 2:
            if eval_loss > previous_eval_loss:
                eval_loss_increase_step += 1
            else: 
                # save the model if it is better than previous step
                torch.save(model.state_dict(), f'{config.save_dir}/gcn_gru_{fold}.pt')
                eval_loss_increase_step = 0
                previous_eval_loss=eval_loss
                
        if eval_loss_increase_step >= config.patience:
            print("early stop the model at Epoch: ", epoch)
            del model
            break
            
        
#     torch.save(model.state_dict(), f'{config.pretrain_dir}/gcn_gru_{fold}.pt')
    return train_losses, eval_losses

In [None]:
# train different aggregators
fine_tune = False
aggregator = "mean"
config.gcn_agg = aggregator
config.pretrain_dir =  "./pretrain_model/pretrain_model.bin"
config.save_dir = "./pretrain_model/GCN_mean"
splits = KFold(n_splits=config.n_split, shuffle=True, random_state=config.seed).split(train_inputs)

for fold, (train_idx, val_idx) in enumerate(splits):
    train_dataset = TensorDataset(train_inputs[train_idx], train_adj[train_idx], train_labels[train_idx])
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8)

    valid_dataset = TensorDataset(train_inputs[val_idx], train_adj[val_idx], train_labels[val_idx])
    valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=8)

    train_losses, eval_losses = run(fold, train_loader, valid_loader,fine_tune)

    fig = px.line(
        pd.DataFrame([train_losses, eval_losses], index=['loss', 'val_loss']).T, 
        y=['loss', 'val_loss'], 
        labels={'index': 'epoch', 'value': 'Mean Squared Error'}, 
        title='Training History')
    fig.show()

load pretrain model
#################
###Epoch: 0
Train loss 0.23829771247175005
Valid loss 0.19860283604690007
#################
###Epoch: 1
Train loss 0.17473984316543298
Valid loss 0.1688109189271927
#################
###Epoch: 2
Train loss 0.15643366784961135
Valid loss 0.16001597259725844
#################
###Epoch: 3
Train loss 0.14604495189808034
Valid loss 0.14465837819235666
#################
###Epoch: 4
Train loss 0.13348221171785285
Valid loss 0.13236888072320394
#################
###Epoch: 5
Train loss 0.12245696810660539
Valid loss 0.12161290219851903
#################
###Epoch: 6
Train loss 0.11631865578669089
Valid loss 0.11812626038278852
#################
###Epoch: 7
Train loss 0.11213576931644369
Valid loss 0.1148280022399766
#################
###Epoch: 8
Train loss 0.10947544900355516
Valid loss 0.1142681815794536
#################
###Epoch: 9
Train loss 0.10577278639431353
Valid loss 0.11341891757079534
#################
###Epoch: 10
Train loss 0.10435921947161357
V

Train loss 0.04359269211137736
Valid loss 0.0654994079044887
#################
###Epoch: 89
Train loss 0.04430779911301754
Valid loss 0.06467472602214132
#################
###Epoch: 90
Train loss 0.043763927977394174
Valid loss 0.06495369066085134
early stop the model at Epoch:  90


load pretrain model
#################
###Epoch: 0
Train loss 0.23893364436096615
Valid loss 0.18567471844809397
#################
###Epoch: 1
Train loss 0.17575249407026503
Valid loss 0.1635939223425729
#################
###Epoch: 2
Train loss 0.15692456287366371
Valid loss 0.14821260741778783
#################
###Epoch: 3
Train loss 0.14542578077978557
Valid loss 0.13436825573444366
#################
###Epoch: 4
Train loss 0.13240437623527315
Valid loss 0.12144551639045988
#################
###Epoch: 5
Train loss 0.12306403699848387
Valid loss 0.11540638442550387
#################
###Epoch: 6
Train loss 0.11889156423233173
Valid loss 0.11672927758523396
#################
###Epoch: 7
Train loss 0.11452001691968353
Valid loss 0.11614341820989336
#################
###Epoch: 8
Train loss 0.1120598307914204
Valid loss 0.10523630465779986
#################
###Epoch: 9
Train loss 0.10755793474338672
Valid loss 0.1012720497591155
#################
###Epoch: 10
Train loss 0.10430825225732944
V

load pretrain model
#################
###Epoch: 0
Train loss 0.24166131681866115
Valid loss 0.1883363574743271
#################
###Epoch: 1
Train loss 0.17826386789480844
Valid loss 0.17029858699866704
#################
###Epoch: 2
Train loss 0.16259494479055758
Valid loss 0.15106571785041265
#################
###Epoch: 3
Train loss 0.15057945748170218
Valid loss 0.13729867232697351
#################
###Epoch: 4
Train loss 0.13723161457865327
Valid loss 0.12022798614842552
#################
###Epoch: 5
Train loss 0.12454297476344639
Valid loss 0.11336104997566768
#################
###Epoch: 6
Train loss 0.11986664793005695
Valid loss 0.11142305710486003
#################
###Epoch: 7
Train loss 0.1156894639134407
Valid loss 0.10803865747792381
#################
###Epoch: 8
Train loss 0.11206868373685414
Valid loss 0.10397611132689885
#################
###Epoch: 9
Train loss 0.10911130546419709
Valid loss 0.10145162897450584
#################
###Epoch: 10
Train loss 0.10471305582258436


Train loss 0.04534064785197929
Valid loss 0.05737101286649704
#################
###Epoch: 89
Train loss 0.045094313444914644
Valid loss 0.05867249784725053
#################
###Epoch: 90
Train loss 0.044958933222073096
Valid loss 0.05852825620344707
#################
###Epoch: 91
Train loss 0.0445477107056865
Valid loss 0.058368683393512456
#################
###Epoch: 92
Train loss 0.04422889746449612
Valid loss 0.05774646092738424
#################
###Epoch: 93
Train loss 0.04392208931622682
Valid loss 0.057656144457204
#################
###Epoch: 94
Train loss 0.04394182631814921
Valid loss 0.057656915592295785
#################
###Epoch: 95
Train loss 0.04280200848976771
Valid loss 0.05804593488574028
#################
###Epoch: 96
Train loss 0.042840438308539217
Valid loss 0.05737345878567014
#################
###Epoch: 97
Train loss 0.042426155121238145
Valid loss 0.05724158510565758
#################
###Epoch: 98
Train loss 0.04251519521629369
Valid loss 0.057248085737228394
####

load pretrain model
#################
###Epoch: 0
Train loss 0.2292865851411113
Valid loss 0.1801609503371375
#################
###Epoch: 1
Train loss 0.1735135813554128
Valid loss 0.1577057923589434
#################
###Epoch: 2
Train loss 0.1577072866536953
Valid loss 0.14689270087650844
#################
###Epoch: 3
Train loss 0.14628861789350156
Valid loss 0.1325191238096782
#################
###Epoch: 4
Train loss 0.13265270546630578
Valid loss 0.12112654426268168
#################
###Epoch: 5
Train loss 0.1226361788533352
Valid loss 0.116769939661026
#################
###Epoch: 6
Train loss 0.11737316212168446
Valid loss 0.11061481492859977
#################
###Epoch: 7
Train loss 0.11477045328528793
Valid loss 0.111296218420778
#################
###Epoch: 8
Train loss 0.11154786469759764
Valid loss 0.10471596036638532
#################
###Epoch: 9
Train loss 0.10763741974477414
Valid loss 0.10121438865150724
#################
###Epoch: 10
Train loss 0.1039285209995729
Valid loss

In [17]:
# aggregator = "mean"
# config.patience = 15
# config.gcn_agg = aggregator
# config.pretrain_dir =  "./pretrain_model/GCN_mean"
# config.save_dir = "./pretrain_model/GCN_mean_fine_tune"
# splits = KFold(n_splits=config.n_split, shuffle=True, random_state=config.seed).split(train_inputs)
# fine_tune = True

# for fold, (train_idx, val_idx) in enumerate(splits):
#     train_dataset = TensorDataset(train_inputs[train_idx], train_adj[train_idx], train_labels[train_idx])
#     train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8)

#     valid_dataset = TensorDataset(train_inputs[val_idx], train_adj[val_idx], train_labels[val_idx])
#     valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=8)

#     train_losses, eval_losses = run(fold, train_loader, valid_loader,fine_tune)

#     fig = px.line(
#         pd.DataFrame([train_losses, eval_losses], index=['loss', 'val_loss']).T, 
#         y=['loss', 'val_loss'], 
#         labels={'index': 'epoch', 'value': 'Mean Squared Error'}, 
#         title='Training History')
#     fig.show()

## prediction

In [18]:
public_df = test.query("seq_length == 107").copy()
private_df = test.query("seq_length == 130").copy()

public_inputs, public_adj = preprocess_inputs(public_df)
private_inputs, private_adj = preprocess_inputs(private_df)

public_inputs = torch.tensor(public_inputs, dtype=torch.long)
private_inputs = torch.tensor(private_inputs, dtype=torch.long)
public_adj = torch.tensor(public_adj, dtype=torch.long)
private_adj = torch.tensor(private_adj, dtype=torch.long)

In [19]:
model_short = Net(seq_len=107, pred_len=107, K=config.K, aggregator=config.gcn_agg)
model_long = Net(seq_len=130, pred_len=130, K=config.K, aggregator=config.gcn_agg)

list_public_preds = []
list_private_preds = []
for fold in range(5):
    config.pretrain_dir = "./pretrain_model/GCN_mean"
    model_short.load_state_dict(torch.load(f'{config.pretrain_dir}/gcn_gru_{fold}.pt'))
    model_long.load_state_dict(torch.load(f'{config.pretrain_dir}/gcn_gru_{fold}.pt'))
    model_short.cuda()
    model_long.cuda()
    model_short.eval()
    model_long.eval()

    public_preds = model_short(public_inputs.cuda(), public_adj.cuda())
    private_preds = model_long(private_inputs.cuda(), private_adj.cuda())
    public_preds = public_preds.cpu().detach().numpy()
    private_preds = private_preds.cpu().detach().numpy()
    
    list_public_preds.append(public_preds)
    list_private_preds.append(private_preds)

In [20]:
public_preds = np.mean(list_public_preds, axis=0)
private_preds = np.mean(list_private_preds, axis=0)

In [None]:
preds_ls = []

for df, preds in [(public_df, public_preds), (private_df, private_preds)]:
    for i, uid in enumerate(df.id):
        single_pred = preds[i]

        single_df = pd.DataFrame(single_pred, columns=pred_cols)
        single_df['id_seqpos'] = [f'{uid}_{x}' for x in range(single_df.shape[0])]

        preds_ls.append(single_df)

preds_df = pd.concat(preds_ls)

In [None]:
submission = sample_df[['id_seqpos']].merge(preds_df, on=['id_seqpos'])
submission.to_csv('%s/submission.csv'%config.pretrain_dir, index=False)