# LSTM baseline

from kuto

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

In [3]:
import os
import sys
import glob
import pickle
import random

In [4]:
import numpy as np
import pandas as pd
import scipy.stats as stats
from pathlib import Path


In [5]:
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder

In [6]:
import pytorch_lightning as pl
# from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping

import wandb
from pytorch_lightning.loggers import WandbLogger

In [7]:
sys.path.append('../../')
import src.utils as utils

In [8]:
import multiprocessing
import scipy.interpolate
import scipy.sparse
from tqdm import tqdm

from indoor_location_competition_20.io_f import read_data_file
import indoor_location_competition_20.compute_f as compute_f

In [9]:
DATA_DIR = Path("/home/knikaido/work/Indoor-Location-Navigation/data/")
WIFI_DIR = DATA_DIR / 'indoorunifiedwifids_original'
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'
OUTPUT_DIR = Path('./output/')
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'

## config

In [10]:
configs = {
    'loss':{
        'name': 'MSELoss',
        'params':{}
    },
    'optimizer':{
        'name': 'Adam',
        'params':{
            'lr': 0.01,
        }
    },

    'scheduler':{
        'name': 'ReduceLROnPlateau',
        'params':{
            'factor': 0.1,
            'patience': 3,
        }
    },

    'loader':{
        'train':{
            'batch_size': 512,
            'shuffle': True,
            'num_workers': 4,
        },
        'valid':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        },
        'test':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        }
    }
}

In [11]:
# config
config = configs

# globals variable
SEED = 777
MAX_EPOCHS = 500
N_SPLITS = 5
DEBUG = False
# EXP_MESSAGE = config['globals']['exp_message']

EXP_NAME = 50
IS_SAVE = True

utils.set_seed(SEED)

In [12]:
!wandb login e8aaf98060af90035c3c28a83b34452780aeec20

/bin/sh: 1: wandb: not found


## read data

In [13]:
train_df = pd.read_csv(WIFI_DIR / 'train_10_th10000_base25_withdelta.csv')
test_df = pd.read_csv(WIFI_DIR / 'test_10_th10000_base25_withdelta.csv')

In [14]:
sub = pd.read_csv(DATA_DIR/'indoor-location-navigation/sample_submission.csv', index_col=0)

BSSIDとRSSIは100ずつ存在しているけど全てが必要なわけではないみたい  
ここでは20だけ取り出している。

In [13]:
# training target features
NUM_FEATS = 80
BSSID_FEATS = [f'bssid_{i}' for i in range(NUM_FEATS)]
RSSI_FEATS  = [f'rssi_{i}' for i in range(NUM_FEATS)]
DELTA_FEATS  = ['delta_x', 'delta_y']

In [15]:
train_df.iloc[:, 100:110]

Unnamed: 0,bssid_0,bssid_1,bssid_2,bssid_3,bssid_4,bssid_5,bssid_6,bssid_7,bssid_8,bssid_9
0,e9b24f94c0007acb4b7169b945622efcd332cf6f,591ea59cf88e3397db5d60eb00a5147edd69399a,590a4dd2afa1ad07090fd5f390f65a55e3dc5f56,a77f8e93896f8fc8bc0d0700ca04b802ee79a07f,1b2fd184314ae440900fa9ce1addeb896b5604a9,2c09230bb32ee49f6a72928f6eeefb6885dc15ce,3799b46aa4cf6c3c45c0bc27d8f1efefea96914f,fc6956beb062b5158252c66953e92a0d25495cac,c71a2f5c4282d27f84b9b841db0e310ef0fcf6cd,4f8b7c168dc76c9d3b4ca7903042173e98fe2ddb
1,e9b24f94c0007acb4b7169b945622efcd332cf6f,590a4dd2afa1ad07090fd5f390f65a55e3dc5f56,591ea59cf88e3397db5d60eb00a5147edd69399a,d32dd11040b254cd889c9ead2d4a50f6e3900196,9c832009dfb1ee02053c9ce9b7770b6cd3191003,509d1f842b0773e85c6beec0bb530542efd35cb9,1b2fd184314ae440900fa9ce1addeb896b5604a9,f920a2e4cb52165850990d9d37d391b630f7de14,6a42281c99a4cff2ce9dba3fc91ad6a431af64d9,fc6956beb062b5158252c66953e92a0d25495cac
2,590a4dd2afa1ad07090fd5f390f65a55e3dc5f56,e9b24f94c0007acb4b7169b945622efcd332cf6f,591ea59cf88e3397db5d60eb00a5147edd69399a,509d1f842b0773e85c6beec0bb530542efd35cb9,b2337b25e7d1df04928bf6698a9c0b2764df7795,76f81d5047273fa64a434457531d400fc5d90fac,6e388d1db5ba8dd9de80522a4ddf50402cf443b3,8c6ab78f2797e076f9106af81090d0ab9904f5cd,ceccac4f0e50ec9e36e8d2800b8f2c7c3b4d903e,f920a2e4cb52165850990d9d37d391b630f7de14
3,590a4dd2afa1ad07090fd5f390f65a55e3dc5f56,591ea59cf88e3397db5d60eb00a5147edd69399a,e9b24f94c0007acb4b7169b945622efcd332cf6f,6ec56c3efdeb067eb20bd2f4a6ccdae07d640cc1,d32dd11040b254cd889c9ead2d4a50f6e3900196,f920a2e4cb52165850990d9d37d391b630f7de14,6a42281c99a4cff2ce9dba3fc91ad6a431af64d9,9c832009dfb1ee02053c9ce9b7770b6cd3191003,509d1f842b0773e85c6beec0bb530542efd35cb9,0452e85d0a41780463cfe079077ea5bd2f519c7a
4,5875360455060f20a3cba705f44a4e3987c9b9f3,6ec56c3efdeb067eb20bd2f4a6ccdae07d640cc1,0452e85d0a41780463cfe079077ea5bd2f519c7a,3c7e7fa0576bc8a2af71d5899581df36f4dab6c8,09e103887f42552d20328aa41891cf82dace79ab,54bba3a36204f8c71b93798c31f9e0b039914575,18067f8d8861af3bcae51ba04b6b11b9150b9ff2,591ea59cf88e3397db5d60eb00a5147edd69399a,f920a2e4cb52165850990d9d37d391b630f7de14,d32dd11040b254cd889c9ead2d4a50f6e3900196
...,...,...,...,...,...,...,...,...,...,...
251108,5964a27e0cb3344b0a18540e6b3120c433971c38,d090a2f7f222fadeeb64e4fbdfe1ca8451116b04,cce41299a022ada08aebf3d309acb07d5f00b014,346b34a42e801c64e043dbaacbe7fef9b8880774,4b5dbdb52b131410ea10b59ea451de62280b41d6,fa11fc4d4960379cb68cc6968ba6415168fef53c,5f583dcccc43b5b7ac25d270e29c92d878fb2be0,4d2e5639041b40b0df2ee258aa504bd904133d80,dfc21edb1f7650d5645fd672bbe6a13fc6fd77f8,a94eb920c0a198fe8385f3de6a8e8e6d44b6f6c9
251109,5964a27e0cb3344b0a18540e6b3120c433971c38,346b34a42e801c64e043dbaacbe7fef9b8880774,d090a2f7f222fadeeb64e4fbdfe1ca8451116b04,cce41299a022ada08aebf3d309acb07d5f00b014,566e0c6e3bcf2b8b3d310d96f111043d17ace817,bd3fc24710537130e97dc2dab4a6bf70b3884a8b,4b5dbdb52b131410ea10b59ea451de62280b41d6,4d2e5639041b40b0df2ee258aa504bd904133d80,f4107af4418d57aacb3542343f7b47768debdc75,5f583dcccc43b5b7ac25d270e29c92d878fb2be0
251110,346b34a42e801c64e043dbaacbe7fef9b8880774,5964a27e0cb3344b0a18540e6b3120c433971c38,cce41299a022ada08aebf3d309acb07d5f00b014,566e0c6e3bcf2b8b3d310d96f111043d17ace817,bd3fc24710537130e97dc2dab4a6bf70b3884a8b,d090a2f7f222fadeeb64e4fbdfe1ca8451116b04,4b5dbdb52b131410ea10b59ea451de62280b41d6,a94eb920c0a198fe8385f3de6a8e8e6d44b6f6c9,ee5ca7a7deaacdcd5d99355ff5f156dc45b74efa,a7986c0cea5d2571ea42011ab4407039e977c0bd
251111,346b34a42e801c64e043dbaacbe7fef9b8880774,d090a2f7f222fadeeb64e4fbdfe1ca8451116b04,5964a27e0cb3344b0a18540e6b3120c433971c38,f4107af4418d57aacb3542343f7b47768debdc75,cce41299a022ada08aebf3d309acb07d5f00b014,4b5dbdb52b131410ea10b59ea451de62280b41d6,fa11fc4d4960379cb68cc6968ba6415168fef53c,a7986c0cea5d2571ea42011ab4407039e977c0bd,180a351ec58c07d60949862c534373c43f548a9a,4d2e5639041b40b0df2ee258aa504bd904133d80


