Model: Convolution + Transfomer + GRU + Fine Tuning

In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [3]:
from collections import OrderedDict
from fastprogress import progress_bar
from pathlib import Path
from sklearn.model_selection import train_test_split, ShuffleSplit
from torch import nn
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter

import functools
import os
import pandas as pd
import random
import shutil
import torch
import torch.nn.functional as F


target_cols = ['reactivity', 'deg_Mg_pH10', 'deg_pH10', 'deg_Mg_50C', 'deg_50C']
input_cols = ['sequence', 'structure', 'predicted_loop_type']
error_cols = ['reactivity_error', 'deg_error_Mg_pH10', 'deg_error_Mg_50C', 'deg_error_pH10', 'deg_error_50C']

token_dicts = {
    "sequence": {x: i for i, x in enumerate("ACGU")},
    "structure": {x: i for i, x in enumerate('().')},
    "predicted_loop_type": {x: i for i, x in enumerate("BEHIMSX")}
}


def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

### Loader


In [4]:
from sklearn.model_selection import train_test_split, ShuffleSplit
from torch import nn
from torch.utils.data import Dataset

import functools


BASE_PATH = "../../input/"
MODEL_SAVE_PATH = "./model"


def preprocess_inputs(df, cols):
    return np.concatenate([preprocess_feature_col(df, col) for col in cols], axis=2)


def preprocess_feature_col(df, col):
    dic = token_dicts[col]
    dic_len = len(dic)
    seq_length = len(df[col][0])
    ident = np.identity(dic_len)
    # convert to one hot
    arr = np.array(
        df[[col]].applymap(lambda seq: [ident[dic[x]] for x in seq]).values.tolist()
    ).squeeze(1)
    # shape: data_size x seq_length x dic_length
    assert arr.shape == (len(df), seq_length, dic_len)
    return arr


def preprocess(base_data, is_test=False):
    inputs = preprocess_inputs(base_data, input_cols)
    if is_test:
        labels = None
    else:
        labels = np.array(base_data[target_cols].values.tolist()).transpose((0, 2, 1))
        assert labels.shape[2] == len(target_cols)
    assert inputs.shape[2] == 14
    return inputs, labels


def get_bpp_feature(bpp):
    bpp_nb_mean = 0.077522  # mean of bpps_nb across all training data
    bpp_nb_std = 0.08914  # std of bpps_nb across all training data
    bpp_max = bpp.max(-1)[0]
    bpp_sum = bpp.sum(-1)
    bpp_nb = torch.true_divide((bpp > 0).sum(dim=1), bpp.shape[1])
    bpp_nb = torch.true_divide(bpp_nb - bpp_nb_mean, bpp_nb_std)
    return [bpp_max.unsqueeze(2), bpp_sum.unsqueeze(2), bpp_nb.unsqueeze(2)]


@functools.lru_cache(5000)
def load_from_id(id_):
    path = Path(BASE_PATH) / f"bpps/{id_}.npy"
    data = np.load(str(path))
    return data


def get_distance_matrix(leng):
    idx = np.arange(leng)
    Ds = []
    for i in range(len(idx)):
        d = np.abs(idx[i] - idx)
        Ds.append(d)

    Ds = np.array(Ds) + 1
    Ds = 1 / Ds
    Ds = Ds[None, :, :]
    Ds = np.repeat(Ds, 1, axis=0)

    Dss = []
    for i in [1, 2, 4]:
        Dss.append(Ds ** i)
    Ds = np.stack(Dss, axis=3)
    print(Ds.shape)
    return Ds


def get_structure_adj(df):
    Ss = []
    for i in range(len(df)):
        seq_length = df["seq_length"].iloc[i]
        structure = df["structure"].iloc[i]
        sequence = df["sequence"].iloc[i]

        cue = []
        a_structures = OrderedDict([
            (("A", "U"), np.zeros([seq_length, seq_length])),
            (("C", "G"), np.zeros([seq_length, seq_length])),
            (("U", "G"), np.zeros([seq_length, seq_length])),
            (("U", "A"), np.zeros([seq_length, seq_length])),
            (("G", "C"), np.zeros([seq_length, seq_length])),
            (("G", "U"), np.zeros([seq_length, seq_length])),
        ])
        for j in range(seq_length):
            if structure[j] == "(":
                cue.append(j)
            elif structure[j] == ")":
                start = cue.pop()
                a_structures[(sequence[start], sequence[j])][start, j] = 1
                a_structures[(sequence[j], sequence[start])][j, start] = 1

        a_strc = np.stack([a for a in a_structures.values()], axis=2)
        a_strc = np.sum(a_strc, axis=2, keepdims=True)
        Ss.append(a_strc)

    Ss = np.array(Ss)
    return Ss


def create_loader(df, batch_size=1, is_test=False):
    features, labels = preprocess(df, is_test)
    features_tensor = torch.from_numpy(features)
    if labels is not None:
        labels_tensor = torch.from_numpy(labels)
        dataset = VacDataset(features_tensor, df, labels_tensor)
        loader = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, drop_last=False)
    else:
        dataset = VacDataset(features_tensor, df, None)
        loader = torch.utils.data.DataLoader(dataset, batch_size, shuffle=False, drop_last=False)
    return loader


class VacDataset(Dataset):
    def __init__(self, features, df, labels=None):
        self.features = features
        self.labels = labels
        self.test = labels is None
        self.ids = df["id"]
        self.score = None
        self.structure_adj = get_structure_adj(df)
        self.distance_matrix = get_distance_matrix(self.structure_adj.shape[1])
        if "score" in df.columns:
            self.score = df["score"]
        else:
            df["score"] = 1.0
            self.score = df["score"]
        self.signal_to_noise = None
        if not self.test:
            self.signal_to_noise = df["signal_to_noise"]
            assert self.features.shape[0] == self.labels.shape[0]
        else:
            assert self.ids is not None

    def __len__(self):
        return len(self.features)

    def __getitem__(self, index):
        bpp = torch.from_numpy(load_from_id(self.ids[index]).copy()).float()
        adj = self.structure_adj[index]
        distance = self.distance_matrix[0]
        bpp = np.concatenate([bpp[:, :, None], adj, distance], axis=2)
        if self.test:
            return dict(sequence=self.features[index].float(), bpp=bpp, ids=self.ids[index])
        else:
            return dict(sequence=self.features[index].float(), bpp=bpp,
                        label=self.labels[index], ids=self.ids[index],
                        signal_to_noise=self.signal_to_noise[index],
                        score=self.score[index])


### Model

In [5]:
from torch.nn import TransformerEncoder, TransformerEncoderLayer

import math


