# LSTM baseline

from kuto

In [1]:
import os
import sys
import glob
import pickle
import random

import numpy as np
import pandas as pd
import scipy.stats as stats
from pathlib import Path
from glob import glob


sys.path.append('../../')
import src.utils as utils
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

import pytorch_lightning as pl
# from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping

import wandb
from pytorch_lightning.loggers import WandbLogger


In [2]:
DATA_DIR = Path("/home/knikaido/work/Indoor-Location-Navigation/data/")
WIFI_DIR = DATA_DIR / 'unified-ds-wifi-and-beacon'
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'
OUTPUT_DIR = Path('./output/')
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'

## config

In [3]:
configs = {
    'loss':{
        'name': 'MSELoss',
        'params':{}
    },
    'optimizer':{
        'name': 'Adam',
        'params':{
            'lr': 0.001,
        }
    },

    'scheduler':{
        'name': 'ReduceLROnPlateau',
        'params':{
            'factor': 0.1,
            'patience': 3,
        }
    },

    'loader':{
        'train':{
            'batch_size': 512,
            'shuffle': True,
            'num_workers': 4,
        },
        'valid':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        },
        'test':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        }
    }
}

In [4]:
# config
config = configs

# globals variable
SEED = 777
MAX_EPOCHS = 200
N_SPLITS = 5
DEBUG = False
# EXP_MESSAGE = config['globals']['exp_message']

EXP_NAME = 11
IS_SAVE = True

utils.set_seed(SEED)

## read data

In [5]:
wifi_path_list = sorted(glob(str(WIFI_DIR / '*train.csv')))
train_df = []
for path in wifi_path_list:
    train_df.append(pd.read_csv(path))
train_df = pd.concat(train_df).reset_index(drop=True)
train_df = train_df.drop('Unnamed: 0', axis=1)
train_df = train_df.rename(columns={'site':'site_id'})
train_df

Unnamed: 0,site_id,floor,path,x,y,wifi_bssid_0,wifi_rssi_0,wifi_timegap_0,wifi_bssid_1,wifi_rssi_1,...,beacon_timegap_6,beacon_macaddress_7,beacon_distance_7,beacon_timegap_7,beacon_macaddress_8,beacon_distance_8,beacon_timegap_8,beacon_macaddress_9,beacon_distance_9,beacon_timegap_9
0,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,230.03738,153.496350,c08ad78a45798cfe176a42b35c7381ae602711c5,-46,434,30f85a5e14351468a6dd13718a9da3b0d7b73685,-46,...,1421,ff9a29cb2dcb4100880d02b305d5691e578e66cf,2.083036,2827,26145e606b575396f4ca1bc439d2a9b37fdc6fa0,2.341936,398,aeb5121f95a613552e00b083ee11cd70be497ab5,3.696872,1088
1,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,231.40290,158.415150,56bc1d2557bf8225384acb992f4b1993a894db63,-47,1856,25d990fc62eb6d5a42994cd7da0f9b6a70fea57a,-47,...,2562,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,2.629927,1757,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,2.949746,2141,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,2.949746,2664
2,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,232.46200,164.416730,c0beb6ad539d9bd9333c62866c08c844ac69ab28,-48,2767,56bc1d2557bf8225384acb992f4b1993a894db63,-48,...,1351,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,5.137370,1527,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,5.137370,2667,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,5.718217,2644
3,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,233.94418,171.414170,bfaebb72653fac35c19b00e7ce484dc2897f18bd,-51,136,bfaebb72653fac35c19b00e7ce484dc2897f18bd,-52,...,100,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,6.356626,148,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,6.356626,220,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,6.356626,324
4,5a0546857ecc773753327266,-1,5e15730b1506f2000638fc29,198.36833,163.520630,24677fe9a6f29ace69792429fd85fa8f3efd0192,-49,1919,e43e8b111c0061d7cba667a23d6f1c0143bc73ac,-50,...,567,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,10.591243,2147,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,11.687424,770,26145e606b575396f4ca1bc439d2a9b37fdc6fa0,12.882066,90
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
75273,5dc8cea7659e181adb076a3f,6,5dd0d97c94e4900006125dd9,249.79349,74.839640,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-40,1848,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-43,...,3000,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000
75274,5dc8cea7659e181adb076a3f,6,5dd0d97d878f3300066c750b,249.79349,74.839640,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,1875,24489f6ba05ba6afd52ac54e453fd1a2f1128b1a,-79,...,2915,3873a11805bd7adb86762e806d0f20e56e709e76,2.341936,234,-,-99.000000,3000,-,-99.000000,3000
75275,5dc8cea7659e181adb076a3f,6,5dd0d97d878f3300066c750b,249.43129,76.241234,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,404,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,...,1048,3873a11805bd7adb86762e806d0f20e56e709e76,1.850688,636,3873a11805bd7adb86762e806d0f20e56e709e76,2.341936,2045,-,-99.000000,3000
75276,5dc8cea7659e181adb076a3f,6,5dd0d97d878f3300066c750b,242.54440,72.935265,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,2666,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-50,...,3000,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000


In [6]:
# train_df = pd.read_csv(WIFI_DIR / 'train_all.csv')
test_df = pd.read_csv(WIFI_DIR / 'test.csv')
test_df = test_df.drop('Unnamed: 0', axis=1)
test_df = test_df.rename(columns={'site':'site_id'})
test_df['site_path_timestamp'] = test_df['site_id'] + '_' + test_df['path'] + '_' + test_df['timestamp'].map('{:013}'.format)
test_df

Unnamed: 0,timestamp,site_id,floor,path,x,y,wifi_bssid_0,wifi_rssi_0,wifi_timegap_0,wifi_bssid_1,...,beacon_macaddress_7,beacon_distance_7,beacon_timegap_7,beacon_macaddress_8,beacon_distance_8,beacon_timegap_8,beacon_macaddress_9,beacon_distance_9,beacon_timegap_9,site_path_timestamp
0,10,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,889bfa434d66eed8c386ccbc90f445932c43f8dd,-58,1170,29c7d9e757292e7b2b3d00dc4dae7514531b20b4,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
1,4048,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,11567178cc5ca582a37c4733207c77739e1bf5fd,-56,1000,889bfa434d66eed8c386ccbc90f445932c43f8dd,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
2,12526,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,11567178cc5ca582a37c4733207c77739e1bf5fd,-62,1952,11567178cc5ca582a37c4733207c77739e1bf5fd,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
3,25542,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,98d67fadac518296992afddd24e97a2855af9472,-55,2019,889bfa434d66eed8c386ccbc90f445932c43f8dd,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
4,37134,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,11567178cc5ca582a37c4733207c77739e1bf5fd,-54,1977,11567178cc5ca582a37c4733207c77739e1bf5fd,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10128,35117,5a0546857ecc773753327266,0,ffcd9524c80c0fa5bb859eaf,0,0,6c381e7e7d8e984394ce02a3da912acc1b7d294e,-60,2011,6c381e7e7d8e984394ce02a3da912acc1b7d294e,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5a0546857ecc773753327266_ffcd9524c80c0fa5bb859...
10129,41230,5a0546857ecc773753327266,0,ffcd9524c80c0fa5bb859eaf,0,0,46e733fa58deea74d962874847a529fb4897e655,-59,244,6c381e7e7d8e984394ce02a3da912acc1b7d294e,...,5a95e3ee4af260d25d5c13fcf02760820cb6dbdc,67.382098,1334,-,-99.000000,3000,-,-99.000000,3000,5a0546857ecc773753327266_ffcd9524c80c0fa5bb859...
10130,51634,5a0546857ecc773753327266,0,ffcd9524c80c0fa5bb859eaf,0,0,46e733fa58deea74d962874847a529fb4897e655,-49,1166,46e733fa58deea74d962874847a529fb4897e655,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5a0546857ecc773753327266_ffcd9524c80c0fa5bb859...
10131,60483,5a0546857ecc773753327266,0,ffcd9524c80c0fa5bb859eaf,0,0,e5c220800a2d5ec83e355c3b1a2d7a141947e95f,-53,1309,6c381e7e7d8e984394ce02a3da912acc1b7d294e,...,216d1cba23183d9f79bdd0ba7e45f1d55e611c6e,18.299263,118,9b91bb4419360103cea4e1f7d434878550956b32,18.299263,2225,9e9913d8d1bf56161d199b1655220e8be5711f8f,18.800756,2213,5a0546857ecc773753327266_ffcd9524c80c0fa5bb859...