bssid_NはN個目のBSSIDを示しておりRSSI値が大きい順に番号が振られている。
100個しかない


In [16]:
# get numbers of bssids to embed them in a layer

# train
wifi_bssids = []
# bssidを列ごとにリストに入れていく
for i in range(100, 200):
    wifi_bssids.extend(train_df.iloc[:,i].values.tolist())
wifi_bssids = list(set(wifi_bssids))

train_wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(train): {train_wifi_bssids_size}')

# test
wifi_bssids_test = []
for i in range(100, 200):
    wifi_bssids_test.extend(test_df.iloc[:,i].values.tolist())
wifi_bssids_test = list(set(wifi_bssids_test))

test_wifi_bssids_size = len(wifi_bssids_test)
print(f'BSSID TYPES(test): {test_wifi_bssids_size}')


wifi_bssids.extend(wifi_bssids_test)
wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(all): {wifi_bssids_size}')


BSSID TYPES(train): 61592
BSSID TYPES(test): 27809
BSSID TYPES(all): 89401


In [17]:
# get numbers of bssids to embed them in a layer

# train
rssi_bssids = []
# bssidを列ごとにリストに入れていく
for i in RSSI_FEATS:
    rssi_bssids.extend(train_df.loc[:,i].values.tolist())
rssi_bssids = list(set(rssi_bssids))

train_rssi_bssids_size = len(rssi_bssids)
print(f'RSSI TYPES(train): {train_rssi_bssids_size}')

# test
rssi_bssids_test = []
for i in RSSI_FEATS:
    rssi_bssids_test.extend(test_df.loc[:,i].values.tolist())
rssi_bssids_test = list(set(rssi_bssids_test))

test_rssi_bssids_size = len(rssi_bssids_test)
print(f'RSSI TYPES(test): {test_rssi_bssids_size}')


rssi_bssids.extend(rssi_bssids_test)
rssi_bssids_size = len(rssi_bssids)
print(f'RSSI TYPES(all): {rssi_bssids_size}')

RSSI TYPES(train): 97
RSSI TYPES(test): 77
RSSI TYPES(all): 174


## PreProcess

In [18]:
# preprocess

le = LabelEncoder()
le.fit(wifi_bssids)
le_site = LabelEncoder()
le_site.fit(train_df['site_id'])
le_rssi = LabelEncoder()
le_rssi.fit(rssi_bssids)

ss = StandardScaler()
ss.fit(train_df.loc[:,DELTA_FEATS])


def preprocess(input_df, le=le, le_site=le_site, ss=ss):
    output_df = input_df.copy()
    # RSSIの正規化
#     output_df.loc[:,RSSI_FEATS] = ss.transform(input_df.loc[:,RSSI_FEATS])

    # BSSIDのLE(1からふる)
    for i in BSSID_FEATS:
        output_df.loc[:,i] = le.transform(input_df.loc[:,i])
#         output_df.loc[:,i] = output_df.loc[:,i] + 1  # 0からではなく1から番号を振りたいため なぜ？
    for i in RSSI_FEATS:
        output_df.loc[:,i] = le_rssi.transform(input_df.loc[:,i])

    # site_idのLE
    output_df.loc[:, 'site_id'] = le_site.transform(input_df.loc[:, 'site_id'])

    # なぜ２重でやる？
    output_df.loc[:,DELTA_FEATS] = ss.transform(output_df.loc[:,DELTA_FEATS])
    return output_df

train = preprocess(train_df)
test = preprocess(test_df)

train  

Unnamed: 0,ssid_0,ssid_1,ssid_2,ssid_3,ssid_4,ssid_5,ssid_6,ssid_7,ssid_8,ssid_9,...,frequency_99,wp_tmestamp,x,y,floor,floor_str,path_id,site_id,delta_x,delta_y
0,b7e6027447eb1f81327d66cfd3adbe557aabf26c,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,cef6dc5e595dd99c3b2c605de65cfc1f147e892b,7182afc4e5c212133d5d7d76eb3df6c24618302b,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,8c1562bec17e1425615f3402f72dded3caa42ce5,b7e6027447eb1f81327d66cfd3adbe557aabf26c,da39a3ee5e6b4b0d3255bfef95601890afd80709,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,...,5745,1578469851129,157.99141,102.125390,-1.0,B1,5e158ef61506f2000638fd1f,0,-0.025217,-0.014531
1,b7e6027447eb1f81327d66cfd3adbe557aabf26c,cef6dc5e595dd99c3b2c605de65cfc1f147e892b,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,7182afc4e5c212133d5d7d76eb3df6c24618302b,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,b7e6027447eb1f81327d66cfd3adbe557aabf26c,b7e6027447eb1f81327d66cfd3adbe557aabf26c,...,5765,1578469857653,162.93443,106.413020,-1.0,B1,5e158ef61506f2000638fd1f,0,0.446463,0.506216
2,cef6dc5e595dd99c3b2c605de65cfc1f147e892b,b7e6027447eb1f81327d66cfd3adbe557aabf26c,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,b7e6027447eb1f81327d66cfd3adbe557aabf26c,7182afc4e5c212133d5d7d76eb3df6c24618302b,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,...,5745,1578469857653,162.93443,106.413020,-1.0,B1,5e158ef61506f2000638fd1f,0,0.446463,0.506216
3,cef6dc5e595dd99c3b2c605de65cfc1f147e892b,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,b7e6027447eb1f81327d66cfd3adbe557aabf26c,da39a3ee5e6b4b0d3255bfef95601890afd80709,7182afc4e5c212133d5d7d76eb3df6c24618302b,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,b7e6027447eb1f81327d66cfd3adbe557aabf26c,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,...,5825,1578469857653,162.93443,106.413020,-1.0,B1,5e158ef61506f2000638fd1f,0,0.446463,0.506216
4,da39a3ee5e6b4b0d3255bfef95601890afd80709,da39a3ee5e6b4b0d3255bfef95601890afd80709,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b7e6027447eb1f81327d66cfd3adbe557aabf26c,7182afc4e5c212133d5d7d76eb3df6c24618302b,5731b8e08abc69d4c4d685c58164059207c93310,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,7182afc4e5c212133d5d7d76eb3df6c24618302b,...,5765,1578469862177,168.49713,109.861336,-1.0,B1,5e158ef61506f2000638fd1f,0,0.504093,0.398055
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
251108,ea7731d04cf9ed352d4805b1ff904bebdf60eb49,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,4abd3985ba804364272767c04cdc211615f77c56,b5d43f6b4e1938ed497c7b589c6eae9ac0bee168,0a8a55cf161bc4980194ec9f7f7a448439be4b74,1f09251bbfadafb11c63c87963af25238d6bc886,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,da39a3ee5e6b4b0d3255bfef95601890afd80709,1556355684145fce5e67ba749d943a180266ad90,...,0,1573733061352,203.53165,143.513960,6.0,F7,5dcd5c9323759900063d590a,23,0.121631,0.362279
251109,ea7731d04cf9ed352d4805b1ff904bebdf60eb49,4abd3985ba804364272767c04cdc211615f77c56,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,b5d43f6b4e1938ed497c7b589c6eae9ac0bee168,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,4abd3985ba804364272767c04cdc211615f77c56,1f09251bbfadafb11c63c87963af25238d6bc886,...,0,1573733070079,192.57130,145.781450,6.0,F7,5dcd5c9323759900063d590a,23,-2.376529,0.568969
251110,4abd3985ba804364272767c04cdc211615f77c56,ea7731d04cf9ed352d4805b1ff904bebdf60eb49,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,1556355684145fce5e67ba749d943a180266ad90,b5d43f6b4e1938ed497c7b589c6eae9ac0bee168,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,5d998a8668536c4f51004c25f474117fe9555f78,...,0,1573733070079,192.57130,145.781450,6.0,F7,5dcd5c9323759900063d590a,23,-2.376529,0.568969
251111,4abd3985ba804364272767c04cdc211615f77c56,1556355684145fce5e67ba749d943a180266ad90,ea7731d04cf9ed352d4805b1ff904bebdf60eb49,4abd3985ba804364272767c04cdc211615f77c56,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,b5d43f6b4e1938ed497c7b589c6eae9ac0bee168,0a8a55cf161bc4980194ec9f7f7a448439be4b74,5d998a8668536c4f51004c25f474117fe9555f78,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,...,0,1573733070079,192.57130,145.781450,6.0,F7,5dcd5c9323759900063d590a,23,-2.376529,0.568969


In [19]:
site_count = len(train['site_id'].unique())
site_count

24

## PyTorch model
- embedding layerが重要  

