# LSTM baseline

from kuto

In [1]:
import os
import sys
import glob
import pickle
import random

import numpy as np
import pandas as pd
import scipy.stats as stats
from pathlib import Path
from glob import glob


sys.path.append('../../')
import src.utils as utils
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

import pytorch_lightning as pl
# from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping

import wandb
from pytorch_lightning.loggers import WandbLogger


In [2]:
DATA_DIR = Path("/home/knikaido/work/Indoor-Location-Navigation/data/")
WIFI_DIR = DATA_DIR / 'unified-ds-wifi-and-beacon'
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'
OUTPUT_DIR = Path('./output/')
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'

## config

In [3]:
configs = {
    'loss':{
        'name': 'MSELoss',
        'params':{}
    },
    'optimizer':{
        'name': 'Adam',
        'params':{
            'lr': 0.001,
        }
    },

    'scheduler':{
        'name': 'ReduceLROnPlateau',
        'params':{
            'factor': 0.1,
            'patience': 3,
        }
    },

    'loader':{
        'train':{
            'batch_size': 512,
            'shuffle': True,
            'num_workers': 4,
        },
        'valid':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        },
        'test':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        }
    }
}

In [4]:
# config
config = configs

# globals variable
SEED = 777
MAX_EPOCHS = 200
N_SPLITS = 5
DEBUG = False
# EXP_MESSAGE = config['globals']['exp_message']

EXP_NAME = 10
IS_SAVE = True

utils.set_seed(SEED)

## read data

In [5]:
wifi_path_list = sorted(glob(str(WIFI_DIR / '*train.csv')))
train_df = []
for path in wifi_path_list:
    train_df.append(pd.read_csv(path))
train_df = pd.concat(train_df).reset_index(drop=True)
train_df = train_df.drop('Unnamed: 0', axis=1)
train_df = train_df.rename(columns={'site':'site_id'})
train_df

Unnamed: 0,site_id,floor,path,x,y,wifi_bssid_0,wifi_rssi_0,wifi_timegap_0,wifi_bssid_1,wifi_rssi_1,...,beacon_timegap_6,beacon_macaddress_7,beacon_distance_7,beacon_timegap_7,beacon_macaddress_8,beacon_distance_8,beacon_timegap_8,beacon_macaddress_9,beacon_distance_9,beacon_timegap_9
0,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,230.03738,153.496350,c08ad78a45798cfe176a42b35c7381ae602711c5,-46,434,30f85a5e14351468a6dd13718a9da3b0d7b73685,-46,...,1421,ff9a29cb2dcb4100880d02b305d5691e578e66cf,2.083036,2827,26145e606b575396f4ca1bc439d2a9b37fdc6fa0,2.341936,398,aeb5121f95a613552e00b083ee11cd70be497ab5,3.696872,1088
1,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,231.40290,158.415150,56bc1d2557bf8225384acb992f4b1993a894db63,-47,1856,25d990fc62eb6d5a42994cd7da0f9b6a70fea57a,-47,...,2562,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,2.629927,1757,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,2.949746,2141,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,2.949746,2664
2,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,232.46200,164.416730,c0beb6ad539d9bd9333c62866c08c844ac69ab28,-48,2767,56bc1d2557bf8225384acb992f4b1993a894db63,-48,...,1351,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,5.137370,1527,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,5.137370,2667,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,5.718217,2644
3,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,233.94418,171.414170,bfaebb72653fac35c19b00e7ce484dc2897f18bd,-51,136,bfaebb72653fac35c19b00e7ce484dc2897f18bd,-52,...,100,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,6.356626,148,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,6.356626,220,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,6.356626,324
4,5a0546857ecc773753327266,-1,5e15730b1506f2000638fc29,198.36833,163.520630,24677fe9a6f29ace69792429fd85fa8f3efd0192,-49,1919,e43e8b111c0061d7cba667a23d6f1c0143bc73ac,-50,...,567,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,10.591243,2147,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,11.687424,770,26145e606b575396f4ca1bc439d2a9b37fdc6fa0,12.882066,90
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
75273,5dc8cea7659e181adb076a3f,6,5dd0d97c94e4900006125dd9,249.79349,74.839640,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-40,1848,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-43,...,3000,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000
75274,5dc8cea7659e181adb076a3f,6,5dd0d97d878f3300066c750b,249.79349,74.839640,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,1875,24489f6ba05ba6afd52ac54e453fd1a2f1128b1a,-79,...,2915,3873a11805bd7adb86762e806d0f20e56e709e76,2.341936,234,-,-99.000000,3000,-,-99.000000,3000
75275,5dc8cea7659e181adb076a3f,6,5dd0d97d878f3300066c750b,249.43129,76.241234,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,404,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,...,1048,3873a11805bd7adb86762e806d0f20e56e709e76,1.850688,636,3873a11805bd7adb86762e806d0f20e56e709e76,2.341936,2045,-,-99.000000,3000
75276,5dc8cea7659e181adb076a3f,6,5dd0d97d878f3300066c750b,242.54440,72.935265,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,2666,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-50,...,3000,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000


In [33]:
# train_df = pd.read_csv(WIFI_DIR / 'train_all.csv')
test_df = pd.read_csv(WIFI_DIR / 'test.csv')
test_df = test_df.drop('Unnamed: 0', axis=1)
test_df = test_df.rename(columns={'site':'site_id'})
test_df['site_path_timestamp'] = test_df['site_id'] + '_' + test_df['path'] + '_' + test_df['timestamp'].map('{:013}'.format)
test_df

Unnamed: 0,timestamp,site_id,floor,path,x,y,wifi_bssid_0,wifi_rssi_0,wifi_timegap_0,wifi_bssid_1,...,beacon_macaddress_7,beacon_distance_7,beacon_timegap_7,beacon_macaddress_8,beacon_distance_8,beacon_timegap_8,beacon_macaddress_9,beacon_distance_9,beacon_timegap_9,site_path_timestamp
0,10,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,889bfa434d66eed8c386ccbc90f445932c43f8dd,-58,1170,29c7d9e757292e7b2b3d00dc4dae7514531b20b4,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
1,4048,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,11567178cc5ca582a37c4733207c77739e1bf5fd,-56,1000,889bfa434d66eed8c386ccbc90f445932c43f8dd,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
2,12526,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,11567178cc5ca582a37c4733207c77739e1bf5fd,-62,1952,11567178cc5ca582a37c4733207c77739e1bf5fd,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
3,25542,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,98d67fadac518296992afddd24e97a2855af9472,-55,2019,889bfa434d66eed8c386ccbc90f445932c43f8dd,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
4,37134,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,11567178cc5ca582a37c4733207c77739e1bf5fd,-54,1977,11567178cc5ca582a37c4733207c77739e1bf5fd,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10128,35117,5a0546857ecc773753327266,0,ffcd9524c80c0fa5bb859eaf,0,0,6c381e7e7d8e984394ce02a3da912acc1b7d294e,-60,2011,6c381e7e7d8e984394ce02a3da912acc1b7d294e,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5a0546857ecc773753327266_ffcd9524c80c0fa5bb859...
10129,41230,5a0546857ecc773753327266,0,ffcd9524c80c0fa5bb859eaf,0,0,46e733fa58deea74d962874847a529fb4897e655,-59,244,6c381e7e7d8e984394ce02a3da912acc1b7d294e,...,5a95e3ee4af260d25d5c13fcf02760820cb6dbdc,67.382098,1334,-,-99.000000,3000,-,-99.000000,3000,5a0546857ecc773753327266_ffcd9524c80c0fa5bb859...
10130,51634,5a0546857ecc773753327266,0,ffcd9524c80c0fa5bb859eaf,0,0,46e733fa58deea74d962874847a529fb4897e655,-49,1166,46e733fa58deea74d962874847a529fb4897e655,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5a0546857ecc773753327266_ffcd9524c80c0fa5bb859...
10131,60483,5a0546857ecc773753327266,0,ffcd9524c80c0fa5bb859eaf,0,0,e5c220800a2d5ec83e355c3b1a2d7a141947e95f,-53,1309,6c381e7e7d8e984394ce02a3da912acc1b7d294e,...,216d1cba23183d9f79bdd0ba7e45f1d55e611c6e,18.299263,118,9b91bb4419360103cea4e1f7d434878550956b32,18.299263,2225,9e9913d8d1bf56161d199b1655220e8be5711f8f,18.800756,2213,5a0546857ecc773753327266_ffcd9524c80c0fa5bb859...


