# Import

In [37]:
import pandas as pd
import numpy as np
import matplotlib.pylab as plt
import json
import seaborn as sns
import os
import random
from tqdm.notebook import tqdm

import lightgbm as lgb
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, Dataset, DataLoader, random_split
from torch.nn import functional as F
from torch.optim import lr_scheduler

# Settings

In [38]:
class config:
    data_dir = '../OpenVaccine/'
    train_file = '../OpenVaccine/train.json'
    test_file = '../OpenVaccine/test.json'
    pretrain_dir = '../OpenVaccine/pretrains/'
    sample_submission = '../OpenVaccine/sample_submission.csv'
    learning_rate = 0.01
    batch_size = 64
    n_epoch = 50
    n_split = 5
    seed = 1234

# Utils

In [39]:
class AverageMeter:
    """
    Computes and stores the average and current value
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [40]:
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
seed_everything(config.seed)

# Model

In [41]:
class Net(nn.Module):
    def __init__(self, num_embedding=14, seq_len=107, pred_len=68, dropout=0.5, embed_dim=100, hidden_dim=128):
        super(Net, self).__init__()
        
        self.pred_len = pred_len
        self.embedding_layer = nn.Embedding(num_embeddings=num_embedding, 
                                      embedding_dim=embed_dim)
        
        self.gru_layer = nn.GRU(input_size=3 * embed_dim, 
                          hidden_size=hidden_dim, 
                          num_layers=3, 
                          batch_first=True, 
                          dropout=dropout, 
                          bidirectional=True)
        
        self.linear_layer = nn.Linear(in_features=2 * hidden_dim, 
                                out_features=5)
        
    def forward(self, input_):
        embedding = self.embedding_layer(input_)
        embedding = torch.reshape(embedding, (-1, embedding.shape[1], embedding.shape[2] * embedding.shape[3]))
        
        gru_output, gru_hidden = self.gru_layer(embedding)
        truncated = gru_output[:, :self.pred_len, :]
        
        output = self.linear_layer(truncated)
        
        return output

# Load Data

In [42]:
pred_cols = ['reactivity', 'deg_Mg_pH10', 'deg_pH10', 'deg_Mg_50C', 'deg_50C']

In [43]:
token2int = {x:i for i, x in enumerate('().ACGUBEHIMSX')}

def preprocess_inputs(df, cols=['sequence', 'structure', 'predicted_loop_type']):
    return np.transpose(
        np.array(
            df[cols]
            .applymap(lambda seq: [token2int[x] for x in seq])
            .values
            .tolist()
        ),
        (0, 2, 1)
    )

In [44]:
train = pd.read_json(config.train_file, lines=True)
test = pd.read_json(config.test_file, lines=True)
sample_df = pd.read_csv(config.sample_submission)

In [45]:
train_inputs = preprocess_inputs(train)
train_labels = np.array(train[pred_cols].values.tolist()).transpose((0, 2, 1))

train_inputs = torch.tensor(train_inputs, dtype=torch.long)
train_labels = torch.tensor(train_labels, dtype=torch.float32)

# Train

In [46]:
def train_fn(epoch, model, train_loader, criterion, optimizer):
    model.train()
    model.zero_grad()
    train_loss = AverageMeter()
    
    for index, (input_, label) in tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Train Epoch {epoch}'):
        input_ = input_.cuda()
        label = label.cuda()
        preds = model(input_)
        
        loss = criterion(preds, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss.update(loss.item())
    
    print(f"Train loss {train_loss.avg}")
    
def eval_fn(epoch, model, valid_loader, criterion):
    model.eval()
    eval_loss = AverageMeter()
    
    for index, (input_, label) in enumerate(valid_loader):
        input_ = input_.cuda()
        label = label.cuda()
        preds = model(input_)
        
        loss = criterion(preds, label)
        eval_loss.update(loss.item())
    
    print(f"Valid loss {eval_loss.avg}")

In [47]:
def run(fold, train_loader, valid_loader):
    model = Net()
    model.cuda()
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(params=model.parameters(), lr=config.learning_rate, weight_decay=0.0)
    
    for epoch in range(config.n_epoch):
        print('#################')
        print('###Epoch:', epoch)
        print('#################')
        
        train_fn(epoch, model, train_loader, criterion, optimizer)
        eval_fn(epoch, model, valid_loader, criterion)
        
    torch.save(model.state_dict(), f'{config.pretrain_dir}/gru_{fold}.pt')

In [48]:
splits = KFold(n_splits=config.n_split, shuffle=True, random_state=config.seed).split(train_inputs)

for fold, (train_idx, val_idx) in enumerate(splits):
    train_dataset = TensorDataset(train_inputs[train_idx], train_labels[train_idx])
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8)
    
    valid_dataset = TensorDataset(train_inputs[val_idx], train_labels[val_idx])
    valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=8)
    
    run(fold, train_loader, valid_loader)
    
    break

#################
###Epoch: 0
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 0', max=30.0, style=ProgressStyle(description…


Train loss 1.2931399504343668
Valid loss 0.8107497282326221
#################
###Epoch: 1
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 1', max=30.0, style=ProgressStyle(description…


Train loss 0.8899546657999357
Valid loss 0.7917404100298882
#################
###Epoch: 2
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 2', max=30.0, style=ProgressStyle(description…


Train loss 0.8773083945115407
Valid loss 0.7910861410200596
#################
###Epoch: 3
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 3', max=30.0, style=ProgressStyle(description…


Train loss 0.8727595816055934
Valid loss 0.7812612876296043
#################
###Epoch: 4
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 4', max=30.0, style=ProgressStyle(description…


Train loss 0.8714268957575162
Valid loss 0.7876414693892002
#################
###Epoch: 5
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 5', max=30.0, style=ProgressStyle(description…


Train loss 0.8690954372286797
Valid loss 0.7781409919261932
#################
###Epoch: 6
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 6', max=30.0, style=ProgressStyle(description…


Train loss 0.8655275076627731
Valid loss 0.7705586310476065
#################
###Epoch: 7
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 7', max=30.0, style=ProgressStyle(description…


Train loss 0.8594218363364537
Valid loss 0.7692657634615898
#################
###Epoch: 8
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 8', max=30.0, style=ProgressStyle(description…


Train loss 0.8578344558676084
Valid loss 0.7649569176137447
#################
###Epoch: 9
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 9', max=30.0, style=ProgressStyle(description…


Train loss 0.8537614191571872
Valid loss 0.7621177677065134
#################
###Epoch: 10
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 10', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8542804112037022
Valid loss 0.7690866179764271
#################
###Epoch: 11
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 11', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.849353342751662
Valid loss 0.7514516189694405
#################
###Epoch: 12
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 12', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8475689788659414
Valid loss 0.7565915808081627
#################
###Epoch: 13
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 13', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8435542896389961
Valid loss 0.7474155556410551
#################
###Epoch: 14
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 14', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8444405446449915
Valid loss 0.7685529496520758
#################
###Epoch: 15
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 15', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.847773402929306
Valid loss 0.7482206989079714
#################
###Epoch: 16
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 16', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8441636800765991
Valid loss 0.7837905529886484
#################
###Epoch: 17
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 17', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8470209310452144
Valid loss 0.757666839286685
#################
###Epoch: 18
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 18', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8406011362870535
Valid loss 0.7541365176439285
#################
###Epoch: 19
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 19', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8471624061465264
Valid loss 0.7594449445605278
#################
###Epoch: 20
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 20', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8512064099311829
Valid loss 0.7611109521239996
#################
###Epoch: 21
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 21', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8456360985835393
Valid loss 0.7508703600615263
#################
###Epoch: 22
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 22', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8425256575147311
Valid loss 0.748156001791358
#################
###Epoch: 23
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 23', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8444280525048574
Valid loss 0.7516131363809109
#################
###Epoch: 24
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 24', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8479284758369128
Valid loss 0.7565869353711605
#################
###Epoch: 25
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 25', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8387995660305023
Valid loss 0.7478378787636757
#################
###Epoch: 26
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 26', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8414486840367317
Valid loss 0.7489225156605244
#################
###Epoch: 27
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 27', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8383665651082992
Valid loss 0.7443334888666868
#################
###Epoch: 28
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 28', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8412712787588438
Valid loss 0.7512251734733582
#################
###Epoch: 29
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 29', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8363219807545345
Valid loss 0.7527503818273544
#################
###Epoch: 30
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 30', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8351916258533796
Valid loss 0.7448699437081814
#################
###Epoch: 31
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 31', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8356490820646286
Valid loss 0.7479509748518467
#################
###Epoch: 32
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 32', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8401934539278348
Valid loss 0.7400018591433764
#################
###Epoch: 33
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 33', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8356031343340874
Valid loss 0.7443856801837683
#################
###Epoch: 34
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 34', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8352412258585294
Valid loss 0.7387781832367182
#################
###Epoch: 35
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 35', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8471806744734446
Valid loss 0.7452088501304388
#################
###Epoch: 36
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 36', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8336769769589106
Valid loss 0.7447696123272181
#################
###Epoch: 37
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 37', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.837757155795892
Valid loss 0.7438877522945404
#################
###Epoch: 38
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 38', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.841324637333552
Valid loss 0.7579204775393009
#################
###Epoch: 39
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 39', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.845101556678613
Valid loss 0.764548022300005
#################
###Epoch: 40
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 40', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8425566623608272
Valid loss 0.7417305931448936
#################
###Epoch: 41
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 41', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8374810487031936
Valid loss 0.7504175342619419
#################
###Epoch: 42
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 42', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8361220056811969
Valid loss 0.7408953923732042
#################
###Epoch: 43
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 43', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8408884773651759
Valid loss 0.7701321672648191
#################
###Epoch: 44
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 44', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8348153640826543
Valid loss 0.7396283131092787
#################
###Epoch: 45
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 45', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8285442878802617
Valid loss 0.7485960405319929
#################
###Epoch: 46
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 46', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8345409577091535
Valid loss 0.736707141622901
#################
###Epoch: 47
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 47', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.830645955602328
Valid loss 0.7460522521287203
#################
###Epoch: 48
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 48', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8383462076385816
Valid loss 0.7358430400490761
#################
###Epoch: 49
#################


HBox(children=(FloatProgress(value=0.0, description='Train Epoch 49', max=30.0, style=ProgressStyle(descriptio…


Train loss 0.8331598450740179
Valid loss 0.7385227922350168


# Test

In [36]:
public_df = test.query("seq_length == 107").copy()
private_df = test.query("seq_length == 130").copy()

public_inputs = preprocess_inputs(public_df)
private_inputs = preprocess_inputs(private_df)

public_inputs = torch.tensor(public_inputs, dtype=torch.long)
private_inputs = torch.tensor(private_inputs, dtype=torch.long)

In [37]:
model_short = Net(seq_len=107, pred_len=107)
model_long = Net(seq_len=130, pred_len=130)

list_public_preds = []
list_private_preds = []
for fold in range(config.n_split):
    model_short.load_state_dict(torch.load(f'{config.pretrain_dir}/gru_{fold}.pt'))
    model_long.load_state_dict(torch.load(f'{config.pretrain_dir}/gru_{fold}.pt'))
    model_short.cuda()
    model_long.cuda()
    model_short.eval()
    model_long.eval()

    public_preds = model_short(public_inputs.cuda())
    private_preds = model_long(private_inputs.cuda())
    public_preds = public_preds.cpu().detach().numpy()
    private_preds = private_preds.cpu().detach().numpy()
    
    list_public_preds.append(public_preds)
    list_private_preds.append(private_preds)

In [38]:
public_preds = np.mean(list_public_preds, axis=0)
private_preds = np.mean(list_private_preds, axis=0)

In [39]:
preds_ls = []

for df, preds in [(public_df, public_preds), (private_df, private_preds)]:
    for i, uid in enumerate(df.id):
        single_pred = preds[i]

        single_df = pd.DataFrame(single_pred, columns=pred_cols)
        single_df['id_seqpos'] = [f'{uid}_{x}' for x in range(single_df.shape[0])]

        preds_ls.append(single_df)

preds_df = pd.concat(preds_ls)

In [40]:
submission = sample_df[['id_seqpos']].merge(preds_df, on=['id_seqpos'])
submission.to_csv(f'{config.data_dir}/submission.csv', index=False)

In [None]:
# 0.6 - 0.28785
# 0.4 - 0.28630