# LSTM baseline

from kuto

In [1]:
import os
import sys
import glob
import pickle
import random

import numpy as np
import pandas as pd
import scipy.stats as stats
from pathlib import Path
from glob import glob
from tqdm import tqdm

sys.path.append('../../')
import src.utils as utils
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

import pytorch_lightning as pl
# from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping

import wandb
from pytorch_lightning.loggers import WandbLogger


In [2]:
DATA_DIR = Path("/home/knikaido/work/Indoor-Location-Navigation/data/")
WIFI_DIR = DATA_DIR / 'unified-ds-wifi-and-beacon'
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'
OUTPUT_DIR = Path('./output/')
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'

## config

In [3]:
configs = {
    'loss':{
        'name': 'MSELoss',
        'params':{}
    },
    'optimizer':{
        'name': 'Adam',
        'params':{
            'lr': 0.001,
        }
    },

    'scheduler':{
        'name': 'ReduceLROnPlateau',
        'params':{
            'factor': 0.1,
            'patience': 3,
        }
    },

    'loader':{
        'train':{
            'batch_size': 512,
            'shuffle': True,
            'num_workers': 4,
        },
        'valid':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        },
        'test':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        }
    }
}

In [4]:
# config
config = configs

# globals variable
SEED = 777
MAX_EPOCHS = 200
N_SPLITS = 5
DEBUG = False
# EXP_MESSAGE = config['globals']['exp_message']

EXP_NAME = 12
IS_SAVE = True

utils.set_seed(SEED)

## read data

In [5]:
wifi_path_list = sorted(glob(str(WIFI_DIR / '*train.csv')))
train_df = []
for path in wifi_path_list:
    train_df.append(pd.read_csv(path))
train_df = pd.concat(train_df).reset_index(drop=True)
train_df = train_df.drop('Unnamed: 0', axis=1)
train_df = train_df.rename(columns={'site':'site_id'})
train_df

Unnamed: 0,site_id,floor,path,x,y,wifi_bssid_0,wifi_rssi_0,wifi_timegap_0,wifi_bssid_1,wifi_rssi_1,...,beacon_timegap_6,beacon_macaddress_7,beacon_distance_7,beacon_timegap_7,beacon_macaddress_8,beacon_distance_8,beacon_timegap_8,beacon_macaddress_9,beacon_distance_9,beacon_timegap_9
0,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,230.03738,153.496350,c08ad78a45798cfe176a42b35c7381ae602711c5,-46,434,30f85a5e14351468a6dd13718a9da3b0d7b73685,-46,...,1421,ff9a29cb2dcb4100880d02b305d5691e578e66cf,2.083036,2827,26145e606b575396f4ca1bc439d2a9b37fdc6fa0,2.341936,398,aeb5121f95a613552e00b083ee11cd70be497ab5,3.696872,1088
1,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,231.40290,158.415150,56bc1d2557bf8225384acb992f4b1993a894db63,-47,1856,25d990fc62eb6d5a42994cd7da0f9b6a70fea57a,-47,...,2562,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,2.629927,1757,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,2.949746,2141,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,2.949746,2664
2,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,232.46200,164.416730,c0beb6ad539d9bd9333c62866c08c844ac69ab28,-48,2767,56bc1d2557bf8225384acb992f4b1993a894db63,-48,...,1351,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,5.137370,1527,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,5.137370,2667,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,5.718217,2644
3,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,233.94418,171.414170,bfaebb72653fac35c19b00e7ce484dc2897f18bd,-51,136,bfaebb72653fac35c19b00e7ce484dc2897f18bd,-52,...,100,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,6.356626,148,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,6.356626,220,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,6.356626,324
4,5a0546857ecc773753327266,-1,5e15730b1506f2000638fc29,198.36833,163.520630,24677fe9a6f29ace69792429fd85fa8f3efd0192,-49,1919,e43e8b111c0061d7cba667a23d6f1c0143bc73ac,-50,...,567,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,10.591243,2147,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,11.687424,770,26145e606b575396f4ca1bc439d2a9b37fdc6fa0,12.882066,90
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
75273,5dc8cea7659e181adb076a3f,6,5dd0d97c94e4900006125dd9,249.79349,74.839640,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-40,1848,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-43,...,3000,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000
75274,5dc8cea7659e181adb076a3f,6,5dd0d97d878f3300066c750b,249.79349,74.839640,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,1875,24489f6ba05ba6afd52ac54e453fd1a2f1128b1a,-79,...,2915,3873a11805bd7adb86762e806d0f20e56e709e76,2.341936,234,-,-99.000000,3000,-,-99.000000,3000
75275,5dc8cea7659e181adb076a3f,6,5dd0d97d878f3300066c750b,249.43129,76.241234,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,404,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,...,1048,3873a11805bd7adb86762e806d0f20e56e709e76,1.850688,636,3873a11805bd7adb86762e806d0f20e56e709e76,2.341936,2045,-,-99.000000,3000
75276,5dc8cea7659e181adb076a3f,6,5dd0d97d878f3300066c750b,242.54440,72.935265,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,2666,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-50,...,3000,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000


In [6]:
# train_df = pd.read_csv(WIFI_DIR / 'train_all.csv')
test_df = pd.read_csv(WIFI_DIR / 'test.csv')
test_df = test_df.drop('Unnamed: 0', axis=1)
test_df = test_df.rename(columns={'site':'site_id'})
test_df['site_path_timestamp'] = test_df['site_id'] + '_' + test_df['path'] + '_' + test_df['timestamp'].map('{:013}'.format)
test_df

Unnamed: 0,timestamp,site_id,floor,path,x,y,wifi_bssid_0,wifi_rssi_0,wifi_timegap_0,wifi_bssid_1,...,beacon_macaddress_7,beacon_distance_7,beacon_timegap_7,beacon_macaddress_8,beacon_distance_8,beacon_timegap_8,beacon_macaddress_9,beacon_distance_9,beacon_timegap_9,site_path_timestamp
0,10,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,889bfa434d66eed8c386ccbc90f445932c43f8dd,-58,1170,29c7d9e757292e7b2b3d00dc4dae7514531b20b4,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
1,4048,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,11567178cc5ca582a37c4733207c77739e1bf5fd,-56,1000,889bfa434d66eed8c386ccbc90f445932c43f8dd,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
2,12526,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,11567178cc5ca582a37c4733207c77739e1bf5fd,-62,1952,11567178cc5ca582a37c4733207c77739e1bf5fd,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
3,25542,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,98d67fadac518296992afddd24e97a2855af9472,-55,2019,889bfa434d66eed8c386ccbc90f445932c43f8dd,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
4,37134,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,11567178cc5ca582a37c4733207c77739e1bf5fd,-54,1977,11567178cc5ca582a37c4733207c77739e1bf5fd,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10128,35117,5a0546857ecc773753327266,0,ffcd9524c80c0fa5bb859eaf,0,0,6c381e7e7d8e984394ce02a3da912acc1b7d294e,-60,2011,6c381e7e7d8e984394ce02a3da912acc1b7d294e,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5a0546857ecc773753327266_ffcd9524c80c0fa5bb859...
10129,41230,5a0546857ecc773753327266,0,ffcd9524c80c0fa5bb859eaf,0,0,46e733fa58deea74d962874847a529fb4897e655,-59,244,6c381e7e7d8e984394ce02a3da912acc1b7d294e,...,5a95e3ee4af260d25d5c13fcf02760820cb6dbdc,67.382098,1334,-,-99.000000,3000,-,-99.000000,3000,5a0546857ecc773753327266_ffcd9524c80c0fa5bb859...
10130,51634,5a0546857ecc773753327266,0,ffcd9524c80c0fa5bb859eaf,0,0,46e733fa58deea74d962874847a529fb4897e655,-49,1166,46e733fa58deea74d962874847a529fb4897e655,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5a0546857ecc773753327266_ffcd9524c80c0fa5bb859...
10131,60483,5a0546857ecc773753327266,0,ffcd9524c80c0fa5bb859eaf,0,0,e5c220800a2d5ec83e355c3b1a2d7a141947e95f,-53,1309,6c381e7e7d8e984394ce02a3da912acc1b7d294e,...,216d1cba23183d9f79bdd0ba7e45f1d55e611c6e,18.299263,118,9b91bb4419360103cea4e1f7d434878550956b32,18.299263,2225,9e9913d8d1bf56161d199b1655220e8be5711f8f,18.800756,2213,5a0546857ecc773753327266_ffcd9524c80c0fa5bb859...


