### Jane Street 1st place solution -> Crypto

In this Notebook, I adapt Yirun's 1st place solution from the competion Jane Street to crypto. There is already an excellent kernel made by Yam Peleg with tensorflow, but I decided to make another one using pytorch. In opposite to Yam Peleg's kernel, I am training a single model for all assets.

Yirun Zhang's original solution: https://www.kaggle.com/gogo827jz/jane-street-supervised-autoencoder-mlp 

Yam Peleg's adaptation: https://www.kaggle.com/yamqwe/1st-place-of-jane-street-adapted-to-crypto

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, SequentialSampler
import time
from torch.optim import AdamW
import warnings
import random
import os

#!mkdir ../checkpoints

## PurgedGroupTimeSeriesSplit

See https://www.kaggle.com/yamqwe/purgedgrouptimeseries-cv-with-extra-data-tabnet

In [None]:
from sklearn.model_selection._split import _BaseKFold, indexable, _num_samples
from sklearn.utils.validation import _deprecate_positional_args

# https://github.com/getgaurav2/scikit-learn/blob/d4a3af5cc9da3a76f0266932644b884c99724c57/sklearn/model_selection/_split.py#L2243
class GroupTimeSeriesSplit(_BaseKFold):
    """Time Series cross-validator variant with non-overlapping groups.
    Provides train/test indices to split time series data samples
    that are observed at fixed time intervals according to a
    third-party provided group.
    In each split, test indices must be higher than before, and thus shuffling
    in cross validator is inappropriate.
    This cross-validation object is a variation of :class:`KFold`.
    In the kth split, it returns first k folds as train set and the
    (k+1)th fold as test set.
    The same group will not appear in two different folds (the number of
    distinct groups has to be at least equal to the number of folds).
    Note that unlike standard cross-validation methods, successive
    training sets are supersets of those that come before them.
    Read more in the :ref:`User Guide <cross_validation>`.
    Parameters
    ----------
    n_splits : int, default=5
        Number of splits. Must be at least 2.
    max_train_size : int, default=None
        Maximum size for a single training set.
    Examples
    --------
    >>> import numpy as np
    >>> from sklearn.model_selection import GroupTimeSeriesSplit
    >>> groups = np.array(['a', 'a', 'a', 'a', 'a', 'a',\
                           'b', 'b', 'b', 'b', 'b',\
                           'c', 'c', 'c', 'c',\
                           'd', 'd', 'd'])
    >>> gtss = GroupTimeSeriesSplit(n_splits=3)
    >>> for train_idx, test_idx in gtss.split(groups, groups=groups):
    ...     print("TRAIN:", train_idx, "TEST:", test_idx)
    ...     print("TRAIN GROUP:", groups[train_idx],\
                  "TEST GROUP:", groups[test_idx])
    TRAIN: [0, 1, 2, 3, 4, 5] TEST: [6, 7, 8, 9, 10]
    TRAIN GROUP: ['a' 'a' 'a' 'a' 'a' 'a']\
    TEST GROUP: ['b' 'b' 'b' 'b' 'b']
    TRAIN: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] TEST: [11, 12, 13, 14]
    TRAIN GROUP: ['a' 'a' 'a' 'a' 'a' 'a' 'b' 'b' 'b' 'b' 'b']\
    TEST GROUP: ['c' 'c' 'c' 'c']
    TRAIN: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]\
    TEST: [15, 16, 17]
    TRAIN GROUP: ['a' 'a' 'a' 'a' 'a' 'a' 'b' 'b' 'b' 'b' 'b' 'c' 'c' 'c' 'c']\
    TEST GROUP: ['d' 'd' 'd']
    """
    @_deprecate_positional_args
    def __init__(self,
                 n_splits=5,
                 *,
                 max_train_size=None
                 ):
        super().__init__(n_splits, shuffle=False, random_state=None)
        self.max_train_size = max_train_size

    def split(self, X, y=None, groups=None):
        """Generate indices to split data into training and test set.
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Training data, where n_samples is the number of samples
            and n_features is the number of features.
        y : array-like of shape (n_samples,)
            Always ignored, exists for compatibility.
        groups : array-like of shape (n_samples,)
            Group labels for the samples used while splitting the dataset into
            train/test set.
        Yields
        ------
        train : ndarray
            The training set indices for that split.
        test : ndarray
            The testing set indices for that split.
        """
        if groups is None:
            raise ValueError(
                "The 'groups' parameter should not be None")
        X, y, groups = indexable(X, y, groups)
        n_samples = _num_samples(X)
        n_splits = self.n_splits
        n_folds = n_splits + 1
        group_dict = {}
        u, ind = np.unique(groups, return_index=True)
        unique_groups = u[np.argsort(ind)]
        n_samples = _num_samples(X)
        n_groups = _num_samples(unique_groups)
        for idx in np.arange(n_samples):
            if (groups[idx] in group_dict):
                group_dict[groups[idx]].append(idx)
            else:
                group_dict[groups[idx]] = [idx]
        if n_folds > n_groups:
            raise ValueError(
                ("Cannot have number of folds={0} greater than"
                 " the number of groups={1}").format(n_folds,
                                                     n_groups))
        group_test_size = n_groups // n_folds
        group_test_starts = range(n_groups - n_splits * group_test_size,
                                  n_groups, group_test_size)
        for group_test_start in group_test_starts:
            train_array = []
            test_array = []
            for train_group_idx in unique_groups[:group_test_start]:
                train_array_tmp = group_dict[train_group_idx]
                train_array = np.sort(np.unique(
                                      np.concatenate((train_array,
                                                      train_array_tmp)),
                                      axis=None), axis=None)
            train_end = train_array.size
            if self.max_train_size and self.max_train_size < train_end:
                train_array = train_array[train_end -
                                          self.max_train_size:train_end]
            for test_group_idx in unique_groups[group_test_start:
                                                group_test_start +
                                                group_test_size]:
                test_array_tmp = group_dict[test_group_idx]
                test_array = np.sort(np.unique(
                                              np.concatenate((test_array,
                                                              test_array_tmp)),
                                     axis=None), axis=None)
            yield [int(i) for i in train_array], [int(i) for i in test_array]
