Ispired by this notebook.

https://www.kaggle.com/mrkmakr/covid-ae-pretrain-gnn-attn-cnn

Pytorch Implementation.

Model: Convolution + Transfomer + GRU.

Training: Bert-Like Pretraining. Then Fine Tuning.

In [1]:
import numpy as np
import pandas as pd
import ast
from collections import OrderedDict
from fastprogress import progress_bar
from pathlib import Path
from torch import nn
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter

from torch.nn import TransformerEncoder, TransformerEncoderLayer
import math

import functools
import itertools
import os
import random
import shutil
import torch
import torch.nn.functional as F

from sklearn.model_selection import train_test_split, ShuffleSplit, KFold, StratifiedKFold, GroupKFold
from torch import nn
from torch.utils.data import Dataset

import functools
from IPython.core.debugger import set_trace


target_cols = ['reactivity', 'deg_Mg_pH10', 'deg_pH10', 'deg_Mg_50C', 'deg_50C']
input_cols = ['sequence', 'structure', 'predicted_loop_type']
error_cols = ['reactivity_error', 'deg_error_Mg_pH10', 'deg_error_Mg_50C', 'deg_error_pH10', 'deg_error_50C']

token_dicts = {
    "sequence": {x: i for i, x in enumerate("ACGU")},
    "structure": {x: i for i, x in enumerate('().')},
    "predicted_loop_type": {x: i for i, x in enumerate("BEHIMSX")}
}


def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

### Loader


In [2]:
BASE_PATH = "../OpenVaccine"
MODEL_SAVE_PATH = "../OpenVaccine/pretrains"


def preprocess_inputs(df, cols):
    return np.concatenate([preprocess_feature_col(df, col) for col in cols], axis=2)


def preprocess_feature_col(df, col):
    dic = token_dicts[col]
    dic_len = len(dic)
    seq_length = len(df[col][0])
    ident = np.identity(dic_len)
    # convert to one hot
    arr = np.array(
        df[[col]].applymap(lambda seq: [ident[dic[x]] for x in seq]).values.tolist()
    ).squeeze(1)
    # shape: data_size x seq_length x dic_length
    assert arr.shape == (len(df), seq_length, dic_len)
    return arr


def preprocess(base_data, is_test=False):
    inputs = preprocess_inputs(base_data, input_cols)
    if is_test:
        labels = None
    else:
        labels = np.array(base_data[target_cols].values.tolist()).transpose((0, 2, 1))
        assert labels.shape[2] == len(target_cols)
    assert inputs.shape[2] == 14
    return inputs, labels


def get_bpp_feature(bpp):
    bpp_nb_mean = 0.077522  # mean of bpps_nb across all training data
    bpp_nb_std = 0.08914  # std of bpps_nb across all training data
    bpp_max = bpp.max(-1)[0]
    bpp_sum = bpp.sum(-1)
    bpp_nb = torch.true_divide((bpp > 0).sum(dim=1), bpp.shape[1])
    bpp_nb = torch.true_divide(bpp_nb - bpp_nb_mean, bpp_nb_std)
    return [bpp_max.unsqueeze(2), bpp_sum.unsqueeze(2), bpp_nb.unsqueeze(2)]


@functools.lru_cache(5000)
def load_from_id(id_):
    path = Path(BASE_PATH) / f"bpps/{id_}.npy"
    data = np.load(str(path))
    return data


def get_distance_matrix(leng):
    idx = np.arange(leng)
    Ds = []
    for i in range(len(idx)):
        d = np.abs(idx[i] - idx)
        Ds.append(d)

    Ds = np.array(Ds) + 1
    Ds = 1 / Ds
    Ds = Ds[None, :, :]
    Ds = np.repeat(Ds, 1, axis=0)

    Dss = []
    for i in [1, 2, 4]:
        Dss.append(Ds ** i)
    Ds = np.stack(Dss, axis=3)
    print(Ds.shape)
    return Ds


def get_structure_adj(df):
    Ss = []
    for i in range(len(df)):
        seq_length = df["seq_length"].iloc[i]
        structure = df["structure"].iloc[i]
        sequence = df["sequence"].iloc[i]

        cue = []
        a_structures = OrderedDict([
            (("A", "U"), np.zeros([seq_length, seq_length])),
            (("C", "G"), np.zeros([seq_length, seq_length])),
            (("U", "G"), np.zeros([seq_length, seq_length])),
            (("U", "A"), np.zeros([seq_length, seq_length])),
            (("G", "C"), np.zeros([seq_length, seq_length])),
            (("G", "U"), np.zeros([seq_length, seq_length])),
        ])
        for j in range(seq_length):
            if structure[j] == "(":
                cue.append(j)
            elif structure[j] == ")":
                start = cue.pop()
                a_structures[(sequence[start], sequence[j])][start, j] = 1
                a_structures[(sequence[j], sequence[start])][j, start] = 1

        a_strc = np.stack([a for a in a_structures.values()], axis=2)
        a_strc = np.sum(a_strc, axis=2, keepdims=True)
        Ss.append(a_strc)

    Ss = np.array(Ss)
    return Ss


def create_loader(df, batch_size=1, is_test=False):
    features, labels = preprocess(df, is_test)
    features_tensor = torch.from_numpy(features)
    if labels is not None:
        labels_tensor = torch.from_numpy(labels)
        dataset = VacDataset(features_tensor, df, labels_tensor)
        loader = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, drop_last=False)
    else:
        dataset = VacDataset(features_tensor, df, None)
        loader = torch.utils.data.DataLoader(dataset, batch_size, shuffle=False, drop_last=False)
    return loader


class VacDataset(Dataset):
    def __init__(self, features, df, labels=None):
        self.features = features
        self.labels = labels
        self.test = labels is None
        self.ids = df["id"]
        self.score = None
        self.structure_adj = get_structure_adj(df)
        self.distance_matrix = get_distance_matrix(self.structure_adj.shape[1])
        if "score" in df.columns:
            self.score = df["score"]
        else:
            df["score"] = 1.0
            self.score = df["score"]
        self.signal_to_noise = None
        if not self.test:
            self.signal_to_noise = df["signal_to_noise"]
            assert self.features.shape[0] == self.labels.shape[0]
        else:
            assert self.ids is not None

    def __len__(self):
        return len(self.features)

    def __getitem__(self, index):
        bpp = torch.from_numpy(load_from_id(self.ids[index]).copy()).float()
        adj = self.structure_adj[index]
        distance = self.distance_matrix[0]
        bpp = np.concatenate([bpp[:, :, None], adj, distance], axis=2)
        if self.test:
            return dict(sequence=self.features[index].float(), bpp=bpp, ids=self.ids[index])
        else:
            return dict(sequence=self.features[index].float(), bpp=bpp,
                        label=self.labels[index], ids=self.ids[index],
                        signal_to_noise=self.signal_to_noise[index],
                        score=self.score[index])


### Model

In [3]:
class Conv1dStack(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size=3, padding=1, dilation=1):
        super(Conv1dStack, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm1d(out_dim),
            nn.Dropout(0.1),
            nn.LeakyReLU(),
        )
        self.res = nn.Sequential(
            nn.Conv1d(out_dim, out_dim, kernel_size=kernel_size, padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm1d(out_dim),
            nn.Dropout(0.1),
            nn.LeakyReLU(),
        )

    def forward(self, x):
        x = self.conv(x)
        h = self.res(x)
        return x + h


class Conv2dStack(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size=3, padding=1, dilation=1):
        super(Conv2dStack, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_dim),
            nn.Dropout(0.1),
            nn.LeakyReLU(),
        )
        self.res = nn.Sequential(
            nn.Conv2d(out_dim, out_dim, kernel_size=kernel_size, padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_dim),
            nn.Dropout(0.1),
            nn.LeakyReLU(),
        )

    def forward(self, x):
        x = self.conv(x)
        h = self.res(x)
        return x + h


class SeqEncoder(nn.Module):
    def __init__(self, in_dim: int):
        super(SeqEncoder, self).__init__()
        self.conv0 = Conv1dStack(in_dim, 128, 3, padding=1)
        self.conv1 = Conv1dStack(128, 64, 6, padding=5, dilation=2)
        self.conv2 = Conv1dStack(64, 32, 15, padding=7, dilation=1)
        self.conv3 = Conv1dStack(32, 32, 30, padding=29, dilation=2)

    def forward(self, x):
        x1 = self.conv0(x)
        x2 = self.conv1(x1)
        x3 = self.conv2(x2)
        x4 = self.conv3(x3)
        x = torch.cat([x1, x2, x3, x4], dim=1)
        # x = x.permute(0, 2, 1).contiguous()
        # BATCH x 256 x seq_length
        return x


class BppAttn(nn.Module):
    def __init__(self, in_channel: int, out_channel: int):
        super(BppAttn, self).__init__()
        self.conv0 = Conv1dStack(in_channel, out_channel, 3, padding=1)
        self.bpp_conv = Conv2dStack(5, out_channel)

    def forward(self, x, bpp):
        x = self.conv0(x)
        bpp = self.bpp_conv(bpp)
        # BATCH x C x SEQ x SEQ
        # BATCH x C x SEQ
        x = torch.matmul(bpp, x.unsqueeze(-1))
        return x.squeeze(-1)


class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class TransformerWrapper(nn.Module):
    def __init__(self, dmodel=256, nhead=8, num_layers=2):
        super(TransformerWrapper, self).__init__()
        self.pos_encoder = PositionalEncoding(256)
        encoder_layer = TransformerEncoderLayer(d_model=dmodel, nhead=nhead)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers)
        self.pos_emb = PositionalEncoding(dmodel)

    def flatten_parameters(self):
        pass

    def forward(self, x):
        x = x.permute((1, 0, 2)).contiguous()
        x = self.pos_emb(x)
        x = self.transformer_encoder(x)
        x = x.permute((1, 0, 2)).contiguous()
        return x, None