In [36]:
sub = pd.read_csv(DATA_DIR/'indoor-location-navigation/sample_submission.csv', index_col=0)
sub

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,75.0,75.0
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,75.0,75.0
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,75.0,75.0
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,75.0,75.0
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,75.0,75.0
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,0,75.0,75.0
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,0,75.0,75.0
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,0,75.0,75.0
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,0,75.0,75.0


BSSIDとRSSIは100ずつ存在しているけど全てが必要なわけではないみたい  
ここでは20だけ取り出している。

In [37]:
# training target features
NUM_FEATS = 80
BSSID_FEATS = [f'wifi_bssid_{i}' for i in range(NUM_FEATS)]
RSSI_FEATS  = [f'wifi_rssi_{i}' for i in range(NUM_FEATS)]

In [38]:
# get numbers of bssids to embed them in a layer

# train
wifi_bssids = []
# bssidを列ごとにリストに入れていく
for i in BSSID_FEATS:
    wifi_bssids.extend(train_df.loc[:,i].values.tolist())
wifi_bssids = list(set(wifi_bssids))

train_wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(train): {train_wifi_bssids_size}')

# test
wifi_bssids_test = []
for i in BSSID_FEATS:
    wifi_bssids_test.extend(test_df.loc[:,i].values.tolist())
wifi_bssids_test = list(set(wifi_bssids_test))

test_wifi_bssids_size = len(wifi_bssids_test)
print(f'BSSID TYPES(test): {test_wifi_bssids_size}')


wifi_bssids.extend(wifi_bssids_test)
wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(all): {wifi_bssids_size}')


BSSID TYPES(train): 52185
BSSID TYPES(test): 25967
BSSID TYPES(all): 78152


## preprocessing

In [39]:
train_df

Unnamed: 0,site_id,floor,path,x,y,wifi_bssid_0,wifi_rssi_0,wifi_timegap_0,wifi_bssid_1,wifi_rssi_1,...,beacon_timegap_6,beacon_macaddress_7,beacon_distance_7,beacon_timegap_7,beacon_macaddress_8,beacon_distance_8,beacon_timegap_8,beacon_macaddress_9,beacon_distance_9,beacon_timegap_9
0,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,230.03738,153.496350,c08ad78a45798cfe176a42b35c7381ae602711c5,-46,434,30f85a5e14351468a6dd13718a9da3b0d7b73685,-46,...,1421,ff9a29cb2dcb4100880d02b305d5691e578e66cf,2.083036,2827,26145e606b575396f4ca1bc439d2a9b37fdc6fa0,2.341936,398,aeb5121f95a613552e00b083ee11cd70be497ab5,3.696872,1088
1,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,231.40290,158.415150,56bc1d2557bf8225384acb992f4b1993a894db63,-47,1856,25d990fc62eb6d5a42994cd7da0f9b6a70fea57a,-47,...,2562,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,2.629927,1757,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,2.949746,2141,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,2.949746,2664
2,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,232.46200,164.416730,c0beb6ad539d9bd9333c62866c08c844ac69ab28,-48,2767,56bc1d2557bf8225384acb992f4b1993a894db63,-48,...,1351,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,5.137370,1527,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,5.137370,2667,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,5.718217,2644
3,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,233.94418,171.414170,bfaebb72653fac35c19b00e7ce484dc2897f18bd,-51,136,bfaebb72653fac35c19b00e7ce484dc2897f18bd,-52,...,100,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,6.356626,148,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,6.356626,220,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,6.356626,324
4,5a0546857ecc773753327266,-1,5e15730b1506f2000638fc29,198.36833,163.520630,24677fe9a6f29ace69792429fd85fa8f3efd0192,-49,1919,e43e8b111c0061d7cba667a23d6f1c0143bc73ac,-50,...,567,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,10.591243,2147,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,11.687424,770,26145e606b575396f4ca1bc439d2a9b37fdc6fa0,12.882066,90
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
75273,5dc8cea7659e181adb076a3f,6,5dd0d97c94e4900006125dd9,249.79349,74.839640,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-40,1848,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-43,...,3000,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000
75274,5dc8cea7659e181adb076a3f,6,5dd0d97d878f3300066c750b,249.79349,74.839640,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,1875,24489f6ba05ba6afd52ac54e453fd1a2f1128b1a,-79,...,2915,3873a11805bd7adb86762e806d0f20e56e709e76,2.341936,234,-,-99.000000,3000,-,-99.000000,3000
75275,5dc8cea7659e181adb076a3f,6,5dd0d97d878f3300066c750b,249.43129,76.241234,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,404,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,...,1048,3873a11805bd7adb86762e806d0f20e56e709e76,1.850688,636,3873a11805bd7adb86762e806d0f20e56e709e76,2.341936,2045,-,-99.000000,3000
75276,5dc8cea7659e181adb076a3f,6,5dd0d97d878f3300066c750b,242.54440,72.935265,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,2666,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-50,...,3000,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000


In [40]:
beacon_columns = [s for s in list(train_df.columns) if 'beacon_' in s]

In [41]:
# preprocess

le = LabelEncoder()
le.fit(wifi_bssids)
le_site = LabelEncoder()
le_site.fit(train_df['site_id'])

ss = StandardScaler()
ss.fit(train_df.loc[:,RSSI_FEATS])


def preprocess(input_df, le=le, le_site=le_site, ss=ss):
    output_df = input_df.copy()
    # RSSIの正規化
    output_df.loc[:,RSSI_FEATS] = ss.transform(input_df.loc[:,RSSI_FEATS])

    # BSSIDのLE(1からふる)
    for i in BSSID_FEATS:
        output_df.loc[:,i] = le.transform(input_df.loc[:,i])
#         output_df.loc[:,i] = output_df.loc[:,i] + 1  # 0からではなく1から番号を振りたいため なぜ？

    # site_idのLE
    output_df.loc[:, 'site_id'] = le_site.transform(input_df.loc[:, 'site_id'])
    
    output_df = output_df.drop(beacon_columns, axis=1)

    # なぜ２重でやる？
#     output_df.loc[:,RSSI_FEATS] = ss.transform(output_df.loc[:,RSSI_FEATS])
    return output_df

train = preprocess(train_df)
test = preprocess(test_df)

    

  return self.partial_fit(X, y)
  from ipykernel import kernelapp as app
  from ipykernel import kernelapp as app