class Conv1dStack(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size=3, padding=1, dilation=1):
        super(Conv1dStack, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm1d(out_dim),
            nn.Dropout(0.1),
            nn.LeakyReLU(),
        )
        self.res = nn.Sequential(
            nn.Conv1d(out_dim, out_dim, kernel_size=kernel_size, padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm1d(out_dim),
            nn.Dropout(0.1),
            nn.LeakyReLU(),
        )

    def forward(self, x):
        x = self.conv(x)
        h = self.res(x)
        return x + h


class Conv2dStack(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size=3, padding=1, dilation=1):
        super(Conv2dStack, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_dim),
            nn.Dropout(0.1),
            nn.LeakyReLU(),
        )
        self.res = nn.Sequential(
            nn.Conv2d(out_dim, out_dim, kernel_size=kernel_size, padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_dim),
            nn.Dropout(0.1),
            nn.LeakyReLU(),
        )

    def forward(self, x):
        x = self.conv(x)
        h = self.res(x)
        return x + h


class SeqEncoder(nn.Module):
    def __init__(self, in_dim: int):
        super(SeqEncoder, self).__init__()
        self.conv0 = Conv1dStack(in_dim, 128, 3, padding=1)
        self.conv1 = Conv1dStack(128, 64, 6, padding=5, dilation=2)
        self.conv2 = Conv1dStack(64, 32, 15, padding=7, dilation=1)
        self.conv3 = Conv1dStack(32, 32, 30, padding=29, dilation=2)

    def forward(self, x):
        x1 = self.conv0(x)
        x2 = self.conv1(x1)
        x3 = self.conv2(x2)
        x4 = self.conv3(x3)
        x = torch.cat([x1, x2, x3, x4], dim=1)
        # x = x.permute(0, 2, 1).contiguous()
        # BATCH x 256 x seq_length
        return x


class BppAttn(nn.Module):
    def __init__(self, in_channel: int, out_channel: int):
        super(BppAttn, self).__init__()
        self.conv0 = Conv1dStack(in_channel, out_channel, 3, padding=1)
        self.bpp_conv = Conv2dStack(5, out_channel)

    def forward(self, x, bpp):
        x = self.conv0(x)
        bpp = self.bpp_conv(bpp)
        # BATCH x C x SEQ x SEQ
        # BATCH x C x SEQ
        x = torch.matmul(bpp, x.unsqueeze(-1))
        return x.squeeze(-1)


class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class TransformerWrapper(nn.Module):
    def __init__(self, dmodel=256, nhead=8, num_layers=2):
        super(TransformerWrapper, self).__init__()
        self.pos_encoder = PositionalEncoding(256)
        encoder_layer = TransformerEncoderLayer(d_model=dmodel, nhead=nhead)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers)
        self.pos_emb = PositionalEncoding(dmodel)

    def flatten_parameters(self):
        pass

    def forward(self, x):
        x = x.permute((1, 0, 2)).contiguous()
        x = self.pos_emb(x)
        x = self.transformer_encoder(x)
        x = x.permute((1, 0, 2)).contiguous()
        return x, None


