# LSTM baseline

from kuto

In [1]:
import os
import sys
import glob
import pickle
import random

import numpy as np
import pandas as pd
import scipy.stats as stats
from pathlib import Path


sys.path.append('../../')
import src.utils as utils
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

import pytorch_lightning as pl
# from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping

import wandb
from pytorch_lightning.loggers import WandbLogger


In [2]:
DATA_DIR = Path("/home/knikaido/work/Indoor-Location-Navigation/data/")
WIFI_DIR = DATA_DIR / 'indoorunifiedwifids'
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'
OUTPUT_DIR = Path('./output/')
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'

## config

In [3]:
configs = {
    'loss':{
        'name': 'MSELoss',
        'params':{}
    },
    'optimizer':{
        'name': 'Adam',
        'params':{
            'lr': 0.001,
        }
    },

    'scheduler':{
        'name': 'ReduceLROnPlateau',
        'params':{
            'factor': 0.1,
            'patience': 3,
        }
    },

    'loader':{
        'train':{
            'batch_size': 512,
            'shuffle': True,
            'num_workers': 4,
        },
        'valid':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        },
        'test':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        }
    }
}

In [4]:
# config
config = configs

# globals variable
SEED = 777
MAX_EPOCHS = 200
N_SPLITS = 5
DEBUG = False
# EXP_MESSAGE = config['globals']['exp_message']

EXP_NAME = 19
IS_SAVE = True

utils.set_seed(SEED)

In [5]:
!wandb login e8aaf98060af90035c3c28a83b34452780aeec20

/bin/sh: 1: wandb: not found


## read data

In [6]:
# training target features
NUM_FEATS = 80
BSSID_FEATS = [f'bssid_{i}' for i in range(NUM_FEATS)]
RSSI_FEATS  = [f'rssi_{i}' for i in range(NUM_FEATS)]

In [7]:
train_df = pd.read_csv(WIFI_DIR / 'train_all.csv')
test_df = pd.read_csv(WIFI_DIR / 'test_all.csv')

In [8]:
sub = pd.read_csv(DATA_DIR/'indoor-location-navigation/sample_submission.csv', index_col=0)

BSSIDとRSSIは100ずつ存在しているけど全てが必要なわけではないみたい  
ここでは20だけ取り出している。

In [9]:
# train_df[RSSI_FEATS] = train_df[RSSI_FEATS] * -1
# test_df[RSSI_FEATS] = test_df[RSSI_FEATS] * -1

In [10]:
train_df.iloc[:, 100:110]

Unnamed: 0,rssi_0,rssi_1,rssi_2,rssi_3,rssi_4,rssi_5,rssi_6,rssi_7,rssi_8,rssi_9
0,-32,-39,-47,-48,-48,-49,-51,-52,-54,-56
1,-29,-34,-47,-48,-48,-49,-52,-52,-52,-53
2,-33,-39,-48,-48,-49,-52,-54,-55,-55,-55
3,-46,-48,-49,-50,-51,-52,-54,-56,-57,-57
4,-42,-49,-51,-51,-52,-53,-54,-55,-55,-55
...,...,...,...,...,...,...,...,...,...,...
258120,-53,-63,-64,-66,-68,-68,-68,-68,-70,-71
258121,-58,-64,-66,-67,-68,-68,-69,-70,-71,-71
258122,-57,-58,-60,-64,-66,-67,-68,-69,-71,-73
258123,-58,-64,-66,-66,-68,-69,-69,-71,-71,-72


bssid_NはN個目のBSSIDを示しておりRSSI値が大きい順に番号が振られている。
100個しかない


In [20]:
# get numbers of bssids to embed them in a layer

# train
wifi_bssids = []
# bssidを列ごとにリストに入れていく
for i in BSSID_FEATS:
    wifi_bssids.extend(train_df.loc[:,i].values.tolist())
wifi_bssids = list(set(wifi_bssids))

train_wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(train): {train_wifi_bssids_size}')

# test
wifi_bssids_test = []
for i in BSSID_FEATS:
    wifi_bssids_test.extend(test_df.loc[:,i].values.tolist())
wifi_bssids_test = list(set(wifi_bssids_test))

test_wifi_bssids_size = len(wifi_bssids_test)
print(f'BSSID TYPES(test): {test_wifi_bssids_size}')


wifi_bssids.extend(wifi_bssids_test)
wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(all): {wifi_bssids_size}')


BSSID TYPES(train): 59330
BSSID TYPES(test): 31266
BSSID TYPES(all): 90596


In [21]:
# get numbers of bssids to embed them in a layer