In [7]:
sub = pd.read_csv(DATA_DIR/'indoor-location-navigation/sample_submission.csv', index_col=0)
sub

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,75.0,75.0
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,75.0,75.0
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,75.0,75.0
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,75.0,75.0
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,75.0,75.0
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,0,75.0,75.0
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,0,75.0,75.0
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,0,75.0,75.0
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,0,75.0,75.0


BSSIDとRSSIは100ずつ存在しているけど全てが必要なわけではないみたい  
ここでは20だけ取り出している。

In [8]:
# training target features
NUM_FEATS = 80
BSSID_FEATS = [f'wifi_bssid_{i}' for i in range(NUM_FEATS)]
RSSI_FEATS  = [f'wifi_rssi_{i}' for i in range(NUM_FEATS)]
TIMEGAP_FEATS  = [f'wifi_timegap_{i}' for i in range(NUM_FEATS)]

In [9]:
# get numbers of bssids to embed them in a layer

# train
wifi_bssids = []
# bssidを列ごとにリストに入れていく
for i in BSSID_FEATS:
    wifi_bssids.extend(train_df.loc[:,i].values.tolist())
wifi_bssids = list(set(wifi_bssids))

train_wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(train): {train_wifi_bssids_size}')

# test
wifi_bssids_test = []
for i in BSSID_FEATS:
    wifi_bssids_test.extend(test_df.loc[:,i].values.tolist())
wifi_bssids_test = list(set(wifi_bssids_test))

test_wifi_bssids_size = len(wifi_bssids_test)
print(f'BSSID TYPES(test): {test_wifi_bssids_size}')


wifi_bssids.extend(wifi_bssids_test)
wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(all): {wifi_bssids_size}')


BSSID TYPES(train): 52185
BSSID TYPES(test): 25967
BSSID TYPES(all): 78152


## preprocessing

In [10]:
beacon_columns = [s for s in list(train_df.columns) if 'beacon_' in s]

In [11]:
# preprocess

le = LabelEncoder()
le.fit(wifi_bssids)
le_site = LabelEncoder()
le_site.fit(train_df['site_id'])

ss = StandardScaler()
ss.fit(train_df.loc[:,RSSI_FEATS])

ss_gap = StandardScaler()
ss_gap.fit(train_df.loc[:,TIMEGAP_FEATS])


def preprocess(input_df, le=le, le_site=le_site, ss=ss):
    output_df = input_df.copy()
    # RSSIの正規化
    output_df.loc[:,RSSI_FEATS] = ss.transform(input_df.loc[:,RSSI_FEATS])
    
    # gapの正規化
    output_df.loc[:,TIMEGAP_FEATS] = ss_gap.transform(input_df.loc[:,TIMEGAP_FEATS])

    # BSSIDのLE(1からふる)
    for i in BSSID_FEATS:
        output_df.loc[:,i] = le.transform(input_df.loc[:,i])
#         output_df.loc[:,i] = output_df.loc[:,i] + 1  # 0からではなく1から番号を振りたいため なぜ？

    # site_idのLE
    output_df.loc[:, 'site_id'] = le_site.transform(input_df.loc[:, 'site_id'])
    
    output_df = output_df.drop(beacon_columns, axis=1)

    # なぜ２重でやる？
#     output_df.loc[:,RSSI_FEATS] = ss.transform(output_df.loc[:,RSSI_FEATS])
    return output_df

train = preprocess(train_df)
test = preprocess(test_df)
train    

  return self.partial_fit(X, y)
  return self.partial_fit(X, y)


Unnamed: 0,site_id,floor,path,x,y,wifi_bssid_0,wifi_rssi_0,wifi_timegap_0,wifi_bssid_1,wifi_rssi_1,...,wifi_timegap_96,wifi_bssid_97,wifi_rssi_97,wifi_timegap_97,wifi_bssid_98,wifi_rssi_98,wifi_timegap_98,wifi_bssid_99,wifi_rssi_99,wifi_timegap_99
0,0,-1,5e15730aa280850006f3d005,230.03738,153.496350,39873,0.145725,-1.134794,10121,0.161037,...,434,902588d66481d2a6efd19a12e75030f8f54747b1,-81,434,d3d056da278ff2975a96a486e4c1f70128570110,-81,2331,1b8286ce49240e893317288ec4c13ad463a63829,-81,2331
1,0,-1,5e15730aa280850006f3d005,231.40290,158.415150,17965,0.137646,0.537887,7792,0.152948,...,74,33e2552f3ae9950b5f73ef586b66e305927affb7,-74,74,932b23062713bdd2a66d05f2ef23953dbbe9400b,-74,74,a46b8cada5b70a31969bb3083a7b4a2cb8eb8bab,-74,74
2,0,-1,5e15730aa280850006f3d005,232.46200,164.416730,39921,0.129567,1.609485,17965,0.144860,...,835,5c1dffe86791122abdb56b0f235d9125ce6aa250,-74,835,9d56a98f7b999c582139778e4bbfdeb1f642f35a,-74,835,932b23062713bdd2a66d05f2ef23953dbbe9400b,-74,835
3,0,-1,5e15730aa280850006f3d005,233.94418,171.414170,39705,0.105331,-1.485328,39705,0.112506,...,1832,24677fe9a6f29ace69792429fd85fa8f3efd0192,-76,2090,bb92225a7f19430eac540651807e1bee4424a233,-76,2090,08723f18742bd39eb373008eef7f9db4cea3ae72,-76,2090
4,0,-1,5e15730b1506f2000638fc29,198.36833,163.520630,7522,0.121489,0.611993,47156,0.128683,...,1919,a07406ff086df3d2df63d04334812809daa53b6a,-78,1919,eb5965ab8eeae41a0053ec4ae1d8133e8478707c,-78,1919,5ec31ac8363b39bc7af3e6137f71ee08c42b832f,-79,1919
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
75273,23,6,5dd0d97c94e4900006125dd9,249.79349,74.839640,19175,0.194199,0.528477,19175,0.185302,...,3000,-,-999,3000,-,-999,3000,-,-999,3000
75274,23,6,5dd0d97d878f3300066c750b,249.79349,74.839640,19175,0.145725,0.560237,7503,-0.105883,...,3000,-,-999,3000,-,-999,3000,-,-999,3000
75275,23,6,5dd0d97d878f3300066c750b,249.43129,76.241234,19175,0.145725,-1.170083,19175,0.161037,...,3000,-,-999,3000,-,-999,3000,-,-999,3000
75276,23,6,5dd0d97d878f3300066c750b,242.54440,72.935265,19175,0.145725,1.490680,19175,0.128683,...,2986,9c68d8d0b2e65047f6eaa88592a5480e0ce4c77c,-86,2986,acda33ccb27755bdf6d572b7da8d2b9e64a7681c,-86,2986,7150550f0583994950cd821d38d6d9cf7a554b64,-87,792