In [7]:
sub = pd.read_csv(DATA_DIR/'indoor-location-navigation/sample_submission.csv', index_col=0)
sub

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,75.0,75.0
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,75.0,75.0
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,75.0,75.0
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,75.0,75.0
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,75.0,75.0
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,0,75.0,75.0
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,0,75.0,75.0
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,0,75.0,75.0
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,0,75.0,75.0


BSSIDとRSSIは100ずつ存在しているけど全てが必要なわけではないみたい  
ここでは20だけ取り出している。

In [8]:
# training target features
NUM_FEATS = 100
BSSID_FEATS = [f'wifi_bssid_{i}' for i in range(NUM_FEATS)]
RSSI_FEATS  = [f'wifi_rssi_{i}' for i in range(NUM_FEATS)]
TIMEGAP_FEATS  = [f'wifi_timegap_{i}' for i in range(NUM_FEATS)]

In [9]:
# get numbers of bssids to embed them in a layer

# train
wifi_bssids = []
# bssidを列ごとにリストに入れていく
for i in BSSID_FEATS:
    wifi_bssids.extend(train_df.loc[:,i].values.tolist())
wifi_bssids = list(set(wifi_bssids))

train_wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(train): {train_wifi_bssids_size}')

# test
wifi_bssids_test = []
for i in BSSID_FEATS:
    wifi_bssids_test.extend(test_df.loc[:,i].values.tolist())
wifi_bssids_test = list(set(wifi_bssids_test))

test_wifi_bssids_size = len(wifi_bssids_test)
print(f'BSSID TYPES(test): {test_wifi_bssids_size}')


wifi_bssids.extend(wifi_bssids_test)
wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(all): {wifi_bssids_size}')


BSSID TYPES(train): 54496
BSSID TYPES(test): 27843
BSSID TYPES(all): 82339


## preprocessing

In [10]:
timegaps = train_df[TIMEGAP_FEATS].values
bssids = train_df[BSSID_FEATS].values
rssis = train_df[RSSI_FEATS].values

In [11]:
ordered_timegaps = []
ordered_bssids = []
ordered_rssis = []

for i in tqdm(range(len(train_df))):
    order = np.argsort(timegaps[i])
    ordered_timegaps.append(timegaps[i][order])
    ordered_bssids.append(bssids[i][order])
    ordered_rssis.append(rssis[i][order])
    
ordered_timegaps = np.array(ordered_timegaps)
ordered_bssids = np.array(ordered_bssids)    
ordered_rssis = np.array(ordered_rssis)

100%|██████████| 75278/75278 [00:00<00:00, 137776.38it/s]


In [12]:
train_df[TIMEGAP_FEATS] = ordered_timegaps
train_df[BSSID_FEATS] = ordered_bssids
train_df[RSSI_FEATS] = ordered_rssis
train_df

Unnamed: 0,site_id,floor,path,x,y,wifi_bssid_0,wifi_rssi_0,wifi_timegap_0,wifi_bssid_1,wifi_rssi_1,...,beacon_timegap_6,beacon_macaddress_7,beacon_distance_7,beacon_timegap_7,beacon_macaddress_8,beacon_distance_8,beacon_timegap_8,beacon_macaddress_9,beacon_distance_9,beacon_timegap_9
0,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,230.03738,153.496350,c08ad78a45798cfe176a42b35c7381ae602711c5,-46,434,bfaebb72653fac35c19b00e7ce484dc2897f18bd,-68,...,1421,ff9a29cb2dcb4100880d02b305d5691e578e66cf,2.083036,2827,26145e606b575396f4ca1bc439d2a9b37fdc6fa0,2.341936,398,aeb5121f95a613552e00b083ee11cd70be497ab5,3.696872,1088
1,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,231.40290,158.415150,b0cb9376f8cd87d3e49443e6c7464fe86ea900e3,-66,74,bfaebb72653fac35c19b00e7ce484dc2897f18bd,-59,...,2562,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,2.629927,1757,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,2.949746,2141,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,2.949746,2664
2,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,232.46200,164.416730,3ce63428d3759cd84971e7ca5114981b2fd18449,-64,835,5c43215d2d7fa8309c357b0534b4e5fb064e7cfa,-56,...,1351,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,5.137370,1527,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,5.137370,2667,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,5.718217,2644
3,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,233.94418,171.414170,bfaebb72653fac35c19b00e7ce484dc2897f18bd,-51,136,8e57a3171309c6ddd6d203133878fa10e48b5d97,-67,...,100,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,6.356626,148,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,6.356626,220,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,6.356626,324
4,5a0546857ecc773753327266,-1,5e15730b1506f2000638fc29,198.36833,163.520630,24677fe9a6f29ace69792429fd85fa8f3efd0192,-49,1919,739ddddef59162677eb0a789f4e58475b07248fc,-76,...,567,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,10.591243,2147,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,11.687424,770,26145e606b575396f4ca1bc439d2a9b37fdc6fa0,12.882066,90
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
75273,5dc8cea7659e181adb076a3f,6,5dd0d97c94e4900006125dd9,249.79349,74.839640,7150550f0583994950cd821d38d6d9cf7a554b64,-87,56,05642eab9b5515eeaa54e50b7bfa93abf9594455,-87,...,3000,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000
75274,5dc8cea7659e181adb076a3f,6,5dd0d97d878f3300066c750b,249.79349,74.839640,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,1875,b6d081cf13776bfa566f563f7c25bc51e5f5e0e2,-88,...,2915,3873a11805bd7adb86762e806d0f20e56e709e76,2.341936,234,-,-99.000000,3000,-,-99.000000,3000
75275,5dc8cea7659e181adb076a3f,6,5dd0d97d878f3300066c750b,249.43129,76.241234,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,404,d1f7a64fee4b811b4dd0226bef2be71f01214b00,-85,...,1048,3873a11805bd7adb86762e806d0f20e56e709e76,1.850688,636,3873a11805bd7adb86762e806d0f20e56e709e76,2.341936,2045,-,-99.000000,3000
75276,5dc8cea7659e181adb076a3f,6,5dd0d97d878f3300066c750b,242.54440,72.935265,7150550f0583994950cd821d38d6d9cf7a554b64,-87,792,5964a27e0cb3344b0a18540e6b3120c433971c38,-84,...,3000,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000


In [13]:
beacon_columns = [s for s in list(train_df.columns) if 'beacon_' in s]

In [14]:
# preprocess

le = LabelEncoder()
le.fit(wifi_bssids)
le_site = LabelEncoder()
le_site.fit(train_df['site_id'])

ss = StandardScaler()
ss.fit(train_df.loc[:,RSSI_FEATS])

ss_gap = StandardScaler()
ss_gap.fit(train_df.loc[:,TIMEGAP_FEATS])


def preprocess(input_df, le=le, le_site=le_site, ss=ss):
    output_df = input_df.copy()
    # RSSIの正規化
    output_df.loc[:,RSSI_FEATS] = ss.transform(input_df.loc[:,RSSI_FEATS])
    
    # gapの正規化
    output_df.loc[:,TIMEGAP_FEATS] = ss_gap.transform(input_df.loc[:,TIMEGAP_FEATS])

    # BSSIDのLE(1からふる)
    for i in BSSID_FEATS:
        output_df.loc[:,i] = le.transform(input_df.loc[:,i])