In [42]:
site_count = len(train['site_id'].unique())
site_count

24

## PyTorch model
- embedding layerが重要  

In [44]:
# dataset
from torch.utils.data import Dataset, DataLoader
class IndoorDataset(Dataset):
    def __init__(self, df, phase='train'):
        self.df = df
        self.phase = phase
        self.bssid_feats = df[BSSID_FEATS].values.astype(int)
        self.rssi_feats = df[RSSI_FEATS].values.astype(np.float32)
        self.site_id = df['site_id'].values.astype(int)

        if phase in ['train', 'valid']:
            self.xy = df[['x', 'y']].values.astype(np.float32)
            self.floor = df['floor'].values.astype(np.float32)
        
    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        
        feature = {
            'BSSID_FEATS':self.bssid_feats[idx],
            'RSSI_FEATS':self.rssi_feats[idx],
            'site_id':self.site_id[idx]
        }
        if self.phase in ['train', 'valid']:
            target = {
                'xy':self.xy[idx],
                'floor':self.floor[idx]
            }
        else:
            target = {}
        return feature, target

In [45]:
import torch
from torch import nn

class LSTMModel(nn.Module):
    def __init__(self, bssid_size=94248, site_size=24, embedding_dim=64):
        super(LSTMModel, self).__init__()
        
        # bssid
        # ->64次元に圧縮後sequence化にする
        # wifi_bssids_sizeが辞書の数を表す
        self.bssid_embedding = nn.Embedding(bssid_size, 64, max_norm=True)
        # site
        # ->2次元に圧縮後sequence化する
        # site_countが辞書の数を表す       
        self.site_embedding = nn.Embedding(site_size, 64, max_norm=True)

        # rssi
        # 次元を64倍に線形変換
        self.rssi = nn.Sequential(
            nn.BatchNorm1d(NUM_FEATS),
            nn.Linear(NUM_FEATS, NUM_FEATS * 64)
        )
        
        concat_size = 64 + (NUM_FEATS * 64) + (NUM_FEATS * 64)
        self.linear_layer2 = nn.Sequential(
            nn.BatchNorm1d(concat_size),
            nn.Dropout(0.3),
            nn.Linear(concat_size, 256),
            nn.ReLU()
        )
        self.bn1 = nn.BatchNorm1d(concat_size)

        self.flatten = nn.Flatten()

        self.dropout1 = nn.Dropout(0.3)
        self.linear1 = nn.Linear(in_features=concat_size, out_features=256)#, bias=False)
        self.bn2 = nn.BatchNorm1d(256)

        self.batch_norm1 = nn.BatchNorm1d(1)
        self.lstm1 = nn.LSTM(input_size=256,hidden_size=128,dropout=0.3, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=128,hidden_size=16,dropout=0.1, batch_first=True)

        self.fc_xy = nn.Linear(16, 2)
        # self.fc_x = nn.Linear(16, 1)
        # self.fc_y = nn.Linear(16, 1)
        self.fc_floor = nn.Linear(16, 1)

    
    def forward(self, x):
        # input embedding
        batch_size = x["site_id"].shape[0]
        x_bssid = self.bssid_embedding(x['BSSID_FEATS'])
        x_bssid = self.flatten(x_bssid)
        
        x_site_id = self.site_embedding(x['site_id'])
        x_site_id = self.flatten(x_site_id)

        x_rssi = self.rssi(x['RSSI_FEATS'])

        x = torch.cat([x_bssid, x_site_id, x_rssi], dim=1)
        x = self.linear_layer2(x)

        # lstm layer
        x = x.view(batch_size, 1, -1)  # [batch, 1]->[batch, 1, 1]
        x = self.batch_norm1(x)
        x, _ = self.lstm1(x)
        x = torch.relu(x)
        x, _ = self.lstm2(x)
        x = torch.relu(x)

        # output [batch, 1, 1] -> [batch]
        # x_ = self.fc_x(x).view(-1)
        # y_ = self.fc_y(x).view(-1)
        xy = self.fc_xy(x).squeeze(1)
        floor = torch.relu(self.fc_floor(x)).view(-1)
        # return {"x":x_, "y":y_, "floor":floor} 
        return {"xy": xy, "floor": floor}

In [46]:
def mean_position_error(xhat, yhat, fhat, x, y, f):
    intermediate = np.sqrt(np.power(xhat-x, 2) + np.power(yhat-y, 2)) + 15 * np.abs(fhat-f)
    return intermediate.sum()/xhat.shape[0]

def to_np(input):
    return input.detach().cpu().numpy()

In [47]:
def get_optimizer(model: nn.Module, config: dict):
    optimizer_config = config["optimizer"]
    optimizer_name = optimizer_config.get("name")
    base_optimizer_name = optimizer_config.get("base_name")
    optimizer_params = optimizer_config['params']

    if hasattr(optim, optimizer_name):
        optimizer = optim.__getattribute__(optimizer_name)(model.parameters(), **optimizer_params)
        return optimizer
    else:
        base_optimizer = optim.__getattribute__(base_optimizer_name)
        optimizer = globals().get(optimizer_name)(
            model.parameters(), 
            base_optimizer,
            **optimizer_config["params"])
        return  optimizer

def get_scheduler(optimizer, config: dict):
    scheduler_config = config["scheduler"]
    scheduler_name = scheduler_config.get("name")

    if scheduler_name is None:
        return
    else:
        return optim.lr_scheduler.__getattribute__(scheduler_name)(
            optimizer, **scheduler_config["params"])


def get_criterion(config: dict):
    loss_config = config["loss"]
    loss_name = loss_config["name"]
    loss_params = {} if loss_config.get("params") is None else loss_config.get("params")
    if hasattr(nn, loss_name):
        criterion = nn.__getattribute__(loss_name)(**loss_params)
    else:
        criterion = globals().get(loss_name)(**loss_params)

    return criterion

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