# train
rssi_bssids = []
# bssidを列ごとにリストに入れていく
for i in RSSI_FEATS:
    rssi_bssids.extend(train_df.loc[:,i].values.tolist())
rssi_bssids = list(set(rssi_bssids))

train_rssi_bssids_size = len(rssi_bssids)
print(f'BSSID TYPES(train): {train_rssi_bssids_size}')

# test
rssi_bssids_test = []
for i in RSSI_FEATS:
    rssi_bssids_test.extend(test_df.loc[:,i].values.tolist())
rssi_bssids_test = list(set(rssi_bssids_test))

test_rssi_bssids_size = len(rssi_bssids_test)
print(f'BSSID TYPES(test): {test_rssi_bssids_size}')


rssi_bssids.extend(rssi_bssids_test)
rssi_bssids_size = len(rssi_bssids)
print(f'BSSID TYPES(all): {rssi_bssids_size}')

BSSID TYPES(train): 96
BSSID TYPES(test): 77
BSSID TYPES(all): 173


## preprocessing

In [22]:
# preprocess

le = LabelEncoder()
le.fit(wifi_bssids)
le_site = LabelEncoder()
le_site.fit(train_df['site_id'])
le_rssi = LabelEncoder()
le_rssi.fit(rssi_bssids)


def preprocess(input_df):
    output_df = input_df.copy()
    # RSSIの正規化
#     output_df.loc[:,RSSI_FEATS] = ss.transform(input_df.loc[:,RSSI_FEATS])

    # BSSIDのLE(1からふる)
    for i in BSSID_FEATS:
        output_df.loc[:,i] = le.transform(input_df.loc[:,i])
        
    for i in RSSI_FEATS:
        output_df.loc[:,i] = le_rssi.transform(input_df.loc[:,i])
#         output_df.loc[:,i] = output_df.loc[:,i] + 1  # 0からではなく1から番号を振りたいため なぜ？

    # site_idのLE
    output_df.loc[:, 'site_id'] = le_site.transform(input_df.loc[:, 'site_id'])

    # なぜ２重でやる？
#     output_df.loc[:,RSSI_FEATS] = ss.transform(output_df.loc[:,RSSI_FEATS])
    return output_df

train = preprocess(train_df)
test = preprocess(test_df)

train  

Unnamed: 0,bssid_0,bssid_1,bssid_2,bssid_3,bssid_4,bssid_5,bssid_6,bssid_7,bssid_8,bssid_9,...,rssi_95,rssi_96,rssi_97,rssi_98,rssi_99,x,y,floor,path,site_id
0,50853,34759,2666,33807,51157,34159,41463,32473,22694,14753,...,-79,-79,-79,-79,-79,107.85044,161.892620,-1,5e1580adf4c3420006d520d4,0
1,34759,50853,7246,33807,51157,34159,21287,14753,16476,5179,...,-79,-79,-79,-80,-80,107.85044,161.892620,-1,5e1580adf4c3420006d520d4,0
2,34759,50853,51157,33807,34159,22694,47981,6460,7246,47101,...,-77,-78,-78,-78,-78,98.33065,163.343340,-1,5e1580adf4c3420006d520d4,0
3,22694,33807,34159,50853,34759,3578,47981,15109,9832,4816,...,-75,-76,-76,-77,-77,98.33065,163.343340,-1,5e1580adf4c3420006d520d4,0
4,34759,34159,22694,18860,50853,3578,47981,17722,20739,51241,...,-75,-76,-76,-77,-77,98.33065,163.343340,-1,5e1580adf4c3420006d520d4,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
258120,33969,14080,15962,20656,49000,22127,911,29499,42529,31676,...,-84,-85,-85,-85,-85,122.68994,124.028015,6,5dcd5c88a4dbe7000630b084,23
258121,33969,15962,20656,911,49000,22127,29136,32337,57808,14080,...,-85,-85,-85,-85,-85,127.17589,123.677780,6,5dcd5c88a4dbe7000630b084,23
258122,911,33969,14080,15962,20656,22127,49000,29136,47077,57808,...,-84,-84,-85,-85,-85,127.17589,123.677780,6,5dcd5c88a4dbe7000630b084,23
258123,14080,15962,33969,20656,49000,911,29136,47077,57049,22127,...,-85,-85,-85,-85,-85,127.17589,123.677780,6,5dcd5c88a4dbe7000630b084,23


In [23]:
site_count = len(train['site_id'].unique())
site_count

24

## PyTorch model
- embedding layerが重要  