In [20]:
# dataset
from torch.utils.data import Dataset, DataLoader
class IndoorDataset(Dataset):
    def __init__(self, df, phase='train'):
        self.df = df
        self.phase = phase
        self.bssid_feats = df[BSSID_FEATS].values.astype(int)
        self.rssi_feats = df[RSSI_FEATS].values.astype(int)
        self.delta_feats = df[DELTA_FEATS].values.astype(np.float32)
        self.site_id = df['site_id'].values.astype(int)

        if phase in ['train', 'valid']:
            self.xy = df[['x', 'y']].values.astype(np.float32)
            self.floor = df['floor'].values.astype(np.float32)
        
    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        
        concat_feat = np.empty(2 * NUM_FEATS).astype(int)
        concat_feat[0::2] = self.bssid_feats[idx]
        concat_feat[1::2] = self.rssi_feats[idx]
        
        feature = {
            'RSSI_BSSID_FEATS':concat_feat,
            'site_id':self.site_id[idx],
            'delta':self.delta_feats[idx]
        }
        if self.phase in ['train', 'valid']:
            target = {
                'xy':self.xy[idx],
                'floor':self.floor[idx]
            }
        else:
            target = {}
        return feature, target

In [21]:
import torch
from torch import nn

class LSTMModel(nn.Module):
    def __init__(self, bssid_size=94248, site_size=24, embedding_dim=64):
        super(LSTMModel, self).__init__()
        
        # bssid
        # ->64次元に圧縮後sequence化にする
        # wifi_bssids_sizeが辞書の数を表す
        self.bssid_embedding = nn.Embedding(bssid_size, 64, max_norm=True)
        # site
        # ->2次元に圧縮後sequence化する
        # site_countが辞書の数を表す       
        self.site_embedding = nn.Embedding(site_size, 64, max_norm=True)

        # rssi
        # 次元を64倍に線形変換
#         self.rssi_embedding = nn.Embedding(rssi_size, 64, max_norm=True)
        self.delta = nn.Sequential(
            nn.BatchNorm1d(2),
            nn.Linear(2, 2 * 64)
        )
        
        concat_size = 64 + (2 * NUM_FEATS * 64) + 64 * 2
        self.linear_layer2 = nn.Sequential(
            nn.BatchNorm1d(concat_size),
            nn.Dropout(0.3),
            nn.Linear(concat_size, 256),
            nn.ReLU()
        )
        self.bn1 = nn.BatchNorm1d(concat_size)

        self.flatten = nn.Flatten()

        self.dropout1 = nn.Dropout(0.3)
        self.linear1 = nn.Linear(in_features=concat_size, out_features=256)#, bias=False)
        self.bn2 = nn.BatchNorm1d(256)

        self.batch_norm1 = nn.BatchNorm1d(1)
        self.lstm1 = nn.LSTM(input_size=256,hidden_size=128,dropout=0.3, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=128,hidden_size=16,dropout=0.1, batch_first=True)

        self.fc_xy = nn.Linear(16, 2)
        # self.fc_x = nn.Linear(16, 1)
        # self.fc_y = nn.Linear(16, 1)
        self.fc_floor = nn.Linear(16, 1)

    
    def forward(self, x):
        # input embedding
        batch_size = x["site_id"].shape[0]
        x_bssid = self.bssid_embedding(x['RSSI_BSSID_FEATS'])
        x_bssid = self.flatten(x_bssid)
        
        x_site_id = self.site_embedding(x['site_id'])
        x_site_id = self.flatten(x_site_id)
        
        x_delta = self.delta(x['delta'])


        x = torch.cat([x_bssid, x_delta, x_site_id], dim=1)
        x = self.linear_layer2(x)

        # lstm layer
        x = x.view(batch_size, 1, -1)  # [batch, 1]->[batch, 1, 1]
        x = self.batch_norm1(x)
        x, _ = self.lstm1(x)
        x = torch.relu(x)
        x, _ = self.lstm2(x)
        x = torch.relu(x)

        # output [batch, 1, 1] -> [batch]
        # x_ = self.fc_x(x).view(-1)
        # y_ = self.fc_y(x).view(-1)
        xy = self.fc_xy(x).squeeze(1)
        floor = torch.relu(self.fc_floor(x)).view(-1)
        # return {"x":x_, "y":y_, "floor":floor} 
        return {"xy": xy, "floor": floor}

In [44]:
def mean_position_error(xhat, yhat, fhat, x, y, f):
    intermediate = np.sqrt(np.power(xhat-x, 2) + np.power(yhat-y, 2)) + 15 * np.abs(fhat-f)
    return intermediate.sum()/xhat.shape[0]

def to_np(input):
    return input.detach().cpu().numpy()

In [23]:
def get_optimizer(model: nn.Module, config: dict):
    optimizer_config = config["optimizer"]
    optimizer_name = optimizer_config.get("name")
    base_optimizer_name = optimizer_config.get("base_name")
    optimizer_params = optimizer_config['params']

    if hasattr(optim, optimizer_name):
        optimizer = optim.__getattribute__(optimizer_name)(model.parameters(), **optimizer_params)
        return optimizer
    else:
        base_optimizer = optim.__getattribute__(base_optimizer_name)
        optimizer = globals().get(optimizer_name)(
            model.parameters(), 
            base_optimizer,
            **optimizer_config["params"])
        return  optimizer

def get_scheduler(optimizer, config: dict):
    scheduler_config = config["scheduler"]
    scheduler_name = scheduler_config.get("name")

    if scheduler_name is None:
        return
    else:
        return optim.lr_scheduler.__getattribute__(scheduler_name)(
            optimizer, **scheduler_config["params"])


def get_criterion(config: dict):
    loss_config = config["loss"]
    loss_name = loss_config["name"]
    loss_params = {} if loss_config.get("params") is None else loss_config.get("params")
    if hasattr(nn, loss_name):
        criterion = nn.__getattribute__(loss_name)(**loss_params)
    else:
        criterion = globals().get(loss_name)(**loss_params)

    return criterion

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

In [24]:
# Learner class(pytorch-lighting)
class Learner(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.xy_criterion = get_criterion(config)
        self.f_criterion = get_criterion(config)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = self.xy_criterion(output["xy"], y["xy"])
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        xy_loss = self.xy_criterion(output["xy"], y["xy"])
        f_loss = self.f_criterion(output["floor"], y["floor"])
        loss = xy_loss  # + f_loss
        mpe = mean_position_error(
            to_np(output['xy'][:, 0]), to_np(output['xy'][:, 1]), 0, 
            to_np(y['xy'][:, 0]), to_np(y['xy'][:, 1]), 0)
        
        # floor lossは現状は無視して良い
        self.log(f'Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/xy', xy_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/floor', f_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'MPE/val', mpe, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return mpe
    
    def validation_epoch_end(self, outputs):
        avg_loss = np.mean(outputs)
        print(f'epoch = {self.current_epoch}, mpe_loss = {avg_loss}')

    def configure_optimizers(self):
        optimizer = get_optimizer(self.model, self.config)
        scheduler = get_scheduler(optimizer, self.config)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "Loss/val"}

In [25]:
# oof
def evaluate(model, loaders, phase):
    x_list = []
    y_list = []
    f_list = []
    with torch.no_grad():
        for batch in loaders[phase]:
            x, y = batch
            output = model(x)
            x_list.append(to_np(output['xy'][:, 0]))
            y_list.append(to_np(output['xy'][:, 1]))
            f_list.append(to_np(output['floor']))

    x_list = np.concatenate(x_list)
    y_list = np.concatenate(y_list)
    f_list = np.concatenate(f_list)
    return x_list, y_list, f_list

## train

In [26]:
oofs = np.zeros((len(train), 2), dtype = np.float32)  # 全てのoofをdfで格納する
predictions = []  # 全ての予測値をdfで格納する
val_scores = []
# skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
gkf = GroupKFold(n_splits=N_SPLITS)
# for fold, (trn_idx, val_idx) in enumerate(skf.split(train.loc[:, 'path'], train.loc[:, 'path'])):
for fold, (trn_idx, val_idx) in enumerate(gkf.split(train.loc[:, 'path_id'], groups=train.loc[:, 'path_id'])):

    # 指定したfoldのみループを回す

    print('=' * 20)
    print(f'Fold {fold}')
    print('=' * 20)

    # train/valid data
    trn_df = train.loc[trn_idx, BSSID_FEATS + RSSI_FEATS + DELTA_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)
    val_df = train.loc[val_idx, BSSID_FEATS + RSSI_FEATS + DELTA_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)

    # data loader
    loaders = {}
    loader_config = config["loader"]
    loaders["train"] = DataLoader(IndoorDataset(trn_df, phase="train"), **loader_config["train"], worker_init_fn=worker_init_fn) 
    loaders["valid"] = DataLoader(IndoorDataset(val_df, phase="valid"), **loader_config["valid"], worker_init_fn=worker_init_fn)
    loaders["test"] = DataLoader(IndoorDataset(test, phase="test"), **loader_config["test"], worker_init_fn=worker_init_fn)
    
    # model
    model = LSTMModel(wifi_bssids_size+rssi_bssids_size, site_count)
    model_name = model.__class__.__name__
    
    # loggers
    RUN_NAME = f'exp{str(EXP_NAME)}'
    wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type=RUN_NAME + f'-fold-{fold}')
    wandb.run.name = RUN_NAME + f'-fold-{fold}'
    wandb_config = wandb.config
    wandb_config.model_name = model_name
    wandb.watch(model)
    
    
    loggers = []
    loggers.append(WandbLogger())

    learner = Learner(model, config)
    
    # callbacks
    callbacks = []
    checkpoint_callback = ModelCheckpoint(
        monitor=f'Loss/val',
        mode='min',
        dirpath=OUTPUT_DIR,
        verbose=False,
        filename=f'{model_name}-{learner.current_epoch}-{fold}')
    callbacks.append(checkpoint_callback)

    early_stop_callback = EarlyStopping(
        monitor='Loss/val',
        min_delta=0.00,
        patience=20,
        verbose=True,
        mode='min')
    callbacks.append(early_stop_callback)
    
    trainer = pl.Trainer(
        logger=loggers,
        callbacks=callbacks,
        max_epochs=MAX_EPOCHS,
        default_root_dir=OUTPUT_DIR,
        gpus=1,
        fast_dev_run=DEBUG,
        deterministic=True,
        benchmark=True,
#         precision=16,
#         progress_bar_refresh_rate=0  # vscodeの時progress barの動作が遅いので表示しない
        )


    trainer.fit(learner, train_dataloader=loaders['train'], val_dataloaders=loaders['valid'])

    #############
    # validation (to make oof)
    #############
    model.eval()
    oof_x, oof_y, oof_f = evaluate(model, loaders, phase="valid")
    oofs[val_idx, 0] = oof_x
    oofs[val_idx, 1] = oof_y

    
    val_score = mean_position_error(
        oof_x, oof_y, 0,
        val_df['x'].values, val_df['y'].values, 0)
    val_scores.append(val_score)
    print(f"fold {fold}: mean position error {val_score}")

    #############
    # inference
    #############
    preds_x, preds_y, preds_f = evaluate(model, loaders, phase="test")
    test_preds = pd.DataFrame(np.stack((preds_f, preds_x, preds_y))).T
    test_preds.columns = sub.columns
    test_preds["site_path_timestamp"] = test["site_path_timestamp"]
    test_preds["floor"] = test_preds["floor"].astype(int)
    predictions.append(test_preds)
    