In [48]:
# Learner class(pytorch-lighting)
class Learner(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.xy_criterion = get_criterion(config)
        self.f_criterion = get_criterion(config)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = self.xy_criterion(output["xy"], y["xy"])
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        xy_loss = self.xy_criterion(output["xy"], y["xy"])
        f_loss = self.f_criterion(output["floor"], y["floor"])
        loss = xy_loss  # + f_loss
        mpe = mean_position_error(
            to_np(output['xy'][:, 0]), to_np(output['xy'][:, 1]), 0, 
            to_np(y['xy'][:, 0]), to_np(y['xy'][:, 1]), 0)
        
        # floor lossは現状は無視して良い
        self.log(f'Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/xy', xy_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/floor', f_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'MPE/val', mpe, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return mpe
    
    def validation_epoch_end(self, outputs):
        avg_loss = np.mean(outputs)
        print(f'epoch = {self.current_epoch}, mpe_loss = {avg_loss}')

    def configure_optimizers(self):
        optimizer = get_optimizer(self.model, self.config)
        scheduler = get_scheduler(optimizer, self.config)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "Loss/val"}

In [49]:
# oof
def evaluate(model, loaders, phase):
    x_list = []
    y_list = []
    f_list = []
    with torch.no_grad():
        for batch in loaders[phase]:
            x, y = batch
            output = model(x)
            x_list.append(to_np(output['xy'][:, 0]))
            y_list.append(to_np(output['xy'][:, 1]))
            f_list.append(to_np(output['floor']))

    x_list = np.concatenate(x_list)
    y_list = np.concatenate(y_list)
    f_list = np.concatenate(f_list)
    return x_list, y_list, f_list

## train

In [50]:
oofs = []  # 全てのoofをdfで格納する
predictions = []  # 全ての予測値をdfで格納する
val_scores = []
# skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
gkf = GroupKFold(n_splits=N_SPLITS)
# for fold, (trn_idx, val_idx) in enumerate(skf.split(train.loc[:, 'path'], train.loc[:, 'path'])):
for fold, (trn_idx, val_idx) in enumerate(gkf.split(train.loc[:, 'path'], groups=train.loc[:, 'path'])):

    # 指定したfoldのみループを回す

    print('=' * 20)
    print(f'Fold {fold}')
    print('=' * 20)

    # train/valid data
    trn_df = train.loc[trn_idx, BSSID_FEATS + RSSI_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)
    val_df = train.loc[val_idx, BSSID_FEATS + RSSI_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)

    # data loader
    loaders = {}
    loader_config = config["loader"]
    loaders["train"] = DataLoader(IndoorDataset(trn_df, phase="train"), **loader_config["train"], worker_init_fn=worker_init_fn) 
    loaders["valid"] = DataLoader(IndoorDataset(val_df, phase="valid"), **loader_config["valid"], worker_init_fn=worker_init_fn)
    loaders["test"] = DataLoader(IndoorDataset(test, phase="test"), **loader_config["test"], worker_init_fn=worker_init_fn)
    
    # model
    model = LSTMModel(wifi_bssids_size, site_count)
    model_name = model.__class__.__name__
    
    # loggers
    RUN_NAME = f'exp{str(EXP_NAME)}'
    wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type=RUN_NAME + f'-fold-{fold}')
    wandb.run.name = RUN_NAME + f'-fold-{fold}'
    wandb_config = wandb.config
    wandb_config.model_name = model_name
    wandb.watch(model)
    
    
    loggers = []
    loggers.append(WandbLogger())

    learner = Learner(model, config)
    
    # callbacks
    callbacks = []
    checkpoint_callback = ModelCheckpoint(
        monitor=f'Loss/val',
        mode='min',
        dirpath=OUTPUT_DIR,
        verbose=False,
        filename=f'{model_name}-{learner.current_epoch}-{fold}')
    callbacks.append(checkpoint_callback)

    early_stop_callback = EarlyStopping(
        monitor='Loss/val',
        min_delta=0.00,
        patience=3,
        verbose=True,
        mode='min')
    callbacks.append(early_stop_callback)
    
    trainer = pl.Trainer(
        logger=loggers,
        checkpoint_callback=callbacks,
        max_epochs=MAX_EPOCHS,
        default_root_dir=OUTPUT_DIR,
        gpus=1,
        fast_dev_run=DEBUG,
        deterministic=True,
        benchmark=True,
#         precision=16,
#         progress_bar_refresh_rate=0  # vscodeの時progress barの動作が遅いので表示しない
        )


    trainer.fit(learner, train_dataloader=loaders['train'], val_dataloaders=loaders['valid'])

    #############
    # validation (to make oof)
    #############
    model.eval()
    oof_x, oof_y, oof_f = evaluate(model, loaders, phase="valid")
    val_df["oof_x"] = oof_x
    val_df["oof_y"] = oof_y
    val_df["oof_floor"] = oof_f
    oofs.append(val_df)
    
    val_score = mean_position_error(
        val_df["oof_x"].values, val_df["oof_y"].values, 0,
        val_df['x'].values, val_df['y'].values, 0)
    val_scores.append(val_score)
    print(f"fold {fold}: mean position error {val_score}")

    #############
    # inference
    #############
    preds_x, preds_y, preds_f = evaluate(model, loaders, phase="test")
    test_preds = pd.DataFrame(np.stack((preds_f, preds_x, preds_y))).T
    test_preds.columns = sub.columns
    test_preds["site_path_timestamp"] = test["site_path_timestamp"]
    test_preds["floor"] = test_preds["floor"].astype(int)
    predictions.append(test_preds)
    

Fold 0


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,1147.9928
Loss/xy,1147.9928
Loss/floor,5.14924
MPE/val,32.71607
epoch,199.0
trainer/global_step,23599.0
_runtime,320.0
_timestamp,1617424782.0
_step,199.0


0,1
Loss/val,██▇▇▆▆▆▆▅▅▅▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,██▇▇▆▆▆▆▅▅▅▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,███████████████▅▅▂▂▂▄▄▂▂▂▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄
MPE/val,██▇▇▇▇▆▆▆▆▅▅▅▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 10.9 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
10.9 M    Trainable params
0         Non-trainable params
10.9 M    Total params
43.772    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 185.55563354492188


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 165.08265134371243


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 163.99766857440653


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 162.96365724221252


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 161.95147122114133


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 160.95127633901743


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 159.96050663483447


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 158.97668533325196


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 157.99958421756062


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 157.0279556274414


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 156.06080928704677


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 155.09834968371268


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 154.1408560924041


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 153.18757337912535


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 152.23770108345227


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 151.2916632823455


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 150.34940711779473


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 149.4110370733799


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 148.47709880731045


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 147.5469678829878


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 146.62081740941758


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 145.69902603931916


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 144.78237377068936


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 143.86927275046324


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 142.96015122242463


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 142.05554072062174


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 141.1535084846692


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 140.2569143833258


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 139.36427657298552


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 138.47634825095153


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 137.59228337605794


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 136.71267044849887


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 135.83769537118764


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 134.96744169577576


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 134.10168157724235


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 133.24129051795373


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 132.38469468141213


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 131.5338867578751


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 130.6873838767027


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 129.84589440761468


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 129.00929897993038


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 128.1795268327762


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 127.35403846349472


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 126.5350461324056


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 125.72074436285557


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 124.91062425466684


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 124.10646284054488


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 123.30797312809871


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 122.51538535876152


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 121.72977376106458


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 120.94970241448819


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 120.17583452860514


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 119.40758578960713


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 118.64671542827901


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 117.89225407135793


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 117.14377193939991


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 116.40137348664113


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 115.66674585586938


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 114.93827285766602


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 114.2171318152012


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 113.50186994503706


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 112.79382397578313


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 112.09237958467924


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 111.39785799857897


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 110.71029013609274


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 110.03003401144957


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 109.35630293626052


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 108.68978643172827


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 108.02889209649501


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 107.37528106493828


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 106.72792594127166


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 106.08851503225473


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 105.4551578326103


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 104.82936958899865


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 104.21163538419283


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 103.60027648730156


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 102.99617591271034


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 102.3994485512758


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 101.810720316569


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 101.22895313165127


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 100.65511931394919


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 100.08982784564678


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 99.53190895471818


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 113.40392107841295


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 96.72514054714105


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 96.08202679951985


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 95.35767208001553


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 94.68623137840858


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 94.03696115933933


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 93.37027569306203


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 92.76945821321927


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 92.08342184409118


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 91.48121317349948


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 90.88312151982235


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 90.25505971664037


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 89.66198801872058


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 89.08015135740622


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 88.49142970549754


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 87.91475366445688


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 87.34701820764786


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 86.78290058038174


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 86.23075132125464


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 85.66561078780737


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 85.13937029716296


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 84.56165390992777


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 83.6174806643755


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 81.7025454105475


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 80.31683425414256


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 79.22292238871256


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 78.32030084072015


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 77.39511423355493


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 76.5425770294972


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 75.77316551208496


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 75.00098760066888


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 74.01267330463115


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 73.12399865663969


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 72.32498809618828


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 71.59208131447816


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 70.85675709064189


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 70.17123735868014


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 69.49128477145464


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 68.86172923552684


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 68.22116555434008


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 67.64157981872559


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 67.05407149975116


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 66.2978514059996


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 65.6325100287413


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 64.9331576860868


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 64.27188131381304


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 63.70644249549279


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 63.13404991932404


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 62.608371167305194


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 62.01926893087534


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 61.46139854039902


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 60.93909009786753


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 60.425414677155324


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 59.92329850319104


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 59.42733535766602


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 58.98855166312976


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 58.4984167881501


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 58.01827338781113


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 57.56880900554168


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 57.17992010361109


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 56.70852606357673


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 56.31533402173947


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 55.90178279876709


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 55.49229694268642


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 55.11059871575772


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 54.666567518772226


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 54.2938318056938


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 53.9024576480572


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 53.5443021774292


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 53.14656020433475


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 52.798758384508965


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 52.48083849686842


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 52.031177501189404


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 51.46229454431778


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 51.02438600980318


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 50.60363162603134


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 50.179155247028056


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 49.80542789361416


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 49.389967429332245


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 49.010772621937285


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 48.654226185725285


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 48.24239167433519


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 47.829493503081494


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 47.46445681254069


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 47.11694739904159


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 46.73671517005334


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 46.38787120916904


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 45.99463590964293


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 45.68372296553392


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 45.3279092104007


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 44.93301300635705


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 44.61465568053416


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 44.204643107683225


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 43.88057589408679


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 43.545641718155295


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 43.19409695649758


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 42.85444120749449


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 42.51076471377642


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 42.16611704704089


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 41.78919135362674


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 41.34022867496197


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 40.89164995535826


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 40.4572368083856


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 40.111253868005214


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 39.6755220682193


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 39.27938207724155


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 38.93953311382196


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 38.564104588826496


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 38.1898861127022


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 37.83341815655048


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 37.490699007572275


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 37.105152973761925


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 36.774672522911665


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 36.45083776620718


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 36.05231264554537


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 35.75157618889442


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 35.37415713041256


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 35.02909615100959
fold 0: mean position error 35.0405163145144
Fold 1


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,1284.76221
Loss/xy,1284.76221
Loss/floor,5.79971
MPE/val,35.04052
epoch,199.0
trainer/global_step,23599.0
_runtime,336.0
_timestamp,1617425768.0
_step,199.0