import numpy as np
from sklearn.model_selection import KFold
from sklearn.model_selection._split import _BaseKFold, indexable, _num_samples
from sklearn.utils.validation import _deprecate_positional_args

# modified code for group gaps; source
# https://github.com/getgaurav2/scikit-learn/blob/d4a3af5cc9da3a76f0266932644b884c99724c57/sklearn/model_selection/_split.py#L2243
class PurgedGroupTimeSeriesSplit(_BaseKFold):
    """Time Series cross-validator variant with non-overlapping groups.
    Allows for a gap in groups to avoid potentially leaking info from
    train into test if the model has windowed or lag features.
    Provides train/test indices to split time series data samples
    that are observed at fixed time intervals according to a
    third-party provided group.
    In each split, test indices must be higher than before, and thus shuffling
    in cross validator is inappropriate.
    This cross-validation object is a variation of :class:`KFold`.
    In the kth split, it returns first k folds as train set and the
    (k+1)th fold as test set.
    The same group will not appear in two different folds (the number of
    distinct groups has to be at least equal to the number of folds).
    Note that unlike standard cross-validation methods, successive
    training sets are supersets of those that come before them.
    Read more in the :ref:`User Guide <cross_validation>`.
    Parameters
    ----------
    n_splits : int, default=5
        Number of splits. Must be at least 2.
    max_train_group_size : int, default=Inf
        Maximum group size for a single training set.
    group_gap : int, default=None
        Gap between train and test
    max_test_group_size : int, default=Inf
        We discard this number of groups from the end of each train split
    """

    @_deprecate_positional_args
    def __init__(self,
                 n_splits=5,
                 *,
                 max_train_group_size=np.inf,
                 max_test_group_size=np.inf,
                 group_gap=None,
                 verbose=False
                 ):
        super().__init__(n_splits, shuffle=False, random_state=None)
        self.max_train_group_size = max_train_group_size
        self.group_gap = group_gap
        self.max_test_group_size = max_test_group_size
        self.verbose = verbose

    def split(self, X, y=None, groups=None):
        """Generate indices to split data into training and test set.
        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            Training data, where n_samples is the number of samples
            and n_features is the number of features.
        y : array-like of shape (n_samples,)
            Always ignored, exists for compatibility.
        groups : array-like of shape (n_samples,)
            Group labels for the samples used while splitting the dataset into
            train/test set.
        Yields
        ------
        train : ndarray
            The training set indices for that split.
        test : ndarray
            The testing set indices for that split.
        """
        if groups is None:
            raise ValueError(
                "The 'groups' parameter should not be None")
        X, y, groups = indexable(X, y, groups)
        n_samples = _num_samples(X)
        n_splits = self.n_splits
        group_gap = self.group_gap
        max_test_group_size = self.max_test_group_size
        max_train_group_size = self.max_train_group_size
        n_folds = n_splits + 1
        group_dict = {}
        u, ind = np.unique(groups, return_index=True)
        unique_groups = u[np.argsort(ind)]
        n_samples = _num_samples(X)
        n_groups = _num_samples(unique_groups)
        for idx in np.arange(n_samples):
            if (groups[idx] in group_dict):
                group_dict[groups[idx]].append(idx)
            else:
                group_dict[groups[idx]] = [idx]
        if n_folds > n_groups:
            raise ValueError(
                ("Cannot have number of folds={0} greater than"
                 " the number of groups={1}").format(n_folds,
                                                     n_groups))

        group_test_size = min(n_groups // n_folds, max_test_group_size)
        group_test_starts = range(n_groups - n_splits * group_test_size,
                                  n_groups, group_test_size)
        for group_test_start in group_test_starts:
            train_array = []
            test_array = []

            group_st = max(0, group_test_start - group_gap - max_train_group_size)
            for train_group_idx in unique_groups[group_st:(group_test_start - group_gap)]:
                train_array_tmp = group_dict[train_group_idx]

                train_array = np.sort(np.unique(
                                      np.concatenate((train_array,
                                                      train_array_tmp)),
                                      axis=None), axis=None)

            train_end = train_array.size

            for test_group_idx in unique_groups[group_test_start:
                                                group_test_start +
                                                group_test_size]:
                test_array_tmp = group_dict[test_group_idx]
                test_array = np.sort(np.unique(
                                              np.concatenate((test_array,
                                                              test_array_tmp)),
                                     axis=None), axis=None)

            test_array  = test_array[group_gap:]


            if self.verbose > 0:
                    pass

            yield [int(i) for i in train_array], [int(i) for i in test_array]

## Seed

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

## Config

Kaggle CPU is a bottleneck, so we train our network only on the first fold. If you want to experiment with all folds, then set train_only_first_fold = False.

In [None]:
class GlobalConfig:
    debug = False
            
    device = "cuda"
    
    # PurgedGroupTimeSeriesSplit
    num_folds = 5
    group_gap = 31 if not debug else 1
    train_only_first_fold = True
    
    
    # dataloader
    train_batch_size = 4096
    val_batch_size = 4096
    num_workers = 2
    
    
    # optimizer
    lr = 1e-2

## DataFrame

In [None]:
def create_df():
    df = pd.read_csv("../input/g-research-crypto-forecasting/train.csv")
    df = df.merge(pd.read_csv("../input/g-research-crypto-forecasting/asset_details.csv"), on="Asset_ID", how="left")
    
    if GlobalConfig.debug:
        df = df.iloc[:200000]
    
    df = df.set_index("timestamp")
    
    df = preprocess_df(df)
    df = add_row_features(df)
    df = add_date(df)
    df = add_group(df)
    df = df.dropna()

    
    return df




def add_row_features(df):
    
    df['upper_shadow'] = df['High'] / df[['Close', 'Open']].max(axis=1)
    df['lower_shadow'] = df[['Close', 'Open']].min(axis=1) / df['Low']
    df['open2close'] = df['Close'] / df['Open']
    df['high2low'] = df['High'] / df['Low']
    mean_price = df[['Open', 'High', 'Low', 'Close']].mean(axis=1)
    median_price = df[['Open', 'High', 'Low', 'Close']].median(axis=1)
    df['high2mean'] = df['High'] / mean_price
    df['low2mean'] = df['Low'] / mean_price
    df['high2median'] = df['High'] / median_price
    df['low2median'] = df['Low'] / median_price
    df['volume2count'] = df['Volume'] / (df['Count'] + 1)
    
    return df   



def preprocess_df(df):
    dfs = []
    df = df.sort_values(by="timestamp")
    
    for asset_id in df.Asset_ID.unique():
        asset_df = df[df.Asset_ID == asset_id]
        asset_df.reindex(range(asset_df.index[0],asset_df.index[-1]+60,60),method='pad')
        asset_df = asset_df.replace([np.inf, -np.inf], np.nan)
        asset_df = asset_df.fillna(method="ffill")  

        dfs.append(asset_df)
        
        
    df = pd.concat(dfs)
    df = df.sort_values(by="timestamp")
    
    return df



def add_date(df):
    df['date'] = pd.to_datetime(df.index, unit='s')        
    
    return df



def add_group(df):
    groups = pd.factorize(df['date'].dt.day.astype(str) + '_' + df['date'].dt.month.astype(str) +\
                                                          '_' + df['date'].dt.year.astype(str))[0]
    
    df["group"] = groups
    
    return df

## CryptoDataset

In [None]:
class CryptoDataset:
    def __init__(self, df, is_train=True):
        self.df = df
        self.is_train = is_train
        self.features = ['Asset_ID', 'upper_shadow', 'lower_shadow', 'open2close', 'high2low', 'high2mean', 'low2mean',\
                                         'high2median', 'low2median', 'volume2count']
        
        
    def __len__(self):
        return self.df.shape[0]
    
    
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        
        features = row[self.features].values.astype(np.float32)
        weight = row["Weight"].astype(np.float32)
        
        if self.is_train:
            target = row["Target"].astype(np.float32)
            return {"features" : features, "target" : target, "weight" : weight}
        
        else:
            return {"features" : features, "weight" : weight}

## DataLoader

In [None]:
def get_dataloader(df, batch_size, is_train=True):
    crypto_dataset = CryptoDataset(df, is_train=is_train)
    dataloader = DataLoader(crypto_dataset, sampler=SequentialSampler(crypto_dataset), batch_size=batch_size, pin_memory=True, num_workers=GlobalConfig.num_workers)
    
    return dataloader

## Architectures

In [None]:
def make_model():
    
    # Preprocessing Layer
    num_features = 10
    num_embeddings = 14
    embedding_dim = 10
    
    
    
    # Encoder
    encoder_input = num_features + embedding_dim - 1
    encoder_output = int((num_features + embedding_dim - 1) * 0.8)
    
    
    
    # Decoder
    decoder_input = encoder_output
    decoder_hidden = 80
    decoder_output = num_features - 1
    
    
    
    # Head Decoder
    head_decoder_input = decoder_output
    head_decoder_hidden = int(decoder_output * 1.5)

    
    # FFN
    ffn_input = encoder_input + decoder_input 
    num_blocks = 2
    
    
    preprocessing_layer = PreprocessingLayer(num_features, num_embeddings, embedding_dim)
    encoder = Encoder(encoder_input, encoder_output)
    decoder = Decoder(decoder_input, decoder_hidden, decoder_output)
    head_decoder = Head(head_decoder_input, head_decoder_hidden)
    ffn = FFNetwork(ffn_input, num_blocks)

    
    main_network = MainNetwork(preprocessing_layer, encoder, decoder, head_decoder, ffn)
    
    return main_network

#### Main Network

In [None]:
class MainNetwork(nn.Module):
    def __init__(self, preprocessing_layer, encoder, decoder, head_decoder, ffn):
        super(MainNetwork, self).__init__()
        
        self.preprocessing_layer = preprocessing_layer
        
        self.encoder = encoder
        self.decoder = decoder
        self.head_decoder = head_decoder

        self.ffn = ffn
        
    
        
        
    def forward(self, x):
        x = self.preprocessing_layer(x)
        
        out_encoder = self.encoder(x)
        
        target_predictions = self.ffn(torch.cat([x, out_encoder], dim=1) )
        features_reconstruction = self.decoder(out_encoder)
        target_predictions_head_decoder = self.head_decoder(features_reconstruction)
        
        
        return target_predictions, target_predictions_head_decoder, features_reconstruction
    

#### Encoder

In [None]:
class Encoder(nn.Module): 
    def __init__(self, num_input_features, num_output_features):
        super(Encoder, self).__init__()
        
        self.fcl = nn.Linear(num_input_features, num_output_features)
        self.bn = nn.BatchNorm1d(num_output_features)
        

    def forward(self, x):
        #x = add_gaussian_noise(x)
        x = self.fcl(x)
        x = swish(self.bn(x))
        
        return x

#### Decoder

In [None]:
class Decoder(nn.Module):  
    def __init__(self, num_input_features, num_hidden_units, num_output_features):
        super(Decoder, self).__init__()
        self.dropout = nn.Dropout(p=0.1)
        
        
        #self.fcl = nn.Linear(num_input_features, num_output_features) 
    
        self.fcl = nn.Linear(num_input_features, num_hidden_units)
        self.fcl_output = nn.Linear(num_hidden_units, num_output_features)
        
    def forward(self, x):
        x = self.dropout(x)
        #x = self.fcl(x)
        x = swish(self.fcl(x))
        x = self.fcl_output(x)
        
        
        return x

#### Head (decoder)

In [None]:
class Head(nn.Module): 
    def __init__(self, num_input_features, num_hidden_units):
        super(Head, self).__init__()
        self.fcl = nn.Linear(num_input_features, num_hidden_units)
        self.bn = nn.BatchNorm1d(num_hidden_units)
        self.dropout = nn.Dropout(p=0.5)
        self.fcl_output = nn.Linear(num_hidden_units, 1)
        

    
    
    def forward(self, x):
        x = self.fcl(x)
        x = swish(self.bn(x))
        x = self.fcl_output(self.dropout(x))            
        
        return x

#### FFNetwork

In [None]:
class FFNetwork(nn.Module): 
    def __init__(self, num_hidden_units, num_blocks):
        super(FFNetwork, self).__init__()
        self.bn = nn.BatchNorm1d(num_hidden_units)
        self.dropout = nn.Dropout(p=0.5)
        self.backbone = nn.Sequential(*[BlockFFNetwork(num_hidden_units) for i in range(num_blocks)])
        
        self.fcl_output = nn.Linear(num_hidden_units, 1)
        
        
    def forward(self, x):
        x = self.dropout(self.bn(x))
        x = self.backbone(x)
        x = self.fcl_output(x)
        
        return x

    
    
        
class BlockFFNetwork(nn.Module):
    def __init__(self, num_hidden_units):
        super(BlockFFNetwork, self).__init__()
        
        self.fcl = nn.Linear(num_hidden_units, num_hidden_units)
        self.bn = nn.BatchNorm1d(num_hidden_units)
        
        self.dropout = nn.Dropout(p=0.5)

    
    
    def forward(self, x):
        x = self.fcl(x)
        x = swish(self.bn(x))
        x = self.dropout(x)
               
        return x

#### Preprocessing Layer

In [None]:
class PreprocessingLayer(nn.Module): 
    def __init__(self, num_features, num_embeddings, embedding_dim=10):
        super(PreprocessingLayer, self).__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.bn = nn.BatchNorm1d(num_features + embedding_dim - 1)
                    
            
    def forward(self, x):
        
        ids = x[:,0].long()
        remaining_features = x[:, 1:]
        
        embedded_ids = self.embedding(ids)
        
        output = self.bn(torch.cat([embedded_ids, remaining_features], dim=1))
        
        return output

#### Helper functions

In [None]:
def add_gaussian_noise(x):              
    return x + torch.randn(*x.shape).to(GlobalConfig.device)         #naive implementation, since different features have different scale
    
    
def swish(x):        
    return x * torch.sigmoid(x)

## Metrics

I took the weighted correlation function from https://www.kaggle.com/code1110/gresearch-simple-lgb-starter

AverageMeter and CorrelationMeter are inspired from the work of Alex Shonenkov, see https://www.kaggle.com/shonenkov/training-cv-melanoma-starter.

In [None]:
def corr(x, y, w):
    """Weighted Correlation"""
    return cov(x, y, w) / np.sqrt(cov(x, x, w) * cov(y, y, w))

def cov(x, y, w):
    """Weighted Covariance"""
    return np.sum(w * (x - m(x, w)) * (y - m(y, w))) / np.sum(w)

def m(x, w):
    """Weighted Mean"""

    return np.sum(x * w) / np.sum(w)




def corr_torch(x, y, w):
    """Weighted Correlation"""
    return cov_torch(x, y, w) / torch.sqrt(cov_torch(x, x, w) * cov_torch(y, y, w))

def cov_torch(x, y, w):
    """Weighted Covariance"""
    return torch.sum(w * (x - m_torch(x, w)) * (y - m_torch(y, w))) / torch.sum(w)


def m_torch(x, w):
    """Weighted Mean"""
    return torch.sum(x * w) / torch.sum(w)




class AverageMeter:
    def __init__(self):
        self.reset()

        
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        
        
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        
        self.avg = self.sum / self.count
        
        
        
    
class CorrelationMeter:
    def __init__(self):
        self.reset()
        
        
    
    def reset(self):
        self.predictions = np.array([])
        self.targets = np.array([])
        self.weights = np.array([])
        self.avg = 0
        
        
    
    def update(self, y_true, y_pred, weights):
        self.predictions = np.hstack([self.predictions, y_pred.detach().cpu().squeeze().numpy()])
        self.targets = np.hstack([self.targets, y_true.cpu().numpy()])
        self.weights = np.hstack([self.weights, weights.cpu().numpy()])
        
        
        self.avg = corr(self.predictions, self.targets, self.weights)

## Loss 

In [None]:
def criterion(target_predictions, target_predictions_head_decoder, targets,\
                                                            features_reconstruction, features, weights):
    
    return -(corr_torch(target_predictions.squeeze(), targets, weights) +\
             corr_torch(target_predictions_head_decoder.squeeze(), targets, weights))  

    #return F.mse_loss(target_predictions.squeeze(), targets) + F.mse_loss(target_predictions_head_decoder.squeeze(), targets)+\
        #F.mse_loss( features_reconstruction, features[:,1:] )

## Train function

Training function inspired from https://www.kaggle.com/shonenkov/training-cv-melanoma-starter

In [None]:
#return target_predictions, target_predictions_head_decoder, features_reconstruction
def train(df, num_epochs=10):
    
    groups = df["group"].values
    
    for fold, (train_idx, val_idx) in \
                    enumerate(PurgedGroupTimeSeriesSplit(n_splits=GlobalConfig.num_folds,\
                            group_gap=GlobalConfig.group_gap).split(df, y=None, groups=groups)):

        if GlobalConfig.train_only_first_fold and fold > 0:
            break
            
            
        if not os.path.exists(f"../working/checkpoints/fold_{fold}"):
            os.makedirs(f"../working/checkpoints/fold_{fold}")
        
        print("% of data used for training: ", len(train_idx)/df.shape[0] * 100, "% of data used for validation: ",  len(val_idx)/df.shape[0] * 100)    
        
        model = make_model().to(GlobalConfig.device)
        optimizer = AdamW(model.parameters(), lr=GlobalConfig.lr)
        
        
        
        train_loader = get_dataloader(df.iloc[train_idx], batch_size=GlobalConfig.train_batch_size, is_train=True)
        val_loader = get_dataloader(df.iloc[val_idx], batch_size=GlobalConfig.val_batch_size, is_train=True)
        
        
        best_final_scores = -10
        
        for epoch in range(num_epochs):

            t = time.time()
            summary_loss_train, final_scores_train = train_one_epoch(model, optimizer, train_loader)
            print(f'[RESULT]: Train. Epoch: {epoch}, summary_loss: {summary_loss_train.avg:.5f}, final_score: {final_scores_train.avg:.5f}, time: {(time.time() - t):.5f}')

            
            
            summary_loss_val, final_scores_val = validation_one_epoch(model, val_loader)
            print(f'[RESULT]: Val. Epoch: {epoch}, summary_loss: {summary_loss_val.avg:.5f}, final_score: {final_scores_val.avg:.5f}')
    
            
        
            model.eval()

            
            save_checkpoint(model, optimizer, fold, name="last_checkpoint")
                        
            if final_scores_val.avg > best_final_scores: 
                best_final_scores = final_scores_val.avg
                save_checkpoint(model, optimizer, fold, name=f"score_{best_final_scores}")
                

        
        
        

        
    
def train_one_epoch(model, optimizer, train_loader):
    
    model.train()        
    summary_loss = AverageMeter()
    final_scores = CorrelationMeter()
    t = time.time()
    
    for step, batch in enumerate(train_loader):
        if  (step % 2 == 0):
            print(
                f'Train Step {step}/{len(train_loader)}, ' + \
                f'summary_loss: {summary_loss.avg:.5f}, final_score: {final_scores.avg:.5f}, ' + \
                f'time: {(time.time() - t):.5f}', end='\r'
                )
            
       
        weights = batch["weight"].to(GlobalConfig.device).float()
        features = batch["features"].to(GlobalConfig.device).float()
        targets = batch["target"].to(GlobalConfig.device).float()

                
        batch_size = weights.shape[0]
                      
 
        optimizer.zero_grad()     

        target_predictions, target_predictions_head_decoder, features_reconstruction = model(features)
        
        
        
        
        loss = criterion(target_predictions, target_predictions_head_decoder, targets,\
                                                            features_reconstruction, features, weights)

        loss.backward()
        
        
        
        
        summary_loss.update(loss.detach().cpu().item(), batch_size)
        final_scores.update(targets, target_predictions, weights)
            
        optimizer.step()

            
                            
    return summary_loss, final_scores




def validation_one_epoch(model, val_loader):
        model.eval()
        
        summary_loss = AverageMeter()
        final_scores = CorrelationMeter()

        for step, batch in enumerate(val_loader):
            
            if step % 10 == 0:
                print(
                    f'Val Step {step}/{len(val_loader)}, ' + \
                    f'summary_loss: {summary_loss.avg:.5f}, final_score: {final_scores.avg:.5f} ', end='\r')

            
            
            weights = batch["weight"].to(GlobalConfig.device).float()
            features = batch["features"].to(GlobalConfig.device).float()
            targets = batch["target"].to(GlobalConfig.device).float()

                            
            batch_size = weights.shape[0]
                      
             
   
            with torch.no_grad():
                target_predictions, target_predictions_head_decoder, features_reconstruction = model(features)

            
            loss = criterion(target_predictions, target_predictions_head_decoder, targets,\
                                                            features_reconstruction, features, weights)

            
            summary_loss.update(loss.cpu().item(), batch_size)
            final_scores.update(targets, target_predictions, weights)
            
                
                
            
        return summary_loss, final_scores
        
        
        
def save_checkpoint(model, optimizer, fold, name):
    checkpoint = { "model_state_dict" : model.state_dict(), "optimizer_state_dict" : optimizer.state_dict()}
    torch.save(checkpoint, f"../working/checkpoints/fold_{fold}/{name}.pt")
    
    
    
    
def load_best_checkpoint(model, fold):
    name = sorted(os.listdir( f"../working/checkpoints/fold_{fold}" ),\
                  key = lambda name : name.split("_")[-1].split("."), reverse=True)[0]
    checkpoint = torch.load( f"../working/checkpoints/fold_{fold}/{name}.pt" )
    
    model.load_state_dict(checkpoint["model_state_dict"])
    
    
    
def load_last_checkpoint(model, optimizer, fold):
    checkpoint = torch.load( f"../working/checkpoints/fold_{fold}/last_checkpoint.pt" )
    
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

## Training

In [None]:
set_seed()
df = create_df()
df.head()

In [None]:
train(df, num_epochs=10)