Fold 0


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))
[34m[1mwandb[0m: Currently logged in as: [33msqrt4kaido[0m (use `wandb login --relogin` to force relogin)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.3 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.3 M    Trainable params
0         Non-trainable params
11.3 M    Total params
45.303    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 165.8133773803711


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 134.42549086324175


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 109.51892014295856


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 94.20691978467708


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 86.41893258413499


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 83.19431330756477


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 82.25849671387401


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 82.14661011243928


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 82.2121002800665


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 82.26248636005926


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 59.830833512533104


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 54.58699796821337


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 52.2942063241217


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 50.61400614291039


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 49.33958125816426


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 48.514673887671336


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 47.6282562277995


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 46.6986020617195


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 45.98503320109016


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 36.987686179362136


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 26.97453365309531


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 21.87705498061435


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 17.397577769264927


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 15.033475350954443


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 13.731278600575923


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 12.341799618376472


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 11.924593973026566


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 11.178240357740654


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 10.704859136945839


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 10.387442056937498


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 9.958990463820063


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 9.843150153304592


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 10.389260898322


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 9.530816659552293


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 9.119738991378892


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 8.765507601215944


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 8.745517370292456


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 8.523431154981163


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 8.408667705301415


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 10.038669169602434


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 8.283357065415192


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 8.01432676507969


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 8.319176617935348


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 8.457639469271344


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 8.594558314305594


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 7.953743638055647


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 7.781810984928173


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 7.899452188361877


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 7.905198606654646


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 7.9595005693113885


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 7.982427995178793


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 8.21424355456682


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 6.8801261819154504


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 6.853213138127362


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 6.881287014358868


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 6.8447811889463965


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 6.852016282240447


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 6.845146605821418


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 6.832599122883452


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 6.863308169829981


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 6.840868991394339


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 6.868696468096252


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 6.841858876999184


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 6.81268510346051


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 6.795663779611463


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 6.799387531436286


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 6.796058202952999


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 6.791354520639091


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 6.797940272657738


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 6.794060186855055


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 6.790766756687783


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 6.785228468540175


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 6.785854021503211


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 6.786579021453448


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 6.798122051785435


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 6.785745189903681


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 6.782107008418139


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 6.788422363792828


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 6.788253555650381


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 6.795420157430302


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 6.78934634564488


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 6.785135042224165


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 6.7914037320682015


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 6.791757658786885


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 6.7881380828149895


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 6.7939084677857124


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 6.791336105276174


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 6.78657172562972


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 6.787036699708651


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 6.786285247901114


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 6.784616400958286


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 6.792288029754954


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 6.782612543682014


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 6.791857727875946


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 6.790647373439265


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 6.787539722379909


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 6.787848118528509
fold 0: mean position error 6.803763280843316
Fold 1


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,38.83481
Loss/xy,38.83481
Loss/floor,5.01106
MPE/val,6.80376
epoch,119.0
trainer/global_step,47159.0
_runtime,548.0
_timestamp,1619456049.0
_step,119.0


0,1
Loss/val,█▄▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▄▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,▁▁▁▅▇█▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
MPE/val,█▅▅▄▃▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.3 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.3 M    Trainable params
0         Non-trainable params
11.3 M    Total params
45.303    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 141.21902465820312


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 132.73067718881344


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 107.50050584992721


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 92.00741922586369


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 83.98651607569997


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 68.53026251400411


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 61.773351567856814


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 58.01505621904561


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 54.76271903814615


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 54.70366244722945


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 42.72624749745857


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 35.44756037431099


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 30.355810466899992


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 25.59765117459981


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 21.236428130449287


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 18.40395911616611


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 16.388519354438415


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 14.924845428638962


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 13.57055828726576


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 12.419185871056222


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 11.66512379307094


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 12.0261295291139


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 10.379894612093217


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 10.311222645587211


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 10.067287862467403


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 9.619524090412533


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 10.491180562470873


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 10.239785396998963


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 8.964871122277016


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 9.503928559276844


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 8.712822312363855


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 8.377743824205167


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 8.236905464530633


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 8.309719655889351


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 8.377718870137139


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 10.593763018279544


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 8.42502063959869


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 8.836149315410806


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 9.108491087380004


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 8.593544823519267


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 8.071318795290042


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 7.946527940084825


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 8.223851739168936


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 8.236235328779255


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 8.14078733369199


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 7.806805242242901


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 8.060521611546331


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 8.003062460977443


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 7.990783901899163


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 8.815225823166907


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 7.533971576825946


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 7.379402480073163


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 7.312787220653811


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 7.265540611403183


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 7.228654493138838


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 7.1707432112744005


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 7.156635699568322


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 7.147275948457863


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 7.136728973355878


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 7.112317407359212


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 7.149662559323994


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 7.116043030049963


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 7.117109542607287


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 7.101925439996512


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 7.079251191192306


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 7.073591930390224


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 7.065858500505453


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 7.065377297822619


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 7.053172449796092


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 7.043059771583076


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 7.042898028275542


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 7.0463285865308185


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 7.049043090425972


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 7.051511415788596


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 7.046351320380271


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 7.04459186064513


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 7.044915592698622


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 7.045896866181678


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 7.050091987780389


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 7.048193439985382


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 7.043988518886045


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 7.048069753161098


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 7.0408513931825345


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 7.045900184775638


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 7.0386724538657575


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 7.041757060562747


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 7.039519758755331


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 7.049634747676328


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 7.053676567409671


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 7.043231028473405
fold 1: mean position error 7.021953872623633
Fold 2


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,39.93956
Loss/xy,39.93956
Loss/floor,5.25928
MPE/val,7.02195
epoch,88.0
trainer/global_step,34976.0
_runtime,416.0
_timestamp,1619456472.0
_step,88.0


0,1
Loss/val,█▅▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▅▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,███████▁▂▃▂▄▆▇█████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
MPE/val,█▆▄▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.3 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.3 M    Trainable params
0         Non-trainable params
11.3 M    Total params
45.303    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 170.17762756347656


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 123.05954293899323


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 97.34375279728069


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 84.5129819620355


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 80.0154487745956


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 79.31007643994172


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 79.62823595038388


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 64.10747276263521


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 55.40338570631476


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 52.285608922899755


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 51.07120075918591


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 49.57712605310568


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 49.07196899928919


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 49.39058708053186


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 47.36776162423602


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 42.56572681413065


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 31.30981492903984


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 24.74251537478767


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 19.968316126351404


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 17.357294765052657


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 15.217772735002962


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 13.705630942208778


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 14.06033996333621


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 12.084255610182545


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 11.194586793094867


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 12.13381265328311


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 10.485624288042054


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 10.043799793543467


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 9.894010246427705


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 9.806450733645562


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 9.729024390161818


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 12.07893173940049


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 9.1997444503107


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 9.154929849426644


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 8.80245076756874


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 8.900860720236224


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 9.267959621473715


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 9.042190471106439


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 8.806391507457617


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 9.557837152798765


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 8.571602692020956


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 8.709300739699069


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 8.420356695559263


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 8.518450059712432


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 8.500971667873662


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 8.311710621140834


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 9.000369897637858


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 8.259057629628511


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 8.234552071372745


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 8.340531701506684


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 9.306617215400662


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 8.569564297510274


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 8.073786546050055


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 8.318197107817213


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 8.291452154200128


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 8.143243465837601


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 8.063872458574568


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 8.025959633663858


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 8.674809893356608


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 8.23882443710678


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 8.161697705484272


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 7.39355035343658


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 7.303694366973215


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 7.2454726766212865


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 7.238658382486277


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 7.221392703471377


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 7.239792361096518


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 7.239227485943681


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 7.223487465401607


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 7.201755827430907


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 7.196058721150886


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 7.2075831632370235


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 7.195640113567399


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 7.204902372845982


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 7.160336968106094


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 7.161868053688789


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 7.142690068799758


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 7.153036875697123


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 7.143965526214344


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 7.152132886521359


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 7.144154354651308


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 7.1435218422022935


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 7.140390988552209


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 7.146589493869378


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 7.136993943043685


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 7.136726053304732


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 7.139026647891994


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 7.136500210806091


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 7.139253332362287


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 7.140405897373746


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 7.1473786558819


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 7.145490379403182


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 7.1457722041008935


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 7.137000552049741


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 7.133677647289553


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 7.141264328976075


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 7.133971230178758


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 7.139089938417292


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 7.147033108605279


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 7.158167563995288


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 7.143294639595601


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 7.1426027445749085


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 7.138254701309811


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 7.136246141146363


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 7.140556593902542


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 7.133915610347234


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 7.142142869927616


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 7.136012259540111


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 7.135192128732182


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 7.147255840133448


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 7.1481489705803725


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 7.14061839663032


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 7.136855006448536


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 7.139542591026105


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 7.138859950520007


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 7.144752322282428


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 7.136281098546763
fold 2: mean position error 7.1155264035193255
Fold 3


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,42.58282
Loss/xy,42.58282
Loss/floor,4.77813
MPE/val,7.11553
epoch,115.0
trainer/global_step,45587.0
_runtime,545.0
_timestamp,1619457023.0
_step,115.0


0,1
Loss/val,█▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,███▆▆▇▂▁▂▂▁▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
MPE/val,█▆▅▄▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.3 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.3 M    Trainable params
0         Non-trainable params
11.3 M    Total params
45.303    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 129.3344612121582


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 121.64419717667236


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 95.44989691253916


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 84.12296374240987


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 67.79214393521904


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 60.049089946922734


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 54.919692968453134


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 51.81644325859462


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 48.34806054299162


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 43.537413487792485


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 38.8466872234956


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 33.64242253410329


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 28.091788220939588


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 23.949171426686792


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 20.635056307402735


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 18.323911386673316


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 16.32736159020157


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 14.388653012394224


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 13.735713685663532


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 12.775564922488568


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 11.928126476917761


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 11.10208148646365


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 10.478329154656654


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 10.802899483373723


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 9.99681549863853


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 9.53873668347318


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 10.384298600174642


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 9.070817595152254


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 8.867153426351702


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 8.678755952854349


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 10.305245486382177


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 8.55065866536512


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 8.569749041613017


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 9.92765626940855


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 8.897645332125336


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 8.578731655184512


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 8.495739461457358


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 8.343346118717587


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 8.331446178345812


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 8.719015179677234


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 8.162360924505904


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 8.023363303342407


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 8.222295259350897


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 9.899158835044595


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 7.973511214836365


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 8.025145310396471


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 7.9121386346243225


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 8.109780129812556


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 8.104853725349583


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 8.245962349427181


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 8.19153527999898


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 7.362457907822868


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 7.23390795758381


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 7.142023545375349


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 7.147552454633885


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 7.106358295058702


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 7.062695981119094


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 7.066805688131783


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 7.037786029334858


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 7.082578428701502


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 7.04962575001103


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 7.054407271576598


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 7.053851520486396


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 7.009455126736004


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 7.006497851787662


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 6.9913395750559255


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 6.999632023696999


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 6.991895679638058


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 6.992233577505455


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 6.986817900214113


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 6.991315381381403


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 6.9894964890779505


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 6.98765744208033


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 6.985078861486194


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 6.9849680922352855


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 6.987429722914531


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 6.983472603583053


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 6.987377687771696


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 6.990088177397422


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 6.982228453765587


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 6.980298912834064


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 6.986855536633853


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 6.985096330228059


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 6.980907539448092


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 6.984278952598991


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 6.983324301782071


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 6.985693422701989


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 6.977894106610701


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 6.982315841661403


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 6.980857485577261


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 6.979983929536507


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 6.981076524924646


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 6.979173524158798


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 6.983389504918101


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 6.981038655426401


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 6.980169378145777


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 6.983087103043306


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 6.981588591281185


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 6.9822494605679655


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 6.98472554022982


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 6.98339332338483


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 6.98381282344009


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 6.9842001293435345


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 6.981282146865985


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 6.9790369920893935


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 6.985686471182681


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 6.9774676177603245


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 6.982423268445081


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 6.985348370811318


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 6.9823685668255075


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 6.982533051143574


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 6.988408352337127


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 6.9813323883723415


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 6.98249389458289


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 6.982364675063651


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 6.982150082217532


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 6.980010225105872


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 6.983094867929116


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 6.98329638827512


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 6.9830432939173415


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 6.98441775241964


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 6.980941445564668


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 6.983754062945768


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 6.98160410597495


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 6.983586868383302


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 6.982082966497502


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 6.978959409615739
fold 3: mean position error 6.979996726606102
Fold 4


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,38.5964
Loss/xy,38.5964
Loss/floor,4.79276
MPE/val,6.98
epoch,125.0
trainer/global_step,49517.0
_runtime,589.0
_timestamp,1619457618.0
_step,125.0


0,1
Loss/val,█▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,█▇▅▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
MPE/val,█▅▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.3 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.3 M    Trainable params
0         Non-trainable params
11.3 M    Total params
45.303    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 128.2920799255371


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 123.45542772768044


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 96.57584302279842


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 77.58849175645429


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 66.34119921061466


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 59.66947503984105


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 54.29634848884914


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 45.06744402942817


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 38.748544452166314


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 33.888369454277885


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 28.36899697367015


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 24.35749813418248


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 19.202848597791668


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 16.459624096338928


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 14.590740201344447


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 13.076195048615762


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 12.208840092703811


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 12.178574751593011


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 10.7566741686617


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 10.343317512885802


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 11.172978444114085


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 9.796757479220773


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 9.394364396364583


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 9.16145102183024


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 9.255327344097797


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 8.842376884055023


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 8.628304447581174


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 8.802271608621549


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 9.393755276160226


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 8.597233742750367


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 8.593780759033216


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 9.944289483048175


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 8.758958621144451


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 8.100074657286056


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 8.106774144388366


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 9.292682673672495


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 8.203795590105422


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 8.003012551829958


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 7.963544057972555


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 7.991255269534346


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 8.044796625144869


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 8.173373162458486


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 7.4161487942889535


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 7.305337417947655


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 7.286599675657249


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 7.251615826686747


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 7.257576725699685


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 7.229507042537617


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 7.21133770542706


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 7.217715608691457


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 7.202221216629826


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 7.1948315840726575


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 7.212259914354635


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 7.178383611093554


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 7.2189343083465


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 7.221404395398729


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 7.186484707307543


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 7.19312041110515


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 7.209124765324917


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 7.191697606717907


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 7.237502123632083


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 7.182074850988744


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 7.202303037446583


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 7.170987896087874


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 7.1686329770412796


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 7.158225009878006


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 7.143444059185151


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 7.150343394666897


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 7.171196482503001


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 7.159518328789404


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 7.169761603685864


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 7.136964794622998


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 7.176783670183332


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 7.145711313070934


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 7.136314006871595


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 7.14272467326741


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 7.113400824299286


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 7.110193457161589


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 7.096353821027787


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 7.1036149319400375


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 7.09452265346579


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 7.105981715164218


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 7.0876323507708525


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 7.086363762682553


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 7.083686491769817


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 7.087446788658843


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 7.091647714494245


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 7.083436212887301


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 7.091216638101838


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 7.085827130731072


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 7.091824290309499


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 7.087110378054291


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 7.0869216093851595


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 7.084130681288782


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 7.084525180480749


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 7.066758712028064


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 7.073210713525167


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 7.090588173009767


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 7.087516715875256


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 7.076941613518569


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 7.080697568927777


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 7.083809941496616


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 7.080437761155147


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 7.078376477467264


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 7.075767630592584


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 7.089553903475842


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 7.0887954275912755


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 7.0846839539566


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 7.072823267942152


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 7.07874792523692


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 7.075533292876768


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 7.082780413948029


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 7.086842846441122


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 7.083440916130927


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 7.090287600162524


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 7.0777064326775925
fold 4: mean position error 7.106145670701268


In [27]:
oofs_df = pd.DataFrame(oofs, columns=['x', 'y'])
oofs_df['path'] = train_df['path_id']
oofs_df['timestamp'] = train_df['wp_tmestamp']
oofs_df['site'] = train_df['site_id']
oofs_df['site_path_timestamp'] = oofs_df['site'] + '_' + oofs_df['path'] + '_' + oofs_df['timestamp'].astype(str)
oofs_df['floor'] = train_df['floor']
oofs_df.to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}.csv", index=False)
oofs_df

Unnamed: 0,x,y,path,timestamp,site,site_path_timestamp,floor
0,161.605011,104.301254,5e158ef61506f2000638fd1f,1578469851129,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
1,163.834015,106.054291,5e158ef61506f2000638fd1f,1578469857653,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
2,162.539764,110.929245,5e158ef61506f2000638fd1f,1578469857653,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
3,162.025391,110.759506,5e158ef61506f2000638fd1f,1578469857653,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
4,161.221939,110.610168,5e158ef61506f2000638fd1f,1578469862177,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
...,...,...,...,...,...,...,...
251108,198.552979,141.156830,5dcd5c9323759900063d590a,1573733061352,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0
251109,193.014008,142.992310,5dcd5c9323759900063d590a,1573733070079,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0
251110,190.655258,142.768112,5dcd5c9323759900063d590a,1573733070079,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0
251111,190.270554,143.015411,5dcd5c9323759900063d590a,1573733070079,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0


In [28]:
oofs_score = mean_position_error(
        oofs_df['x'], oofs_df['y'], 0,
        train_df['x'].values, train_df['y'].values, 0)
print(f"CV:{oofs_score}")

CV:7.00547689143968


In [29]:
all_preds = pd.concat(predictions).groupby('site_path_timestamp').mean()
all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474563646,0,85.713966,104.649239
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474572654,0,80.899323,102.921295
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474578963,0,85.225449,105.812569
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474582400,0,88.351288,107.935364
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474585965,0,87.318352,108.258492
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731143256,0,210.101776,100.415657
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731146426,0,208.612549,101.582489
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731151563,0,205.288788,106.346169
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731157567,0,199.996414,112.778862


In [30]:
all_preds_37 = pd.read_csv('../37/output/sub37.csv', index_col=0)
all_preds_37.index = pd.read_csv(WIFI_DIR / 'test_7_th20000.csv')['site_path_timestamp']
all_preds_37

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474563646,0,88.266884,104.794300
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474572654,0,82.316630,104.338745
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474578963,0,84.221380,105.362060
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474582400,0,87.842510,109.344190
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474585965,0,87.390120,108.134900
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731143256,5,214.121280,98.048190
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731146426,5,211.845250,100.757540
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731151563,5,208.958570,107.238950
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731157567,5,202.751280,110.968445


In [31]:
all_preds_merge = pd.merge(all_preds_37, all_preds, how='left', on='site_path_timestamp')[['floor_y', 'x_y', 'y_y']]
all_preds_merge = all_preds_merge.rename(columns={'floor_y': 'floor', 'x_y': 'x', 'y_y': 'y'})
all_preds_merge

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474563646,0.0,85.713966,104.649239
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474572654,0.0,80.899323,102.921295
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474578963,0.0,85.225449,105.812569
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474582400,0.0,88.351288,107.935364
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474585965,0.0,87.318352,108.258492
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731143256,0.0,210.101776,100.415657
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731146426,0.0,208.612549,101.582489
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731151563,0.0,205.288788,106.346169
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731157567,0.0,199.996414,112.778862


In [32]:
all_preds_merge['floor'].fillna(all_preds_37['floor'], inplace=True)
all_preds_merge['x'].fillna(all_preds_37['x'], inplace=True)
all_preds_merge['y'].fillna(all_preds_37['y'], inplace=True)
all_preds_merge

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474563646,0.0,85.713966,104.649239
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474572654,0.0,80.899323,102.921295
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474578963,0.0,85.225449,105.812569
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474582400,0.0,88.351288,107.935364
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474585965,0.0,87.318352,108.258492
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731143256,0.0,210.101776,100.415657
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731146426,0.0,208.612549,101.582489
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731151563,0.0,205.288788,106.346169
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731157567,0.0,199.996414,112.778862


In [33]:
# foldの結果を平均した後、reindexでsubmission fileにindexを合わせる
all_preds_merge.index = sub.index
all_preds_merge

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0.0,85.713966,104.649239
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0.0,80.899323,102.921295
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0.0,85.225449,105.812569
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0.0,88.351288,107.935364
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0.0,87.318352,108.258492
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,0.0,210.101776,100.415657
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,0.0,208.612549,101.582489
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,0.0,205.288788,106.346169
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,0.0,199.996414,112.778862


In [34]:
# floorの数値を置換
simple_accurate_99 = pd.read_csv('../01/submission.csv')
all_preds_merge['floor'] = simple_accurate_99['floor'].values
all_preds_merge.to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}.csv")
all_preds_merge

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,85.713966,104.649239
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,80.899323,102.921295
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,85.225449,105.812569
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,88.351288,107.935364
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,87.318352,108.258492
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,5,210.101776,100.415657
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,5,208.612549,101.582489
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,5,205.288788,106.346169
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,5,199.996414,112.778862


# Post Proccess

In [14]:
oofs_df = pd.read_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}.csv")
sub_df = pd.read_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}.csv")

In [16]:
def compute_rel_positions(acce_datas, ahrs_datas):
    step_timestamps, step_indexs, step_acce_max_mins = compute_f.compute_steps(acce_datas)
    headings = compute_f.compute_headings(ahrs_datas)
    stride_lengths = compute_f.compute_stride_length(step_acce_max_mins)
    step_headings = compute_f.compute_step_heading(step_timestamps, headings)
    rel_positions = compute_f.compute_rel_positions(stride_lengths, step_headings)
    return rel_positions

In [32]:
import math

order = 3
fs = 50.0  # sample rate, Hz
# fs = 100
# cutoff = 3.667  # desired cutoff frequency of the filter, Hz
cutoff = 3

step_distance = 0.8
w_height = 1.7
m_trans = -5

from scipy.signal import butter, lfilter

def butter_lowpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    return b, a

def butter_lowpass_filter(data, cutoff, fs, order=5):
    b, a = butter_lowpass(cutoff, fs, order=order)
    y = lfilter(b, a, data)
    return y

def peak_accel_threshold(data, timestamps, threshold):
    d_acc = []
    last_state = 'below'
    crest_troughs = 0
    crossings = []

    for i, datum in enumerate(data):
        
        current_state = last_state
        if datum < threshold:
            current_state = 'below'
        elif datum > threshold:
            current_state = 'above'

        if current_state is not last_state:
            if current_state is 'above':
                crossing = [timestamps[i], threshold]
                crossings.append(crossing)
            else:
                crossing = [timestamps[i], threshold]
                crossings.append(crossing)

            crest_troughs += 1
        last_state = current_state
    return np.array(crossings)

In [33]:
def steps_compute_rel_positions(sample_file):
    
    mix_acce = np.sqrt(sample_file.acce[:,1:2]**2 + sample_file.acce[:,2:3]**2 + sample_file.acce[:,3:4]**2)
    mix_acce = np.concatenate([sample_file.acce[:,0:1], mix_acce], 1)
    mix_df = pd.DataFrame(mix_acce)
    mix_df.columns = ["timestamp","acce"]
    
    filtered = butter_lowpass_filter(mix_df["acce"], cutoff, fs, order)

    threshold = filtered.mean() * 1.1
    crossings = peak_accel_threshold(filtered, mix_df["timestamp"], threshold)

    step_sum = len(crossings)/2
    distance = w_height * 0.4 * step_sum

    mag_df = pd.DataFrame(sample_file.magn)
    mag_df.columns = ["timestamp","x","y","z"]
    
    acce_df = pd.DataFrame(sample_file.acce)
    acce_df.columns = ["timestamp","ax","ay","az"]
    
    mag_df = pd.merge(mag_df,acce_df,on="timestamp")
    mag_df.dropna()
    
    time_di_list = []

    for i in mag_df.iterrows():

        gx,gy,gz = i[1][1],i[1][2],i[1][3]
        ax,ay,az = i[1][4],i[1][5],i[1][6]

        roll = math.atan2(ay,az)
        pitch = math.atan2(-1*ax , (ay * math.sin(roll) + az * math.cos(roll)))

        q = m_trans - math.degrees(math.atan2(
            (gz*math.sin(roll)-gy*math.cos(roll)),(gx*math.cos(pitch) + gy*math.sin(roll)*math.sin(pitch) + gz*math.sin(pitch)*math.cos(roll))
        )) -90
        if q <= 0:
            q += 360
        time_di_list.append((i[1][0],q))

    d_list = [x[1] for x in time_di_list]
    
    steps = []
    step_time = []
    di_dict = dict(time_di_list)

    for n,i in enumerate(crossings[:,:1]):
        if n % 2 == 1:
            continue
        direct_now = di_dict[i[0]]
        dx = math.sin(math.radians(direct_now))
        dy = math.cos(math.radians(direct_now))
#         print(int(n/2+1),"歩目/x:",dx,"/y:",dy,"/角度：",direct_now)
        steps.append((i[0],dx,dy))
        step_time.append(i[0])
    
        step_dtime = np.diff(step_time)/1000
        step_dtime = step_dtime.tolist()
        step_dtime.insert(0,5)
        
        rel_position = []

        wp_idx = 0
#         print("WP:",round(sample_file.waypoint[0,1],3),round(sample_file.waypoint[0,2],3),sample_file.waypoint[0,0])
#         print("------------------")
        for p,i in enumerate(steps):
            step_distance = 0
            if step_dtime[p] >= 1:
                step_distance = w_height*0.25
            elif step_dtime[p] >= 0.75:
                step_distance = w_height*0.3
            elif step_dtime[p] >= 0.5:
                step_distance = w_height*0.4
            elif step_dtime[p] >= 0.35:
                step_distance = w_height*0.45
            elif step_dtime[p] >= 0.2:
                step_distance = w_height*0.5
            else:
                step_distance = w_height*0.4

#             step_x += i[1]*step_distance
#             step_y += i[2]*step_distance
            
            rel_position.append([i[0], i[1]*step_distance, i[2]*step_distance])
#     print(rel_position)
    
    return np.array(rel_position)

In [34]:
def correct_path(args):

    path, path_df = args
    T_ref  = path_df['timestamp'].values
    xy_hat = path_df[['x', 'y']].values
    txt_path = path_df['txt_path'].values[0]
    
    example = read_data_file(txt_path)
    rel_positions1 = compute_rel_positions(example.acce, example.ahrs)
    rel_positions2 = steps_compute_rel_positions(example)
    rel1 = rel_positions1.copy()
    rel2 = rel_positions2.copy()
    rel1[:,1:] = rel_positions1[:,1:] / 2
    rel2[:,1:] = rel_positions2[:,1:] / 2
    rel_positions = np.vstack([rel1,rel2])
    rel_positions = rel_positions[np.argsort(rel_positions[:, 0])]
    
    if T_ref[-1] > rel_positions[-1, 0]:
        rel_positions = [np.array([[0, 0, 0]]), rel_positions, np.array([[T_ref[-1], 0, 0]])]
    else:
        rel_positions = [np.array([[0, 0, 0]]), rel_positions]
    rel_positions = np.concatenate(rel_positions)
    
    T_rel = rel_positions[:, 0]
    delta_xy_hat = np.diff(scipy.interpolate.interp1d(T_rel, np.cumsum(rel_positions[:, 1:3], axis=0), axis=0)(T_ref), axis=0)

    N = xy_hat.shape[0]
    delta_t = np.diff(T_ref)
    alpha = (8.1)**(-2) * np.ones(N)
    beta  = (0.3 + 0.3 * 1e-3 * delta_t)**(-2)
    A = scipy.sparse.spdiags(alpha, [0], N, N)
    B = scipy.sparse.spdiags( beta, [0], N-1, N-1)
    D = scipy.sparse.spdiags(np.stack([-np.ones(N), np.ones(N)]), [0, 1], N-1, N)

    Q = A + (D.T @ B @ D)
    c = (A @ xy_hat) + (D.T @ (B @ delta_xy_hat))
    xy_star = scipy.sparse.linalg.spsolve(Q, c)

    return pd.DataFrame({
        'site_path_timestamp' : path_df['site_path_timestamp'],
        'floor' : path_df['floor'],
        'x' : xy_star[:, 0],
        'y' : xy_star[:, 1],
    })

In [35]:
tmp = sub_df['site_path_timestamp'].apply(lambda s : pd.Series(s.split('_')))
sub_df['site'] = tmp[0]
sub_df['path'] = tmp[1]
sub_df['timestamp'] = tmp[2].astype(float)

In [36]:
used_buildings = sorted(sub_df['site'].value_counts().index.tolist())
test_txts = sorted(glob.glob(str(DATA_DIR/'indoor-location-navigation') + f'/test/*.txt'))
train_txts = [sorted(glob.glob(str(DATA_DIR/'indoor-location-navigation') + f'/train/{used_building}/*/*.txt')) for used_building in used_buildings]
train_txts = sum(train_txts, [])

In [37]:
txt_pathes = []
for path in tqdm(sub_df['path'].values):
    txt_pathes.append([test_txt for test_txt in test_txts if path in test_txt][0])

100%|██████████| 10133/10133 [00:00<00:00, 28922.91it/s]


In [38]:
sub_df['txt_path'] = txt_pathes

In [39]:
processes = multiprocessing.cpu_count()
with multiprocessing.Pool(processes=processes) as pool:
    dfs = pool.imap_unordered(correct_path, sub_df.groupby('path'))
    dfs = tqdm(dfs)
    dfs = list(dfs)
sub_df_cm = pd.concat(dfs).sort_values('site_path_timestamp')

626it [02:21,  4.43it/s]


In [40]:
txt_pathes = []
for path in tqdm(oofs_df['path'].values):
    txt_pathes.append([train_txt for train_txt in train_txts if path in train_txt][0])

100%|██████████| 251113/251113 [02:26<00:00, 1711.14it/s]


In [41]:
oofs_df['txt_path'] = txt_pathes

In [42]:
processes = multiprocessing.cpu_count()
with multiprocessing.Pool(processes=processes) as pool:
    dfs = pool.imap_unordered(correct_path, oofs_df.groupby('path'))
    dfs = tqdm(dfs)
    dfs = list(dfs)
oofs_df_cm = pd.concat(dfs).sort_index()
oofs_df_cm

10789it [16:30, 10.89it/s]


Unnamed: 0,site_path_timestamp,floor,x,y
0,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,158.978636,104.825028
1,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.257042,108.604798
2,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.251276,108.609015
3,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.245114,108.610050
4,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,165.494042,112.005578
...,...,...,...,...
251108,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,203.264141,140.781741
251109,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,190.250754,143.186429
251110,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,190.243241,143.182151
251111,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,190.235164,143.178441


In [None]:
oofs_df_cm.to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}_cm.csv", index=False)

