## Get env

In [3]:
!nvidia-smi

Wed Oct  6 22:09:13 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 471.11       Driver Version: 471.11       CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:07:00.0  On |                  N/A |
|  0%   48C    P8    24W / 370W |   2925MiB / 24576MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
# 環境によって処理を変えるためのもの
import sys
IN_COLAB = 'google.colab' in sys.modules
IN_KAGGLE = 'kaggle_web_client' in sys.modules
LOCAL = not (IN_KAGGLE or IN_COLAB)
print(f'IN_COLAB:{IN_COLAB}, IN_KAGGLE:{IN_KAGGLE}, LOCAL:{LOCAL}')

IN_COLAB:False, IN_KAGGLE:False, LOCAL:True


In [5]:
# For Colab Download some datasets
# ==================
if IN_COLAB:
    # mount googledrive
    from google.colab import drive
    drive.mount('/content/drive')
    # copy kaggle.json from googledrive
    ! pip install --upgrade --force-reinstall --no-deps  kaggle > /dev/null
    ! mkdir ~/.kaggle
    ! cp "/content/drive/MyDrive/kaggle/kaggle.json" ~/.kaggle/
    ! chmod 600 ~/.kaggle/kaggle.json
    
    # if not os.path.exists("/content/input/train_short_audio"):
    #     !mkdir input
    #     !kaggle competitions download -c birdclef-2021
    #     !unzip /content/birdclef-2021.zip -d input

In [6]:
if IN_KAGGLE:
    !pip install --upgrade -q wandb
    !pip install -q pytorch-lightning
    !pip install torch_optimizer

## Import Libraries

In [7]:
# Hide Warning
import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)

# Python Libraries
import os
import math
import random
import glob
import pickle
from collections import defaultdict
from pathlib import Path

# Third party
import numpy as np
import pandas as pd
from tqdm import tqdm

# Visualizations
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
%matplotlib inline
sns.set(style="whitegrid")

# Utilities and Metrics
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold
from sklearn.metrics import mean_absolute_error #[roc_auc_score, accuracy_score]

# Pytorch 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
from torch.optim.optimizer import Optimizer, required
import torch_optimizer as optim

# Pytorch Lightning
import pytorch_lightning as pl
from pytorch_lightning import Callback, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger, CSVLogger