In [24]:
# dataset
from torch.utils.data import Dataset, DataLoader
class IndoorDataset(Dataset):
    def __init__(self, df, phase='train'):
        self.df = df
        self.phase = phase
        self.bssid_feats = df[BSSID_FEATS].values.astype(int)
        self.rssi_feats = df[RSSI_FEATS].values.astype(int)
        self.site_id = df['site_id'].values.astype(int)

        if phase in ['train', 'valid']:
            self.xy = df[['x', 'y']].values.astype(np.float32)
            self.floor = df['floor'].values.astype(np.float32)
        
    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        
        feature = {
            'BSSID_FEATS':self.bssid_feats[idx],
            'RSSI_FEATS':self.rssi_feats[idx],
            'site_id':self.site_id[idx]
        }
        if self.phase in ['train', 'valid']:
            target = {
                'xy':self.xy[idx],
                'floor':self.floor[idx]
            }
        else:
            target = {}
        return feature, target

In [31]:
import torch
from torch import nn

class LSTMModel(nn.Module):
    def __init__(self, bssid_size=94248, site_size=24, embedding_dim=64):
        super(LSTMModel, self).__init__()
        
        # bssid
        # ->64次元に圧縮後sequence化にする
        # wifi_bssids_sizeが辞書の数を表す
        self.bssid_embedding = nn.Embedding(bssid_size, 64, max_norm=True)
        # site
        # ->2次元に圧縮後sequence化する
        # site_countが辞書の数を表す       
        self.site_embedding = nn.Embedding(site_size, 64, max_norm=True)
        
        self.rssi_embedding = nn.Embedding(173, 64, max_norm=True)

        # rssi
        # 次元を64倍に線形変換
#         self.rssi = nn.Sequential(
#             nn.BatchNorm1d(NUM_FEATS),
#             nn.Linear(NUM_FEATS, NUM_FEATS * 64)
#         )
        
        concat_size = 64 + (NUM_FEATS * 64) + (NUM_FEATS * 64)
        self.linear_layer2 = nn.Sequential(
            nn.BatchNorm1d(concat_size),
            nn.Dropout(0.3),
            nn.Linear(concat_size, 256),
            nn.ReLU()
        )
        self.bn1 = nn.BatchNorm1d(concat_size)

        self.flatten = nn.Flatten()

        self.dropout1 = nn.Dropout(0.3)
        self.linear1 = nn.Linear(in_features=concat_size, out_features=256)#, bias=False)
        self.bn2 = nn.BatchNorm1d(256)

        self.batch_norm1 = nn.BatchNorm1d(1)
        self.lstm1 = nn.LSTM(input_size=256,hidden_size=128,dropout=0.3, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=128,hidden_size=16,dropout=0.1, batch_first=True)

        self.fc_xy = nn.Linear(16, 2)
        # self.fc_x = nn.Linear(16, 1)
        # self.fc_y = nn.Linear(16, 1)
        self.fc_floor = nn.Linear(16, 1)

    
    def forward(self, x):
        # input embedding
        batch_size = x["site_id"].shape[0]
        x_bssid = self.bssid_embedding(x['BSSID_FEATS'])
        x_bssid = self.flatten(x_bssid)
        
        x_site_id = self.site_embedding(x['site_id'])
        x_site_id = self.flatten(x_site_id)

        x_rssi = self.rssi_embedding(x['RSSI_FEATS'])
        x_rssi = self.flatten(x_rssi)

        x = torch.cat([x_bssid, x_site_id, x_rssi], dim=1)
        x = self.linear_layer2(x)

        # lstm layer
        x = x.view(batch_size, 1, -1)  # [batch, 1]->[batch, 1, 1]
        x = self.batch_norm1(x)
        x, _ = self.lstm1(x)
        x = torch.relu(x)
        x, _ = self.lstm2(x)
        x = torch.relu(x)

        # output [batch, 1, 1] -> [batch]
        # x_ = self.fc_x(x).view(-1)
        # y_ = self.fc_y(x).view(-1)
        xy = self.fc_xy(x).squeeze(1)
        floor = torch.relu(self.fc_floor(x)).view(-1)
        # return {"x":x_, "y":y_, "floor":floor} 
        return {"xy": xy, "floor": floor}

In [32]:
def mean_position_error(xhat, yhat, fhat, x, y, f):
    intermediate = np.sqrt(np.power(xhat-x, 2) + np.power(yhat-y, 2)) + 15 * np.abs(fhat-f)
    return intermediate.sum()/xhat.shape[0]

def to_np(input):
    return input.detach().cpu().numpy()