In [45]:
oofs_score = mean_position_error(
        oofs_df_cm['x'], oofs_df_cm['y'], 0,
        train_df['x'].values, train_df['y'].values, 0)
print(f"CV:{oofs_score}")

CV:5.5303376207693065


In [None]:
sub_df_cm.to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}_cm.csv", index=False)

In [None]:
oofs_df_cm = pd.read_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}_cm.csv")

In [None]:
sub_df_cm = pd.read_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}_cm.csv")

In [46]:
def split_col(df):
    df = pd.concat([
        df['site_path_timestamp'].str.split('_', expand=True) \
        .rename(columns={0:'site',
                         1:'path',
                         2:'timestamp'}),
        df
    ], axis=1).copy()
    return df
def sub_process(sub, train_waypoints):
    train_waypoints['isTrainWaypoint'] = True
    sub = split_col(sub[['site_path_timestamp','floor','x','y']]).copy()
    sub = sub.merge(train_waypoints[['site','floorNo','floor']].drop_duplicates(), how='left')
    sub = sub.merge(
        train_waypoints[['x','y','site','floor','isTrainWaypoint']].drop_duplicates(),
        how='left',
        on=['site','x','y','floor']
             )
    sub['isTrainWaypoint'] = sub['isTrainWaypoint'].fillna(False)
    return sub.copy()

