In [1]:
import sys
from time import time
import numpy as np
import pandas as pd
from pathlib import Path
import lightgbm as lgb
import matplotlib.pyplot as plt 
import seaborn as sns
from tqdm import tqdm
import copy
import wandb
from collections import OrderedDict

from sklearn.metrics import mean_absolute_error
from sklearn import model_selection
from sklearn.preprocessing import StandardScaler, MinMaxScaler

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as torchdata

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import WandbLogger


In [2]:
sys.path.append('../../src/')
import utils as utils
from utils import Timer

In [3]:
class CFG:
    seed = 42
    exp_num = 14
    local = True
    n_folds = 5
    folds = [0]
    debug = False
    bias = 1000
    epochs = 200

    
    ######################
    # Dataset #
    ######################
    transforms = {
        "train": [{"name": ""}],
        "valid": [{"name": ""}],
        "test": [{"name": ""}]
    }

    ######################
    # Loaders #
    ######################
    loader_params = {
        "train": {
            'batch_size': 128,
            'shuffle': True,
            'num_workers': 8,
            'pin_memory': True,
            'drop_last': True,
        },
        "valid": {
            'batch_size': 32,
            'shuffle': False,
            'num_workers': 8,
            'pin_memory': True,
            'drop_last': False,
        },
        "test": {
            'batch_size': 32,
            'shuffle': False,
            'num_workers': 8,
            'pin_memory': True,
            'drop_last': False,
        }
    }

    ######################
    # Split #
    ######################
    split = "GroupKFold"
    split_params = {
        "n_splits": 5,
    }

    ######################
    # Model #
    ######################
    input_dim = 5

    dense_dim = 512
    lstm_dim = 512
    logit_dim = 512
    num_classes = 1

    ######################
    # Criterion #
    ######################
#     loss_name = "rmspe_loss"
#     loss_params: dict = {}

    ######################
    # Optimizer #
    ######################
    optimizer_name = "AdamW"
    optimizer_params = {
        "lr": 0.001,
        'weight_decay': 1e-6
    }

    ######################
    # Scheduler #
    ######################
    scheduler_name = "CosineAnnealingLR"
    scheduler_params = {
        'T_max': 25, 
        'eta_min': 1e-6
    }

In [4]:
utils.set_seed(CFG.seed)

In [5]:
if CFG.local:
    DATA_DIR = Path("/home/knikaido/work/Ventilator-Pressure-Prediction/data/ventilator-pressure-prediction")
    OUTPUT_DIR = Path('./output/')
else:
    DATA_DIR = Path("../input/ventilator-pressure-prediction")
    OUTPUT_DIR = Path('')   

In [6]:
def get_transforms(phase: str):
    transforms = CFG.transforms
    if transforms is None:
        return None
    else:
        if transforms[phase] is None:
            return None
        trns_list = []
        for trns_conf in transforms[phase]:
            trns_name = trns_conf["name"]
            trns_params = {} if trns_conf.get("params") is None else \
                trns_conf["params"]
            if globals().get(trns_name) is not None:
                trns_cls = globals()[trns_name]
                trns_list.append(trns_cls(**trns_params))

        if len(trns_list) > 0:
            return Compose(trns_list)
        else:
            return None
        
        
class Normalize:
    def __call__(self, y: np.ndarray):
        max_vol = np.abs(y).max()
        y_vol = y * 1 / max_vol
        return np.asfortranarray(y_vol)


class Compose:
    def __init__(self, transforms: list):
        self.transforms = transforms

    def __call__(self, y: np.ndarray):
        for trns in self.transforms:
            y = trns(y)
        return y

In [7]:
def compute_metric(preds, trues, u_outs):
    """
    Metric for the problem, as I understood it.
    """
    
    y = trues
    w = 1 - u_outs
    
    assert y.shape == preds.shape and w.shape == y.shape, (y.shape, preds.shape, w.shape)
    
    mae = w * np.abs(y - preds)
    mae = mae.sum() / w.sum()
    
    return mae


class VentilatorLoss(nn.Module):
    """
    Directly optimizes the competition metric
    """
    def __call__(self, preds, y, u_out):
        w = 1 - u_out
        mae = w * (y - preds).abs()
        mae = mae.sum(-1) / w.sum(-1)

        return mae

In [8]:
def get_criterion():
    return VentilatorLoss()

In [9]:
# Custom optimizer
__OPTIMIZERS__ = {}


def get_optimizer(model: nn.Module):
    optimizer_name = CFG.optimizer_name
    if optimizer_name == "SAM":
        base_optimizer_name = CFG.base_optimizer
        if __OPTIMIZERS__.get(base_optimizer_name) is not None:
            base_optimizer = __OPTIMIZERS__[base_optimizer_name]
        else:
            base_optimizer = optim.__getattribute__(base_optimizer_name)
        return SAM(model.parameters(), base_optimizer, **CFG.optimizer_params)

    if __OPTIMIZERS__.get(optimizer_name) is not None:
        return __OPTIMIZERS__[optimizer_name](model.parameters(),
                                              **CFG.optimizer_params)
    else:
        return optim.__getattribute__(optimizer_name)(model.parameters(),
                                                      **CFG.optimizer_params)


def get_scheduler(optimizer):
    scheduler_name = CFG.scheduler_name

    if scheduler_name is None:
        return
    else:
        return optim.lr_scheduler.__getattribute__(scheduler_name)(
            optimizer, **CFG.scheduler_params)

In [10]:
# validation
splitter = getattr(model_selection, CFG.split)(**CFG.split_params)

In [11]:
class VentilatorDataset(torchdata.Dataset):
    def __init__(self, df, train_col):
        if "pressure" not in df.columns:
            df['pressure'] = 0
        self.df = df
        self.groups = df.groupby('breath_id').groups
        self.keys = list(self.groups.keys())
        self.train_col = train_col
        
    def __len__(self):
        return len(self.groups)

    def __getitem__(self, idx):
        indexes = self.groups[self.keys[idx]]
        df_ = self.df.iloc[indexes]
        
        input_ = df_[self.train_col].values
        u_out_ = df_['u_out'].values
        p_ = df_['pressure'].values

        data = {
            "input": input_.astype(np.float32),
            "u_out": u_out_.astype(np.float32),
            "p": p_.astype(np.float32),
        }
        
        return data