# Weights and Biases Tool
import wandb
os.environ["WANDB_API_KEY"]='68fa1bbcda0fcf7a56f3c33a0fafa45b02f1c52d'
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mteyosan1229[0m (use `wandb login --relogin` to force relogin)


True

## Config

In [8]:
class CFG:
    debug = True
    competition='ventilator'
    exp_name = "test"
    seed = 29
    # model
    # img_size = 384
    
    # data
    target_col = 'pressure' # 目標値のある列名
    target_size = 1
    
    # optimizer
    optimizer_name = 'RAdam'
    lr = 5e-3
    weight_decay = 1e-6
    amsgrad = False
    
    # scheduler
    epochs = 300
    scheduler = 'CosineAnnealingLR'
    T_max = 300
    min_lr = 1e-6
    # criterion
    # u_out = 1 を考慮しないLoss
    criterion_name = 'CustomLoss1'
    
    # training
    train = True
    inference = True
    n_fold = 5
    trn_fold = [0]
    precision = 16 #[16, 32, 64]
    grad_acc = 1
    # DataLoader
    loader = {
        "train": {
            "batch_size": 256,
            "num_workers": 0,
            "shuffle": True,
            "pin_memory": True,
            "drop_last": True
        },
        "valid": {
            "batch_size": 256,
            "num_workers": 0,
            "shuffle": False,
            "pin_memory": True,
            "drop_last": False
        }
    }
    # pl
    trainer = {
        'gpus': 1,
        'progress_bar_refresh_rate': 1,
        'benchmark': False,
        'deterministic': True,
        }
    # LSTM
    num_layers = 4
    feature_cols = ['time_step', 'u_in', 'u_out'] + ['area'] + ['cross', 'cross2'] + ['u_in_cumsum', 'u_in_cummean'] + \
                   ['u_in_lag','u_in_lag2','u_in_lag3','u_in_lag_back','u_in_lag_back2','u_in_lag_back3'] + \
                   ['u_out_lag','u_out_lag2','u_out_lag3','u_out_lag_back','u_out_lag_back2','u_out_lag_back3'] + \
                   ['R_20', 'R_5', 'R_50', 'C_10', 'C_20', 'C_50', 'RC_2010', 'RC_2020', 'RC_2050', 'RC_5010', 'RC_5020', 'RC_5050', 'RC_510', 'RC_520', 'RC_550']
    
    dense_dim = 512
    hidden_size = 512
    logit_dim = 512

In [9]:
len(CFG.feature_cols)

35

## Directory & LoadData

In [10]:
if IN_KAGGLE:
    INPUT_DIR = Path('../input/ventilator-pressure-prediction')
    OUTPUT_DIR = './'
elif IN_COLAB:
    INPUT_DIR = Path('/content/input/')
    OUTPUT_DIR = f'/content/drive/MyDrive/kaggle/ventilator-pressure-prediction/{CFG.exp_name}/'
if LOCAL:
    INPUT_DIR = Path("F:/Kaggle/ventilator-pressure-prediction/data/input/")
    OUTPUT_DIR = f'F:/Kaggle/ventilator-pressure-prediction/data/output/{CFG.exp_name}/'
    
def load_datasets(feats):
    dfs = [pd.read_feather(INPUT_DIR / f'features/{f}_train.ftr') for f in feats]
    X_train = pd.concat(dfs, axis=1)
    dfs = [pd.read_feather(INPUT_DIR / f'features/{f}_test.ftr') for f in feats]
    X_test = pd.concat(dfs, axis=1)
    return X_train, X_test

feats = ['Base', 'Area', 'Cross', 'U_in_cumsum_mean', 'U_in_Lag', 'U_out_Lag', 'RC_OHE']
df_train, df_test = load_datasets(feats)

#df_train = pd.read_csv(INPUT_DIR / "train_v2.csv")
#df_test = pd.read_csv(INPUT_DIR / "test_v2.csv")
submission = pd.read_csv(INPUT_DIR / "sample_submission.csv")
display(df_train.head())
display(df_test.head())

if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

if CFG.debug:
    CFG.epochs = 5
    #CFG.inference = False
    #df_train = df_train.sample(n=1000, random_state=CFG.seed).reset_index(drop=True)
    df_train = df_train.head(240000)

Unnamed: 0,id,breath_id,time_step,u_in,u_out,pressure,fold,area,cross,cross2,...,C_50,RC_2010,RC_2020,RC_2050,RC_5010,RC_5020,RC_5050,RC_510,RC_520,RC_550
0,1,1,0.0,0.080043,0,5.837492,4,0.0,0.0,0.0,...,1,0,0,1,0,0,0,0,0,0
1,2,1,0.033652,2.964399,0,5.907794,4,0.618632,0.0,0.0,...,1,0,0,1,0,0,0,0,0,0
2,3,1,0.067514,3.157395,0,7.876254,4,2.138333,0.0,0.0,...,1,0,0,1,0,0,0,0,0,0
3,4,1,0.101542,3.170056,0,11.742872,4,4.454391,0.0,0.0,...,1,0,0,1,0,0,0,0,0,0
4,5,1,0.135756,3.27169,0,12.234987,4,7.896588,0.0,0.0,...,1,0,0,1,0,0,0,0,0,0


Unnamed: 0,id,breath_id,time_step,u_in,u_out,area,cross,cross2,u_in_cumsum,u_in_cummean,...,C_50,RC_2010,RC_2020,RC_2050,RC_5010,RC_5020,RC_5050,RC_510,RC_520,RC_550
0,1,0,0.0,0.0,0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,1,0
1,2,0,0.031904,2.141835,0,0.239758,0.0,0.0,7.515046,3.757523,...,0,0,0,0,0,0,0,0,1,0
2,3,0,0.063827,2.750578,0,1.174935,0.0,0.0,22.166721,7.388907,...,0,0,0,0,0,0,0,0,1,0
3,4,0,0.095751,3.10147,0,3.207788,0.0,0.0,43.397331,10.849333,...,0,0,0,0,0,0,0,0,1,0
4,5,0,0.127644,3.307654,0,6.567489,0.0,0.0,69.718287,13.943657,...,0,0,0,0,0,0,0,0,1,0


## Utils

In [11]:
seed_everything(CFG.seed)
# LINEに通知
import requests
def send_line_notification(message):
    line_token = '8vBbxd0jENU39kV2ROEwp78jAzeankBFi7AG0JjoU3j'
    endpoint = 'https://notify-api.line.me/api/notify'
    message = f"{message}"
    payload = {'message': message}
    headers = {'Authorization': 'Bearer {}'.format(line_token)}
    requests.post(endpoint, data=payload, headers=headers)

Global seed set to 29


## CV Split

In [12]:
# df_train["fold"] = -1
# Fold = GroupKFold(n_splits=CFG.n_fold)
# for n, (train_index, val_index) in enumerate(Fold.split(df_train, df_train[CFG.target_col], groups=df_train.breath_id.values)):
#      df_train.loc[val_index, 'fold'] = int(n)
# df_train['fold'] = df_train['fold'].astype(int)
print(df_train.groupby(['fold', 'breath_id']).size())

fold  breath_id
0     4            80
      16           80
      18           80
      20           80
      23           80
                   ..
4     4858         80
      4867         80
      4874         80
      4888         80
      4907         80
Length: 3000, dtype: int64


## Transforms

In [13]:
def get_transforms(phase: str):
    if phase == 'train':
        return Compose([
            Resize(CFG.img_size, CFG.img_size),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(p=1.0),
        ])
    elif phase == 'valid':
        return Compose([
            Resize(CFG.img_size, CFG.img_size),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(p=1.0),
        ])


## Dataset

In [14]:
"""
breath_idで１まとまり？
"""
class TrainDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.groups =df.groupby('breath_id').groups
        self.keys = list(self.groups.keys())
        #self.targets = df[CFG.target_col].values
        #self.transform = transform
        
    def __len__(self):
        return len(self.groups)
    
    def __getitem__(self, idx):
        indexes = self.groups[self.keys[idx]]
        df = self.df.iloc[indexes]
        #cate_seq_x = torch.LongTensor(df[CFG.cate_seq_cols].values)
        #cont_seq_x = torch.FloatTensor(df[CFG.cont_seq_cols].values)
        x = torch.FloatTensor(df[CFG.feature_cols].values)
        u_out = torch.LongTensor(df['u_out'].values)
        label = torch.FloatTensor(df['pressure'].values)
        #return cate_seq_x, cont_seq_x, u_out, label
        return x, u_out, label
    
class TestDataset(Dataset):
    def __init__(self, df):
        self.df = df
        self.groups = df.groupby('breath_id').groups
        self.keys = list(self.groups.keys())
        
    def __len__(self):
        return len(self.groups)

    def __getitem__(self, idx):
        indexes = self.groups[self.keys[idx]]
        df = self.df.iloc[indexes]
        #cate_seq_x = torch.LongTensor(df[CFG.cate_seq_cols].values)
        #cont_seq_x = torch.FloatTensor(df[CFG.cont_seq_cols].values)
        x = torch.FloatTensor(df[CFG.feature_cols].values)
        return x

どんなデータを取ってきてるのかの確認  
breath_id1個分のデータを何個かに分けてとってきてるっぽい

In [15]:
train_dataset = TrainDataset(df_train)
train_dataset[0]

(tensor([[0.0000, 0.0800, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0337, 2.9644, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0675, 3.1574, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [2.6218, 1.7887, 1.0000,  ..., 0.0000, 0.0000, 0.0000],
         [2.6557, 1.7892, 1.0000,  ..., 0.0000, 0.0000, 0.0000],
         [2.6898, 1.7896, 1.0000,  ..., 0.0000, 0.0000, 0.0000]]),
 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]),
 tensor([ 5.8375,  5.9078,  7.8763, 11.7429, 12.2350, 12.8677, 14.6956, 15.8907,
         15.5392, 15.7501, 17.2967, 17.2264, 16.1719, 17.3670, 18.0701, 17.1561,
         18.2810, 18.7731, 17.8592, 19.1246, 19.3355, 18.4919, 18.5622, 18.6325,
         18.8434, 19.0543, 19.2652, 19.3355, 19.3355, 19.4761, 19.5464, 1

In [17]:
df_train[df_train["breath_id"] ==1 ]

Unnamed: 0,id,breath_id,time_step,u_in,u_out,pressure,fold,area,cross,cross2,...,C_50,RC_2010,RC_2020,RC_2050,RC_5010,RC_5020,RC_5050,RC_510,RC_520,RC_550
0,1,1,0.000000,0.080043,0,5.837492,4,0.000000,0.000000,0.000000,...,1,0,0,1,0,0,0,0,0,0
1,2,1,0.033652,2.964399,0,5.907794,4,0.618632,0.000000,0.000000,...,1,0,0,1,0,0,0,0,0,0
2,3,1,0.067514,3.157395,0,7.876254,4,2.138333,0.000000,0.000000,...,1,0,0,1,0,0,0,0,0,0
3,4,1,0.101542,3.170056,0,11.742872,4,4.454391,0.000000,0.000000,...,1,0,0,1,0,0,0,0,0,0
4,5,1,0.135756,3.271690,0,12.234987,4,7.896588,0.000000,0.000000,...,1,0,0,1,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
75,76,1,2.553593,1.787496,1,6.399909,4,585.651356,4.974474,2.553593,...,1,0,0,1,0,0,0,0,0,0
76,77,1,2.587754,1.788167,1,6.610815,4,598.534439,4.978481,2.587754,...,1,0,0,1,0,0,0,0,0,0
77,78,1,2.621773,1.788729,1,6.329607,4,611.595714,4.981847,2.621773,...,1,0,0,1,0,0,0,0,0,0
78,79,1,2.655746,1.789203,1,6.540513,4,624.833765,4.984683,2.655746,...,1,0,0,1,0,0,0,0,0,0


## DataModule

In [18]:
class DataModule(pl.LightningDataModule):
    
    def __init__(self, train_data, valid_data, test_data, cfg):
        super().__init__()
        self.train_data = train_data
        self.valid_data = valid_data
        self.test_data = test_data
        self.cfg = cfg
        
    # 必ず呼び出される関数
    def setup(self, stage=None):
        self.train_dataset = TrainDataset(self.train_data)
        self.valid_dataset = TrainDataset(self.valid_data)
        self.test_dataset = TestDataset(self.test_data)
        
    # Trainer.fit() 時に呼び出される
    def train_dataloader(self):
        return DataLoader(self.train_dataset, **self.cfg.loader['train'])

    # Trainer.fit() 時に呼び出される
    def val_dataloader(self):
        return DataLoader(self.valid_dataset, **self.cfg.loader['valid'])

    def test_dataloader(self):
        return DataLoader(self.test_dataset, **self.cfg.loader['valid'])

## Pytorch Lightning Module

In [19]:
# for n, m in model.named_modules()の処理がよくわからないから観察

# model = CustomModel(CFG)
# for n, m in model.named_modules():
#     print(f"n:{n}")
#     print(f"m:{m}")
#     if isinstance(m, nn.LSTM):
#         print(f'isinstance nn.LSTM')
#     print('='*50)

In [23]:
#model = CustomModel(CFG)
Data = DataModule(df_train,df_train,df_train,CFG)
Data.setup()
loader = Data.train_dataloader()
tmp = loader.__iter__()
tmp.next()
# cate_seq_x, cont_seq_x, u_out, label = tmp.next()
# print(cate_seq_x.shape, cont_seq_x.shape, u_out.shape, label.shape)
# output = model(cate_seq_x,cont_seq_x)
# output.shape

[tensor([[[0.0000, 2.2962, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0320, 4.6151, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0640, 4.5983, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [2.4622, 1.7850, 1.0000,  ..., 0.0000, 0.0000, 0.0000],
          [2.4942, 1.7860, 1.0000,  ..., 0.0000, 0.0000, 0.0000],
          [2.5261, 1.7869, 1.0000,  ..., 0.0000, 0.0000, 0.0000]],
 
         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0318, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0647, 0.4595, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [2.4755, 1.7857, 1.0000,  ..., 0.0000, 0.0000, 0.0000],
          [2.5085, 1.7867, 1.0000,  ..., 0.0000, 0.0000, 0.0000],
          [2.5404, 1.7874, 1.0000,  ..., 0.0000, 0.0000, 0.0000]],
 
         [[0.0000, 1.5478, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0320, 3.2211, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0639, 3.1414, 0.0000,  ...,

In [21]:
# ====================================================
# model
# ====================================================
class CustomModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.dense_dim = cfg.dense_dim
        self.hidden_size = cfg.hidden_size
        self.num_layers = cfg.num_layers
        self.logit_dim = cfg.logit_dim
        # nn.Embedding(vocab_size, emb_dim)
        # 1層パーセプトロンのようなもの
        #self.r_emb = nn.Embedding(3, 2, padding_idx=0)
        #self.c_emb = nn.Embedding(3, 2, padding_idx=0)
        self.mlp = nn.Sequential(
            nn.Linear(len(cfg.feature_cols), self.dense_dim // 2),
            nn.ReLU(),
            nn.Linear(self.dense_dim // 2, self.dense_dim),
            #nn.Dropout(0.2),
            nn.ReLU(),
        )
        self.lstm = nn.LSTM(self.dense_dim, self.hidden_size,
                            num_layers = self.num_layers,
                            dropout=0.2, batch_first=True, bidirectional=True)
        self.head = nn.Sequential(
            nn.Linear(self.hidden_size * 2, self.logit_dim),
            nn.ReLU(),
            nn.Linear(self.logit_dim, cfg.target_size),
        )
        # LSTMやGRUは直交行列に初期化する
        for n, m in self.named_modules():
            if isinstance(m, nn.LSTM):
                print(f'init {m}')
                for param in m.parameters():
                    if len(param.shape) >= 2:
                        nn.init.orthogonal_(param.data)
                    else:
                        nn.init.normal_(param.data)
            elif isinstance(m, nn.GRU):
                print(f"init {m}")
                for param in m.parameters():
                    if len(param.shape) >= 2:
                        nn.init.orthogonal_(param.data)
                    else:
                        nn.init.normal_(param.data)

    def forward(self, x):
        bs = x.size(0)
        features = self.mlp(x)
        features, _ = self.lstm(features)
        output = self.head(features).view(bs, -1)
        return output
    
def get_model(cfg):
    model = CustomModel(cfg)
    return model

# ====================================================
# criterion
# ====================================================
def compute_metric(df, preds):
    """
    Metric for the problem, as I understood it.
    """
    
    y = np.array(df['pressure'].values.tolist())
    w = 1 - np.array(df['u_out'].values.tolist())
    
    assert y.shape == preds.shape and w.shape == y.shape, (y.shape, preds.shape, w.shape)
    
    mae = w * np.abs(y - preds)
    mae = mae.sum() / w.sum()
    
    return mae


class VentilatorLoss(nn.Module):
    """
    Directly optimizes the competition metric
    """
    def __call__(self, preds, y, u_out):
        w = 1 - u_out
        mae = w * (y - preds).abs()
        mae = mae.sum(-1) / w.sum(-1)

        return mae

def get_criterion():
    if CFG.criterion_name == 'BCEWithLogitsLoss':
        # plだとto(device)いらない
        criterion = nn.BCEWithLogitsLoss(reduction="mean")
    if CFG.criterion_name == 'CrossEntropyLoss':
        criterion = nn.CrossEntropyLoss()
    if CFG.criterion_name == 'CustomLoss1':
        # [reference]https://www.kaggle.com/theoviel/deep-learning-starter-simple-lstm
        criterion = VentilatorLoss()
    else:
        raise NotImplementedError
    return criterion
# ====================================================
# optimizer
# ====================================================
def get_optimizer(model: nn.Module, config: dict):
    """
    input:
    model:model
    config:optimizer_nameやlrが入ったものを渡す
    
    output:optimizer
    """
    optimizer_name = config.optimizer_name
    if 'Adam' == optimizer_name:
        return Adam(model.parameters(),
                    lr=config.lr,
                    weight_decay=config.weight_decay,
                    amsgrad=config.amsgrad)
    elif 'RAdam' == optimizer_name:
        return optim.RAdam(model.parameters(),
                           lr=config.lr,
                           weight_decay=config.weight_decay)
    elif 'sgd' == optimizer_name:
        return SGD(model.parameters(),
                   lr=config.lr,
                   momentum=0.9,
                   nesterov=True,
                   weight_decay=config.weight_decay,)
    else:
        raise NotImplementedError

# ====================================================
# scheduler
# ====================================================
def get_scheduler(optimizer):
    if CFG.scheduler=='ReduceLROnPlateau':
        """
        factor : 学習率の減衰率
        patience : 何ステップ向上しなければ減衰するかの値
        eps : nanとかInf回避用の微小数
        """
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
    elif CFG.scheduler=='CosineAnnealingLR':
        """
        T_max : 1 半周期のステップサイズ
        eta_min : 最小学習率(極小値)
        """
        scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
    elif CFG.scheduler=='CosineAnnealingWarmRestarts':
        """
        T_0 : 初期の繰りかえし回数
        T_mult : サイクルのスケール倍率
        """
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)
    else:
        raise NotImplementedError
    return scheduler

In [19]:
# # modelの動作確認
# model = get_model(CFG)
# Data = DataModule(df_train,df_train,df_train, CFG)
# Data.setup()
# loader = Data.train_dataloader()
# tmp = loader.__iter__()
# x, u_out, label = tmp.next()
# print(x.shape, u_out.shape, label.shape)
# output = model(x)
# print(output.shape)
# del model, Data, loader, tmp,x, u_out, label, output

In [20]:
# #schedulerの確認
# model = get_model(CFG)
# optimizer = get_optimizer(model, CFG)
# scheduler = get_scheduler(optimizer)
# from pylab import rcParams
# lrs = []
# for epoch in range(1, CFG.epochs+1):
#     scheduler.step(epoch-1)
#     lrs.append(optimizer.param_groups[0]["lr"])
# rcParams['figure.figsize'] = 20,3
# print(lrs)
# plt.plot(lrs)

In [21]:
class Trainer(pl.LightningModule):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.model = get_model(cfg)
        self.criterion = get_criterion()
    
    def forward(self, x):
        output = self.model(x)
        return output
    
    def training_step(self, batch, batch_idx):
        x, u_out, y = batch
        # mixup とかしたい場合はここに差し込む
        output = self.forward(x)
        labels = y#.unsqueeze(1)
        loss = self.criterion(output, labels ,u_out).mean()
        #self.log_dict(dict(train_loss=loss))
        self.log('train_loss', loss, on_step=True, prog_bar=True, logger=True)
        return {"loss": loss, "predictions": output, "labels": labels}
    
    def training_epoch_end(self, outputs):
        # training_stepの出力のまとまりがoutputsに入っている。
        self.log("lr", self.optimizer.param_groups[0]['lr'], prog_bar=True, logger=True)
    
    def validation_step(self, batch, batch_idx):
        x, u_out, y = batch
        output = self.forward(x)
        labels = y#.unsqueeze(1)
        loss = self.criterion(output,labels ,u_out).mean()
        self.log('val_loss', loss, on_step= True, prog_bar=True, logger=True)
        return {"predictions": output,
                "labels": labels,
                "loss": loss.item()}
    
    def validation_epoch_end(self, outputs):
        preds = []
        labels = []
        loss = 0
        for output in outputs:
            preds += output['predictions']
            labels += output['labels']
            loss += output['loss']

        labels = torch.stack(labels)
        preds = torch.stack(preds)
        loss = loss/len(outputs)
        
        self.log("val_loss_epoch", loss, prog_bar=True, logger=True)
        
    def predict_step(self, batch, batch_idx):
        x = batch
        output = self.forward(x)
        return output
        
    def test_step(self, batch, batch_idx):
        x = batch       
        output = self.forward(x)
        return output
    
    def configure_optimizers(self):
        self.optimizer = get_optimizer(self, self.cfg)
        self.scheduler = {'scheduler': get_scheduler(self.optimizer),
                          'interval': 'step', # or 'epoch'
                          'frequency': 1}
        return {'optimizer': self.optimizer, 'lr_scheduler': self.scheduler}

## Train

In [22]:
def train() -> None:
    for fold in range(CFG.n_fold):
        if not fold in CFG.trn_fold:
            continue
        print(f"{'='*38} Fold: {fold} {'='*38}")
        # Logger
        #======================================================
        lr_monitor = LearningRateMonitor(logging_interval='step')
        # 学習済重みを保存するために必要
        loss_checkpoint = ModelCheckpoint(
            dirpath=OUTPUT_DIR,
            filename=f"best_loss_fold{fold}",
            monitor="val_loss",
            save_last=True,
            save_top_k=1,
            save_weights_only=True,
            mode="min",
        )
        
        wandb_logger = WandbLogger(
            project=f'{CFG.competition}',
            group= f'{CFG.exp_name}',
            name = f'Fold{fold}',
            save_dir=OUTPUT_DIR
        )
        
        data_module = DataModule(
            df_train[df_train['fold']!=fold].reset_index(drop=True),
            df_train[df_train['fold']==fold].reset_index(drop=True), 
            df_test, 
            CFG
        )
        data_module.setup()
        
        CFG.T_max = int(math.ceil(len(data_module.train_dataloader())/CFG.grad_acc)*CFG.epochs)
        print(f"set schedular T_max {CFG.T_max}")
        #early_stopping_callback = EarlyStopping(monitor='val_loss_epoch',mode="min", patience=5)
        
        trainer = pl.Trainer(
            logger=wandb_logger,
            callbacks=[loss_checkpoint],#lr_monitor,early_stopping_callback
            default_root_dir=OUTPUT_DIR,
            accumulate_grad_batches=CFG.grad_acc,
            max_epochs=CFG.epochs,
            precision=CFG.precision,
            **CFG.trainer
        )
        # 学習
        model = Trainer(CFG)
        trainer.fit(model, data_module)
        torch.save(model.model.state_dict(),OUTPUT_DIR + '/' + f'{CFG.exp_name}_fold{fold}.pth')
        # best loss modelのロード
        best_model = Trainer.load_from_checkpoint(cfg=CFG,checkpoint_path=loss_checkpoint.best_model_path)
        # テストデータを予測して保存
        if CFG.inference:
            predictions = trainer.predict(best_model, data_module.test_dataloader())
            preds = []
            for p in predictions:
                preds += p
            preds = torch.stack(preds).flatten()
            submission['pressure'] = preds.to('cpu').detach().numpy()
            submission.to_csv(OUTPUT_DIR + '/' + f'submission_fold{fold}.csv',index=False)
        
        wandb.finish()

        
        

In [23]:
train()
send_line_notification("[locl]finished")
wandb.finish()



Using native 16bit precision.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


set schedular T_max 70500
init LSTM(512, 512, num_layers=4, batch_first=True, dropout=0.2, bidirectional=True)


  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: wandb version 0.12.4 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade



  | Name      | Type           | Params
---------------------------------------------
0 | model     | CustomModel    | 23.8 M
1 | criterion | VentilatorLoss | 0     
---------------------------------------------
23.8 M    Trainable params
0         Non-trainable params
23.8 M    Total params
95.070    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 29


Training: -1it [00:00, ?it/s]

  f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  "`Trainer.train_loop` has been renamed to `Trainer.fit_loop` and will be removed in v1.6."


init LSTM(512, 512, num_layers=4, batch_first=True, dropout=0.2, bidirectional=True)


  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 235it [00:00, ?it/s]

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_loss,0.64871
epoch,299.0
trainer/global_step,70499.0
_step,19709.0
_runtime,26357.0
_timestamp,1633498986.0
val_loss_step,0.61183
val_loss_epoch,0.62068
lr,0.0


0,1
train_loss,▅▄▂▂▂▄█▇▇▄▄▅▅▅▄▃▃▃▃▃▃▂▂▂▂▂▂▂▁▂▂▂▂▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▄▂▂▂▂▂▂▂▂▂▂▃▃▃▃█▃▃▃▃▃▃▃▃▃
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
val_loss_step,▃▃▂▂▃▆█▆▅▆▅▅▅▄▄▄▄▄▂▂▃▃▂▃▂▂▂▂▂▂▂▂▁▂▁▂▁▁▁▁
val_loss_epoch,▃▃▂▂▂▆█▇▆▅▆▅▅▄▄▃▃▄▅▃▃▂▂▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
lr,███████▇▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁


In [24]:
# print( "="*20,"last weight", "="*20)
#         print(model.state_dict()['model.seq_emb.0.weight'][0])
#         print( "="*20,"best weight", "="*20)
# model = Trainer(CFG)
# fold=0
# model = Trainer.load_from_checkpoint(checkpoint_path=OUTPUT_DIR + '/' + f'best_loss_fold{fold}.ckpt',cfg=CFG)
# model.state_dict()['model.seq_emb.0.weight'][0]

+ TestDatasetでなぜu_outを使っていない？
+ hidden_sizeの指す場所がよくわかってない

In [25]:
# class Trainer(pl.LightningModule):
#     def training_step(self, batch, batch_idx):
#         self.log('train_loss', loss)
#         return {"loss": loss, "predictions": output, "labels": labels}
    
#     def training_epoch_end(self, outputs):
#         pass
    
#     def validation_step(self, batch, batch_idx):
#         self.log('val_loss', loss, on_step= True, prog_bar=True, logger=True)
#         return {"predictions": output,"labels": labels,"loss": loss.item()}

#     def validation_epoch_end(self, outputs):
#         self.log("val_loss_epoch", loss, prog_bar=True, logger=True)