In [47]:
train_waypoints = pd.read_csv(str(DATA_DIR/'indoor-location-navigation') + '/train_waypoints.csv')


In [48]:
sub_df_cm = sub_process(sub_df_cm, train_waypoints)
oofs_df_cm = sub_process(oofs_df_cm, train_waypoints)

In [49]:
from scipy.spatial.distance import cdist

def add_xy(df):
    df['xy'] = [(x, y) for x,y in zip(df['x'], df['y'])]
    return df

def closest_point(point, points):
    """ Find closest point from a list of points. """
    return points[cdist([point], points).argmin()]

sub_df_cm = add_xy(sub_df_cm)
train_waypoints = add_xy(train_waypoints)

ds = []
for (site, myfloor), d in tqdm(sub_df_cm.groupby(['site','floor'])):
    true_floor_locs = train_waypoints.loc[(train_waypoints['floor'] == myfloor) &
                                          (train_waypoints['site'] == site)] \
        .reset_index(drop=True)
    if len(true_floor_locs) == 0:
        print(f'Skipping {site} {myfloor}')
        continue
    d['matched_point'] = [closest_point(x, list(true_floor_locs['xy'])) for x in d['xy']]
    d['x_'] = d['matched_point'].apply(lambda x: x[0])
    d['y_'] = d['matched_point'].apply(lambda x: x[1])
    ds.append(d)