#         output_df.loc[:,i] = output_df.loc[:,i] + 1  # 0からではなく1から番号を振りたいため なぜ？

    # site_idのLE
    output_df.loc[:, 'site_id'] = le_site.transform(input_df.loc[:, 'site_id'])
    
    output_df = output_df.drop(beacon_columns, axis=1)

    # なぜ２重でやる？
#     output_df.loc[:,RSSI_FEATS] = ss.transform(output_df.loc[:,RSSI_FEATS])
    return output_df

train = preprocess(train_df)
test = preprocess(test_df)
train    

  return self.partial_fit(X, y)
  return self.partial_fit(X, y)


Unnamed: 0,site_id,floor,path,x,y,wifi_bssid_0,wifi_rssi_0,wifi_timegap_0,wifi_bssid_1,wifi_rssi_1,...,wifi_timegap_96,wifi_bssid_97,wifi_rssi_97,wifi_timegap_97,wifi_bssid_98,wifi_rssi_98,wifi_timegap_98,wifi_bssid_99,wifi_rssi_99,wifi_timegap_99
0,0,-1,5e15730aa280850006f3d005,230.03738,153.496350,41663,0.222849,-0.515114,41490,0.132786,...,-0.067173,45919,0.327659,-0.069780,28038,0.364874,-0.072009,5971,0.318894,-0.074546
1,0,-1,5e15730aa280850006f3d005,231.40290,158.415150,38161,0.060209,-1.067310,41490,0.206979,...,-0.705287,46296,0.361087,-0.707686,41300,0.348248,-0.709593,13645,0.414933,-0.711877
2,0,-1,5e15730aa280850006f3d005,232.46200,164.416730,13281,0.076473,0.099971,19970,0.231710,...,0.802258,16806,0.401201,0.799367,10708,0.364874,0.796698,41711,0.428180,0.793818
3,0,-1,5e15730aa280850006f3d005,233.94418,171.414170,41490,0.182189,-0.972210,30748,0.141029,...,-0.547753,13281,0.384487,-0.550203,10708,0.381499,-0.552189,1797,0.335453,-0.554536
4,0,-1,5e15730b1506f2000638fc29,198.36833,163.520630,7849,0.198453,1.762696,25034,0.066836,...,-0.888745,28277,0.371116,-0.891084,27665,0.358224,-0.892898,20567,0.325518,-0.895110
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
75273,23,6,5dd0d97c94e4900006125dd9,249.79349,74.839640,24515,-0.110562,-1.094920,1146,-0.023845,...,1.266885,0,-2.757797,1.263842,0,-2.740793,1.260938,0,-2.721228,1.257875
75274,23,6,5dd0d97d878f3300066c750b,249.79349,74.839640,20039,0.222849,1.695205,39456,-0.032089,...,1.266885,0,-2.757797,1.263842,0,-2.740793,1.260938,0,-2.721228,1.257875
75275,23,6,5dd0d97d878f3300066c750b,249.43129,76.241234,20039,0.222849,-0.561130,45394,-0.007358,...,1.266885,0,-2.757797,1.263842,0,-2.740793,1.260938,0,-2.721228,1.257875
75276,23,6,5dd0d97d878f3300066c750b,242.54440,72.935265,24515,-0.110562,0.034015,19357,0.000886,...,1.238967,24138,0.307601,1.235934,16364,0.301696,1.233044,47275,0.322206,1.229992


In [15]:
site_count = len(train['site_id'].unique())
site_count

24

## PyTorch model
- embedding layerが重要  

In [16]:
# dataset
from torch.utils.data import Dataset, DataLoader
class IndoorDataset(Dataset):
    def __init__(self, df, phase='train'):
        self.df = df
        self.phase = phase
        self.bssid_feats = df[BSSID_FEATS].values.astype(int)
        self.rssi_feats = df[RSSI_FEATS].values.astype(np.float32)
        self.timegap_feats = df[TIMEGAP_FEATS].values.astype(np.float32)
        self.site_id = df['site_id'].values.astype(int)

        if phase in ['train', 'valid']:
            self.xy = df[['x', 'y']].values.astype(np.float32)
            self.floor = df['floor'].values.astype(np.float32)
        
    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        
        feature = {
            'BSSID_FEATS':self.bssid_feats[idx],
            'RSSI_FEATS':self.rssi_feats[idx],
            'TIMEGAP_FEATS':self.timegap_feats[idx],
            'site_id':self.site_id[idx]
        }
        if self.phase in ['train', 'valid']:
            target = {
                'xy':self.xy[idx],
                'floor':self.floor[idx]
            }
        else:
            target = {}
        return feature, target

In [17]:
import torch
from torch import nn

class LSTMModel(nn.Module):
    def __init__(self, bssid_size=94248, site_size=24, embedding_dim=64):
        super(LSTMModel, self).__init__()
        
        # bssid
        # ->64次元に圧縮後sequence化にする
        # wifi_bssids_sizeが辞書の数を表す
        self.bssid_embedding = nn.Embedding(bssid_size, 64, max_norm=True)
        # site
        # ->2次元に圧縮後sequence化する
        # site_countが辞書の数を表す       
        self.site_embedding = nn.Embedding(site_size, 64, max_norm=True)

        # rssi
        # 次元を64倍に線形変換
        self.rssi = nn.Sequential(
            nn.BatchNorm1d(NUM_FEATS),
            nn.Linear(NUM_FEATS, NUM_FEATS * 64)
        )
        
        self.timegap = nn.Sequential(
            nn.BatchNorm1d(NUM_FEATS),
            nn.Linear(NUM_FEATS, NUM_FEATS * 64)
        )
        
        concat_size = 64 + (NUM_FEATS * 64) + (NUM_FEATS * 64) + (NUM_FEATS * 64)
        self.linear_layer2 = nn.Sequential(
            nn.BatchNorm1d(concat_size),
            nn.Dropout(0.3),
            nn.Linear(concat_size, 256),
            nn.ReLU()
        )
        self.bn1 = nn.BatchNorm1d(concat_size)

        self.flatten = nn.Flatten()

        self.dropout1 = nn.Dropout(0.3)
        self.linear1 = nn.Linear(in_features=concat_size, out_features=256)#, bias=False)
        self.bn2 = nn.BatchNorm1d(256)

        self.batch_norm1 = nn.BatchNorm1d(1)
        self.lstm1 = nn.LSTM(input_size=256,hidden_size=128,dropout=0.3, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=128,hidden_size=16,dropout=0.1, batch_first=True)

        self.fc_xy = nn.Linear(16, 2)
        # self.fc_x = nn.Linear(16, 1)
        # self.fc_y = nn.Linear(16, 1)
        self.fc_floor = nn.Linear(16, 1)

    
    def forward(self, x):
        # input embedding
        batch_size = x["site_id"].shape[0]
        x_bssid = self.bssid_embedding(x['BSSID_FEATS'])
        x_bssid = self.flatten(x_bssid)
        
        x_site_id = self.site_embedding(x['site_id'])
        x_site_id = self.flatten(x_site_id)

        x_rssi = self.rssi(x['RSSI_FEATS'])
        
        x_timegap = self.timegap(x['TIMEGAP_FEATS'])

        x = torch.cat([x_bssid, x_site_id, x_rssi, x_timegap], dim=1)
        x = self.linear_layer2(x)

        # lstm layer
        x = x.view(batch_size, 1, -1)  # [batch, 1]->[batch, 1, 1]
        x = self.batch_norm1(x)
        x, _ = self.lstm1(x)
        x = torch.relu(x)
        x, _ = self.lstm2(x)
        x = torch.relu(x)

        # output [batch, 1, 1] -> [batch]
        # x_ = self.fc_x(x).view(-1)
        # y_ = self.fc_y(x).view(-1)
        xy = self.fc_xy(x).squeeze(1)
        floor = torch.relu(self.fc_floor(x)).view(-1)
        # return {"x":x_, "y":y_, "floor":floor} 
        return {"xy": xy, "floor": floor}