0,1
Loss/val,██▇▇▇▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
Loss/xy,██▇▇▇▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
Loss/floor,███████████████████████▁▂▂▅▇████████████
MPE/val,██▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 10.9 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
10.9 M    Trainable params
0         Non-trainable params
10.9 M    Total params
43.772    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 195.0871353149414


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 167.28099216558994


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 166.95257603571966


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 166.64768518301156


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 166.34992192586262


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 165.3201764815893


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 164.73555968847032


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 164.25396693303034


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 163.8046880868765


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 163.37172569861778


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 162.94915663890347


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 162.53398844401042


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 162.12408279027696


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 161.7181166820037


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 161.31587964571438


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 160.91609661395734


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 160.51829830071864


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 160.1227549626277


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 159.72884476490512


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 159.33625284830728


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 158.9449794084598


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 158.55506462684042


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 158.16627347897258


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 157.77792104085287


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 157.39032140878533


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 157.00390049861028


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 156.6182489444048


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 156.23314141493577


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 155.8484516045986


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 155.46432615426872


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 155.08131030156062


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 154.69845631917318


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 154.3160395304362


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 153.93433211889024


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 153.55269395877156


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 153.17176926441684


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 152.7912760221041


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 152.41117398188666


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 152.03177766066332


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 151.65258685380982


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 151.2736895634578


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 150.89591077168782


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 150.51785383958082


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 150.14049600454476


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 149.76394786345654


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 149.38780230986768


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 149.0121707035945


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 148.63720908531778


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 148.26286484155898


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 147.88881698999649


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 147.51558495545999


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 147.1432090074588


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 146.77086189465643


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 146.39959186651768


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 146.0283194126227


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 145.65794582856006


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 145.28816864796175


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 144.9188431568635


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 144.5501435402112


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 144.18175578973234


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 143.81407709366238


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 143.44676428574783


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 143.0806005428999


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 142.7143701308813


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 142.34844320248334


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 141.98366435124322


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 141.6191313327887


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 141.25583129295939


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 140.89235794849887


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 140.52984677828275


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 140.16769350492038


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 139.80625603504672


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 139.44542323381472


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 139.0852664751884


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 138.72522399119842


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 138.36629102658


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 138.00779377863958


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 137.6500705816807


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 137.29315048609024


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 136.93647443331204


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 136.58117047823393


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 136.2259169554099


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 135.87151675102038


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 135.51803978161934


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 135.16484820048015


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 134.81270896715995


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 134.46085102374735


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 134.10997454325357


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 133.75948674128605


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 133.4100367717254


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 133.060882910704


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 132.71301396687826


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 132.36530405680338


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 132.01855071630231


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 131.67210469368177


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 131.3265394161909


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 130.98194290552385


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 130.6380812718318


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 130.29463874621268


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 129.95249009743713


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 129.61091903295272


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 129.26985706427158


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 128.92947421929776


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 128.59000276418834


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 128.2513291774652


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 127.91356948461288


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 127.57617619832357


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 127.23956791804386


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 126.90397090422802


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 126.56902190966484


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 126.23528876671425


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 125.90201446337578


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 125.56989027170034


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 125.2383294129983


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 124.90798929654635


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 124.57821676792243


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 124.24908229143192


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 123.92133996425531


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 123.59417980878781


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 123.26805031605255


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 122.94247411092122


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 122.61853816692647


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 122.28670839162973


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 121.52429409516164


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 120.90127721933217


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 119.73020643576598


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 117.65394706726075


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 116.06541243333082


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 114.8053640805758


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 112.59939811902169


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 110.71003761291504


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 109.12797730274688


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 107.72040481567383


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 105.96049076960637


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 104.11953620910644


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 102.58699555030236


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 101.21057468316494


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 99.89858063917895


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 98.65433867038824


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 97.47134589659862


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 96.29273739350148


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 94.82960555736835


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 93.5591300377479


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 92.40617018479567


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 91.30081259898651


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 90.24283722608517


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 89.2043275979849


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 88.22199132381341


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 87.25038433564015


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 86.29209207388071


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 85.36231409708658


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 84.47127314836551


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 83.57975801321177


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 82.71490808144593


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 81.88767130436041


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 81.06237163543702


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 80.25430944393842


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 79.46964480571258


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 78.68240023881961


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 77.92868949694511


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 77.19095266293257


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 76.44424211550981


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 75.73329502986027


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 75.02123782818134


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 74.32987736922044


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 73.65687914628249


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 73.00547661903578


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 72.35577047299117


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 71.72347255608975


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 71.1114286373823


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 70.52267854152582


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 69.92888366992656


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 69.35824584960938


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 68.79538837335049


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 68.23575396415515


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 67.7093596580701


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 67.18494450984856


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 66.6598807995136


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 66.17544885293032


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 65.66371491750081


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 65.18475749676044


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 64.71976056710268


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 64.2564844033657


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 63.787284063681575


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 63.365943400065106


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 62.90331533872164


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 62.48880611321865


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 61.95016854604085


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 61.514873597560786


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 61.09544647901486


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 60.82343312287942


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 60.297360650087015


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 59.89975992349478


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 59.53724866035657


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 59.164349291874814


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 58.80603010715583


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 58.45333086649577


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 58.290386640108544


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 57.93632116562281


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 57.46550805018499


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 57.12002457838792
fold 1: mean position error 57.28776218397872
Fold 2


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,2680.04028
Loss/xy,2680.04028
Loss/floor,6.12897
MPE/val,57.28776
epoch,199.0
trainer/global_step,23599.0
_runtime,340.0
_timestamp,1617426113.0
_step,199.0