sub_df_cm_ds = pd.concat(ds)


oofs_df_cm = add_xy(oofs_df_cm)
train_waypoints = add_xy(train_waypoints)

ds = []
for (site, myfloor), d in tqdm(oofs_df_cm.groupby(['site','floor'])):
    true_floor_locs = train_waypoints.loc[(train_waypoints['floor'] == myfloor) &
                                          (train_waypoints['site'] == site)] \
        .reset_index(drop=True)
    if len(true_floor_locs) == 0:
        print(f'Skipping {site} {myfloor}')
        continue
    d['matched_point'] = [closest_point(x, list(true_floor_locs['xy'])) for x in d['xy']]
    d['x_'] = d['matched_point'].apply(lambda x: x[0])
    d['y_'] = d['matched_point'].apply(lambda x: x[1])
    ds.append(d)

oofs_df_cm_ds = pd.concat(ds)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
100%|██████████| 118/118 [00:04<00:00, 27.25it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation:

In [50]:
def snap_to_grid(sub, threshold):
    """
    Snap to grid if within a threshold.
    
    x, y are the predicted points.
    x_, y_ are the closest grid points.
    _x_, _y_ are the new predictions after post processing.
    """
    sub['_x_'] = sub['x']
    sub['_y_'] = sub['y']
    sub.loc[sub['dist'] < threshold, '_x_'] = sub.loc[sub['dist'] < threshold]['x_']
    sub.loc[sub['dist'] < threshold, '_y_'] = sub.loc[sub['dist'] < threshold]['y_']
    return sub.copy()



In [51]:
# Calculate the distances
sub_df_cm_ds['dist'] = np.sqrt( (sub_df_cm_ds.x-sub_df_cm_ds.x_)**2 + (sub_df_cm_ds.y-sub_df_cm_ds.y_)**2 )
sub_pp = snap_to_grid(sub_df_cm_ds, threshold=5)
sub_pp = sub_pp[['site_path_timestamp','floor','_x_','_y_','site','path','floorNo']] \
    .rename(columns={'_x_':'x', '_y_':'y'})

# Calculate the distances
oofs_df_cm_ds['dist'] = np.sqrt( (oofs_df_cm_ds.x-oofs_df_cm_ds.x_)**2 + (oofs_df_cm_ds.y-oofs_df_cm_ds.y_)**2 )
oofs_pp = snap_to_grid(oofs_df_cm_ds, threshold=5)
oofs_pp = oofs_pp[['site_path_timestamp','floor','_x_','_y_','site','path','floorNo']] \
    .rename(columns={'_x_':'x', '_y_':'y'})

In [52]:
sub_pp = sub_pp.sort_index()
sub_pp

Unnamed: 0,site_path_timestamp,floor,x,y,site,path,floorNo
0,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,93.728470,97.948860,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
1,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,79.662285,102.766754,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
2,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,87.162830,104.450550,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
3,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,81.657740,110.509090,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
4,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,81.657740,110.509090,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
...,...,...,...,...,...,...,...
10128,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,205.596040,98.605774,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6
10129,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,206.627910,102.635210,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6
10130,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,205.511300,107.841324,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6
10131,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,192.553130,111.863014,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6


In [55]:
sub_pp[['site_path_timestamp','floor','x','y']] \
    .to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}_pp.csv", index=False)