class RnnLayers(nn.Module):
    def __init__(self, dmodel, dropout=0.3, transformer_layers: int = 2):
        super(RnnLayers, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.rnn0 = TransformerWrapper(dmodel, nhead=8, num_layers=transformer_layers)
        self.rnn1 = nn.LSTM(dmodel, dmodel // 2, batch_first=True, num_layers=1, bidirectional=True)
        self.rnn2 = nn.GRU(dmodel, dmodel // 2, batch_first=True, num_layers=1, bidirectional=True)

    def forward(self, x):
        self.rnn0.flatten_parameters()
        x, _ = self.rnn0(x)
        if self.rnn1 is not None:
            self.rnn1.flatten_parameters()
            x = self.dropout(x)
            x, _ = self.rnn1(x)
        if self.rnn2 is not None:
            self.rnn2.flatten_parameters()
            x = self.dropout(x)
            x, _ = self.rnn2(x)
        return x

    
class BaseAttnModel(nn.Module):
    def __init__(self, transformer_layers: int = 2):
        super(BaseAttnModel, self).__init__()
        self.linear0 = nn.Linear(14 + 3, 1)
        self.seq_encoder_x = SeqEncoder(18)
        self.attn = BppAttn(256, 128)
        self.seq_encoder_bpp = SeqEncoder(128)
        self.seq = RnnLayers(256 * 2, dropout=0.3,
                             transformer_layers=transformer_layers)

    def forward(self, x, bpp):
        bpp_features = get_bpp_feature(bpp[:, :, :, 0].float())
        x = torch.cat([x] + bpp_features, dim=-1)
        learned = self.linear0(x)
        x = torch.cat([x, learned], dim=-1)
        x = x.permute(0, 2, 1).contiguous().float()
        # BATCH x 18 x seq_len
        bpp = bpp.permute([0, 3, 1, 2]).contiguous().float()
        # BATCH x 5 x seq_len x seq_len
        x = self.seq_encoder_x(x)
        # BATCH x 256 x seq_len
        bpp = self.attn(x, bpp)
        bpp = self.seq_encoder_bpp(bpp)
        # BATCH x 256 x seq_len
        x = x.permute(0, 2, 1).contiguous()
        # BATCH x seq_len x 256
        bpp = bpp.permute(0, 2, 1).contiguous()
        # BATCH x seq_len x 256
        x = torch.cat([x, bpp], dim=2)
        # BATCH x seq_len x 512
        x = self.seq(x)
        return x


class AEModel(nn.Module):
    def __init__(self, transformer_layers: int = 2):
        super(AEModel, self).__init__()
        self.seq = BaseAttnModel(transformer_layers=transformer_layers)
        self.linear = nn.Sequential(
            nn.Linear(256 * 2, 14),
            nn.Sigmoid(),
        )

    def forward(self, x, bpp):
        x = self.seq(x, bpp)
        x = F.dropout(x, p=0.3)
        x = self.linear(x)
        return x


class FromAeModel(nn.Module):
    def __init__(self, seq, pred_len=68, dmodel: int = 256):
        super(FromAeModel, self).__init__()
        self.seq = seq
        self.pred_len = pred_len
        self.linear = nn.Sequential(
            nn.Linear(dmodel * 2, len(target_cols)),
        )

    def forward(self, x, bpp):
        x = self.seq(x, bpp)
        x = self.linear(x)
        x = x[:, :self.pred_len]
        return x

In [4]:
aug_df = pd.read_csv('../OpenVaccine/aug_data1.csv')

def aug_data(df):
    target_df = df.copy()
    new_df = aug_df[aug_df['id'].isin(target_df['id'])]
                         
    del target_df['structure']
    del target_df['predicted_loop_type']
    new_df = new_df.merge(target_df, on=['id','sequence'], how='left')

    df['cnt'] = df['id'].map(new_df[['id','cnt']].set_index('id').to_dict()['cnt'])
    df['log_gamma'] = 100
    df['score'] = 1.0
    df = df.append(new_df[df.columns])
    return df

In [5]:
pseudo_df = pd.read_csv('../OpenVaccine/pseudo_test.csv')

pseudo_df['SN_filter'] = 1
pseudo_df['signal_to_noise'] = 1.0
pseudo_df['score'] = 1.0
pseudo_df['reactivity'] = pseudo_df['reactivity'].apply(lambda x: ast.literal_eval(x))
pseudo_df['deg_Mg_pH10'] = pseudo_df['deg_Mg_pH10'].apply(lambda x: ast.literal_eval(x))
pseudo_df['deg_pH10'] = pseudo_df['deg_pH10'].apply(lambda x: ast.literal_eval(x))
pseudo_df['deg_Mg_50C'] = pseudo_df['deg_Mg_50C'].apply(lambda x: ast.literal_eval(x))
pseudo_df['deg_50C'] = pseudo_df['deg_50C'].apply(lambda x: ast.literal_eval(x))

In [6]:
pseudo_st = pseudo_df[pseudo_df['seq_length'] == 107]
pseudo_lg = pseudo_df[pseudo_df['seq_length'] == 130]

pseudo_st['reactivity'] = pseudo_st['reactivity'].apply(lambda x: x[:68])
pseudo_st['deg_Mg_pH10'] = pseudo_st['deg_Mg_pH10'].apply(lambda x: x[:68])
pseudo_st['deg_pH10'] = pseudo_st['deg_pH10'].apply(lambda x: x[:68])
pseudo_st['deg_Mg_50C'] = pseudo_st['deg_Mg_50C'].apply(lambda x: x[:68])
pseudo_st['deg_50C'] = pseudo_st['deg_50C'].apply(lambda x: x[:68])

pseudo_lg['reactivity'] = pseudo_lg['reactivity'].apply(lambda x: x[:91])
pseudo_lg['deg_Mg_pH10'] = pseudo_lg['deg_Mg_pH10'].apply(lambda x: x[:91])
pseudo_lg['deg_pH10'] = pseudo_lg['deg_pH10'].apply(lambda x: x[:91])
pseudo_lg['deg_Mg_50C'] = pseudo_lg['deg_Mg_50C'].apply(lambda x: x[:91])
pseudo_lg['deg_50C'] = pseudo_lg['deg_50C'].apply(lambda x: x[:91])

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  after removing the cwd from sys.path.
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: 

In [7]:
base_train_data = pd.read_json(str(Path(BASE_PATH) / 'train.json'), lines=True)
#base_train_data.head()

# add aug data

base_train_data = aug_data(base_train_data)
# base_train_data = base_train_data.append(pseudo_st)
# base_train_data = base_train_data[base_train_data['SN_filter'] == 1]
base_train_data = base_train_data.reset_index(drop = True)

# base_train_data_lg = pseudo_lg.reset_index(drop = True)


device = torch.device('cuda:1')
BATCH_SIZE = 64

#base_train_data = pd.read_json(str(Path(BASE_PATH) / 'train.json'), lines=True)
base_test_data = pd.read_json(str(Path(BASE_PATH) / 'test.json'), lines=True)
base_test_data = aug_data(base_test_data)

public_df = base_test_data.query("seq_length == 107").copy()
private_df = base_test_data.query("seq_length == 130").copy()
print(f"public_df: {public_df.shape}")
print(f"private_df: {private_df.shape}")
public_df = public_df.reset_index()
private_df = private_df.reset_index()

features, _ = preprocess(base_train_data, True)
features_tensor = torch.from_numpy(features)
dataset0 = VacDataset(features_tensor, base_train_data, None)

# features, _ = preprocess(base_train_data_lg, True)
# features_tensor = torch.from_numpy(features)
# dataset0_lg = VacDataset(features_tensor, base_train_data_lg, None)

features, _ = preprocess(public_df, True)
features_tensor = torch.from_numpy(features)
dataset1 = VacDataset(features_tensor, public_df, None)

features, _ = preprocess(private_df, True)
features_tensor = torch.from_numpy(features)
dataset2 = VacDataset(features_tensor, private_df, None)

loader0 = torch.utils.data.DataLoader(dataset0, BATCH_SIZE, shuffle=False, drop_last=False)
# loader0_lg = torch.utils.data.DataLoader(dataset0_lg, BATCH_SIZE, shuffle=False, drop_last=False)
loader1 = torch.utils.data.DataLoader(dataset1, BATCH_SIZE, shuffle=False, drop_last=False)
loader2 = torch.utils.data.DataLoader(dataset2, BATCH_SIZE, shuffle=False, drop_last=False)

public_df: (1258, 10)
private_df: (6010, 10)
(1, 107, 107, 3)
(1, 107, 107, 3)
(1, 130, 130, 3)


### Pretrain

In [8]:
def learn_from_batch_ae(model, data, device):
    seq = data["sequence"].clone()
    seq[:, :, :14] = F.dropout2d(seq[:, :, :14], p=0.3)
    target = data["sequence"][:, :, :14]
    out = model(seq.to(device), data["bpp"].to(device))
    loss = F.binary_cross_entropy(out, target.to(device))
    return loss


def train_ae(model, train_data, optimizer, lr_scheduler, epochs=10, device="cpu", 
             start_epoch: int = 0, start_it: int = 0, log_path: str = "./logs"): #10
    print(f"device: {device}")
    losses = []
    it = start_it
    model_save_path = Path(MODEL_SAVE_PATH)
    start_epoch = start_epoch
    end_epoch = start_epoch + epochs
    min_loss = 10.0
    min_loss_epoch = 0
    if not model_save_path.exists():
        model_save_path.mkdir(parents=True)
    for epoch in progress_bar(range(start_epoch, end_epoch)):
        print(f"epoch: {epoch}")
        model.train()
        for i, data in enumerate(train_data):
            optimizer.zero_grad()
            loss = learn_from_batch_ae(model, data, device)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            if lr_scheduler:
                lr_scheduler.step()
            loss_v = loss.item()
            losses.append(loss_v)
            it += 1
        loss_m = np.mean(losses)
        if loss_m < min_loss:
            min_loss_epoch = epoch
            min_loss = loss_m
        print(f'epoch: {epoch} loss: {loss_m}')
        losses = []
        torch.save(optimizer.state_dict(), str(model_save_path / "optimizer.pt"))
        torch.save(model.state_dict(), str(model_save_path / f"model-{epoch}.pt"))
    return dict(end_epoch=end_epoch, it=it, min_loss_epoch=min_loss_epoch)

In [9]:
import shutil

set_seed(123)
shutil.rmtree("./model", True)
shutil.rmtree("./logs", True)
save_path = Path("./model_prediction")
if not save_path.exists():
    save_path.mkdir(parents=True)

lr_scheduler = None
device = "cuda:1" if torch.cuda.is_available() else "cpu"
model = AEModel()
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
res = dict(end_epoch=0, it=0, min_loss_epoch=0)
epochs = [5, 5, 5, 5] # [5, 5, 5, 5]
for e in epochs:
    print('--loader0')
    res = train_ae(model, loader0, optimizer, lr_scheduler, e, device=device,
                   start_epoch=res["end_epoch"], start_it=res["it"])
#     print('--loader0_lg')
#     res = train_ae(model, loader0_lg, optimizer, lr_scheduler, e, device=device,
#                    start_epoch=res["end_epoch"], start_it=res["it"])
    print('--loader1')
    res = train_ae(model, loader1, optimizer, lr_scheduler, e, device=device,
                   start_epoch=res["end_epoch"], start_it=res["it"])
    print('--loader2')
    res = train_ae(model, loader2, optimizer, lr_scheduler, e, device=device,
                   start_epoch=res["end_epoch"], start_it=res["it"])

epoch = res["min_loss_epoch"]
shutil.copyfile(str(Path(MODEL_SAVE_PATH) / f"model-{epoch}.pt"), "ae-model.pt")

--loader0
device: cuda:1


epoch: 0
epoch: 0 loss: 0.25064601679642995
epoch: 1
epoch: 1 loss: 0.1384060760339101
epoch: 2
epoch: 2 loss: 0.08038593033949534
epoch: 3
epoch: 3 loss: 0.03185466726620992
epoch: 4
epoch: 4 loss: 0.019618625019987426
--loader1
device: cuda:1


epoch: 5
epoch: 5 loss: 0.017772705294191837
epoch: 6
epoch: 6 loss: 0.01611646986566484
epoch: 7
epoch: 7 loss: 0.013704161671921612
epoch: 8
epoch: 8 loss: 0.011801821365952491
epoch: 9
epoch: 9 loss: 0.010440635774284602
--loader2
device: cuda:1


epoch: 10
epoch: 10 loss: 0.01808331135977456
epoch: 11
epoch: 11 loss: 0.014588844268879991
epoch: 12
epoch: 12 loss: 0.012971125731363576
epoch: 13
epoch: 13 loss: 0.012219857741543588
epoch: 14
epoch: 14 loss: 0.011251784831364738
--loader0
device: cuda:1


epoch: 15
epoch: 15 loss: 0.11141376870994767
epoch: 16
epoch: 16 loss: 0.14910804033279418
epoch: 17
epoch: 17 loss: 0.1133262633283933
epoch: 18
epoch: 18 loss: 0.08852919230858484
epoch: 19
epoch: 19 loss: 0.0644798181951046
--loader1
device: cuda:1


epoch: 20
epoch: 20 loss: 0.04376504682004452
epoch: 21
epoch: 21 loss: 0.03720299247652292
epoch: 22
epoch: 22 loss: 0.031649123318493365
epoch: 23
epoch: 23 loss: 0.028315145894885062
epoch: 24
epoch: 24 loss: 0.024756696540862322
--loader2
device: cuda:1


epoch: 25
epoch: 25 loss: 0.02692662131913165
epoch: 26
epoch: 26 loss: 0.02191674142600374
epoch: 27
epoch: 27 loss: 0.019343210384249687
epoch: 28
epoch: 28 loss: 0.018591710326677943
epoch: 29
epoch: 29 loss: 0.018955215732467934
--loader0
device: cuda:1


epoch: 30
epoch: 30 loss: 0.01184980947524309
epoch: 31
epoch: 31 loss: 0.011005594283342362
epoch: 32
epoch: 32 loss: 0.010138162014385065
epoch: 33
epoch: 33 loss: 0.00937236485381921
epoch: 34
epoch: 34 loss: 0.009434934935222069
--loader1
device: cuda:1


epoch: 35
epoch: 35 loss: 0.007858033105731011
epoch: 36
epoch: 36 loss: 0.008021229738369584
epoch: 37
epoch: 37 loss: 0.0076625833986327056
epoch: 38
epoch: 38 loss: 0.007388785784132779
epoch: 39
epoch: 39 loss: 0.006525391549803317
--loader2
device: cuda:1


epoch: 40
epoch: 40 loss: 0.014415230730825917
epoch: 41
epoch: 41 loss: 0.013332009156967731
epoch: 42
epoch: 42 loss: 0.01705033711573862
epoch: 43
epoch: 43 loss: 0.030063067166570652
epoch: 44
epoch: 44 loss: 0.01620889058772554
--loader0
device: cuda:1


epoch: 45
epoch: 45 loss: 0.0104247345837454
epoch: 46
epoch: 46 loss: 0.010372398793697358
epoch: 47
epoch: 47 loss: 0.011601323901365201
epoch: 48
epoch: 48 loss: 0.030044919215142726
epoch: 49
epoch: 49 loss: 0.0157378122707208
--loader1
device: cuda:1


epoch: 50
epoch: 50 loss: 0.009795540105551481
epoch: 51
epoch: 51 loss: 0.009084979956969618
epoch: 52
epoch: 52 loss: 0.009306920599192381
epoch: 53
epoch: 53 loss: 0.011493802862241864
epoch: 54
epoch: 54 loss: 0.01057223421521485
--loader2
device: cuda:1


epoch: 55
epoch: 55 loss: 0.019030503562076927
epoch: 56
epoch: 56 loss: 0.016735949563456975
epoch: 57
epoch: 57 loss: 0.014369799279944693
epoch: 58
epoch: 58 loss: 0.013059445131728624
epoch: 59
epoch: 59 loss: 0.012738071332507312


'ae-model.pt'

### Training

In [9]:
def MCRMSE(y_true, y_pred):
    colwise_mse = torch.mean(torch.square(y_true - y_pred), dim=1)
    return torch.mean(torch.sqrt(colwise_mse), dim=1)


def sn_mcrmse_loss(predict, target, signal_to_noise):
    loss = MCRMSE(target, predict)
#     weight = 0.5 * torch.log(signal_to_noise + 1.01)
#     loss = (loss * weight).mean()
    return loss.mean()


def learn_from_batch(model, data, optimizer, lr_scheduler, device):
    optimizer.zero_grad()

    out = model(data["sequence"].to(device), data["bpp"].to(device))
    out = out[:, :data["label"].shape[1]]
    signal_to_noise = data["signal_to_noise"] * data["score"]
    loss = sn_mcrmse_loss(out, data["label"].to(device), signal_to_noise.to(device))
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
    if lr_scheduler:
        lr_scheduler.step()
    return out, loss

def predict_batch(model, data, device):
    # batch x seq_len x target_size
    with torch.no_grad():
        pred = model(data["sequence"].to(device), data["bpp"].to(device))
        pred = pred.detach().cpu().numpy()
    return_values = []
    ids = data["ids"]
    for idx, p in enumerate(pred):
        id_ = ids[idx]
        assert p.shape == (model.pred_len, len(target_cols))
        for seqpos, val in enumerate(p):
            assert len(val) == len(target_cols)
            dic = {key: val for key, val in zip(target_cols, val)}
            dic["id_seqpos"] = f"{id_}_{seqpos}"
            return_values.append(dic)
    return return_values


def predict_data(model, loader, device, batch_size):
    model.eval()
    data_list = []
    for i, data in enumerate(progress_bar(loader)):
        data_list += predict_batch(model, data, device)
    expected_length = model.pred_len * len(loader) * batch_size
#     assert len(data_list) == expected_length, f"len = {len(data_list)} expected = {expected_length}"
    return data_list

def evaluate(model, valid_data, device):
    model.eval()
    loss_list = []
    mcrmse = []
    for i, data in enumerate(valid_data):
        with torch.no_grad():
            y = model(data["sequence"].to(device), data["bpp"].to(device))
            y = y[:, :68]
            mcrmse_ = MCRMSE(data["label"].to(device), y) #[data["signal_to_noise"] > 1]
            mcrmse.append(mcrmse_.mean().item())
            loss = sn_mcrmse_loss(y, data["label"].to(device), data["signal_to_noise"].to(device))
            loss_list.append(loss.item())
    model.train()
    return dict(loss=np.mean(loss_list), mcmse=np.mean(mcrmse))

def train(model, train_data, valid_data, train_data_lg, valid_data_lg, optimizer, lr_scheduler, epochs=10, device="cpu",
          start_epoch: int = 0, log_path: str = "./logs"):
    print(f"device: {device}")
    losses = []
    writer = SummaryWriter(log_path)
    it = 0
    model_save_path = Path(MODEL_SAVE_PATH)
    start_epoch = start_epoch
    end_epoch = start_epoch + epochs
    if not model_save_path.exists():
        model_save_path.mkdir(parents=True)
    min_eval_loss = 10.0
    min_eval_epoch = None
    for epoch in progress_bar(range(start_epoch, end_epoch)):
        print(f"epoch: {epoch}")
        model.train()
        
        for i, (data, data_lg) in enumerate(itertools.zip_longest(train_data, train_data_lg)):
            if data_lg is not None:
                _, loss_lg = learn_from_batch(model, data_lg, optimizer, lr_scheduler, device)
                losses.append(loss_lg.item())
                
            _, loss = learn_from_batch(model, data, optimizer, lr_scheduler, device)
            loss_v = loss.item()
            writer.add_scalar('loss', loss_v, it)
            losses.append(loss_v)
            it += 1
        print(f'epoch: {epoch} loss: {np.mean(losses)}')
        losses = []

        eval_result = evaluate(model, valid_data, device)
        eval_loss = eval_result["loss"]
        if eval_loss <= min_eval_loss:
            min_eval_epoch = epoch
            min_eval_loss = eval_loss
            min_mcrmse = eval_result["mcmse"]
            valid_pred = predict_data(model, valid_data, device, BATCH_SIZE)

        print(f"eval loss: {eval_loss} {eval_result['mcmse']}")
        writer.add_scalar(f"evaluate/loss", eval_loss, epoch)
        writer.add_scalar(f"evaluate/mcmse", eval_result["mcmse"], epoch)
        model.train()
        torch.save(optimizer.state_dict(), str(model_save_path / "optimizer.pt"))
        torch.save(model.state_dict(), str(model_save_path / f"model-{epoch}.pt"))
        
    print(f'min eval loss: {min_eval_loss} min mcrmse: {min_mcrmse} epoch {min_eval_epoch}')
    return min_eval_epoch, valid_pred

In [10]:
device = torch.device('cuda:1')
BATCH_SIZE = 64

# categorical value for target (used for stratified kfold)
#def add_y_cat(df):
#    target_mean = df['reactivity'].apply(np.mean) + \
#                  df['deg_Mg_pH10'].apply(np.mean) + \
#                  df['deg_Mg_50C'].apply(np.mean) 
                  #df['deg_pH10'].apply(np.mean) + \
                  #df['deg_50C'].apply(np.mean)
#    df['y_cat'] = pd.qcut(np.array(target_mean), q=20).codes
    
base_train_data = pd.read_json(str(Path(BASE_PATH) / 'train.json'), lines=True)

base_train_data = aug_data(base_train_data)
base_train_data = base_train_data.append(pseudo_st)
base_train_data = base_train_data[base_train_data['SN_filter'] == 1]
base_train_data = base_train_data.reset_index(drop = True)

base_train_data_lg = pseudo_lg.reset_index(drop = True)
#add_y_cat(base_train_data)

samples = base_train_data
save_path = Path("./model_prediction")
if not save_path.exists():
    save_path.mkdir(parents=True)
shutil.rmtree("./model", True)
shutil.rmtree("./logs", True)
split = ShuffleSplit(n_splits=5, test_size=.1)
# kf = StratifiedKFold(n_splits = 5, shuffle = True, random_state = 123)
kf = GroupKFold(n_splits=5)
    
ids = samples.reset_index()["id"]

set_seed(123)

oof_list = []
for fold, ((train_index, test_index), (train_index_lg, test_index_lg)) in enumerate(zip(kf.split(samples, groups=ids), 
                                                     kf.split(base_train_data_lg, groups=base_train_data_lg['id']))): #split.split(samples) / kf
    print(f"fold: {fold}")
    train_df = samples.loc[train_index].reset_index()
    val_df = samples.loc[test_index].reset_index()
    train_loader = create_loader(train_df, BATCH_SIZE)
    valid_loader = create_loader(val_df, BATCH_SIZE)
    print(train_df.shape, val_df.shape)
    
    train_df_lg = base_train_data_lg.loc[train_index_lg].reset_index()
    val_df_lg = base_train_data_lg.loc[test_index_lg].reset_index()
    train_loader_lg = create_loader(train_df_lg, BATCH_SIZE)
    valid_loader_lg = create_loader(val_df_lg, BATCH_SIZE)
    print(train_df_lg.shape, val_df_lg.shape)
    
    ae_model = AEModel()
    state_dict = torch.load("./ae-model.pt")
    ae_model.load_state_dict(state_dict)
    del state_dict
    
    model = FromAeModel(ae_model.seq, pred_len=91)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    lr_scheduler = None
    
    epoch, valid_pred = train(model, train_loader, valid_loader, train_loader_lg, valid_loader_lg, optimizer, lr_scheduler, 150, device=device,
                  log_path=f"logs/{fold}") #150 #200 epochs
    oof_list += valid_pred
    shutil.copyfile(str(Path(MODEL_SAVE_PATH) / f"./model-{epoch}.pt"), f"model_prediction/model-{fold}.pt")
    del model

fold: 0
(1, 107, 107, 3)
(1, 107, 107, 3)
(3045, 23) (762, 23)
(1, 130, 130, 3)
(1, 130, 130, 3)
(2404, 16) (601, 16)
device: cuda:1


epoch: 0
epoch: 0 loss: 0.28286328863691057


eval loss: 0.2623660741429435 0.2623660741429435
epoch: 1
epoch: 1 loss: 0.2139285774706132


eval loss: 0.24422846096852358 0.24422846096852358
epoch: 2
epoch: 2 loss: 0.1892291842582362
eval loss: 0.24579620056284857 0.24579620056284857
epoch: 3
epoch: 3 loss: 0.17810301003949489


eval loss: 0.21774162759340068 0.21774162759340068
epoch: 4
epoch: 4 loss: 0.16308569032544168


eval loss: 0.20429748354437952 0.20429748354437952
epoch: 5
epoch: 5 loss: 0.15465881631977005


eval loss: 0.19918136001394593 0.19918136001394593
epoch: 6
epoch: 6 loss: 0.14918654299496084


eval loss: 0.19781545938710277 0.19781545938710277
epoch: 7
epoch: 7 loss: 0.14367057427190685


eval loss: 0.19362520458740132 0.19362520458740132
epoch: 8
epoch: 8 loss: 0.13974011909911674


eval loss: 0.19022096662134502 0.19022096662134502
epoch: 9
epoch: 9 loss: 0.13870093609826495


eval loss: 0.1901595437182774 0.1901595437182774
epoch: 10
epoch: 10 loss: 0.1346225675163326


eval loss: 0.18701443872626677 0.18701443872626677
epoch: 11
epoch: 11 loss: 0.13283342411371657


eval loss: 0.18415654422567693 0.18415654422567693
epoch: 12
epoch: 12 loss: 0.1289976335076433
eval loss: 0.18557248621905376 0.18557248621905376
epoch: 13
epoch: 13 loss: 0.126967857671882


eval loss: 0.18385273420643067 0.18385273420643067
epoch: 14
epoch: 14 loss: 0.12519794685592836


eval loss: 0.17882717240100307 0.17882717240100307
epoch: 15
epoch: 15 loss: 0.12324988661960448
eval loss: 0.18054060150171838 0.18054060150171838
epoch: 16
epoch: 16 loss: 0.12148582390867561
eval loss: 0.17968540764504684 0.17968540764504684
epoch: 17
epoch: 17 loss: 0.12055273748221518
eval loss: 0.17962127565493405 0.17962127565493405
epoch: 18
epoch: 18 loss: 0.12057955518651009
eval loss: 0.17997039802536396 0.17997039802536396
epoch: 19
epoch: 19 loss: 0.11767072735910138


eval loss: 0.17812070684513703 0.17812070684513703
epoch: 20
epoch: 20 loss: 0.11710595512382746


eval loss: 0.17730624580564655 0.17730624580564655
epoch: 21
epoch: 21 loss: 0.11641794964671319
eval loss: 0.17893271034092162 0.17893271034092162
epoch: 22
epoch: 22 loss: 0.11492762559325106
eval loss: 0.1778544670890849 0.1778544670890849
epoch: 23
epoch: 23 loss: 0.11504120863149328
eval loss: 0.17758369937243831 0.17758369937243831
epoch: 24
epoch: 24 loss: 0.1123554507113918


eval loss: 0.1747537654484621 0.1747537654484621
epoch: 25
epoch: 25 loss: 0.11122482351940169


eval loss: 0.17198085101915275 0.17198085101915275
epoch: 26
epoch: 26 loss: 0.11314377702199266
eval loss: 0.17514093091674987 0.17514093091674987
epoch: 27
epoch: 27 loss: 0.10991075354321599
eval loss: 0.17297489120112472 0.17297489120112472
epoch: 28
epoch: 28 loss: 0.10767803209739744


eval loss: 0.17129225086065658 0.17129225086065658
epoch: 29
epoch: 29 loss: 0.10811798589190928
eval loss: 0.17361782797672795 0.17361782797672795
epoch: 30
epoch: 30 loss: 0.10767711296882126
eval loss: 0.17929367265817572 0.17929367265817572
epoch: 31
epoch: 31 loss: 0.10705641761415782


eval loss: 0.17051068915631254 0.17051068915631254
epoch: 32
epoch: 32 loss: 0.10457343681158118
eval loss: 0.1724574871850534 0.1724574871850534
epoch: 33
epoch: 33 loss: 0.10443890443413749
eval loss: 0.1736875959843928 0.1736875959843928
epoch: 34
epoch: 34 loss: 0.10671895793379767
eval loss: 0.17069732748217345 0.17069732748217345
epoch: 35
epoch: 35 loss: 0.1035267514336702
eval loss: 0.17273678349524035 0.17273678349524035
epoch: 36
epoch: 36 loss: 0.10320223711145207
eval loss: 0.1738403327327461 0.1738403327327461
epoch: 37
epoch: 37 loss: 0.10128518275010025
eval loss: 0.17258510059190113 0.17258510059190113
epoch: 38
epoch: 38 loss: 0.10027850482532445
eval loss: 0.17419975520068298 0.17419975520068298
epoch: 39
epoch: 39 loss: 0.10079523170552798


eval loss: 0.17034436262501937 0.17034436262501937
epoch: 40
epoch: 40 loss: 0.09981005970193114
eval loss: 0.17053463096805235 0.17053463096805235
epoch: 41
epoch: 41 loss: 0.09763822136265521


eval loss: 0.16887796560699378 0.16887796560699378
epoch: 42
epoch: 42 loss: 0.09846647157930738
eval loss: 0.17345452504017764 0.17345452504017764
epoch: 43
epoch: 43 loss: 0.09952069220043167
eval loss: 0.17088227921840815 0.17088227921840815
epoch: 44
epoch: 44 loss: 0.0973182633228332
eval loss: 0.17009235619843807 0.17009235619843807
epoch: 45
epoch: 45 loss: 0.09695753457838564


eval loss: 0.16870622434976346 0.16870622434976346
epoch: 46
epoch: 46 loss: 0.09716807359923033


eval loss: 0.16841802528294195 0.16841802528294195
epoch: 47
epoch: 47 loss: 0.09632971150767028
eval loss: 0.1713222037884632 0.1713222037884632
epoch: 48
epoch: 48 loss: 0.095681635430036
eval loss: 0.16938366790752088 0.16938366790752088
epoch: 49
epoch: 49 loss: 0.09566728369351747


eval loss: 0.16829380776420141 0.16829380776420141
epoch: 50
epoch: 50 loss: 0.09603162673639563


eval loss: 0.16791686600398936 0.16791686600398936
epoch: 51
epoch: 51 loss: 0.09385074497904115
eval loss: 0.1718490103708462 0.1718490103708462
epoch: 52
epoch: 52 loss: 0.09366078652859196
eval loss: 0.16898601799104443 0.16898601799104443
epoch: 53
epoch: 53 loss: 0.0923652412250632
eval loss: 0.16802728290723123 0.16802728290723123
epoch: 54
epoch: 54 loss: 0.0924217274463274
eval loss: 0.17546550020909804 0.17546550020909804
epoch: 55
epoch: 55 loss: 0.09477460978579223
eval loss: 0.16930888638849226 0.16930888638849226
epoch: 56
epoch: 56 loss: 0.09225244007407639
eval loss: 0.17255288837826066 0.17255288837826066
epoch: 57
epoch: 57 loss: 0.09114660506687591
eval loss: 0.16832474282701337 0.16832474282701337
epoch: 58
epoch: 58 loss: 0.09316098038682213
eval loss: 0.16989234820743726 0.16989234820743726
epoch: 59
epoch: 59 loss: 0.09043658060423689
eval loss: 0.16840744932071414 0.16840744932071414
epoch: 60
epoch: 60 loss: 0.09014280842816666
eval loss: 0.17029268568315634 0.1

eval loss: 0.1656160009840286 0.1656160009840286
epoch: 63
epoch: 63 loss: 0.0896536538343156
eval loss: 0.16707343299582111 0.16707343299582111
epoch: 64
epoch: 64 loss: 0.08790000722529602
eval loss: 0.16800157185779308 0.16800157185779308
epoch: 65
epoch: 65 loss: 0.08792722808874648
eval loss: 0.171340417184633 0.171340417184633
epoch: 66
epoch: 66 loss: 0.08998366469537315
eval loss: 0.17274689527665746 0.17274689527665746
epoch: 67
epoch: 67 loss: 0.08776596186593044
eval loss: 0.16883324938744862 0.16883324938744862
epoch: 68
epoch: 68 loss: 0.08709971670534727
eval loss: 0.16967312484320315 0.16967312484320315
epoch: 69
epoch: 69 loss: 0.08593584134396881
eval loss: 0.16819682004587375 0.16819682004587375
epoch: 70
epoch: 70 loss: 0.08621803051015975
eval loss: 0.16871751203609534 0.16871751203609534
epoch: 71
epoch: 71 loss: 0.08605734954100347
eval loss: 0.1692669774818865 0.1692669774818865
epoch: 72
epoch: 72 loss: 0.08572722278093005
eval loss: 0.1693584808438741 0.1693584

epoch: 147 loss: 0.33141629734732947
eval loss: 0.35816155029260294 0.35816155029260294
epoch: 148
epoch: 148 loss: 0.3325244014787904
eval loss: 0.35971676050798324 0.35971676050798324
epoch: 149
epoch: 149 loss: 0.328998286438479
eval loss: 0.3580578257784088 0.3580578257784088
min eval loss: 0.1656160009840286 min mcrmse: 0.1656160009840286 epoch 62
fold: 1
(1, 107, 107, 3)
(1, 107, 107, 3)
(3045, 23) (762, 23)
(1, 130, 130, 3)
(1, 130, 130, 3)
(2404, 16) (601, 16)
device: cuda:1


epoch: 0
epoch: 0 loss: 0.31152110851244835


eval loss: 0.27341997837128024 0.27341997837128024
epoch: 1
epoch: 1 loss: 0.2203649458604234


eval loss: 0.23708212888168154 0.23708212888168154
epoch: 2
epoch: 2 loss: 0.1961635056132221


eval loss: 0.2248811587756673 0.2248811587756673
epoch: 3
epoch: 3 loss: 0.18105219300260536


eval loss: 0.21560929602815912 0.21560929602815912
epoch: 4
epoch: 4 loss: 0.1683889460278933


eval loss: 0.19978784021369736 0.19978784021369736
epoch: 5
epoch: 5 loss: 0.15999269189387863


eval loss: 0.19823078767903365 0.19823078767903365
epoch: 6
epoch: 6 loss: 0.1552315689855958


eval loss: 0.19245113955329304 0.19245113955329304
epoch: 7
epoch: 7 loss: 0.14871518830160402


eval loss: 0.18775028642800465 0.18775028642800465
epoch: 8
epoch: 8 loss: 0.14507442816895305


eval loss: 0.18681151947482713 0.18681151947482713
epoch: 9
epoch: 9 loss: 0.14084173516774906
eval loss: 0.1898186107891837 0.1898186107891837
epoch: 10
epoch: 10 loss: 0.13714610056773577


eval loss: 0.183655171147438 0.183655171147438
epoch: 11
epoch: 11 loss: 0.13510125419934518


eval loss: 0.17944905462395022 0.17944905462395022
epoch: 12
epoch: 12 loss: 0.134051643970446
eval loss: 0.1801801153354686 0.1801801153354686
epoch: 13
epoch: 13 loss: 0.1307773412946359


eval loss: 0.17857141110986344 0.17857141110986344
epoch: 14
epoch: 14 loss: 0.12917578456758957


eval loss: 0.17525763670454195 0.17525763670454195
epoch: 15
epoch: 15 loss: 0.1266910897623717
eval loss: 0.17926948835848336 0.17926948835848336
epoch: 16
epoch: 16 loss: 0.12555351134509876
eval loss: 0.18688621135096253 0.18688621135096253
epoch: 17
epoch: 17 loss: 0.12533738338699277
eval loss: 0.17809955444261516 0.17809955444261516
epoch: 18
epoch: 18 loss: 0.12385339796585787
eval loss: 0.17671511093962175 0.17671511093962175
epoch: 19
epoch: 19 loss: 0.12410520496325615


eval loss: 0.1738489354064162 0.1738489354064162
epoch: 20
epoch: 20 loss: 0.11964078477573847


eval loss: 0.17074414301988824 0.17074414301988824
epoch: 21
epoch: 21 loss: 0.12038625694584591
eval loss: 0.17122158871785623 0.17122158871785623
epoch: 22
epoch: 22 loss: 0.11823025249264316


eval loss: 0.17003468025919277 0.17003468025919277
epoch: 23
epoch: 23 loss: 0.11985372726092297


eval loss: 0.16758186199035385 0.16758186199035385
epoch: 24
epoch: 24 loss: 0.11774121837424632
eval loss: 0.17218138180191503 0.17218138180191503
epoch: 25
epoch: 25 loss: 0.11662823862276754
eval loss: 0.1693995608637153 0.1693995608637153
epoch: 26
epoch: 26 loss: 0.11907408674965356
eval loss: 0.1694250514510628 0.1694250514510628
epoch: 27
epoch: 27 loss: 0.11501798100072791


eval loss: 0.16698477977957468 0.16698477977957468
epoch: 28
epoch: 28 loss: 0.113080476439299
eval loss: 0.1678188372730133 0.1678188372730133
epoch: 29
epoch: 29 loss: 0.11205734719306253
eval loss: 0.16938092111121772 0.16938092111121772
epoch: 30
epoch: 30 loss: 0.1142016233523466
eval loss: 0.17096066006537924 0.17096066006537924
epoch: 31
epoch: 31 loss: 0.1105357108567972


eval loss: 0.1637861132449098 0.1637861132449098
epoch: 32
epoch: 32 loss: 0.10866938875744037
eval loss: 0.16395597279072363 0.16395597279072363
epoch: 33
epoch: 33 loss: 0.10925340675059712
eval loss: 0.16484845810432233 0.16484845810432233
epoch: 34
epoch: 34 loss: 0.10803490166057221


eval loss: 0.16296662839451706 0.16296662839451706
epoch: 35
epoch: 35 loss: 0.10783175908260371
eval loss: 0.16574773501630627 0.16574773501630627
epoch: 36
epoch: 36 loss: 0.10764795161618908
eval loss: 0.16461066984739833 0.16461066984739833
epoch: 37
epoch: 37 loss: 0.10735663138599941
eval loss: 0.17685800678785302 0.17685800678785302
epoch: 38
epoch: 38 loss: 0.10793110389516593
eval loss: 0.16304461709387508 0.16304461709387508
epoch: 39
epoch: 39 loss: 0.10526354424950106
eval loss: 0.16341189839695094 0.16341189839695094
epoch: 40
epoch: 40 loss: 0.10586704974010347
eval loss: 0.16405100842892376 0.16405100842892376
epoch: 41
epoch: 41 loss: 0.10565450226864097
eval loss: 0.16540880996003937 0.16540880996003937
epoch: 42
epoch: 42 loss: 0.10290710493143441
eval loss: 0.16502805463515072 0.16502805463515072
epoch: 43
epoch: 43 loss: 0.10403327574653659
eval loss: 0.16547374610218732 0.16547374610218732
epoch: 44
epoch: 44 loss: 0.10206898594537196
eval loss: 0.16350047824430683

eval loss: 0.16096700224332747 0.16096700224332747
epoch: 49
epoch: 49 loss: 0.09954354004467036
eval loss: 0.16217947851384992 0.16217947851384992
epoch: 50
epoch: 50 loss: 0.10108063098059539
eval loss: 0.1618693255296531 0.1618693255296531
epoch: 51
epoch: 51 loss: 0.09961702815101628
eval loss: 0.16268185750349612 0.16268185750349612
epoch: 52
epoch: 52 loss: 0.09815237270075768
eval loss: 0.1626993439437442 0.1626993439437442
epoch: 53
epoch: 53 loss: 0.10023855561425674
eval loss: 0.16182428432677443 0.16182428432677443
epoch: 54
epoch: 54 loss: 0.10057889729860725
eval loss: 0.16393698959431624 0.16393698959431624
epoch: 55
epoch: 55 loss: 0.09741924054945374
eval loss: 0.1614297697639939 0.1614297697639939
epoch: 56
epoch: 56 loss: 0.09639829738872527
eval loss: 0.16429386837625218 0.16429386837625218
epoch: 57
epoch: 57 loss: 0.10053358392692875
eval loss: 0.1616901479039654 0.1616901479039654
epoch: 58
epoch: 58 loss: 0.09735975250480582
eval loss: 0.1626698123455763 0.162669

epoch: 134
epoch: 134 loss: 0.3143890821187298
eval loss: 0.34705295060877533 0.34705295060877533
epoch: 135
epoch: 135 loss: 0.3175117650903228
eval loss: 0.34871375090335704 0.34871375090335704
epoch: 136
epoch: 136 loss: 0.31677917481353607
eval loss: 0.34251427341605845 0.34251427341605845
epoch: 137
epoch: 137 loss: 0.31531893294336194
eval loss: 0.34305112217201067 0.34305112217201067
epoch: 138
epoch: 138 loss: 0.3118210083904876
eval loss: 0.33928026682204004 0.33928026682204004
epoch: 139
epoch: 139 loss: 0.3105826380003245
eval loss: 0.33847506210635103 0.33847506210635103
epoch: 140
epoch: 140 loss: 0.31015493452621395
eval loss: 0.33821519518603643 0.33821519518603643
epoch: 141
epoch: 141 loss: 0.3120737897270862
eval loss: 0.3438273436508903 0.3438273436508903
epoch: 142
epoch: 142 loss: 0.3143289459351606
eval loss: 0.34367458445486965 0.34367458445486965
epoch: 143
epoch: 143 loss: 0.3118775613143845
eval loss: 0.3420490120000039 0.3420490120000039
epoch: 144
epoch: 144

epoch: 0
epoch: 0 loss: 0.27533954453492987


eval loss: 0.2578477311060042 0.2578477311060042
epoch: 1
epoch: 1 loss: 0.20989604904883288


eval loss: 0.23864290173041547 0.23864290173041547
epoch: 2
epoch: 2 loss: 0.1890919394098314


eval loss: 0.22352910962319059 0.22352910962319059
epoch: 3
epoch: 3 loss: 0.1756779174206408


eval loss: 0.21201088915133534 0.21201088915133534
epoch: 4
epoch: 4 loss: 0.16544854584116728


eval loss: 0.2039941265083497 0.2039941265083497
epoch: 5
epoch: 5 loss: 0.1577717189251038


eval loss: 0.19838611281855512 0.19838611281855512
epoch: 6
epoch: 6 loss: 0.15247410934515845


eval loss: 0.19089944370440293 0.19089944370440293
epoch: 7
epoch: 7 loss: 0.14713257672077076


eval loss: 0.18761336358713834 0.18761336358713834
epoch: 8
epoch: 8 loss: 0.1432078200681094


eval loss: 0.18548232081509144 0.18548232081509144
epoch: 9
epoch: 9 loss: 0.14196513119887844


eval loss: 0.18494741346271845 0.18494741346271845
epoch: 10
epoch: 10 loss: 0.1357093359438675


eval loss: 0.1797273153602541 0.1797273153602541
epoch: 11
epoch: 11 loss: 0.13266978106667113


eval loss: 0.17762606861843575 0.17762606861843575
epoch: 12
epoch: 12 loss: 0.13040766281256752


eval loss: 0.17612939553342136 0.17612939553342136
epoch: 13
epoch: 13 loss: 0.1287065649940328


eval loss: 0.17566716902929422 0.17566716902929422
epoch: 14
epoch: 14 loss: 0.12777521215417897


eval loss: 0.17465612830178445 0.17465612830178445
epoch: 15
epoch: 15 loss: 0.1243917473376048


eval loss: 0.17247952792990162 0.17247952792990162
epoch: 16
epoch: 16 loss: 0.1246118023980773


eval loss: 0.16984807887771006 0.16984807887771006
epoch: 17
epoch: 17 loss: 0.12190296499034618
eval loss: 0.1790741209554214 0.1790741209554214
epoch: 18
epoch: 18 loss: 0.12048856110867653
eval loss: 0.17102494730373152 0.17102494730373152
epoch: 19
epoch: 19 loss: 0.11911388857007768


eval loss: 0.16978584368969685 0.16978584368969685
epoch: 20
epoch: 20 loss: 0.11793942508273364


eval loss: 0.1676946383923672 0.1676946383923672
epoch: 21
epoch: 21 loss: 0.1157514637942288
eval loss: 0.16930545195375069 0.16930545195375069
epoch: 22
epoch: 22 loss: 0.11585549592158459
eval loss: 0.16902098951018893 0.16902098951018893
epoch: 23
epoch: 23 loss: 0.1150568726668746
eval loss: 0.17109597990045247 0.17109597990045247
epoch: 24
epoch: 24 loss: 0.11453081048601717


eval loss: 0.16752054052471 0.16752054052471
epoch: 25
epoch: 25 loss: 0.11157515468096808


eval loss: 0.1645797941377221 0.1645797941377221
epoch: 26
epoch: 26 loss: 0.11183371889256127
eval loss: 0.16827509734425286 0.16827509734425286
epoch: 27
epoch: 27 loss: 0.10972942395393182
eval loss: 0.16556594277423956 0.16556594277423956
epoch: 28
epoch: 28 loss: 0.10925870543081258
eval loss: 0.1667768198418755 0.1667768198418755
epoch: 29
epoch: 29 loss: 0.1096217566253164
eval loss: 0.1658571654353168 0.1658571654353168
epoch: 30
epoch: 30 loss: 0.10795790166607767


eval loss: 0.16379283078691356 0.16379283078691356
epoch: 31
epoch: 31 loss: 0.10677164084862201
eval loss: 0.16539584812774819 0.16539584812774819
epoch: 32
epoch: 32 loss: 0.10583802225755189
eval loss: 0.17275620750122836 0.17275620750122836
epoch: 33
epoch: 33 loss: 0.10511448851921522


eval loss: 0.1635572402304887 0.1635572402304887
epoch: 34
epoch: 34 loss: 0.10370700795777518


eval loss: 0.162365516417803 0.162365516417803
epoch: 35
epoch: 35 loss: 0.10699549221493335
eval loss: 0.16292449169345483 0.16292449169345483
epoch: 36
epoch: 36 loss: 0.10237728316959763


eval loss: 0.16198968548957388 0.16198968548957388
epoch: 37
epoch: 37 loss: 0.10281191525015299
eval loss: 0.167214242284644 0.167214242284644
epoch: 38
epoch: 38 loss: 0.10210434347113247


eval loss: 0.16192155471585395 0.16192155471585395
epoch: 39
epoch: 39 loss: 0.10173396516467313


eval loss: 0.1609064180155388 0.1609064180155388
epoch: 40
epoch: 40 loss: 0.1001425178323239
eval loss: 0.1622327986978592 0.1622327986978592
epoch: 41
epoch: 41 loss: 0.0990991837543414


eval loss: 0.16084132995072317 0.16084132995072317
epoch: 42
epoch: 42 loss: 0.0996992259671634


eval loss: 0.16074955140595995 0.16074955140595995
epoch: 43
epoch: 43 loss: 0.1004611079704874


eval loss: 0.16054398976982856 0.16054398976982856
epoch: 44
epoch: 44 loss: 0.09822678446012947
eval loss: 0.16181036930558898 0.16181036930558898
epoch: 45
epoch: 45 loss: 0.09860517550624065
eval loss: 0.16421372914128837 0.16421372914128837
epoch: 46
epoch: 46 loss: 0.09871038922772316
eval loss: 0.1681330712107514 0.1681330712107514
epoch: 47
epoch: 47 loss: 0.09834392484300326
eval loss: 0.16252882611468575 0.16252882611468575
epoch: 48
epoch: 48 loss: 0.096456497387369


eval loss: 0.15978805647107117 0.15978805647107117
epoch: 49
epoch: 49 loss: 0.09490483832890034
eval loss: 0.1625639708476818 0.1625639708476818
epoch: 50
epoch: 50 loss: 0.09453797026827034
eval loss: 0.16026967548752472 0.16026967548752472
epoch: 51
epoch: 51 loss: 0.09459072845237909


eval loss: 0.1593678356148955 0.1593678356148955
epoch: 52
epoch: 52 loss: 0.09616625074723524
eval loss: 0.16150294280146274 0.16150294280146274
epoch: 53
epoch: 53 loss: 0.0950744322234205
eval loss: 0.1625733796240326 0.1625733796240326
epoch: 54
epoch: 54 loss: 0.094876578376603
eval loss: 0.16134819599566333 0.16134819599566333
epoch: 55
epoch: 55 loss: 0.09640002998019521
eval loss: 0.16028848159031497 0.16028848159031497
epoch: 56
epoch: 56 loss: 0.09338349235214846
eval loss: 0.16078282258683316 0.16078282258683316
epoch: 57
epoch: 57 loss: 0.09386671225929001
eval loss: 0.16423113190307184 0.16423113190307184
epoch: 58
epoch: 58 loss: 0.09187604111877833
eval loss: 0.16132980549259723 0.16132980549259723
epoch: 59
epoch: 59 loss: 0.091959205971643


eval loss: 0.15872754950860954 0.15872754950860954
epoch: 60
epoch: 60 loss: 0.09152956299514986
eval loss: 0.16136799823935175 0.16136799823935175
epoch: 61
epoch: 61 loss: 0.0904718880064009
eval loss: 0.15986845324699614 0.15986845324699614
epoch: 62
epoch: 62 loss: 0.09124137146589843
eval loss: 0.1610863256967753 0.1610863256967753
epoch: 63
epoch: 63 loss: 0.09053118421647997
eval loss: 0.15951182043065656 0.15951182043065656
epoch: 64
epoch: 64 loss: 0.09017393581033628
eval loss: 0.16035491358256093 0.16035491358256093
epoch: 65
epoch: 65 loss: 0.09053814157657539
eval loss: 0.1623507059427918 0.1623507059427918
epoch: 66
epoch: 66 loss: 0.08864759986412284
eval loss: 0.1590894585883633 0.1590894585883633
epoch: 67
epoch: 67 loss: 0.09071645546007863
eval loss: 0.16214548390442943 0.16214548390442943
epoch: 68
epoch: 68 loss: 0.09006457174538685
eval loss: 0.159197079898641 0.159197079898641
epoch: 69
epoch: 69 loss: 0.08735823922999307
eval loss: 0.15933222984202108 0.15933222

epoch: 145
epoch: 145 loss: 0.3282341074292684
eval loss: 0.36399836879180936 0.36399836879180936
epoch: 146
epoch: 146 loss: 0.32818271299890706
eval loss: 0.3604382163740009 0.3604382163740009
epoch: 147
epoch: 147 loss: 0.3264930645145361
eval loss: 0.3601011846880735 0.3601011846880735
epoch: 148
epoch: 148 loss: 0.3254999799932768
eval loss: 0.3571696501476924 0.3571696501476924
epoch: 149
epoch: 149 loss: 0.32506137586217687
eval loss: 0.362560332891431 0.362560332891431
min eval loss: 0.15872754950860954 min mcrmse: 0.15872754950860954 epoch 59
fold: 3
(1, 107, 107, 3)
(1, 107, 107, 3)
(3046, 23) (761, 23)
(1, 130, 130, 3)
(1, 130, 130, 3)
(2404, 16) (601, 16)
device: cuda:1


epoch: 0
epoch: 0 loss: 0.2934841860369835


eval loss: 0.2679976945937938 0.2679976945937938
epoch: 1
epoch: 1 loss: 0.21724215749711834


eval loss: 0.2485765628407004 0.2485765628407004
epoch: 2
epoch: 2 loss: 0.1923432694162525


eval loss: 0.22542516926764308 0.22542516926764308
epoch: 3
epoch: 3 loss: 0.17370227397514648


eval loss: 0.21312797336384895 0.21312797336384895
epoch: 4
epoch: 4 loss: 0.16386388006367095


eval loss: 0.20425596390196746 0.20425596390196746
epoch: 5
epoch: 5 loss: 0.15429004301976185


eval loss: 0.20100494337051492 0.20100494337051492
epoch: 6
epoch: 6 loss: 0.14877361162498312
eval loss: 0.20319055098083114 0.20319055098083114
epoch: 7
epoch: 7 loss: 0.1442794753551702
eval loss: 0.20586895697058005 0.20586895697058005
epoch: 8
epoch: 8 loss: 0.1403785639881727


eval loss: 0.191542868812925 0.191542868812925
epoch: 9
epoch: 9 loss: 0.135383567819488


eval loss: 0.1899796081583225 0.1899796081583225
epoch: 10
epoch: 10 loss: 0.13149222238333244


eval loss: 0.18584534329754362 0.18584534329754362
epoch: 11
epoch: 11 loss: 0.12882805714130965
eval loss: 0.18738145948515786 0.18738145948515786
epoch: 12
epoch: 12 loss: 0.12675540839221575
eval loss: 0.18597104342553125 0.18597104342553125
epoch: 13
epoch: 13 loss: 0.12569452446625368


eval loss: 0.18093367803167057 0.18093367803167057
epoch: 14
epoch: 14 loss: 0.12266299020717329


eval loss: 0.1798740376026017 0.1798740376026017
epoch: 15
epoch: 15 loss: 0.1222302694092468


eval loss: 0.17908077303576805 0.17908077303576805
epoch: 16
epoch: 16 loss: 0.11913387545037611


eval loss: 0.1781599715705338 0.1781599715705338
epoch: 17
epoch: 17 loss: 0.11813534397828533
eval loss: 0.178695277000764 0.178695277000764
epoch: 18
epoch: 18 loss: 0.11708432815596347
eval loss: 0.1831002726788029 0.1831002726788029
epoch: 19
epoch: 19 loss: 0.11759032984322844
eval loss: 0.18130748704642807 0.18130748704642807
epoch: 20
epoch: 20 loss: 0.11563558033379771


eval loss: 0.1754833246895204 0.1754833246895204
epoch: 21
epoch: 21 loss: 0.1129364879723655


eval loss: 0.174707550683067 0.174707550683067
epoch: 22
epoch: 22 loss: 0.11226242390359376
eval loss: 0.17618780874913584 0.17618780874913584
epoch: 23
epoch: 23 loss: 0.11055970230145142
eval loss: 0.17645639874265776 0.17645639874265776
epoch: 24
epoch: 24 loss: 0.11008776145012315


eval loss: 0.1737559072420356 0.1737559072420356
epoch: 25
epoch: 25 loss: 0.1089280081223054
eval loss: 0.17693402594702298 0.17693402594702298
epoch: 26
epoch: 26 loss: 0.10941675996057285
eval loss: 0.17772387553385546 0.17772387553385546
epoch: 27
epoch: 27 loss: 0.10794523975225881
eval loss: 0.17476435507893526 0.17476435507893526
epoch: 28
epoch: 28 loss: 0.10615883659307948
eval loss: 0.1761021833247156 0.1761021833247156
epoch: 29
epoch: 29 loss: 0.10707101165868212


eval loss: 0.17152393870530347 0.17152393870530347
epoch: 30
epoch: 30 loss: 0.10435483221640124
eval loss: 0.17370674109743386 0.17370674109743386
epoch: 31
epoch: 31 loss: 0.10295937861369467


eval loss: 0.17023480753725642 0.17023480753725642
epoch: 32
epoch: 32 loss: 0.10322636845937172
eval loss: 0.17546844263630867 0.17546844263630867
epoch: 33
epoch: 33 loss: 0.10225380955721172


eval loss: 0.16971413408323135 0.16971413408323135
epoch: 34
epoch: 34 loss: 0.10091511424274258
eval loss: 0.17052755792771115 0.17052755792771115
epoch: 35
epoch: 35 loss: 0.10083998144543847
eval loss: 0.17195656479876095 0.17195656479876095
epoch: 36
epoch: 36 loss: 0.09947084930752587


eval loss: 0.16912314907606305 0.16912314907606305
epoch: 37
epoch: 37 loss: 0.09961113018179792
eval loss: 0.1708776319274029 0.1708776319274029
epoch: 38
epoch: 38 loss: 0.09881032283505764
eval loss: 0.1700948957383046 0.1700948957383046
epoch: 39
epoch: 39 loss: 0.09769001971489964


eval loss: 0.16909186128357 0.16909186128357
epoch: 40
epoch: 40 loss: 0.09934352701196834
eval loss: 0.16996266391997847 0.16996266391997847
epoch: 41
epoch: 41 loss: 0.0977964882321232
eval loss: 0.17082077884554417 0.17082077884554417
epoch: 42
epoch: 42 loss: 0.09558555877728846
eval loss: 0.17382652042060476 0.17382652042060476
epoch: 43
epoch: 43 loss: 0.09695991904248447
eval loss: 0.1695977938917407 0.1695977938917407
epoch: 44
epoch: 44 loss: 0.0948977567657664
eval loss: 0.17056169990335357 0.17056169990335357
epoch: 45
epoch: 45 loss: 0.09438416659642763


eval loss: 0.16780385913307205 0.16780385913307205
epoch: 46
epoch: 46 loss: 0.0940434976483105
eval loss: 0.16941299433991683 0.16941299433991683
epoch: 47
epoch: 47 loss: 0.09582586019629637
eval loss: 0.17384103668098505 0.17384103668098505
epoch: 48
epoch: 48 loss: 0.09345339510426132


eval loss: 0.16705090969380323 0.16705090969380323
epoch: 49
epoch: 49 loss: 0.09333363395975436
eval loss: 0.16816359748958865 0.16816359748958865
epoch: 50
epoch: 50 loss: 0.0915125047722032
eval loss: 0.1715363197066052 0.1715363197066052
epoch: 51
epoch: 51 loss: 0.09196094355084018
eval loss: 0.16879078667562086 0.16879078667562086
epoch: 52
epoch: 52 loss: 0.09121006503056074
eval loss: 0.16809554588281103 0.16809554588281103
epoch: 53
epoch: 53 loss: 0.09065609576240807
eval loss: 0.16824491133019517 0.16824491133019517
epoch: 54
epoch: 54 loss: 0.08960289301198161
eval loss: 0.1681247949767848 0.1681247949767848
epoch: 55
epoch: 55 loss: 0.08940061822038098
eval loss: 0.16784916200122077 0.16784916200122077
epoch: 56
epoch: 56 loss: 0.08873138108023143
eval loss: 0.16762196216191683 0.16762196216191683
epoch: 57
epoch: 57 loss: 0.08857182110286876
eval loss: 0.1677130316541505 0.1677130316541505
epoch: 58
epoch: 58 loss: 0.08841058427864892
eval loss: 0.1698116839478554 0.16981

eval loss: 0.1663051792894458 0.1663051792894458
epoch: 61
epoch: 61 loss: 0.08944719019771417
eval loss: 0.1696566667993963 0.1696566667993963
epoch: 62
epoch: 62 loss: 0.08642617804232997
eval loss: 0.16908808020120278 0.16908808020120278
epoch: 63
epoch: 63 loss: 0.08625787578011368
eval loss: 0.16692832233173438 0.16692832233173438
epoch: 64
epoch: 64 loss: 0.08733205951446914
eval loss: 0.1677540562652127 0.1677540562652127
epoch: 65
epoch: 65 loss: 0.0874776227393943
eval loss: 0.1674159046409008 0.1674159046409008
epoch: 66
epoch: 66 loss: 0.08585916510593564


eval loss: 0.16613025308070564 0.16613025308070564
epoch: 67
epoch: 67 loss: 0.08514389452086363
eval loss: 0.1716490202961091 0.1716490202961091
epoch: 68
epoch: 68 loss: 0.0849435920346823
eval loss: 0.16713317315194617 0.16713317315194617
epoch: 69
epoch: 69 loss: 0.08508734313579426
eval loss: 0.1670510479027715 0.1670510479027715
epoch: 70
epoch: 70 loss: 0.08340131416735769
eval loss: 0.16635935420834494 0.16635935420834494
epoch: 71
epoch: 71 loss: 0.08329726175323596
eval loss: 0.16955118626350998 0.16955118626350998
epoch: 72
epoch: 72 loss: 0.0827129521357869
eval loss: 0.16990481638168767 0.16990481638168767
epoch: 73
epoch: 73 loss: 0.08275736479318449
eval loss: 0.1699732954309414 0.1699732954309414
epoch: 74
epoch: 74 loss: 0.08243687765382914
eval loss: 0.166874593886675 0.166874593886675
epoch: 75
epoch: 75 loss: 0.08248753999659969


eval loss: 0.16565628525040713 0.16565628525040713
epoch: 76
epoch: 76 loss: 0.08161745057667087
eval loss: 0.1675997476552431 0.1675997476552431
epoch: 77
epoch: 77 loss: 0.08136534271801228
eval loss: 0.170714211912019 0.170714211912019
epoch: 78
epoch: 78 loss: 0.08174062068913056
eval loss: 0.16839165353243724 0.16839165353243724
epoch: 79
epoch: 79 loss: 0.0809452966954849
eval loss: 0.16715949514621423 0.16715949514621423
epoch: 80
epoch: 80 loss: 0.08170804079121097
eval loss: 0.167671106140475 0.167671106140475
epoch: 81
epoch: 81 loss: 0.08060521520726623
eval loss: 0.16803889227491184 0.16803889227491184
epoch: 82
epoch: 82 loss: 0.0814364122799837
eval loss: 0.1697330195936664 0.1697330195936664
epoch: 83
epoch: 83 loss: 0.08088895336144203
eval loss: 0.1722072690799968 0.1722072690799968
epoch: 84
epoch: 84 loss: 0.07999444333326144
eval loss: 0.1713323796909432 0.1713323796909432
epoch: 85
epoch: 85 loss: 0.08006504453325859
eval loss: 0.16736781137281542 0.167367811372815

epoch: 0
epoch: 0 loss: 0.2915149955695964


eval loss: 0.2651013919832455 0.2651013919832455
epoch: 1
epoch: 1 loss: 0.21261025195870084


eval loss: 0.24125434040981023 0.24125434040981023
epoch: 2
epoch: 2 loss: 0.19100319551979233


eval loss: 0.22575255483439371 0.22575255483439371
epoch: 3
epoch: 3 loss: 0.17725571448570318


eval loss: 0.22448871734472675 0.22448871734472675
epoch: 4
epoch: 4 loss: 0.16682449782594705


eval loss: 0.2041298729948632 0.2041298729948632
epoch: 5
epoch: 5 loss: 0.1585320417382682


eval loss: 0.20189635204109002 0.20189635204109002
epoch: 6
epoch: 6 loss: 0.15053350560698947


eval loss: 0.19530392429902635 0.19530392429902635
epoch: 7
epoch: 7 loss: 0.14522926632725647


eval loss: 0.1914787932150289 0.1914787932150289
epoch: 8
epoch: 8 loss: 0.14101900632877068


eval loss: 0.1879505855493732 0.1879505855493732
epoch: 9
epoch: 9 loss: 0.13852495531608136


eval loss: 0.18481601886861668 0.18481601886861668
epoch: 10
epoch: 10 loss: 0.1346759094795216


eval loss: 0.18232446709778993 0.18232446709778993
epoch: 11
epoch: 11 loss: 0.13106428985533466
eval loss: 0.18516814541832094 0.18516814541832094
epoch: 12
epoch: 12 loss: 0.12824936565289968


eval loss: 0.17937822849646023 0.17937822849646023
epoch: 13
epoch: 13 loss: 0.12598104354931444
eval loss: 0.18323478229830228 0.18323478229830228
epoch: 14
epoch: 14 loss: 0.1251505846385419


eval loss: 0.17774776680035173 0.17774776680035173
epoch: 15
epoch: 15 loss: 0.12263800215779216


eval loss: 0.1757689311655691 0.1757689311655691
epoch: 16
epoch: 16 loss: 0.1214548597227272


eval loss: 0.17559124628677683 0.17559124628677683
epoch: 17
epoch: 17 loss: 0.1192278091877343


eval loss: 0.17318201951316167 0.17318201951316167
epoch: 18
epoch: 18 loss: 0.11769021194373576
eval loss: 0.17485812835790382 0.17485812835790382
epoch: 19
epoch: 19 loss: 0.11789992496007136
eval loss: 0.1755170271431945 0.1755170271431945
epoch: 20
epoch: 20 loss: 0.11543535104995431
eval loss: 0.17382008310277922 0.17382008310277922
epoch: 21
epoch: 21 loss: 0.11444372521404632


eval loss: 0.1708763960256684 0.1708763960256684
epoch: 22
epoch: 22 loss: 0.1125238330972912
eval loss: 0.17134562984068738 0.17134562984068738
epoch: 23
epoch: 23 loss: 0.11253885532049845
eval loss: 0.1732820446559649 0.1732820446559649
epoch: 24
epoch: 24 loss: 0.1113166045019456


eval loss: 0.16885493979416524 0.16885493979416524
epoch: 25
epoch: 25 loss: 0.11013268075879465
eval loss: 0.1704434618473759 0.1704434618473759
epoch: 26
epoch: 26 loss: 0.10939009471332811
eval loss: 0.1697487332008527 0.1697487332008527
epoch: 27
epoch: 27 loss: 0.10872900971042071
eval loss: 0.17000105688894465 0.17000105688894465
epoch: 28
epoch: 28 loss: 0.10726177264222857


eval loss: 0.16833815427114096 0.16833815427114096
epoch: 29
epoch: 29 loss: 0.10590655637215295


eval loss: 0.1672125336826309 0.1672125336826309
epoch: 30
epoch: 30 loss: 0.10456544918811998


eval loss: 0.16680973611929603 0.16680973611929603
epoch: 31
epoch: 31 loss: 0.10502823883875692
eval loss: 0.169708927564853 0.169708927564853
epoch: 32
epoch: 32 loss: 0.10466228478295346
eval loss: 0.16805249134529424 0.16805249134529424
epoch: 33
epoch: 33 loss: 0.10379878127015409
eval loss: 0.16793926848785845 0.16793926848785845
epoch: 34
epoch: 34 loss: 0.10275607880275506
eval loss: 0.16899202676975125 0.16899202676975125
epoch: 35
epoch: 35 loss: 0.10106303513694924


eval loss: 0.16588203558397674 0.16588203558397674
epoch: 36
epoch: 36 loss: 0.10129905696754643
eval loss: 0.16675824855526108 0.16675824855526108
epoch: 37
epoch: 37 loss: 0.09918566497826647
eval loss: 0.1666178494367155 0.1666178494367155
epoch: 38
epoch: 38 loss: 0.09961219761414729
eval loss: 0.16608005186340588 0.16608005186340588
epoch: 39
epoch: 39 loss: 0.09965742619651703


eval loss: 0.16488907907414788 0.16488907907414788
epoch: 40
epoch: 40 loss: 0.09783773119102845
eval loss: 0.16529280266480914 0.16529280266480914
epoch: 41
epoch: 41 loss: 0.09668438742818887
eval loss: 0.16502624625602702 0.16502624625602702
epoch: 42
epoch: 42 loss: 0.09632655568322761
eval loss: 0.16537683642899134 0.16537683642899134
epoch: 43
epoch: 43 loss: 0.0972176709638243
eval loss: 0.16504900778302084 0.16504900778302084
epoch: 44
epoch: 44 loss: 0.09708896461444047
eval loss: 0.1667988983254716 0.1667988983254716
epoch: 45
epoch: 45 loss: 0.09477534670163137
eval loss: 0.1678816745972432 0.1678816745972432
epoch: 46
epoch: 46 loss: 0.09657393378458276


eval loss: 0.16487641162894745 0.16487641162894745
epoch: 47
epoch: 47 loss: 0.09650614777024207


eval loss: 0.16338028267641594 0.16338028267641594
epoch: 48
epoch: 48 loss: 0.09333501403979504


eval loss: 0.16275318489416915 0.16275318489416915
epoch: 49
epoch: 49 loss: 0.09423255203367925
eval loss: 0.16716502998050417 0.16716502998050417
epoch: 50
epoch: 50 loss: 0.09216191677816922
eval loss: 0.163313597871277 0.163313597871277
epoch: 51
epoch: 51 loss: 0.092063543445322
eval loss: 0.16302260851240663 0.16302260851240663
epoch: 52
epoch: 52 loss: 0.0930640347787958
eval loss: 0.16675889981725345 0.16675889981725345
epoch: 53
epoch: 53 loss: 0.0919933049783522
eval loss: 0.16667839298198925 0.16667839298198925
epoch: 54
epoch: 54 loss: 0.09099090864699665
eval loss: 0.167342112060737 0.167342112060737
epoch: 55
epoch: 55 loss: 0.0908735365762644


eval loss: 0.16247049371780578 0.16247049371780578
epoch: 56
epoch: 56 loss: 0.09016928304918656


eval loss: 0.16228038513450624 0.16228038513450624
epoch: 57
epoch: 57 loss: 0.0887125525562936
eval loss: 0.16509418817028165 0.16509418817028165
epoch: 58
epoch: 58 loss: 0.08967807136862234
eval loss: 0.16326157067929833 0.16326157067929833
epoch: 59
epoch: 59 loss: 0.09340627013772389
eval loss: 0.16459933326527187 0.16459933326527187
epoch: 60
epoch: 60 loss: 0.09062865157671643
eval loss: 0.1640413692934521 0.1640413692934521
epoch: 61
epoch: 61 loss: 0.08908713700380533
eval loss: 0.1629358352540209 0.1629358352540209
epoch: 62
epoch: 62 loss: 0.08858021406721504
eval loss: 0.163507350273373 0.163507350273373
epoch: 63
epoch: 63 loss: 0.08719879276535082


eval loss: 0.1616144785568587 0.1616144785568587
epoch: 64
epoch: 64 loss: 0.08667113755912595
eval loss: 0.16514213374970058 0.16514213374970058
epoch: 65
epoch: 65 loss: 0.08614668949504073
eval loss: 0.16265384236147787 0.16265384236147787
epoch: 66
epoch: 66 loss: 0.08562281452058883
eval loss: 0.16362618353684052 0.16362618353684052
epoch: 67
epoch: 67 loss: 0.08460358598412955
eval loss: 0.1644209336514047 0.1644209336514047
epoch: 68
epoch: 68 loss: 0.0848064645627343


eval loss: 0.16142933778782279 0.16142933778782279
epoch: 69
epoch: 69 loss: 0.08371413586108961
eval loss: 0.16310666112236716 0.16310666112236716
epoch: 70
epoch: 70 loss: 0.08522664940118867
eval loss: 0.16179916626330162 0.16179916626330162
epoch: 71
epoch: 71 loss: 0.08422335897124342
eval loss: 0.16264296837514386 0.16264296837514386
epoch: 72
epoch: 72 loss: 0.08281282986183057
eval loss: 0.1635028932634329 0.1635028932634329
epoch: 73
epoch: 73 loss: 0.08329071394404083
eval loss: 0.16428739695683034 0.16428739695683034
epoch: 74
epoch: 74 loss: 0.08300791362809655
eval loss: 0.1648388226673769 0.1648388226673769
epoch: 75
epoch: 75 loss: 0.08213866807615874
eval loss: 0.1650974273057287 0.1650974273057287
epoch: 76
epoch: 76 loss: 0.08636800151943835
eval loss: 0.16300509337945998 0.16300509337945998
epoch: 77
epoch: 77 loss: 0.08490529265653811
eval loss: 0.1659781371055881 0.1659781371055881
epoch: 78
epoch: 78 loss: 0.08273538265432578
eval loss: 0.1637187647627174 0.163718

In [11]:
oof_df = pd.DataFrame(oof_list, columns=["id_seqpos"] + target_cols)
oof_df.to_csv('validation_aepytorch.csv', index=False)

### Prediction

In [12]:
device = torch.device('cuda:1') if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 1
base_test_data = pd.read_json(str(Path(BASE_PATH) / 'test.json'), lines=True)
base_test_data = aug_data(base_test_data)

public_df = base_test_data.query("seq_length == 107").copy()
private_df = base_test_data.query("seq_length == 130").copy()
print(f"public_df: {public_df.shape}")
print(f"private_df: {private_df.shape}")
public_df = public_df.reset_index()
private_df = private_df.reset_index()
pub_loader = create_loader(public_df, BATCH_SIZE, is_test=True)
pri_loader = create_loader(private_df, BATCH_SIZE, is_test=True)
pred_df_list = []
c = 0
for fold in range(5):
    model_load_path = f"./model_prediction/model-{fold}.pt"
    ae_model0 = AEModel()
    ae_model1 = AEModel()
    model_pub = FromAeModel(pred_len=107, seq=ae_model0.seq)
    model_pub = model_pub.to(device)
    model_pri = FromAeModel(pred_len=130, seq=ae_model1.seq)
    model_pri = model_pri.to(device)
    state_dict = torch.load(model_load_path, map_location=device)
    model_pub.load_state_dict(state_dict)
    model_pri.load_state_dict(state_dict)
    del state_dict

    data_list = []
    data_list += predict_data(model_pub, pub_loader, device, BATCH_SIZE)
    data_list += predict_data(model_pri, pri_loader, device, BATCH_SIZE)
    pred_df = pd.DataFrame(data_list, columns=["id_seqpos"] + target_cols)
    #print(pred_df.head())
    #print(pred_df.tail())
    pred_df_list.append(pred_df)
    c += 1
data_dic = dict(id_seqpos=pred_df_list[0]["id_seqpos"])
for col in target_cols:
    vals = np.zeros(pred_df_list[0][col].shape[0])
    for df in pred_df_list:
        vals += df[col].values
    data_dic[col] = vals / float(c)
pred_df_avg = pd.DataFrame(data_dic, columns=["id_seqpos"] + target_cols)
#print(pred_df_avg.head())
#pred_df_avg.to_csv("./submission_all.csv", index=False)

##### submission final
sample_df = pd.read_csv('../OpenVaccine/sample_submission.csv')
target_cols = [c for c in sample_df.columns if c != 'id_seqpos']
list_id = list(sample_df.id_seqpos.values)
output = {}
output_df = pd.DataFrame({'id_seqpos': sample_df.id_seqpos.values})

for c in target_cols:
    output_values = []
    x = pred_df_avg.groupby('id_seqpos')[c].mean().reset_index()
    print(x.shape)
    output_df = pd.merge(output_df, x, on='id_seqpos')
    
output_df.to_csv('submission_aepytorch.csv', index=False)
print(output_df.shape)

public_df: (1258, 10)
private_df: (6010, 10)
(1, 107, 107, 3)
(1, 130, 130, 3)


(457953, 2)
(457953, 2)
(457953, 2)
(457953, 2)
(457953, 2)
(457953, 6)


In [13]:
output_df.head(10)

Unnamed: 0,id_seqpos,reactivity,deg_Mg_pH10,deg_pH10,deg_Mg_50C,deg_50C
0,id_00073f8be_0,0.699546,0.64135,1.813082,0.526265,0.696698
1,id_00073f8be_1,2.198713,3.088485,4.026186,3.064846,2.537296
2,id_00073f8be_2,1.573336,0.5917,0.687707,0.687055,0.638836
3,id_00073f8be_3,1.304416,1.074144,1.07941,1.533832,1.484126
4,id_00073f8be_4,0.838901,0.604326,0.508964,0.847205,0.773494
5,id_00073f8be_5,0.674927,0.603603,0.591096,0.712282,0.681853
6,id_00073f8be_6,0.74029,0.900757,0.84498,0.946156,0.849289
7,id_00073f8be_7,0.853472,1.024172,1.078439,0.984978,1.205088
8,id_00073f8be_8,0.203472,0.74566,0.736198,0.831221,0.644785
9,id_00073f8be_9,0.067765,0.221344,0.236622,0.263358,0.347607