0,1
Loss/val,███▇▇▇▇▇▇▇▆▆▆▆▆▆▆▅▅▅▅▅▅▅▅▄▃▃▃▃▂▂▂▂▂▁▁▁▁▁
Loss/xy,███▇▇▇▇▇▇▇▆▆▆▆▆▆▆▅▅▅▅▅▅▅▅▄▃▃▃▃▂▂▂▂▂▁▁▁▁▁
Loss/floor,████████████████████████████▇▇▇▇▅▅▅▄▂▂▁▁
MPE/val,████▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▅▅▅▅▅▄▄▃▃▃▂▂▂▂▂▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 10.9 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
10.9 M    Trainable params
0         Non-trainable params
10.9 M    Total params
43.772    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 203.06328582763672


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 163.47612844613883


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 161.29857193384413


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 159.52769112220176


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 157.85077802217924


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 156.22230682373046


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 154.6274379632412


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 153.05821089133238


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 151.50907125228488


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 149.97853428767277


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 148.46504855033677


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 146.9666440327962


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 145.48552408462916


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 144.0183137942583


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 142.5664807050656


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 141.129283005152


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 139.70555537297176


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 138.29671654334433


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 136.90192953256462


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 135.5208619924692


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 134.15567682706393


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 132.80380554199218


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 131.46647916940543


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 130.1425845415164


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 128.8351511148306


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 127.54237684592223


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 126.26558284270457


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 125.00332076243866


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 123.75713295569787


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 122.52420246417705


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 121.30834254729442


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 120.10790130419609


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 118.92315591665415


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 117.75590713696602


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 116.60475448217147


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 115.47061919187887


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 114.35278036655524


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 113.25236950409719


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 112.16781204419257


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 111.10223755469688


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 110.05291042816945


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 109.02085349254119


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 108.00490575937124


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 107.00623363592686


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 106.02433031522311


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 105.06130575522398


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 104.1158995897342


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 103.19001061855218


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 102.27947907081017


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 101.3891936668983


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 100.51745905753893


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 99.66488883189666


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 98.8305158174955


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 98.01591646243364


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 97.22024014790853


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 96.44404157002766


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 95.6867286682129


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 94.95114042820074


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 94.23272839081594


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 93.53448568490835


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 92.85542577107748


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 92.19427117567797


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 91.5512464474409


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 90.92709658696101


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 90.32142265515449


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 89.73392128577599


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 89.16366962530674


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 88.61202295743503


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 88.0783984942314


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 87.56274183224409


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 87.06494817489234


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 86.58567154713165


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 86.1251205835587


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 85.68170405656863


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 85.25780359903972


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 84.8516569871169


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 84.46295044727815


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 84.0928243783804


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 83.74137476407564


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 83.40586373744867


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 83.08886916331757


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 82.78831332280086


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 82.50395972423065


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 82.23673517276079


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 81.98515882247534


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 81.74943410433255


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 81.52997873746432


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 81.32419024736454


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 81.13503276140263


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 80.95867603986692


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 80.79626746544471


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 70.43809086726262


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 69.37493504255245


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 68.63749618530274


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 67.96292997506949


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 67.33150817675468


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 66.73168152051095


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 66.1485856667543


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 65.5599021129119


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 65.09801752628424


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 64.51806161342523


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 63.805661010742185


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 63.137103041624414


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 62.50905747535901


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 61.90392438937456


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 61.37471145238632


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 60.83469268114139


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 60.32611099634415


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 59.85093907087277


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 59.35299544212146


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 58.87473908937895


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 58.18174256055783


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 57.588368508754634


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 57.076066012260235


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 56.56715793120555


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 56.08917757670085


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 55.62131020961663


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 55.20681301508194


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 54.75875543936705


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 54.34842361548008


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 53.91407498090695


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 53.522970483241934


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 53.11853505159036


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 52.7433882199801


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 52.37671535198505


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 52.00427898504795


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 51.65994151677841


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 51.29738656068459


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 50.961664463923526


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 50.62986108828814


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 50.29037085802127


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 49.965559905614604


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 49.55733233231764


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 49.01616281851744


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 48.617552581200236


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 48.16145558968569


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 47.81273422730275


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 47.44583468803992


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 47.05421241369003


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 46.71786669706687


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 46.31593988858736


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 46.006909096546664


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 45.661629138848724


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 45.28971232389792


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 44.94151136936286


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 44.64887805840908


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 44.252578138693785


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 43.92587932684482


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 43.536114409031015


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 43.252489080184546


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 42.93826012244592


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 42.530946697332915


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 42.17532229301257


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 41.837895745497484


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 41.50875940567408


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 41.12839156419803


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 40.760211714720114


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 40.42374958136143


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 40.120738410949706


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 39.76507068047157


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 39.435191599528


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 39.06519172619551


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 38.696977361043295


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 38.4108754549271


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 38.005256070846166


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 37.68698384700677


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 37.3508359175462


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 37.00483544178498


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 36.6665858293191


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 36.371369555057626


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 36.02487848722017


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 35.68968900044759


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 35.37059492942615


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 35.06211852538279


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 34.6980281047332


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 34.3959084608616


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 34.07183306278326


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 33.73915297923944


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 33.41766159595587


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 33.13983269715921


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 32.84512906685854


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 32.51152880741999


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 32.18759526717357


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 31.904190410711823


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 31.68093226017096


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 31.327433811089932


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 31.010141494946602


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 30.691223315703564


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 30.391226130265455


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 30.101357447795376


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 29.835581075228177


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 29.51512501789973


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 29.237271812634592


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 28.913024383936172


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 28.684699982863208


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 28.44449104406895


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 28.129654339032296


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 27.876698970794678


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 27.626035660963794


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 27.327440274067417


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 27.02056229909261
fold 2: mean position error 27.08224816121162
Fold 3


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,883.90381
Loss/xy,883.90381
Loss/floor,5.9133
MPE/val,27.08225
epoch,199.0
trainer/global_step,23599.0
_runtime,343.0
_timestamp,1617426461.0
_step,199.0