In [12]:
site_count = len(train['site_id'].unique())
site_count

24

## PyTorch model
- embedding layerが重要  

In [13]:
# dataset
from torch.utils.data import Dataset, DataLoader
class IndoorDataset(Dataset):
    def __init__(self, df, phase='train'):
        self.df = df
        self.phase = phase
        self.bssid_feats = df[BSSID_FEATS].values.astype(int)
        self.rssi_feats = df[RSSI_FEATS].values.astype(np.float32)
        self.timegap_feats = df[TIMEGAP_FEATS].values.astype(np.float32)
        self.site_id = df['site_id'].values.astype(int)

        if phase in ['train', 'valid']:
            self.xy = df[['x', 'y']].values.astype(np.float32)
            self.floor = df['floor'].values.astype(np.float32)
        
    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        
        feature = {
            'BSSID_FEATS':self.bssid_feats[idx],
            'RSSI_FEATS':self.rssi_feats[idx],
            'TIMEGAP_FEATS':self.timegap_feats[idx],
            'site_id':self.site_id[idx]
        }
        if self.phase in ['train', 'valid']:
            target = {
                'xy':self.xy[idx],
                'floor':self.floor[idx]
            }
        else:
            target = {}
        return feature, target

In [14]:
import torch
from torch import nn

class LSTMModel(nn.Module):
    def __init__(self, bssid_size=94248, site_size=24, embedding_dim=64):
        super(LSTMModel, self).__init__()
        
        # bssid
        # ->64次元に圧縮後sequence化にする
        # wifi_bssids_sizeが辞書の数を表す
        self.bssid_embedding = nn.Embedding(bssid_size, 64, max_norm=True)
        # site
        # ->2次元に圧縮後sequence化する
        # site_countが辞書の数を表す       
        self.site_embedding = nn.Embedding(site_size, 64, max_norm=True)

        # rssi
        # 次元を64倍に線形変換
        self.rssi = nn.Sequential(
            nn.BatchNorm1d(NUM_FEATS),
            nn.Linear(NUM_FEATS, NUM_FEATS * 64)
        )
        
        self.timegap = nn.Sequential(
            nn.BatchNorm1d(NUM_FEATS),
            nn.Linear(NUM_FEATS, NUM_FEATS * 64)
        )
        
        concat_size = 64 + (NUM_FEATS * 64) + (NUM_FEATS * 64) + (NUM_FEATS * 64)
        self.linear_layer2 = nn.Sequential(
            nn.BatchNorm1d(concat_size),
            nn.Dropout(0.3),
            nn.Linear(concat_size, 256),
            nn.ReLU()
        )
        self.bn1 = nn.BatchNorm1d(concat_size)

        self.flatten = nn.Flatten()

        self.dropout1 = nn.Dropout(0.3)
        self.linear1 = nn.Linear(in_features=concat_size, out_features=256)#, bias=False)
        self.bn2 = nn.BatchNorm1d(256)

        self.batch_norm1 = nn.BatchNorm1d(1)
        self.lstm1 = nn.LSTM(input_size=256,hidden_size=128,dropout=0.3, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=128,hidden_size=16,dropout=0.1, batch_first=True)

        self.fc_xy = nn.Linear(16, 2)
        # self.fc_x = nn.Linear(16, 1)
        # self.fc_y = nn.Linear(16, 1)
        self.fc_floor = nn.Linear(16, 1)

    
    def forward(self, x):
        # input embedding
        batch_size = x["site_id"].shape[0]
        x_bssid = self.bssid_embedding(x['BSSID_FEATS'])
        x_bssid = self.flatten(x_bssid)
        
        x_site_id = self.site_embedding(x['site_id'])
        x_site_id = self.flatten(x_site_id)

        x_rssi = self.rssi(x['RSSI_FEATS'])
        
        x_timegap = self.timegap(x['TIMEGAP_FEATS'])

        x = torch.cat([x_bssid, x_site_id, x_rssi, x_timegap], dim=1)
        x = self.linear_layer2(x)

        # lstm layer
        x = x.view(batch_size, 1, -1)  # [batch, 1]->[batch, 1, 1]
        x = self.batch_norm1(x)
        x, _ = self.lstm1(x)
        x = torch.relu(x)
        x, _ = self.lstm2(x)
        x = torch.relu(x)

        # output [batch, 1, 1] -> [batch]
        # x_ = self.fc_x(x).view(-1)
        # y_ = self.fc_y(x).view(-1)
        xy = self.fc_xy(x).squeeze(1)
        floor = torch.relu(self.fc_floor(x)).view(-1)
        # return {"x":x_, "y":y_, "floor":floor} 
        return {"xy": xy, "floor": floor}

In [15]:
def mean_position_error(xhat, yhat, fhat, x, y, f):
    intermediate = np.sqrt(np.power(xhat-x, 2) + np.power(yhat-y, 2)) + 15 * np.abs(fhat-f)
    return intermediate.sum()/xhat.shape[0]

def to_np(input):
    return input.detach().cpu().numpy()

In [16]:
def get_optimizer(model: nn.Module, config: dict):
    optimizer_config = config["optimizer"]
    optimizer_name = optimizer_config.get("name")
    base_optimizer_name = optimizer_config.get("base_name")
    optimizer_params = optimizer_config['params']

    if hasattr(optim, optimizer_name):
        optimizer = optim.__getattribute__(optimizer_name)(model.parameters(), **optimizer_params)
        return optimizer
    else:
        base_optimizer = optim.__getattribute__(base_optimizer_name)
        optimizer = globals().get(optimizer_name)(
            model.parameters(), 
            base_optimizer,
            **optimizer_config["params"])
        return  optimizer

def get_scheduler(optimizer, config: dict):
    scheduler_config = config["scheduler"]
    scheduler_name = scheduler_config.get("name")

    if scheduler_name is None:
        return
    else:
        return optim.lr_scheduler.__getattribute__(scheduler_name)(
            optimizer, **scheduler_config["params"])


def get_criterion(config: dict):
    loss_config = config["loss"]
    loss_name = loss_config["name"]
    loss_params = {} if loss_config.get("params") is None else loss_config.get("params")
    if hasattr(nn, loss_name):
        criterion = nn.__getattribute__(loss_name)(**loss_params)
    else:
        criterion = globals().get(loss_name)(**loss_params)

    return criterion

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