In [53]:
oofs_pp = oofs_pp.sort_index()
oofs_pp

Unnamed: 0,site_path_timestamp,floor,x,y,site,path,floorNo
0,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,158.49695,107.122680,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
1,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.33182,110.822685,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
2,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.33182,110.822685,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
3,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.33182,110.822685,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
4,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.33182,110.822685,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
...,...,...,...,...,...,...,...
251108,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,202.50334,140.972670,5dc8cea7659e181adb076a3f,5dcd5c9323759900063d590a,F7
251109,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,191.95721,143.862530,5dc8cea7659e181adb076a3f,5dcd5c9323759900063d590a,F7
251110,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,191.95721,143.862530,5dc8cea7659e181adb076a3f,5dcd5c9323759900063d590a,F7
251111,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,191.95721,143.862530,5dc8cea7659e181adb076a3f,5dcd5c9323759900063d590a,F7


In [54]:
oofs_score = mean_position_error(
        oofs_pp['x'], oofs_pp['y'], 0,
        train_df['x'].values, train_df['y'].values, 0)
print(f"CV:{oofs_score}")

CV:5.05298385725833


In [56]:
oofs_pp[['site_path_timestamp','floor','x','y']] \
    .to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}_pp.csv", index=False)

In [59]:
wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type='summary')
wandb.run.name = 'summary'
wandb.log({'CV_score': oofs_score})
wandb.save(utils.get_notebook_path())
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33msqrt4kaido[0m (use `wandb login --relogin` to force relogin)




VBox(children=(Label(value=' 0.00MB of 0.42MB uploaded (0.00MB deduped)\r'), FloatProgress(value=0.00150438817…

0,1
CV_score,5.05298
_runtime,1.0
_timestamp,1619533532.0
_step,0.0


0,1
CV_score,▁
_runtime,▁
_timestamp,▁
_step,▁


In [None]:
import json
import matplotlib.pylab as plt

def plot_preds(
    site,
    floorNo,
    sub=None,
    true_locs=None,
    base=str(DATA_DIR/'indoor-location-navigation'),
    show_train=True,
    show_preds=True,
    fix_labels=True,
    map_floor=None
):
    """
    Plots predictions on floorplan map.
    
    map_floor : use a different floor's map
    """
    if map_floor is None:
        map_floor = floorNo
    # Prepare width_meter & height_meter (taken from the .json file)
    floor_plan_filename = f"{base}/metadata/{site}/{map_floor}/floor_image.png"
    json_plan_filename = f"{base}/metadata/{site}/{map_floor}/floor_info.json"
    with open(json_plan_filename) as json_file:
        json_data = json.load(json_file)

    width_meter = json_data["map_info"]["width"]
    height_meter = json_data["map_info"]["height"]

    floor_img = plt.imread(f"{base}/metadata/{site}/{map_floor}/floor_image.png")

    fig, ax = plt.subplots(figsize=(12, 12))
    plt.imshow(floor_img)

    if show_train:
        true_locs = true_locs.query('site == @site and floorNo == @map_floor').copy()
        true_locs["x_"] = true_locs["x"] * floor_img.shape[0] / height_meter
        true_locs["y_"] = (
            true_locs["y"] * -1 * floor_img.shape[1] / width_meter
        ) + floor_img.shape[0]
        true_locs.query("site == @site and floorNo == @map_floor").groupby("path").plot(
            x="x_",
            y="y_",
            style="+",
            ax=ax,
            label="train waypoint location",
            color="grey",
            alpha=0.5,
        )

    if show_preds:
        sub = sub.query('site == @site and floorNo == @floorNo').copy()
        sub["x_"] = sub["x"] * floor_img.shape[0] / height_meter
        sub["y_"] = (
            sub["y"] * -1 * floor_img.shape[1] / width_meter
        ) + floor_img.shape[0]
        for path, path_data in sub.query(
            "site == @site and floorNo == @floorNo"
        ).groupby("path"):
            path_data.plot(
                x="x_",
                y="y_",
                style=".-",
                ax=ax,
                title=f"{site} - floor - {floorNo}",
                alpha=1,
                label=path,
            )
    if fix_labels:
        handles, labels = ax.get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        plt.legend(
            by_label.values(), by_label.keys(), loc="center left", bbox_to_anchor=(1, 0.5)
        )
    return fig, ax

In [None]:
example_site = '5a0546857ecc773753327266'
example_floorNo = 'F1'

sub_df = sub_process(sub_df, train_waypoints)
plot_preds(example_site, example_floorNo, sub_df,
           train_waypoints, show_preds=True)

In [None]:
plot_preds(example_site, example_floorNo, sub_df_cm,
           train_waypoints, show_preds=True)

In [None]:
plot_preds(example_site, example_floorNo, sub_pp,
           train_waypoints, show_preds=True)