In [1]:
import sys
from time import time
import numpy as np
import pandas as pd
from pathlib import Path
import lightgbm as lgb
import matplotlib.pyplot as plt 
import seaborn as sns
from tqdm import tqdm
import copy
import wandb
from collections import OrderedDict

from sklearn.metrics import mean_absolute_error
from sklearn import model_selection
from sklearn.preprocessing import StandardScaler, MinMaxScaler

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as torchdata

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import WandbLogger


In [2]:
sys.path.append('../../src/')
import utils as utils
from utils import Timer

In [3]:
class CFG:
    seed = 42
    exp_num = 4
    local = True
    n_folds = 5
    folds = [0]
    debug = False
    bias = 1000
    epochs = 200

    
    ######################
    # Dataset #
    ######################
    transforms = {
        "train": [{"name": ""}],
        "valid": [{"name": ""}],
        "test": [{"name": ""}]
    }

    ######################
    # Loaders #
    ######################
    loader_params = {
        "train": {
            'batch_size': 128,
            'shuffle': True,
            'num_workers': 8,
            'pin_memory': True,
            'drop_last': True,
        },
        "valid": {
            'batch_size': 32,
            'shuffle': False,
            'num_workers': 8,
            'pin_memory': True,
            'drop_last': False,
        },
        "test": {
            'batch_size': 32,
            'shuffle': False,
            'num_workers': 8,
            'pin_memory': True,
            'drop_last': False,
        }
    }

    ######################
    # Split #
    ######################
    split = "GroupKFold"
    split_params = {
        "n_splits": 5,
    }

    ######################
    # Model #
    ######################
    input_dim = 5

    dense_dim = 512
    lstm_dim = 512
    logit_dim = 512
    num_classes = 1

    ######################
    # Criterion #
    ######################
#     loss_name = "rmspe_loss"
#     loss_params: dict = {}

    ######################
    # Optimizer #
    ######################
    optimizer_name = "Adam"
    optimizer_params = {
        "lr": 0.001
    }

    ######################
    # Scheduler #
    ######################
    scheduler_name = "ReduceLROnPlateau"
    scheduler_params = {
        'factor': 0.2, 
        'patience': 7
    }

In [4]:
utils.set_seed(CFG.seed)

In [5]:
if CFG.local:
    DATA_DIR = Path("/home/knikaido/work/Ventilator-Pressure-Prediction/data/ventilator-pressure-prediction")
    OUTPUT_DIR = Path('./output/')
else:
    DATA_DIR = Path("../input/ventilator-pressure-prediction")
    OUTPUT_DIR = Path('')   

In [6]:
def get_transforms(phase: str):
    transforms = CFG.transforms
    if transforms is None:
        return None
    else:
        if transforms[phase] is None:
            return None
        trns_list = []
        for trns_conf in transforms[phase]:
            trns_name = trns_conf["name"]
            trns_params = {} if trns_conf.get("params") is None else \
                trns_conf["params"]
            if globals().get(trns_name) is not None:
                trns_cls = globals()[trns_name]
                trns_list.append(trns_cls(**trns_params))

        if len(trns_list) > 0:
            return Compose(trns_list)
        else:
            return None
        
        
class Normalize:
    def __call__(self, y: np.ndarray):
        max_vol = np.abs(y).max()
        y_vol = y * 1 / max_vol
        return np.asfortranarray(y_vol)


class Compose:
    def __init__(self, transforms: list):
        self.transforms = transforms

    def __call__(self, y: np.ndarray):
        for trns in self.transforms:
            y = trns(y)
        return y

In [7]:
def compute_metric(preds, trues, u_outs):
    """
    Metric for the problem, as I understood it.
    """
    
    y = trues
    w = 1 - u_outs
    
    assert y.shape == preds.shape and w.shape == y.shape, (y.shape, preds.shape, w.shape)
    
    mae = w * np.abs(y - preds)
    mae = mae.sum() / w.sum()
    
    return mae


class VentilatorLoss(nn.Module):
    """
    Directly optimizes the competition metric
    """
    def __call__(self, preds, y, u_out):
        w = 1 - u_out
        mae = w * (y - preds).abs()
        mae = mae.sum(-1) / w.sum(-1)

        return mae

In [8]:
def get_criterion():
    return VentilatorLoss()

In [9]:
# Custom optimizer
__OPTIMIZERS__ = {}


def get_optimizer(model: nn.Module):
    optimizer_name = CFG.optimizer_name
    if optimizer_name == "SAM":
        base_optimizer_name = CFG.base_optimizer
        if __OPTIMIZERS__.get(base_optimizer_name) is not None:
            base_optimizer = __OPTIMIZERS__[base_optimizer_name]
        else:
            base_optimizer = optim.__getattribute__(base_optimizer_name)
        return SAM(model.parameters(), base_optimizer, **CFG.optimizer_params)

    if __OPTIMIZERS__.get(optimizer_name) is not None:
        return __OPTIMIZERS__[optimizer_name](model.parameters(),
                                              **CFG.optimizer_params)
    else:
        return optim.__getattribute__(optimizer_name)(model.parameters(),
                                                      **CFG.optimizer_params)


def get_scheduler(optimizer):
    scheduler_name = CFG.scheduler_name

    if scheduler_name is None:
        return
    else:
        return optim.lr_scheduler.__getattribute__(scheduler_name)(
            optimizer, **CFG.scheduler_params)

In [10]:
# validation
splitter = getattr(model_selection, CFG.split)(**CFG.split_params)

In [11]:
class VentilatorDataset(torchdata.Dataset):
    def __init__(self, df, train_value_col, train_category_col):
        if "pressure" not in df.columns:
            df['pressure'] = 0
        self.df = df
        self.groups = df.groupby('breath_id').groups
        self.keys = list(self.groups.keys())
        self.train_value_col = train_value_col
        self.train_category_col = train_category_col

        
    def __len__(self):
        return len(self.groups)

    def __getitem__(self, idx):
        indexes = self.groups[self.keys[idx]]
        df_ = self.df.iloc[indexes]
        
        input_value = df_[self.train_value_col].values
        input_category = df_[self.train_category_col].values

        u_out_ = df_['u_out'].values
        p_ = df_['pressure'].values

        data = {
            "input_value": input_value.astype(np.float32),
            "input_category": input_category.astype(int),
            "u_out": u_out_.astype(np.float32),
            "p": p_.astype(np.float32),
        }
        
        return data