In [33]:
def get_optimizer(model: nn.Module, config: dict):
    optimizer_config = config["optimizer"]
    optimizer_name = optimizer_config.get("name")
    base_optimizer_name = optimizer_config.get("base_name")
    optimizer_params = optimizer_config['params']

    if hasattr(optim, optimizer_name):
        optimizer = optim.__getattribute__(optimizer_name)(model.parameters(), **optimizer_params)
        return optimizer
    else:
        base_optimizer = optim.__getattribute__(base_optimizer_name)
        optimizer = globals().get(optimizer_name)(
            model.parameters(), 
            base_optimizer,
            **optimizer_config["params"])
        return  optimizer

def get_scheduler(optimizer, config: dict):
    scheduler_config = config["scheduler"]
    scheduler_name = scheduler_config.get("name")

    if scheduler_name is None:
        return
    else:
        return optim.lr_scheduler.__getattribute__(scheduler_name)(
            optimizer, **scheduler_config["params"])


def get_criterion(config: dict):
    loss_config = config["loss"]
    loss_name = loss_config["name"]
    loss_params = {} if loss_config.get("params") is None else loss_config.get("params")
    if hasattr(nn, loss_name):
        criterion = nn.__getattribute__(loss_name)(**loss_params)
    else:
        criterion = globals().get(loss_name)(**loss_params)

    return criterion

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

In [34]:
# Learner class(pytorch-lighting)
class Learner(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.xy_criterion = get_criterion(config)
        self.f_criterion = get_criterion(config)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = self.xy_criterion(output["xy"], y["xy"])
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        xy_loss = self.xy_criterion(output["xy"], y["xy"])
        f_loss = self.f_criterion(output["floor"], y["floor"])
        loss = xy_loss  # + f_loss
        mpe = mean_position_error(
            to_np(output['xy'][:, 0]), to_np(output['xy'][:, 1]), 0, 
            to_np(y['xy'][:, 0]), to_np(y['xy'][:, 1]), 0)
        
        # floor lossは現状は無視して良い
        self.log(f'Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/xy', xy_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/floor', f_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'MPE/val', mpe, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return mpe
    
    def validation_epoch_end(self, outputs):
        avg_loss = np.mean(outputs)
        print(f'epoch = {self.current_epoch}, mpe_loss = {avg_loss}')

    def configure_optimizers(self):
        optimizer = get_optimizer(self.model, self.config)
        scheduler = get_scheduler(optimizer, self.config)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "Loss/val"}

In [35]:
# oof
def evaluate(model, loaders, phase):
    x_list = []
    y_list = []
    f_list = []
    with torch.no_grad():
        for batch in loaders[phase]:
            x, y = batch
            output = model(x)
            x_list.append(to_np(output['xy'][:, 0]))
            y_list.append(to_np(output['xy'][:, 1]))
            f_list.append(to_np(output['floor']))

    x_list = np.concatenate(x_list)
    y_list = np.concatenate(y_list)
    f_list = np.concatenate(f_list)
    return x_list, y_list, f_list

## train

In [36]:
oofs = []  # 全てのoofをdfで格納する
predictions = []  # 全ての予測値をdfで格納する
val_scores = []
# skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
gkf = GroupKFold(n_splits=N_SPLITS)
# for fold, (trn_idx, val_idx) in enumerate(skf.split(train.loc[:, 'path'], train.loc[:, 'path'])):
for fold, (trn_idx, val_idx) in enumerate(gkf.split(train.loc[:, 'path'], groups=train.loc[:, 'path'])):

    # 指定したfoldのみループを回す

    print('=' * 20)
    print(f'Fold {fold}')
    print('=' * 20)

    # train/valid data
    trn_df = train.loc[trn_idx, BSSID_FEATS + RSSI_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)
    val_df = train.loc[val_idx, BSSID_FEATS + RSSI_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)

    # data loader
    loaders = {}
    loader_config = config["loader"]
    loaders["train"] = DataLoader(IndoorDataset(trn_df, phase="train"), **loader_config["train"], worker_init_fn=worker_init_fn) 
    loaders["valid"] = DataLoader(IndoorDataset(val_df, phase="valid"), **loader_config["valid"], worker_init_fn=worker_init_fn)
    loaders["test"] = DataLoader(IndoorDataset(test, phase="test"), **loader_config["test"], worker_init_fn=worker_init_fn)
    
    # model
    model = LSTMModel(wifi_bssids_size, site_count)
    model_name = model.__class__.__name__
    
    # loggers
    RUN_NAME = f'exp{str(EXP_NAME)}'
    wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type=RUN_NAME + f'-fold-{fold}')
    wandb.run.name = RUN_NAME + f'-fold-{fold}'
    wandb_config = wandb.config
    wandb_config.model_name = model_name
    wandb.watch(model)
    
    
    loggers = []
    loggers.append(WandbLogger())

    learner = Learner(model, config)
    
    # callbacks
    callbacks = []
    checkpoint_callback = ModelCheckpoint(
        monitor=f'Loss/val',
        mode='min',
        dirpath=OUTPUT_DIR,
        verbose=False,
        filename=f'{model_name}-{learner.current_epoch}-{fold}')
    callbacks.append(checkpoint_callback)

    early_stop_callback = EarlyStopping(
        monitor='Loss/val',
        min_delta=0.00,
        patience=3,
        verbose=True,
        mode='min')
    callbacks.append(early_stop_callback)
    
    trainer = pl.Trainer(
        logger=loggers,
        checkpoint_callback=callbacks,
        max_epochs=MAX_EPOCHS,
        default_root_dir=OUTPUT_DIR,
        gpus=1,
        fast_dev_run=DEBUG,
        deterministic=True,
        benchmark=True,
#         precision=16,
#         progress_bar_refresh_rate=0  # vscodeの時progress barの動作が遅いので表示しない
        )


    trainer.fit(learner, train_dataloader=loaders['train'], val_dataloaders=loaders['valid'])

    #############
    # validation (to make oof)
    #############
    model.eval()
    oof_x, oof_y, oof_f = evaluate(model, loaders, phase="valid")
    val_df["oof_x"] = oof_x
    val_df["oof_y"] = oof_y
    val_df["oof_floor"] = oof_f
    oofs.append(val_df)
    
    val_score = mean_position_error(
        val_df["oof_x"].values, val_df["oof_y"].values, 0,
        val_df['x'].values, val_df['y'].values, 0)
    val_scores.append(val_score)
    print(f"fold {fold}: mean position error {val_score}")

    #############
    # inference
    #############
    preds_x, preds_y, preds_f = evaluate(model, loaders, phase="test")
    test_preds = pd.DataFrame(np.stack((preds_f, preds_x, preds_y))).T
    test_preds.columns = sub.columns
    test_preds["site_path_timestamp"] = test["site_path_timestamp"]
    test_preds["floor"] = test_preds["floor"].astype(int)
    predictions.append(test_preds)
    