In [18]:
def mean_position_error(xhat, yhat, fhat, x, y, f):
    intermediate = np.sqrt(np.power(xhat-x, 2) + np.power(yhat-y, 2)) + 15 * np.abs(fhat-f)
    return intermediate.sum()/xhat.shape[0]

def to_np(input):
    return input.detach().cpu().numpy()

In [19]:
def get_optimizer(model: nn.Module, config: dict):
    optimizer_config = config["optimizer"]
    optimizer_name = optimizer_config.get("name")
    base_optimizer_name = optimizer_config.get("base_name")
    optimizer_params = optimizer_config['params']

    if hasattr(optim, optimizer_name):
        optimizer = optim.__getattribute__(optimizer_name)(model.parameters(), **optimizer_params)
        return optimizer
    else:
        base_optimizer = optim.__getattribute__(base_optimizer_name)
        optimizer = globals().get(optimizer_name)(
            model.parameters(), 
            base_optimizer,
            **optimizer_config["params"])
        return  optimizer

def get_scheduler(optimizer, config: dict):
    scheduler_config = config["scheduler"]
    scheduler_name = scheduler_config.get("name")

    if scheduler_name is None:
        return
    else:
        return optim.lr_scheduler.__getattribute__(scheduler_name)(
            optimizer, **scheduler_config["params"])


def get_criterion(config: dict):
    loss_config = config["loss"]
    loss_name = loss_config["name"]
    loss_params = {} if loss_config.get("params") is None else loss_config.get("params")
    if hasattr(nn, loss_name):
        criterion = nn.__getattribute__(loss_name)(**loss_params)
    else:
        criterion = globals().get(loss_name)(**loss_params)

    return criterion

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

In [20]:
# Learner class(pytorch-lighting)
class Learner(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.xy_criterion = get_criterion(config)
        self.f_criterion = get_criterion(config)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = self.xy_criterion(output["xy"], y["xy"])
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        xy_loss = self.xy_criterion(output["xy"], y["xy"])
        f_loss = self.f_criterion(output["floor"], y["floor"])
        loss = xy_loss  # + f_loss
        mpe = mean_position_error(
            to_np(output['xy'][:, 0]), to_np(output['xy'][:, 1]), 0, 
            to_np(y['xy'][:, 0]), to_np(y['xy'][:, 1]), 0)
        
        # floor lossは現状は無視して良い
        self.log(f'Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/xy', xy_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/floor', f_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'MPE/val', mpe, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return mpe
    
    def validation_epoch_end(self, outputs):
        avg_loss = np.mean(outputs)
        print(f'epoch = {self.current_epoch}, mpe_loss = {avg_loss}')

    def configure_optimizers(self):
        optimizer = get_optimizer(self.model, self.config)
        scheduler = get_scheduler(optimizer, self.config)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "Loss/val"}

In [21]:
# oof
def evaluate(model, loaders, phase):
    x_list = []
    y_list = []
    f_list = []
    with torch.no_grad():
        for batch in loaders[phase]:
            x, y = batch
            output = model(x)
            x_list.append(to_np(output['xy'][:, 0]))
            y_list.append(to_np(output['xy'][:, 1]))
            f_list.append(to_np(output['floor']))

    x_list = np.concatenate(x_list)
    y_list = np.concatenate(y_list)
    f_list = np.concatenate(f_list)
    return x_list, y_list, f_list

## train

In [22]:
oofs = []  # 全てのoofをdfで格納する
predictions = []  # 全ての予測値をdfで格納する
val_scores = []
# skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
gkf = GroupKFold(n_splits=N_SPLITS)
# for fold, (trn_idx, val_idx) in enumerate(skf.split(train.loc[:, 'path'], train.loc[:, 'path'])):
for fold, (trn_idx, val_idx) in enumerate(gkf.split(train.loc[:, 'path'], groups=train.loc[:, 'path'])):

    # 指定したfoldのみループを回す

    print('=' * 20)
    print(f'Fold {fold}')
    print('=' * 20)

    # train/valid data
    trn_df = train.loc[trn_idx, BSSID_FEATS + RSSI_FEATS + TIMEGAP_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)
    val_df = train.loc[val_idx, BSSID_FEATS + RSSI_FEATS + TIMEGAP_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)

    # data loader
    loaders = {}
    loader_config = config["loader"]
    loaders["train"] = DataLoader(IndoorDataset(trn_df, phase="train"), **loader_config["train"], worker_init_fn=worker_init_fn) 
    loaders["valid"] = DataLoader(IndoorDataset(val_df, phase="valid"), **loader_config["valid"], worker_init_fn=worker_init_fn)
    loaders["test"] = DataLoader(IndoorDataset(test, phase="test"), **loader_config["test"], worker_init_fn=worker_init_fn)
    
    # model
    model = LSTMModel(wifi_bssids_size, site_count)
    model_name = model.__class__.__name__
    
    # loggers
    RUN_NAME = f'exp{str(EXP_NAME)}'
    wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type=RUN_NAME + f'-fold-{fold}')
    wandb.run.name = RUN_NAME + f'-fold-{fold}'
    wandb_config = wandb.config
    wandb_config.model_name = model_name
    wandb.watch(model)
    
    
    loggers = []
    loggers.append(WandbLogger())

    learner = Learner(model, config)
    
    # callbacks
    callbacks = []
    checkpoint_callback = ModelCheckpoint(
        monitor=f'Loss/val',
        mode='min',
        dirpath=OUTPUT_DIR,
        verbose=False,
        filename=f'{model_name}-{learner.current_epoch}-{fold}')
    callbacks.append(checkpoint_callback)

    early_stop_callback = EarlyStopping(
        monitor='Loss/val',
        min_delta=0.00,
        patience=3,
        verbose=True,
        mode='min')
    callbacks.append(early_stop_callback)
    
    trainer = pl.Trainer(
        logger=loggers,
        checkpoint_callback=callbacks,
        max_epochs=MAX_EPOCHS,
        default_root_dir=OUTPUT_DIR,
        gpus=1,
        fast_dev_run=DEBUG,
        deterministic=True,
        benchmark=True,
#         precision=16,
#         progress_bar_refresh_rate=0  # vscodeの時progress barの動作が遅いので表示しない
        )


    trainer.fit(learner, train_dataloader=loaders['train'], val_dataloaders=loaders['valid'])

    #############
    # validation (to make oof)
    #############
    model.eval()
    oof_x, oof_y, oof_f = evaluate(model, loaders, phase="valid")
    val_df["oof_x"] = oof_x
    val_df["oof_y"] = oof_y
    val_df["oof_floor"] = oof_f
    oofs.append(val_df)
    
    val_score = mean_position_error(
        val_df["oof_x"].values, val_df["oof_y"].values, 0,
        val_df['x'].values, val_df['y'].values, 0)
    val_scores.append(val_score)
    print(f"fold {fold}: mean position error {val_score}")

    #############
    # inference
    #############
    preds_x, preds_y, preds_f = evaluate(model, loaders, phase="test")
    test_preds = pd.DataFrame(np.stack((preds_f, preds_x, preds_y))).T
    test_preds.columns = sub.columns
    test_preds["site_path_timestamp"] = test["site_path_timestamp"]
    test_preds["floor"] = test_preds["floor"].astype(int)
    predictions.append(test_preds)
    