0,1
Loss/val,█▇▇▆▆▆▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▇▇▆▆▆▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▄▂▂▅▅▆▆▇▇▇███████████
MPE/val,██▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 10.9 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
10.9 M    Trainable params
0         Non-trainable params
10.9 M    Total params
43.772    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 192.03323364257812


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 164.41231999788883


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 162.837009986472


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 161.40078485706962


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 160.0121157917999


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 158.65009662179747


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 157.3070078065046


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 155.97943763794336


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 154.6644280253977


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 153.3606655114707


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 152.06695299624627


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 150.7838354838643


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 149.50840027174897


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 148.24339236690994


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 146.9864149317841


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 145.73970509773292


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 144.50193159438177


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 143.27169707047958


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 142.05024235551866


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 140.83770776524443


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 139.63369191364774


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 138.43866600245477


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 137.25337068370382


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 136.07671547980317


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 134.90835112911108


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 133.75105881744727


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 132.60287027650793


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 131.4643043419781


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 130.33502378479102


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 129.21590410806897


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 128.10679047687427


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 127.00738221566266


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 125.9182810637494


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 124.83917627687808


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 123.7714390501308


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 122.71389456977782


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 121.6673095629411


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 120.63247535631852


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 119.60852459279428


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 118.59688321787762


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 117.59577046761383


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 116.60696323413204


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 115.62957463533213


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 114.66536427496327


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 113.71299822034851


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 112.77374771228735


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 111.84676233067414


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 110.9320910166619


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 110.03154483924165


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 109.14351050396857


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 108.26907661241418


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 107.40598706047315


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 106.55759158909993


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 105.72242814154633


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 104.89970444266154


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 104.08996750590498


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 103.29455184445098


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 102.51423561661332


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 101.74754347808887


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 100.99464546731897


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 100.2554787652696


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 99.5303805352793


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 98.81856989530358


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 98.12128287789902


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 97.43773362840048


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 96.76808560259292


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 96.11111701989903


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 95.4685529447792


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 94.84002948459988


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 94.22547986618754


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 93.62370203174831


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 93.03734937474347


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 92.46219600548491


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 91.90113734738263


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 91.35399545125915


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 90.81990666535357


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 90.29895016582692


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 89.7917309280371


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 89.29814341141214


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 88.81920231989616


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 88.3527015907177


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 87.89931868278269


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 87.46033482513182


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 87.0360244087551


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 86.62442915543265


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 86.2272863206848


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 85.84493265259476


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 85.47549495420594


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 85.12137084813509


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 84.78050900703467


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 79.81668502267048


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 76.52229184444016


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 74.9484529234169


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 73.84355589180177


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 72.9166295102253


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 71.80559809250148


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 70.78896808132842


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 69.86758376319628


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 69.06824727849299


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 68.32826541587352


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 67.61581956891047


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 66.9458192585747


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 66.32470326200968


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 65.7051237348965


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 65.20410197621958


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 64.53697204098417


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 64.03498150952964


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 63.504367192586265


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 62.90789729685024


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 62.46988506378567


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 62.006423131663254


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 61.45168595797774


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 60.963493266082615


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 60.63711211316635


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 60.09530301024949


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 59.69616779136965


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 59.293817896697064


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 58.88999858155919


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 58.55996422944246


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 58.290634322972686


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 57.80441342352285


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 57.414856240829984


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 57.08815513463412


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 56.740679836426764


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 56.38981047306276


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 56.12803336495171


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 55.83946196146057


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 55.47487654447939


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 55.14694473739599


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 54.85129660147208


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 54.53325826058257


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 54.28155637424733


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 53.909456128966596


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 53.676747870022936


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 53.37096915774875


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 53.03676104921841


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 52.722587075287215


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 52.41939827800756


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 52.12020265944722


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 51.770144479478226


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 51.51422189414597


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 51.136124298537986


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 50.809242813982635


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 50.554747260826225


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 50.18421578246038


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 49.853106675938896


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 49.45532950955696


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 49.105952900197


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 48.81780596961914


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 48.40613793819999


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 48.02684280776363


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 47.68001057414425


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 47.28431726753616


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 46.91402895523538


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 46.538340877526814


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 46.140903108937735


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 45.73922008262548


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 45.380802310568896


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 45.04173100236533


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 44.56817857693167


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 44.21861141330762


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 43.80740978836821


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 43.405284785917225


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 43.13226120667757


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 42.75672225598935


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 42.35944353370851


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 41.9193719037681


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 41.533205739388336


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 41.10411354439654


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 40.74417899987164


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 40.330720242056486


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 39.95947038186729


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 39.61799744238984


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 39.2266275714561


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 38.859954894184106


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 38.53048122055864


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 38.127675835982615


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 37.80577810874116


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 37.44307235250918


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 37.11155871914973


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 36.74316028634899


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 36.381364168637035


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 36.072233270408645


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 35.718427373975175


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 35.38408366585699


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 35.12894434675502


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 34.72290166203719


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 34.39288648491705


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 34.05951037061387


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 33.771682015171756


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 33.44950793799187


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 33.13306173717726


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 32.837187792367985


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 32.53541803221772


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 32.220698506820604


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 31.901512489840986


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 31.582800164276467


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 31.317261989949788


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 31.052806205565226


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 30.680391558894407


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 30.457959680971893
fold 3: mean position error 30.32329018891897
Fold 4


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,1065.85254
Loss/xy,1065.85254
Loss/floor,5.71823
MPE/val,30.32329
epoch,199.0
trainer/global_step,23599.0
_runtime,343.0
_timestamp,1617426809.0
_step,199.0