In [17]:
# Learner class(pytorch-lighting)
class Learner(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.xy_criterion = get_criterion(config)
        self.f_criterion = get_criterion(config)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = self.xy_criterion(output["xy"], y["xy"])
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        xy_loss = self.xy_criterion(output["xy"], y["xy"])
        f_loss = self.f_criterion(output["floor"], y["floor"])
        loss = xy_loss  # + f_loss
        mpe = mean_position_error(
            to_np(output['xy'][:, 0]), to_np(output['xy'][:, 1]), 0, 
            to_np(y['xy'][:, 0]), to_np(y['xy'][:, 1]), 0)
        
        # floor lossは現状は無視して良い
        self.log(f'Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/xy', xy_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/floor', f_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'MPE/val', mpe, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return mpe
    
    def validation_epoch_end(self, outputs):
        avg_loss = np.mean(outputs)
        print(f'epoch = {self.current_epoch}, mpe_loss = {avg_loss}')

    def configure_optimizers(self):
        optimizer = get_optimizer(self.model, self.config)
        scheduler = get_scheduler(optimizer, self.config)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "Loss/val"}

In [18]:
# oof
def evaluate(model, loaders, phase):
    x_list = []
    y_list = []
    f_list = []
    with torch.no_grad():
        for batch in loaders[phase]:
            x, y = batch
            output = model(x)
            x_list.append(to_np(output['xy'][:, 0]))
            y_list.append(to_np(output['xy'][:, 1]))
            f_list.append(to_np(output['floor']))

    x_list = np.concatenate(x_list)
    y_list = np.concatenate(y_list)
    f_list = np.concatenate(f_list)
    return x_list, y_list, f_list

## train

In [19]:
oofs = []  # 全てのoofをdfで格納する
predictions = []  # 全ての予測値をdfで格納する
val_scores = []
# skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
gkf = GroupKFold(n_splits=N_SPLITS)
# for fold, (trn_idx, val_idx) in enumerate(skf.split(train.loc[:, 'path'], train.loc[:, 'path'])):
for fold, (trn_idx, val_idx) in enumerate(gkf.split(train.loc[:, 'path'], groups=train.loc[:, 'path'])):

    # 指定したfoldのみループを回す

    print('=' * 20)
    print(f'Fold {fold}')
    print('=' * 20)

    # train/valid data
    trn_df = train.loc[trn_idx, BSSID_FEATS + RSSI_FEATS + TIMEGAP_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)
    val_df = train.loc[val_idx, BSSID_FEATS + RSSI_FEATS + TIMEGAP_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)

    # data loader
    loaders = {}
    loader_config = config["loader"]
    loaders["train"] = DataLoader(IndoorDataset(trn_df, phase="train"), **loader_config["train"], worker_init_fn=worker_init_fn) 
    loaders["valid"] = DataLoader(IndoorDataset(val_df, phase="valid"), **loader_config["valid"], worker_init_fn=worker_init_fn)
    loaders["test"] = DataLoader(IndoorDataset(test, phase="test"), **loader_config["test"], worker_init_fn=worker_init_fn)
    
    # model
    model = LSTMModel(wifi_bssids_size, site_count)
    model_name = model.__class__.__name__
    
    # loggers
    RUN_NAME = f'exp{str(EXP_NAME)}'
    wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type=RUN_NAME + f'-fold-{fold}')
    wandb.run.name = RUN_NAME + f'-fold-{fold}'
    wandb_config = wandb.config
    wandb_config.model_name = model_name
    wandb.watch(model)
    
    
    loggers = []
    loggers.append(WandbLogger())

    learner = Learner(model, config)
    
    # callbacks
    callbacks = []
    checkpoint_callback = ModelCheckpoint(
        monitor=f'Loss/val',
        mode='min',
        dirpath=OUTPUT_DIR,
        verbose=False,
        filename=f'{model_name}-{learner.current_epoch}-{fold}')
    callbacks.append(checkpoint_callback)

    early_stop_callback = EarlyStopping(
        monitor='Loss/val',
        min_delta=0.00,
        patience=3,
        verbose=True,
        mode='min')
    callbacks.append(early_stop_callback)
    
    trainer = pl.Trainer(
        logger=loggers,
        checkpoint_callback=callbacks,
        max_epochs=MAX_EPOCHS,
        default_root_dir=OUTPUT_DIR,
        gpus=1,
        fast_dev_run=DEBUG,
        deterministic=True,
        benchmark=True,
#         precision=16,
#         progress_bar_refresh_rate=0  # vscodeの時progress barの動作が遅いので表示しない
        )


    trainer.fit(learner, train_dataloader=loaders['train'], val_dataloaders=loaders['valid'])

    #############
    # validation (to make oof)
    #############
    model.eval()
    oof_x, oof_y, oof_f = evaluate(model, loaders, phase="valid")
    val_df["oof_x"] = oof_x
    val_df["oof_y"] = oof_y
    val_df["oof_floor"] = oof_f
    oofs.append(val_df)
    
    val_score = mean_position_error(
        val_df["oof_x"].values, val_df["oof_y"].values, 0,
        val_df['x'].values, val_df['y'].values, 0)
    val_scores.append(val_score)
    print(f"fold {fold}: mean position error {val_score}")

    #############
    # inference
    #############
    preds_x, preds_y, preds_f = evaluate(model, loaders, phase="test")
    test_preds = pd.DataFrame(np.stack((preds_f, preds_x, preds_y))).T
    test_preds.columns = sub.columns
    test_preds["site_path_timestamp"] = test["site_path_timestamp"]
    test_preds["floor"] = test_preds["floor"].astype(int)
    predictions.append(test_preds)
    