In [12]:
class CustomModel(nn.Module):
    def __init__(self, input_dim=4, hidden_size=64):
        super().__init__()
        self.hidden_size = hidden_size
        self.r_emb = nn.Embedding(3, 2, padding_idx=0)
        self.c_emb = nn.Embedding(3, 2, padding_idx=0)
        self.seq_emb = nn.Sequential(
            nn.Linear(4 + input_dim, self.hidden_size),
            nn.LayerNorm(self.hidden_size),
            nn.ReLU(),
            nn.Dropout(0.2),
        )
        self.lstm = nn.LSTM(self.hidden_size, self.hidden_size, 
                            dropout=0.2, batch_first=True, bidirectional=True)
        self.head = nn.Sequential(
            nn.Linear(self.hidden_size * 2, self.hidden_size * 2),
            nn.LayerNorm(self.hidden_size * 2),
            nn.ReLU(),
            nn.Dropout(0.),
            nn.Linear(self.hidden_size * 2, 1),
        )
        for n, m in self.named_modules():
            if isinstance(m, nn.LSTM):
                print(f'init {m}')
                for param in m.parameters():
                    if len(param.shape) >= 2:
                        nn.init.orthogonal_(param.data)
                    else:
                        nn.init.normal_(param.data)
            elif isinstance(m, nn.GRU):
                print(f"init {m}")
                for param in m.parameters():
                    if len(param.shape) >= 2:
                        init.orthogonal_(param.data)
                    else:
                        init.normal_(param.data)

    def forward(self, cont_seq_x, cate_seq_x):
        bs = cont_seq_x.size(0)
        r_emb = self.r_emb(cate_seq_x[:,:,0]).view(bs, 80, -1)
        c_emb = self.c_emb(cate_seq_x[:,:,1]).view(bs, 80, -1)
        seq_x = torch.cat((r_emb, c_emb, cont_seq_x), 2)
        seq_emb = self.seq_emb(seq_x)
        seq_emb, _ = self.lstm(seq_emb)
        output = self.head(seq_emb).view(bs, -1)
        return output