0,1
Loss/val,██▇▇▆▆▅▅▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,██▇▇▆▆▅▅▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
MPE/val,██▇▇▇▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 10.9 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
10.9 M    Trainable params
0         Non-trainable params
10.9 M    Total params
43.772    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 190.50552368164062


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 160.97718698744228


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 159.48391632227506


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 158.0753774314303


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 156.7043754639065


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 155.35580949153686


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 154.02284631775197


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 152.70264668119125


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 151.3953109274356


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 150.0974690190068


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 148.80982669947036


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 147.5311934392809


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 146.26223662371797


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 145.0021212419642


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 143.7534789596779


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 142.5146382557597


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 141.28443379793765


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 140.06481146267262


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 138.85357664172776


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 137.65268397308202


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 136.45780496950502


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 135.27461296179828


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 134.09915978129166


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 132.93410229398816


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 131.7778609695066


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 130.63124763931054


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 129.49488039554222


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 128.36607259728868


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 127.24884196909537


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 126.13890658883848


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 125.03841550891526


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 123.9483641748843


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 122.8679716868869


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 121.79538486130572


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 120.73402960465533


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 119.68344347058479


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 118.6413383262745


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 117.61151464649636


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 116.59024635720367


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 115.58125041771243


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 114.5820891191418


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 113.59514772903516


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 112.61703538449298


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 111.65347680538748


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 110.69962284192563


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 109.75704227269368


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 108.8265284349377


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 107.90932296286074


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 107.00260880573168


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 106.10949836583528


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 105.2283777049582


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 104.36015593364428


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 103.50548759067308


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 102.66529317232337


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 101.83830409272666


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 101.02615650011145


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 100.22681059353593


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 99.44318347445625


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 98.67332290243988


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 97.91992892266856


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 97.18007254485346


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 96.45468555272298


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 95.74345631499605


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 95.04793756273058


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 93.96017190438731


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 90.79554427572302


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 89.33986278939362


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 88.00458302643756


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 86.87195286313117


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 85.8342088204843


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 84.69019898861502


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 83.60102670979768


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 82.27949855055019


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 81.29605605913245


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 80.22623615787033


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 79.30569760726462


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 78.46782040004761


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 77.52208940226483


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 76.66801027645043


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 75.84952894201601


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 75.10999367325394


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 74.30580390709035


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 73.59523705253663


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 72.83838506228682


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 72.12453064849412


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 71.43047168504405


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 70.75726725598273


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 70.1083239019397


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 69.45989786108143


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 68.81618234877041


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 68.22949686157914


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 67.60800273790836


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 67.08978800750585


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 66.46992758230311


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 65.87976396533026


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 65.36541122031096


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 64.83201017640832


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 64.32690652242222


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 63.8278260892908


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 63.33263468105052


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 62.65140011552452


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 62.04742575198556


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 61.49633957836747


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 60.95575381630669


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 60.40275705279168


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 59.778745413210466


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 59.295805008745425


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 58.832825680669764


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 58.61292352415321


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 57.90765779068313


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 57.49894586461754


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 57.19888919578466


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 56.69903263823038


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 56.3145498684255


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 55.99063622847847


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 55.5194816030358


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 55.1130922584718


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 54.76470777447097


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 54.38900620864401


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 54.08800986469656


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 53.69049982264422


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 53.32612679246543


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 52.9813242416259


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 52.62499843480698


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 52.280085198161295


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 51.973097778172885


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 51.56945190890399


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 51.32329907102477


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 50.97879805081132


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 50.71727175904549


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 50.33334709787906


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 50.00617637634277


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 49.717629853769196


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 49.43793047522578


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 49.12551431870883


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 48.77019126572663


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 48.44247316768972


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 48.1385302449962


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 47.865105808645055


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 47.55416827885233


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 47.20876182381082


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 46.85053953149277


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 46.53634361322375


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 46.14263419398555


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 45.8864540014098


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 45.482040504173


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 45.15601149542128


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 44.81841328539518


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 44.449512847494965


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 44.09950936942477


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 43.74496616510953


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 43.412904180996655


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 43.06063973661782


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 42.67366092731027


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 42.3492248462978


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 41.988770648937866


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 41.58439785378375


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 41.230988201350215


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 40.90010487100353


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 40.53145128640192


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 40.15692318626073


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 39.789514378533845


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 39.459960120503645


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 39.06670487230335


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 38.73502982564978


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 38.37331848282745


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 38.035834837036624


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 37.68427396273651


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 37.3288373624645


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 36.970608626164484


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 36.59548751812626


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 36.23322030864476


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 35.8986256049643


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 35.63384890963298


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 35.243257839706594


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 34.89657281546969


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 34.514656805876946


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 34.20745104783591


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 33.87260300072687


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 33.50886678987463


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 33.189099647534256


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 32.865474734482945


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 32.57748923954372


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 32.263887185405416


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 31.950129864718797


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 31.588406100941164


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 31.33645083900427


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 31.009906859405568


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 30.743504475549035


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 30.39314497396374


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 30.150352234157005


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 29.799268026858712


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 29.59658338881537


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 29.255374592014746


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 28.982765777944174


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 28.716850135023083


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 28.42556491781164


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 28.15143518977695


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 27.85781464999041


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 27.57602515351177


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 27.319380289498543
fold 4: mean position error 27.335164778088384


In [56]:
if len(oofs) > 1:
    oofs_df = pd.concat(oofs)
else:
    oofs_df = oofs[0]
oofs_df.to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}.csv", index=False)
oofs_df

Unnamed: 0,wifi_bssid_0,wifi_bssid_1,wifi_bssid_2,wifi_bssid_3,wifi_bssid_4,wifi_bssid_5,wifi_bssid_6,wifi_bssid_7,wifi_bssid_8,wifi_bssid_9,...,wifi_rssi_77,wifi_rssi_78,wifi_rssi_79,site_id,x,y,floor,oof_x,oof_y,oof_floor
0,39873,10121,39095,16067,52302,13060,39873,16067,10121,52302,...,0.279928,0.283362,0.286632,0,230.03738,153.496350,-1,172.351013,141.136810,0.000000
1,17965,7792,39921,19108,29424,39921,16191,7792,19108,29424,...,0.306497,0.309732,0.309077,0,231.40290,158.415150,-1,172.350861,141.137222,0.000000
2,39921,17965,16191,7792,19108,52934,52302,16067,10121,39095,...,0.314088,0.317267,0.320299,0,232.46200,164.416730,-1,172.350220,141.135941,0.000000
3,39705,39705,39873,39873,39705,16067,52302,39095,13060,10121,...,0.302701,0.305965,0.309077,0,233.94418,171.414170,-1,172.350784,141.137772,0.000000
4,26417,12652,41723,21199,16863,39050,23964,19076,17470,10752,...,0.283724,0.287129,0.290373,0,210.86192,165.376080,-1,172.349136,141.141098,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15050,19175,19175,19175,7503,18513,13910,7503,18513,13910,7503,...,0.238176,0.241923,0.245484,23,249.43129,76.241234,6,176.420517,120.608803,0.005381
15051,19175,19175,19175,13910,32715,13910,32715,13910,7503,34182,...,0.249563,0.253224,0.256706,23,237.22395,73.177680,6,175.924881,119.061211,0.000000
15052,19175,19175,19175,7503,32715,7503,7503,7038,7038,32715,...,0.249563,0.253224,0.256706,23,242.54440,72.935265,6,178.619904,120.864449,0.000000
15053,19175,19175,19175,7503,7038,7503,7503,7038,32715,13910,...,0.249563,0.253224,0.256706,23,249.43129,76.241234,6,175.069153,116.344002,0.000000


In [51]:
    # foldの結果を平均した後、reindexでsubmission fileにindexを合わせる
all_preds = pd.concat(predictions).groupby('site_path_timestamp').mean().reindex(sub.index)

all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,90.989883,101.665627
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,84.448128,101.667984
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,88.571632,105.758469
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,90.495361,105.674240
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,90.122734,106.457397
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,0,163.001083,119.541588
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,0,162.870422,119.380234
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,0,165.387131,123.414742
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,0,167.301056,126.359894


In [52]:
# floorの数値を置換
simple_accurate_99 = pd.read_csv('../01/submission.csv')
all_preds['floor'] = simple_accurate_99['floor'].values
all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,90.989883,101.665627
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,84.448128,101.667984
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,88.571632,105.758469
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,90.495361,105.674240
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,90.122734,106.457397
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,5,163.001083,119.541588
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,5,162.870422,119.380234
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,5,165.387131,123.414742
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,5,167.301056,126.359894


In [54]:
all_preds.to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}.csv")

In [53]:
print(f"CV:{np.mean(val_scores)}")

CV:35.413796325342425


In [55]:
wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type='summary')
wandb.run.name = 'summary'
wandb.log({'CV_score': np.mean(val_scores)})
wandb.save(utils.get_notebook_path())
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,899.63599
Loss/xy,899.63599
Loss/floor,6.1533
MPE/val,27.33517
epoch,199.0
trainer/global_step,23599.0
_runtime,343.0
_timestamp,1617427157.0
_step,199.0


0,1
Loss/val,██▇▇▆▆▅▅▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,██▇▇▆▆▅▅▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,█████████████▆▆▄▄▄▅▄▄▄▅▄▄▄▃▃▃▂▂▂▁▁▁▁▁▁▁▁
MPE/val,██▇▇▇▆▆▆▆▅▅▅▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███




VBox(children=(Label(value=' 0.00MB of 0.58MB uploaded (0.00MB deduped)\r'), FloatProgress(value=0.00109057002…

0,1
CV_score,35.4138
_runtime,2.0
_timestamp,1617429740.0
_step,0.0


0,1
CV_score,▁
_runtime,▁
_timestamp,▁
_step,▁
