Ispired by this notebook.

https://www.kaggle.com/mrkmakr/covid-ae-pretrain-gnn-attn-cnn

Pytorch Implementation.

Model: Convolution + Transfomer + GRU.

Training: Bert-Like Pretraining. Then Fine Tuning.

In [1]:
import numpy as np
import pandas as pd
import ast
from collections import OrderedDict
from fastprogress import progress_bar
from pathlib import Path
from torch import nn
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter

from torch.nn import TransformerEncoder, TransformerEncoderLayer
import math

import functools
import itertools
import os
import random
import shutil
import torch
import torch.nn.functional as F

from sklearn.model_selection import train_test_split, ShuffleSplit, KFold, StratifiedKFold, GroupKFold
from torch import nn
from torch.utils.data import Dataset

import functools
from IPython.core.debugger import set_trace


target_cols = ['reactivity', 'deg_Mg_pH10', 'deg_pH10', 'deg_Mg_50C', 'deg_50C']
input_cols = ['sequence', 'structure', 'predicted_loop_type']
error_cols = ['reactivity_error', 'deg_error_Mg_pH10', 'deg_error_Mg_50C', 'deg_error_pH10', 'deg_error_50C']

token_dicts = {
    "sequence": {x: i for i, x in enumerate("ACGU")},
    "structure": {x: i for i, x in enumerate('().')},
    "predicted_loop_type": {x: i for i, x in enumerate("BEHIMSX")}
}


def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

### Loader


In [2]:
BASE_PATH = "../OpenVaccine"
MODEL_SAVE_PATH = "../OpenVaccine/pretrains"


def preprocess_inputs(df, cols):
    return np.concatenate([preprocess_feature_col(df, col) for col in cols], axis=2)


def preprocess_feature_col(df, col):
    dic = token_dicts[col]
    dic_len = len(dic)
    seq_length = len(df[col][0])
    ident = np.identity(dic_len)
    # convert to one hot
    arr = np.array(
        df[[col]].applymap(lambda seq: [ident[dic[x]] for x in seq]).values.tolist()
    ).squeeze(1)
    # shape: data_size x seq_length x dic_length
    assert arr.shape == (len(df), seq_length, dic_len)
    return arr


def preprocess(base_data, is_test=False):
    inputs = preprocess_inputs(base_data, input_cols)
    if is_test:
        labels = None
    else:
        labels = np.array(base_data[target_cols].values.tolist()).transpose((0, 2, 1))
        assert labels.shape[2] == len(target_cols)
    assert inputs.shape[2] == 14
    return inputs, labels


def get_bpp_feature(bpp):
    bpp_nb_mean = 0.077522  # mean of bpps_nb across all training data
    bpp_nb_std = 0.08914  # std of bpps_nb across all training data
    bpp_max = bpp.max(-1)[0]
    bpp_sum = bpp.sum(-1)
    bpp_nb = torch.true_divide((bpp > 0).sum(dim=1), bpp.shape[1])
    bpp_nb = torch.true_divide(bpp_nb - bpp_nb_mean, bpp_nb_std)
    return [bpp_max.unsqueeze(2), bpp_sum.unsqueeze(2), bpp_nb.unsqueeze(2)]


@functools.lru_cache(5000)
def load_from_id(id_):
    path = Path(BASE_PATH) / f"bpps/{id_}.npy"
    data = np.load(str(path))
    return data


def get_distance_matrix(leng):
    idx = np.arange(leng)
    Ds = []
    for i in range(len(idx)):
        d = np.abs(idx[i] - idx)
        Ds.append(d)

    Ds = np.array(Ds) + 1
    Ds = 1 / Ds
    Ds = Ds[None, :, :]
    Ds = np.repeat(Ds, 1, axis=0)

    Dss = []
    for i in [1, 2, 4]:
        Dss.append(Ds ** i)
    Ds = np.stack(Dss, axis=3)
    print(Ds.shape)
    return Ds


def get_structure_adj(df):
    Ss = []
    for i in range(len(df)):
        seq_length = df["seq_length"].iloc[i]
        structure = df["structure"].iloc[i]
        sequence = df["sequence"].iloc[i]

        cue = []
        a_structures = OrderedDict([
            (("A", "U"), np.zeros([seq_length, seq_length])),
            (("C", "G"), np.zeros([seq_length, seq_length])),
            (("U", "G"), np.zeros([seq_length, seq_length])),
            (("U", "A"), np.zeros([seq_length, seq_length])),
            (("G", "C"), np.zeros([seq_length, seq_length])),
            (("G", "U"), np.zeros([seq_length, seq_length])),
        ])
        for j in range(seq_length):
            if structure[j] == "(":
                cue.append(j)
            elif structure[j] == ")":
                start = cue.pop()
                a_structures[(sequence[start], sequence[j])][start, j] = 1
                a_structures[(sequence[j], sequence[start])][j, start] = 1

        a_strc = np.stack([a for a in a_structures.values()], axis=2)
        a_strc = np.sum(a_strc, axis=2, keepdims=True)
        Ss.append(a_strc)

    Ss = np.array(Ss)
    return Ss


def create_loader(df, batch_size=1, is_test=False):
    features, labels = preprocess(df, is_test)
    features_tensor = torch.from_numpy(features)
    if labels is not None:
        labels_tensor = torch.from_numpy(labels)
        dataset = VacDataset(features_tensor, df, labels_tensor)
        loader = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, drop_last=False)
    else:
        dataset = VacDataset(features_tensor, df, None)
        loader = torch.utils.data.DataLoader(dataset, batch_size, shuffle=False, drop_last=False)
    return loader


class VacDataset(Dataset):
    def __init__(self, features, df, labels=None):
        self.features = features
        self.labels = labels
        self.test = labels is None
        self.ids = df["id"]
        self.score = None
        self.structure_adj = get_structure_adj(df)
        self.distance_matrix = get_distance_matrix(self.structure_adj.shape[1])
        if "score" in df.columns:
            self.score = df["score"]
        else:
            df["score"] = 1.0
            self.score = df["score"]
        self.signal_to_noise = None
        if not self.test:
            self.signal_to_noise = df["signal_to_noise"]
            assert self.features.shape[0] == self.labels.shape[0]
        else:
            assert self.ids is not None

    def __len__(self):
        return len(self.features)

    def __getitem__(self, index):
        bpp = torch.from_numpy(load_from_id(self.ids[index]).copy()).float()
        adj = self.structure_adj[index]
        distance = self.distance_matrix[0]
        bpp = np.concatenate([bpp[:, :, None], adj, distance], axis=2)
        if self.test:
            return dict(sequence=self.features[index].float(), bpp=bpp, ids=self.ids[index])
        else:
            return dict(sequence=self.features[index].float(), bpp=bpp,
                        label=self.labels[index], ids=self.ids[index],
                        signal_to_noise=self.signal_to_noise[index],
                        score=self.score[index])


### Model

In [3]:
class Conv1dStack(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size=3, padding=1, dilation=1):
        super(Conv1dStack, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm1d(out_dim),
            nn.Dropout(0.1),
            nn.LeakyReLU(),
        )
        self.res = nn.Sequential(
            nn.Conv1d(out_dim, out_dim, kernel_size=kernel_size, padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm1d(out_dim),
            nn.Dropout(0.1),
            nn.LeakyReLU(),
        )

    def forward(self, x):
        x = self.conv(x)
        h = self.res(x)
        return x + h


class Conv2dStack(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size=3, padding=1, dilation=1):
        super(Conv2dStack, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_dim),
            nn.Dropout(0.1),
            nn.LeakyReLU(),
        )
        self.res = nn.Sequential(
            nn.Conv2d(out_dim, out_dim, kernel_size=kernel_size, padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_dim),
            nn.Dropout(0.1),
            nn.LeakyReLU(),
        )

    def forward(self, x):
        x = self.conv(x)
        h = self.res(x)
        return x + h


class SeqEncoder(nn.Module):
    def __init__(self, in_dim: int):
        super(SeqEncoder, self).__init__()
        self.conv0 = Conv1dStack(in_dim, 128, 3, padding=1)
        self.conv1 = Conv1dStack(128, 64, 6, padding=5, dilation=2)
        self.conv2 = Conv1dStack(64, 32, 15, padding=7, dilation=1)
        self.conv3 = Conv1dStack(32, 32, 30, padding=29, dilation=2)

    def forward(self, x):
        x1 = self.conv0(x)
        x2 = self.conv1(x1)
        x3 = self.conv2(x2)
        x4 = self.conv3(x3)
        x = torch.cat([x1, x2, x3, x4], dim=1)
        # x = x.permute(0, 2, 1).contiguous()
        # BATCH x 256 x seq_length
        return x


class BppAttn(nn.Module):
    def __init__(self, in_channel: int, out_channel: int):
        super(BppAttn, self).__init__()
        self.conv0 = Conv1dStack(in_channel, out_channel, 3, padding=1)
        self.bpp_conv = Conv2dStack(5, out_channel)

    def forward(self, x, bpp):
        x = self.conv0(x)
        bpp = self.bpp_conv(bpp)
        # BATCH x C x SEQ x SEQ
        # BATCH x C x SEQ
        x = torch.matmul(bpp, x.unsqueeze(-1))
        return x.squeeze(-1)


class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class TransformerWrapper(nn.Module):
    def __init__(self, dmodel=256, nhead=8, num_layers=2):
        super(TransformerWrapper, self).__init__()
        self.pos_encoder = PositionalEncoding(256)
        encoder_layer = TransformerEncoderLayer(d_model=dmodel, nhead=nhead)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers)
        self.pos_emb = PositionalEncoding(dmodel)

    def flatten_parameters(self):
        pass

    def forward(self, x):
        x = x.permute((1, 0, 2)).contiguous()
        x = self.pos_emb(x)
        x = self.transformer_encoder(x)
        x = x.permute((1, 0, 2)).contiguous()
        return x, None