class RnnLayers(nn.Module):
    def __init__(self, dmodel, dropout=0.3, transformer_layers: int = 2):
        super(RnnLayers, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.rnn0 = TransformerWrapper(dmodel, nhead=8, num_layers=transformer_layers)
        self.rnn1 = nn.LSTM(dmodel, dmodel // 2, batch_first=True, num_layers=1, bidirectional=True)
        self.rnn2 = nn.GRU(dmodel, dmodel // 2, batch_first=True, num_layers=1, bidirectional=True)

    def forward(self, x):
        self.rnn0.flatten_parameters()
        x, _ = self.rnn0(x)
        if self.rnn1 is not None:
            self.rnn1.flatten_parameters()
            x = self.dropout(x)
            x, _ = self.rnn1(x)
        if self.rnn2 is not None:
            self.rnn2.flatten_parameters()
            x = self.dropout(x)
            x, _ = self.rnn2(x)
        return x

    
class BaseAttnModel(nn.Module):
    def __init__(self, transformer_layers: int = 2):
        super(BaseAttnModel, self).__init__()
        self.linear0 = nn.Linear(14 + 3, 1)
        self.seq_encoder_x = SeqEncoder(18)
        self.attn = BppAttn(256, 128)
        self.seq_encoder_bpp = SeqEncoder(128)
        self.seq = RnnLayers(256 * 2, dropout=0.3,
                             transformer_layers=transformer_layers)

    def forward(self, x, bpp):
        bpp_features = get_bpp_feature(bpp[:, :, :, 0].float())
        x = torch.cat([x] + bpp_features, dim=-1)
        learned = self.linear0(x)
        x = torch.cat([x, learned], dim=-1)
        x = x.permute(0, 2, 1).contiguous().float()
        # BATCH x 18 x seq_len
        bpp = bpp.permute([0, 3, 1, 2]).contiguous().float()
        # BATCH x 5 x seq_len x seq_len
        x = self.seq_encoder_x(x)
        # BATCH x 256 x seq_len
        bpp = self.attn(x, bpp)
        bpp = self.seq_encoder_bpp(bpp)
        # BATCH x 256 x seq_len
        x = x.permute(0, 2, 1).contiguous()
        # BATCH x seq_len x 256
        bpp = bpp.permute(0, 2, 1).contiguous()
        # BATCH x seq_len x 256
        x = torch.cat([x, bpp], dim=2)
        # BATCH x seq_len x 512
        x = self.seq(x)
        return x


class AEModel(nn.Module):
    def __init__(self, transformer_layers: int = 2):
        super(AEModel, self).__init__()
        self.seq = BaseAttnModel(transformer_layers=transformer_layers)
        self.linear = nn.Sequential(
            nn.Linear(256 * 2, 14),
            nn.Sigmoid(),
        )

    def forward(self, x, bpp):
        x = self.seq(x, bpp)
        x = F.dropout(x, p=0.3)
        x = self.linear(x)
        return x


class FromAeModel(nn.Module):
    def __init__(self, seq, pred_len=68, dmodel: int = 256):
        super(FromAeModel, self).__init__()
        self.seq = seq
        self.pred_len = pred_len
        self.linear = nn.Sequential(
            nn.Linear(dmodel * 2, len(target_cols)),
        )

    def forward(self, x, bpp):
        x = self.seq(x, bpp)
        x = self.linear(x)
        x = x[:, :self.pred_len]
        return x


In [6]:
base_train_data = pd.read_json(str(Path(BASE_PATH) / 'train.json'), lines=True)
base_train_data.head()

device = torch.device('cuda')
BATCH_SIZE = 64
base_train_data = pd.read_json(str(Path(BASE_PATH) / 'train.json'), lines=True)
base_test_data = pd.read_json(str(Path(BASE_PATH) / 'test.json'), lines=True)
public_df = base_test_data.query("seq_length == 107").copy()
private_df = base_test_data.query("seq_length == 130").copy()
print(f"public_df: {public_df.shape}")
print(f"private_df: {private_df.shape}")
public_df = public_df.reset_index()
private_df = private_df.reset_index()

features, _ = preprocess(base_train_data, True)
features_tensor = torch.from_numpy(features)
dataset0 = VacDataset(features_tensor, base_train_data, None)
features, _ = preprocess(public_df, True)
features_tensor = torch.from_numpy(features)
dataset1 = VacDataset(features_tensor, public_df, None)
features, _ = preprocess(private_df, True)
features_tensor = torch.from_numpy(features)
dataset2 = VacDataset(features_tensor, private_df, None)

loader0 = torch.utils.data.DataLoader(dataset0, BATCH_SIZE, shuffle=False, drop_last=False)
loader1 = torch.utils.data.DataLoader(dataset1, BATCH_SIZE, shuffle=False, drop_last=False)
loader2 = torch.utils.data.DataLoader(dataset2, BATCH_SIZE, shuffle=False, drop_last=False)

public_df: (629, 7)
private_df: (3005, 7)
(1, 107, 107, 3)
(1, 107, 107, 3)
(1, 130, 130, 3)


### Pretrain

In [7]:
def learn_from_batch_ae(model, data, device):
    seq = data["sequence"].clone()
    seq[:, :, :14] = F.dropout2d(seq[:, :, :14], p=0.3)
    target = data["sequence"][:, :, :14]
    out = model(seq.to(device), data["bpp"].to(device))
    loss = F.binary_cross_entropy(out, target.to(device))
    return loss


def train_ae(model, train_data, optimizer, lr_scheduler, epochs=10, device="cpu",
             start_epoch: int = 0, start_it: int = 0, log_path: str = "./logs"):
    print(f"device: {device}")
    losses = []
    it = start_it
    model_save_path = Path(MODEL_SAVE_PATH)
    start_epoch = start_epoch
    end_epoch = start_epoch + epochs
    min_loss = 10.0
    min_loss_epoch = 0
    if not model_save_path.exists():
        model_save_path.mkdir(parents=True)
    for epoch in progress_bar(range(start_epoch, end_epoch)):
        print(f"epoch: {epoch}")
        model.train()
        for i, data in enumerate(train_data):
            optimizer.zero_grad()
            loss = learn_from_batch_ae(model, data, device)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            if lr_scheduler:
                lr_scheduler.step()
            loss_v = loss.item()
            losses.append(loss_v)
            it += 1
        loss_m = np.mean(losses)
        if loss_m < min_loss:
            min_loss_epoch = epoch
            min_loss = loss_m
        print(f'epoch: {epoch} loss: {loss_m}')
        losses = []
        torch.save(optimizer.state_dict(), str(model_save_path / "optimizer.pt"))
        torch.save(model.state_dict(), str(model_save_path / f"model-{epoch}.pt"))
    return dict(end_epoch=end_epoch, it=it, min_loss_epoch=min_loss_epoch)


In [8]:
import shutil


set_seed(123)
shutil.rmtree("./model", True)
shutil.rmtree("./logs", True)
save_path = Path("./model_prediction")
if not save_path.exists():
    save_path.mkdir(parents=True)

lr_scheduler = None
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AEModel()
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
res = dict(end_epoch=0, it=0, min_loss_epoch=0)
epochs = [5, 5, 5, 5]
for e in epochs:
    res = train_ae(model, loader0, optimizer, lr_scheduler, e, device=device,
                   start_epoch=res["end_epoch"], start_it=res["it"])
    res = train_ae(model, loader1, optimizer, lr_scheduler, e, device=device,
                   start_epoch=res["end_epoch"], start_it=res["it"])
    res = train_ae(model, loader2, optimizer, lr_scheduler, e, device=device,
                   start_epoch=res["end_epoch"], start_it=res["it"])

epoch = res["min_loss_epoch"]
shutil.copyfile(str(Path(MODEL_SAVE_PATH) / f"model-{epoch}.pt"), "ae-model.pt")

device: cuda


epoch: 0
epoch: 0 loss: 0.31882223722181824
epoch: 1
epoch: 1 loss: 0.17517408101182236
epoch: 2
epoch: 2 loss: 0.1470378245178022
epoch: 3
epoch: 3 loss: 0.12168800909268229
epoch: 4
epoch: 4 loss: 0.09409924380873379
device: cuda


epoch: 5
epoch: 5 loss: 0.07560041099786759
epoch: 6
epoch: 6 loss: 0.06614229753613472
epoch: 7
epoch: 7 loss: 0.05657824091613293
epoch: 8
epoch: 8 loss: 0.04485596902668476
epoch: 9
epoch: 9 loss: 0.038395919278264044
device: cuda


epoch: 10
epoch: 10 loss: 0.04017001378567929
epoch: 11
epoch: 11 loss: 0.025475389621359236
epoch: 12
epoch: 12 loss: 0.021186316306603715
epoch: 13
epoch: 13 loss: 0.018733715202580107
epoch: 14
epoch: 14 loss: 0.017134580543225118
device: cuda


epoch: 15
epoch: 15 loss: 0.012958621165077937
epoch: 16
epoch: 16 loss: 0.010914026048818701
epoch: 17
epoch: 17 loss: 0.010145704825653842
epoch: 18
epoch: 18 loss: 0.009687816797706642
epoch: 19
epoch: 19 loss: 0.008968610916972944
device: cuda


epoch: 20
epoch: 20 loss: 0.008991648443043232
epoch: 21
epoch: 21 loss: 0.008263317309319973
epoch: 22
epoch: 22 loss: 0.007726962631568313
epoch: 23
epoch: 23 loss: 0.007300319010391831
epoch: 24
epoch: 24 loss: 0.007323366636410355
device: cuda


epoch: 25
epoch: 25 loss: 0.014351425811331323
epoch: 26
epoch: 26 loss: 0.012842220769442142
epoch: 27
epoch: 27 loss: 0.012268878638427308
epoch: 28
epoch: 28 loss: 0.011755118007831116
epoch: 29
epoch: 29 loss: 0.011265068057369678
device: cuda


epoch: 30
epoch: 30 loss: 0.007474323563081653
epoch: 31
epoch: 31 loss: 0.007744161289577421
epoch: 32
epoch: 32 loss: 0.00765139037302058
epoch: 33
epoch: 33 loss: 0.006985742488483849
epoch: 34
epoch: 34 loss: 0.006537947276803224
device: cuda


epoch: 35
epoch: 35 loss: 0.005245428392663598
epoch: 36
epoch: 36 loss: 0.005438948096707463
epoch: 37
epoch: 37 loss: 0.004695903789252043
epoch: 38
epoch: 38 loss: 0.004718615929596126
epoch: 39
epoch: 39 loss: 0.0048602860886603596
device: cuda


epoch: 40
epoch: 40 loss: 0.011894844076100817
epoch: 41
epoch: 41 loss: 0.01043161383255365
epoch: 42
epoch: 42 loss: 0.010052733361086946
epoch: 43
epoch: 43 loss: 0.009701506531936056
epoch: 44
epoch: 44 loss: 0.009377883667958544
device: cuda


epoch: 45
epoch: 45 loss: 0.006192560638546159
epoch: 46
epoch: 46 loss: 0.005543965356130349
epoch: 47
epoch: 47 loss: 0.005970509686066132
epoch: 48
epoch: 48 loss: 0.0059537784760131645
epoch: 49
epoch: 49 loss: 0.005761238262302389
device: cuda


epoch: 50
epoch: 50 loss: 0.004520050436258316
epoch: 51
epoch: 51 loss: 0.004751157085411251
epoch: 52
epoch: 52 loss: 0.004272466991096735
epoch: 53
epoch: 53 loss: 0.004051221394911408
epoch: 54
epoch: 54 loss: 0.004011103441007436
device: cuda


epoch: 55
epoch: 55 loss: 0.01005062203299492
epoch: 56
epoch: 56 loss: 0.008941334375041597
epoch: 57
epoch: 57 loss: 0.008624718960453855
epoch: 58
epoch: 58 loss: 0.008248806940986122
epoch: 59
epoch: 59 loss: 0.00812909893810432


'ae-model.pt'

### Training

In [9]:
def MCRMSE(y_true, y_pred):
    colwise_mse = torch.mean(torch.square(y_true - y_pred), dim=1)
    return torch.mean(torch.sqrt(colwise_mse), dim=1)


def sn_mcrmse_loss(predict, target, signal_to_noise):
    loss = MCRMSE(target, predict)
    weight = 0.5 * torch.log(signal_to_noise + 1.01)
    loss = (loss * weight).mean()
    return loss


def learn_from_batch(model, data, optimizer, lr_scheduler, device):
    optimizer.zero_grad()
    out = model(data["sequence"].to(device), data["bpp"].to(device))
    signal_to_noise = data["signal_to_noise"] * data["score"]
    loss = sn_mcrmse_loss(out, data["label"].to(device), signal_to_noise.to(device))
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
    if lr_scheduler:
        lr_scheduler.step()
    return out, loss


def evaluate(model, valid_data, device):
    model.eval()
    loss_list = []
    mcrmse = []
    for i, data in enumerate(valid_data):
        with torch.no_grad():
            y = model(data["sequence"].to(device), data["bpp"].to(device))
            mcrmse_ = MCRMSE(data["label"].to(device), y)[data["signal_to_noise"] > 1]
            mcrmse.append(mcrmse_.mean().item())
            loss = sn_mcrmse_loss(y, data["label"].to(device), data["signal_to_noise"].to(device))
            loss_list.append(loss.item())
    model.train()
    return dict(loss=np.mean(loss_list), mcmse=np.mean(mcrmse))


def train(model, train_data, valid_data, optimizer, lr_scheduler, epochs=10, device="cpu",
          start_epoch: int = 0, log_path: str = "./logs"):
    print(f"device: {device}")
    losses = []
    writer = SummaryWriter(log_path)
    it = 0
    model_save_path = Path(MODEL_SAVE_PATH)
    start_epoch = start_epoch
    end_epoch = start_epoch + epochs
    if not model_save_path.exists():
        model_save_path.mkdir(parents=True)
    min_eval_loss = 10.0
    min_eval_epoch = None
    for epoch in progress_bar(range(start_epoch, end_epoch)):
        print(f"epoch: {epoch}")
        model.train()
        for i, data in enumerate(train_data):
            _, loss = learn_from_batch(model, data, optimizer, lr_scheduler, device)
            loss_v = loss.item()
            writer.add_scalar('loss', loss_v, it)
            losses.append(loss_v)
            it += 1
        print(f'epoch: {epoch} loss: {np.mean(losses)}')
        losses = []

        eval_result = evaluate(model, valid_data, device)
        eval_loss = eval_result["loss"]
        if eval_loss <= min_eval_loss:
            min_eval_epoch = epoch
            min_eval_loss = eval_loss

        print(f"eval loss: {eval_loss} {eval_result['mcmse']}")
        writer.add_scalar(f"evaluate/loss", eval_loss, epoch)
        writer.add_scalar(f"evaluate/mcmse", eval_result["mcmse"], epoch)
        model.train()
        torch.save(optimizer.state_dict(), str(model_save_path / "optimizer.pt"))
        torch.save(model.state_dict(), str(model_save_path / f"model-{epoch}.pt"))
    print(f'min eval loss: {min_eval_loss} epoch {min_eval_epoch}')
    return min_eval_epoch


In [10]:
device = torch.device('cuda')
BATCH_SIZE = 64
base_train_data = pd.read_json(str(Path(BASE_PATH) / 'train.json'), lines=True)
samples = base_train_data
save_path = Path("./model_prediction")
if not save_path.exists():
    save_path.mkdir(parents=True)
shutil.rmtree("./model", True)
shutil.rmtree("./logs", True)
split = ShuffleSplit(n_splits=5, test_size=.1)
ids = samples.reset_index()["id"]
set_seed(124)
for fold, (train_index, test_index) in enumerate(split.split(samples)):
    print(f"fold: {fold}")
    train_df = samples.loc[train_index].reset_index()
    val_df = samples.loc[test_index].reset_index()
    train_loader = create_loader(train_df, BATCH_SIZE)
    valid_loader = create_loader(val_df, BATCH_SIZE)
    print(train_df.shape, val_df.shape)
    ae_model = AEModel()
    state_dict = torch.load("./ae-model.pt")
    ae_model.load_state_dict(state_dict)
    del state_dict
    model = FromAeModel(ae_model.seq)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    lr_scheduler = None
    epoch = train(model, train_loader, valid_loader, optimizer, lr_scheduler, 200, device=device,
                  log_path=f"logs/{fold}")
    shutil.copyfile(str(Path(MODEL_SAVE_PATH) / f"./model-{epoch}.pt"), f"model_prediction/model-{fold}.pt")
    del model

fold: 0
(1, 107, 107, 3)
(1, 107, 107, 3)
(2160, 21) (240, 21)
device: cuda


epoch: 0
epoch: 0 loss: 0.3177025982349815
eval loss: 0.2510703484812935 0.3287137100886572
epoch: 1
epoch: 1 loss: 0.24563885382637493
eval loss: 0.2287785113395545 0.3009757565183342
epoch: 2
epoch: 2 loss: 0.23057666266640534
eval loss: 0.22546030916546672 0.29468193855184716
epoch: 3
epoch: 3 loss: 0.21659198751390918
eval loss: 0.20454477929404374 0.2690668367534538
epoch: 4
epoch: 4 loss: 0.20606597615219677
eval loss: 0.20178344382576294 0.2650432708153864
epoch: 5
epoch: 5 loss: 0.20023093245721416
eval loss: 0.19806548305731486 0.26189495976311183
epoch: 6
epoch: 6 loss: 0.19658712018955077
eval loss: 0.1870464674612002 0.24565475595073166
epoch: 7
epoch: 7 loss: 0.1881081432914251
eval loss: 0.1816334157305413 0.23914802028395746
epoch: 8
epoch: 8 loss: 0.18550608696093643
eval loss: 0.18204575756903485 0.23930442010862174
epoch: 9
epoch: 9 loss: 0.18330566619801714
eval loss: 0.18400677076015065 0.2419307163798818
epoch: 10
epoch: 10 loss: 0.18029594829978352
eval loss: 0.17

eval loss: 0.15492578005292415 0.20340259341490702
epoch: 86
epoch: 86 loss: 0.12847896421982147
eval loss: 0.16091540570876797 0.21009309214050764
epoch: 87
epoch: 87 loss: 0.12828724874426767
eval loss: 0.15767866654147428 0.20663813011211785
epoch: 88
epoch: 88 loss: 0.1275664209269087
eval loss: 0.15730446158435152 0.2068249756484499
epoch: 89
epoch: 89 loss: 0.12826604463493982
eval loss: 0.15989844692239435 0.20833178796074558
epoch: 90
epoch: 90 loss: 0.12796674088198867
eval loss: 0.15515885377612626 0.20328277038573395
epoch: 91
epoch: 91 loss: 0.12468228537422413
eval loss: 0.15676650936823205 0.20714445434818046
epoch: 92
epoch: 92 loss: 0.1262776665222465
eval loss: 0.15641553289764865 0.2059494350461708
epoch: 93
epoch: 93 loss: 0.13784769376510075
eval loss: 0.16626289980927905 0.21857444837165074
epoch: 94
epoch: 94 loss: 0.15416742917610127
eval loss: 0.1705702040580433 0.22456175212470886
epoch: 95
epoch: 95 loss: 0.14806580291329915
eval loss: 0.16694587839969424 0.21

epoch: 169 loss: 0.1141077903444973
eval loss: 0.15743398645947293 0.2056995662840674
epoch: 170
epoch: 170 loss: 0.11399681107167733
eval loss: 0.1542707148992438 0.20341223273627076
epoch: 171
epoch: 171 loss: 0.11463539888555346
eval loss: 0.1563439257904387 0.20516845529578576
epoch: 172
epoch: 172 loss: 0.11386108671624154
eval loss: 0.15645827237259477 0.20479719027427531
epoch: 173
epoch: 173 loss: 0.11424778806341529
eval loss: 0.15400441509969842 0.20195190549372632
epoch: 174
epoch: 174 loss: 0.11323846967814503
eval loss: 0.15374255771931883 0.20240041365017292
epoch: 175
epoch: 175 loss: 0.11547990733971217
eval loss: 0.15607957838527986 0.20571733376800114
epoch: 176
epoch: 176 loss: 0.11378783384117254
eval loss: 0.15304332984100102 0.2015166864785688
epoch: 177
epoch: 177 loss: 0.11209179857350769
eval loss: 0.15489982898876375 0.20388073188506836
epoch: 178
epoch: 178 loss: 0.11358491009621463
eval loss: 0.15461480759075796 0.20343887442335706
epoch: 179
epoch: 179 loss

epoch: 0
epoch: 0 loss: 0.3054855206325494
eval loss: 0.25769360123239127 0.3300628060130335
epoch: 1
epoch: 1 loss: 0.2413718787154986
eval loss: 0.2419858265456819 0.31117819999826524
epoch: 2
epoch: 2 loss: 0.22522085229402997
eval loss: 0.22681292829322694 0.29118138422563516
epoch: 3
epoch: 3 loss: 0.21387706057561193
eval loss: 0.22201741891766108 0.2852308380445783
epoch: 4
epoch: 4 loss: 0.2024705836797799
eval loss: 0.21063681220576547 0.2705374026266576
epoch: 5
epoch: 5 loss: 0.19693335740011414
eval loss: 0.2033648627407909 0.2618481143777419
epoch: 6
epoch: 6 loss: 0.188311746939162
eval loss: 0.19332927554389776 0.24971250405003265
epoch: 7
epoch: 7 loss: 0.1865912144170365
eval loss: 0.2017151924825308 0.25956201485404373
epoch: 8
epoch: 8 loss: 0.18045201919428494
eval loss: 0.19186812509289386 0.24718836090655613
epoch: 9
epoch: 9 loss: 0.17649111584602378
eval loss: 0.19669294307613813 0.2528875036970701
epoch: 10
epoch: 10 loss: 0.17839324019454028
eval loss: 0.19178

epoch: 86
epoch: 86 loss: 0.12239971653233427
eval loss: 0.16707852282840857 0.21616548334950403
epoch: 87
epoch: 87 loss: 0.12611623982008002
eval loss: 0.1673228185756296 0.21672411578774606
epoch: 88
epoch: 88 loss: 0.1273441484412895
eval loss: 0.16779337371546285 0.21606270805094097
epoch: 89
epoch: 89 loss: 0.13376555165779033
eval loss: 0.17127811716845925 0.22047118185536457
epoch: 90
epoch: 90 loss: 0.13318865664626908
eval loss: 0.17230273672797528 0.22171626536177128
epoch: 91
epoch: 91 loss: 0.13417260630395905
eval loss: 0.16938732252304814 0.21990947132696875
epoch: 92
epoch: 92 loss: 0.1313510650292233
eval loss: 0.1675652440783666 0.21589845270594105
epoch: 93
epoch: 93 loss: 0.12801555083436122
eval loss: 0.16724243055301097 0.21508578374028509
epoch: 94
epoch: 94 loss: 0.1254633551740774
eval loss: 0.16738400604780196 0.2162654675812112
epoch: 95
epoch: 95 loss: 0.12499191985240334
eval loss: 0.16593737249126458 0.21442890357197633
epoch: 96
epoch: 96 loss: 0.12413592

epoch: 170
epoch: 170 loss: 0.11278259521690952
eval loss: 0.16181182065574964 0.20838081022610577
epoch: 171
epoch: 171 loss: 0.11165223869918012
eval loss: 0.16138338174892056 0.20837550905443974
epoch: 172
epoch: 172 loss: 0.11229869883077094
eval loss: 0.16313655710484892 0.2095956839828252
epoch: 173
epoch: 173 loss: 0.11118144483553037
eval loss: 0.1612194015926287 0.20755257906730876
epoch: 174
epoch: 174 loss: 0.1107848191188282
eval loss: 0.16335961623817888 0.209883174551388
epoch: 175
epoch: 175 loss: 0.11033138134181603
eval loss: 0.16352262888352678 0.21119475642261631
epoch: 176
epoch: 176 loss: 0.11003748035100071
eval loss: 0.16191249462721333 0.20915016889647742
epoch: 177
epoch: 177 loss: 0.1090644715289564
eval loss: 0.162725744753771 0.2096144726965747
epoch: 178
epoch: 178 loss: 0.10952701537444008
eval loss: 0.1607369249466112 0.2068994294275585
epoch: 179
epoch: 179 loss: 0.10895657226555862
eval loss: 0.16338960320696302 0.210375966459231
epoch: 180
epoch: 180 l

epoch: 0
epoch: 0 loss: 0.30028902517858846
eval loss: 0.2525042113902186 0.3310614119526397
epoch: 1
epoch: 1 loss: 0.23687105811737186
eval loss: 0.22008011670819866 0.2898621765616925
epoch: 2
epoch: 2 loss: 0.21883635615547234
eval loss: 0.20660597977042627 0.2709269339257027
epoch: 3
epoch: 3 loss: 0.20407692191086738
eval loss: 0.19498832477751263 0.2554000927881096
epoch: 4
epoch: 4 loss: 0.1965839542587805
eval loss: 0.19018200327435564 0.2492272403870344
epoch: 5
epoch: 5 loss: 0.18764947873962634
eval loss: 0.18669771343164376 0.24418062508978425
epoch: 6
epoch: 6 loss: 0.18334978629241416
eval loss: 0.1837542075466725 0.2400235332688355
epoch: 7
epoch: 7 loss: 0.17782650777138048
eval loss: 0.18341805810562223 0.23917824902421622
epoch: 8
epoch: 8 loss: 0.17322014558755533
eval loss: 0.17458898788448315 0.22931343760435935
epoch: 9
epoch: 9 loss: 0.17076027158878962
eval loss: 0.17459188534355952 0.2285700423650372
epoch: 10
epoch: 10 loss: 0.1664903189394379
eval loss: 0.17

eval loss: 0.15650687414244924 0.20435867977888783
epoch: 86
epoch: 86 loss: 0.11490919325945878
eval loss: 0.15971468602130234 0.20645453873337444
epoch: 87
epoch: 87 loss: 0.11375979147955535
eval loss: 0.15660803789775582 0.20465233950893325
epoch: 88
epoch: 88 loss: 0.11358244317593517
eval loss: 0.15745178109012387 0.20516554602049358
epoch: 89
epoch: 89 loss: 0.11288989943425005
eval loss: 0.15735063453611675 0.2040476247855708
epoch: 90
epoch: 90 loss: 0.11300301732275775
eval loss: 0.15685588915416712 0.20568329127462284
epoch: 91
epoch: 91 loss: 0.11300895320877764
eval loss: 0.15649517240972688 0.2034450840464221
epoch: 92
epoch: 92 loss: 0.11206709421988638
eval loss: 0.15534159492205152 0.2018638533251303
epoch: 93
epoch: 93 loss: 0.11122857606775596
eval loss: 0.15564652617436509 0.20256313507560683
epoch: 94
epoch: 94 loss: 0.11137903726095082
eval loss: 0.15652086157475278 0.20407254969409383
epoch: 95
epoch: 95 loss: 0.11051816400476337
eval loss: 0.15627301690659925 0.

eval loss: 0.15545171346774567 0.20324814765803306
epoch: 170
epoch: 170 loss: 0.09593199376534525
eval loss: 0.1561106720411623 0.20302005965128506
epoch: 171
epoch: 171 loss: 0.09517651864896427
eval loss: 0.1550259984441012 0.2025128243962826
epoch: 172
epoch: 172 loss: 0.09466885866602652
eval loss: 0.155594916235477 0.20311654883053873
epoch: 173
epoch: 173 loss: 0.09447987550292053
eval loss: 0.15639214223789025 0.2025919737938059
epoch: 174
epoch: 174 loss: 0.09512176274185259
eval loss: 0.15686361383856035 0.2042297523994517
epoch: 175
epoch: 175 loss: 0.09486763831595253
eval loss: 0.15541224429264594 0.2027233830862779
epoch: 176
epoch: 176 loss: 0.0948453606916006
eval loss: 0.15656778173408292 0.20299023387790482
epoch: 177
epoch: 177 loss: 0.09495006897576305
eval loss: 0.1571044663698444 0.20503973845575243
epoch: 178
epoch: 178 loss: 0.09464857372390202
eval loss: 0.154981467975505 0.2017085950730032
epoch: 179
epoch: 179 loss: 0.0953792946807221
eval loss: 0.15712225842

epoch: 0
epoch: 0 loss: 0.29977360992230323
eval loss: 0.25212849738421356 0.3354790844471585
epoch: 1
epoch: 1 loss: 0.23853620675615356
eval loss: 0.22299837915850113 0.2967533751137255
epoch: 2
epoch: 2 loss: 0.2196382435665687
eval loss: 0.20750444293667392 0.279282824661696
epoch: 3
epoch: 3 loss: 0.20770664801008748
eval loss: 0.20174694085265674 0.2688169017101762
epoch: 4
epoch: 4 loss: 0.20149138981231798
eval loss: 0.1905146383546944 0.25476979985841663
epoch: 5
epoch: 5 loss: 0.19249608018243916
eval loss: 0.18634227969327202 0.24865106590886404
epoch: 6
epoch: 6 loss: 0.18624235470343442
eval loss: 0.17864474909457978 0.2362854254962605
epoch: 7
epoch: 7 loss: 0.18197209982388773
eval loss: 0.17770131549681425 0.23607282745148786
epoch: 8
epoch: 8 loss: 0.17850847313286
eval loss: 0.17441276970011516 0.23170260891167335
epoch: 9
epoch: 9 loss: 0.17633426254338475
eval loss: 0.17166729401229922 0.2287093898820861
epoch: 10
epoch: 10 loss: 0.17154680726959015
eval loss: 0.178

eval loss: 0.14955133271985277 0.19823430092364358
epoch: 86
epoch: 86 loss: 0.11699785391059853
eval loss: 0.15132520905394542 0.20068897724303755
epoch: 87
epoch: 87 loss: 0.11660226066416712
eval loss: 0.15075155925931516 0.20010237964262875
epoch: 88
epoch: 88 loss: 0.1168705101845373
eval loss: 0.15379786686390834 0.2040539227443519
epoch: 89
epoch: 89 loss: 0.11598181716986665
eval loss: 0.15047164949358044 0.20032294555950966
epoch: 90
epoch: 90 loss: 0.11481088420219593
eval loss: 0.1499914549374555 0.19895897475205646
epoch: 91
epoch: 91 loss: 0.11526993841438302
eval loss: 0.15281356377302058 0.20279113759373496
epoch: 92
epoch: 92 loss: 0.11458798357291819
eval loss: 0.151720512298045 0.20176637015314025
epoch: 93
epoch: 93 loss: 0.113959899214674
eval loss: 0.15141309661870034 0.20028507247938795
epoch: 94
epoch: 94 loss: 0.11407840311329438
eval loss: 0.1507100504136632 0.20175896987874078
epoch: 95
epoch: 95 loss: 0.11298692041119537
eval loss: 0.15239793443570934 0.20268

eval loss: 0.14882843840902119 0.19822818191034736
epoch: 170
epoch: 170 loss: 0.09734573747334399
eval loss: 0.14965985695874504 0.19978250766236327
epoch: 171
epoch: 171 loss: 0.09707226511488418
eval loss: 0.14971351680815465 0.1993344225799143
epoch: 172
epoch: 172 loss: 0.09742792173535238
eval loss: 0.14890432725476715 0.19850493034743186
epoch: 173
epoch: 173 loss: 0.09758654029116513
eval loss: 0.14996425532753732 0.1992141406023239
epoch: 174
epoch: 174 loss: 0.09721183405231547
eval loss: 0.14879124303432403 0.1982729123116503
epoch: 175
epoch: 175 loss: 0.09562203654736406
eval loss: 0.1492875549470165 0.19810663552988783
epoch: 176
epoch: 176 loss: 0.09725792177930108
eval loss: 0.15170968091003673 0.20078174374382327
epoch: 177
epoch: 177 loss: 0.09614412940298032
eval loss: 0.14904121633897557 0.19854212368668966
epoch: 178
epoch: 178 loss: 0.09606760640209736
eval loss: 0.1486976424742871 0.1980167484671374
epoch: 179
epoch: 179 loss: 0.0964671595867759
eval loss: 0.1491

epoch: 0
epoch: 0 loss: 0.30498708430691424
eval loss: 0.24435602404157095 0.3202493192262386
epoch: 1
epoch: 1 loss: 0.2359797548781307
eval loss: 0.22342879976882501 0.2938054696036824
epoch: 2
epoch: 2 loss: 0.21648283570134708
eval loss: 0.20569559941042584 0.2704521567946928
epoch: 3
epoch: 3 loss: 0.20453296743033864
eval loss: 0.19566636050720504 0.25661105591711586
epoch: 4
epoch: 4 loss: 0.1950344603960127
eval loss: 0.19118970561826246 0.2515768067726739
epoch: 5
epoch: 5 loss: 0.18707916566517893
eval loss: 0.18593407565085715 0.24437206596550531
epoch: 6
epoch: 6 loss: 0.18354989066816937
eval loss: 0.18004211714552137 0.23684437976352218
epoch: 7
epoch: 7 loss: 0.1780018507176173
eval loss: 0.1823810159198591 0.24018152092496647
epoch: 8
epoch: 8 loss: 0.17419118072329834
eval loss: 0.17316995931595086 0.22988947259834852
epoch: 9
epoch: 9 loss: 0.16979566598081008
eval loss: 0.17677842641472233 0.23343681610397365
epoch: 10
epoch: 10 loss: 0.1695930351150727
eval loss: 0.

epoch: 86
epoch: 86 loss: 0.11459515347203253
eval loss: 0.15405082366416192 0.2032435281408927
epoch: 87
epoch: 87 loss: 0.11367547144106951
eval loss: 0.1549252555832599 0.205469167105631
epoch: 88
epoch: 88 loss: 0.11273616175343444
eval loss: 0.15410704129013308 0.2046252485402001
epoch: 89
epoch: 89 loss: 0.11321188206318744
eval loss: 0.15501141039447253 0.20522331667028515
epoch: 90
epoch: 90 loss: 0.11333222026590156
eval loss: 0.15478063625300512 0.20476952544845534
epoch: 91
epoch: 91 loss: 0.11167208723977864
eval loss: 0.15455208700627515 0.20563651770074604
epoch: 92
epoch: 92 loss: 0.1118773642939714
eval loss: 0.15641036466742622 0.20828014470707407
epoch: 93
epoch: 93 loss: 0.11237322553342126
eval loss: 0.15339594912608628 0.20208147873189036
epoch: 94
epoch: 94 loss: 0.11196514501652359
eval loss: 0.15508335422314573 0.20545415638463285
epoch: 95
epoch: 95 loss: 0.11076371176146882
eval loss: 0.1551114702156364 0.20500951645517054
epoch: 96
epoch: 96 loss: 0.110185393

epoch: 170
epoch: 170 loss: 0.09553196684788103
eval loss: 0.15396891866205767 0.20367897186176653
epoch: 171
epoch: 171 loss: 0.0951193428356385
eval loss: 0.1536083912161528 0.20330683551212933
epoch: 172
epoch: 172 loss: 0.09517937051196557
eval loss: 0.15394148824189788 0.20409374930482804
epoch: 173
epoch: 173 loss: 0.09503245806653587
eval loss: 0.1523932294769246 0.20152016448077203
epoch: 174
epoch: 174 loss: 0.09528366329820682
eval loss: 0.15381814303839392 0.2042157114943693
epoch: 175
epoch: 175 loss: 0.09466898930392999
eval loss: 0.15512516813103053 0.20475203609927456
epoch: 176
epoch: 176 loss: 0.09460511116426817
eval loss: 0.15518156144446132 0.2046965894487009
epoch: 177
epoch: 177 loss: 0.09439174871893796
eval loss: 0.15388048004377639 0.203920661237184
epoch: 178
epoch: 178 loss: 0.09410383879529412
eval loss: 0.1548939029424169 0.20634717143871711
epoch: 179
epoch: 179 loss: 0.09409260263155063
eval loss: 0.15295199130842255 0.20240466647716168
epoch: 180
epoch: 

In [11]:
def predict_batch(model, data, device):
    # batch x seq_len x target_size
    with torch.no_grad():
        pred = model(data["sequence"].to(device), data["bpp"].to(device))
        pred = pred.detach().cpu().numpy()
    return_values = []
    ids = data["ids"]
    for idx, p in enumerate(pred):
        id_ = ids[idx]
        assert p.shape == (model.pred_len, len(target_cols))
        for seqpos, val in enumerate(p):
            assert len(val) == len(target_cols)
            dic = {key: val for key, val in zip(target_cols, val)}
            dic["id_seqpos"] = f"{id_}_{seqpos}"
            return_values.append(dic)
    return return_values


def predict_data(model, loader, device, batch_size):
    data_list = []
    for i, data in enumerate(progress_bar(loader)):
        data_list += predict_batch(model, data, device)
    expected_length = model.pred_len * len(loader) * batch_size
    assert len(data_list) == expected_length, f"len = {len(data_list)} expected = {expected_length}"
    return data_list


### Prediction

In [12]:
device = torch.device('cuda') if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 1
base_test_data = pd.read_json(str(Path(BASE_PATH) / 'test.json'), lines=True)
public_df = base_test_data.query("seq_length == 107").copy()
private_df = base_test_data.query("seq_length == 130").copy()
print(f"public_df: {public_df.shape}")
print(f"private_df: {private_df.shape}")
public_df = public_df.reset_index()
private_df = private_df.reset_index()
pub_loader = create_loader(public_df, BATCH_SIZE, is_test=True)
pri_loader = create_loader(private_df, BATCH_SIZE, is_test=True)
pred_df_list = []
c = 0
for fold in range(5):
    model_load_path = f"./model_prediction/model-{fold}.pt"
    ae_model0 = AEModel()
    ae_model1 = AEModel()
    model_pub = FromAeModel(pred_len=107, seq=ae_model0.seq)
    model_pub = model_pub.to(device)
    model_pri = FromAeModel(pred_len=130, seq=ae_model1.seq)
    model_pri = model_pri.to(device)
    state_dict = torch.load(model_load_path, map_location=device)
    model_pub.load_state_dict(state_dict)
    model_pri.load_state_dict(state_dict)
    del state_dict

    data_list = []
    data_list += predict_data(model_pub, pub_loader, device, BATCH_SIZE)
    data_list += predict_data(model_pri, pri_loader, device, BATCH_SIZE)
    pred_df = pd.DataFrame(data_list, columns=["id_seqpos"] + target_cols)
    print(pred_df.head())
    print(pred_df.tail())
    pred_df_list.append(pred_df)
    c += 1
data_dic = dict(id_seqpos=pred_df_list[0]["id_seqpos"])
for col in target_cols:
    vals = np.zeros(pred_df_list[0][col].shape[0])
    for df in pred_df_list:
        vals += df[col].values
    data_dic[col] = vals / float(c)
pred_df_avg = pd.DataFrame(data_dic, columns=["id_seqpos"] + target_cols)
print(pred_df_avg.head())
pred_df_avg.to_csv("./submission.csv", index=False)


public_df: (629, 7)
private_df: (3005, 7)
(1, 107, 107, 3)
(1, 130, 130, 3)


        id_seqpos  reactivity  deg_Mg_pH10  deg_pH10  deg_Mg_50C   deg_50C
0  id_00073f8be_0    0.889971     0.578829  1.819234    0.525138  0.761425
1  id_00073f8be_1    2.399060     3.076897  4.175925    3.054830  2.517211
2  id_00073f8be_2    1.625049     0.711416  0.825830    0.866535  0.754455
3  id_00073f8be_3    1.357023     1.372728  1.276058    1.813206  1.830867
4  id_00073f8be_4    0.674013     0.607116  0.482549    0.856179  0.958451
               id_seqpos  reactivity  deg_Mg_pH10  deg_pH10  deg_Mg_50C  \
457948  id_ffda94f24_125    0.122056     0.378401  0.499851    0.590173   
457949  id_ffda94f24_126    0.290203     0.455823  0.768566    0.630433   
457950  id_ffda94f24_127    0.647525     0.113197  0.247836    0.306979   
457951  id_ffda94f24_128    0.179872     0.325130  0.155531    0.333050   
457952  id_ffda94f24_129    0.222451     0.631444  0.362951    0.654375   

         deg_50C  
457948  0.261905  
457949  0.643318  
457950  0.410135  
457951  0.100388  
4579

        id_seqpos  reactivity  deg_Mg_pH10  deg_pH10  deg_Mg_50C   deg_50C
0  id_00073f8be_0    0.840175     0.674623  1.874771    0.576894  0.763813
1  id_00073f8be_1    2.598827     3.051501  3.963598    3.037505  2.675584
2  id_00073f8be_2    1.837594     0.756073  0.999536    0.950256  0.827487
3  id_00073f8be_3    1.373335     1.148979  1.352951    1.640345  1.714775
4  id_00073f8be_4    0.791760     0.635696  0.485334    0.857873  0.739246
               id_seqpos  reactivity  deg_Mg_pH10  deg_pH10  deg_Mg_50C  \
457948  id_ffda94f24_125    0.061936     0.380067  0.417753    0.526409   
457949  id_ffda94f24_126    0.175213     0.436718  0.669643    0.622146   
457950  id_ffda94f24_127    0.515669     0.249173  0.363749    0.307995   
457951  id_ffda94f24_128    0.254049     0.341768  0.382574    0.350461   
457952  id_ffda94f24_129    0.078056     0.196059 -0.084321    0.261044   

         deg_50C  
457948  0.259373  
457949  0.530936  
457950  0.371209  
457951  0.188759  
4579

        id_seqpos  reactivity  deg_Mg_pH10  deg_pH10  deg_Mg_50C   deg_50C
0  id_00073f8be_0    0.855185     0.662608  1.900996    0.466713  0.729960
1  id_00073f8be_1    2.780891     3.752434  4.286752    3.277346  2.745527
2  id_00073f8be_2    1.866325     0.458401  0.728388    0.528044  0.771705
3  id_00073f8be_3    1.455908     1.373978  1.179715    1.703558  1.886637
4  id_00073f8be_4    0.832506     0.632477  0.425005    0.768186  0.937939
               id_seqpos  reactivity  deg_Mg_pH10  deg_pH10  deg_Mg_50C  \
457948  id_ffda94f24_125    0.142682     0.446002  0.377910    0.508230   
457949  id_ffda94f24_126    0.360360     0.377472  0.727666    0.589547   
457950  id_ffda94f24_127    0.972605     0.208708  0.529396    0.400295   
457951  id_ffda94f24_128    0.116183     0.425355  0.224215    0.397149   
457952  id_ffda94f24_129    0.210886     0.303565  0.533920    0.392660   

         deg_50C  
457948  0.293122  
457949  0.678551  
457950  0.543406  
457951  0.003551  
4579

        id_seqpos  reactivity  deg_Mg_pH10  deg_pH10  deg_Mg_50C   deg_50C
0  id_00073f8be_0    0.746331     0.503821  1.961306    0.510185  0.651083
1  id_00073f8be_1    2.494136     3.099043  4.126005    3.013498  2.694892
2  id_00073f8be_2    1.483856     0.585563  0.654471    0.584309  0.726083
3  id_00073f8be_3    1.279821     1.247037  1.308155    1.597072  2.021647
4  id_00073f8be_4    0.814999     0.487120  0.446322    0.847108  0.979084
               id_seqpos  reactivity  deg_Mg_pH10  deg_pH10  deg_Mg_50C  \
457948  id_ffda94f24_125    0.112796     0.368742  0.331351    0.516847   
457949  id_ffda94f24_126    0.199566     0.458324  0.913162    0.561407   
457950  id_ffda94f24_127    0.496672     0.240959  0.211376    0.422419   
457951  id_ffda94f24_128    0.127723     0.156009  0.018023    0.244848   
457952  id_ffda94f24_129    0.207528     0.198639  0.318818    0.458210   

         deg_50C  
457948  0.235081  
457949  0.605117  
457950  0.335691  
457951  0.070953  
4579

        id_seqpos  reactivity  deg_Mg_pH10  deg_pH10  deg_Mg_50C   deg_50C
0  id_00073f8be_0    0.822579     0.445991  1.683710    0.580682  0.593131
1  id_00073f8be_1    2.361188     2.836679  4.318111    3.132144  2.762321
2  id_00073f8be_2    1.547433     0.637337  0.970721    0.756127  0.768788
3  id_00073f8be_3    1.236453     1.188284  1.486593    1.798093  1.770873
4  id_00073f8be_4    0.736645     0.673909  0.575954    0.884074  0.803650
               id_seqpos  reactivity  deg_Mg_pH10  deg_pH10  deg_Mg_50C  \
457948  id_ffda94f24_125    0.206949     0.455570  0.542259    0.483514   
457949  id_ffda94f24_126    0.257412     0.152836  0.924369    0.573841   
457950  id_ffda94f24_127    0.689666     0.164514  0.334700    0.247386   
457951  id_ffda94f24_128    0.313861     0.144368  0.117415    0.250088   
457952  id_ffda94f24_129    0.260051     0.265586  0.925387    0.715781   

         deg_50C  
457948  0.199750  
457949  0.645355  
457950  0.466647  
457951 -0.053965  
4579