Fold 0


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))
[34m[1mwandb[0m: Currently logged in as: [33msqrt4kaido[0m (use `wandb login --relogin` to force relogin)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.3 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.3 M    Trainable params
0         Non-trainable params
11.3 M    Total params
45.343    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 133.49254608154297


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 161.4971921939661


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 157.30697641069415


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 153.23355096451542


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 149.24403000713298


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 145.33293121560064


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 141.5003811736154


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 137.74304005902246


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 134.07235908832672


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 130.49044457428388


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 127.00500226304638


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 123.62091896297075


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 120.34110881549923


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 117.17040817407376


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 114.11065803159026


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 111.17131194219661


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 108.36473338843086


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 105.68999290706255


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 103.15406453798592


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 100.76390684715454


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 98.5189807279539


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 93.23916959959195


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 88.25155560412931


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 84.3478468882628


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 81.03264224177401


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 78.13628761495903


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 75.5612976792946


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 73.24346340741539


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 71.06406325600769


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 69.02920645281125


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 66.79427711817853


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 64.95539317513735


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 63.30154851230272


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 61.76537254044672


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 60.33896912394253


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 59.040390236065285


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 57.802847399969885


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 56.65121558940889


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 55.58429323519671


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 54.57647657559927


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 53.53611399837754


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 52.46474892844554


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 51.37975659507613


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 50.24752014164599


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 49.13778724814822


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 48.00752351917496


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 46.76199064171446


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 45.53562609824179


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 44.33962942559982


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 42.95419970368534


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 41.38170096000267


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 39.9501524829809


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 38.6202972273585


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 37.333166562259784


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 36.10171868646985


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 34.88703799033901


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 33.66773687289354


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 32.48800691499083


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 31.403928486796822


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 30.282240942931132


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 29.24033082537687


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 28.179505168457123


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 27.209254703132927


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 26.237806232154682


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 25.311874276062884


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 24.4203117408708


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 23.543247872787504


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 22.71976616438975


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 21.952948141314643


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 21.152941993409005


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 20.42981351381412


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 19.79128632457158


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 19.074464343187067


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 18.434476090721702


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 17.858768584429697


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 17.271409065822006


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 16.761245621685102


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 16.208646919714827


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 15.71683145142046


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 15.261333580050415


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 14.852817708283714


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 14.375311348903436


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 13.984631117263644


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 13.55801707978157


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 13.25147883572409


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 12.963312437149698


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 12.639728289155956


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 12.365000162130189


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 12.08849546530251


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 11.77641783233847


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 11.583190649129238


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 11.339084913768113


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 11.111227907115245


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 10.88893701110111


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 10.761215528311533


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 10.641870315887237


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 10.406972663803856


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 10.271111698628305


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 10.107173029033682


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 10.042653908982274


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 9.923464606523375


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 9.812871130830546


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 9.727910519701558


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 9.620232165874821


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 9.527599021398403


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 9.49512932251412


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 9.364724238122974


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 9.331014905417458


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 9.296615391588517


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 9.176088030942317


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 9.109360396130251


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 9.051464507825582


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 9.031233056923229


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 8.977331592446507


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 8.91316533389172


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 8.907489333810958


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 8.849850903734886


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 8.81510174040886


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 8.828419389030444


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 8.803221223816736


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 8.722579557366613


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 8.699467593422447


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 8.6762519688348


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 8.649202501678023


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 8.578079264068158


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 8.556974606811133


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 8.537588849264866


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 8.515220131027053


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 8.554900429003032


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 8.528995106385485


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 8.460335122243428


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 8.478726777036174


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 8.443244963919762


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 8.431466318184926


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 8.479213399925966


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 8.421787745642954


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 8.373051895303787


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 8.390728831360628


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 8.397806242796733


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 8.37083391255116


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 8.390270717045423


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 8.343175251082618


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 8.331903327748606


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 8.324740523292217


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 8.33594384672757


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 8.30937990258583


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 8.307652478215301


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 8.272074726284695


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 8.258226731921882


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 8.259771630659865


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 8.252333149362874


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 8.245444396957105


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 8.240445873136082


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 8.246986447880278


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 8.236629328697173


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 8.225175256912419


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 8.22770489777669


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 8.223291690584249


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 8.20751694549845


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 8.228692965965738


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 8.218248387749416


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 8.207867431940558


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 8.212958867168037


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 8.21541098410818


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 8.21210068928758


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 8.210569622281962


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 8.219771035815924


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 8.210423393616097


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 8.209256078579013


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 8.212106012078864


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 8.220317500619327


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 8.220556486070885


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 8.221785973075106


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 8.216625430140999


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 8.207416377225863


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 8.225591837317012


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 8.219899436637421


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 8.225992095871726


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 8.223549451711364


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 8.215420225162317


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 8.210736535116206


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 8.21860611449582


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 8.213561361787203


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 8.217604087669626


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 8.210039609121457


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 8.21648904752648


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 8.213210860792815


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 8.209675430671101


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 8.2194569823099


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 8.213458728990238


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 8.208866863484014


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 8.220405889454454


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 8.225755043729627


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 8.208445424772517


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 8.215563591438876


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 8.223887524957496


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 8.214514996867994


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 8.22270765235692


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 8.215936925218212


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 8.209458064888548


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 8.209997205817567
fold 0: mean position error 8.208222630435468
Fold 1


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,56.61498
Loss/xy,56.61498
Loss/floor,5.38545
MPE/val,8.20822
epoch,199.0
trainer/global_step,80799.0
_runtime,913.0
_timestamp,1617988963.0
_step,199.0


0,1
Loss/val,█▇▆▅▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▇▆▅▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,▁▁▁▁████████████████████████████████████
MPE/val,█▇▆▆▅▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.3 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.3 M    Trainable params
0         Non-trainable params
11.3 M    Total params
45.343    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 156.68993377685547


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 161.91706987322107


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 159.322196561339


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 156.77563962542044


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 154.26074043091745


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 151.77212616606653


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 149.30611598179794


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 146.8620478103629


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 144.44126328154505


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 142.04554265061714


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 139.67357980913664


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 137.32689753155273


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 135.0059062489869


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 132.71216179369492


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 130.4468402232024


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 128.2118481404433


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 126.0094178931048


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 123.8423031514456


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 121.71072527548291


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 119.61773687443763


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 117.56659021034307


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 115.5599510520353


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 113.59892968255001


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 111.6816646736336


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 109.81181056996314


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 107.98858261757142


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 106.21381756173224


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 104.48878372681162


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 102.81483369386258


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 101.19296707979802


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 99.62419110955231


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 98.09854743744235


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 92.84245118413891


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 89.86737711724808


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 87.39602587180563


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 85.17913341226728


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 83.12057209716957


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 81.19902971251037


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 78.45660287025021


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 76.28316597957978


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 73.03348856906524


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 70.68540737909707


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 68.63406953468389


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 66.77679202386926


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 65.07946246307675


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 63.526677052159606


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 62.035643927496956


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 60.673896631808496


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 59.37716148363135


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 57.66960878970947


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 56.28274774670253


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 54.48145592337927


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 53.0171337618464


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 51.6118342849607


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 50.24377797874935


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 48.97434342783693


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 47.770147731761014


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 46.53096929112491


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 45.33638894903709


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 44.180299106880064


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 43.03181849527998


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 41.88854762085663


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 40.55002540793888


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 39.20030686435688


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 37.92123458926735


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 36.70992727765109


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 35.57699516881811


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 34.39924410386216


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 33.32439913825789


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 32.23777509436054


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 31.160843975809193


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 30.123433695833697


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 29.147553191498815


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 28.233149704997043


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 27.308668161724363


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 26.410896161223263


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 25.4775171190594


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 24.652865278369546


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 23.830768742514127


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 23.104731050179183


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 22.356800729554966


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 21.667210078397737


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 20.886333858443336


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 20.234727707164677


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 19.643893478140278


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 19.02333855571758


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 18.42990651082354


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 17.93134265445316


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 17.42815623128421


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 16.927722014847646


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 16.497097751148647


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 16.087923372776338


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 15.630692598468421


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 15.267574645816133


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 14.827350568032639


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 14.520234213980673


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 14.223915213383279


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 13.854460757893744


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 13.540267854767201


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 13.257622003627503


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 12.97069477312358


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 12.77606493443893


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 12.476841461468243


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 12.197316810583471


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 12.01265345803929


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 11.725944720008865


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 11.544036839240224


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 11.397615611676775


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 11.170345464494245


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 10.957122011023618


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 10.885378902537502


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 10.703993188172632


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 10.569396268281398


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 10.405023998033881


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 10.259840334124114


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 10.170258178033313


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 10.05965095388938


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 9.985139199458517


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 9.883838281137137


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 9.761621856500607


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 9.707671922640392


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 9.59676939963739


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 9.552068215716211


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 9.459602337583387


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 9.427317698106005


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 9.33118341564785


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 9.283807708882424


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 9.258809495236058


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 9.180125267470931


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 9.090357271671017


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 9.071329379639934


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 9.020239342748944


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 9.001721921197605


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 8.890262250739193


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 8.859640904650412


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 8.850863982585537


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 8.771097870861713


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 8.73916972919613


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 8.710908177164175


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 8.676718975604796


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 8.635648940546817


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 8.598972589729412


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 8.582396520615735


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 8.532391694901499


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 8.504752016217472


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 8.448807280445488


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 8.429780985688359


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 8.40925342692007


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 8.393290662982121


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 8.331769606035396


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 8.32833351870644


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 8.319237447370952


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 8.3443149737723


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 8.265037237766112


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 8.25923099183453


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 8.245625149399718


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 8.22795786301619


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 8.205940507445561


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 8.20690325936399


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 8.191401914797623


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 8.188727625689676


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 8.159331864338665


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 8.120298212525729


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 8.080928418387767


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 8.110570011338872


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 8.074579469861856


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 8.068901160511778


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 8.033213800375865


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 8.060847342734984


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 8.074586457878473


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 8.035297885288246


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 8.039560235739447


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 7.99481283805505


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 7.964959609115547


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 7.990493018050907


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 7.985171222464595


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 7.946895210580487


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 8.000645405897652


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 7.958851385466498


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 7.947750104216552


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 7.977445672223509


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 7.959883522051466


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 7.932054659555177


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 7.935244877370346


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 7.978081600136998


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 7.962628543122813


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 7.963784139113296


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 7.904063649313631


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 7.88042535932382


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 7.8923797934348325


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 7.883333043685539


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 7.885626632099451


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 7.880657554369856


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 7.882064550996313


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 7.883196350818605


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 7.879703203738951


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 7.8746677197171415


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 7.873549638524886


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 7.871439198675583


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 7.87769514110523


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 7.878632878140786
fold 1: mean position error 7.875189377620455
Fold 2


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,47.71848
Loss/xy,47.71848
Loss/floor,5.22984
MPE/val,7.87519
epoch,199.0
trainer/global_step,80799.0
_runtime,931.0
_timestamp,1617989901.0
_step,199.0


0,1
Loss/val,█▇▆▆▅▅▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▇▆▆▅▅▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,██████▂▁█▇▇▇██▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▂▂▂▂▂▂
MPE/val,██▇▆▆▆▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.3 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.3 M    Trainable params
0         Non-trainable params
11.3 M    Total params
45.343    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 149.7208023071289


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 158.65544253589806


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 153.66363175247182


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 148.83434222008646


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 144.1406692079201


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 139.58457044038232


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 135.15554165697404


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 130.85550594669758


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 126.68877606586862


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 122.65675319442794


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 118.77377043851253


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 115.05266898313509


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 111.50424738471288


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 108.13360991329667


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 104.95278307057704


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 101.97113019095652


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 99.19290927929175


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 96.62168905188194


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 94.2664662939126


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 92.00057109332877


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 85.08577593219648


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 81.62324197177077


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 78.71957372508774


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 76.13739988125961


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 73.74510853424694


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 71.55633689791502


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 69.55094395199278


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 67.63300134679919


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 65.8413920691462


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 64.21068788869476


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 62.74407694141741


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 60.87368587882698


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 59.37217045836763


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 57.92201871477593


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 56.68892612983824


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 55.53539836801341


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 54.43992066349481


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 53.444585104060856


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 52.45346192692627


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 51.01348003449659


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 49.7879108492052


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 48.57598152964908


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 47.43855327935017


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 46.24487783293149


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 45.072662706880564


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 43.89791132612012


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 42.77890587983884


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 41.49795761337236


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 40.231507348397756


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 39.009669612213564


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 37.87232701317682


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 36.70869420178769


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 35.6005041373269


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 34.465046200460655


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 33.375832993335756


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 32.285041762892426


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 31.2586997048273


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 30.258904015402774


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 29.28245583352615


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 28.366676236254293


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 27.470422609059284


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 26.58365599358436


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 25.736255249796596


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 24.92721374563929


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 24.118774812517294


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 23.331916962436416


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 22.600084794221676


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 21.890087523785315


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 21.181306723626324


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 20.556237517564046


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 19.893532740482428


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 19.30945113873579


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 18.72490239814887


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 18.15678290999762


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 17.7112710455093


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 17.147297533831765


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 16.669633359259205


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 16.1693336531473


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 15.729692373303529


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 15.318945243848779


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 14.880342512869461


RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/user/.local/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 185, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/user/.local/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/home/user/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 84, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/home/user/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 84, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/home/user/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 74, in default_collate
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/home/user/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 74, in <dictcomp>
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/home/user/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 64, in default_collate
    return default_collate([torch.as_tensor(b) for b in batch])
  File "/home/user/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 53, in default_collate
    storage = elem.storage()._new_shared(numel)
  File "/home/user/.local/lib/python3.6/site-packages/torch/storage.py", line 138, in _new_shared
    return cls._new_using_fd(size)
RuntimeError: unable to write to file </torch_4090_3910636463>


In [None]:
if len(oofs) > 1:
    oofs_df = pd.concat(oofs)
else:
    oofs_df = oofs[0]
oofs_df.to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}.csv", index=False)
oofs_df

In [21]:
    # foldの結果を平均した後、reindexでsubmission fileにindexを合わせる
all_preds = pd.concat(predictions).groupby('site_path_timestamp').mean().reindex(sub.index)

all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,87.820877,103.378113
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,84.934212,102.221764
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,84.637535,106.694267
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,88.344582,108.472931
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,88.364525,107.649292
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,0,216.584732,91.119125
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,0,210.694809,98.997589
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,0,207.790909,106.390236
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,0,202.718658,113.294189


In [22]:
# floorの数値を置換
simple_accurate_99 = pd.read_csv('../01/submission.csv')
all_preds['floor'] = simple_accurate_99['floor'].values
all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,87.820877,103.378113
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,84.934212,102.221764
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,84.637535,106.694267
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,88.344582,108.472931
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,88.364525,107.649292
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,5,216.584732,91.119125
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,5,210.694809,98.997589
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,5,207.790909,106.390236
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,5,202.718658,113.294189


In [23]:
all_preds.to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}.csv")

In [24]:
print(f"CV:{np.mean(val_scores)}")

CV:8.134229954895378


In [25]:
wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type='summary')
wandb.run.name = 'summary'
wandb.log({'CV_score': np.mean(val_scores)})
wandb.save(utils.get_notebook_path())
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,55.20833
Loss/xy,55.20833
Loss/floor,4.40498
MPE/val,8.26543
epoch,199.0
trainer/global_step,80799.0
_runtime,957.0
_timestamp,1616836981.0
_step,199.0


0,1
Loss/val,█▇▅▅▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▇▅▅▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,▆▆▆▆▆▆▅█▇▆▅▄▃▂▂▂▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
MPE/val,█▇▆▆▅▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███




VBox(children=(Label(value=' 0.00MB of 0.49MB uploaded (0.00MB deduped)\r'), FloatProgress(value=0.00130087052…

0,1
CV_score,8.13423
_runtime,2.0
_timestamp,1616837013.0
_step,0.0


0,1
CV_score,▁
_runtime,▁
_timestamp,▁
_step,▁
