In [34]:
import pandas as pd
import numpy as np
import matplotlib.pylab as plt
import json
import seaborn as sns
import os
import random
from tqdm.notebook import tqdm
import plotly.express as px

import lightgbm as lgb
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, Dataset, DataLoader, random_split
from torch.nn import functional as F
from torch.optim import lr_scheduler

In [35]:
class config:
    train_file = './stanford-covid-vaccine/train.json'
    test_file = './stanford-covid-vaccine/test.json'
    pretrain_dir = './baseline_model'
    sample_submission = './stanford-covid-vaccine/sample_submission.csv'
    learning_rate = 0.001
    batch_size = 64
    n_epoch = 200
    n_split = 5
    filter_noise = True
    patience= 10
    seed = 1234
    pooling_kernel = 3
    cnn_dropout_rate = 0.1

In [36]:
class AverageMeter:
    """
    Computes and stores the average and current value
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [37]:
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
seed_everything(config.seed)

## Load data

In [38]:
pred_cols = ['reactivity', 'deg_Mg_pH10', 'deg_pH10', 'deg_Mg_50C', 'deg_50C']
train = pd.read_json(config.train_file, lines=True)

if config.filter_noise:
    train = train[train.signal_to_noise > 1]
    
test = pd.read_json(config.test_file, lines=True)
sample_df = pd.read_csv(config.sample_submission)

In [52]:
def encode(data,dict_):
    one_hot_encoding = np.zeros([data.shape[0],len(data[0]),len(dict_)])
    for i in range(len(data)):
        d = data[i]
        for j in range(len(d)):
            idx = dict_[d[j]]
            one_hot_encoding[i,j,idx] = 1
    return one_hot_encoding


In [53]:
def one_hot_encoding(input_data):
    one_hot_encoding = []
    sequence_dict = {"A":0,"U":1,"C":2,"G":3}
    structure_dict = {"(":0,")":1,".":2}
    loop_dict = {"B":0,"E":1,"H":2,"I":3,"M":4,"S":5,"X":6}
    cols = ['sequence','structure', 'predicted_loop_type']
    for c in cols:
        data = input_data[c].values
        if c == "sequence":
            one_hot_encoding += [encode(data,sequence_dict)]
        elif c == "structure":
            one_hot_encoding += [encode(data,structure_dict)]
        else: 
            one_hot_encoding += [encode(data,loop_dict)]
    one_hot_encoding = np.concatenate(one_hot_encoding,axis=2)
    return one_hot_encoding

In [54]:
train_inputs = one_hot_encoding(train)
train_labels = np.array(train[pred_cols].values.tolist()).transpose((0, 2, 1))

train_inputs = torch.tensor(train_inputs, dtype=torch.float32)
train_labels = torch.tensor(train_labels, dtype=torch.float32)


## model 

In [56]:
class CNN(nn.Module):
    def __init__(self, in_channels, out_channels, kernel):
        super().__init__()
        
        self.kernel = kernel
        self.cnn = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel, padding=kernel//2).cuda()
        self.norm = nn.LayerNorm(out_channels).cuda()
        self.pooling = nn.AvgPool1d(config.pooling_kernel,1,padding=config.pooling_kernel//2)
        self.dropout = nn.Dropout(p=config.cnn_dropout_rate).cuda()
        
    def forward(self, feature_embedding):
        f = feature_embedding.permute([0,2,1])
        f = self.cnn(f)
        f = F.leaky_relu(f)
        f = self.pooling(f)
        f = f.permute([0,2,1])
        f = self.norm(f)
        f = self.dropout(f)
        
        return f

In [57]:
class Net(nn.Module):
    def __init__(self, seq_len=107, pred_len=68):
        '''
        K: number of GCN layers
        aggregator: type of aggregator function
        '''
        super(Net, self).__init__()
        
        self.pred_len = pred_len
        self.CNN1 = CNN(in_channels=14,out_channels=128,kernel=5)
        self.CNN2 = CNN(in_channels=128,out_channels=64,kernel=9)
        self.CNN3 = CNN(in_channels=64,out_channels=32,kernel=17)
        self.linear_layer1 = nn.Linear(in_features=32, 
                                out_features=5)
        
    def forward(self, input_):
        cnn_output1 = self.CNN1(input_)
        cnn_output2 = self.CNN2(cnn_output1)
        cnn_output3 = self.CNN3(cnn_output2)
        truncated = cnn_output3[:, :self.pred_len,:]
        truncated = self.linear_layer1(truncated)
        
        return truncated


In [58]:
def train_fn(epoch, model, train_loader, criterion, optimizer):
    model.train()
    model.zero_grad()
    train_loss = AverageMeter()
    
    for index, (input_, label) in enumerate(train_loader):
        input_ = input_.cuda()
        label = label.cuda()
        preds = model(input_)
        
        loss = criterion(preds, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss.update(loss.item())
    
    print(f"Train loss {train_loss.avg}")
    return train_loss.avg
    
def eval_fn(epoch, model, valid_loader, criterion):
    model.eval()
    eval_loss = AverageMeter()
    
    for index, (input_, label) in enumerate(valid_loader):
        input_ = input_.cuda()
        label = label.cuda()
        preds = model(input_)
        
        loss = criterion(preds, label)
        eval_loss.update(loss.item())
    
    print(f"Valid loss {eval_loss.avg}")
    return eval_loss.avg

In [59]:
def run(fold, train_loader, valid_loader):
    model = Net()
    model.cuda()
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(params=model.parameters(), lr=config.learning_rate, weight_decay=0.0)
    
    eval_loss_increase_step = 0
    train_losses = []
    eval_losses = []
    for epoch in range(config.n_epoch):
        print('#################')
        print('###Epoch:', epoch)
        
        train_loss = train_fn(epoch, model, train_loader, criterion, optimizer)
        eval_loss = eval_fn(epoch, model, valid_loader, criterion)
        train_losses.append(train_loss)
        eval_losses.append(eval_loss)
        
        # check if should early stop
        if len(eval_losses) == 2:
            previous_eval_loss = eval_losses[0]
        if len(eval_losses) >= 2:
            if eval_loss > previous_eval_loss:
                eval_loss_increase_step += 1
            else: 
                # save the model if it is better than previous step
                torch.save(model.state_dict(), f'{config.pretrain_dir}/baseline_{fold}.pt')
                eval_loss_increase_step = 0
                previous_eval_loss=eval_loss
                
            print("previous_eval_loss %s"%previous_eval_loss)
        if eval_loss_increase_step >= config.patience:
            print("early stop the model at Epoch: ", epoch)
            del model
            break
            
        
#     torch.save(model.state_dict(), f'{config.pretrain_dir}/gcn_gru_{fold}.pt')
    return train_losses, eval_losses

In [60]:
# train different aggregators
config.pretrain_dir = "baseline_model"
splits = KFold(n_splits=config.n_split, shuffle=True, random_state=config.seed).split(train_inputs)

for fold, (train_idx, val_idx) in enumerate(splits):
    train_dataset = TensorDataset(train_inputs[train_idx], train_labels[train_idx])
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8)

    valid_dataset = TensorDataset(train_inputs[val_idx], train_labels[val_idx])
    valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=8)

    train_losses, eval_losses = run(fold, train_loader, valid_loader)


#################
###Epoch: 0
Train loss 0.29530018347280995
Valid loss 0.214090319616454
#################
###Epoch: 1
Train loss 0.2159131340406559
Valid loss 0.19492762642247335
previous_eval_loss 0.19492762642247335
#################
###Epoch: 2
Train loss 0.19862050811449686
Valid loss 0.18405300378799438
previous_eval_loss 0.18405300378799438
#################
###Epoch: 3
Train loss 0.18505494627687666
Valid loss 0.17527470205511367
previous_eval_loss 0.17527470205511367
#################
###Epoch: 4
Train loss 0.17449305565268905
Valid loss 0.16804164435182298
previous_eval_loss 0.16804164435182298
#################
###Epoch: 5
Train loss 0.1661762970465201
Valid loss 0.16858275021825517
previous_eval_loss 0.16804164435182298
#################
###Epoch: 6
Train loss 0.15856283516795547
Valid loss 0.15909463592938014
previous_eval_loss 0.15909463592938014
#################
###Epoch: 7
Train loss 0.1524307617434749
Valid loss 0.15309140299047744
previous_eval_loss 0.15309140299047

Train loss 0.07970534734151981
Valid loss 0.09438819863966533
previous_eval_loss 0.09378405553953988
#################
###Epoch: 64
Train loss 0.07923038138283624
Valid loss 0.0924643480351993
previous_eval_loss 0.0924643480351993
#################
###Epoch: 65
Train loss 0.07777676593374323
Valid loss 0.0933568509561675
previous_eval_loss 0.0924643480351993
#################
###Epoch: 66
Train loss 0.07898548300619479
Valid loss 0.09260428590433938
previous_eval_loss 0.0924643480351993
#################
###Epoch: 67
Train loss 0.07712562409815965
Valid loss 0.09248917869159154
previous_eval_loss 0.0924643480351993
#################
###Epoch: 68
Train loss 0.07683730125427246
Valid loss 0.09200175319399152
previous_eval_loss 0.09200175319399152
#################
###Epoch: 69
Train loss 0.07749684210176822
Valid loss 0.09230485877820424
previous_eval_loss 0.09200175319399152
#################
###Epoch: 70
Train loss 0.07741922195310946
Valid loss 0.09202673499073301
previous_eval_loss 0

Valid loss 0.08713833774839129
previous_eval_loss 0.08617001452616282
#################
###Epoch: 126
Train loss 0.06412568371053096
Valid loss 0.08728707156011037
previous_eval_loss 0.08617001452616282
#################
###Epoch: 127
Train loss 0.06444660546603026
Valid loss 0.08640284623418536
previous_eval_loss 0.08617001452616282
#################
###Epoch: 128
Train loss 0.06423243080024366
Valid loss 0.08686772946800504
previous_eval_loss 0.08617001452616282
#################
###Epoch: 129
Train loss 0.06490520898390699
Valid loss 0.086388641170093
previous_eval_loss 0.08617001452616282
early stop the model at Epoch:  129
#################
###Epoch: 0
Train loss 0.2941984396289896
Valid loss 0.19649753826005117
#################
###Epoch: 1
Train loss 0.22018871980684776
Valid loss 0.17895579125199998
previous_eval_loss 0.17895579125199998
#################
###Epoch: 2
Train loss 0.1994883120059967
Valid loss 0.1662836628300803
previous_eval_loss 0.1662836628300803
##############

Train loss 0.08320132808552848
Valid loss 0.0901557377406529
previous_eval_loss 0.0901557377406529
#################
###Epoch: 59
Train loss 0.08156494575518149
Valid loss 0.08857837638684682
previous_eval_loss 0.08857837638684682
#################
###Epoch: 60
Train loss 0.08127771179985117
Valid loss 0.08971371927431651
previous_eval_loss 0.08857837638684682
#################
###Epoch: 61
Train loss 0.0811072748016428
Valid loss 0.09026871940919332
previous_eval_loss 0.08857837638684682
#################
###Epoch: 62
Train loss 0.08063171693572292
Valid loss 0.08902498973267418
previous_eval_loss 0.08857837638684682
#################
###Epoch: 63
Train loss 0.0816981323339321
Valid loss 0.08969327062368393
previous_eval_loss 0.08857837638684682
#################
###Epoch: 64
Train loss 0.08045275509357452
Valid loss 0.08882168361118861
previous_eval_loss 0.08857837638684682
#################
###Epoch: 65
Train loss 0.07947513764655148
Valid loss 0.0882773186479296
previous_eval_loss 

Train loss 0.1235269973123515
Valid loss 0.11558129851307188
previous_eval_loss 0.11558129851307188
#################
###Epoch: 17
Train loss 0.12201060741036027
Valid loss 0.1130071569766317
previous_eval_loss 0.1130071569766317
#################
###Epoch: 18
Train loss 0.12127257193680163
Valid loss 0.11187517642974854
previous_eval_loss 0.11187517642974854
#################
###Epoch: 19
Train loss 0.11944886490150734
Valid loss 0.1123554642711367
previous_eval_loss 0.11187517642974854
#################
###Epoch: 20
Train loss 0.11817071338494618
Valid loss 0.1111358180642128
previous_eval_loss 0.1111358180642128
#################
###Epoch: 21
Train loss 0.11639671883097401
Valid loss 0.10690443643501826
previous_eval_loss 0.10690443643501826
#################
###Epoch: 22
Train loss 0.11499092589925837
Valid loss 0.1070186857666288
previous_eval_loss 0.10690443643501826
#################
###Epoch: 23
Train loss 0.11356512622700797
Valid loss 0.10570039706570762
previous_eval_loss 0.

Train loss 0.07608299398863758
Valid loss 0.08306723620210375
previous_eval_loss 0.08306723620210375
#################
###Epoch: 80
Train loss 0.07556196771286151
Valid loss 0.0831243848162038
previous_eval_loss 0.08306723620210375
#################
###Epoch: 81
Train loss 0.07640054628804878
Valid loss 0.08402522972651891
previous_eval_loss 0.08306723620210375
#################
###Epoch: 82
Train loss 0.07598800902013425
Valid loss 0.08268011254923684
previous_eval_loss 0.08268011254923684
#################
###Epoch: 83
Train loss 0.07640049799724861
Valid loss 0.08399064838886261
previous_eval_loss 0.08268011254923684
#################
###Epoch: 84
Train loss 0.07550817452095172
Valid loss 0.08432658974613462
previous_eval_loss 0.08268011254923684
#################
###Epoch: 85
Train loss 0.07633012947109011
Valid loss 0.08379442031894412
previous_eval_loss 0.08268011254923684
#################
###Epoch: 86
Train loss 0.07709282829805657
Valid loss 0.08329119426863534
previous_eval_l

Train loss 0.10281226480448688
Valid loss 0.10025336593389511
previous_eval_loss 0.10025336593389511
#################
###Epoch: 31
Train loss 0.10239476131068335
Valid loss 0.10125723055430821
previous_eval_loss 0.10025336593389511
#################
###Epoch: 32
Train loss 0.09990482804951845
Valid loss 0.09926779461758477
previous_eval_loss 0.09926779461758477
#################
###Epoch: 33
Train loss 0.09986246791150835
Valid loss 0.09986712038516998
previous_eval_loss 0.09926779461758477
#################
###Epoch: 34
Train loss 0.098957858979702
Valid loss 0.0990869871207646
previous_eval_loss 0.0990869871207646
#################
###Epoch: 35
Train loss 0.09795929408735699
Valid loss 0.0987225930605616
previous_eval_loss 0.0987225930605616
#################
###Epoch: 36
Train loss 0.09700653729615388
Valid loss 0.0966583000762122
previous_eval_loss 0.0966583000762122
#################
###Epoch: 37
Train loss 0.09734821678311736
Valid loss 0.09737756316150938
previous_eval_loss 0.0

Train loss 0.07004241341793979
Valid loss 0.08371191578251976
previous_eval_loss 0.08371191578251976
#################
###Epoch: 94
Train loss 0.0692575045481876
Valid loss 0.08401294159037727
previous_eval_loss 0.08371191578251976
#################
###Epoch: 95
Train loss 0.06899012349270008
Valid loss 0.0837104395031929
previous_eval_loss 0.0837104395031929
#################
###Epoch: 96
Train loss 0.06944032882650693
Valid loss 0.08683012106588908
previous_eval_loss 0.0837104395031929
#################
###Epoch: 97
Train loss 0.07021208452405753
Valid loss 0.08365948178938457
previous_eval_loss 0.08365948178938457
#################
###Epoch: 98
Train loss 0.06956089977864865
Valid loss 0.08444997348955699
previous_eval_loss 0.08365948178938457
#################
###Epoch: 99
Train loss 0.06932224278096799
Valid loss 0.0838201886841229
previous_eval_loss 0.08365948178938457
#################
###Epoch: 100
Train loss 0.06843194075756603
Valid loss 0.08433093449899129
previous_eval_loss

Valid loss 0.1124303298337119
previous_eval_loss 0.1114182812826974
#################
###Epoch: 24
Train loss 0.11426360077328152
Valid loss 0.10929876885243825
previous_eval_loss 0.10929876885243825
#################
###Epoch: 25
Train loss 0.11327121241225137
Valid loss 0.10864748912198204
previous_eval_loss 0.10864748912198204
#################
###Epoch: 26
Train loss 0.11152972143005442
Valid loss 0.10833208582230977
previous_eval_loss 0.10833208582230977
#################
###Epoch: 27
Train loss 0.11130127172779154
Valid loss 0.10836964100599289
previous_eval_loss 0.10833208582230977
#################
###Epoch: 28
Train loss 0.10973043298279797
Valid loss 0.10648371385676521
previous_eval_loss 0.10648371385676521
#################
###Epoch: 29
Train loss 0.10902470157102302
Valid loss 0.10619572337184634
previous_eval_loss 0.10619572337184634
#################
###Epoch: 30
Train loss 0.10862670700859141
Valid loss 0.10654668297086443
previous_eval_loss 0.10619572337184634
########

Train loss 0.07595304979218377
Valid loss 0.08650406450033188
previous_eval_loss 0.08555248592581068
#################
###Epoch: 87
Train loss 0.0751106865980007
Valid loss 0.08562626370361873
previous_eval_loss 0.08555248592581068
#################
###Epoch: 88
Train loss 0.07435175169397283
Valid loss 0.08586688339710236
previous_eval_loss 0.08555248592581068
#################
###Epoch: 89
Train loss 0.07532683428790835
Valid loss 0.0859811103769711
previous_eval_loss 0.08555248592581068
#################
###Epoch: 90
Train loss 0.0747007346815533
Valid loss 0.08639246331793922
previous_eval_loss 0.08555248592581068
#################
###Epoch: 91
Train loss 0.07399467847965381
Valid loss 0.08618186733552388
previous_eval_loss 0.08555248592581068
#################
###Epoch: 92
Train loss 0.07436263726817237
Valid loss 0.08725528525454658
previous_eval_loss 0.08555248592581068
#################
###Epoch: 93
Train loss 0.07444506124765785
Valid loss 0.08669852891138621
previous_eval_los

### prediction

In [66]:
public_df = test.query("seq_length == 107").copy()
private_df = test.query("seq_length == 130").copy()

public_inputs = one_hot_encoding(public_df)
private_inputs = one_hot_encoding(private_df)

public_inputs = torch.tensor(public_inputs, dtype=torch.float32)
private_inputs = torch.tensor(private_inputs, dtype=torch.float32)

In [67]:
model_short = Net(seq_len=107, pred_len=107)
model_long = Net(seq_len=130, pred_len=130)

list_public_preds = []
list_private_preds = []
for fold in range(5):
    config.pretrain_dir = "baseline_model"
    model_short.load_state_dict(torch.load(f'{config.pretrain_dir}/baseline_{fold}.pt'))
    model_long.load_state_dict(torch.load(f'{config.pretrain_dir}/baseline_{fold}.pt'))
    model_short.cuda()
    model_long.cuda()
    model_short.eval()
    model_long.eval()

    public_preds = model_short(public_inputs.cuda())
    private_preds = model_long(private_inputs.cuda())
    public_preds = public_preds.cpu().detach().numpy()
    private_preds = private_preds.cpu().detach().numpy()
    
    list_public_preds.append(public_preds)
    list_private_preds.append(private_preds)

In [68]:
public_preds = np.mean(list_public_preds, axis=0)
private_preds = np.mean(list_private_preds, axis=0)

In [69]:
preds_ls = []

for df, preds in [(public_df, public_preds), (private_df, private_preds)]:
    for i, uid in enumerate(df.id):
        single_pred = preds[i]

        single_df = pd.DataFrame(single_pred, columns=pred_cols)
        single_df['id_seqpos'] = [f'{uid}_{x}' for x in range(single_df.shape[0])]

        preds_ls.append(single_df)

preds_df = pd.concat(preds_ls)

In [70]:
submission = sample_df[['id_seqpos']].merge(preds_df, on=['id_seqpos'])
submission.to_csv('%s/submission.csv'%config.pretrain_dir, index=False)