class RnnLayers(nn.Module):
    def __init__(self, dmodel, dropout=0.3, transformer_layers: int = 2):
        super(RnnLayers, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.rnn0 = TransformerWrapper(dmodel, nhead=8, num_layers=transformer_layers)
        self.rnn1 = nn.LSTM(dmodel, dmodel // 2, batch_first=True, num_layers=1, bidirectional=True)
        self.rnn2 = nn.GRU(dmodel, dmodel // 2, batch_first=True, num_layers=1, bidirectional=True)

    def forward(self, x):
        self.rnn0.flatten_parameters()
        x, _ = self.rnn0(x)
        if self.rnn1 is not None:
            self.rnn1.flatten_parameters()
            x = self.dropout(x)
            x, _ = self.rnn1(x)
        if self.rnn2 is not None:
            self.rnn2.flatten_parameters()
            x = self.dropout(x)
            x, _ = self.rnn2(x)
        return x

    
class BaseAttnModel(nn.Module):
    def __init__(self, transformer_layers: int = 2):
        super(BaseAttnModel, self).__init__()
        self.linear0 = nn.Linear(14 + 3, 1)
        self.seq_encoder_x = SeqEncoder(18)
        self.attn = BppAttn(256, 128)
        self.seq_encoder_bpp = SeqEncoder(128)
        self.seq = RnnLayers(256 * 2, dropout=0.3,
                             transformer_layers=transformer_layers)

    def forward(self, x, bpp):
        bpp_features = get_bpp_feature(bpp[:, :, :, 0].float())
        x = torch.cat([x] + bpp_features, dim=-1)
        learned = self.linear0(x)
        x = torch.cat([x, learned], dim=-1)
        x = x.permute(0, 2, 1).contiguous().float()
        # BATCH x 18 x seq_len
        bpp = bpp.permute([0, 3, 1, 2]).contiguous().float()
        # BATCH x 5 x seq_len x seq_len
        x = self.seq_encoder_x(x)
        # BATCH x 256 x seq_len
        bpp = self.attn(x, bpp)
        bpp = self.seq_encoder_bpp(bpp)
        # BATCH x 256 x seq_len
        x = x.permute(0, 2, 1).contiguous()
        # BATCH x seq_len x 256
        bpp = bpp.permute(0, 2, 1).contiguous()
        # BATCH x seq_len x 256
        x = torch.cat([x, bpp], dim=2)
        # BATCH x seq_len x 512
        x = self.seq(x)
        return x


class AEModel(nn.Module):
    def __init__(self, transformer_layers: int = 2):
        super(AEModel, self).__init__()
        self.seq = BaseAttnModel(transformer_layers=transformer_layers)
        self.linear = nn.Sequential(
            nn.Linear(256 * 2, 14),
            nn.Sigmoid(),
        )

    def forward(self, x, bpp):
        x = self.seq(x, bpp)
        x = F.dropout(x, p=0.3)
        x = self.linear(x)
        return x


class FromAeModel(nn.Module):
    def __init__(self, seq, pred_len=68, dmodel: int = 256):
        super(FromAeModel, self).__init__()
        self.seq = seq
        self.pred_len = pred_len
        self.linear = nn.Sequential(
            nn.Linear(dmodel * 2, len(target_cols)),
        )

    def forward(self, x, bpp):
        x = self.seq(x, bpp)
        x = self.linear(x)
        x = x[:, :self.pred_len]
        return x

In [4]:
aug_df = pd.read_csv('../OpenVaccine/aug_data1.csv')

def aug_data(df):
    target_df = df.copy()
    new_df = aug_df[aug_df['id'].isin(target_df['id'])]
                         
    del target_df['structure']
    del target_df['predicted_loop_type']
    new_df = new_df.merge(target_df, on=['id','sequence'], how='left')

    df['cnt'] = df['id'].map(new_df[['id','cnt']].set_index('id').to_dict()['cnt'])
    df['log_gamma'] = 100
    df['score'] = 1.0
    df = df.append(new_df[df.columns])
    return df

In [5]:
pseudo_df = pd.read_csv('../OpenVaccine/pseudo_test.csv')

pseudo_df['SN_filter'] = 1
pseudo_df['signal_to_noise'] = 1.0
pseudo_df['score'] = 1.0
pseudo_df['reactivity'] = pseudo_df['reactivity'].apply(lambda x: ast.literal_eval(x))
pseudo_df['deg_Mg_pH10'] = pseudo_df['deg_Mg_pH10'].apply(lambda x: ast.literal_eval(x))
pseudo_df['deg_pH10'] = pseudo_df['deg_pH10'].apply(lambda x: ast.literal_eval(x))
pseudo_df['deg_Mg_50C'] = pseudo_df['deg_Mg_50C'].apply(lambda x: ast.literal_eval(x))
pseudo_df['deg_50C'] = pseudo_df['deg_50C'].apply(lambda x: ast.literal_eval(x))

In [6]:
pseudo_st = pseudo_df[pseudo_df['seq_length'] == 107]
pseudo_lg = pseudo_df[pseudo_df['seq_length'] == 130]

pseudo_st['reactivity'] = pseudo_st['reactivity'].apply(lambda x: x[:68])
pseudo_st['deg_Mg_pH10'] = pseudo_st['deg_Mg_pH10'].apply(lambda x: x[:68])
pseudo_st['deg_pH10'] = pseudo_st['deg_pH10'].apply(lambda x: x[:68])
pseudo_st['deg_Mg_50C'] = pseudo_st['deg_Mg_50C'].apply(lambda x: x[:68])
pseudo_st['deg_50C'] = pseudo_st['deg_50C'].apply(lambda x: x[:68])

pseudo_lg['reactivity'] = pseudo_lg['reactivity'].apply(lambda x: x[:91])
pseudo_lg['deg_Mg_pH10'] = pseudo_lg['deg_Mg_pH10'].apply(lambda x: x[:91])
pseudo_lg['deg_pH10'] = pseudo_lg['deg_pH10'].apply(lambda x: x[:91])
pseudo_lg['deg_Mg_50C'] = pseudo_lg['deg_Mg_50C'].apply(lambda x: x[:91])
pseudo_lg['deg_50C'] = pseudo_lg['deg_50C'].apply(lambda x: x[:91])

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  after removing the cwd from sys.path.
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: 

In [7]:
base_train_data = pd.read_json(str(Path(BASE_PATH) / 'train.json'), lines=True)
#base_train_data.head()

# add aug data

base_train_data = aug_data(base_train_data)
# base_train_data = base_train_data.append(pseudo_st)
base_train_data = base_train_data.reset_index(drop = True)
# base_train_data_lg = pseudo_lg.reset_index(drop = True)


device = torch.device('cuda')
BATCH_SIZE = 64

#base_train_data = pd.read_json(str(Path(BASE_PATH) / 'train.json'), lines=True)
base_test_data = pd.read_json(str(Path(BASE_PATH) / 'test.json'), lines=True)
base_test_data = aug_data(base_test_data)

public_df = base_test_data.query("seq_length == 107").copy()
private_df = base_test_data.query("seq_length == 130").copy()
print(f"public_df: {public_df.shape}")
print(f"private_df: {private_df.shape}")
public_df = public_df.reset_index()
private_df = private_df.reset_index()

features, _ = preprocess(base_train_data, True)
features_tensor = torch.from_numpy(features)
dataset0 = VacDataset(features_tensor, base_train_data, None)

# features, _ = preprocess(base_train_data_lg, True)
# features_tensor = torch.from_numpy(features)
# dataset0_lg = VacDataset(features_tensor, base_train_data_lg, None)

features, _ = preprocess(public_df, True)
features_tensor = torch.from_numpy(features)
dataset1 = VacDataset(features_tensor, public_df, None)

features, _ = preprocess(private_df, True)
features_tensor = torch.from_numpy(features)
dataset2 = VacDataset(features_tensor, private_df, None)

loader0 = torch.utils.data.DataLoader(dataset0, BATCH_SIZE, shuffle=False, drop_last=False)
# loader0_lg = torch.utils.data.DataLoader(dataset0_lg, BATCH_SIZE, shuffle=False, drop_last=False)
loader1 = torch.utils.data.DataLoader(dataset1, BATCH_SIZE, shuffle=False, drop_last=False)
loader2 = torch.utils.data.DataLoader(dataset2, BATCH_SIZE, shuffle=False, drop_last=False)

public_df: (1258, 10)
private_df: (6010, 10)
(1, 107, 107, 3)
(1, 107, 107, 3)
(1, 130, 130, 3)


### Pretrain

In [8]:
def learn_from_batch_ae(model, data, device):
    seq = data["sequence"].clone()
    seq[:, :, :14] = F.dropout2d(seq[:, :, :14], p=0.3)
    target = data["sequence"][:, :, :14]
    out = model(seq.to(device), data["bpp"].to(device))
    loss = F.binary_cross_entropy(out, target.to(device))
    return loss


def train_ae(model, train_data, optimizer, lr_scheduler, epochs=10, device="cpu", 
             start_epoch: int = 0, start_it: int = 0, log_path: str = "./logs"): #10
    print(f"device: {device}")
    losses = []
    it = start_it
    model_save_path = Path(MODEL_SAVE_PATH)
    start_epoch = start_epoch
    end_epoch = start_epoch + epochs
    min_loss = 10.0
    min_loss_epoch = 0
    if not model_save_path.exists():
        model_save_path.mkdir(parents=True)
    for epoch in progress_bar(range(start_epoch, end_epoch)):
        print(f"epoch: {epoch}")
        model.train()
        for i, data in enumerate(train_data):
            optimizer.zero_grad()
            loss = learn_from_batch_ae(model, data, device)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            if lr_scheduler:
                lr_scheduler.step()
            loss_v = loss.item()
            losses.append(loss_v)
            it += 1
        loss_m = np.mean(losses)
        if loss_m < min_loss:
            min_loss_epoch = epoch
            min_loss = loss_m
        print(f'epoch: {epoch} loss: {loss_m}')
        losses = []
        torch.save(optimizer.state_dict(), str(model_save_path / "optimizer.pt"))
        torch.save(model.state_dict(), str(model_save_path / f"model-{epoch}.pt"))
    return dict(end_epoch=end_epoch, it=it, min_loss_epoch=min_loss_epoch)

In [9]:
import shutil

set_seed(123)
shutil.rmtree("./model", True)
shutil.rmtree("./logs", True)
save_path = Path("./model_prediction")
if not save_path.exists():
    save_path.mkdir(parents=True)

lr_scheduler = None
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AEModel()
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
res = dict(end_epoch=0, it=0, min_loss_epoch=0)
epochs = [5, 5, 5, 5] # [5, 5, 5, 5]
for e in epochs:
    print('--loader0')
    res = train_ae(model, loader0, optimizer, lr_scheduler, e, device=device,
                   start_epoch=res["end_epoch"], start_it=res["it"])
#     print('--loader0_lg')
#     res = train_ae(model, loader0_lg, optimizer, lr_scheduler, e, device=device,
#                    start_epoch=res["end_epoch"], start_it=res["it"])
    print('--loader1')
    res = train_ae(model, loader1, optimizer, lr_scheduler, e, device=device,
                   start_epoch=res["end_epoch"], start_it=res["it"])
    print('--loader2')
    res = train_ae(model, loader2, optimizer, lr_scheduler, e, device=device,
                   start_epoch=res["end_epoch"], start_it=res["it"])

epoch = res["min_loss_epoch"]
shutil.copyfile(str(Path(MODEL_SAVE_PATH) / f"model-{epoch}.pt"), "ae-model.pt")

--loader0
device: cuda


epoch: 0
epoch: 0 loss: 0.25064601679642995
epoch: 1
epoch: 1 loss: 0.1384060760339101
epoch: 2
epoch: 2 loss: 0.08038593033949534
epoch: 3
epoch: 3 loss: 0.03185466726620992
epoch: 4
epoch: 4 loss: 0.019618625019987426
--loader1
device: cuda


epoch: 5
epoch: 5 loss: 0.017772705294191837
epoch: 6
epoch: 6 loss: 0.01611646986566484
epoch: 7
epoch: 7 loss: 0.013704161671921612
epoch: 8
epoch: 8 loss: 0.011801821365952491
epoch: 9
epoch: 9 loss: 0.010440635774284602
--loader2
device: cuda


epoch: 10
epoch: 10 loss: 0.01808331135977456
epoch: 11
epoch: 11 loss: 0.014588844268879991
epoch: 12
epoch: 12 loss: 0.012971125731363576
epoch: 13
epoch: 13 loss: 0.012219857741543588
epoch: 14
epoch: 14 loss: 0.011251784831364738
--loader0
device: cuda


epoch: 15
epoch: 15 loss: 0.11141376870994767
epoch: 16
epoch: 16 loss: 0.14910804033279418
epoch: 17
epoch: 17 loss: 0.1133262633283933
epoch: 18
epoch: 18 loss: 0.08852919230858484
epoch: 19
epoch: 19 loss: 0.0644798181951046
--loader1
device: cuda


epoch: 20
epoch: 20 loss: 0.04376504682004452
epoch: 21
epoch: 21 loss: 0.03720299247652292
epoch: 22
epoch: 22 loss: 0.031649123318493365
epoch: 23
epoch: 23 loss: 0.028315145894885062
epoch: 24
epoch: 24 loss: 0.024756696540862322
--loader2
device: cuda


epoch: 25
epoch: 25 loss: 0.02692662131913165
epoch: 26
epoch: 26 loss: 0.02191674142600374
epoch: 27
epoch: 27 loss: 0.019343210384249687
epoch: 28
epoch: 28 loss: 0.018591710326677943
epoch: 29
epoch: 29 loss: 0.018955215732467934
--loader0
device: cuda


epoch: 30
epoch: 30 loss: 0.01184980947524309
epoch: 31
epoch: 31 loss: 0.011005594283342362
epoch: 32
epoch: 32 loss: 0.010138162014385065
epoch: 33
epoch: 33 loss: 0.00937236485381921
epoch: 34
epoch: 34 loss: 0.009434934935222069
--loader1
device: cuda


epoch: 35
epoch: 35 loss: 0.007858033105731011
epoch: 36
epoch: 36 loss: 0.008021229738369584
epoch: 37
epoch: 37 loss: 0.0076625833986327056
epoch: 38
epoch: 38 loss: 0.007388785784132779
epoch: 39
epoch: 39 loss: 0.006525391549803317
--loader2
device: cuda


epoch: 40
epoch: 40 loss: 0.014415230730825917
epoch: 41
epoch: 41 loss: 0.013332009156967731
epoch: 42
epoch: 42 loss: 0.01705033711573862
epoch: 43
epoch: 43 loss: 0.030063067166570652
epoch: 44
epoch: 44 loss: 0.01620889058772554
--loader0
device: cuda


epoch: 45
epoch: 45 loss: 0.0104247345837454
epoch: 46
epoch: 46 loss: 0.010372398793697358
epoch: 47
epoch: 47 loss: 0.011601323901365201
epoch: 48
epoch: 48 loss: 0.030044919215142726
epoch: 49
epoch: 49 loss: 0.0157378122707208
--loader1
device: cuda


epoch: 50
epoch: 50 loss: 0.009795540105551481
epoch: 51
epoch: 51 loss: 0.009084979956969618
epoch: 52
epoch: 52 loss: 0.009306920599192381
epoch: 53
epoch: 53 loss: 0.011493802862241864
epoch: 54
epoch: 54 loss: 0.01057223421521485
--loader2
device: cuda


epoch: 55
epoch: 55 loss: 0.019030503562076927
epoch: 56
epoch: 56 loss: 0.016735949563456975
epoch: 57
epoch: 57 loss: 0.014369799279944693
epoch: 58
epoch: 58 loss: 0.013059445131728624
epoch: 59
epoch: 59 loss: 0.012738071332507312


'ae-model.pt'

### Training

In [10]:
def MCRMSE(y_true, y_pred):
    colwise_mse = torch.mean(torch.square(y_true - y_pred), dim=1)
    return torch.mean(torch.sqrt(colwise_mse), dim=1)


def sn_mcrmse_loss(predict, target, signal_to_noise):
    loss = MCRMSE(target, predict)
    weight = 0.5 * torch.log(signal_to_noise + 1.01)
    loss = (loss * weight).mean()
    return loss


def learn_from_batch(model, data, optimizer, lr_scheduler, device):
    optimizer.zero_grad()

    out = model(data["sequence"].to(device), data["bpp"].to(device))
    out = out[:, :data["label"].shape[1]]
    signal_to_noise = data["signal_to_noise"] * data["score"]
    loss = sn_mcrmse_loss(out, data["label"].to(device), signal_to_noise.to(device))
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
    if lr_scheduler:
        lr_scheduler.step()
    return out, loss

def predict_batch(model, data, device):
    # batch x seq_len x target_size
    with torch.no_grad():
        pred = model(data["sequence"].to(device), data["bpp"].to(device))
        pred = pred.detach().cpu().numpy()
    return_values = []
    ids = data["ids"]
    for idx, p in enumerate(pred):
        id_ = ids[idx]
        assert p.shape == (model.pred_len, len(target_cols))
        for seqpos, val in enumerate(p):
            assert len(val) == len(target_cols)
            dic = {key: val for key, val in zip(target_cols, val)}
            dic["id_seqpos"] = f"{id_}_{seqpos}"
            return_values.append(dic)
    return return_values


def predict_data(model, loader, device, batch_size):
    model.eval()
    data_list = []
    for i, data in enumerate(progress_bar(loader)):
        data_list += predict_batch(model, data, device)
    expected_length = model.pred_len * len(loader) * batch_size
#     assert len(data_list) == expected_length, f"len = {len(data_list)} expected = {expected_length}"
    return data_list

def evaluate(model, valid_data, device):
    model.eval()
    loss_list = []
    mcrmse = []
    for i, data in enumerate(valid_data):
        with torch.no_grad():
            y = model(data["sequence"].to(device), data["bpp"].to(device))
            y = y[:, :68]
            mcrmse_ = MCRMSE(data["label"].to(device), y)[data["signal_to_noise"] > 1]
            mcrmse.append(mcrmse_.mean().item())
            loss = sn_mcrmse_loss(y, data["label"].to(device), data["signal_to_noise"].to(device))
            loss_list.append(loss.item())
    model.train()
    return dict(loss=np.mean(loss_list), mcmse=np.mean(mcrmse))

def train(model, train_data, valid_data, train_data_lg, valid_data_lg, optimizer, lr_scheduler, epochs=10, device="cpu",
          start_epoch: int = 0, log_path: str = "./logs"):
    print(f"device: {device}")
    losses = []
    writer = SummaryWriter(log_path)
    it = 0
    model_save_path = Path(MODEL_SAVE_PATH)
    start_epoch = start_epoch
    end_epoch = start_epoch + epochs
    if not model_save_path.exists():
        model_save_path.mkdir(parents=True)
    min_eval_loss = 10.0
    min_eval_epoch = None
    for epoch in progress_bar(range(start_epoch, end_epoch)):
        print(f"epoch: {epoch}")
        model.train()
        
        for i, (data, data_lg) in enumerate(itertools.zip_longest(train_data, train_data_lg)):
            if data_lg is not None:
                _, loss_lg = learn_from_batch(model, data_lg, optimizer, lr_scheduler, device)
                losses.append(loss_lg.item())
                
            _, loss = learn_from_batch(model, data, optimizer, lr_scheduler, device)
            loss_v = loss.item()
            writer.add_scalar('loss', loss_v, it)
            losses.append(loss_v)
            it += 1
        print(f'epoch: {epoch} loss: {np.mean(losses)}')
        losses = []

        eval_result = evaluate(model, valid_data, device)
        eval_loss = eval_result["loss"]
        if eval_loss <= min_eval_loss:
            min_eval_epoch = epoch
            min_eval_loss = eval_loss
            min_mcrmse = eval_result["mcmse"]
            valid_pred = predict_data(model, valid_data, device, BATCH_SIZE)

        print(f"eval loss: {eval_loss} {eval_result['mcmse']}")
        writer.add_scalar(f"evaluate/loss", eval_loss, epoch)
        writer.add_scalar(f"evaluate/mcmse", eval_result["mcmse"], epoch)
        model.train()
        torch.save(optimizer.state_dict(), str(model_save_path / "optimizer.pt"))
        torch.save(model.state_dict(), str(model_save_path / f"model-{epoch}.pt"))
        
    print(f'min eval loss: {min_eval_loss} min mcrmse: {min_mcrmse} epoch {min_eval_epoch}')
    return min_eval_epoch, valid_pred

In [11]:
device = torch.device('cuda')
BATCH_SIZE = 64

# categorical value for target (used for stratified kfold)
#def add_y_cat(df):
#    target_mean = df['reactivity'].apply(np.mean) + \
#                  df['deg_Mg_pH10'].apply(np.mean) + \
#                  df['deg_Mg_50C'].apply(np.mean) 
                  #df['deg_pH10'].apply(np.mean) + \
                  #df['deg_50C'].apply(np.mean)
#    df['y_cat'] = pd.qcut(np.array(target_mean), q=20).codes
    
base_train_data = pd.read_json(str(Path(BASE_PATH) / 'train.json'), lines=True)

base_train_data = aug_data(base_train_data)
base_train_data = base_train_data.append(pseudo_st)
base_train_data = base_train_data.reset_index(drop = True)
base_train_data_lg = pseudo_lg.reset_index(drop = True)
#add_y_cat(base_train_data)

samples = base_train_data
save_path = Path("./model_prediction")
if not save_path.exists():
    save_path.mkdir(parents=True)
shutil.rmtree("./model", True)
shutil.rmtree("./logs", True)
split = ShuffleSplit(n_splits=5, test_size=.1)
# kf = StratifiedKFold(n_splits = 5, shuffle = True, random_state = 123)
kf = GroupKFold(n_splits=5)
    
ids = samples.reset_index()["id"]

set_seed(123)

oof_list = []
for fold, ((train_index, test_index), (train_index_lg, test_index_lg)) in enumerate(zip(kf.split(samples, groups=ids), 
                                                     kf.split(base_train_data_lg, groups=base_train_data_lg['id']))): #split.split(samples) / kf
    print(f"fold: {fold}")
    train_df = samples.loc[train_index].reset_index()
    val_df = samples.loc[test_index].reset_index()
    train_loader = create_loader(train_df, BATCH_SIZE)
    valid_loader = create_loader(val_df, BATCH_SIZE)
    print(train_df.shape, val_df.shape)
    
    train_df_lg = base_train_data_lg.loc[train_index_lg].reset_index()
    val_df_lg = base_train_data_lg.loc[test_index_lg].reset_index()
    train_loader_lg = create_loader(train_df_lg, BATCH_SIZE)
    valid_loader_lg = create_loader(val_df_lg, BATCH_SIZE)
    print(train_df_lg.shape, val_df_lg.shape)
    
    ae_model = AEModel()
    state_dict = torch.load("./ae-model.pt")
    ae_model.load_state_dict(state_dict)
    del state_dict
    
    model = FromAeModel(ae_model.seq, pred_len=91)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    lr_scheduler = None
    
    epoch, valid_pred = train(model, train_loader, valid_loader, train_loader_lg, valid_loader_lg, optimizer, lr_scheduler, 120, device=device,
                  log_path=f"logs/{fold}") #150 #200 epochs
    oof_list += valid_pred
    shutil.copyfile(str(Path(MODEL_SAVE_PATH) / f"./model-{epoch}.pt"), f"model_prediction/model-{fold}.pt")
    del model

fold: 0
(1, 107, 107, 3)
(1, 107, 107, 3)
(4343, 23) (1086, 23)
(1, 130, 130, 3)
(1, 130, 130, 3)
(2404, 16) (601, 16)
device: cuda


epoch: 0
epoch: 0 loss: 0.18869469190160354


eval loss: 0.21509586582567805 0.30248562270230833
epoch: 1
epoch: 1 loss: 0.15431082517049252


eval loss: 0.19450895108046576 0.2751677992319841
epoch: 2
epoch: 2 loss: 0.1409647957733456


eval loss: 0.18319475955616493 0.25941055190707996
epoch: 3
epoch: 3 loss: 0.13313073580271603


eval loss: 0.1751153792362447 0.24878046937433704
epoch: 4
epoch: 4 loss: 0.1260147069673287


eval loss: 0.17387720131095485 0.24637305981125834
epoch: 5
epoch: 5 loss: 0.1251439502125519


eval loss: 0.1673050698606395 0.23773711710470308
epoch: 6
epoch: 6 loss: 0.11798529815028175


eval loss: 0.16092019830261317 0.22820773168995986
epoch: 7
epoch: 7 loss: 0.11572650257334663


eval loss: 0.1588031307547628 0.22562494214136528
epoch: 8
epoch: 8 loss: 0.12525338029193928
eval loss: 0.1757836279326918 0.24876036953446062
epoch: 9
epoch: 9 loss: 0.11915921948519449


eval loss: 0.15715470709875767 0.22341047527250044
epoch: 10
epoch: 10 loss: 0.10923985603093833


eval loss: 0.15555753364278008 0.22079903092879008
epoch: 11
epoch: 11 loss: 0.10734444366286816


eval loss: 0.15293985531077664 0.21760346058056707
epoch: 12
epoch: 12 loss: 0.10502835396815706
eval loss: 0.15474201691317083 0.21940833837155474
epoch: 13
epoch: 13 loss: 0.10354271482633025


eval loss: 0.15259977259905685 0.2163742196166505
epoch: 14
epoch: 14 loss: 0.10607421644744496
eval loss: 0.1567805315656572 0.22190535566665176
epoch: 15
epoch: 15 loss: 0.12386550660789576
eval loss: 0.22185320611084 0.3121371216417926
epoch: 16
epoch: 16 loss: 0.12293150342125381
eval loss: 0.1554273144667279 0.22103406711712886
epoch: 17
epoch: 17 loss: 0.10413770863564441


eval loss: 0.14908484358761212 0.21175911591271573
epoch: 18
epoch: 18 loss: 0.09990744230075865


eval loss: 0.14859221026640015 0.21067117098473023
epoch: 19
epoch: 19 loss: 0.09790598302829052


eval loss: 0.14704820099090526 0.20879715472657928
epoch: 20
epoch: 20 loss: 0.09651227709511713
eval loss: 0.14820451838263557 0.21041560433488443
epoch: 21
epoch: 21 loss: 0.10078369775647837
eval loss: 0.14976783183635087 0.21251299375565322
epoch: 22
epoch: 22 loss: 0.0970091755189624


eval loss: 0.14611168419077364 0.20693586608448947
epoch: 23
epoch: 23 loss: 0.09341484359190579
eval loss: 0.14628201972270272 0.2076701873152163
epoch: 24
epoch: 24 loss: 0.09468588480376847
eval loss: 0.14649692129650738 0.20791293997436094
epoch: 25
epoch: 25 loss: 0.0924241936974732


eval loss: 0.14443135377595948 0.20517442146313258
epoch: 26
epoch: 26 loss: 0.09051501597197727


eval loss: 0.14429533194912475 0.20467855112872088
epoch: 27
epoch: 27 loss: 0.08933266638460298


eval loss: 0.14403480019925374 0.20422429382854665
epoch: 28
epoch: 28 loss: 0.0885558130814473


eval loss: 0.14304431794884204 0.20292351336027106
epoch: 29
epoch: 29 loss: 0.0877060894974546
eval loss: 0.14346866565860206 0.20344658873679125
epoch: 30
epoch: 30 loss: 0.0868115956010161
eval loss: 0.14344341988342973 0.20346013773636057
epoch: 31
epoch: 31 loss: 0.08669366939460854
eval loss: 0.14356581488087336 0.20353281118050925
epoch: 32
epoch: 32 loss: 0.08522749959554682


eval loss: 0.14293531277769259 0.20307684690755107
epoch: 33
epoch: 33 loss: 0.08476193354312885


eval loss: 0.14226254317149534 0.2021286966720494
epoch: 34
epoch: 34 loss: 0.08386400286253114
eval loss: 0.14230773889584059 0.20162747931398478
epoch: 35
epoch: 35 loss: 0.08537384817615597
eval loss: 0.1456272528432045 0.20634943722768417
epoch: 36
epoch: 36 loss: 0.08595606949187404
eval loss: 0.14454779963789363 0.204390543909979
epoch: 37
epoch: 37 loss: 0.11835942973405877
eval loss: 0.15920709250167128 0.22390182405837197
epoch: 38
epoch: 38 loss: 0.19188321247579004
eval loss: 0.2869085094367883 0.4000022557100484
epoch: 39
epoch: 39 loss: 0.18533785016249688
eval loss: 0.21418841035489594 0.30122886140382676
epoch: 40
epoch: 40 loss: 0.13353221305368934
eval loss: 0.15677584298331543 0.22165533659978412
epoch: 41
epoch: 41 loss: 0.1149969141042389
eval loss: 0.15452779358309346 0.21916269268016256
epoch: 42
epoch: 42 loss: 0.10107235680125717
eval loss: 0.14720788612483793 0.20816284246932232
epoch: 43
epoch: 43 loss: 0.09041530163376112
eval loss: 0.14562411651371648 0.2058

eval loss: 0.142155870868348 0.20136865646865404
epoch: 52
epoch: 52 loss: 0.07854060466614908
eval loss: 0.1425838904906089 0.2020993762304355
epoch: 53
epoch: 53 loss: 0.07800511759789998
eval loss: 0.14239899436921089 0.20184685135325353
epoch: 54
epoch: 54 loss: 0.07760558136283219


eval loss: 0.1421557652140282 0.20131153152424752
epoch: 55
epoch: 55 loss: 0.07758087601633956
eval loss: 0.1428274658654587 0.20203749025208054
epoch: 56
epoch: 56 loss: 0.07671235861514518


eval loss: 0.14214810796512234 0.2012805080693335
epoch: 57
epoch: 57 loss: 0.07758128118180456
eval loss: 0.1538635044104888 0.2170393570229879
epoch: 58
epoch: 58 loss: 0.12041203110771122
eval loss: 0.294782348236443 0.41035590891179086
epoch: 59
epoch: 59 loss: 0.22292333824267474
eval loss: 0.2927106332316579 0.40898228023504357
epoch: 60
epoch: 60 loss: 0.21509345787175546
eval loss: 0.2633926990430057 0.36872002751121846
epoch: 61
epoch: 61 loss: 0.20651166714009647
eval loss: 0.2522488525871006 0.3535871648300144
epoch: 62
epoch: 62 loss: 0.19014336252440855
eval loss: 0.278877479232329 0.389595004738562
epoch: 63
epoch: 63 loss: 0.1833166545997828
eval loss: 0.217402462812959 0.3070021921262876
epoch: 64
epoch: 64 loss: 0.1802360518539033
eval loss: 0.29456635546126825 0.4112549846468226
epoch: 65
epoch: 65 loss: 0.22219364870990285
eval loss: 0.2890757874114241 0.40292004071521414
epoch: 66
epoch: 66 loss: 0.216075743495483
eval loss: 0.25920930796384256 0.3637703317579688
ep

epoch: 0
epoch: 0 loss: 0.1832286934888531


eval loss: 0.21038217970474338 0.30068665307801584
epoch: 1
epoch: 1 loss: 0.14912984156161788


eval loss: 0.19386492164931618 0.27750831348618665
epoch: 2
epoch: 2 loss: 0.13651658355524932


eval loss: 0.1845491769986911 0.2639032875874997
epoch: 3
epoch: 3 loss: 0.12914610431105483


eval loss: 0.17771856598441776 0.25480550599042817
epoch: 4
epoch: 4 loss: 0.12393856190925143


eval loss: 0.17443900564929282 0.24949470390053172
epoch: 5
epoch: 5 loss: 0.11811722891611529


eval loss: 0.1679366497425598 0.2403755247960581
epoch: 6
epoch: 6 loss: 0.11532141694079624


eval loss: 0.16611673249786985 0.237908110631734
epoch: 7
epoch: 7 loss: 0.1157567493780093
eval loss: 0.21125046228096095 0.2987312732863776
epoch: 8
epoch: 8 loss: 0.13836243568846457
eval loss: 0.20397646100124658 0.29260075783710265
epoch: 9
epoch: 9 loss: 0.13801112106816285
eval loss: 0.18071308658507892 0.25897234695994487
epoch: 10
epoch: 10 loss: 0.12177578668090332
eval loss: 0.1668787208200755 0.23854222839386907
epoch: 11
epoch: 11 loss: 0.1152473090498907


eval loss: 0.1654172368798285 0.23620444066538832
epoch: 12
epoch: 12 loss: 0.11199007016842573


eval loss: 0.16282117626540984 0.2322687244610462
epoch: 13
epoch: 13 loss: 0.11211027210863633


eval loss: 0.15978857046400896 0.22801078602171237
epoch: 14
epoch: 14 loss: 0.10757429316258023
eval loss: 0.1615180784028432 0.23032060946034827
epoch: 15
epoch: 15 loss: 0.10584828971766373
eval loss: 0.1606445113411099 0.22872858322579584
epoch: 16
epoch: 16 loss: 0.10668201116006595


eval loss: 0.15734968766661905 0.22412733114205635
epoch: 17
epoch: 17 loss: 0.10354947301974884


eval loss: 0.1544462951675551 0.22049108639066414
epoch: 18
epoch: 18 loss: 0.1004584953130145
eval loss: 0.1545907271046382 0.22089413742669703
epoch: 19
epoch: 19 loss: 0.09887429249071567


eval loss: 0.15341171273806897 0.21878148309801249
epoch: 20
epoch: 20 loss: 0.09824207162164983


eval loss: 0.15321378500240154 0.2181085176715476
epoch: 21
epoch: 21 loss: 0.09752397725111925


eval loss: 0.15136709033941512 0.21597226534058256
epoch: 22
epoch: 22 loss: 0.09721412855948583


eval loss: 0.15133605276523243 0.21531598747422812
epoch: 23
epoch: 23 loss: 0.09460453858333248


eval loss: 0.14975974049896146 0.21341286679001334
epoch: 24
epoch: 24 loss: 0.09321093209152327
eval loss: 0.14999677833956307 0.2139836308237914
epoch: 25
epoch: 25 loss: 0.09305403321050745


eval loss: 0.14941174499747079 0.21287141298530401
epoch: 26
epoch: 26 loss: 0.09389665184417208
eval loss: 0.1496977681558611 0.21322287012846478
epoch: 27
epoch: 27 loss: 0.12018148342413852
eval loss: 0.1668848197789503 0.23777012144633627
epoch: 28
epoch: 28 loss: 0.12467698723091092
eval loss: 0.19589392654260068 0.2740959610499951
epoch: 29
epoch: 29 loss: 0.11724711301381519
eval loss: 0.16371802365719404 0.23176751782007224
epoch: 30
epoch: 30 loss: 0.11263618748569551
eval loss: 0.15358245542124177 0.21792496299243233
epoch: 31
epoch: 31 loss: 0.09862341125052584
eval loss: 0.15251684168551896 0.21643041710533284
epoch: 32
epoch: 32 loss: 0.09340114830959054
eval loss: 0.14960590341733235 0.21289125092464734
epoch: 33
epoch: 33 loss: 0.09468501541458367
eval loss: 0.15326417101948064 0.21763965407917746
epoch: 34
epoch: 34 loss: 0.09151495835435276


eval loss: 0.14903358025184438 0.2117650910437114
epoch: 35
epoch: 35 loss: 0.0917561589855559
eval loss: 0.154362648048234 0.2189423504608124
epoch: 36
epoch: 36 loss: 0.09289205375811938
eval loss: 0.15014833781405454 0.21319991782318837
epoch: 37
epoch: 37 loss: 0.088929663736904


eval loss: 0.14830003478543818 0.21070607380599804
epoch: 38
epoch: 38 loss: 0.08757611039587314


eval loss: 0.14779360104963987 0.21057484648719757
epoch: 39
epoch: 39 loss: 0.08622343134482015


eval loss: 0.14726861820758758 0.20932356962708815
epoch: 40
epoch: 40 loss: 0.08586604306549371
eval loss: 0.14771685646489371 0.20962918779514733
epoch: 41
epoch: 41 loss: 0.0875067165533
eval loss: 0.14927971608520943 0.21180681381184907
epoch: 42
epoch: 42 loss: 0.09456302307248439
eval loss: 0.15080723227812193 0.21396071063846928
epoch: 43
epoch: 43 loss: 0.09451773495748024
eval loss: 0.14849017641398485 0.21101920026832763
epoch: 44
epoch: 44 loss: 0.08857591975152145
eval loss: 0.1489919258385354 0.21149936853170392
epoch: 45
epoch: 45 loss: 0.08637018584982199
eval loss: 0.14754811775012544 0.20918798831816568
epoch: 46
epoch: 46 loss: 0.08407233078301478


eval loss: 0.14626938092878336 0.2080678048686785
epoch: 47
epoch: 47 loss: 0.08335056413822858
eval loss: 0.16936648394820839 0.24019859431851398
epoch: 48
epoch: 48 loss: 0.08525172194383165
eval loss: 0.14904995969912468 0.21114030999015176
epoch: 49
epoch: 49 loss: 0.09516347522641104
eval loss: 0.15132289547615715 0.21380002108415105
epoch: 50
epoch: 50 loss: 0.10802108988029596
eval loss: 0.35196827770244726 0.4949795999605602
epoch: 51
epoch: 51 loss: 0.16942831133496042
eval loss: 0.2411107094684406 0.3402562388660316
epoch: 52
epoch: 52 loss: 0.18293413382306184
eval loss: 0.2715237483109847 0.3834438383240407
epoch: 53
epoch: 53 loss: 0.15507844194350817
eval loss: 0.2057161522645555 0.2924617203535471
epoch: 54
epoch: 54 loss: 0.1613123797707975
eval loss: 0.19094027150693313 0.27145091098060936
epoch: 55
epoch: 55 loss: 0.1487135490947118
eval loss: 0.1739728410387044 0.2471791209815737
epoch: 56
epoch: 56 loss: 0.13956945781698935
eval loss: 0.17201342359629754 0.242688795

eval loss: 0.14609222483964282 0.20673961477568686
epoch: 104
epoch: 104 loss: 0.07628775694646084
eval loss: 0.15350980247338453 0.21749236414518563
epoch: 105
epoch: 105 loss: 0.08448937360527922
eval loss: 0.15172049389965964 0.2147925680663555
epoch: 106
epoch: 106 loss: 0.08513353577674629
eval loss: 0.14934353053092045 0.21080586485033972
epoch: 107
epoch: 107 loss: 0.07586879085377954
eval loss: 0.1467867445186285 0.20791687867981692
epoch: 108
epoch: 108 loss: 0.07349166215442716


eval loss: 0.14549492653755902 0.20611269358306253
epoch: 109
epoch: 109 loss: 0.07136571280809806


eval loss: 0.14507171754570933 0.2056472649100093
epoch: 110
epoch: 110 loss: 0.07189347995343821
eval loss: 0.14568729406145262 0.2061895010430972
epoch: 111
epoch: 111 loss: 0.07170257448182911
eval loss: 0.14547295183089715 0.20578128129742382
epoch: 112
epoch: 112 loss: 0.07241577789997256
eval loss: 0.14571886608888765 0.2062818834103401
epoch: 113
epoch: 113 loss: 0.07373850763748936
eval loss: 0.1469411068344604 0.20785460595644406
epoch: 114
epoch: 114 loss: 0.07254101410445649
eval loss: 0.1463349370210146 0.20723244154195097
epoch: 115
epoch: 115 loss: 0.08371731610219998
eval loss: 0.1491470278947283 0.21079043768189387
epoch: 116
epoch: 116 loss: 0.10113203008788176
eval loss: 0.18593537890678438 0.261839799005363
epoch: 117
epoch: 117 loss: 0.10633574376126173
eval loss: 0.14913430400206606 0.2105669915376956
epoch: 118
epoch: 118 loss: 0.08782371873441498
eval loss: 0.1502217163854782 0.21254845863904603
epoch: 119
epoch: 119 loss: 0.08547186041953461
eval loss: 0.1992831

epoch: 0
epoch: 0 loss: 0.18531746649434072


eval loss: 0.21332057522823747 0.3046974466306092
epoch: 1
epoch: 1 loss: 0.14764689188132557


eval loss: 0.19518473079850898 0.2793498002712094
epoch: 2
epoch: 2 loss: 0.13781922017952278


eval loss: 0.18987851939605213 0.27190676816928816
epoch: 3
epoch: 3 loss: 0.12859661645504497


eval loss: 0.18000530406679704 0.25821519946245086
epoch: 4
epoch: 4 loss: 0.12168485680638562


eval loss: 0.17607336714889513 0.2524257227324666
epoch: 5
epoch: 5 loss: 0.11698641053310091


eval loss: 0.17054452289160163 0.245148346450804
epoch: 6
epoch: 6 loss: 0.11291703685612722


eval loss: 0.1663650819310744 0.2393434128948634
epoch: 7
epoch: 7 loss: 0.11029112114669373


eval loss: 0.1660249590334316 0.23866174197721157
epoch: 8
epoch: 8 loss: 0.11038293590581105
eval loss: 0.16677039934872873 0.23943031469155157
epoch: 9
epoch: 9 loss: 0.10636431992360205


eval loss: 0.16201087379083765 0.23247533450985491
epoch: 10
epoch: 10 loss: 0.10381943756735731


eval loss: 0.16035222569034113 0.230725930101216
epoch: 11
epoch: 11 loss: 0.10237020156373647


eval loss: 0.159793379048548 0.22965117663843235
epoch: 12
epoch: 12 loss: 0.10035928898982334


eval loss: 0.15779644060693418 0.22711054188683313
epoch: 13
epoch: 13 loss: 0.09850383780698418


eval loss: 0.156152568187931 0.22490951654755229
epoch: 14
epoch: 14 loss: 0.09648923634521768


eval loss: 0.1551984075825674 0.22356919990562446
epoch: 15
epoch: 15 loss: 0.09539727395966263
eval loss: 0.15628187173938907 0.2244702977999045
epoch: 16
epoch: 16 loss: 0.09444787969750358
eval loss: 0.15641149079031985 0.22430369200359057
epoch: 17
epoch: 17 loss: 0.09319350106385517
eval loss: 0.1556916100230088 0.22363121079082485
epoch: 18
epoch: 18 loss: 0.09183257881246822


eval loss: 0.15272744366414348 0.21973678505446748
epoch: 19
epoch: 19 loss: 0.0904826660031846


eval loss: 0.1520977756953178 0.21946004791462817
epoch: 20
epoch: 20 loss: 0.08932063184443143


eval loss: 0.15173076711523084 0.21858674852269835
epoch: 21
epoch: 21 loss: 0.08854714909105943


eval loss: 0.15040424802936775 0.21656224065081778
epoch: 22
epoch: 22 loss: 0.08762863926461245
eval loss: 0.1510437881700893 0.21762046141008398
epoch: 23
epoch: 23 loss: 0.08676173143180316


eval loss: 0.15015414929353332 0.21646247129176963
epoch: 24
epoch: 24 loss: 0.0859209586261886


eval loss: 0.14997961136361138 0.21615591697730982
epoch: 25
epoch: 25 loss: 0.08581118815547475


eval loss: 0.14980180286453473 0.21529819601054237
epoch: 26
epoch: 26 loss: 0.0837618288240845


eval loss: 0.14928186829742723 0.21442538931781854
epoch: 27
epoch: 27 loss: 0.08366248837270814


eval loss: 0.1482850312283317 0.2134564245202819
epoch: 28
epoch: 28 loss: 0.08244057280008611
eval loss: 0.14893619090683016 0.21427485096922927
epoch: 29
epoch: 29 loss: 0.08159473020994444
eval loss: 0.14902233495898412 0.21418101860225547
epoch: 30
epoch: 30 loss: 0.08088132318005184


eval loss: 0.14708170003923707 0.2119119977909746
epoch: 31
epoch: 31 loss: 0.08008558206647927
eval loss: 0.1486222110066081 0.21364142287578564
epoch: 32
epoch: 32 loss: 0.07949208903298434
eval loss: 0.14768442912216223 0.21232290322396394
epoch: 33
epoch: 33 loss: 0.07886192377311241
eval loss: 0.1475135890028587 0.21218925623176146
epoch: 34
epoch: 34 loss: 0.07817733769276927


eval loss: 0.14621379012224206 0.2105910002185405
epoch: 35
epoch: 35 loss: 0.07765380616542442
eval loss: 0.14710123093290117 0.21170929541355904
epoch: 36
epoch: 36 loss: 0.07736498651143556


eval loss: 0.14617452165589714 0.21025841490986652
epoch: 37
epoch: 37 loss: 0.07653870077042509


eval loss: 0.14576648132695463 0.20966346104869701
epoch: 38
epoch: 38 loss: 0.07608818271172453
eval loss: 0.14596599120326137 0.2094538632545681
epoch: 39
epoch: 39 loss: 0.07554892936166578
eval loss: 0.1464172440634691 0.21036850495419362
epoch: 40
epoch: 40 loss: 0.07501620443888823
eval loss: 0.1467126344417764 0.2108214552108735
epoch: 41
epoch: 41 loss: 0.0743736849451858
eval loss: 0.14665137087543975 0.2105739608119368
epoch: 42
epoch: 42 loss: 0.07403739398248818
eval loss: 0.14624387631783697 0.21017756339772603
epoch: 43
epoch: 43 loss: 0.0734152928953578
eval loss: 0.1466145860028855 0.2104557461200624
epoch: 44
epoch: 44 loss: 0.07308666676874416
eval loss: 0.1464534655455723 0.21038678553322265
epoch: 45
epoch: 45 loss: 0.0725377553681599


eval loss: 0.14517860495694634 0.20836465097799534
epoch: 46
epoch: 46 loss: 0.07237516604369343
eval loss: 0.14742642226571467 0.2113471557295405
epoch: 47
epoch: 47 loss: 0.0717536232154588
eval loss: 0.1455317495657306 0.2087034259553394
epoch: 48
epoch: 48 loss: 0.07168438189446544
eval loss: 0.14653775024283758 0.20999695412510175
epoch: 49
epoch: 49 loss: 0.11554991622186912
eval loss: 0.29976874848255336 0.4235337728838602
epoch: 50
epoch: 50 loss: 0.2255078995879097
eval loss: 0.29600729327978614 0.418012439453577
epoch: 51
epoch: 51 loss: 0.22258111887541307
eval loss: 0.29510022113793033 0.41604302420579714
epoch: 52
epoch: 52 loss: 0.22215770016769743
eval loss: 0.2946012521401246 0.4154334699554514
epoch: 53
epoch: 53 loss: 0.22200054825487278
eval loss: 0.29701107021514533 0.41923316596921084
epoch: 54
epoch: 54 loss: 0.22188104536265466
eval loss: 0.2936972613635407 0.41454488487881197
epoch: 55
epoch: 55 loss: 0.22193636992525334
eval loss: 0.293795668475409 0.4146343982

epoch: 0
epoch: 0 loss: 0.18691852125308717


eval loss: 0.22426756288149874 0.31399002532891285
epoch: 1
epoch: 1 loss: 0.15104417106654236


eval loss: 0.20315634522344614 0.28691051515362553
epoch: 2
epoch: 2 loss: 0.13845591832004794


eval loss: 0.18799383465066422 0.26616914784685347
epoch: 3
epoch: 3 loss: 0.13713045501113688
eval loss: 0.19903947824599383 0.28069053931974974
epoch: 4
epoch: 4 loss: 0.13348535755807506


eval loss: 0.1826408019575425 0.2584563116107897
epoch: 5
epoch: 5 loss: 0.1255450032693182


eval loss: 0.1796403463750513 0.25456835826131724
epoch: 6
epoch: 6 loss: 0.1209047656197222


eval loss: 0.1708505193126682 0.2429247761046927
epoch: 7
epoch: 7 loss: 0.11708989495850593


eval loss: 0.1694397709271088 0.24007868216540662
epoch: 8
epoch: 8 loss: 0.11583377689800216


eval loss: 0.16793526481754764 0.23815303105628333
epoch: 9
epoch: 9 loss: 0.11325436024921977


eval loss: 0.167478221155667 0.23728790973034844
epoch: 10
epoch: 10 loss: 0.10971915544784068


eval loss: 0.16134649155200678 0.22982833611754971
epoch: 11
epoch: 11 loss: 0.10750899725736024
eval loss: 0.1640502291511768 0.2329995097239545
epoch: 12
epoch: 12 loss: 0.10620429569801691


eval loss: 0.15752375112873768 0.2243614255187069
epoch: 13
epoch: 13 loss: 0.10326570334199167


eval loss: 0.1562635186010555 0.22303765548873863
epoch: 14
epoch: 14 loss: 0.10134403329161189
eval loss: 0.15670940301907765 0.2229901363310685
epoch: 15
epoch: 15 loss: 0.10043781516309543


eval loss: 0.15482742796809731 0.22058113513662495
epoch: 16
epoch: 16 loss: 0.09904566528982216


eval loss: 0.1539928760643002 0.21953924960515986
epoch: 17
epoch: 17 loss: 0.09774515241963698
eval loss: 0.1552127339079183 0.22096335798389524
epoch: 18
epoch: 18 loss: 0.09908349465040776
eval loss: 0.1912004477770912 0.2661339413645356
epoch: 19
epoch: 19 loss: 0.1326379587926569
eval loss: 0.18175564092843807 0.25538380410478045
epoch: 20
epoch: 20 loss: 0.11355023784867342
eval loss: 0.16098281353339647 0.2281537319619311
epoch: 21
epoch: 21 loss: 0.10952933290624825
eval loss: 0.15931816208383395 0.22599055500835616
epoch: 22
epoch: 22 loss: 0.09901273717815065


eval loss: 0.15351923986047705 0.21857723751712943
epoch: 23
epoch: 23 loss: 0.09662982859655103
eval loss: 0.15359635694989032 0.21884793509931677
epoch: 24
epoch: 24 loss: 0.09465076290832446


eval loss: 0.1522594845529625 0.21651779625085182
epoch: 25
epoch: 25 loss: 0.09321380595438644


eval loss: 0.15190553785862898 0.21627464301145652
epoch: 26
epoch: 26 loss: 0.09261110342315362
eval loss: 0.15271977245030938 0.21699893290726222
epoch: 27
epoch: 27 loss: 0.09154197069147645
eval loss: 0.15258088926476723 0.21667465247398215
epoch: 28
epoch: 28 loss: 0.09060768225751031


eval loss: 0.15050098673421727 0.2141726740297598
epoch: 29
epoch: 29 loss: 0.0899079455929122
eval loss: 0.1525933099437251 0.21691592273060586
epoch: 30
epoch: 30 loss: 0.09952334004070829
eval loss: 0.15253803621954132 0.21692977010444503
epoch: 31
epoch: 31 loss: 0.0891048406502607


eval loss: 0.15047070465206452 0.21434208861810317
epoch: 32
epoch: 32 loss: 0.08750422771410164


eval loss: 0.15014847072065038 0.21392520698806483
epoch: 33
epoch: 33 loss: 0.08669301821777099
eval loss: 0.1537629500385292 0.21827046844326176
epoch: 34
epoch: 34 loss: 0.08874690260543472


eval loss: 0.14990456295339333 0.21302661671511247
epoch: 35
epoch: 35 loss: 0.08538140032247536


eval loss: 0.14895476874874813 0.211931756427891
epoch: 36
epoch: 36 loss: 0.08513434720050178
eval loss: 0.1492016234921498 0.21253000084820575
epoch: 37
epoch: 37 loss: 0.09186444191721287
eval loss: 0.17230380976332862 0.2438397982151086
epoch: 38
epoch: 38 loss: 0.10479703606160636
eval loss: 0.15055459290950673 0.21407773713376782
epoch: 39
epoch: 39 loss: 0.08738939087639445
eval loss: 0.15002031562663962 0.21370981380535597
epoch: 40
epoch: 40 loss: 0.08666080214821556
eval loss: 0.14965451310288738 0.2129237048732554
epoch: 41
epoch: 41 loss: 0.08354562510583127


eval loss: 0.14869463060172453 0.2110472476582641
epoch: 42
epoch: 42 loss: 0.08242162606534544


eval loss: 0.14741724778292073 0.21016090956427294
epoch: 43
epoch: 43 loss: 0.08150809769478892
eval loss: 0.14756039944692956 0.2101560627350413
epoch: 44
epoch: 44 loss: 0.08052281135122473
eval loss: 0.14774573511091701 0.21029492570335323
epoch: 45
epoch: 45 loss: 0.08240074215268957
eval loss: 0.1545130139527248 0.21818566887788637
epoch: 46
epoch: 46 loss: 0.08539918286582265
eval loss: 0.14795523604885483 0.21059036122001337
epoch: 47
epoch: 47 loss: 0.07985768120941737


eval loss: 0.14655194189520718 0.20837830998935702
epoch: 48
epoch: 48 loss: 0.07866774999801766
eval loss: 0.14663750445347204 0.2088500879645359
epoch: 49
epoch: 49 loss: 0.10080274399573962
eval loss: 0.16367352536304308 0.2311112091570251
epoch: 50
epoch: 50 loss: 0.12074246578385898
eval loss: 0.3974068068364575 0.5469630091175932
epoch: 51
epoch: 51 loss: 0.19952805204676438
eval loss: 0.2695055684137919 0.375880238185865
epoch: 52
epoch: 52 loss: 0.15267749435657782
eval loss: 0.1925198447796636 0.26942576222087855
epoch: 53
epoch: 53 loss: 0.10784399004676543
eval loss: 0.15399264756814926 0.21877410466019845
epoch: 54
epoch: 54 loss: 0.08846659295669417
eval loss: 0.1500483803877169 0.21346023670351913
epoch: 55
epoch: 55 loss: 0.08354578969407167
eval loss: 0.14881261134267004 0.21138042833331736
epoch: 56
epoch: 56 loss: 0.07999258516274364
eval loss: 0.14872771563907064 0.21130711562315307
epoch: 57
epoch: 57 loss: 0.07936170123852379
eval loss: 0.14798355447137096 0.210290

eval loss: 0.14651152149381225 0.20858794302305714
epoch: 65
epoch: 65 loss: 0.07406692707864065


eval loss: 0.14605124560117275 0.20811529748816995
epoch: 66
epoch: 66 loss: 0.0740246449126862
eval loss: 0.1468445701476806 0.20856445687730643
epoch: 67
epoch: 67 loss: 0.07431710417132323
eval loss: 0.14784209587432637 0.20989970462119834
epoch: 68
epoch: 68 loss: 0.0777736546417284
eval loss: 0.1716274185870072 0.24192973691788938
epoch: 69
epoch: 69 loss: 0.15615135514837558
eval loss: 0.33103289911672773 0.45711026133464777
epoch: 70
epoch: 70 loss: 0.22277123495543721
eval loss: 0.3041969559367203 0.4202574173187732
epoch: 71
epoch: 71 loss: 0.21999072003540304
eval loss: 0.301913840832907 0.41743499270597956
epoch: 72
epoch: 72 loss: 0.2193663025120777
eval loss: 0.30802842313943973 0.42583103907262837
epoch: 73
epoch: 73 loss: 0.21956554624077593
eval loss: 0.3018169116455983 0.4170716259842517
epoch: 74
epoch: 74 loss: 0.21923615437622848
eval loss: 0.3015089183793339 0.4164820904728956
epoch: 75
epoch: 75 loss: 0.2188255936744632
eval loss: 0.3053062667538618 0.421992998543

epoch: 0
epoch: 0 loss: 0.18164363834911773


eval loss: 0.2296332019581523 0.3231019029876028
epoch: 1
epoch: 1 loss: 0.1515320348012016


eval loss: 0.20307952733124074 0.28695060730878
epoch: 2
epoch: 2 loss: 0.13822229583458953


eval loss: 0.18906808014715487 0.2674766511446341
epoch: 3
epoch: 3 loss: 0.12854440571802023


eval loss: 0.1789639286574892 0.25319820390174885
epoch: 4
epoch: 4 loss: 0.12220646287658292


eval loss: 0.17882584491103695 0.25258758492300615
epoch: 5
epoch: 5 loss: 0.11804124204120992


eval loss: 0.17170344927275968 0.24363980324788972
epoch: 6
epoch: 6 loss: 0.11541881032155428


eval loss: 0.1695468605130066 0.24040066370542476
epoch: 7
epoch: 7 loss: 0.112069783883026


eval loss: 0.1651371480349907 0.23489446084547502
epoch: 8
epoch: 8 loss: 0.10820088034457243


eval loss: 0.16324682058856946 0.23212492815385927
epoch: 9
epoch: 9 loss: 0.10580747069022071
eval loss: 0.163364165230014 0.2326371317746523
epoch: 10
epoch: 10 loss: 0.11397423102126528
eval loss: 0.2730206048825907 0.37919457737918166
epoch: 11
epoch: 11 loss: 0.12413141560746035
eval loss: 0.17149867999642382 0.24388303908739914
epoch: 12
epoch: 12 loss: 0.13358961811197326
eval loss: 0.1834978210956987 0.25896014214168334
epoch: 13
epoch: 13 loss: 0.12481006169551394
eval loss: 0.1706398614272968 0.2420561126902718
epoch: 14
epoch: 14 loss: 0.11247162721515047


eval loss: 0.16257751945686733 0.23104289055126137
epoch: 15
epoch: 15 loss: 0.10405455884405264


eval loss: 0.15991194233900116 0.22735849453996687
epoch: 16
epoch: 16 loss: 0.10137389464879304


eval loss: 0.15865622246832736 0.22563775186488466
epoch: 17
epoch: 17 loss: 0.10216809420326534
eval loss: 0.1602854243442162 0.2269728588726134
epoch: 18
epoch: 18 loss: 0.09950936580824316


eval loss: 0.15797951207235728 0.22440517000252647
epoch: 19
epoch: 19 loss: 0.09756186521908579


eval loss: 0.1572311370742962 0.22380806554118232
epoch: 20
epoch: 20 loss: 0.09988150146427523


eval loss: 0.1568533048022231 0.22264407296759015
epoch: 21
epoch: 21 loss: 0.09894686571349585
eval loss: 0.16051410587014495 0.22765745143960242
epoch: 22
epoch: 22 loss: 0.10237990178957314
eval loss: 0.15818852208667464 0.22482123725366998
epoch: 23
epoch: 23 loss: 0.0949001381209889


eval loss: 0.15494915668101955 0.2202580211476583
epoch: 24
epoch: 24 loss: 0.09246699415940372


eval loss: 0.15373540900300212 0.21830846161399242
epoch: 25
epoch: 25 loss: 0.09163225795725195


eval loss: 0.15351554922030347 0.2180543495536564
epoch: 26
epoch: 26 loss: 0.0899286904363221


eval loss: 0.1525326993534609 0.21682270051906094
epoch: 27
epoch: 27 loss: 0.08933683019007511
eval loss: 0.15259257276351534 0.21693479051810693
epoch: 28
epoch: 28 loss: 0.08821828088291353


eval loss: 0.15188624177967164 0.21547394779329387
epoch: 29
epoch: 29 loss: 0.08741298157966748


eval loss: 0.15174258538718532 0.21600332232221356
epoch: 30
epoch: 30 loss: 0.08921844954947984
eval loss: 0.15489585490159136 0.21997611491300859
epoch: 31
epoch: 31 loss: 0.09361306514520985
eval loss: 0.15871755201411566 0.22469935536855967
epoch: 32
epoch: 32 loss: 0.10820948684488106
eval loss: 0.23226745931091913 0.3264969201207637
epoch: 33
epoch: 33 loss: 0.12845488751955564
eval loss: 0.18166220080153425 0.2555124198954886
epoch: 34
epoch: 34 loss: 0.12240139587725163
eval loss: 0.22022796346927193 0.3072146496122934
epoch: 35
epoch: 35 loss: 0.14546333087610613
eval loss: 0.183564339020873 0.25933968324094225
epoch: 36
epoch: 36 loss: 0.11907244858185337
eval loss: 0.1611495418378268 0.2280247596978115
epoch: 37
epoch: 37 loss: 0.09725705685590798
eval loss: 0.15476387238109005 0.21953420998409423
epoch: 38
epoch: 38 loss: 0.0899617632976229
eval loss: 0.15223141925781286 0.21619092074046656
epoch: 39
epoch: 39 loss: 0.0879564682063333
eval loss: 0.15174619835897804 0.214836

eval loss: 0.15170938507827444 0.2153080834019533
epoch: 43
epoch: 43 loss: 0.08722946755121375


eval loss: 0.1513173512493379 0.2144563306910257
epoch: 44
epoch: 44 loss: 0.08316457605766085


eval loss: 0.1502836576580044 0.21343520745071554
epoch: 45
epoch: 45 loss: 0.08250454278926725
eval loss: 0.15168532852296476 0.21461554174541164
epoch: 46
epoch: 46 loss: 0.08330500412101716


eval loss: 0.14998232652076848 0.21231635149881345
epoch: 47
epoch: 47 loss: 0.08148964118340139
eval loss: 0.15085586492954295 0.2139350229036206
epoch: 48
epoch: 48 loss: 0.08550705491517908
eval loss: 0.15149598845236836 0.21486822717081336
epoch: 49
epoch: 49 loss: 0.08024550149466195


eval loss: 0.14994062148684467 0.2125610639671116
epoch: 50
epoch: 50 loss: 0.08232475354676938


eval loss: 0.14898649745066386 0.21158178663776644
epoch: 51
epoch: 51 loss: 0.08449402850395704
eval loss: 0.15415241489753018 0.21834847910724364
epoch: 52
epoch: 52 loss: 0.09454252345014136
eval loss: 0.15384900703917445 0.2176564978719071
epoch: 53
epoch: 53 loss: 0.09798771750010828
eval loss: 0.17182953791857777 0.24088143876552007
epoch: 54
epoch: 54 loss: 0.1798150153972822
eval loss: 0.30218118847670394 0.41950147150068423
epoch: 55
epoch: 55 loss: 0.19540211495562806
eval loss: 0.2549729346478726 0.35476433209609287
epoch: 56
epoch: 56 loss: 0.14304057193366676
eval loss: 0.19814231572960628 0.27862812504768675
epoch: 57
epoch: 57 loss: 0.11070869790839813
eval loss: 0.1550825360079942 0.2200011036097438
epoch: 58
epoch: 58 loss: 0.09012958632048676
eval loss: 0.152252462898259 0.21588831349996976
epoch: 59
epoch: 59 loss: 0.09807063841169283
eval loss: 0.1575332492313526 0.22288476838261628
epoch: 60
epoch: 60 loss: 0.09680505777637356
eval loss: 0.1546795034212294 0.219317

In [12]:
oof_df = pd.DataFrame(oof_list, columns=["id_seqpos"] + target_cols)
oof_df.to_csv('validation_aepytorch.csv', index=False)

### Prediction

In [13]:
device = torch.device('cuda') if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 1
base_test_data = pd.read_json(str(Path(BASE_PATH) / 'test.json'), lines=True)
base_test_data = aug_data(base_test_data)

public_df = base_test_data.query("seq_length == 107").copy()
private_df = base_test_data.query("seq_length == 130").copy()
print(f"public_df: {public_df.shape}")
print(f"private_df: {private_df.shape}")
public_df = public_df.reset_index()
private_df = private_df.reset_index()
pub_loader = create_loader(public_df, BATCH_SIZE, is_test=True)
pri_loader = create_loader(private_df, BATCH_SIZE, is_test=True)
pred_df_list = []
c = 0
for fold in range(5):
    model_load_path = f"./model_prediction/model-{fold}.pt"
    ae_model0 = AEModel()
    ae_model1 = AEModel()
    model_pub = FromAeModel(pred_len=107, seq=ae_model0.seq)
    model_pub = model_pub.to(device)
    model_pri = FromAeModel(pred_len=130, seq=ae_model1.seq)
    model_pri = model_pri.to(device)
    state_dict = torch.load(model_load_path, map_location=device)
    model_pub.load_state_dict(state_dict)
    model_pri.load_state_dict(state_dict)
    del state_dict

    data_list = []
    data_list += predict_data(model_pub, pub_loader, device, BATCH_SIZE)
    data_list += predict_data(model_pri, pri_loader, device, BATCH_SIZE)
    pred_df = pd.DataFrame(data_list, columns=["id_seqpos"] + target_cols)
    #print(pred_df.head())
    #print(pred_df.tail())
    pred_df_list.append(pred_df)
    c += 1
data_dic = dict(id_seqpos=pred_df_list[0]["id_seqpos"])
for col in target_cols:
    vals = np.zeros(pred_df_list[0][col].shape[0])
    for df in pred_df_list:
        vals += df[col].values
    data_dic[col] = vals / float(c)
pred_df_avg = pd.DataFrame(data_dic, columns=["id_seqpos"] + target_cols)
#print(pred_df_avg.head())
#pred_df_avg.to_csv("./submission_all.csv", index=False)

##### submission final
sample_df = pd.read_csv('../OpenVaccine/sample_submission.csv')
target_cols = [c for c in sample_df.columns if c != 'id_seqpos']
list_id = list(sample_df.id_seqpos.values)
output = {}
output_df = pd.DataFrame({'id_seqpos': sample_df.id_seqpos.values})

for c in target_cols:
    output_values = []
    x = pred_df_avg.groupby('id_seqpos')[c].mean().reset_index()
    print(x.shape)
    output_df = pd.merge(output_df, x, on='id_seqpos')
    
output_df.to_csv('submission_aepytorch.csv', index=False)
print(output_df.shape)

public_df: (1258, 10)
private_df: (6010, 10)
(1, 107, 107, 3)
(1, 130, 130, 3)


(457953, 2)
(457953, 2)
(457953, 2)
(457953, 2)
(457953, 2)
(457953, 6)


In [14]:
output_df.head(10)

Unnamed: 0,id_seqpos,reactivity,deg_Mg_pH10,deg_pH10,deg_Mg_50C,deg_50C
0,id_00073f8be_0,0.647781,0.634589,1.88013,0.497652,0.708136
1,id_00073f8be_1,1.988751,2.927485,3.775051,2.9225,2.57154
2,id_00073f8be_2,1.441579,0.613249,0.685714,0.71315,0.67665
3,id_00073f8be_3,1.19209,1.029909,1.09555,1.474018,1.638364
4,id_00073f8be_4,0.809244,0.525973,0.460842,0.805732,0.854745
5,id_00073f8be_5,0.705192,0.513918,0.53422,0.657263,0.705889
6,id_00073f8be_6,0.739668,0.949133,0.950112,1.003862,0.979089
7,id_00073f8be_7,0.921899,0.937763,1.101304,0.940724,1.34659
8,id_00073f8be_8,0.196084,0.90614,0.817736,0.971482,0.70886
9,id_00073f8be_9,0.052724,0.208538,0.265326,0.280487,0.396252