In [12]:
class RNNModel(nn.Module):
    def __init__(
        self,
        input_dim=4,
        lstm_dim=256,
        dense_dim=256,
        logit_dim=256,
        num_classes=1,
    ):
        super().__init__()

        self.mlp = nn.Sequential(
            nn.Linear(input_dim, dense_dim // 2),
            nn.ReLU(),
            nn.Linear(dense_dim // 2, dense_dim),
            nn.ReLU(),
        )

        self.lstm = nn.LSTM(dense_dim, lstm_dim, batch_first=True, bidirectional=True)

        self.logits = nn.Sequential(
            nn.Linear(lstm_dim * 2, logit_dim),
            nn.ReLU(),
            nn.Linear(logit_dim, num_classes),
        )

    def forward(self, x):
        features = self.mlp(x)
        features, _ = self.lstm(features)
        pred = self.logits(features)
        return pred

In [13]:
# Learner class(pytorch-lighting)
class Learner(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.criterion = get_criterion()
    
    def training_step(self, batch, batch_idx):
        d_ = batch
        output = self.model(d_['input'])
        loss = self.criterion(output.view(-1), d_['p'].view(-1), d_['u_out'].view(-1))
        return loss
    
    def validation_step(self, batch, batch_idx):
        d_ = batch
        output = self.model(d_['input'])
        loss = self.criterion(output.view(-1), d_['p'].view(-1), d_['u_out'].view(-1))
        
        self.log(f'Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        
        output = OrderedDict({
            "targets": d_['p'].detach(), "preds": output.detach(), "u_outs": d_['u_out'].detach(), "loss": loss.detach()
        })
        return output
    
    def validation_epoch_end(self, outputs):

        targets = torch.cat([o["targets"].view(-1) for o in outputs]).cpu().numpy()
        preds = torch.cat([o["preds"].view(-1) for o in outputs]).cpu().numpy()
        u_outs = torch.cat([o["u_outs"].view(-1) for o in outputs]).cpu().numpy()

        score = get_score(preds, targets, u_outs)
        self.log(f'custom_mae/val', score, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        print(f'epoch = {self.current_epoch}, custom_mae = {score}')

    def configure_optimizers(self):
        optimizer = get_optimizer(self.model)
        scheduler = get_scheduler(optimizer)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "Loss/val"}

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [15]:
def get_score(y_pred, y_true, u_outs):
    return compute_metric(y_pred, y_true, u_outs)


def to_np(input):
    return input.detach().cpu().numpy()

# oof
def evaluate(model, loaders, phase):
    model.eval()
    pred_list = []
    target_list = []
    with torch.no_grad():
        for batch in loaders[phase]:
            d_ = batch
            d_['input'] = d_['input'].to(device)
            output = model(d_['input'])
#             output = nn.Softmax(dim=1)(output)
            pred_list.append(to_np(output))
            target_list.append(to_np(d_['p']))

    pred_list = np.concatenate(pred_list).reshape(-1)
    target_list = np.concatenate(target_list).reshape(-1)
    model.train()
    return pred_list, target_list

In [16]:
train = pd.read_csv(DATA_DIR / 'train.csv')
test = pd.read_csv(DATA_DIR / 'test.csv')
display(train), display(test)

Unnamed: 0,id,breath_id,R,C,time_step,u_in,u_out,pressure
0,1,1,20,50,0.000000,0.083334,0,5.837492
1,2,1,20,50,0.033652,18.383041,0,5.907794
2,3,1,20,50,0.067514,22.509278,0,7.876254
3,4,1,20,50,0.101542,22.808822,0,11.742872
4,5,1,20,50,0.135756,25.355850,0,12.234987
...,...,...,...,...,...,...,...,...
6035995,6035996,125749,50,10,2.504603,1.489714,1,3.869032
6035996,6035997,125749,50,10,2.537961,1.488497,1,3.869032
6035997,6035998,125749,50,10,2.571408,1.558978,1,3.798729
6035998,6035999,125749,50,10,2.604744,1.272663,1,4.079938


Unnamed: 0,id,breath_id,R,C,time_step,u_in,u_out
0,1,0,5,20,0.000000,0.000000,0
1,2,0,5,20,0.031904,7.515046,0
2,3,0,5,20,0.063827,14.651675,0
3,4,0,5,20,0.095751,21.230610,0
4,5,0,5,20,0.127644,26.320956,0
...,...,...,...,...,...,...,...
4023995,4023996,125748,20,10,2.530117,4.971245,1
4023996,4023997,125748,20,10,2.563853,4.975709,1
4023997,4023998,125748,20,10,2.597475,4.979468,1
4023998,4023999,125748,20,10,2.631134,4.982648,1


(None, None)

In [17]:
def get_raw_features(input_df, dataType = 'train'):
    colum = ['time_step', 'u_in']

    return input_df[colum]

In [18]:
def get_category_features(input_df, dataType = 'train'):
    output_df = copy.deepcopy(input_df)
    colum = ['R', 'C']
    
#     output_df['R_C'] = output_df['R'] + output_df['C'] * 10

    return output_df[colum].astype('category')

In [19]:
def get_diff_shift_features(input_df, dataType = 'train'):
    
    
    output_df = copy.deepcopy(input_df)
    c_num = input_df.shape[1]
    
    b_id_gby = input_df.groupby(['breath_id'])
    shift_idx = [-2, -1, 1, 2, 3, 4]
    
    def g_by_diff(c_, i):
        temp_df=pd.concat([output_df.loc[:, ['breath_id', c_]], output_df.loc[:, ['breath_id', c_]].reset_index().shift(i).rename(columns=lambda s:s+'_shift')], axis=1)
        df_with_diff=temp_df.loc[(temp_df['breath_id']==temp_df['breath_id_shift']), slice(None)]
        return(df_with_diff[c_]-df_with_diff[f'{c_}_shift'])
    
    # diffより直接引いたほうが早い
    for i in shift_idx:
        output_df[f'u_in_shift_{i}'] = b_id_gby['u_in'].shift(i)
        output_df[f'u_in_diff_{i}'] = g_by_diff('u_in', i)

        output_df[f'time_step_shift_{i}'] = b_id_gby['time_step'].shift(i)
        output_df[f'time_step_diff_{i}'] = g_by_diff('time_step', i)
    
    return output_df.iloc[:, c_num:]

In [20]:
def get_cum_features(input_df, dataType = 'train'):
    
    output_df = copy.deepcopy(input_df)
    c_num = input_df.shape[1]
    
    b_id_gby = input_df.groupby(['breath_id'])
    
    output_df['u_in_cumsum'] = b_id_gby['u_in'].cumsum()
    output_df['time_step_cumsum'] = b_id_gby['time_step'].cumsum()
    
    return output_df.iloc[:, c_num:]

In [21]:
def get_agg_features(input_df, dataType = 'train'):
    
    output_df = copy.deepcopy(input_df)
    c_num = input_df.shape[1]
    
    # Dict for aggregations
    create_feature_dict = {
        'u_in': [np.max, np.std, np.mean, 'first', 'last'],
    }
    
    def get_agg_window(start_time=0, end_time=3.0, add_suffix = False):
        
        df_tgt = output_df[(output_df['time_step'] >= start_time) & (output_df['time_step'] <= end_time)]
        df_feature = df_tgt.groupby(['breath_id']).agg(create_feature_dict)
        df_feature.columns = ['_'.join(col) for col in df_feature.columns]
        
        if add_suffix:
            df_feature = df_feature.add_suffix('_' + str(start_time) + '_' + str(end_time))
            
        return df_feature
    
    df_agg_feature = get_agg_window().reset_index()
    
#     df_tmp = get_agg_window(start_time = 2, add_suffix = True).reset_index()
#     df_agg_feature = df_agg_feature.merge(df_tmp, how = 'left', on = 'breath_id')
#     df_tmp = get_agg_window(start_time = 1, add_suffix = True).reset_index()
#     df_agg_feature = df_agg_feature.merge(df_tmp, how = 'left', on = 'breath_id')
#     df_tmp = get_agg_window(end_time = 1, add_suffix = True).reset_index()
#     df_agg_feature = df_agg_feature.merge(df_tmp, how = 'left', on = 'breath_id')
#     df_tmp = get_agg_window(end_time = 2, add_suffix = True).reset_index()
#     df_agg_feature = df_agg_feature.merge(df_tmp, how = 'left', on = 'breath_id')

    output_df = pd.merge(output_df, df_agg_feature, how='left', on='breath_id')
    
    
    return output_df.iloc[:, c_num:]

In [22]:
def to_feature(input_df, dataType = 'train'):
    """input_df を特徴量行列に変換した新しいデータフレームを返す.
    """

    processors = [
        get_raw_features,
        get_category_features,
        get_diff_shift_features,
        get_cum_features,
        get_agg_features
    ]

    out_df = pd.DataFrame()

    for func in tqdm(processors, total=len(processors)):
        with Timer(prefix='' + func.__name__ + ' '):
            _df = func(input_df, dataType)

        # 長さが等しいことをチェック (ずれている場合, func の実装がおかしい)
        assert len(_df) == len(input_df), func.__name__
        out_df = pd.concat([out_df, _df], axis=1)
    out_df = utils.reduce_mem_usage(out_df)
    
    return out_df

In [23]:
train_df = to_feature(train, dataType = 'train')
test_df = to_feature(test, dataType = 'test')

 40%|████      | 2/5 [00:00<00:00, 10.14it/s]

get_raw_features  0.014[s]
get_category_features  0.151[s]
get_diff_shift_features  11.202[s]
get_cum_features  0.198[s]


 80%|████████  | 4/5 [00:12<00:03,  3.72s/it]

get_agg_features  1.160[s]


100%|██████████| 5/5 [00:14<00:00,  2.90s/it]
 40%|████      | 2/5 [00:00<00:00, 15.45it/s]

Mem. usage decreased from 1531.20 Mb to 391.43 Mb (74.4% reduction)
get_raw_features  0.009[s]
get_category_features  0.098[s]
get_diff_shift_features  7.021[s]
get_cum_features  0.109[s]


 80%|████████  | 4/5 [00:07<00:02,  2.33s/it]

get_agg_features  0.724[s]


100%|██████████| 5/5 [00:09<00:00,  1.82s/it]


Mem. usage decreased from 1020.80 Mb to 260.96 Mb (74.4% reduction)


In [24]:
ss = StandardScaler()
ss.fit(train_df)

train_df = pd.DataFrame(ss.transform(train_df), columns=list(train_df.columns))
train_mean = train_df.mean()
train_df = train_df.fillna(train_df.mean())

test_df = pd.DataFrame(ss.transform(test_df), columns=list(test_df.columns))
test_df = test_df.fillna(train_mean)

In [25]:
display(train_df), display(test_df)

Unnamed: 0,time_step,u_in,R,C,u_in_shift_-2,u_in_diff_-2,time_step_shift_-2,time_step_diff_-2,u_in_shift_-1,u_in_diff_-1,...,u_in_diff_4,time_step_shift_4,time_step_diff_4,u_in_cumsum,time_step_cumsum,u_in_amax,u_in_std,u_in_mean,u_in_first,u_in_last
0,-1.706609,-0.538776,-0.359072,1.394522,1.240467e+00,-2.503374e+00,-1.703993e+00,-5.629722e-01,8.449263e-01,-2.511949e+00,...,2.022262e-18,1.113961e-16,-1.941055e-15,-0.980690,-1.116536,-0.245417,0.119488,0.513998,-0.550081,0.282547
1,-1.662664,0.823912,-0.359072,1.394522,1.264159e+00,-5.308007e-01,-1.658398e+00,-7.206978e-01,1.157443e+00,-5.786234e-01,...,2.022262e-18,1.113961e-16,-1.941055e-15,-0.936297,-1.115471,-0.245417,0.119488,0.513998,-0.550081,0.282547
2,-1.618480,1.130953,-0.359072,1.394522,1.467412e+00,-3.576065e-01,-1.612640e+00,-8.784233e-01,1.179935e+00,-5.675740e-02,...,2.022262e-18,1.113961e-16,-1.941055e-15,-0.881968,-1.113334,-0.245417,0.119488,0.513998,-0.550081,0.282547
3,-1.574017,1.153051,-0.359072,1.394522,1.619540e+00,-5.333698e-01,-1.567208e+00,-8.521357e-01,1.372891e+00,-3.633358e-01,...,2.022262e-18,1.113961e-16,-1.941055e-15,-0.826884,-1.110121,-0.245417,0.119488,0.513998,-0.550081,0.282547
4,-1.529395,1.342625,-0.359072,1.394522,1.608317e+00,-2.398601e-01,-1.521613e+00,-7.469854e-01,1.517311e+00,-2.756753e-01,...,3.504181e+00,-1.705506e+00,7.421831e-01,-0.765688,-1.105830,-0.245417,0.119488,0.513998,-0.550081,0.282547
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6035995,1.562294,-0.434126,1.171893,-0.937525,-4.320053e-01,-5.327896e-02,1.649282e+00,-2.738087e-01,-4.356242e-01,-1.572538e-02,...,1.286684e-01,1.553039e+00,2.278182e-01,-0.046433,1.892880,-0.334970,-0.430004,-0.438473,-0.313459,-3.250301
6035996,1.605641,-0.434199,1.171893,-0.937525,-4.548400e-01,-2.202934e-02,1.693733e+00,-2.475211e-01,-4.302972e-01,-2.550839e-02,...,1.291274e-01,1.598670e+00,2.548900e-01,-0.042811,1.973948,-0.334970,-0.430004,-0.438473,-0.313459,-3.250301
6035997,1.651539,-0.428965,1.171893,-0.937525,-4.380841e-01,-3.732968e-02,1.738184e+00,-1.686583e-01,-4.519751e-01,2.317612e-02,...,1.391917e-01,1.644300e+00,2.819618e-01,-0.038585,2.055017,-0.334970,-0.430004,-0.438473,-0.313459,-3.250301
6035998,1.694886,-0.450263,1.171893,-0.937525,-2.216228e-17,4.515522e-19,1.010101e-16,2.726380e-15,-4.360681e-01,-4.455089e-02,...,9.167304e-02,1.689931e+00,2.548900e-01,-0.035567,2.138062,-0.334970,-0.430004,-0.438473,-0.313459,-3.250301


Unnamed: 0,time_step,u_in,R,C,u_in_shift_-2,u_in_diff_-2,time_step_shift_-2,time_step_diff_-2,u_in_shift_-1,u_in_diff_-1,...,u_in_diff_4,time_step_shift_4,time_step_diff_4,u_in_cumsum,time_step_cumsum,u_in_amax,u_in_std,u_in_mean,u_in_first,u_in_last
0,-1.706609,-0.544978,-1.124554,-0.354513,6.126275e-01,-1.651318e+00,-1.708896e+00,1.014283e+00,2.101770e-02,-1.041172e+00,...,2.022262e-18,1.113961e-16,-1.941055e-15,-0.980892,-1.116536,0.048118,0.367372,0.364841,-0.553395,0.266802
1,-1.664975,0.014441,-1.124554,-0.354513,1.138217e+00,-1.549414e+00,-1.666161e+00,1.014283e+00,5.614117e-01,-9.894817e-01,...,2.022262e-18,1.113961e-16,-1.941055e-15,-0.962744,-1.115527,0.048118,0.367372,0.364841,-0.553395,0.266802
2,-1.623261,0.545366,-1.124554,-0.354513,1.544723e+00,-1.325054e+00,-1.623426e+00,1.014283e+00,1.060374e+00,-9.132784e-01,...,2.022262e-18,1.113961e-16,-1.941055e-15,-0.927355,-1.113508,0.048118,0.367372,0.364841,-0.553395,0.266802
3,-1.581587,1.035584,-1.124554,-0.354513,1.876412e+00,-1.060445e+00,-1.580773e+00,1.040571e+00,1.446285e+00,-7.102472e-01,...,2.022262e-18,1.113961e-16,-1.941055e-15,-0.876082,-1.110476,0.048118,0.367372,0.364841,-0.553395,0.266802
4,-1.539913,1.414733,-1.124554,-0.354513,2.119567e+00,-8.377977e-01,-1.537956e+00,1.014283e+00,1.761169e+00,-5.839523e-01,...,3.646851e+00,-1.705506e+00,-1.044558e+00,-0.812623,-1.106441,0.048118,0.367372,0.364841,-0.553395,0.266802
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4023995,1.595442,-0.174842,-0.359072,-0.937525,-1.589231e-01,-4.658715e-02,1.683274e+00,-5.103970e-01,-1.713460e-01,-1.650026e-02,...,1.152951e-01,1.585249e+00,6.880395e-01,-0.510644,1.922539,-0.027504,-0.551590,-0.843326,-0.156195,0.278611
4023996,1.641339,-0.174552,-0.359072,-0.937525,-1.586114e-01,-4.644624e-02,1.727725e+00,-4.578218e-01,-1.710501e-01,-1.640424e-02,...,1.147214e-01,1.630879e+00,7.151113e-01,-0.498571,2.003607,-0.027504,-0.551590,-0.843326,-0.156195,0.278611
4023997,1.684687,-0.174261,-0.359072,-0.937525,-1.586114e-01,-4.633293e-02,1.774791e+00,-6.944102e-01,-1.707541e-01,-1.632514e-02,...,1.142174e-01,1.679194e+00,6.068239e-01,-0.486800,2.086653,-0.027504,-0.551590,-0.843326,-0.156195,0.278611
4023998,1.728034,-0.173970,-0.359072,-0.937525,-2.216228e-17,4.515522e-19,1.010101e-16,2.726380e-15,-1.707541e-01,-1.626321e-02,...,1.137973e-01,1.724824e+00,5.256084e-01,-0.474726,2.169699,-0.027504,-0.551590,-0.843326,-0.156195,0.278611


(None, None)

In [26]:
train_col = train_df.columns.to_list()

In [27]:
train_df = pd.concat([train_df, train[['id', 'breath_id', 'pressure', 'u_out']]], axis=1)
test_df = pd.concat([test_df, test[['id', 'breath_id', 'u_out']]], axis=1)

In [None]:
oof_total = np.zeros((len(train), CFG.num_classes))
sub_preds = np.zeros((test.shape[0], len(CFG.folds)))
val_idxes = []
models = []
y = train['pressure']
groups = train['breath_id']
gkfold = model_selection.GroupKFold(n_splits=CFG.n_folds)
scores = []
input_dim = len(train_col)

for i, (trn_idx, val_idx) in enumerate(splitter.split(train_df, y, groups)):
    if i not in CFG.folds:
        continue

    trn_df = train_df.loc[trn_idx, :].reset_index(drop=True)
    val_df = train_df.loc[val_idx, :].reset_index(drop=True)
    trn_y = y.values[trn_idx]
    val_y = y.values[val_idx]
    
    
    loaders = {
        phase: torchdata.DataLoader(
            VentilatorDataset(
                df_, train_col
            ),
            **CFG.loader_params[phase])  # type: ignore
        for phase, df_ in zip(["train", "valid", "test"], [trn_df, val_df, test_df])
    }
    
    
    model = RNNModel(
        input_dim=input_dim,
        lstm_dim=CFG.lstm_dim,
        dense_dim=CFG.dense_dim,
        logit_dim=CFG.logit_dim,
        num_classes=CFG.num_classes,
    )
    model_name = model.__class__.__name__
#     break
    
    learner = Learner(model)
    
    # loggers
    RUN_NAME = f'exp{str(CFG.exp_num)}'
    wandb.init(project='Ventilator-Pressure-Prediction', entity='sqrt4kaido', group=RUN_NAME, job_type=RUN_NAME + f'-fold-{i}')
    wandb.run.name = RUN_NAME + f'-fold-{i}'
    wandb_config = wandb.config
    wandb_config.model_name = model_name
    wandb.watch(model)
    
    # callbacks
    callbacks = []
    checkpoint_callback = ModelCheckpoint(
        monitor=f'Loss/val',
        mode='min',
        dirpath=OUTPUT_DIR,
        verbose=False,
        save_weights_only=True,
        filename=f'{model_name}-{learner.current_epoch}-{i}')
    callbacks.append(checkpoint_callback)

#     early_stop_callback = EarlyStopping(
#         monitor='Loss/val',
#         min_delta=0.00,
#         patience=10,
#         verbose=True,
#         mode='min')
#     callbacks.append(early_stop_callback)
    
    loggers = []
    loggers.append(WandbLogger())
    
    trainer = pl.Trainer(
        logger=loggers,
        callbacks=callbacks,
        max_epochs=CFG.epochs,
        default_root_dir=OUTPUT_DIR,
        gpus=1,
#         fast_dev_run=DEBUG,
        deterministic=True,
        benchmark=False,
        )
    
    trainer.fit(learner, train_dataloader=loaders['train'], val_dataloaders=loaders['valid'])
#     trainer.save_checkpoint(OUTPUT_DIR / "last.ckpt")
    print('train done.')
    
    #############
    # validation (to make oof)
    #############
    checkpoint = torch.load(checkpoint_callback.best_model_path)
    learner.load_state_dict(checkpoint['state_dict'])
    
    model = model.to(device)
    oof_pred, oof_target = evaluate(model, loaders, phase="valid")
    models.append(model)
    
    oof_score = get_score(oof_pred, oof_target, val_df['u_out'].values)
    scores.append(oof_score)
    oof_total[val_idx] = oof_pred.reshape(1, -1).T / CFG.bias
    val_idxes.append(val_idx)
    
    print('validate done.')
    print(f'fold = {i}, auc = {oof_score}')
    wandb.log({'CV_score': oof_score})
    
    #############
    # inference
    #############
    test_pred, _ = evaluate(model, loaders, phase="test")
    sub_preds[:, i] = test_pred
    
    print('inference done.')

# test_preds_total = np.array(test_preds_total)


[34m[1mwandb[0m: Currently logged in as: [33msqrt4kaido[0m (use `wandb login --relogin` to force relogin)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type           | Params
---------------------------------------------
0 | model     | RNNModel       | 4.9 M 
1 | criterion | VentilatorLoss | 0     
---------------------------------------------
4.9 M     Trainable params
0         Non-trainable params
4.9 M     Total params
19.474    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, custom_mae = 17.430959701538086


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, custom_mae = 1.1416189670562744


Validating: 0it [00:00, ?it/s]

epoch = 1, custom_mae = 0.8791009187698364


Validating: 0it [00:00, ?it/s]

epoch = 2, custom_mae = 0.867949366569519


Validating: 0it [00:00, ?it/s]

epoch = 3, custom_mae = 0.7773261070251465


Validating: 0it [00:00, ?it/s]

epoch = 4, custom_mae = 0.8659289479255676


Validating: 0it [00:00, ?it/s]

epoch = 5, custom_mae = 0.7033107280731201


Validating: 0it [00:00, ?it/s]

epoch = 6, custom_mae = 0.7012588977813721


Validating: 0it [00:00, ?it/s]

epoch = 7, custom_mae = 0.664511501789093


Validating: 0it [00:00, ?it/s]

epoch = 8, custom_mae = 0.6467579007148743


Validating: 0it [00:00, ?it/s]

epoch = 9, custom_mae = 0.5969302654266357


Validating: 0it [00:00, ?it/s]

epoch = 10, custom_mae = 0.6229968667030334


Validating: 0it [00:00, ?it/s]

epoch = 11, custom_mae = 0.5772889256477356


Validating: 0it [00:00, ?it/s]

epoch = 12, custom_mae = 0.5613875389099121


Validating: 0it [00:00, ?it/s]

epoch = 13, custom_mae = 0.5272951126098633


Validating: 0it [00:00, ?it/s]

epoch = 14, custom_mae = 0.5185585618019104


Validating: 0it [00:00, ?it/s]

epoch = 15, custom_mae = 0.5046277046203613


Validating: 0it [00:00, ?it/s]

epoch = 16, custom_mae = 0.5345470905303955


Validating: 0it [00:00, ?it/s]

epoch = 17, custom_mae = 0.4687510132789612


Validating: 0it [00:00, ?it/s]

epoch = 18, custom_mae = 0.44642820954322815


Validating: 0it [00:00, ?it/s]

epoch = 19, custom_mae = 0.43084901571273804


Validating: 0it [00:00, ?it/s]

epoch = 20, custom_mae = 0.4207027852535248


Validating: 0it [00:00, ?it/s]

epoch = 21, custom_mae = 0.4154285490512848


Validating: 0it [00:00, ?it/s]

epoch = 22, custom_mae = 0.4103018045425415


Validating: 0it [00:00, ?it/s]

epoch = 23, custom_mae = 0.40671390295028687


Validating: 0it [00:00, ?it/s]

epoch = 24, custom_mae = 0.40412217378616333


Validating: 0it [00:00, ?it/s]

epoch = 25, custom_mae = 0.4034540355205536


Validating: 0it [00:00, ?it/s]

epoch = 26, custom_mae = 0.4030631482601166


Validating: 0it [00:00, ?it/s]

epoch = 27, custom_mae = 0.40481093525886536


Validating: 0it [00:00, ?it/s]

epoch = 28, custom_mae = 0.4089651107788086


Validating: 0it [00:00, ?it/s]

epoch = 29, custom_mae = 0.40759578347206116


Validating: 0it [00:00, ?it/s]

epoch = 30, custom_mae = 0.40767282247543335


Validating: 0it [00:00, ?it/s]

epoch = 31, custom_mae = 0.42291396856307983


Validating: 0it [00:00, ?it/s]

epoch = 32, custom_mae = 0.4283355474472046


Validating: 0it [00:00, ?it/s]

epoch = 33, custom_mae = 0.440805584192276


Validating: 0it [00:00, ?it/s]

epoch = 34, custom_mae = 0.42945414781570435


Validating: 0it [00:00, ?it/s]

epoch = 35, custom_mae = 0.4946896433830261


Validating: 0it [00:00, ?it/s]

epoch = 36, custom_mae = 0.43294447660446167


Validating: 0it [00:00, ?it/s]

epoch = 37, custom_mae = 0.4475654363632202


Validating: 0it [00:00, ?it/s]

epoch = 38, custom_mae = 0.45693346858024597


Validating: 0it [00:00, ?it/s]

epoch = 39, custom_mae = 0.4935589134693146


Validating: 0it [00:00, ?it/s]

epoch = 40, custom_mae = 0.5112201571464539


Validating: 0it [00:00, ?it/s]

epoch = 41, custom_mae = 0.4391661584377289


Validating: 0it [00:00, ?it/s]

epoch = 42, custom_mae = 0.470886766910553


Validating: 0it [00:00, ?it/s]

epoch = 43, custom_mae = 0.4862082302570343


Validating: 0it [00:00, ?it/s]

epoch = 44, custom_mae = 0.4700528383255005


Validating: 0it [00:00, ?it/s]

epoch = 45, custom_mae = 0.5821118354797363


Validating: 0it [00:00, ?it/s]

epoch = 46, custom_mae = 0.4836120903491974


Validating: 0it [00:00, ?it/s]

epoch = 47, custom_mae = 0.5053191781044006


Validating: 0it [00:00, ?it/s]

epoch = 48, custom_mae = 0.4368104338645935


Validating: 0it [00:00, ?it/s]

epoch = 49, custom_mae = 0.43317121267318726


Validating: 0it [00:00, ?it/s]

epoch = 50, custom_mae = 0.40634506940841675


Validating: 0it [00:00, ?it/s]

epoch = 51, custom_mae = 0.4351082444190979


Validating: 0it [00:00, ?it/s]

epoch = 52, custom_mae = 0.42427897453308105


Validating: 0it [00:00, ?it/s]

epoch = 53, custom_mae = 0.4735209047794342


Validating: 0it [00:00, ?it/s]

epoch = 54, custom_mae = 0.4475274085998535


Validating: 0it [00:00, ?it/s]

epoch = 55, custom_mae = 0.39773067831993103


Validating: 0it [00:00, ?it/s]

epoch = 56, custom_mae = 0.3831775188446045


Validating: 0it [00:00, ?it/s]

epoch = 57, custom_mae = 0.3959064483642578


Validating: 0it [00:00, ?it/s]

epoch = 58, custom_mae = 0.37376049160957336


Validating: 0it [00:00, ?it/s]

epoch = 59, custom_mae = 0.3688708543777466


Validating: 0it [00:00, ?it/s]

epoch = 60, custom_mae = 0.3685709834098816


Validating: 0it [00:00, ?it/s]

epoch = 61, custom_mae = 0.3561937212944031


Validating: 0it [00:00, ?it/s]

epoch = 62, custom_mae = 0.35721632838249207


Validating: 0it [00:00, ?it/s]

epoch = 63, custom_mae = 0.3465037941932678


Validating: 0it [00:00, ?it/s]

epoch = 64, custom_mae = 0.31294745206832886


Validating: 0it [00:00, ?it/s]

epoch = 65, custom_mae = 0.3116929531097412


Validating: 0it [00:00, ?it/s]

epoch = 66, custom_mae = 0.3048430383205414


Validating: 0it [00:00, ?it/s]

epoch = 67, custom_mae = 0.294059157371521


Validating: 0it [00:00, ?it/s]

epoch = 68, custom_mae = 0.288612425327301


Validating: 0it [00:00, ?it/s]

epoch = 69, custom_mae = 0.28646790981292725


Validating: 0it [00:00, ?it/s]

epoch = 70, custom_mae = 0.2788875997066498


Validating: 0it [00:00, ?it/s]

epoch = 71, custom_mae = 0.2758743464946747


Validating: 0it [00:00, ?it/s]

epoch = 72, custom_mae = 0.2783994972705841


Validating: 0it [00:00, ?it/s]

epoch = 73, custom_mae = 0.27356064319610596


Validating: 0it [00:00, ?it/s]

epoch = 74, custom_mae = 0.27316272258758545


Validating: 0it [00:00, ?it/s]

epoch = 75, custom_mae = 0.2730516195297241


Validating: 0it [00:00, ?it/s]

epoch = 76, custom_mae = 0.2730470895767212


Validating: 0it [00:00, ?it/s]

epoch = 77, custom_mae = 0.272980660200119


Validating: 0it [00:00, ?it/s]

epoch = 78, custom_mae = 0.2738143801689148


Validating: 0it [00:00, ?it/s]

epoch = 79, custom_mae = 0.27391088008880615


Validating: 0it [00:00, ?it/s]

epoch = 80, custom_mae = 0.27612388134002686


Validating: 0it [00:00, ?it/s]

epoch = 81, custom_mae = 0.28307008743286133


Validating: 0it [00:00, ?it/s]

epoch = 82, custom_mae = 0.2791364789009094


Validating: 0it [00:00, ?it/s]

epoch = 83, custom_mae = 0.27819740772247314


Validating: 0it [00:00, ?it/s]

epoch = 84, custom_mae = 0.3078193962574005


Validating: 0it [00:00, ?it/s]

epoch = 85, custom_mae = 0.29521068930625916


Validating: 0it [00:00, ?it/s]

epoch = 86, custom_mae = 0.34312039613723755


Validating: 0it [00:00, ?it/s]

epoch = 87, custom_mae = 0.32837918400764465


Validating: 0it [00:00, ?it/s]

epoch = 88, custom_mae = 0.32461676001548767


Validating: 0it [00:00, ?it/s]

epoch = 89, custom_mae = 0.31406334042549133


Validating: 0it [00:00, ?it/s]

epoch = 90, custom_mae = 0.3344520926475525


Validating: 0it [00:00, ?it/s]

epoch = 91, custom_mae = 0.36580705642700195


Validating: 0it [00:00, ?it/s]

epoch = 92, custom_mae = 0.32439985871315


Validating: 0it [00:00, ?it/s]

epoch = 93, custom_mae = 0.3677029013633728


Validating: 0it [00:00, ?it/s]

epoch = 94, custom_mae = 0.34141722321510315


Validating: 0it [00:00, ?it/s]

epoch = 95, custom_mae = 0.4168119728565216


Validating: 0it [00:00, ?it/s]

epoch = 96, custom_mae = 0.3662966787815094


Validating: 0it [00:00, ?it/s]

epoch = 97, custom_mae = 0.4433158040046692


Validating: 0it [00:00, ?it/s]

epoch = 98, custom_mae = 0.3840024173259735


Validating: 0it [00:00, ?it/s]

epoch = 99, custom_mae = 0.39122632145881653


Validating: 0it [00:00, ?it/s]

epoch = 100, custom_mae = 0.37326404452323914


Validating: 0it [00:00, ?it/s]

epoch = 101, custom_mae = 0.35947108268737793


Validating: 0it [00:00, ?it/s]

epoch = 102, custom_mae = 0.3697090446949005


Validating: 0it [00:00, ?it/s]

epoch = 103, custom_mae = 0.3765827417373657


Validating: 0it [00:00, ?it/s]

epoch = 104, custom_mae = 0.3739912509918213


Validating: 0it [00:00, ?it/s]

epoch = 105, custom_mae = 0.34031856060028076


Validating: 0it [00:00, ?it/s]

epoch = 106, custom_mae = 0.39740633964538574


Validating: 0it [00:00, ?it/s]

epoch = 107, custom_mae = 0.3254444897174835


Validating: 0it [00:00, ?it/s]

epoch = 108, custom_mae = 0.3195509612560272


Validating: 0it [00:00, ?it/s]

epoch = 109, custom_mae = 0.3015612065792084


Validating: 0it [00:00, ?it/s]

epoch = 110, custom_mae = 0.30830487608909607


Validating: 0it [00:00, ?it/s]

epoch = 111, custom_mae = 0.41761621832847595


Validating: 0it [00:00, ?it/s]

epoch = 112, custom_mae = 0.3760630488395691


Validating: 0it [00:00, ?it/s]

epoch = 113, custom_mae = 0.2944341003894806


Validating: 0it [00:00, ?it/s]

epoch = 114, custom_mae = 0.2822716534137726


Validating: 0it [00:00, ?it/s]

epoch = 115, custom_mae = 0.26914364099502563


Validating: 0it [00:00, ?it/s]

epoch = 116, custom_mae = 0.2666206657886505


Validating: 0it [00:00, ?it/s]

epoch = 117, custom_mae = 0.2640042006969452


Validating: 0it [00:00, ?it/s]

epoch = 118, custom_mae = 0.25784215331077576


Validating: 0it [00:00, ?it/s]

epoch = 119, custom_mae = 0.256095826625824


Validating: 0it [00:00, ?it/s]

epoch = 120, custom_mae = 0.25283724069595337


Validating: 0it [00:00, ?it/s]

epoch = 121, custom_mae = 0.25078821182250977


Validating: 0it [00:00, ?it/s]

epoch = 122, custom_mae = 0.24890026450157166


Validating: 0it [00:00, ?it/s]

epoch = 123, custom_mae = 0.24878597259521484


Validating: 0it [00:00, ?it/s]

epoch = 124, custom_mae = 0.2480844408273697


Validating: 0it [00:00, ?it/s]

epoch = 125, custom_mae = 0.24800936877727509


Validating: 0it [00:00, ?it/s]

epoch = 126, custom_mae = 0.24800705909729004


Validating: 0it [00:00, ?it/s]

epoch = 127, custom_mae = 0.24795059859752655


Validating: 0it [00:00, ?it/s]

epoch = 128, custom_mae = 0.24834589660167694


Validating: 0it [00:00, ?it/s]

epoch = 129, custom_mae = 0.24883057177066803


Validating: 0it [00:00, ?it/s]

epoch = 130, custom_mae = 0.2490956038236618


Validating: 0it [00:00, ?it/s]

epoch = 131, custom_mae = 0.2524987459182739


Validating: 0it [00:00, ?it/s]

epoch = 132, custom_mae = 0.25440487265586853


Validating: 0it [00:00, ?it/s]

epoch = 133, custom_mae = 0.25891605019569397


Validating: 0it [00:00, ?it/s]

epoch = 134, custom_mae = 0.2589152455329895


Validating: 0it [00:00, ?it/s]

epoch = 135, custom_mae = 0.2558973431587219


Validating: 0it [00:00, ?it/s]

epoch = 136, custom_mae = 0.3049897253513336


Validating: 0it [00:00, ?it/s]

epoch = 137, custom_mae = 0.27459368109703064


Validating: 0it [00:00, ?it/s]

epoch = 138, custom_mae = 0.2838951647281647


Validating: 0it [00:00, ?it/s]

epoch = 139, custom_mae = 0.2728775441646576


Validating: 0it [00:00, ?it/s]

epoch = 140, custom_mae = 0.3442526161670685


Validating: 0it [00:00, ?it/s]

epoch = 141, custom_mae = 0.29978615045547485


Validating: 0it [00:00, ?it/s]

epoch = 142, custom_mae = 0.29590585827827454


Validating: 0it [00:00, ?it/s]

epoch = 143, custom_mae = 0.3528928756713867


Validating: 0it [00:00, ?it/s]

epoch = 144, custom_mae = 0.3138945400714874


Validating: 0it [00:00, ?it/s]

epoch = 145, custom_mae = 0.29792678356170654


Validating: 0it [00:00, ?it/s]

epoch = 146, custom_mae = 0.33080577850341797


Validating: 0it [00:00, ?it/s]

epoch = 147, custom_mae = 0.31069982051849365


Validating: 0it [00:00, ?it/s]

epoch = 148, custom_mae = 0.32132795453071594


Validating: 0it [00:00, ?it/s]

epoch = 149, custom_mae = 0.34104910492897034


Validating: 0it [00:00, ?it/s]

epoch = 150, custom_mae = 0.32819950580596924


Validating: 0it [00:00, ?it/s]

epoch = 151, custom_mae = 0.3139752745628357


Validating: 0it [00:00, ?it/s]

epoch = 152, custom_mae = 0.3046962022781372


Validating: 0it [00:00, ?it/s]

epoch = 153, custom_mae = 0.3069967031478882


Validating: 0it [00:00, ?it/s]

epoch = 154, custom_mae = 0.3082621991634369


Validating: 0it [00:00, ?it/s]

epoch = 155, custom_mae = 0.2772802710533142


Validating: 0it [00:00, ?it/s]

epoch = 156, custom_mae = 0.3136462867259979


Validating: 0it [00:00, ?it/s]

epoch = 157, custom_mae = 0.30411458015441895


Validating: 0it [00:00, ?it/s]

epoch = 158, custom_mae = 0.27756446599960327


Validating: 0it [00:00, ?it/s]

epoch = 159, custom_mae = 0.2921149730682373


Validating: 0it [00:00, ?it/s]

epoch = 160, custom_mae = 0.2793535888195038


Validating: 0it [00:00, ?it/s]

epoch = 161, custom_mae = 0.262950599193573


Validating: 0it [00:00, ?it/s]

epoch = 162, custom_mae = 0.258510947227478


Validating: 0it [00:00, ?it/s]

epoch = 163, custom_mae = 0.25027307868003845


Validating: 0it [00:00, ?it/s]

epoch = 164, custom_mae = 0.25109267234802246


Validating: 0it [00:00, ?it/s]

epoch = 165, custom_mae = 0.27174481749534607


Validating: 0it [00:00, ?it/s]

epoch = 166, custom_mae = 0.2539628744125366


Validating: 0it [00:00, ?it/s]

epoch = 167, custom_mae = 0.2438044399023056


Validating: 0it [00:00, ?it/s]

epoch = 168, custom_mae = 0.2399175614118576


Validating: 0it [00:00, ?it/s]

epoch = 169, custom_mae = 0.24024507403373718


Validating: 0it [00:00, ?it/s]

epoch = 170, custom_mae = 0.2378750592470169


Validating: 0it [00:00, ?it/s]

epoch = 171, custom_mae = 0.23655220866203308


Validating: 0it [00:00, ?it/s]

epoch = 172, custom_mae = 0.23551471531391144


Validating: 0it [00:00, ?it/s]

epoch = 173, custom_mae = 0.23504312336444855


Validating: 0it [00:00, ?it/s]

epoch = 174, custom_mae = 0.2349730283021927


Validating: 0it [00:00, ?it/s]

epoch = 175, custom_mae = 0.23490466177463531


Validating: 0it [00:00, ?it/s]

epoch = 176, custom_mae = 0.23490560054779053


Validating: 0it [00:00, ?it/s]

epoch = 177, custom_mae = 0.23473221063613892


Validating: 0it [00:00, ?it/s]

epoch = 178, custom_mae = 0.23571573197841644


Validating: 0it [00:00, ?it/s]

epoch = 179, custom_mae = 0.23486699163913727


Validating: 0it [00:00, ?it/s]

epoch = 180, custom_mae = 0.237295001745224


Validating: 0it [00:00, ?it/s]

epoch = 181, custom_mae = 0.23669621348381042


Validating: 0it [00:00, ?it/s]

epoch = 182, custom_mae = 0.23865117132663727


Validating: 0it [00:00, ?it/s]

epoch = 183, custom_mae = 0.24380433559417725


Validating: 0it [00:00, ?it/s]

epoch = 184, custom_mae = 0.2695484459400177


Validating: 0it [00:00, ?it/s]

epoch = 185, custom_mae = 0.24815469980239868


Validating: 0it [00:00, ?it/s]

epoch = 186, custom_mae = 0.24550969898700714


Validating: 0it [00:00, ?it/s]

epoch = 187, custom_mae = 0.25361090898513794


Validating: 0it [00:00, ?it/s]

epoch = 188, custom_mae = 0.2514781951904297


Validating: 0it [00:00, ?it/s]

epoch = 189, custom_mae = 0.2593199908733368


Validating: 0it [00:00, ?it/s]

epoch = 190, custom_mae = 0.3061671257019043


Validating: 0it [00:00, ?it/s]

epoch = 191, custom_mae = 0.3363674283027649


Validating: 0it [00:00, ?it/s]

epoch = 192, custom_mae = 0.27447059750556946


Validating: 0it [00:00, ?it/s]

epoch = 193, custom_mae = 0.2673512101173401


Validating: 0it [00:00, ?it/s]

epoch = 194, custom_mae = 0.2643401622772217


Validating: 0it [00:00, ?it/s]

epoch = 195, custom_mae = 0.2778829038143158


Validating: 0it [00:00, ?it/s]

epoch = 196, custom_mae = 0.3217642903327942


Validating: 0it [00:00, ?it/s]

epoch = 197, custom_mae = 0.3081640899181366


In [None]:
if len(CFG.folds) != CFG.n_folds:

    oof_score = get_score(oof_pred, oof_target, val_df['u_out'].values)
    print(f'MAE {oof_score}')

    oof_df = train.iloc[val_idxes[0], :1]
    oof_df['pressure'] = oof_pred
    oof_df.to_csv(OUTPUT_DIR / f'oof{CFG.exp_num}.csv',index = False)    
else:
    score = get_score(y, oof_total, train['u_out'].values)
    print(f'MAE {score}: folds: {scores}')

    oof_df = pd.DataFrame({'id': train['id'].values, 'pressure':oof_total.reshape(-1)})
    oof_df.to_csv(OUTPUT_DIR / f'oof{CFG.exp_num}.csv',index = False)
oof_df

In [None]:
sub = pd.read_csv(DATA_DIR / 'sample_submission.csv')
sub['pressure'] = np.mean(sub_preds, axis=1)
sub.to_csv(OUTPUT_DIR / f'sub{CFG.exp_num}.csv',index = False)
sub

In [None]:
wandb.init(project='Ventilator-Pressure-Prediction', entity='sqrt4kaido', group=RUN_NAME, job_type='summary')
wandb.run.name = 'summary'
wandb.log({'CV_score': oof_score})
# wandb.save(utils.get_notebook_path())
wandb.finish()