Fold 0


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))
[34m[1mwandb[0m: Currently logged in as: [33msqrt4kaido[0m (use `wandb login --relogin` to force relogin)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 14.0 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
14.0 M    Trainable params
0         Non-trainable params
14.0 M    Total params
55.999    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 185.69168090820312


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 164.74712669176932


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 163.40157447228066


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 162.12075713720077


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 160.86778157552084


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 159.3467003455529


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 157.87401840992464


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 156.47655727679913


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 155.11467157999675


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 153.77547411796374


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 152.45269849728314


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 151.14563064575196


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 149.84935283171825


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 148.5647932394957


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 147.29179532955854


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 146.0295168363131


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 144.7771547170786


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 143.53591073843148


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 142.3037919068948


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 141.07972838572965


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 139.86660520113432


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 138.66282383845405


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 137.46809125068862


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 136.2831863109882


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 135.1075746291723


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 133.94101274930514


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 132.78594231238733


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 131.6407891982641


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 130.503441795936


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 129.37760046934469


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 128.2618070749136


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 127.15733685615736


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 126.06434099246295


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 124.97996050516764


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 123.9062993367513


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 122.84457076635115


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 121.79301972022424


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 120.75430250901442


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 119.72674132127028


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 118.71155472779886


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 117.70793135227301


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 116.71672366215633


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 115.73699111938477


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 114.77076073670999


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 113.81567405309433


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 112.87301241556803


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 111.94407682174293


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 111.02673699794671


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 110.12244638296275


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 109.2304015428592


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 108.35084797785832


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 107.48371498890411


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 106.62778008289825


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 105.78583266429412


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 104.95673982669146


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 104.13824692750589


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 103.33576140770545


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 102.54495012332231


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 101.76723318833571


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 101.00447363242125


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 100.25592793195675


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 99.5190069540953


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 98.7970196063702


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 98.08868720225799


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 97.3944442162147


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 96.71448543255147


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 96.05047383430677


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 95.40104180360453


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 94.76436430124136


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 94.1441050407214


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 93.53603128286508


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 92.94220583010943


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 92.36403873150167


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 91.79631685110239


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 91.2454282613901


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 90.70556203402005


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 90.18230503766965


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 89.66992184565616


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 89.17169041755872


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 88.68731226798815


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 88.21602102426382


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 87.7598990611541


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 87.31593457735501


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 86.8877513494247


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 84.10622631953312


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 81.04301093663925


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 79.83246840452536


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 79.00790069775704


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 78.03353357559595


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 77.28189942775629


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 76.5362176552797


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 75.7695183191544


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 75.016009418781


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 74.38662033081054


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 73.4624563461695


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 72.61206684601613


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 71.85086133908003


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 71.12650990608411


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 70.61918220030955


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 69.77512424175556


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 69.19720206627478


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 68.54951872703356


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 67.98402621929462


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 67.42454447624011


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 66.9702719370524


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 66.29620314378005


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 65.74667093814948


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 65.25913805839343


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 64.58013643118052


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 63.99025679368239


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 63.417406341357115


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 62.8541746873122


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 62.30857049990923


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 61.898785796532266


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 61.361383213141025


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 60.944270579020184


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 60.494705552321214


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 60.036930612417365


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 59.61774669549404


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 59.225923949021556


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 58.85077640826886


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 58.47121599148481


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 58.10131363990979


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 57.72575118236053


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 57.413335707248784


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 57.04532087766207


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 56.71539188287197


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 56.3603515625


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 56.089644950475446


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 55.76952822758601


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 55.45279029454941


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 55.136882102183804


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 54.85184909624931


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 54.5595655294565


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 54.27037916428004


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 54.05041020711263


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 53.73307814475818


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 53.447958965790576


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 53.169211426759375


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 52.90027900597988


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 52.620915456918574


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 52.32288125845103


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 52.07761727846586


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 51.80487481630766


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 51.47882303091196


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 51.200050813723834


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 50.921023275913335


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 50.62443709251208


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 50.319546802227315


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 50.0199703901242


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 49.713380065331094


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 49.44704491541936


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 49.13279171968118


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 48.81975126511011


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 48.4715350273328


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 48.17039218804776


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 47.85967876727764


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 47.511691024975896


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 47.22049215267866


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 46.85578761956631


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 46.5170957222963


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 46.181171588408645


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 45.87879311977289


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 45.50032372107873


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 45.19180674919715


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 44.863452089749856


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 44.50960692870311


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 44.16505702092097


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 43.85015865717178


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 43.53954820877467


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 43.16672294812325


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 42.80952029595008


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 42.45552509014423


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 42.143520159599106


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 41.79004967029278


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 41.51075338705992


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 41.157718795385115


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 40.80783239511343


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 40.47703799712352


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 40.18316489977714


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 39.85899452551817


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 39.49124938769218


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 39.16307134872828


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 38.90723740504338


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 38.50974032573211


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 38.18014497023363


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 37.83822765839406


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 37.51498315517719


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 37.16691078283848


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 36.91892699217185


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 36.55519945438092


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 36.22636210123698


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 35.942811080736995


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 35.64678740134606


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 35.31715745192308


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 35.05660471304869


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 34.74760788648556


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 34.434747334015675


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 34.17519970673781


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 33.87818120320638


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 33.58841432913756
fold 0: mean position error 33.60109476557105
Fold 1


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,1186.86316
Loss/xy,1186.86316
Loss/floor,5.55224
MPE/val,33.6011
epoch,199.0
trainer/global_step,23599.0
_runtime,378.0
_timestamp,1617431388.0
_step,199.0


0,1
Loss/val,██▇▇▆▆▅▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,██▇▇▆▆▅▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▅█████▇▇▇▇▇▇▇▇▇▇▇▇▇▇
MPE/val,██▇▇▇▆▆▆▅▅▅▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 14.0 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
14.0 M    Trainable params
0         Non-trainable params
14.0 M    Total params
55.999    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 195.3228988647461


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 165.40287755330402


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 163.76973916078225


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 162.23580044477416


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 160.74256826547477


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 159.27480699588094


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 157.8253200433193


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 156.38988081125115


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 154.96851617862018


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 153.55846780630256


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 152.16061970637395


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 150.77372295673075


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 149.39831345386995


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 148.03524979811448


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 146.68521469311835


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 145.34680747007712


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 144.01865727351262


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 142.7026578463041


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 141.39744617755596


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 140.1037074064597


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 138.81973660786946


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 137.54767816983738


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 136.28735310481144


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 135.03760709518042


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 133.80119243524013


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 132.5730298653627


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 131.3575881860195


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 130.15257453918457


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 128.96021745143793


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 127.77811612838354


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 126.60818988115359


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 125.45095927898701


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 124.30566890423115


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 123.171766212659


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 122.05153703934107


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 120.94379430917593


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 119.8514389820588


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 118.77126970535669


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 117.70454190572103


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 116.6525447454208


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 115.6138263898018


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 114.59021180959849


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 113.58067186795749


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 112.58551238622421


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 111.60487871414576


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 110.64027589651255


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 109.68755223200871


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 108.75064449799366


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 107.82704029572317


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 106.91818771362304


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 106.02468845660869


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 105.14689217591898


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 104.28490924346141


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 103.43736737569174


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 102.6053079654009


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 101.78913637797038


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 100.98759312751966


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 100.200084568904


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 99.43125355060283


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 98.67717005411784


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 97.9402865287585


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 97.22023425958096


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 96.51664495223608


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 95.8290575956687


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 95.15831133524577


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 94.50526043818546


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 93.86629982972757


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 93.2430962684827


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 92.63536149049418


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 92.0434611882919


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 91.46807099366801


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 90.90773911109338


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 90.36269027514336


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 89.83382792350572


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 89.32070664625901


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 88.82284405048077


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 88.34169045472757


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 87.87579334943722


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 87.42821433238494


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 86.99570020039876


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 86.57875270354442


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 86.17875583355244


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 85.7949059999906


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 85.42631281339206


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 85.07369470840845


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 84.73620085104919


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 84.41473062955417


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 84.10709864298502


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 83.81462241930839


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 83.53650643764398


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 83.27347099108574


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 83.02393603202624


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 82.78746167696438


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 82.56542178422977


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 82.35640232379619


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 82.16065834241036


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 81.97749457726111


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 81.80818820855556


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 81.65039514394907


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 81.50515981820914


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 81.37209811088366


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 81.24965685330905


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 81.13911848801833


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 81.03898221529447


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 80.94944506914187


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 80.87106425945576


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 80.80176457136105


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 80.74134019704965


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 80.68951533390926


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 80.64578916109525


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 80.60958388890975


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 80.58013126666728


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 80.55642865498861


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 80.53862899389021


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 80.52587226476425


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 80.51759309035081


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 80.51314571087178


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 80.51219379718486


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 80.51416225922414


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 80.51876370356632


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 80.52524078564765


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 80.53332659403483


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 80.5424638503637


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 80.5529460711357


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 80.56283608461038


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 80.57435383918958


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 80.58434364123222


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 80.59464313800518


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 80.60427384987855


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 80.61333727714343


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 80.62190726842636


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 80.62902955275315


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 80.63551795176971


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 80.64251731481308


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 80.64702187562601


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 83.33854552048903


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 65.45522829691569


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 63.182225603935045


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 61.78521229181534


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 61.150337380629324


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 60.61976060133714


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 60.144906489054364


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 59.489444395212026


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 59.20017511416704


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 58.77694022349822


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 58.40323028564453


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 58.212361555833084


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 57.77543893960806


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 57.51611796648074


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 57.260622782584946


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 57.12978078402006


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 56.689428882109816


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 56.480382963327266


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 56.27261038560133


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 56.06663148831099


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 55.874071913499094


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 55.66121895129864


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 55.44480710151868


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 55.29489560738588


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 55.12595597780668


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 54.93957762106871


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 54.786648955711954


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 54.57110086587759


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 54.550080646612706


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 54.29132284995838


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 54.21806037609394


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 53.99776960519644


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 53.836660722585826


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 53.76841280032427


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 53.57467116331443


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 53.47661580794897


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 53.412293566190286


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 53.22228664006943


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 53.111073381472856


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 52.98140435830141


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 52.92321916726919


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 52.78003657414363


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 52.668781544612


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 52.534708448556756


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 52.32537491138165


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 52.053925744081155


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 51.63622098091321


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 51.203322493724336


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 50.7865879596808


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 50.46464720505934


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 49.817193926297705


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 49.400931915870075


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 48.9294717104007


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 48.56117462256016


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 48.16687004871857


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 47.784365409459824


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 47.40924986325777


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 47.02550900777181


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 46.69220091501872


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 46.339528176723384


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 45.986708743755635


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 45.64133911621877


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 45.30534353989821


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 44.9296258975298


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 44.597375175280455


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 44.26461399763058
fold 1: mean position error 44.48774507089203
Fold 2


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,1771.61707
Loss/xy,1771.61707
Loss/floor,6.12709
MPE/val,44.48775
epoch,199.0
trainer/global_step,23599.0
_runtime,384.0
_timestamp,1617431778.0
_step,199.0


0,1
Loss/val,██▇▆▆▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,██▇▆▆▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,███████████████████████████▁▅▆▇▆▆▆▆▅▆▇▇▇
MPE/val,██▇▇▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 14.0 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
14.0 M    Trainable params
0         Non-trainable params
14.0 M    Total params
55.999    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 202.85586547851562


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 163.72767404409555


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 162.36590985029173


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 161.081375787197


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 159.82695380968926


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 158.5923824994992


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 157.37102762858072


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 156.16140295175407


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 154.96069052280524


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 153.76986541748047


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 152.58638559977214


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 151.41074705857497


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 150.24087305313503


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 149.07998438126003


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 147.92528796073717


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 146.77868114373624


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 145.63882049169297


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 144.50641939212116


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 143.38054743057643


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 142.26322025396885


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 141.1516322600536


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 140.04714103111854


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 138.9511973063151


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 137.86108772082207


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 136.77873963576096


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 135.70452059232272


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 134.63883488972982


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 133.57987690460988


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 132.52869215745193


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 131.4861431219639


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 130.4511510604467


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 129.42358676225712


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 128.40489164499135


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 127.39543946095002


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 126.39324364295372


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 125.40111902677096


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 124.41715028224847


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 123.4422595684345


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 122.47553361745982


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 121.5178237915039


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 120.5687652881329


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 119.62976159315843


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 118.70027857071314


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 117.77945436330943


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 116.86923792912411


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 115.96787333366198


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 115.07805432050655


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 114.19644843859551


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 113.32582213572967


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 112.46499979068071


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 111.61498547089406


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 110.77536619137496


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 109.94526396531325


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 109.12653491680439


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 108.31662798172388


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 107.51790999388082


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 106.73029380211463


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 105.95169691428161


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 105.18508516947428


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 104.42913208007812


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 103.6842816964174


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 102.94975896003919


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 102.22788553482447


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 101.5158914712759


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 100.81715398935171


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 100.12985719534068


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 99.45370587079952


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 98.78986352773813


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 98.13760649852264


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 97.49694963112854


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 96.86858351291754


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 96.25230700174967


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 95.64960010235127


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 95.05659438891288


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 94.47816022237141


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 93.91065641549918


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 93.35570745223609


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 92.81109155508189


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 92.27897453308105


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 91.75885060628255


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 91.25016696636494


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 90.75284897241838


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 90.26785195179474


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 89.79336021618965


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 89.33066363212389


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 88.87962928185097


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 88.43991870391064


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 88.01200175163073


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 87.59446337773251


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 87.18997430067796


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 86.79640179169483


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 86.41443994962252


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 86.04396731058756


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 85.68547152983837


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 85.3397198408078


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 85.00455545278696


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 84.68208001943735


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 84.3711900759966


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 84.07130353878706


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 83.78344622880985


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 83.5071450551351


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 83.24132315806854


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 82.98706305088142


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 82.74397068512745


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 82.51203431349535


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 82.29088969108385


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 82.08011628175394


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 81.8799371279203


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 81.6909715016683


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 81.51173212100298


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 81.34271926879883


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 81.18379003084623


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 81.03391386178824


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 80.8933334252773


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 80.76208575322079


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 80.63980682568672


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 80.5267583993765


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 80.42207315885103


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 80.32570563096266


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 80.23745125012519


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 80.15697786380083


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 80.08442759391589


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 80.01945805671888


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 79.9613718962058


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 79.90983188335713


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 79.86481030782063


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 79.82631437839606


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 79.7936773055639


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 79.76619815337352


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 79.74394834469527


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 79.72681315495419


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 79.71388787489671


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 79.70507582640036


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 79.69971906221829


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 79.6975726592235


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 79.69835290175219


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 79.70172825348683


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 79.70701707693247


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 79.71457953819862


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 79.72336421868741


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 79.73394598349547


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 79.74452727880232


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 79.75610446440868


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 79.76759885152181


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 63.42725277925149


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 62.057089957212796


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 61.26584771963267


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 60.68854784843249


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 60.00635154430683


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 59.558885857997794


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 59.087255086654274


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 58.645027835552504


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 58.198606002025116


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 57.81255725958408


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 57.49627008193578


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 57.17956888247759


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 56.7833727469811


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 56.487488952049844


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 56.229259231763


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 55.98023144648626


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 55.735815811157224


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 55.3357294131548


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 54.33416767609425


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 53.512483821771085


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 52.88883005777995


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 52.41458870325333


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 51.94192920587002


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 51.51841729971079


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 51.132081882770244


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 50.830497932434085


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 50.387699484213805


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 50.03152878101056


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 49.6001646286402


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 49.244136692927434


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 48.830968651404746


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 48.495101762429265


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 48.11522805140569


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 47.75029009794578


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 47.483178212092476


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 46.973607097527925


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 46.59391208550869


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 46.24907104296562


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 45.93512651492388


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 45.54346124697954


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 45.182783738160744


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 44.83522391686073


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 44.52695178007468


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 44.142542408674196


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 43.80365232809996


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 43.48907731374105


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 42.90472334348238


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 42.45203296954815


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 42.098913354140066


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 41.66708217523037


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 41.26220199389335


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 40.922861182384004


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 40.50976519951453


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 40.120634582715155


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 39.78511242499719


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 39.39240077092097


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 39.09867065136249
fold 2: mean position error 39.20695964261819
Fold 3


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,1508.0675
Loss/xy,1508.0675
Loss/floor,5.99754
MPE/val,39.20696
epoch,199.0
trainer/global_step,23599.0
_runtime,389.0
_timestamp,1617432173.0
_step,199.0


0,1
Loss/val,██▇▇▆▆▆▅▅▅▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
Loss/xy,██▇▇▆▆▆▅▅▅▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
Loss/floor,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
MPE/val,██▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 14.0 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
14.0 M    Trainable params
0         Non-trainable params
14.0 M    Total params
55.999    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 192.14067840576172


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 164.1278339944984


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 162.66424954302263


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 161.27146057276335


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 159.90863134534655


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 158.5646078321669


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 157.23547648799976


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 155.91828671515276


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 154.61106273632694


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 153.31458356307516


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 152.02618923709397


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 150.74694032162284


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 149.47610939013595


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 148.2133531641077


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 146.9597901901761


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 145.71500078652792


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 144.4791869846903


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 143.24998129876914


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 142.03110376870958


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 140.81929483459768


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 139.61712860746277


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 138.42415670741966


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 137.23947975293834


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 136.0628745529003


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 134.89631958990665


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 133.73876477731406


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 132.59094904778277


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 131.4526075747086


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 130.32320123571128


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 129.20500543996715


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 128.09558649109184


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 126.9972188719229


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 125.90898963326225


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 124.8308946164142


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 123.76142868619418


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 122.70386361753306


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 121.65805488193286


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 120.6233135647244


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 119.60100342755157


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 118.58893792464154


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 117.58767932455704


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 116.59923180412747


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 115.62275231587138


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 114.6575901504492


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 113.70669124007416


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 112.76600322385532


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 111.83812740528641


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 110.92557423110937


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 110.02475139292349


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 109.13710765992194


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 108.26246532126902


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 107.40052221301289


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 106.55182939305206


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 105.71607834943443


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 104.89357871640708


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 104.08495204513967


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 103.28975311377582


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 102.50995469085643


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 101.74396173596958


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 100.99057117959728


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 100.2510093609108


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 99.52555217527922


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 98.81494351754058


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 98.11829743991919


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 97.43523962601371


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 96.76528599872681


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 96.1108413045149


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 95.46805283243913


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 94.8400578233163


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 94.22523232182058


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 93.62368795238255


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 93.03699985510293


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 92.46340437995063


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 91.90208386714524


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 91.35535708748392


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 90.82088937252615


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 90.30059291268317


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 89.79353452058997


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 89.30018084675025


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 88.82053197532076


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 88.35483118201608


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 87.90286383636524


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 87.46377771841347


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 83.25378962469178


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 81.37310607529301


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 80.05585640555611


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 78.90126469807157


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 78.01294156257272


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 77.19986640389607


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 76.39392588211527


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 75.67377563267706


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 74.94807567903577


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 74.2482352319739


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 73.58589954468363


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 72.97869259439612


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 72.3333125607403


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 71.79659038580557


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 71.18034510681595


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 70.06915991102824


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 68.85612247301184


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 67.8851866963213


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 67.06300111853558


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 66.2970649746881


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 65.5087392186581


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 64.78910675048829


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 64.10044616508792


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 63.52662214749101


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 62.89234844705333


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 62.357431898869564


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 61.76293071034258


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 61.24311396188782


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 60.73799690547581


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 60.26285058603579


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 59.89408934404309


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 59.41329832706666


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 58.98789749636934


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 58.561905164918265


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 58.057087851876034


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 57.653801007662416


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 57.250464771412034


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 56.92595819851051


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 56.496801505956476


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 56.07839126740485


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 55.75658764462924


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 55.30709806617332


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 55.005545763270675


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 54.57159309479349


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 54.19948550299554


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 53.83578032095843


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 53.45547035419998


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 53.05873293362186


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 52.71876915334312


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 52.350185381080024


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 52.00595216920026


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 51.67776037299115


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 51.32299675288792


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 50.948494281553806


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 50.590213498286


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 50.273865831838904


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 49.95363184488146


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 49.575704554006485


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 49.28914834727412


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 48.936361710307295


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 48.597266504346074


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 48.22555613740439


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 47.878031158140125


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 47.600780841194684


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 47.213924441744545


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 46.89546452981454


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 46.50755676017675


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 46.13586005304555


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 45.776243270652884


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 45.44182737998532


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 45.1064253412391


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 44.74636445851718


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 44.41789168006172


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 44.069812630148135


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 43.72171366640911


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 43.358524906232155


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 42.77802740164618


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 42.31332563808767


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 41.87217413768676


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 41.46670330496034


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 41.06727666440217


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 40.67708024103285


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 40.270967885722285


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 39.857401097031996


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 39.51955327680529


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 39.10920325600198


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 38.745147269244356


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 38.37498597699471


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 38.02198133837198


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 37.65046299183426


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 37.31018543826977


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 36.91564701528749


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 36.55967398426959


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 36.2447509728768


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 35.841664500090616


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 35.48410603896431


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 35.111099454323835


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 34.791189958200746


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 34.423464331419574


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 34.073376420615375


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 33.73642979389995


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 33.38945910166619


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 33.070332793759455


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 32.7091606427314


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 32.417155479271436


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 32.08149429696002


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 31.69702014708097


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 31.416280752295652


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 31.07454467503153


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 30.771468059797794


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 30.481191013323897


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 30.197202895344166


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 29.871434963152602


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 29.600264089618136


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 29.277080452653326


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 29.03772515958826


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 28.734810964765565


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 28.43224730645209
fold 3: mean position error 28.328698485192106
Fold 4


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,972.74329
Loss/xy,972.74329
Loss/floor,5.52553
MPE/val,28.3287
epoch,199.0
trainer/global_step,23599.0
_runtime,390.0
_timestamp,1617432569.0
_step,199.0


0,1
Loss/val,██▇▇▆▆▅▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,██▇▇▆▆▅▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,█████████████████▂▂▃▁▂▂▂▃▃▄▄▅▅▅▅▇▇▇▇▇▇▇▇
MPE/val,██▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 14.0 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
14.0 M    Trainable params
0         Non-trainable params
14.0 M    Total params
55.999    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 190.34444427490234


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 160.36526847876212


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 158.7243364576748


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 157.19283152901227


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 155.7013877340369


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 154.23587580633242


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 152.7894578365504


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 151.3575724148712


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 149.9406757913734


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 148.53587106898212


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 147.14281872858365


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 145.76258984257058


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 144.39423454518095


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 143.0394193996362


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 141.69548923351147


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 140.36321371388703


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 139.04390803105204


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 137.73468907367013


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 136.4365377576647


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 135.1496975743636


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 133.87435290755857


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 132.61028688464572


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 131.35779203209134


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 130.11590842156403


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 128.887995612103


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 127.66969187455476


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 126.46282925321667


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 125.26729941744352


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 124.0831626486663


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 122.91068601992203


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 121.74996512101276


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 120.60138144869352


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 119.46449976013479


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 118.33983454496963


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 117.22876090819133


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 116.12909210561362


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 115.04254585794396


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 113.96968093540357


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 112.90783331190714


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 111.86255658897609


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 110.82952711931557


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 109.80995854241068


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 108.80600313195859


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 107.81400909485257


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 106.83865039106728


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 105.87543379865022


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 104.92920211877991


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 103.99754063518728


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 103.08184319463905


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 102.18364118861692


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 101.30256842522613


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 100.43722300322159


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 99.58752010425316


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 98.75541324093338


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 97.94271175734663


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 97.14574949537882


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 96.36484021234436


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 95.60387961100456


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 94.85740315534068


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 94.1303522463198


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 93.42080335509566


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 92.72798526843773


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 92.0523508253113


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 91.39486994290313


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 90.75354047151771


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 90.13102955841212


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 89.5239781451878


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 88.93418573327303


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 88.36190554976655


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 87.80662448225775


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 87.26952420478857


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 86.74834843941166


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 86.24698621286095


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 85.7619426616724


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 85.11661233579478


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 80.49424033356942


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 79.09448727839619


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 77.68697389611876


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 76.80241067735852


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 75.84130025946575


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 75.40315870693533


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 74.32546555400855


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 73.57651529542491


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 72.89491965129564


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 72.22567874859304


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 71.5499486213721


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 70.893533077025


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 70.30196113862853


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 69.78616194517716


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 69.12889548837659


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 68.69477007247016


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 68.09874999680572


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 67.41302434273196


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 66.8966138486509


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 66.25882832378197


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 65.67593586441015


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 64.65871982973746


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 63.853302904949096


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 63.2037410484228


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 62.530102723961676


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 61.95287582071891


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 61.396183860397954


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 60.856098756774806


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 60.29772399988344


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 59.7825828662817


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 59.43005514405968


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 59.08588903675909


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 58.30683508518237


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 57.92220107399515


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 57.4392854411053


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 56.9714895091771


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 56.73088360047763


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 56.112689384055024


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 55.79665147041161


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 55.48428593234739


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 54.87730885135571


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 54.55171080572402


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 54.169435520755684


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 53.81661567257991


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 53.41047456544763


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 53.08015239918289


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 52.72075146354147


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 52.343439213665214


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 51.99626824967143


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 51.70227312895795


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 51.34345237145293


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 51.04517542091161


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 50.735103218183035


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 50.42700242243719


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 50.10973668689697


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 49.797109942044614


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 49.53114875818029


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 49.16713532756492


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 48.88373791453535


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 48.59357085604214


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 48.273922608171304


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 47.98062145076513


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 47.67961434934067


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 47.405088202619325


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 47.1205119146817


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 46.84061171121643


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 46.53315295405242


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 46.25927249367878


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 45.75235218855878


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 45.31174402252295


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 44.891297978940216


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 44.4702419537469


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 44.15420271693797


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 43.735416147597554


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 43.38546528593545


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 43.013206879988964


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 42.672058957786376


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 42.249660432665046


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 41.91326330740863


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 41.54420427991956


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 41.14841013031498


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 40.78861321819385


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 40.41028190526793


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 40.0453275517757


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 39.70835642714815


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 39.37950122759538


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 38.997294207126046


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 38.62268371720245


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 38.32552345463235


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 37.939025012460114


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 37.631118096245665


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 37.27502658585995


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 36.93408257519757


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 36.56816587709191


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 36.23662093795248


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 35.9047894662897


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 35.543076374603736


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 35.214561373408095


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 34.87103285651276


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 34.539809096992116


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 34.21134327813239


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 33.88709841235248


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 33.598586770762566


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 33.224563530600975


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 32.921558225174074


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 32.612207601227816


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 32.32323240980434


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 31.988754425494182


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 31.705779369864104


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 31.342963469700344


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 31.036414288245922


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 30.753468856796168


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 30.49736475829341


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 30.158717401784017


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 29.92402000580817


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 29.60681903074329


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 29.28583551528181


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 29.01919897927178


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 28.709025613658863


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 28.488795733567024


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 28.17949042926856


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 27.929870144986875


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 27.617982084086936


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 27.389129218963035


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 27.14860602039454


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 26.87187683662931
fold 4: mean position error 26.88435848072176


In [22]:
if len(oofs) > 1:
    oofs_df = pd.concat(oofs)
else:
    oofs_df = oofs[0]
oofs_df.to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}.csv", index=False)
oofs_df

Unnamed: 0,wifi_bssid_0,wifi_bssid_1,wifi_bssid_2,wifi_bssid_3,wifi_bssid_4,wifi_bssid_5,wifi_bssid_6,wifi_bssid_7,wifi_bssid_8,wifi_bssid_9,...,wifi_timegap_77,wifi_timegap_78,wifi_timegap_79,site_id,x,y,floor,oof_x,oof_y,oof_floor
0,39873,10121,39095,16067,52302,13060,39873,16067,10121,52302,...,-1.330807,-1.330601,-1.331937,0,230.03738,153.496350,-1,176.430313,155.572342,0.157174
1,17965,7792,39921,19108,29424,39921,16191,7792,19108,29424,...,0.432958,0.430131,-1.733404,0,231.40290,158.415150,-1,176.407486,155.841202,0.158001
2,39921,17965,16191,7792,19108,52934,52302,16067,10121,39095,...,-0.580340,-0.581424,1.269790,0,232.46200,164.416730,-1,176.420746,155.689865,0.157538
3,39705,39705,39873,39873,39705,16067,52302,39095,13060,10121,...,0.232759,0.230276,0.227092,0,233.94418,171.414170,-1,176.414627,155.753967,0.157730
4,26417,12652,41723,21199,16863,39050,23964,19076,17470,10752,...,0.378155,0.375422,0.372066,0,210.86192,165.376080,-1,176.171432,157.637131,0.138766
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15050,19175,19175,19175,7503,18513,13910,7503,18513,13910,7503,...,-1.715547,0.164402,0.161296,23,249.43129,76.241234,6,176.280716,118.478104,0.013724
15051,19175,19175,19175,13910,32715,13910,32715,13910,7503,34182,...,0.582827,0.579743,0.576145,23,237.22395,73.177680,6,168.918488,119.839279,0.107121
15052,19175,19175,19175,7503,32715,7503,7503,7038,7038,32715,...,-0.258232,-0.259870,-0.262474,23,242.54440,72.935265,6,170.729889,116.128365,0.027674
15053,19175,19175,19175,7503,7038,7503,7503,7038,32715,13910,...,-0.471852,-0.473123,-0.475475,23,249.43129,76.241234,6,172.952377,113.620102,0.000000


In [23]:
    # foldの結果を平均した後、reindexでsubmission fileにindexを合わせる
all_preds = pd.concat(predictions).groupby('site_path_timestamp').mean().reindex(sub.index)

all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,91.197029,101.293404
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,84.478302,101.666359
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,85.910324,106.208855
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,88.830978,106.066658
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,88.957901,105.958168
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,0,167.698700,125.386986
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,0,167.104019,124.574036
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,0,170.583908,129.099243
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,0,171.617325,130.661057


In [24]:
# floorの数値を置換
simple_accurate_99 = pd.read_csv('../01/submission.csv')
all_preds['floor'] = simple_accurate_99['floor'].values
all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,91.197029,101.293404
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,84.478302,101.666359
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,85.910324,106.208855
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,88.830978,106.066658
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,88.957901,105.958168
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,5,167.698700,125.386986
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,5,167.104019,124.574036
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,5,170.583908,129.099243
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,5,171.617325,130.661057


In [25]:
all_preds.to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}.csv")

In [26]:
print(f"CV:{np.mean(val_scores)}")

CV:34.501771288999024


In [27]:
wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type='summary')
wandb.run.name = 'summary'
wandb.log({'CV_score': np.mean(val_scores)})
wandb.save(utils.get_notebook_path())
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,878.90088
Loss/xy,878.90088
Loss/floor,5.63142
MPE/val,26.88436
epoch,199.0
trainer/global_step,23599.0
_runtime,384.0
_timestamp,1617432959.0
_step,199.0


0,1
Loss/val,██▇▇▆▆▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,██▇▇▆▆▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁███▇▇▇▇▇▇▇▆▆▆▄▃▃▃▃▃▃▃▂▂▂▂
MPE/val,██▇▇▇▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███




VBox(children=(Label(value=' 0.57MB of 0.57MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
CV_score,34.50177
_runtime,2.0
_timestamp,1617433227.0
_step,0.0


0,1
CV_score,▁
_runtime,▁
_timestamp,▁
_step,▁