Fold 0


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))
[34m[1mwandb[0m: Currently logged in as: [33msqrt4kaido[0m (use `wandb login --relogin` to force relogin)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 16.7 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
16.7 M    Trainable params
0         Non-trainable params
16.7 M    Total params
66.851    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 185.5595245361328


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 164.49281600560897


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 163.1592658996582


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 161.88499822372043


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 160.63703464605868


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 159.40410357744264


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 158.18332325862005


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 156.97271748078174


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 155.77053212874975


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 154.5778028439253


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 153.39177748851287


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 152.21167782514524


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 151.03959638155422


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 149.8734328441131


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 148.7141954275278


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 147.56199849446614


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 146.41626508663862


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 145.2783009455754


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 144.1468610127767


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 143.023731642503


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 141.9055333259778


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 140.7952689537635


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 139.6915391384027


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 138.59554754403922


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 137.50636676886143


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 136.42446470994216


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 135.3503657023112


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 134.28357704358223


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 133.22506742232886


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 132.1735729315342


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 131.1302962474334


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 130.09518945156


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 129.06779008523011


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 128.0494493337778


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 127.03918213477502


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 126.03754454392653


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 125.045999223758


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 124.06079664963941


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 123.0855005508814


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 122.11877938295022


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 121.16257249881059


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 120.21591268686147


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 119.2787370730669


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 118.34958156683506


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 117.43351023747371


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 116.52448973044372


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 115.62861360403208


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 114.74128827804175


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 113.86376717396271


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 112.99679641723633


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 112.14091798831255


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 111.29384565108862


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 110.45921993744679


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 109.63447441687951


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 108.82009496444311


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 108.01541457543006


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 107.22201512654622


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 106.43766715220916


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 105.66443963662171


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 104.90189098455967


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 104.15009600321451


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 103.41025764269708


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 102.68110907138923


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 101.9625509408804


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 101.25589127173791


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 100.5599531418238


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 99.87753344804813


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 99.20540806696964


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 98.54445948478502


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 97.89559264549843


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 97.25899879259941


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 96.63470685909957


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 96.02259799272586


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 95.42294062100925


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 94.83550033569335


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 94.26118574876051


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 93.69676388471554


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 93.14559798607459


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 92.60497321104391


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 92.07669581877879


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 91.56034382551144


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 91.05475535270494


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 90.56067769955366


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 90.07866561107147


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 89.60775338197367


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 89.14845260228866


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 84.37401091746796


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 83.07056876940605


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 81.8412696838379


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 80.70980229500013


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 79.69902072808682


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 78.7450178879958


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 77.90850724440355


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 77.12694571079352


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 76.3374066181672


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 75.60681993533403


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 74.88859379108135


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 74.22427925696739


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 73.13610981672238


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 72.37961263412085


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 71.59956107506385


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 70.73331382458028


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 69.8246881436079


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 68.93263161488068


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 68.14494942396115


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 67.48663334968762


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 66.76192816220797


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 66.10014690496982


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 65.47248851091435


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 64.86462938846687


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 64.25762158418313


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 63.718710341820355


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 63.21515098963028


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 62.67867131844545


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 62.19690207456931


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 61.66712273328732


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 61.2163592069577


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 60.74081598917643


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 60.296783344562236


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 59.82356040172088


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 59.422240472451236


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 58.99092879662147


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 58.54978498801207


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 58.17351997571114


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 57.71753078851944


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 57.334421240977754


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 56.96591331775372


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 56.59437874525021


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 56.243258099678236


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 55.94543704497508


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 55.59648592044145


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 55.329027268825435


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 54.90747779944004


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 54.59842000129895


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 54.238987467839166


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 53.92442366771209


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 53.62448058006091


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 53.346665142744015


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 53.079868820386054


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 52.67180675115341


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 52.31562014359694


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 52.022954681592104


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 51.732450817792845


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 51.39632501357641


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 51.125788864722615


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 50.74556478842711


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 50.427762134258565


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 50.13606571784386


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 49.81954657725799


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 49.47403568854699


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 49.16253327100705


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 48.823203140650044


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 48.476468311212


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 48.14154616624881


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 47.81264220017653


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 47.44823428422977


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 47.17746446560591


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 46.8264436232738


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 46.49703626877223


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 46.084849113073105


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 45.76540428552872


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 45.4150907320854


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 45.03681356968024


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 44.692476751865485


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 44.31415840051113


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 44.01554822188157


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 43.683074750655734


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 43.307819439814644


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 42.97896277109782


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 42.63902771289532


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 42.27151792966402


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 41.93612194550344


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 41.6005220804459


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 41.23411494524051


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 40.91912106244992


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 40.55545764825283


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 40.20214519989796


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 39.88492977924836


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 39.54463704916147


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 39.203482539837175


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 38.87465975834773


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 38.511341921488444


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 38.18541737580911


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 37.83966203836294


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 37.55388233967316


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 37.14115958091541


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 36.854203082353635


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 36.52819104805971


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 36.217924264761116


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 35.897335255451694


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 35.56076274040417


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 35.28094641367594


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 34.917859886854124


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 34.6259804334396


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 34.294181121923984


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 34.042096162453674


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 33.709826826437926


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 33.41528062820434


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 33.12467867533366


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 32.825547384604434


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 32.54885458823962
fold 0: mean position error 32.55860588925218
Fold 1


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,1115.9071
Loss/xy,1115.9071
Loss/floor,5.50095
MPE/val,32.55861
epoch,199.0
trainer/global_step,23599.0
_runtime,396.0
_timestamp,1617444856.0
_step,199.0


0,1
Loss/val,██▇▇▆▆▆▅▅▅▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,██▇▇▆▆▆▅▅▅▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▄▅▅▅▅▅▅▆▆▆▆▇▇▇██████
MPE/val,██▇▇▇▇▆▆▆▅▅▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 16.7 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
16.7 M    Trainable params
0         Non-trainable params
16.7 M    Total params
66.851    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 195.19263458251953


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 166.19086111997947


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 165.2265829233023


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 164.31361410678963


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 163.42023311517178


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 162.53902415740185


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 161.66586931668795


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 160.79959548558946


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 159.93873279278097


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 159.08202657455053


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 158.229982425005


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 157.38134947556716


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 156.5369358551808


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 155.69534910153118


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 154.85691459851387


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 154.02244207920174


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 153.18983586629233


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 152.36098333505484


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 151.5349292363876


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 150.7119874611879


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 149.89121660085823


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 149.07508487212354


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 148.2615723830003


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 147.45168779813326


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 146.64640068641074


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 145.8438414940467


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 145.04429720365084


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 144.24821180685973


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 143.45489977323092


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 142.66592678167882


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 141.8786236983079


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 141.09559699816583


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 140.31518114530124


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 139.53891913585173


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 138.76473416059446


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 137.9945147196452


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 137.22750044602614


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 136.4644401941544


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 135.7042106237167


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 134.948194924379


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 134.19509056286932


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 133.445570901724


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 132.69971084594727


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 131.95742133702987


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 131.21847252478966


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 130.4833818582388


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 129.75154690864758


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 129.02384357941457


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 128.2993466499524


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 127.57844720498109


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 126.86230906951121


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 126.14911100925544


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 125.44103751549353


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 124.7368649996244


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 124.03614157652243


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 123.34028795682467


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 122.64901030124763


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 121.96131396171374


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 121.27845376821665


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 120.60057482597155


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 119.92703774281037


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 119.25849173130133


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 118.59482409159342


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 117.93558400472006


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 117.28161525237255


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 116.6323509803185


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 115.98848323332957


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 115.34937254098745


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 114.71532909686748


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 114.08565899775577


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 113.4626636113876


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 112.84382497347318


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 112.23021593338403


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 111.62289047241211


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 111.02113206325433


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 110.42416296738844


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 109.83263546870305


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 109.24498260693672


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 108.66320790021847


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 108.0870896266057


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 107.51544151306152


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 106.94990628560384


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 106.3902188227727


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 105.83527931800255


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 105.28723598382412


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 104.74406447777382


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 104.20700618059207


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 103.67459020370093


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 103.14981818565956


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 102.62970759073893


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 102.11470742836975


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 101.60621535472382


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 101.10323439378004


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 100.60600314996182


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 100.1145483946189


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 99.6292105845916


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 99.14902282127967


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 98.67634936601688


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 98.20978846427721


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 97.749773534139


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 97.29634160750952


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 96.84910579583584


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 96.40821931301019


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 95.9731169480544


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 95.5453670501709


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 95.1229498056265


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 94.70615629538513


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 94.29624473376153


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 93.8929048000238


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 93.49470022152632


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 93.10216224866035


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 92.71646153376652


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 92.33599458352113


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 91.96232235248272


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 91.59476042527419


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 91.23339356642504


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 90.87806076636681


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 90.52887418698042


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 90.18515127133101


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 89.84770309252617


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 89.51697201851087


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 89.19184874510152


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 88.87354719699958


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 88.56121599246293


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 88.25501803862743


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 83.59011721488757


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 81.65792674529247


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 79.92322491376828


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 78.5023435152494


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 77.29340946490947


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 76.35325103172889


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 75.49535585061098


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 74.67535581344214


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 73.93272788218964


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 73.20831639216496


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 72.60495992807242


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 71.86926660782251


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 71.10961628449269


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 70.42363207890438


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 69.82882071274977


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 69.21506697825897


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 68.65819897773939


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 68.11629238617726


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 67.58139214148888


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 67.0540532430013


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 66.57890443068284


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 66.07593475733047


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 65.62203906132625


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 65.12999704801119


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 64.73435927170974


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 63.633010795788884


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 62.60894780281262


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 61.78500460111177


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 61.07500567313953


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 60.52436211414826


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 59.971153342418184


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 59.39466082744109


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 58.80680585029798


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 58.35721853207319


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 57.84593733274019


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 57.34411091437707


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 56.93336808131291


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 56.49130479372465


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 56.0949508373554


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 55.59253341968243


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 55.38487830039783


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 54.62552398779453


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 54.19210489224165


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 53.77212247114915


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 53.38702213091728


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 53.00848419971955


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 52.657219123840335


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 52.22328147399119


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 51.880706767546826


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 51.49815005277976


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 51.1347687794612


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 50.72633592165434


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 50.393463917267624


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 49.97332457517966


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 49.68598348177397


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 49.32784061920949


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 48.93415546906301


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 48.58269143715883


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 48.28024277320275


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 47.882111573830635


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 47.56424324818146


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 47.19476703252548


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 46.811170128064276


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 46.53195200210963


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 46.094580508501096


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 45.704077627719975


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 45.43412366035657


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 45.127398451780664


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 44.66794847830748


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 44.28384866959009


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 43.94549875992995


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 43.54736380454822


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 43.2178205587925


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 42.87847757094946


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 42.48652781462058


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 42.1977208113059
fold 1: mean position error 42.31937834532694
Fold 2


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,1667.77417
Loss/xy,1667.77417
Loss/floor,6.01595
MPE/val,42.31938
epoch,199.0
trainer/global_step,23599.0
_runtime,404.0
_timestamp,1617445267.0
_step,199.0


0,1
Loss/val,██▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
Loss/xy,██▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
Loss/floor,▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▁▂▂▂▂▆▆▆▆▇▇▇███
MPE/val,███▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 16.7 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
16.7 M    Trainable params
0         Non-trainable params
16.7 M    Total params
66.851    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 202.7873992919922


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 164.21426900227866


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 163.24711068960335


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 162.33288850050707


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 161.4391528007312


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 160.5567389072516


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 159.68321251502402


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 158.81673118395685


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 157.9550444187262


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 157.0986355121319


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 156.246667773907


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 155.39814807207156


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 154.5533763200809


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 153.7115255893805


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 152.87380396525066


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 152.0384076240735


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 151.20648330297226


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 150.37744690332656


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 149.55123220590445


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 148.72853325085762


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 147.90874835283327


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 147.09072787945087


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 146.2778102776943


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 145.46691378079927


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 144.66012048965845


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 143.85551667824768


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 143.05421346028646


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 142.25589998685396


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 141.46042130299105


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 140.66829296014248


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 139.87928228133762


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 139.09354523878832


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 138.31114239814954


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 137.53191960652669


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 136.75644318018203


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 135.98474114246858


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 135.2164058196239


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 134.45064995594512


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 133.68919194539387


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 132.9324977483505


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 132.17723134358724


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 131.4269759642772


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 130.6790305504432


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 129.93687866406563


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 129.19785396869366


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 128.46320161085862


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 127.73259546328813


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 127.006071579762


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 126.28391866439429


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 125.56569931812776


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 124.85147000826322


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 124.14240685487405


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 123.4361743829189


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 122.7348885560647


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 122.03800164247171


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 121.34600888765775


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 120.65818017812876


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 119.97517104515663


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 119.29645845462115


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 118.6234616010617


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 117.95438194274902


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 117.29013635684282


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 116.63053532135793


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 115.97757704319098


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 115.32755506466597


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 114.68392592210036


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 114.04507520626753


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 113.41135497459999


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 112.78270252912472


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 112.15968262110002


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 111.54179783356496


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 110.92780845837714


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 110.3208492963742


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 109.71861852010092


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 109.1216651134002


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 108.5306891808143


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 107.94394761109965


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 107.3631224999061


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 106.78729679401104


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 106.21709310091458


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 105.6522385328244


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 105.0933512956668


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 104.53958848806528


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 103.99261577312764


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 103.45028778467423


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 102.91487152882111


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 102.38560877090845


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 101.86132341042543


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 101.34314621167304


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 100.8315976754213


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 100.32570238357934


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 99.82615200923038


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 99.33363918402256


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 98.84659847357334


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 98.36599300091083


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 97.89167161599183


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 97.4226200788449


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 96.96098880278758


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 96.50552505101912


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 96.05704161815154


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 95.61429551931528


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 95.17818351158729


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 93.05290657434709


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 91.17212473551432


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 89.64803925538675


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 88.27983639056865


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 87.16740546593299


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 86.18573976174379


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 85.28919404836802


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 84.42295037294046


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 83.53811508569963


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 82.7926750769982


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 82.05824104700334


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 81.3403042133038


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 80.67564451755622


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 79.86784210205079


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 78.9878869472406


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 78.14160307859763


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 77.40277399894518


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 76.72232347146058


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 75.98468372638408


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 75.33805535145295


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 74.75108043964092


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 74.12004393553121


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 73.54809777186466


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 72.9744350091005


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 72.41855079944317


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 71.91657361739722


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 71.32866699023126


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 70.82382855048546


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 70.3039040981195


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 69.79741049546462


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 69.2974968347794


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 68.72404361627041


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 66.84335788824619


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 66.09858334370148


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 65.44467841906425


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 65.01330276391445


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 64.29910759559044


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 63.55373148795886


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 62.993853681515425


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 62.43905252309946


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 61.91642558269012


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 61.45691933265099


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 60.927109004289676


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 60.455544887444916


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 59.79877605927297


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 59.28406149057241


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 58.836335920676206


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 58.381406485728725


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 57.915101525722406


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 57.49726363451053


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 57.05693028278839


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 56.74373225187644


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 56.18772815313095


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 55.822800817245096


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 55.22448208148663


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 54.839766903412645


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 54.45373647151849


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 54.00973679469182


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 53.66471533653064


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 53.30189695602808


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 52.95690650939942


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 52.61690600468562


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 52.26650388668745


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 51.849764046302205


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 51.06516510401016


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 50.68286017393454


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 50.30803575760279


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 49.92285817953257


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 49.60352731362367


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 49.28734121567164


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 48.99005475166516


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 48.59141403589493


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 48.33826514513065


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 48.032342617328354


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 47.750149863805525


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 47.44674189640925


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 47.09577240577111


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 46.839841681260324


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 46.52274045699682


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 46.21478223556127


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 45.922936806311974


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 45.656287677471454


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 45.33815621596116


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 45.06191225785475


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 44.86367265261137


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 44.54815581395076


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 44.23786595173371


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 44.02968482971191


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 43.701996191954


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 43.427185386266466


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 43.0828848423102


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 42.80061970246144


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 42.49534581013214


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 42.17945649562738


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 41.89260968428392


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 41.62549772507105


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 41.36013916700314


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 41.04127777295235


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 40.73160143632155
fold 2: mean position error 40.82932465039846
Fold 3


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,1579.53662
Loss/xy,1579.53662
Loss/floor,5.84241
MPE/val,40.82932
epoch,199.0
trainer/global_step,23599.0
_runtime,409.0
_timestamp,1617445683.0
_step,199.0


0,1
Loss/val,██▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
Loss/xy,██▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
Loss/floor,█████████████████████▆▆███▄▄▅▄▄▂▂▂▁▁▁▁▁▁
MPE/val,███▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 16.7 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
16.7 M    Trainable params
0         Non-trainable params
16.7 M    Total params
66.851    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 192.10541534423828


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 164.2591101604959


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 162.9053963886943


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 161.6257924950641


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 160.37541360778317


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 159.1437288981513


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 157.9259979174333


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 156.71801114289656


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 155.520618500149


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 154.3310998102701


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 153.1491952340192


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 151.97499764975333


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 150.80810608794724


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 149.64785778518655


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 148.49438053935813


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 147.3485033720204


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 146.20916222712077


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 145.0773694270283


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 143.9510769031666


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 142.83247668009832


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 141.72008714783402


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 140.61703345748728


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 139.51842634904403


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 138.42724842556817


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 137.34362411007595


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 136.26702493123963


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 135.1967990881387


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 134.1369044551143


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 133.0826322276043


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 132.03544437512875


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 130.99761704461778


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 129.96791414392933


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 128.94549145291586


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 127.93131406172846


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 126.92597303528716


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 125.92741007167551


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 124.93887411845479


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 123.95804871453178


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 122.98596275071591


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 122.02342197891211


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 121.06906350607265


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 120.12520017700687


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 119.1907538458537


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 118.26526176772065


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 117.35018225345827


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 116.44425233604443


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 115.54827045741672


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 114.66244084577822


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 113.7868457161478


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 112.92212320963542


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 112.06752960770218


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 111.22468268107292


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 110.3922255769444


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 109.57063240014413


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 108.76011909878004


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 107.95986645463584


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 107.17008524724251


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 106.39175039343597


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 105.62528541552656


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 104.8681135819538


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 104.12432376413146


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 103.3914390809678


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 102.67034264593693


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 101.96096747393769


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 101.26240934596161


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 100.57535092573427


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 99.90137544592031


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 99.2383909744533


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 98.58821708814342


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 97.94846959444251


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 97.32019120061264


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 96.70544338134177


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 96.10081742266718


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 95.50744805758318


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 94.92548296155945


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 94.3552670655427


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 93.79773853443288


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 93.25147242215904


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 92.71571767257224


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 92.19286185739121


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 91.6791878864577


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 91.17760327534208


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 90.68849012955374


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 90.20878849735966


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 89.7418634356316


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 89.28624357509152


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 88.84230361791047


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 88.40952435553362


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 87.989299948012


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 87.5802469453182


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 87.1821071440471


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 86.79607152892771


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 86.42085937647428


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 86.05968106342014


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 85.73208785064747


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 80.35062530382434


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 78.73748932130479


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 77.70640188023664


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 76.43431806395402


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 75.4763528323212


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 74.55738916719594


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 73.78116295610265


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 73.02517803210567


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 72.18347080402712


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 71.09788144490953


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 69.87777910217187


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 69.00062134814915


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 68.02845225556845


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 67.23356115829542


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 66.47841386902543


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 65.83598377946494


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 65.1480177813298


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 64.58761812784438


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 63.937808736779644


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 63.38364069780482


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 62.91094448194028


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 62.34320214005868


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 61.86288745959984


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 61.351852260657175


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 60.905825190766805


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 60.433552547890976


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 60.011341713553655


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 59.58133258143869


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 59.20874471280502


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 58.80796877911701


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 58.441814212215505


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 57.9119249743156


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 57.33841202070939


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 56.84530389020984


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 56.42714022453664


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 56.02486685232264


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 55.60812064399658


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 55.15906490749783


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 54.83232912871381


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 54.39942753648988


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 54.01938472661803


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 53.685085626500815


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 53.35102008322011


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 52.929268053457164


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 52.59393687870192


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 52.20277779467056


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 51.81282283703102


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 51.45704677323788


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 51.08443452248443


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 50.69793915372348


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 50.39824644847385


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 49.93580254970925


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 49.552324507893


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 49.182941920822564


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 48.80377607422367


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 48.42550334408279


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 48.043028180034845


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 47.667955162674915


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 47.25207488724959


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 46.89754225443719


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 46.45544690026178


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 46.05540924594406


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 45.6324598002165


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 45.23040419745944


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 44.81808803929991


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 44.470853200473265


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 44.046136999207036


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 43.664203517870824


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 43.22513635423448


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 42.83253591018407


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 42.45436378408361


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 42.07436207619266


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 41.73560938520324


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 41.37640466060423


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 40.94416764684729


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 40.58310902299128


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 40.170812486104914


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 39.79950752872584


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 39.443705336559994


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 39.116035296951516


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 38.76235760122106


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 38.41864414291874


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 38.05251252962195


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 37.698807592591606


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 37.318296427427285


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 36.974101913071294


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 36.66327495498165


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 36.32852369974969


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 35.92114825932108


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 35.657644474871105


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 35.31384846630496


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 34.906372407400276


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 34.585596378068416


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 34.24794060911342


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 33.955869260502325


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 33.61757106458507


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 33.28872579706656


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 32.97700249139046


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 32.68749367488179


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 32.36083585689993


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 32.058981627619396


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 31.740942801446348


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 31.463012212486085


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 31.165563969297303


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 30.850564475681473


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 30.54070802832955
fold 3: mean position error 30.45215899482389
Fold 4


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,1060.95715
Loss/xy,1060.95715
Loss/floor,5.24985
MPE/val,30.45216
epoch,199.0
trainer/global_step,23599.0
_runtime,410.0
_timestamp,1617446099.0
_step,199.0


0,1
Loss/val,██▇▇▆▆▆▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
Loss/xy,██▇▇▆▆▆▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
Loss/floor,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▅▇▇▇▇███▇▇▇▇▇▇▇▇▇▇▇▇
MPE/val,██▇▇▇▇▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 16.7 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
16.7 M    Trainable params
0         Non-trainable params
16.7 M    Total params
66.851    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 190.52713775634766


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 161.48732491338117


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 160.35227447116625


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 159.30070568705142


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 158.279810418023


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 157.27571634486102


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 156.2835981200091


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 155.29985001052634


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 154.3240021009952


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 153.3542618117279


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 152.39056374370188


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 151.4320735261828


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 150.4793109199659


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 149.53049063935947


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 148.58648850937013


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 147.64659446433737


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 146.7123147285695


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 145.78116542306307


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 144.8551712748701


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 143.93397444204433


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 143.0179409801096


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 142.10659854415917


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 141.19993640574089


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 140.29783459661854


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 139.3994244598536


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 138.5063631410952


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 137.61715356332286


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 136.7317563565265


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 135.85230262375492


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 134.97649974730857


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 134.10444439874178


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 133.2376664817429


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 132.37471594818166


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 131.51784236373533


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 130.66407376963542


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 129.81682846204478


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 128.97313438329527


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 128.1344085994358


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 127.30071221319373


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 126.47146474374473


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 125.64694087386323


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 124.82750314046028


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 124.01272892360718


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 123.20234574027683


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 122.39747137416772


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 121.59697912894968


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 120.8023726526282


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 120.01265436562555


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 119.22966365353497


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 118.45021392650266


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 117.67782553397898


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 116.91000829810297


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 116.14848437224802


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 115.39236712739856


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 114.6415953318278


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 113.89800635259509


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 113.15901566841752


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 112.42673034852254


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 111.6997629414434


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 110.98001588892053


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 110.2656956732561


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 109.55879430294804


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 108.85747115339443


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 108.1637381469186


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 107.47582360603958


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 106.79533316140781


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 106.12081890198344


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 105.45376169017355


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 104.79354795373004


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 104.14130993411545


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 103.49621823075888


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 102.85938294145029


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 102.22895549215173


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 101.60821469993407


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 100.99423127473841


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 100.38849447192008


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 99.7911466171584


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 99.20099906614246


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 98.61932316839983


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 98.04576808112446


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 97.48174320964229


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 96.92470281972594


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 96.37654180664948


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 95.8357379446475


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 95.30363431737042


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 94.78073977041936


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 94.26524027825938


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 93.75911741210642


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 93.26177880721775


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 92.77181662468902


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 92.29162181049537


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 91.81930911790538


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 91.35488536423146


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 90.89887434057952


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 90.45177335201637


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 90.0119338835687


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 89.58118231461627


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 89.15864961473646


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 88.74478344756048


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 88.33901323757694


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 87.89607531168227


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 87.26777082864022


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 86.77328003040257


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 86.32146914569651


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 85.89744116618822


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 85.49271081872224


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 85.10749360605138


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 84.73787522062588


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 84.3818554016703


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 84.04109342209576


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 83.71234762772269


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 83.39703651035082


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 83.09201178312686


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 82.80018776664795


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 82.521160187775


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 82.22853806659987


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 75.87306946587064


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 74.17217119312133


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 72.99467171286615


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 71.96820973474622


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 71.10738441786712


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 70.39692986898376


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 69.64886535976244


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 68.98400621183828


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 68.40996000286846


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 67.84430263107717


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 67.20049423180916


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 66.65279841246428


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 65.7813392823445


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 65.18782028898525


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 64.55190402843334


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 63.99428504876276


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 63.239084597601405


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 62.557437150374696


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 61.822090153840044


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 61.02315235045797


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 60.321345555340805


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 59.591727366731554


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 58.98115469399666


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 58.30900497067954


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 57.97069346478596


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 57.38694213861045


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 56.674688597846526


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 56.07727222841911


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 55.551641310201944


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 54.99261061673003


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 54.27172971439823


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 53.70971820772942


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 53.170727656697686


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 52.65560181828129


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 52.09599447265723


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 51.8733200189956


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 51.234368907541466


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 50.7115649174185


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 50.13992108177639


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 49.70837828226136


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 49.25067711045393


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 48.90499313557205


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 48.42233504174028


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 48.07507585104729


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 47.47837380165063


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 47.07182843950059


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 46.67099516741128


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 46.21873103355248


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 45.83623447633212


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 45.59003902318589


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 45.04945288441607


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 44.66860778596666


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 44.26607453535144


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 43.92192772482905


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 43.52336351176583


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 43.174225428023774


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 42.806672122666605


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 42.40017219795313


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 42.10437973655173


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 41.74414546094271


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 41.5333326868005


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 41.054153918144976


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 40.694321082294856


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 40.37793222356725


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 39.99536452577502


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 39.62126825888568


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 39.32350489329217


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 38.99005393290865


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 38.65887378072201


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 38.27818973820759


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 37.945354926528566


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 37.726969747343695


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 37.24978433108368


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 36.94572004327451


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 36.598408816056555


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 36.30268292726526


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 36.001399492027296


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 35.66794805327092


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 35.31090112093374


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 34.98676767195673


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 34.65738484632949


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 34.30902214948682


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 33.959087824552725


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 33.67124149426938


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 33.3629050232554
fold 4: mean position error 33.39957943653497


In [23]:
if len(oofs) > 1:
    oofs_df = pd.concat(oofs)
else:
    oofs_df = oofs[0]
oofs_df.to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}.csv", index=False)
oofs_df

Unnamed: 0,wifi_bssid_0,wifi_bssid_1,wifi_bssid_2,wifi_bssid_3,wifi_bssid_4,wifi_bssid_5,wifi_bssid_6,wifi_bssid_7,wifi_bssid_8,wifi_bssid_9,...,wifi_timegap_97,wifi_timegap_98,wifi_timegap_99,site_id,x,y,floor,oof_x,oof_y,oof_floor
0,41663,41490,6097,13281,7849,17205,40545,38814,28277,40389,...,-0.069780,-0.072009,-0.074546,0,230.03738,153.496350,-1,177.019577,164.333542,0.220378
1,38161,41490,7941,13281,31760,41300,17205,10708,6289,40545,...,-0.707686,-0.709593,-0.711877,0,231.40290,158.415150,-1,177.019119,164.348785,0.220364
2,13281,19970,41490,30748,33953,29390,10708,53418,38161,46296,...,0.799367,0.796698,0.793818,0,232.46200,164.416730,-1,177.035110,163.737869,0.220855
3,41490,30748,24148,38161,29390,44510,6289,45919,16936,8134,...,-0.550203,-0.552189,-0.554536,0,233.94418,171.414170,-1,177.019058,164.354492,0.220360
4,27602,53675,21370,44510,47170,8942,41125,17256,4140,20358,...,-0.805365,-0.807223,-0.809469,0,210.86192,165.376080,-1,177.018921,164.355209,0.220359
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15050,20039,46329,54181,39456,54484,16364,44446,15563,39565,33741,...,1.263842,1.260938,1.257875,23,249.43129,76.241234,6,163.109192,118.054108,0.383541
15051,21863,15563,39628,39565,53086,1942,19357,39456,18998,47308,...,1.263842,1.260938,1.257875,23,237.22395,73.177680,6,162.498444,117.116028,0.372976
15052,20039,54181,486,15563,39628,53754,21863,39565,53086,1942,...,1.263842,1.260938,1.257875,23,242.54440,72.935265,6,163.732056,119.006485,0.394242
15053,20039,53086,2282,39628,16364,39565,1942,19357,24138,18998,...,1.263842,1.260938,1.257875,23,249.43129,76.241234,6,163.867401,119.213043,0.396559


In [24]:
    # foldの結果を平均した後、reindexでsubmission fileにindexを合わせる
all_preds = pd.concat(predictions).groupby('site_path_timestamp').mean().reindex(sub.index)

all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,96.238533,100.854156
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,86.830490,98.938667
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,87.715874,102.928299
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,86.239601,104.513565
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,85.633614,106.807449
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,0,162.991730,121.536186
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,0,163.648239,122.691208
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,0,166.087997,126.914223
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,0,168.254501,130.636414


In [25]:
# floorの数値を置換
simple_accurate_99 = pd.read_csv('../01/submission.csv')
all_preds['floor'] = simple_accurate_99['floor'].values
all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,96.238533,100.854156
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,86.830490,98.938667
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,87.715874,102.928299
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,86.239601,104.513565
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,85.633614,106.807449
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,5,162.991730,121.536186
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,5,163.648239,122.691208
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,5,166.087997,126.914223
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,5,168.254501,130.636414


In [26]:
all_preds.to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}.csv")

In [27]:
print(f"CV:{np.mean(val_scores)}")

CV:35.91180946326729


In [27]:
wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type='summary')
wandb.run.name = 'summary'
wandb.log({'CV_score': np.mean(val_scores)})
wandb.save(utils.get_notebook_path())
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,878.90088
Loss/xy,878.90088
Loss/floor,5.63142
MPE/val,26.88436
epoch,199.0
trainer/global_step,23599.0
_runtime,384.0
_timestamp,1617432959.0
_step,199.0


0,1
Loss/val,██▇▇▆▆▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,██▇▇▆▆▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁███▇▇▇▇▇▇▇▆▆▆▄▃▃▃▃▃▃▃▂▂▂▂
MPE/val,██▇▇▇▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███




VBox(children=(Label(value=' 0.57MB of 0.57MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
CV_score,34.50177
_runtime,2.0
_timestamp,1617433227.0
_step,0.0


0,1
CV_score,▁
_runtime,▁
_timestamp,▁
_step,▁