In [30]:
class RNNModel(nn.Module):
    def __init__(
        self,
        input_dim=4,
        lstm_dim=256,
        dense_dim=256,
        logit_dim=256,
        num_classes=1,
    ):
        super().__init__()
        
        self.r_emb = nn.Embedding(3, 2, padding_idx=0)
        self.c_emb = nn.Embedding(3, 2, padding_idx=0)

        self.mlp = nn.Sequential(
            nn.Linear(input_dim, dense_dim // 2),
            nn.ReLU(),
            nn.Linear(dense_dim // 2, dense_dim),
            nn.ReLU(),
        )

        self.lstm = nn.LSTM(dense_dim + 4, lstm_dim, batch_first=True, bidirectional=True)

        self.logits = nn.Sequential(
            nn.Linear(lstm_dim * 2, logit_dim),
            nn.ReLU(),
            nn.Linear(logit_dim, num_classes),
        )

    def forward(self, cont_seq_x, cate_seq_x):
        bs = cont_seq_x.size(0)

        features = self.mlp(cont_seq_x)
        r_emb = self.r_emb(cate_seq_x[:,:,0]).view(bs, 80, -1)
        c_emb = self.c_emb(cate_seq_x[:,:,1]).view(bs, 80, -1)
        features = torch.cat((r_emb, c_emb, features), 2)
        
        features, _ = self.lstm(features)
        pred = self.logits(features)
        return pred

In [14]:
# Learner class(pytorch-lighting)
class Learner(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.criterion = get_criterion()
    
    def training_step(self, batch, batch_idx):
        d_ = batch
        output = self.model(d_['input_value'], d_['input_category'])
        loss = self.criterion(output.view(-1), d_['p'].view(-1), d_['u_out'].view(-1))
        return loss
    
    def validation_step(self, batch, batch_idx):
        d_ = batch
        output = self.model(d_['input_value'], d_['input_category'])
        loss = self.criterion(output.view(-1), d_['p'].view(-1), d_['u_out'].view(-1))
        
        self.log(f'Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        
        output = OrderedDict({
            "targets": d_['p'].detach(), "preds": output.detach(), "u_outs": d_['u_out'].detach(), "loss": loss.detach()
        })
        return output
    
    def validation_epoch_end(self, outputs):

        targets = torch.cat([o["targets"].view(-1) for o in outputs]).cpu().numpy()
        preds = torch.cat([o["preds"].view(-1) for o in outputs]).cpu().numpy()
        u_outs = torch.cat([o["u_outs"].view(-1) for o in outputs]).cpu().numpy()

        score = get_score(preds, targets, u_outs)
        self.log(f'custom_mae/val', score, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        print(f'epoch = {self.current_epoch}, custom_mae = {score}')

    def configure_optimizers(self):
        optimizer = get_optimizer(self.model)
        scheduler = get_scheduler(optimizer)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "Loss/val"}

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
def get_score(y_pred, y_true, u_outs):
    return compute_metric(y_pred, y_true, u_outs)


def to_np(input):
    return input.detach().cpu().numpy()

# oof
def evaluate(model, loaders, phase):
    model.eval()
    pred_list = []
    target_list = []
    with torch.no_grad():
        for batch in loaders[phase]:
            d_ = batch
            d_['input_value'] = d_['input_value'].to(device)
            d_['input_category'] = d_['input_category'].to(device)
            output = model(d_['input_category'])
#             output = nn.Softmax(dim=1)(output)
            pred_list.append(to_np(output))
            target_list.append(to_np(d_['p']))

    pred_list = np.concatenate(pred_list).reshape(-1)
    target_list = np.concatenate(target_list).reshape(-1)
    model.train()
    return pred_list, target_list

In [17]:
train = pd.read_csv(DATA_DIR / 'train.csv')
test = pd.read_csv(DATA_DIR / 'test.csv')
display(train), display(test)

Unnamed: 0,id,breath_id,R,C,time_step,u_in,u_out,pressure
0,1,1,20,50,0.000000,0.083334,0,5.837492
1,2,1,20,50,0.033652,18.383041,0,5.907794
2,3,1,20,50,0.067514,22.509278,0,7.876254
3,4,1,20,50,0.101542,22.808822,0,11.742872
4,5,1,20,50,0.135756,25.355850,0,12.234987
...,...,...,...,...,...,...,...,...
6035995,6035996,125749,50,10,2.504603,1.489714,1,3.869032
6035996,6035997,125749,50,10,2.537961,1.488497,1,3.869032
6035997,6035998,125749,50,10,2.571408,1.558978,1,3.798729
6035998,6035999,125749,50,10,2.604744,1.272663,1,4.079938


Unnamed: 0,id,breath_id,R,C,time_step,u_in,u_out
0,1,0,5,20,0.000000,0.000000,0
1,2,0,5,20,0.031904,7.515046,0
2,3,0,5,20,0.063827,14.651675,0
3,4,0,5,20,0.095751,21.230610,0
4,5,0,5,20,0.127644,26.320956,0
...,...,...,...,...,...,...,...
4023995,4023996,125748,20,10,2.530117,4.971245,1
4023996,4023997,125748,20,10,2.563853,4.975709,1
4023997,4023998,125748,20,10,2.597475,4.979468,1
4023998,4023999,125748,20,10,2.631134,4.982648,1


(None, None)

In [18]:
def get_raw_features(input_df, dataType = 'train'):
    colum = ['time_step', 'u_in']

    return input_df[colum]

In [19]:
def get_category_features(input_df, dataType = 'train'):
    output_df = copy.deepcopy(input_df)
    colum = ['R', 'C']
    r_map = {5: 0, 20: 1, 50: 2}
    c_map = {10: 0, 20: 1, 50: 2}
    output_df['R'] = output_df['R'].map(r_map)
    output_df['C'] = output_df['C'].map(c_map)
    
#     output_df['R_C'] = output_df['R'] + output_df['C'] * 10

    return output_df[colum]

In [20]:
def get_diff_shift_features(input_df, dataType = 'train'):
    
    
    output_df = copy.deepcopy(input_df)
    c_num = input_df.shape[1]
    
    b_id_gby = input_df.groupby(['breath_id'])
    shift_idx = [-2, -1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    
    for i in shift_idx:
        output_df[f'u_in_diff_{i}'] = b_id_gby['u_in'].diff(i)
        output_df[f'u_in_shift_{i}'] = b_id_gby['u_in'].shift(i)

        output_df[f'time_step_diff_{i}'] = b_id_gby['time_step'].diff(i)
        output_df[f'time_step_shift_{i}'] = b_id_gby['time_step'].shift(i)
    
    return output_df.iloc[:, c_num:]

In [21]:
def get_cum_features(input_df, dataType = 'train'):
    
    output_df = copy.deepcopy(input_df)
    c_num = input_df.shape[1]
    
    b_id_gby = input_df.groupby(['breath_id'])
    
    output_df['u_in_cumsum'] = b_id_gby['u_in'].cumsum()
    output_df['time_step_cumsum'] = b_id_gby['time_step'].cumsum()
    
    return output_df.iloc[:, c_num:]

In [22]:
def get_agg_features(input_df, dataType = 'train'):
    
    output_df = copy.deepcopy(input_df)
    c_num = input_df.shape[1]
    
    # Dict for aggregations
    create_feature_dict = {
        'u_in': [np.max, np.std, np.mean, 'first', 'last'],
    }
    
    def get_agg_window(start_time=0, end_time=3.0, add_suffix = False):
        
        df_tgt = output_df[(output_df['time_step'] >= start_time) & (output_df['time_step'] <= end_time)]
        df_feature = df_tgt.groupby(['breath_id']).agg(create_feature_dict)
        df_feature.columns = ['_'.join(col) for col in df_feature.columns]
        
        if add_suffix:
            df_feature = df_feature.add_suffix('_' + str(start_time) + '_' + str(end_time))
            
        return df_feature
    
    df_agg_feature = get_agg_window().reset_index()
    
#     df_tmp = get_agg_window(start_time = 2, add_suffix = True).reset_index()
#     df_agg_feature = df_agg_feature.merge(df_tmp, how = 'left', on = 'breath_id')
#     df_tmp = get_agg_window(start_time = 1, add_suffix = True).reset_index()
#     df_agg_feature = df_agg_feature.merge(df_tmp, how = 'left', on = 'breath_id')
#     df_tmp = get_agg_window(end_time = 1, add_suffix = True).reset_index()
#     df_agg_feature = df_agg_feature.merge(df_tmp, how = 'left', on = 'breath_id')
#     df_tmp = get_agg_window(end_time = 2, add_suffix = True).reset_index()
#     df_agg_feature = df_agg_feature.merge(df_tmp, how = 'left', on = 'breath_id')

    output_df = pd.merge(output_df, df_agg_feature, how='left', on='breath_id')
    
    
    return output_df.iloc[:, c_num:]

In [23]:
def to_feature(input_df, dataType = 'train'):
    """input_df を特徴量行列に変換した新しいデータフレームを返す.
    """

    processors = [
        get_raw_features,
        get_category_features,
        get_diff_shift_features,
        get_cum_features,
        get_agg_features
    ]

    out_df = pd.DataFrame()

    for func in tqdm(processors, total=len(processors)):
        with Timer(prefix='' + func.__name__ + ' '):
            _df = func(input_df, dataType)

        # 長さが等しいことをチェック (ずれている場合, func の実装がおかしい)
        assert len(_df) == len(input_df), func.__name__
        out_df = pd.concat([out_df, _df], axis=1)
    out_df = utils.reduce_mem_usage(out_df)
    
    return out_df

In [24]:
train_df = to_feature(train, dataType = 'train')
test_df = to_feature(test, dataType = 'test')

  0%|          | 0/5 [00:00<?, ?it/s]

get_raw_features  0.014[s]
get_category_features  0.165[s]


 40%|████      | 2/5 [00:00<00:00,  9.00it/s]

get_diff_shift_features  271.737[s]


 60%|██████    | 3/5 [04:32<03:47, 113.72s/it]

get_cum_features  0.198[s]


 80%|████████  | 4/5 [04:34<01:11, 71.78s/it] 

get_agg_features  1.147[s]


100%|██████████| 5/5 [04:36<00:00, 55.32s/it]
 40%|████      | 2/5 [00:00<00:00, 15.09it/s]

Mem. usage decreased from 2717.01 Mb to 667.74 Mb (75.4% reduction)
get_raw_features  0.009[s]
get_category_features  0.094[s]


 40%|████      | 2/5 [00:16<00:00, 15.09it/s]

get_diff_shift_features  180.849[s]


 60%|██████    | 3/5 [03:01<02:31, 75.69s/it]

get_cum_features  0.121[s]


 80%|████████  | 4/5 [03:02<00:47, 47.77s/it]

get_agg_features  0.712[s]


100%|██████████| 5/5 [03:04<00:00, 36.81s/it]


Mem. usage decreased from 1811.34 Mb to 445.16 Mb (75.4% reduction)


In [25]:
train_value_col = [i for i in train_df.columns.to_list() if i not in ['R', 'C']]
train_category_col = ['R', 'C']

In [26]:
ss = StandardScaler()

train_category = train_df[train_category_col]
train_df = pd.DataFrame(ss.fit_transform(train_df[train_value_col]), columns=train_value_col)
train_mean = train_df.mean()
train_df = train_df.fillna(train_df.mean())

test_category = test_df[train_category_col]
test_df = pd.DataFrame(ss.transform(test_df[train_value_col]), columns=train_value_col)
test_df = test_df.fillna(train_mean)

  sqr = np.multiply(arr, arr, out=arr)


In [27]:
train_df = pd.concat([train_df, train_category, train[['id', 'breath_id', 'pressure', 'u_out']]], axis=1)
test_df = pd.concat([test_df, test_category, test[['id', 'breath_id', 'u_out']]], axis=1)

In [28]:
display(train_df), display(test_df)

Unnamed: 0,time_step,u_in,u_in_diff_-2,u_in_shift_-2,time_step_diff_-2,time_step_shift_-2,u_in_diff_-1,u_in_shift_-1,time_step_diff_-1,time_step_shift_-1,...,u_in_std,u_in_mean,u_in_first,u_in_last,R,C,id,breath_id,pressure,u_out
0,-1.707031,-0.538574,-2.503906,1.241211,-0.562988,-1.703125,-2.511719,0.845215,-0.456543,-1.705078,...,0.119507,0.513672,-0.550293,0.282471,1,2,1,1,5.837492,0
1,-1.663086,0.823730,-0.530762,1.264648,-0.720703,-1.657227,-0.578613,1.158203,-0.629883,-1.660156,...,0.119507,0.513672,-0.550293,0.282471,1,2,2,1,5.907794,0
2,-1.618164,1.130859,-0.357422,1.467773,-0.877930,-1.612305,-0.056763,1.180664,-0.753906,-1.616211,...,0.119507,0.513672,-0.550293,0.282471,1,2,3,1,7.876254,0
3,-1.574219,1.153320,-0.533691,1.620117,-0.851562,-1.567383,-0.363525,1.373047,-0.901855,-1.571289,...,0.119507,0.513672,-0.550293,0.282471,1,2,4,1,11.742872,0
4,-1.530273,1.341797,-0.239746,1.608398,-0.747070,-1.521484,-0.275879,1.517578,-0.679199,-1.526367,...,0.119507,0.513672,-0.550293,0.282471,1,2,5,1,12.234987,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6035995,1.561523,-0.434082,-0.053284,-0.432129,-0.273682,1.649414,-0.015732,-0.435791,-0.208618,1.603516,...,-0.430176,-0.438721,-0.313477,-3.250000,2,0,6035996,125749,3.869032,1
6035996,1.605469,-0.434082,-0.022018,-0.454834,-0.247437,1.694336,-0.025513,-0.430420,-0.282959,1.650391,...,-0.430176,-0.438721,-0.313477,-3.250000,2,0,6035997,125749,3.869032,1
6035997,1.651367,-0.428711,-0.037323,-0.437988,-0.168579,1.738281,0.023178,-0.451904,-0.183838,1.694336,...,-0.430176,-0.438721,-0.313477,-3.250000,2,0,6035998,125749,3.798729,1
6035998,1.694336,-0.450439,-0.000000,-0.000000,0.000000,0.000000,-0.044556,-0.436035,-0.134277,1.737305,...,-0.430176,-0.438721,-0.313477,-3.250000,2,0,6035999,125749,4.079938,1


Unnamed: 0,time_step,u_in,u_in_diff_-2,u_in_shift_-2,time_step_diff_-2,time_step_shift_-2,u_in_diff_-1,u_in_shift_-1,time_step_diff_-1,time_step_shift_-1,...,u_in_amax,u_in_std,u_in_mean,u_in_first,u_in_last,R,C,id,breath_id,u_out
0,-1.707031,-0.544922,-1.651367,0.612793,1.014648,-1.708984,-1.041016,0.021011,0.980957,-1.708008,...,0.048096,0.367432,0.364746,-0.553711,0.266602,0,1,1,0,0
1,-1.665039,0.014435,-1.548828,1.138672,1.014648,-1.665039,-0.989746,0.561523,0.956055,-1.665039,...,0.048096,0.367432,0.364746,-0.553711,0.266602,0,1,2,0,0
2,-1.623047,0.545410,-1.325195,1.544922,1.014648,-1.624023,-0.913574,1.060547,0.956055,-1.623047,...,0.048096,0.367432,0.364746,-0.553711,0.266602,0,1,3,0,0
3,-1.582031,1.035156,-1.060547,1.876953,1.041016,-1.580078,-0.710449,1.446289,0.980957,-1.581055,...,0.048096,0.367432,0.364746,-0.553711,0.266602,0,1,4,0,0
4,-1.540039,1.414062,-0.837891,2.119141,1.014648,-1.537109,-0.583984,1.761719,0.956055,-1.539062,...,0.048096,0.367432,0.364746,-0.553711,0.266602,0,1,5,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4023995,1.594727,-0.174927,-0.046600,-0.158936,-0.510254,1.683594,-0.016510,-0.171387,-0.505859,1.639648,...,-0.027512,-0.551270,-0.843262,-0.156250,0.278564,1,0,4023996,125748,1
4023996,1.640625,-0.174561,-0.046448,-0.158569,-0.457764,1.728516,-0.016403,-0.171021,-0.431641,1.683594,...,-0.027512,-0.551270,-0.843262,-0.156250,0.278564,1,0,4023997,125748,1
4023997,1.684570,-0.174316,-0.046326,-0.158569,-0.694336,1.775391,-0.016327,-0.170776,-0.456543,1.727539,...,-0.027512,-0.551270,-0.843262,-0.156250,0.278564,1,0,4023998,125748,1
4023998,1.727539,-0.173950,-0.000000,-0.000000,0.000000,0.000000,-0.016266,-0.170776,-0.877441,1.773438,...,-0.027512,-0.551270,-0.843262,-0.156250,0.278564,1,0,4023999,125748,1


(None, None)

In [31]:
oof_total = np.zeros((len(train), CFG.num_classes))
sub_preds = np.zeros((test.shape[0], len(CFG.folds)))
val_idxes = []
models = []
y = train['pressure']
groups = train['breath_id']
gkfold = model_selection.GroupKFold(n_splits=CFG.n_folds)
scores = []
input_dim = len(train_value_col)

for i, (trn_idx, val_idx) in enumerate(splitter.split(train_df, y, groups)):
    if i not in CFG.folds:
        continue

    trn_df = train_df.loc[trn_idx, :].reset_index(drop=True)
    val_df = train_df.loc[val_idx, :].reset_index(drop=True)
    trn_y = y.values[trn_idx]
    val_y = y.values[val_idx]
    
    
    loaders = {
        phase: torchdata.DataLoader(
            VentilatorDataset(
                df_, train_value_col, train_category_col
            ),
            **CFG.loader_params[phase])  # type: ignore
        for phase, df_ in zip(["train", "valid", "test"], [trn_df, val_df, test_df])
    }
    
    
    model = RNNModel(
        input_dim=input_dim,
        lstm_dim=CFG.lstm_dim,
        dense_dim=CFG.dense_dim,
        logit_dim=CFG.logit_dim,
        num_classes=CFG.num_classes,
    )
    model_name = model.__class__.__name__
#     break
    
    learner = Learner(model)
    
    # loggers
    RUN_NAME = f'exp{str(CFG.exp_num)}'
    wandb.init(project='Ventilator-Pressure-Prediction', entity='sqrt4kaido', group=RUN_NAME, job_type=RUN_NAME + f'-fold-{i}')
    wandb.run.name = RUN_NAME + f'-fold-{i}'
    wandb_config = wandb.config
    wandb_config.model_name = model_name
    wandb.watch(model)
    
    # callbacks
    callbacks = []
    checkpoint_callback = ModelCheckpoint(
        monitor=f'Loss/val',
        mode='min',
        dirpath=OUTPUT_DIR,
        verbose=False,
        save_weights_only=True,
        filename=f'{model_name}-{learner.current_epoch}-{i}')
    callbacks.append(checkpoint_callback)

#     early_stop_callback = EarlyStopping(
#         monitor='Loss/val',
#         min_delta=0.00,
#         patience=10,
#         verbose=True,
#         mode='min')
#     callbacks.append(early_stop_callback)
    
    loggers = []
    loggers.append(WandbLogger())
    
    trainer = pl.Trainer(
        logger=loggers,
        callbacks=callbacks,
        max_epochs=CFG.epochs,
        default_root_dir=OUTPUT_DIR,
        gpus=1,
#         fast_dev_run=DEBUG,
        deterministic=True,
        benchmark=False,
        )
    
    trainer.fit(learner, train_dataloader=loaders['train'], val_dataloaders=loaders['valid'])
#     trainer.save_checkpoint(OUTPUT_DIR / "last.ckpt")
    print('train done.')
    
    #############
    # validation (to make oof)
    #############
    checkpoint = torch.load(checkpoint_callback.best_model_path)
    learner.load_state_dict(checkpoint['state_dict'])
    
    model = model.to(device)
    oof_pred, oof_target = evaluate(model, loaders, phase="valid")
    models.append(model)
    
    oof_score = get_score(oof_pred, oof_target, val_df['u_out'].values)
    scores.append(oof_score)
    oof_total[val_idx] = oof_pred.reshape(1, -1).T / CFG.bias
    val_idxes.append(val_idx)
    
    print('validate done.')
    print(f'fold = {i}, auc = {oof_score}')
    wandb.log({'CV_score': oof_score})
    
    #############
    # inference
    #############
    test_pred, _ = evaluate(model, loaders, phase="test")
    sub_preds[:, i] = test_pred
    
    print('inference done.')

# test_preds_total = np.array(test_preds_total)


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

[34m[1mwandb[0m: wandb version 0.12.2 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type           | Params
---------------------------------------------
0 | model     | RNNModel       | 4.9 M 
1 | criterion | VentilatorLoss | 0     
---------------------------------------------
4.9 M     Trainable params
0         Non-trainable params
4.9 M     Total params
19.563    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, custom_mae = 17.441207885742188


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, custom_mae = 1.1678025722503662


Validating: 0it [00:00, ?it/s]

epoch = 1, custom_mae = 0.9120286107063293


Validating: 0it [00:00, ?it/s]

epoch = 2, custom_mae = 0.8011571168899536


Validating: 0it [00:00, ?it/s]

epoch = 3, custom_mae = 0.8308899402618408


Validating: 0it [00:00, ?it/s]

epoch = 4, custom_mae = 0.7497460246086121


Validating: 0it [00:00, ?it/s]

epoch = 5, custom_mae = 0.7276131510734558


Validating: 0it [00:00, ?it/s]

epoch = 6, custom_mae = 0.680391788482666


Validating: 0it [00:00, ?it/s]

epoch = 7, custom_mae = 0.6263871192932129


Validating: 0it [00:00, ?it/s]

epoch = 8, custom_mae = 0.6427077054977417


Validating: 0it [00:00, ?it/s]

epoch = 9, custom_mae = 0.6225729584693909


Validating: 0it [00:00, ?it/s]

epoch = 10, custom_mae = 0.7020693421363831


Validating: 0it [00:00, ?it/s]

epoch = 11, custom_mae = 0.5831179022789001


Validating: 0it [00:00, ?it/s]

epoch = 12, custom_mae = 0.5741868615150452


Validating: 0it [00:00, ?it/s]

epoch = 13, custom_mae = 0.5713881850242615


Validating: 0it [00:00, ?it/s]

epoch = 14, custom_mae = 0.5558568239212036


Validating: 0it [00:00, ?it/s]

epoch = 15, custom_mae = 0.545965850353241


Validating: 0it [00:00, ?it/s]

epoch = 16, custom_mae = 0.5538600087165833


Validating: 0it [00:00, ?it/s]

epoch = 17, custom_mae = 0.5059000849723816


Validating: 0it [00:00, ?it/s]

epoch = 18, custom_mae = 0.4942365884780884


Validating: 0it [00:00, ?it/s]

epoch = 19, custom_mae = 0.5145253539085388


Validating: 0it [00:00, ?it/s]

epoch = 20, custom_mae = 0.5289043188095093


Validating: 0it [00:00, ?it/s]

epoch = 21, custom_mae = 0.4792182445526123


Validating: 0it [00:00, ?it/s]

epoch = 22, custom_mae = 0.47797590494155884


Validating: 0it [00:00, ?it/s]

epoch = 23, custom_mae = 0.4319208860397339


Validating: 0it [00:00, ?it/s]

epoch = 24, custom_mae = 0.5166347026824951


Validating: 0it [00:00, ?it/s]

epoch = 25, custom_mae = 0.4493643641471863


Validating: 0it [00:00, ?it/s]

epoch = 26, custom_mae = 0.4276195466518402


Validating: 0it [00:00, ?it/s]

epoch = 27, custom_mae = 0.4555519223213196


Validating: 0it [00:00, ?it/s]

epoch = 28, custom_mae = 0.5984547138214111


Validating: 0it [00:00, ?it/s]

epoch = 29, custom_mae = 0.4860117435455322


Validating: 0it [00:00, ?it/s]

epoch = 30, custom_mae = 0.45248591899871826


Validating: 0it [00:00, ?it/s]

epoch = 31, custom_mae = 0.4377414584159851


Validating: 0it [00:00, ?it/s]

epoch = 32, custom_mae = 0.42213645577430725


Validating: 0it [00:00, ?it/s]

epoch = 33, custom_mae = 0.4603876769542694


Validating: 0it [00:00, ?it/s]

epoch = 34, custom_mae = 0.4048911929130554


Validating: 0it [00:00, ?it/s]

epoch = 35, custom_mae = 0.3823695480823517


Validating: 0it [00:00, ?it/s]

epoch = 36, custom_mae = 0.39339479804039


Validating: 0it [00:00, ?it/s]

epoch = 37, custom_mae = 0.4077553153038025


Validating: 0it [00:00, ?it/s]

epoch = 38, custom_mae = 0.4005122780799866


Validating: 0it [00:00, ?it/s]

epoch = 39, custom_mae = 0.4052567780017853


Validating: 0it [00:00, ?it/s]

epoch = 40, custom_mae = 0.37734633684158325


Validating: 0it [00:00, ?it/s]

epoch = 41, custom_mae = 0.3716513514518738


Validating: 0it [00:00, ?it/s]

epoch = 42, custom_mae = 0.40463587641716003


Validating: 0it [00:00, ?it/s]

epoch = 43, custom_mae = 0.3841870129108429


Validating: 0it [00:00, ?it/s]

epoch = 44, custom_mae = 0.40211987495422363


Validating: 0it [00:00, ?it/s]

epoch = 45, custom_mae = 0.38021814823150635


Validating: 0it [00:00, ?it/s]

epoch = 46, custom_mae = 0.3835064172744751


Validating: 0it [00:00, ?it/s]

epoch = 47, custom_mae = 0.42257943749427795


Validating: 0it [00:00, ?it/s]

epoch = 48, custom_mae = 0.3633236885070801


Validating: 0it [00:00, ?it/s]

epoch = 49, custom_mae = 0.34992215037345886


Validating: 0it [00:00, ?it/s]

epoch = 50, custom_mae = 0.39550521969795227


Validating: 0it [00:00, ?it/s]

epoch = 51, custom_mae = 0.3701828420162201


Validating: 0it [00:00, ?it/s]

epoch = 52, custom_mae = 0.34873560070991516


Validating: 0it [00:00, ?it/s]

epoch = 53, custom_mae = 0.3533232510089874


Validating: 0it [00:00, ?it/s]

epoch = 54, custom_mae = 0.3915421664714813


Validating: 0it [00:00, ?it/s]

epoch = 55, custom_mae = 0.39964258670806885


Validating: 0it [00:00, ?it/s]

epoch = 56, custom_mae = 0.38213565945625305


Validating: 0it [00:00, ?it/s]

epoch = 57, custom_mae = 0.394172340631485


Validating: 0it [00:00, ?it/s]

epoch = 58, custom_mae = 0.34809640049934387


Validating: 0it [00:00, ?it/s]

epoch = 59, custom_mae = 0.3498207926750183


Validating: 0it [00:00, ?it/s]

epoch = 60, custom_mae = 0.39286574721336365


Validating: 0it [00:00, ?it/s]

epoch = 61, custom_mae = 0.36118558049201965


Validating: 0it [00:00, ?it/s]

epoch = 62, custom_mae = 0.33819833397865295


Validating: 0it [00:00, ?it/s]

epoch = 63, custom_mae = 0.32863596081733704


Validating: 0it [00:00, ?it/s]

epoch = 64, custom_mae = 0.32191842794418335


Validating: 0it [00:00, ?it/s]

epoch = 65, custom_mae = 0.34160882234573364


Validating: 0it [00:00, ?it/s]

epoch = 66, custom_mae = 0.3854815661907196


Validating: 0it [00:00, ?it/s]

epoch = 67, custom_mae = 0.3451760709285736


Validating: 0it [00:00, ?it/s]

epoch = 68, custom_mae = 0.32037198543548584


Validating: 0it [00:00, ?it/s]

epoch = 69, custom_mae = 0.3280428647994995


Validating: 0it [00:00, ?it/s]

epoch = 70, custom_mae = 0.3328806161880493


Validating: 0it [00:00, ?it/s]

epoch = 71, custom_mae = 0.32637834548950195


Validating: 0it [00:00, ?it/s]

epoch = 72, custom_mae = 0.32981202006340027


Validating: 0it [00:00, ?it/s]

epoch = 73, custom_mae = 0.30470746755599976


Validating: 0it [00:00, ?it/s]

epoch = 74, custom_mae = 0.38072335720062256


Validating: 0it [00:00, ?it/s]

epoch = 75, custom_mae = 0.3200618028640747


Validating: 0it [00:00, ?it/s]

epoch = 76, custom_mae = 0.32027962803840637


Validating: 0it [00:00, ?it/s]

epoch = 77, custom_mae = 0.312086284160614


Validating: 0it [00:00, ?it/s]

epoch = 78, custom_mae = 0.40549221634864807


Validating: 0it [00:00, ?it/s]

epoch = 79, custom_mae = 0.3582950234413147


Validating: 0it [00:00, ?it/s]

epoch = 80, custom_mae = 0.31366944313049316


Validating: 0it [00:00, ?it/s]

epoch = 81, custom_mae = 0.30942657589912415


Validating: 0it [00:00, ?it/s]

epoch = 82, custom_mae = 0.2743402123451233


Validating: 0it [00:00, ?it/s]

epoch = 83, custom_mae = 0.2781956195831299


Validating: 0it [00:00, ?it/s]

epoch = 84, custom_mae = 0.2716595232486725


Validating: 0it [00:00, ?it/s]

epoch = 85, custom_mae = 0.26907309889793396


Validating: 0it [00:00, ?it/s]

epoch = 86, custom_mae = 0.2676846981048584


Validating: 0it [00:00, ?it/s]

epoch = 87, custom_mae = 0.2664794623851776


Validating: 0it [00:00, ?it/s]

epoch = 88, custom_mae = 0.2650305926799774


Validating: 0it [00:00, ?it/s]

epoch = 89, custom_mae = 0.2650876045227051


Validating: 0it [00:00, ?it/s]

epoch = 90, custom_mae = 0.26226145029067993


Validating: 0it [00:00, ?it/s]

epoch = 91, custom_mae = 0.26373055577278137


Validating: 0it [00:00, ?it/s]

epoch = 92, custom_mae = 0.26107311248779297


Validating: 0it [00:00, ?it/s]

epoch = 93, custom_mae = 0.2611369788646698


Validating: 0it [00:00, ?it/s]

epoch = 94, custom_mae = 0.2613253891468048


Validating: 0it [00:00, ?it/s]

epoch = 95, custom_mae = 0.261644572019577


Validating: 0it [00:00, ?it/s]

epoch = 96, custom_mae = 0.2630506753921509


Validating: 0it [00:00, ?it/s]

epoch = 97, custom_mae = 0.26064616441726685


Validating: 0it [00:00, ?it/s]

epoch = 98, custom_mae = 0.2582496106624603


Validating: 0it [00:00, ?it/s]

epoch = 99, custom_mae = 0.2587800621986389


Validating: 0it [00:00, ?it/s]

epoch = 100, custom_mae = 0.2580699622631073


Validating: 0it [00:00, ?it/s]

epoch = 101, custom_mae = 0.2579943835735321


Validating: 0it [00:00, ?it/s]

epoch = 102, custom_mae = 0.2559516727924347


Validating: 0it [00:00, ?it/s]

epoch = 103, custom_mae = 0.25727248191833496


Validating: 0it [00:00, ?it/s]

epoch = 104, custom_mae = 0.25594741106033325


Validating: 0it [00:00, ?it/s]

epoch = 105, custom_mae = 0.25622785091400146


Validating: 0it [00:00, ?it/s]

epoch = 106, custom_mae = 0.2551858425140381


Validating: 0it [00:00, ?it/s]

epoch = 107, custom_mae = 0.2560684084892273


Validating: 0it [00:00, ?it/s]

epoch = 108, custom_mae = 0.2549721598625183


Validating: 0it [00:00, ?it/s]

epoch = 109, custom_mae = 0.2534162998199463


Validating: 0it [00:00, ?it/s]

epoch = 110, custom_mae = 0.2533490061759949


Validating: 0it [00:00, ?it/s]

epoch = 111, custom_mae = 0.2540771961212158


Validating: 0it [00:00, ?it/s]

epoch = 112, custom_mae = 0.2560593783855438


Validating: 0it [00:00, ?it/s]

epoch = 113, custom_mae = 0.25530949234962463


Validating: 0it [00:00, ?it/s]

epoch = 114, custom_mae = 0.25236260890960693


Validating: 0it [00:00, ?it/s]

epoch = 115, custom_mae = 0.2645738124847412


Validating: 0it [00:00, ?it/s]

epoch = 116, custom_mae = 0.2527962923049927


Validating: 0it [00:00, ?it/s]

epoch = 117, custom_mae = 0.25316619873046875


Validating: 0it [00:00, ?it/s]

epoch = 118, custom_mae = 0.2509187161922455


Validating: 0it [00:00, ?it/s]

epoch = 119, custom_mae = 0.251234233379364


Validating: 0it [00:00, ?it/s]

epoch = 120, custom_mae = 0.2513061761856079


Validating: 0it [00:00, ?it/s]

epoch = 121, custom_mae = 0.25019970536231995


Validating: 0it [00:00, ?it/s]

epoch = 122, custom_mae = 0.251199871301651


Validating: 0it [00:00, ?it/s]

epoch = 123, custom_mae = 0.2514662742614746


Validating: 0it [00:00, ?it/s]

epoch = 124, custom_mae = 0.2519843876361847


Validating: 0it [00:00, ?it/s]

epoch = 125, custom_mae = 0.25346457958221436


Validating: 0it [00:00, ?it/s]

epoch = 126, custom_mae = 0.25444746017456055


Validating: 0it [00:00, ?it/s]

epoch = 127, custom_mae = 0.25151777267456055


Validating: 0it [00:00, ?it/s]

epoch = 128, custom_mae = 0.24989952147006989


Validating: 0it [00:00, ?it/s]

epoch = 129, custom_mae = 0.2511695325374603


Validating: 0it [00:00, ?it/s]

epoch = 130, custom_mae = 0.25318214297294617


Validating: 0it [00:00, ?it/s]

epoch = 131, custom_mae = 0.24983589351177216


Validating: 0it [00:00, ?it/s]

epoch = 132, custom_mae = 0.24985581636428833


Validating: 0it [00:00, ?it/s]

epoch = 133, custom_mae = 0.25013095140457153


Validating: 0it [00:00, ?it/s]

epoch = 134, custom_mae = 0.2514011859893799


Validating: 0it [00:00, ?it/s]

epoch = 135, custom_mae = 0.2511430084705353


Validating: 0it [00:00, ?it/s]

epoch = 136, custom_mae = 0.24931631982326508


Validating: 0it [00:00, ?it/s]

epoch = 137, custom_mae = 0.24861082434654236


Validating: 0it [00:00, ?it/s]

epoch = 138, custom_mae = 0.2503777742385864


Validating: 0it [00:00, ?it/s]

epoch = 139, custom_mae = 0.25077274441719055


Validating: 0it [00:00, ?it/s]

epoch = 140, custom_mae = 0.24923905730247498


Validating: 0it [00:00, ?it/s]

epoch = 141, custom_mae = 0.25197187066078186


Validating: 0it [00:00, ?it/s]

epoch = 142, custom_mae = 0.24835903942584991


Validating: 0it [00:00, ?it/s]

epoch = 143, custom_mae = 0.25094711780548096


Validating: 0it [00:00, ?it/s]

epoch = 144, custom_mae = 0.24877914786338806


Validating: 0it [00:00, ?it/s]

epoch = 145, custom_mae = 0.2492087334394455


Validating: 0it [00:00, ?it/s]

epoch = 146, custom_mae = 0.24938040971755981


Validating: 0it [00:00, ?it/s]

epoch = 147, custom_mae = 0.24888357520103455


Validating: 0it [00:00, ?it/s]

epoch = 148, custom_mae = 0.2502705156803131


Validating: 0it [00:00, ?it/s]

epoch = 149, custom_mae = 0.25157472491264343


Validating: 0it [00:00, ?it/s]

epoch = 150, custom_mae = 0.24846462905406952


Validating: 0it [00:00, ?it/s]

epoch = 151, custom_mae = 0.24476513266563416


Validating: 0it [00:00, ?it/s]

epoch = 152, custom_mae = 0.24454215168952942


Validating: 0it [00:00, ?it/s]

epoch = 153, custom_mae = 0.24479930102825165


Validating: 0it [00:00, ?it/s]

epoch = 154, custom_mae = 0.2445240467786789


Validating: 0it [00:00, ?it/s]

epoch = 155, custom_mae = 0.244753897190094


Validating: 0it [00:00, ?it/s]

epoch = 156, custom_mae = 0.2447330206632614


Validating: 0it [00:00, ?it/s]

epoch = 157, custom_mae = 0.2447165548801422


Validating: 0it [00:00, ?it/s]

epoch = 158, custom_mae = 0.244577556848526


Validating: 0it [00:00, ?it/s]

epoch = 159, custom_mae = 0.24506595730781555


Validating: 0it [00:00, ?it/s]

epoch = 160, custom_mae = 0.24488940834999084


Validating: 0it [00:00, ?it/s]

epoch = 161, custom_mae = 0.2443196326494217


Validating: 0it [00:00, ?it/s]

epoch = 162, custom_mae = 0.24442368745803833


Validating: 0it [00:00, ?it/s]

epoch = 163, custom_mae = 0.24453948438167572


Validating: 0it [00:00, ?it/s]

epoch = 164, custom_mae = 0.24447864294052124


Validating: 0it [00:00, ?it/s]

epoch = 165, custom_mae = 0.2445286065340042


Validating: 0it [00:00, ?it/s]

epoch = 166, custom_mae = 0.24448451399803162


Validating: 0it [00:00, ?it/s]

epoch = 167, custom_mae = 0.24444209039211273


Validating: 0it [00:00, ?it/s]

epoch = 168, custom_mae = 0.24439340829849243


Validating: 0it [00:00, ?it/s]

epoch = 169, custom_mae = 0.24435631930828094


Validating: 0it [00:00, ?it/s]

epoch = 170, custom_mae = 0.24440215528011322


Validating: 0it [00:00, ?it/s]

epoch = 171, custom_mae = 0.24445213377475739


Validating: 0it [00:00, ?it/s]

epoch = 172, custom_mae = 0.24440833926200867
train done.


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


TypeError: forward() missing 1 required positional argument: 'cate_seq_x'

In [None]:
if len(CFG.folds) != CFG.n_folds:

    oof_score = get_score(oof_pred, oof_target, val_df['u_out'].values)
    print(f'MAE {oof_score}')

    oof_df = train.iloc[val_idx, :1]
    oof_df['target'] = oof_pred
    oof_df.to_csv(OUTPUT_DIR / f'oof{CFG.exp_num}.csv',index = False)    
else:
    score = mean_absolute_error(y, oof_total)
    print(f'MAE {score}: folds: {scores}')

    oof_df = pd.DataFrame({'id': train['id'].values, 'pressure':oof_total.reshape(-1)})
    oof_df.to_csv(OUTPUT_DIR / f'oof{CFG.exp_num}.csv',index = False)
oof_df

In [None]:
sub = pd.read_csv(DATA_DIR / 'sample_submission.csv')
sub['pressure'] = np.mean(sub_preds, axis=1)
sub.to_csv(OUTPUT_DIR / f'sub{CFG.exp_num}.csv',index = False)
sub

In [None]:
wandb.init(project='Ventilator-Pressure-Prediction', entity='sqrt4kaido', group=RUN_NAME, job_type='summary')
wandb.run.name = 'summary'
wandb.log({'CV_score': oof_score})
# wandb.save(utils.get_notebook_path())
wandb.finish()