# LSTM baseline

from kuto

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

In [2]:
import os
import sys
import glob
import pickle
import random

In [3]:
import numpy as np
import pandas as pd
import scipy.stats as stats
from pathlib import Path


In [4]:
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder

In [5]:
import pytorch_lightning as pl
# from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping

import wandb
from pytorch_lightning.loggers import WandbLogger

In [6]:
sys.path.append('../../')
import src.utils as utils

In [7]:
import multiprocessing
import scipy.interpolate
import scipy.sparse
from tqdm import tqdm

from indoor_location_competition_20.io_f import read_data_file
import indoor_location_competition_20.compute_f as compute_f

In [8]:
DATA_DIR = Path("/home/knikaido/work/Indoor-Location-Navigation/data/")
WIFI_DIR = DATA_DIR / 'indoorunifiedwifids_original'
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'
OUTPUT_DIR = Path('./output/')
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'

## config

In [9]:
configs = {
    'loss':{
        'name': 'MSELoss',
        'params':{}
    },
    'optimizer':{
        'name': 'Adam',
        'params':{
            'lr': 0.01,
        }
    },

    'scheduler':{
        'name': 'ReduceLROnPlateau',
        'params':{
            'factor': 0.1,
            'patience': 3,
        }
    },

    'loader':{
        'train':{
            'batch_size': 512,
            'shuffle': True,
            'num_workers': 4,
        },
        'valid':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        },
        'test':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        }
    }
}

In [10]:
# config
config = configs

# globals variable
SEED = 777
MAX_EPOCHS = 500
N_SPLITS = 5
DEBUG = False
# EXP_MESSAGE = config['globals']['exp_message']

EXP_NAME = 53
IS_SAVE = True

utils.set_seed(SEED)

In [11]:
!wandb login e8aaf98060af90035c3c28a83b34452780aeec20

/bin/sh: 1: wandb: not found


## read data

In [12]:
train_df = pd.read_csv(WIFI_DIR / 'train_10_th10000_base25_withmagndelta.csv')
test_df = pd.read_csv(WIFI_DIR / 'test_10_th10000_base25_allwifibase_withmagndelta.csv')

In [13]:
sub = pd.read_csv(DATA_DIR/'indoor-location-navigation/sample_submission.csv', index_col=0)

BSSIDとRSSIは100ずつ存在しているけど全てが必要なわけではないみたい  
ここでは20だけ取り出している。

In [14]:
# training target features
NUM_FEATS = 80
BSSID_FEATS = [f'bssid_{i}' for i in range(NUM_FEATS)]
RSSI_FEATS  = [f'rssi_{i}' for i in range(NUM_FEATS)]
DELTA_FEATS  = ['delta_x', 'delta_y']

In [15]:
train_df.iloc[:, 100:110]

Unnamed: 0,bssid_0,bssid_1,bssid_2,bssid_3,bssid_4,bssid_5,bssid_6,bssid_7,bssid_8,bssid_9
0,e9b24f94c0007acb4b7169b945622efcd332cf6f,591ea59cf88e3397db5d60eb00a5147edd69399a,590a4dd2afa1ad07090fd5f390f65a55e3dc5f56,a77f8e93896f8fc8bc0d0700ca04b802ee79a07f,1b2fd184314ae440900fa9ce1addeb896b5604a9,2c09230bb32ee49f6a72928f6eeefb6885dc15ce,3799b46aa4cf6c3c45c0bc27d8f1efefea96914f,fc6956beb062b5158252c66953e92a0d25495cac,c71a2f5c4282d27f84b9b841db0e310ef0fcf6cd,4f8b7c168dc76c9d3b4ca7903042173e98fe2ddb
1,e9b24f94c0007acb4b7169b945622efcd332cf6f,590a4dd2afa1ad07090fd5f390f65a55e3dc5f56,591ea59cf88e3397db5d60eb00a5147edd69399a,d32dd11040b254cd889c9ead2d4a50f6e3900196,9c832009dfb1ee02053c9ce9b7770b6cd3191003,509d1f842b0773e85c6beec0bb530542efd35cb9,1b2fd184314ae440900fa9ce1addeb896b5604a9,f920a2e4cb52165850990d9d37d391b630f7de14,6a42281c99a4cff2ce9dba3fc91ad6a431af64d9,fc6956beb062b5158252c66953e92a0d25495cac
2,590a4dd2afa1ad07090fd5f390f65a55e3dc5f56,e9b24f94c0007acb4b7169b945622efcd332cf6f,591ea59cf88e3397db5d60eb00a5147edd69399a,509d1f842b0773e85c6beec0bb530542efd35cb9,b2337b25e7d1df04928bf6698a9c0b2764df7795,76f81d5047273fa64a434457531d400fc5d90fac,6e388d1db5ba8dd9de80522a4ddf50402cf443b3,8c6ab78f2797e076f9106af81090d0ab9904f5cd,ceccac4f0e50ec9e36e8d2800b8f2c7c3b4d903e,f920a2e4cb52165850990d9d37d391b630f7de14
3,590a4dd2afa1ad07090fd5f390f65a55e3dc5f56,591ea59cf88e3397db5d60eb00a5147edd69399a,e9b24f94c0007acb4b7169b945622efcd332cf6f,6ec56c3efdeb067eb20bd2f4a6ccdae07d640cc1,d32dd11040b254cd889c9ead2d4a50f6e3900196,f920a2e4cb52165850990d9d37d391b630f7de14,6a42281c99a4cff2ce9dba3fc91ad6a431af64d9,9c832009dfb1ee02053c9ce9b7770b6cd3191003,509d1f842b0773e85c6beec0bb530542efd35cb9,0452e85d0a41780463cfe079077ea5bd2f519c7a
4,5875360455060f20a3cba705f44a4e3987c9b9f3,6ec56c3efdeb067eb20bd2f4a6ccdae07d640cc1,0452e85d0a41780463cfe079077ea5bd2f519c7a,3c7e7fa0576bc8a2af71d5899581df36f4dab6c8,09e103887f42552d20328aa41891cf82dace79ab,54bba3a36204f8c71b93798c31f9e0b039914575,18067f8d8861af3bcae51ba04b6b11b9150b9ff2,591ea59cf88e3397db5d60eb00a5147edd69399a,f920a2e4cb52165850990d9d37d391b630f7de14,d32dd11040b254cd889c9ead2d4a50f6e3900196
...,...,...,...,...,...,...,...,...,...,...
251108,5964a27e0cb3344b0a18540e6b3120c433971c38,d090a2f7f222fadeeb64e4fbdfe1ca8451116b04,cce41299a022ada08aebf3d309acb07d5f00b014,346b34a42e801c64e043dbaacbe7fef9b8880774,4b5dbdb52b131410ea10b59ea451de62280b41d6,fa11fc4d4960379cb68cc6968ba6415168fef53c,5f583dcccc43b5b7ac25d270e29c92d878fb2be0,4d2e5639041b40b0df2ee258aa504bd904133d80,dfc21edb1f7650d5645fd672bbe6a13fc6fd77f8,a94eb920c0a198fe8385f3de6a8e8e6d44b6f6c9
251109,5964a27e0cb3344b0a18540e6b3120c433971c38,346b34a42e801c64e043dbaacbe7fef9b8880774,d090a2f7f222fadeeb64e4fbdfe1ca8451116b04,cce41299a022ada08aebf3d309acb07d5f00b014,566e0c6e3bcf2b8b3d310d96f111043d17ace817,bd3fc24710537130e97dc2dab4a6bf70b3884a8b,4b5dbdb52b131410ea10b59ea451de62280b41d6,4d2e5639041b40b0df2ee258aa504bd904133d80,f4107af4418d57aacb3542343f7b47768debdc75,5f583dcccc43b5b7ac25d270e29c92d878fb2be0
251110,346b34a42e801c64e043dbaacbe7fef9b8880774,5964a27e0cb3344b0a18540e6b3120c433971c38,cce41299a022ada08aebf3d309acb07d5f00b014,566e0c6e3bcf2b8b3d310d96f111043d17ace817,bd3fc24710537130e97dc2dab4a6bf70b3884a8b,d090a2f7f222fadeeb64e4fbdfe1ca8451116b04,4b5dbdb52b131410ea10b59ea451de62280b41d6,a94eb920c0a198fe8385f3de6a8e8e6d44b6f6c9,ee5ca7a7deaacdcd5d99355ff5f156dc45b74efa,a7986c0cea5d2571ea42011ab4407039e977c0bd
251111,346b34a42e801c64e043dbaacbe7fef9b8880774,d090a2f7f222fadeeb64e4fbdfe1ca8451116b04,5964a27e0cb3344b0a18540e6b3120c433971c38,f4107af4418d57aacb3542343f7b47768debdc75,cce41299a022ada08aebf3d309acb07d5f00b014,4b5dbdb52b131410ea10b59ea451de62280b41d6,fa11fc4d4960379cb68cc6968ba6415168fef53c,a7986c0cea5d2571ea42011ab4407039e977c0bd,180a351ec58c07d60949862c534373c43f548a9a,4d2e5639041b40b0df2ee258aa504bd904133d80


bssid_NはN個目のBSSIDを示しておりRSSI値が大きい順に番号が振られている。
100個しかない


In [16]:
# get numbers of bssids to embed them in a layer

# train
wifi_bssids = []
# bssidを列ごとにリストに入れていく
for i in range(100, 200):
    wifi_bssids.extend(train_df.iloc[:,i].values.tolist())
wifi_bssids = list(set(wifi_bssids))

train_wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(train): {train_wifi_bssids_size}')

# test
wifi_bssids_test = []
for i in range(100, 200):
    wifi_bssids_test.extend(test_df.iloc[:,i].values.tolist())
wifi_bssids_test = list(set(wifi_bssids_test))

test_wifi_bssids_size = len(wifi_bssids_test)
print(f'BSSID TYPES(test): {test_wifi_bssids_size}')


wifi_bssids.extend(wifi_bssids_test)
wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(all): {wifi_bssids_size}')


BSSID TYPES(train): 61592
BSSID TYPES(test): 28283
BSSID TYPES(all): 89875


In [17]:
# get numbers of bssids to embed them in a layer

# train
rssi_bssids = []
# bssidを列ごとにリストに入れていく
for i in RSSI_FEATS:
    rssi_bssids.extend(train_df.loc[:,i].values.tolist())
rssi_bssids = list(set(rssi_bssids))

train_rssi_bssids_size = len(rssi_bssids)
print(f'RSSI TYPES(train): {train_rssi_bssids_size}')

# test
rssi_bssids_test = []
for i in RSSI_FEATS:
    rssi_bssids_test.extend(test_df.loc[:,i].values.tolist())
rssi_bssids_test = list(set(rssi_bssids_test))

test_rssi_bssids_size = len(rssi_bssids_test)
print(f'RSSI TYPES(test): {test_rssi_bssids_size}')


rssi_bssids.extend(rssi_bssids_test)
rssi_bssids_size = len(rssi_bssids)
print(f'RSSI TYPES(all): {rssi_bssids_size}')

RSSI TYPES(train): 97
RSSI TYPES(test): 82
RSSI TYPES(all): 179


## PreProcess

In [18]:
# preprocess

le = LabelEncoder()
le.fit(wifi_bssids)
le_site = LabelEncoder()
le_site.fit(train_df['site_id'])
le_rssi = LabelEncoder()
le_rssi.fit(rssi_bssids)

ss = StandardScaler()
ss.fit(train_df.loc[:,DELTA_FEATS])


def preprocess(input_df, le=le, le_site=le_site, ss=ss):
    output_df = input_df.copy()
    # RSSIの正規化
#     output_df.loc[:,RSSI_FEATS] = ss.transform(input_df.loc[:,RSSI_FEATS])

    # BSSIDのLE(1からふる)
    for i in BSSID_FEATS:
        output_df.loc[:,i] = le.transform(input_df.loc[:,i])
#         output_df.loc[:,i] = output_df.loc[:,i] + 1  # 0からではなく1から番号を振りたいため なぜ？
    for i in RSSI_FEATS:
        output_df.loc[:,i] = le_rssi.transform(input_df.loc[:,i])

    # site_idのLE
    output_df.loc[:, 'site_id'] = le_site.transform(input_df.loc[:, 'site_id'])

    # なぜ２重でやる？
    output_df.loc[:,DELTA_FEATS] = ss.transform(output_df.loc[:,DELTA_FEATS])
    return output_df

train = preprocess(train_df)
test = preprocess(test_df)

train  

Unnamed: 0,ssid_0,ssid_1,ssid_2,ssid_3,ssid_4,ssid_5,ssid_6,ssid_7,ssid_8,ssid_9,...,frequency_99,wp_tmestamp,x,y,floor,floor_str,path_id,site_id,delta_x,delta_y
0,b7e6027447eb1f81327d66cfd3adbe557aabf26c,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,cef6dc5e595dd99c3b2c605de65cfc1f147e892b,7182afc4e5c212133d5d7d76eb3df6c24618302b,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,8c1562bec17e1425615f3402f72dded3caa42ce5,b7e6027447eb1f81327d66cfd3adbe557aabf26c,da39a3ee5e6b4b0d3255bfef95601890afd80709,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,...,5745,1578469851129,157.99141,102.125390,-1.0,B1,5e158ef61506f2000638fd1f,0,-0.024981,-0.022974
1,b7e6027447eb1f81327d66cfd3adbe557aabf26c,cef6dc5e595dd99c3b2c605de65cfc1f147e892b,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,7182afc4e5c212133d5d7d76eb3df6c24618302b,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,b7e6027447eb1f81327d66cfd3adbe557aabf26c,b7e6027447eb1f81327d66cfd3adbe557aabf26c,...,5765,1578469857653,162.93443,106.413020,-1.0,B1,5e158ef61506f2000638fd1f,0,0.481487,0.530641
2,cef6dc5e595dd99c3b2c605de65cfc1f147e892b,b7e6027447eb1f81327d66cfd3adbe557aabf26c,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,b7e6027447eb1f81327d66cfd3adbe557aabf26c,7182afc4e5c212133d5d7d76eb3df6c24618302b,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,...,5745,1578469857653,162.93443,106.413020,-1.0,B1,5e158ef61506f2000638fd1f,0,0.481487,0.530641
3,cef6dc5e595dd99c3b2c605de65cfc1f147e892b,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,b7e6027447eb1f81327d66cfd3adbe557aabf26c,da39a3ee5e6b4b0d3255bfef95601890afd80709,7182afc4e5c212133d5d7d76eb3df6c24618302b,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,b7e6027447eb1f81327d66cfd3adbe557aabf26c,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,...,5825,1578469857653,162.93443,106.413020,-1.0,B1,5e158ef61506f2000638fd1f,0,0.481487,0.530641
4,da39a3ee5e6b4b0d3255bfef95601890afd80709,da39a3ee5e6b4b0d3255bfef95601890afd80709,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b7e6027447eb1f81327d66cfd3adbe557aabf26c,7182afc4e5c212133d5d7d76eb3df6c24618302b,5731b8e08abc69d4c4d685c58164059207c93310,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,7182afc4e5c212133d5d7d76eb3df6c24618302b,...,5765,1578469862177,168.49713,109.861336,-1.0,B1,5e158ef61506f2000638fd1f,0,0.473545,0.488420
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
251108,ea7731d04cf9ed352d4805b1ff904bebdf60eb49,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,4abd3985ba804364272767c04cdc211615f77c56,b5d43f6b4e1938ed497c7b589c6eae9ac0bee168,0a8a55cf161bc4980194ec9f7f7a448439be4b74,1f09251bbfadafb11c63c87963af25238d6bc886,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,da39a3ee5e6b4b0d3255bfef95601890afd80709,1556355684145fce5e67ba749d943a180266ad90,...,0,1573733061352,203.53165,143.513960,6.0,F7,5dcd5c9323759900063d590a,23,0.141618,0.376964
251109,ea7731d04cf9ed352d4805b1ff904bebdf60eb49,4abd3985ba804364272767c04cdc211615f77c56,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,b5d43f6b4e1938ed497c7b589c6eae9ac0bee168,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,4abd3985ba804364272767c04cdc211615f77c56,1f09251bbfadafb11c63c87963af25238d6bc886,...,0,1573733070079,192.57130,145.781450,6.0,F7,5dcd5c9323759900063d590a,23,-1.866413,0.396725
251110,4abd3985ba804364272767c04cdc211615f77c56,ea7731d04cf9ed352d4805b1ff904bebdf60eb49,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,1556355684145fce5e67ba749d943a180266ad90,b5d43f6b4e1938ed497c7b589c6eae9ac0bee168,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,5d998a8668536c4f51004c25f474117fe9555f78,...,0,1573733070079,192.57130,145.781450,6.0,F7,5dcd5c9323759900063d590a,23,-1.866413,0.396725
251111,4abd3985ba804364272767c04cdc211615f77c56,1556355684145fce5e67ba749d943a180266ad90,ea7731d04cf9ed352d4805b1ff904bebdf60eb49,4abd3985ba804364272767c04cdc211615f77c56,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,b5d43f6b4e1938ed497c7b589c6eae9ac0bee168,0a8a55cf161bc4980194ec9f7f7a448439be4b74,5d998a8668536c4f51004c25f474117fe9555f78,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,...,0,1573733070079,192.57130,145.781450,6.0,F7,5dcd5c9323759900063d590a,23,-1.866413,0.396725


In [19]:
site_count = len(train['site_id'].unique())
site_count

24

## PyTorch model
- embedding layerが重要  

In [20]:
# dataset
from torch.utils.data import Dataset, DataLoader
class IndoorDataset(Dataset):
    def __init__(self, df, phase='train'):
        self.df = df
        self.phase = phase
        self.bssid_feats = df[BSSID_FEATS].values.astype(int)
        self.rssi_feats = df[RSSI_FEATS].values.astype(int)
        self.delta_feats = df[DELTA_FEATS].values.astype(np.float32)
        self.site_id = df['site_id'].values.astype(int)

        if phase in ['train', 'valid']:
            self.xy = df[['x', 'y']].values.astype(np.float32)
            self.floor = df['floor'].values.astype(np.float32)
        
    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        
        concat_feat = np.empty(2 * NUM_FEATS).astype(int)
        concat_feat[0::2] = self.bssid_feats[idx]
        concat_feat[1::2] = self.rssi_feats[idx]
        
        feature = {
            'RSSI_BSSID_FEATS':concat_feat,
            'site_id':self.site_id[idx],
            'delta':self.delta_feats[idx]
        }
        if self.phase in ['train', 'valid']:
            target = {
                'xy':self.xy[idx],
                'floor':self.floor[idx]
            }
        else:
            target = {}
        return feature, target

In [21]:
import torch
from torch import nn

class LSTMModel(nn.Module):
    def __init__(self, bssid_size=94248, site_size=24, embedding_dim=64):
        super(LSTMModel, self).__init__()
        
        # bssid
        # ->64次元に圧縮後sequence化にする
        # wifi_bssids_sizeが辞書の数を表す
        self.bssid_embedding = nn.Embedding(bssid_size, 64, max_norm=True)
        # site
        # ->2次元に圧縮後sequence化する
        # site_countが辞書の数を表す       
        self.site_embedding = nn.Embedding(site_size, 64, max_norm=True)

        # rssi
        # 次元を64倍に線形変換
#         self.rssi_embedding = nn.Embedding(rssi_size, 64, max_norm=True)
        self.delta = nn.Sequential(
            nn.BatchNorm1d(2),
            nn.Linear(2, 2 * 64)
        )
        
        concat_size = 64 + (2 * NUM_FEATS * 64) + 64 * 2
        self.linear_layer2 = nn.Sequential(
            nn.BatchNorm1d(concat_size),
            nn.Dropout(0.3),
            nn.Linear(concat_size, 256),
            nn.ReLU()
        )
        self.bn1 = nn.BatchNorm1d(concat_size)

        self.flatten = nn.Flatten()

        self.dropout1 = nn.Dropout(0.3)
        self.linear1 = nn.Linear(in_features=concat_size, out_features=256)#, bias=False)
        self.bn2 = nn.BatchNorm1d(256)

        self.batch_norm1 = nn.BatchNorm1d(1)
        self.lstm1 = nn.LSTM(input_size=256,hidden_size=128,dropout=0.3, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=128,hidden_size=16,dropout=0.1, batch_first=True)

        self.fc_xy = nn.Linear(16, 2)
        # self.fc_x = nn.Linear(16, 1)
        # self.fc_y = nn.Linear(16, 1)
        self.fc_floor = nn.Linear(16, 1)

    
    def forward(self, x):
        # input embedding
        batch_size = x["site_id"].shape[0]
        x_bssid = self.bssid_embedding(x['RSSI_BSSID_FEATS'])
        x_bssid = self.flatten(x_bssid)
        
        x_site_id = self.site_embedding(x['site_id'])
        x_site_id = self.flatten(x_site_id)
        
        x_delta = self.delta(x['delta'])


        x = torch.cat([x_bssid, x_delta, x_site_id], dim=1)
        x = self.linear_layer2(x)

        # lstm layer
        x = x.view(batch_size, 1, -1)  # [batch, 1]->[batch, 1, 1]
        x = self.batch_norm1(x)
        x, _ = self.lstm1(x)
        x = torch.relu(x)
        x, _ = self.lstm2(x)
        x = torch.relu(x)

        # output [batch, 1, 1] -> [batch]
        # x_ = self.fc_x(x).view(-1)
        # y_ = self.fc_y(x).view(-1)
        xy = self.fc_xy(x).squeeze(1)
        floor = torch.relu(self.fc_floor(x)).view(-1)
        # return {"x":x_, "y":y_, "floor":floor} 
        return {"xy": xy, "floor": floor}

In [22]:
def mean_position_error(xhat, yhat, fhat, x, y, f):
    intermediate = np.sqrt(np.power(xhat-x, 2) + np.power(yhat-y, 2)) + 15 * np.abs(fhat-f)
    return intermediate.sum()/xhat.shape[0]

def to_np(input):
    return input.detach().cpu().numpy()

In [23]:
def get_optimizer(model: nn.Module, config: dict):
    optimizer_config = config["optimizer"]
    optimizer_name = optimizer_config.get("name")
    base_optimizer_name = optimizer_config.get("base_name")
    optimizer_params = optimizer_config['params']

    if hasattr(optim, optimizer_name):
        optimizer = optim.__getattribute__(optimizer_name)(model.parameters(), **optimizer_params)
        return optimizer
    else:
        base_optimizer = optim.__getattribute__(base_optimizer_name)
        optimizer = globals().get(optimizer_name)(
            model.parameters(), 
            base_optimizer,
            **optimizer_config["params"])
        return  optimizer

def get_scheduler(optimizer, config: dict):
    scheduler_config = config["scheduler"]
    scheduler_name = scheduler_config.get("name")

    if scheduler_name is None:
        return
    else:
        return optim.lr_scheduler.__getattribute__(scheduler_name)(
            optimizer, **scheduler_config["params"])


def get_criterion(config: dict):
    loss_config = config["loss"]
    loss_name = loss_config["name"]
    loss_params = {} if loss_config.get("params") is None else loss_config.get("params")
    if hasattr(nn, loss_name):
        criterion = nn.__getattribute__(loss_name)(**loss_params)
    else:
        criterion = globals().get(loss_name)(**loss_params)

    return criterion

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

In [24]:
# Learner class(pytorch-lighting)
class Learner(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.xy_criterion = get_criterion(config)
        self.f_criterion = get_criterion(config)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = self.xy_criterion(output["xy"], y["xy"])
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        xy_loss = self.xy_criterion(output["xy"], y["xy"])
        f_loss = self.f_criterion(output["floor"], y["floor"])
        loss = xy_loss  # + f_loss
        mpe = mean_position_error(
            to_np(output['xy'][:, 0]), to_np(output['xy'][:, 1]), 0, 
            to_np(y['xy'][:, 0]), to_np(y['xy'][:, 1]), 0)
        
        # floor lossは現状は無視して良い
        self.log(f'Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/xy', xy_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/floor', f_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'MPE/val', mpe, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return mpe
    
    def validation_epoch_end(self, outputs):
        avg_loss = np.mean(outputs)
        print(f'epoch = {self.current_epoch}, mpe_loss = {avg_loss}')

    def configure_optimizers(self):
        optimizer = get_optimizer(self.model, self.config)
        scheduler = get_scheduler(optimizer, self.config)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "Loss/val"}

In [25]:
# oof
def evaluate(model, loaders, phase):
    x_list = []
    y_list = []
    f_list = []
    with torch.no_grad():
        for batch in loaders[phase]:
            x, y = batch
            output = model(x)
            x_list.append(to_np(output['xy'][:, 0]))
            y_list.append(to_np(output['xy'][:, 1]))
            f_list.append(to_np(output['floor']))

    x_list = np.concatenate(x_list)
    y_list = np.concatenate(y_list)
    f_list = np.concatenate(f_list)
    return x_list, y_list, f_list

## train

In [26]:
oofs = np.zeros((len(train), 2), dtype = np.float32)  # 全てのoofをdfで格納する
predictions = []  # 全ての予測値をdfで格納する
val_scores = []
# skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
gkf = GroupKFold(n_splits=N_SPLITS)
# for fold, (trn_idx, val_idx) in enumerate(skf.split(train.loc[:, 'path'], train.loc[:, 'path'])):
for fold, (trn_idx, val_idx) in enumerate(gkf.split(train.loc[:, 'path_id'], groups=train.loc[:, 'path_id'])):

    # 指定したfoldのみループを回す

    print('=' * 20)
    print(f'Fold {fold}')
    print('=' * 20)

    # train/valid data
    trn_df = train.loc[trn_idx, BSSID_FEATS + RSSI_FEATS + DELTA_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)
    val_df = train.loc[val_idx, BSSID_FEATS + RSSI_FEATS + DELTA_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)

    # data loader
    loaders = {}
    loader_config = config["loader"]
    loaders["train"] = DataLoader(IndoorDataset(trn_df, phase="train"), **loader_config["train"], worker_init_fn=worker_init_fn) 
    loaders["valid"] = DataLoader(IndoorDataset(val_df, phase="valid"), **loader_config["valid"], worker_init_fn=worker_init_fn)
    loaders["test"] = DataLoader(IndoorDataset(test, phase="test"), **loader_config["test"], worker_init_fn=worker_init_fn)
    
    # model
    model = LSTMModel(wifi_bssids_size+rssi_bssids_size, site_count)
    model_name = model.__class__.__name__
    
    # loggers
    RUN_NAME = f'exp{str(EXP_NAME)}'
    wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type=RUN_NAME + f'-fold-{fold}')
    wandb.run.name = RUN_NAME + f'-fold-{fold}'
    wandb_config = wandb.config
    wandb_config.model_name = model_name
    wandb.watch(model)
    
    
    loggers = []
    loggers.append(WandbLogger())

    learner = Learner(model, config)
    
    # callbacks
    callbacks = []
    checkpoint_callback = ModelCheckpoint(
        monitor=f'Loss/val',
        mode='min',
        dirpath=OUTPUT_DIR,
        verbose=False,
        filename=f'{model_name}-{learner.current_epoch}-{fold}')
    callbacks.append(checkpoint_callback)

    early_stop_callback = EarlyStopping(
        monitor='Loss/val',
        min_delta=0.00,
        patience=20,
        verbose=True,
        mode='min')
    callbacks.append(early_stop_callback)
    
    trainer = pl.Trainer(
        logger=loggers,
        callbacks=callbacks,
        max_epochs=MAX_EPOCHS,
        default_root_dir=OUTPUT_DIR,
        gpus=1,
        fast_dev_run=DEBUG,
        deterministic=True,
        benchmark=True,
#         precision=16,
#         progress_bar_refresh_rate=0  # vscodeの時progress barの動作が遅いので表示しない
        )


    trainer.fit(learner, train_dataloader=loaders['train'], val_dataloaders=loaders['valid'])

    #############
    # validation (to make oof)
    #############
    model.eval()
    oof_x, oof_y, oof_f = evaluate(model, loaders, phase="valid")
    oofs[val_idx, 0] = oof_x
    oofs[val_idx, 1] = oof_y

    
    val_score = mean_position_error(
        oof_x, oof_y, 0,
        val_df['x'].values, val_df['y'].values, 0)
    val_scores.append(val_score)
    print(f"fold {fold}: mean position error {val_score}")

    #############
    # inference
    #############
    preds_x, preds_y, preds_f = evaluate(model, loaders, phase="test")
    test_preds = pd.DataFrame(np.stack((preds_f, preds_x, preds_y))).T
    test_preds.columns = sub.columns
    test_preds["site_path_timestamp"] = test["site_path_timestamp"]
    test_preds["floor"] = test_preds["floor"].astype(int)
    predictions.append(test_preds)
    

Fold 0


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))
[34m[1mwandb[0m: Currently logged in as: [33msqrt4kaido[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.28 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.4 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.4 M    Trainable params
0         Non-trainable params
11.4 M    Total params
45.425    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 165.57601165771484


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 131.1819489653485


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 105.40312624090389


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 91.03064052976127


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 84.66111311306882


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 82.5437791087718


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 82.14621494899072


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 63.71621826790031


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 49.01202746501359


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 41.63124019049738


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 35.97619981825185


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 30.57386945166589


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 26.193983804275216


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 22.95501486808798


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 20.17406083721257


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 18.006471192634365


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 15.93888814777757


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 14.410175361198787


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 13.007101774574377


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 12.467662061026807


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 11.800758896964405


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 10.765838065250904


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 11.489973234971465


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 9.976972466408759


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 10.244028803253338


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 9.68374015540429


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 9.026247747938992


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 9.014650271473304


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 8.668250780374088


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 8.792566672053974


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 8.328581448654807


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 9.657096392719458


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 8.279107368148683


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 8.190927212299288


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 8.110639092232011


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 8.205454829326886


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 8.836218003224435


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 8.3253867789645


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 8.381400361871709


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 7.305377592303639


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 7.219846612633927


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 7.163734070080828


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 7.105069461341266


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 7.095899741162338


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 7.088023633770397


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 7.072087713807674


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 7.07046742524328


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 7.025815639909866


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 7.043349844702797


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 7.053453044124811


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 7.005302008315873


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 7.028350886083177


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 6.9906417131680145


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 6.998340845159272


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 6.9915404191663795


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 7.000161390470582


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 6.9839684493357765


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 6.948928136864503


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 6.945955389165478


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 6.942281883812811


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 6.943623540531096


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 6.924864453857442


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 6.92991790063351


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 6.9270573647694


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 6.911535820396336


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 6.927188601105438


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 6.925114828925888


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 6.9240577254734985


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 6.92057777065988


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 6.919578015202224


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 6.914253353739512


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 6.912366344061253


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 6.914008613096984


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 6.92195012209212


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 6.919082514578211


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 6.9153600608100945


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 6.911273398502858


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 6.909352056702926


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 6.909507913279989


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 6.914504722807962


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 6.9099803867069545


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 6.914316340822982


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 6.908956889718829


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 6.9135299013835905


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 6.915865652232741
fold 0: mean position error 6.924551830963387
Fold 1


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,40.15426
Loss/xy,40.15426
Loss/floor,5.61235
MPE/val,6.92455
epoch,83.0
trainer/global_step,33011.0
_runtime,419.0
_timestamp,1619799753.0
_step,83.0


0,1
Loss/val,█▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,████▅▁█▇▇▇█▇████████████████████████████
MPE/val,█▆▅▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


[34m[1mwandb[0m: wandb version 0.10.28 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.4 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.4 M    Trainable params
0         Non-trainable params
11.4 M    Total params
45.425    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 140.84415435791016


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 139.28599515054498


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 116.96181805928548


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 100.78123492157488


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 90.24263196839021


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 68.29979547459209


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 59.41741150085987


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 55.30074791136594


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 52.91810177742167


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 51.710336013325005


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 50.055448537228465


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 48.62484365193121


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 39.317905329284734


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 33.42658632349358


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 27.21437952765951


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 23.167311304696216


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 20.74209037873403


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 17.75647932305122


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 15.900096664063884


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 14.366304029886528


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 13.639318135075229


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 12.046549016676641


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 11.815898642139283


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 11.41406548399836


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 10.245296302621606


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 9.724464639601798


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 9.721120174792462


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 9.493715920808006


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 9.316493712982105


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 9.021289000889318


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 9.316555999208209


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 9.220081394215233


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 8.624943156255693


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 8.599783127983745


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 8.674153941896636


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 9.17996136716379


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 8.359014982514143


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 8.303848384760437


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 8.537663296323014


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 8.804252259069788


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 8.07681624145885


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 8.061410504772974


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 8.027452372541331


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 8.16155297664476


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 8.390378930813108


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 7.99233993417136


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 8.114661275764037


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 8.217507250970495


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 9.093732993115463


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 8.168862415185007


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 7.382950028936606


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 7.236265251771266


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 7.212544652740853


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 7.199595616142032


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 7.158828808419453


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 7.133511142579043


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 7.1486443935249175


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 7.128500010975966


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 7.125248502152929


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 7.1331732069679985


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 7.119202570276775


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 7.106190262931203


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 7.122740489960536


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 7.131555031633572


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 7.103215618490174


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 7.0770027328505964


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 7.076733734898226


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 7.078219825212555


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 7.092244856182846


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 7.0900812491329415


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 7.086105112514008


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 7.04800441048362


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 7.038008106157084


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 7.033678374033244


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 7.022706614907727


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 7.0292896673086505


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 7.036441489596175


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 7.038014887841111


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 7.032716199143994


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 7.04124456074016


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 7.035673432521299


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 7.033331739346545


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 7.036761046429488


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 7.030481361968785


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 7.03223447002392


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 7.032114627103562


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 7.027279829635636


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 7.034052154414761


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 7.039184102596018


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 7.0424786757685


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 7.034389868260147


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 7.043542026423035


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 7.0372913591379564


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 7.040113621467215


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 7.046969832181162
fold 1: mean position error 7.030594708196308
Fold 2


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,39.89498
Loss/xy,39.89498
Loss/floor,4.87751
MPE/val,7.03059
epoch,93.0
trainer/global_step,36941.0
_runtime,468.0
_timestamp,1619800230.0
_step,93.0


0,1
Loss/val,█▅▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▅▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,██▄▂▁▃▁▁▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
MPE/val,█▆▄▃▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


[34m[1mwandb[0m: wandb version 0.10.28 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.4 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.4 M    Trainable params
0         Non-trainable params
11.4 M    Total params
45.425    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 170.29209518432617


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 126.00652363365398


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 100.81269450446956


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 86.87084550568552


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 80.92489241897842


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 79.37064605801433


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 79.44673591333385


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 57.93751116370785


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 54.204115861261016


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 52.47185245401598


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 50.69700403006493


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 50.152411165120604


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 48.93243129499851


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 48.54442771880468


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 47.93315463063027


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 39.889290685631146


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 33.48740085181231


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 26.89907274844696


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 22.08877301948854


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 18.88290141696139


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 16.658198216743887


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 14.584288662898778


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 12.942429199645066


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 12.147043760837905


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 11.790818276519087


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 11.175990244559825


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 10.599842558823271


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 10.1445253194083


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 10.186926268772902


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 12.098860194650795


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 9.549751887494661


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 9.340102522341942


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 9.121420192226603


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 9.285037297011703


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 9.120981257934865


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 9.455614447004432


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 11.943805857522499


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 8.61836806343828


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 8.450984005873725


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 8.38026082600262


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 8.268471901887672


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 8.86096489170713


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 8.583456606345358


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 8.592020043090878


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 8.310638877414464


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 8.695814014429075


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 8.521220027772735


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 8.745582813194893


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 8.55435851802371


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 7.6212920989883655


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 7.5454025888555725


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 7.505278933107174


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 7.442270791635651


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 7.431514347698672


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 7.4219473991090705


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 7.3693610999293675


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 7.386104048899674


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 7.390755078983696


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 7.351903372366352


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 7.360300564084441


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 7.359809239397965


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 7.364521262211721


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 7.3727924091544885


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 7.314216806328529


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 7.307442904184793


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 7.299062732468619


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 7.300543505838552


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 7.286400070657736


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 7.295426786932258


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 7.288908653148599


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 7.282825734265359


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 7.284290080887257


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 7.286492463929049


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 7.282726595690244


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 7.284286713718011


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 7.278890266742293


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 7.279760193471215


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 7.281306771188814


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 7.281238936414622


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 7.284469244891588


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 7.279927611786551


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 7.281764256013943


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 7.2785439520991035


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 7.2883630063286295


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 7.285187287149239


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 7.277476100032109


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 7.281529701014833


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 7.281002902943627


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 7.2836966257363835


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 7.280049019569022


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 7.27525718007681


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 7.27779369919935


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 7.2824236323391265


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 7.281740481382592


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 7.286004078559867


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 7.284612296141529


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 7.282783118981124


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 7.2770645028669145


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 7.281978119682707


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 7.278977972805795


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 7.278916016256273


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 7.273775670391602


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 7.285967562241169


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 7.284662416717608


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 7.277169345758972


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 7.275627307678688


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 7.282068976273774


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 7.273242854211194


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 7.279917380539751


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 7.283017148362829


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 7.2807241955701185


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 7.281104620927999


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 7.278999240480903


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 7.280233404699407


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 7.279421818223072


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 7.277867335850978


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 7.282127114846684


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 7.278559534825895


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 7.278632058140132


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 7.279768858216231


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 7.275879878575587


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 7.279909747968355
fold 2: mean position error 7.2555088289257315
Fold 3


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,44.42846
Loss/xy,44.42846
Loss/floor,4.81937
MPE/val,7.25551
epoch,120.0
trainer/global_step,47552.0
_runtime,600.0
_timestamp,1619800837.0
_step,120.0


0,1
Loss/val,█▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,█████▁▃▅▅▆▇▇▇▇██████████████████████████
MPE/val,█▅▄▄▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


[34m[1mwandb[0m: wandb version 0.10.28 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.4 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.4 M    Trainable params
0         Non-trainable params
11.4 M    Total params
45.425    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 129.3900375366211


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 138.08806637247562


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 115.97076887449884


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 99.98026401367322


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 89.76278081650162


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 84.11041611833534


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 81.5901583008563


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 80.77890441854323


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 80.74179282207263


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 80.92402030733734


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 81.08608189056356


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 81.1651633241274


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 81.19660653008356


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 81.20761587371625


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 81.21267048999516


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 81.20372353052433


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 63.741549352412164


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 53.41492166925429


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 47.962256622146924


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 42.332984044356046


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 39.00419131778655


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 34.467675141613896


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 29.16718153645713


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 25.06349639406736


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 22.83954034084074


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 20.2715236461105


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 17.267517442079193


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 15.388373935751398


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 13.127183329708282


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 12.143507680402914


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 12.344456190805346


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 11.317783686588037


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 11.375948354974879


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 10.824828254480241


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 10.370457702422279


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 10.194099676938867


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 9.941298796830075


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 10.090777845303112


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 9.464498814334458


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 9.25207466486519


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 9.698381201133008


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 9.583050188601515


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 10.842626028643085


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 8.555064854938104


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 9.437750797707311


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 8.751532560908904


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 8.563518944883953


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 8.752857405101725


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 8.678388755573847


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 9.182732805118887


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 8.62411212586079


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 8.282889398190344


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 9.098441617343317


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 8.036362289492374


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 8.134252039340174


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 8.108505826798456


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 8.125577980874136


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 8.69381733384633


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 7.452291364045744


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 7.33452628859837


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 7.265523280184364


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 7.250127643255167


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 7.236425404387396


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 7.229858829622265


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 7.228363214902413


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 7.2171230569553835


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 7.207364779058967


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 7.237466087349686


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 7.206363835877162


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 7.192150967865594


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 7.203316595976323


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 7.174117167057779


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 7.160958212826092


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 7.158550222662389


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 7.16441245470646


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 7.154120262223061


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 7.145236138874468


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 7.148530959035935


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 7.148936703271075


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 7.136007434818166


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 7.138913157533857


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 7.146289009350562


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 7.131456901799915


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 7.137084120975339


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 7.133458374150761


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 7.1524654400154954


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 7.143328846574196


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 7.139947546329423


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 7.136713165748942


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 7.139544615161278


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 7.138365334463476


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 7.142605206083718


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 7.1321250342150035


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 7.136469673262283


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 7.134630888451805


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 7.137292545303525


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 7.1335204593120665


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 7.141080502621898


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 7.13403318469651


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 7.13711490681363


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 7.1369006835843685


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 7.145509473995833


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 7.130837212856998


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 7.139548926380679


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 7.137505914958674
fold 3: mean position error 7.105433723995068
Fold 4


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,40.34948
Loss/xy,40.34948
Loss/floor,4.56044
MPE/val,7.10543
epoch,103.0
trainer/global_step,40871.0
_runtime,524.0
_timestamp,1619801369.0
_step,103.0


0,1
Loss/val,█▅▄▃▃▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▅▄▃▃▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,▁▁▁▁▁▁▂▆██▇▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
MPE/val,█▆▅▅▅▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


[34m[1mwandb[0m: wandb version 0.10.28 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.4 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.4 M    Trainable params
0         Non-trainable params
11.4 M    Total params
45.425    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 128.08655548095703


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 129.79250320615503


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 103.64793801255324


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 89.02056524069832


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 82.36777914832128


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 80.32913093751179


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 80.0384195300116


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 80.1552117054118


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 57.58804045783149


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 52.57783828705196


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 46.7470516653819


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 41.503300285255584


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 34.27410670503692


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 28.781472688397663


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 24.68509184921526


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 21.55002405625663


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 18.92586262476355


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 16.40905387159707


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 14.554967702874814


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 13.782841216067656


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 12.387395453337886


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 11.505270402579125


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 12.576348771743204


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 10.821102744819932


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 10.637480608664115


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 9.806563098649471


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 9.372558821593138


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 9.472252263800847


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 9.300253228969515


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 8.845205805623118


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 8.851529517611443


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 8.82307979670019


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 9.441021161169456


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 8.46305536752633


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 8.277944570054702


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 8.320647644064747


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 8.180452873479602


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 8.527396316008973


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 8.363815419444611


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 8.42209569016307


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 8.06783304206101


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 8.297780168857498


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 10.844660945770036


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 8.094180236116124


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 7.95511436085456


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 8.85675905396868


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 7.891887091208195


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 7.83354682403015


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 8.030682584304374


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 8.727932396613562


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 8.278797143794593


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 8.500514159874797


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 7.667383665990347


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 7.494449175079507


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 7.396215843222463


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 7.357505312497856


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 7.313801802717703


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 7.298766559187027


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 7.267666515782574


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 7.252095707372766


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 7.226092914661715


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 7.239284193091713


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 7.224996112133256


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 7.216334266834318


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 7.198810848375135


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 7.180621507558374


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 7.1827928445084765


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 7.189427723002947


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 7.177463343230371


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 7.175852324389867


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 7.151655190442287


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 7.135767916812151


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 7.128595675509071


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 7.134252152842914


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 7.117850058948884


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 7.125500982562649


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 7.113605314354443


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 7.110326559437856


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 7.1140403894148845


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 7.111833517102228


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 7.114716702194449


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 7.116647877397902


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 7.1040093166868985


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 7.113075343673565


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 7.115380954365486


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 7.115380330938072


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 7.109097737306872


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 7.104626822900919


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 7.098890584238696


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 7.106496058486249


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 7.105658326172022


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 7.105585870483124


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 7.112641568240441


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 7.101239228698837


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 7.102144921511791


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 7.122942167360844


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 7.118868975666568


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 7.109670395698262


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 7.120022501922461


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 7.104594031184401


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 7.110935401120727


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 7.107508921549934


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 7.100924111271198


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 7.105583962296833


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 7.108488319524296


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 7.1057537470044


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 7.115691912294638


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 7.11206860881549


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 7.109244234162985


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 7.107098210418124


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 7.116730897008822


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 7.101950850673135


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 7.107273248920432


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 7.1143469674366315


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 7.121024763055366


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 7.119015444251702


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 7.106285200759827


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 7.110634346829206


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 7.112074787490035


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 7.106325330784513


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 7.108322697456437


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 7.113737992775456


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 7.103658858338207
fold 4: mean position error 7.124749507434723


In [202]:
oofs_df = pd.DataFrame(oofs, columns=['x', 'y'])
oofs_df['path'] = train_df['path_id']
oofs_df['timestamp'] = train_df['wp_tmestamp']
oofs_df['site'] = train_df['site_id']
oofs_df['site_path_timestamp'] = oofs_df['site'] + '_' + oofs_df['path'] + '_' + oofs_df['timestamp'].astype(str)
oofs_df['floor'] = train_df['floor']
oofs_df.to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}.csv", index=False)
oofs_df

Unnamed: 0,x,y,path,timestamp,site,site_path_timestamp,floor
0,160.423431,105.089066,5e158ef61506f2000638fd1f,1578469851129,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
1,161.996552,107.799286,5e158ef61506f2000638fd1f,1578469857653,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
2,162.238556,112.793472,5e158ef61506f2000638fd1f,1578469857653,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
3,160.717529,112.267014,5e158ef61506f2000638fd1f,1578469857653,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
4,162.516251,111.863922,5e158ef61506f2000638fd1f,1578469862177,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
...,...,...,...,...,...,...,...
251108,198.058701,142.002686,5dcd5c9323759900063d590a,1573733061352,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0
251109,188.711761,142.446594,5dcd5c9323759900063d590a,1573733070079,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0
251110,189.257156,140.555359,5dcd5c9323759900063d590a,1573733070079,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0
251111,185.828659,142.298462,5dcd5c9323759900063d590a,1573733070079,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0


In [203]:
oofs_df_dupli = oofs_df[~oofs_df.duplicated(subset='timestamp')].sort_values('timestamp').reset_index(drop=True)

In [204]:
oofs_df_dupli[['path', 'timestamp', 'site']]

Unnamed: 0,path,timestamp,site
0,5d08a2553f461f0008dac591,1560495549579,5c3c44b80379370013e0fd2b
1,5d08a2553f461f0008dac591,1560495574531,5c3c44b80379370013e0fd2b
2,5d073b814a19c000086c558b,1560500995805,5c3c44b80379370013e0fd2b
3,5d073b814a19c000086c558b,1560501011427,5c3c44b80379370013e0fd2b
4,5d073b821a69370008bc5cf8,1560501353313,5c3c44b80379370013e0fd2b
...,...,...,...
70727,5e15bf91f4c3420006d52341,1578483551384,5a0546857ecc773753327266
70728,5e15bf91f4c3420006d52341,1578483556553,5a0546857ecc773753327266
70729,5e15bf91f4c3420006d52341,1578483567957,5a0546857ecc773753327266
70730,5e15bf91f4c3420006d52341,1578483574917,5a0546857ecc773753327266


In [205]:
oofs_df_gby = oofs_df.groupby('timestamp').mean()[['x', 'y']].sort_index().reset_index(drop=True)
oofs_df_gby[['path', 'timestamp', 'site', 'site_path_timestamp', 'floor']] = oofs_df_dupli[['path', 'timestamp', 'site', 'site_path_timestamp', 'floor']]
oofs_df_gby.to_csv(str(OUTPUT_DIR) + f"/oof_gby{EXP_NAME}.csv", index=False)
oofs_df_gby

Unnamed: 0,x,y,path,timestamp,site,site_path_timestamp,floor
0,166.068970,38.230618,5d08a2553f461f0008dac591,1560495549579,5c3c44b80379370013e0fd2b,5c3c44b80379370013e0fd2b_5d08a2553f461f0008dac...,0.0
1,187.829163,31.312820,5d08a2553f461f0008dac591,1560495574531,5c3c44b80379370013e0fd2b,5c3c44b80379370013e0fd2b_5d08a2553f461f0008dac...,0.0
2,193.163330,91.676598,5d073b814a19c000086c558b,1560500995805,5c3c44b80379370013e0fd2b,5c3c44b80379370013e0fd2b_5d073b814a19c000086c5...,2.0
3,186.780960,86.928207,5d073b814a19c000086c558b,1560501011427,5c3c44b80379370013e0fd2b,5c3c44b80379370013e0fd2b_5d073b814a19c000086c5...,2.0
4,184.996216,86.089577,5d073b821a69370008bc5cf8,1560501353313,5c3c44b80379370013e0fd2b,5c3c44b80379370013e0fd2b_5d073b821a69370008bc5...,2.0
...,...,...,...,...,...,...,...
70727,65.555809,180.498825,5e15bf91f4c3420006d52341,1578483551384,5a0546857ecc773753327266,5a0546857ecc773753327266_5e15bf91f4c3420006d52...,-1.0
70728,62.187752,178.135742,5e15bf91f4c3420006d52341,1578483556553,5a0546857ecc773753327266,5a0546857ecc773753327266_5e15bf91f4c3420006d52...,-1.0
70729,52.109596,181.569809,5e15bf91f4c3420006d52341,1578483567957,5a0546857ecc773753327266,5a0546857ecc773753327266_5e15bf91f4c3420006d52...,-1.0
70730,44.894745,183.290253,5e15bf91f4c3420006d52341,1578483574917,5a0546857ecc773753327266,5a0546857ecc773753327266_5e15bf91f4c3420006d52...,-1.0


In [206]:
train_df_gby = train_df.groupby('wp_tmestamp').mean()[['x', 'y']]

In [207]:
oofs_score = mean_position_error(
        oofs_df['x'], oofs_df['y'], 0,
        train_df['x'].values, train_df['y'].values, 0)
oofs_score_gby = mean_position_error(
        oofs_df_gby['x'], oofs_df_gby['y'], 0,
        train_df_gby['x'].values, train_df_gby['y'].values, 0)
print(f"CV:{oofs_score}, CV_gby:{oofs_score_gby}")

CV:7.088167505466546, CV_gby:6.565643716260137


In [208]:
fold_mean_xy = 0
for i in range(len(predictions)):
    fold_mean_xy += predictions[i][['x', 'y']]
fold_mean_xy = fold_mean_xy / 5
fold_mean_xy

Unnamed: 0,x,y
0,71.194847,85.575447
1,75.961678,86.774094
2,71.241074,85.795677
3,70.683273,82.870407
4,69.701759,81.978279
...,...,...
37395,131.009430,183.347931
37396,132.237106,181.345490
37397,132.319672,177.304153
37398,131.089142,174.237656


In [137]:
fold_mean_xy['site_path_timestamp'] = predictions[0]['site_path_timestamp']

In [138]:
all_preds = fold_mean_xy.groupby('site_path_timestamp').mean()

In [139]:
# all_preds = pd.concat(predictions).groupby('site_path_timestamp').mean()
# all_preds

In [140]:
all_preds_50 = pd.read_csv('../50/output/sub50.csv', index_col=0)
all_preds_50.index = pd.read_csv(WIFI_DIR / 'test_7_th20000.csv')['site_path_timestamp']
all_preds_50

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474563646,0,85.71397,104.649240
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474572654,0,80.89932,102.921295
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474578963,0,85.22545,105.812570
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474582400,0,88.35129,107.935360
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474585965,0,87.31835,108.258490
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731143256,5,210.10178,100.415660
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731146426,5,208.61255,101.582490
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731151563,5,205.28879,106.346170
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731157567,5,199.99641,112.778860


In [141]:
all_preds_merge = pd.merge(all_preds_50, all_preds, how='left', on='site_path_timestamp')[['floor', 'x_y', 'y_y']]
all_preds_merge = all_preds_merge.rename(columns={'x_y': 'x', 'y_y': 'y'})
all_preds_merge

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474563646,0,,
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474572654,0,,
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474578963,0,,
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474582400,0,,
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474585965,0,,
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731143256,5,208.540924,98.267525
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731146426,5,209.484756,101.554031
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731151563,5,203.762955,108.776146
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731157567,5,200.898163,111.846176


In [142]:
all_preds_merge['floor'].fillna(all_preds_50['floor'], inplace=True)
all_preds_merge['x'].fillna(all_preds_50['x'], inplace=True)
all_preds_merge['y'].fillna(all_preds_50['y'], inplace=True)
all_preds_merge

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474563646,0,85.713966,104.649239
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474572654,0,80.899323,102.921295
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474578963,0,85.225449,105.812569
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474582400,0,88.351288,107.935364
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474585965,0,87.318352,108.258492
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731143256,5,208.540924,98.267525
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731146426,5,209.484756,101.554031
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731151563,5,203.762955,108.776146
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731157567,5,200.898163,111.846176


In [143]:
# foldの結果を平均した後、reindexでsubmission fileにindexを合わせる
all_preds_merge.index = sub.index
all_preds_merge

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,85.713966,104.649239
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,80.899323,102.921295
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,85.225449,105.812569
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,88.351288,107.935364
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,87.318352,108.258492
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,5,208.540924,98.267525
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,5,209.484756,101.554031
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,5,203.762955,108.776146
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,5,200.898163,111.846176


In [144]:
# floorの数値を置換
simple_accurate_99 = pd.read_csv('../01/submission.csv')
all_preds_merge['floor'] = simple_accurate_99['floor'].values
all_preds_merge.to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}.csv")
all_preds_merge

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,85.713966,104.649239
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,80.899323,102.921295
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,85.225449,105.812569
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,88.351288,107.935364
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,87.318352,108.258492
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,5,208.540924,98.267525
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,5,209.484756,101.554031
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,5,203.762955,108.776146
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,5,200.898163,111.846176


# Post Proccess

In [209]:
oofs_df = pd.read_csv(str(OUTPUT_DIR) + f"/oof_gby{EXP_NAME}.csv")
sub_df = pd.read_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}.csv")

In [210]:
def compute_rel_positions(acce_datas, ahrs_datas):
    step_timestamps, step_indexs, step_acce_max_mins = compute_f.compute_steps(acce_datas)
    headings = compute_f.compute_headings(ahrs_datas)
    stride_lengths = compute_f.compute_stride_length(step_acce_max_mins)
    step_headings = compute_f.compute_step_heading(step_timestamps, headings)
    rel_positions = compute_f.compute_rel_positions(stride_lengths, step_headings)
    return rel_positions

In [211]:
import math

order = 3
fs = 50.0  # sample rate, Hz
# fs = 100
# cutoff = 3.667  # desired cutoff frequency of the filter, Hz
cutoff = 3

step_distance = 0.8
w_height = 1.7
m_trans = -5

from scipy.signal import butter, lfilter

def butter_lowpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    return b, a

def butter_lowpass_filter(data, cutoff, fs, order=5):
    b, a = butter_lowpass(cutoff, fs, order=order)
    y = lfilter(b, a, data)
    return y

def peak_accel_threshold(data, timestamps, threshold):
    d_acc = []
    last_state = 'below'
    crest_troughs = 0
    crossings = []

    for i, datum in enumerate(data):
        
        current_state = last_state
        if datum < threshold:
            current_state = 'below'
        elif datum > threshold:
            current_state = 'above'

        if current_state is not last_state:
            if current_state is 'above':
                crossing = [timestamps[i], threshold]
                crossings.append(crossing)
            else:
                crossing = [timestamps[i], threshold]
                crossings.append(crossing)

            crest_troughs += 1
        last_state = current_state
    return np.array(crossings)

In [212]:
def steps_compute_rel_positions(sample_file):
    
    mix_acce = np.sqrt(sample_file.acce[:,1:2]**2 + sample_file.acce[:,2:3]**2 + sample_file.acce[:,3:4]**2)
    mix_acce = np.concatenate([sample_file.acce[:,0:1], mix_acce], 1)
    mix_df = pd.DataFrame(mix_acce)
    mix_df.columns = ["timestamp","acce"]
    
    filtered = butter_lowpass_filter(mix_df["acce"], cutoff, fs, order)

    threshold = filtered.mean() * 1.1
    crossings = peak_accel_threshold(filtered, mix_df["timestamp"], threshold)

    step_sum = len(crossings)/2
    distance = w_height * 0.4 * step_sum

    mag_df = pd.DataFrame(sample_file.magn)
    mag_df.columns = ["timestamp","x","y","z"]
    
    acce_df = pd.DataFrame(sample_file.acce)
    acce_df.columns = ["timestamp","ax","ay","az"]
    
    mag_df = pd.merge(mag_df,acce_df,on="timestamp")
    mag_df.dropna()
    
    time_di_list = []

    for i in mag_df.iterrows():

        gx,gy,gz = i[1][1],i[1][2],i[1][3]
        ax,ay,az = i[1][4],i[1][5],i[1][6]

        roll = math.atan2(ay,az)
        pitch = math.atan2(-1*ax , (ay * math.sin(roll) + az * math.cos(roll)))

        q = m_trans - math.degrees(math.atan2(
            (gz*math.sin(roll)-gy*math.cos(roll)),(gx*math.cos(pitch) + gy*math.sin(roll)*math.sin(pitch) + gz*math.sin(pitch)*math.cos(roll))
        )) -90
        if q <= 0:
            q += 360
        time_di_list.append((i[1][0],q))

    d_list = [x[1] for x in time_di_list]
    
    steps = []
    step_time = []
    di_dict = dict(time_di_list)

    for n,i in enumerate(crossings[:,:1]):
        if n % 2 == 1:
            continue
        direct_now = di_dict[i[0]]
        dx = math.sin(math.radians(direct_now))
        dy = math.cos(math.radians(direct_now))
#         print(int(n/2+1),"歩目/x:",dx,"/y:",dy,"/角度：",direct_now)
        steps.append((i[0],dx,dy))
        step_time.append(i[0])
    
        step_dtime = np.diff(step_time)/1000
        step_dtime = step_dtime.tolist()
        step_dtime.insert(0,5)
        
        rel_position = []

        wp_idx = 0
#         print("WP:",round(sample_file.waypoint[0,1],3),round(sample_file.waypoint[0,2],3),sample_file.waypoint[0,0])
#         print("------------------")
        for p,i in enumerate(steps):
            step_distance = 0
            if step_dtime[p] >= 1:
                step_distance = w_height*0.25
            elif step_dtime[p] >= 0.75:
                step_distance = w_height*0.3
            elif step_dtime[p] >= 0.5:
                step_distance = w_height*0.4
            elif step_dtime[p] >= 0.35:
                step_distance = w_height*0.45
            elif step_dtime[p] >= 0.2:
                step_distance = w_height*0.5
            else:
                step_distance = w_height*0.4

#             step_x += i[1]*step_distance
#             step_y += i[2]*step_distance
            
            rel_position.append([i[0], i[1]*step_distance, i[2]*step_distance])
#     print(rel_position)
    
    return np.array(rel_position)

In [213]:
def correct_path(args):

    path, path_df = args
    T_ref  = path_df['timestamp'].values
    xy_hat = path_df[['x', 'y']].values
    txt_path = path_df['txt_path'].values[0]
    
    example = read_data_file(txt_path)
    rel_positions1 = compute_rel_positions(example.acce, example.ahrs)
    rel_positions2 = steps_compute_rel_positions(example)
    rel1 = rel_positions1.copy()
    rel2 = rel_positions2.copy()
    rel1[:,1:] = rel_positions1[:,1:] / 2
    rel2[:,1:] = rel_positions2[:,1:] / 2
    rel_positions = np.vstack([rel1,rel2])
    rel_positions = rel_positions[np.argsort(rel_positions[:, 0])]
    
    if T_ref[-1] > rel_positions[-1, 0]:
        rel_positions = [np.array([[0, 0, 0]]), rel_positions, np.array([[T_ref[-1], 0, 0]])]
    else:
        rel_positions = [np.array([[0, 0, 0]]), rel_positions]
    rel_positions = np.concatenate(rel_positions)
    
    T_rel = rel_positions[:, 0]
    delta_xy_hat = np.diff(scipy.interpolate.interp1d(T_rel, np.cumsum(rel_positions[:, 1:3], axis=0), axis=0)(T_ref), axis=0)

    N = xy_hat.shape[0]
    delta_t = np.diff(T_ref)
    alpha = (8.1)**(-2) * np.ones(N)
    beta  = (0.3 + 0.3 * 1e-3 * delta_t)**(-2)
    A = scipy.sparse.spdiags(alpha, [0], N, N)
    B = scipy.sparse.spdiags( beta, [0], N-1, N-1)
    D = scipy.sparse.spdiags(np.stack([-np.ones(N), np.ones(N)]), [0, 1], N-1, N)

    Q = A + (D.T @ B @ D)
    c = (A @ xy_hat) + (D.T @ (B @ delta_xy_hat))
    xy_star = scipy.sparse.linalg.spsolve(Q, c)

    return pd.DataFrame({
        'site_path_timestamp' : path_df['site_path_timestamp'],
        'floor' : path_df['floor'],
        'x' : xy_star[:, 0],
        'y' : xy_star[:, 1],
    })

In [214]:
tmp = sub_df['site_path_timestamp'].apply(lambda s : pd.Series(s.split('_')))
sub_df['site'] = tmp[0]
sub_df['path'] = tmp[1]
sub_df['timestamp'] = tmp[2].astype(float)

In [215]:
used_buildings = sorted(sub_df['site'].value_counts().index.tolist())
test_txts = sorted(glob.glob(str(DATA_DIR/'indoor-location-navigation') + f'/test/*.txt'))
train_txts = [sorted(glob.glob(str(DATA_DIR/'indoor-location-navigation') + f'/train/{used_building}/*/*.txt')) for used_building in used_buildings]
train_txts = sum(train_txts, [])

In [216]:
txt_pathes = []
for path in tqdm(sub_df['path'].values):
    txt_pathes.append([test_txt for test_txt in test_txts if path in test_txt][0])

100%|██████████| 10133/10133 [00:00<00:00, 28574.17it/s]


In [217]:
sub_df['txt_path'] = txt_pathes

In [218]:
processes = multiprocessing.cpu_count()
with multiprocessing.Pool(processes=processes) as pool:
    dfs = pool.imap_unordered(correct_path, sub_df.groupby('path'))
    dfs = tqdm(dfs)
    dfs = list(dfs)
sub_df_cm = pd.concat(dfs).sort_values('site_path_timestamp')

626it [02:25,  4.29it/s]


In [219]:
txt_pathes = []
for path in tqdm(oofs_df['path'].values):
    txt_pathes.append([train_txt for train_txt in train_txts if path in train_txt][0])

100%|██████████| 70732/70732 [00:41<00:00, 1707.00it/s]


In [220]:
oofs_df['txt_path'] = txt_pathes

In [221]:
processes = multiprocessing.cpu_count()
with multiprocessing.Pool(processes=processes) as pool:
    dfs = pool.imap_unordered(correct_path, oofs_df.groupby('path'))
    dfs = tqdm(dfs)
    dfs = list(dfs)
oofs_df_cm = pd.concat(dfs).sort_index()
oofs_df_cm

10789it [16:55, 10.63it/s]


Unnamed: 0,site_path_timestamp,floor,x,y
0,5c3c44b80379370013e0fd2b_5d08a2553f461f0008dac...,0.0,165.537923,38.994414
1,5c3c44b80379370013e0fd2b_5d08a2553f461f0008dac...,0.0,188.360207,30.549024
2,5c3c44b80379370013e0fd2b_5d073b814a19c000086c5...,2.0,196.443684,91.277600
3,5c3c44b80379370013e0fd2b_5d073b814a19c000086c5...,2.0,183.500606,87.327210
4,5c3c44b80379370013e0fd2b_5d073b821a69370008bc5...,2.0,181.887303,85.937432
...,...,...,...,...
70727,5a0546857ecc773753327266_5e15bf91f4c3420006d52...,-1.0,65.682126,178.702787
70728,5a0546857ecc773753327266_5e15bf91f4c3420006d52...,-1.0,60.491982,180.332963
70729,5a0546857ecc773753327266_5e15bf91f4c3420006d52...,-1.0,52.212050,183.005921
70730,5a0546857ecc773753327266_5e15bf91f4c3420006d52...,-1.0,46.905675,185.432584


In [224]:
oofs_df_cm.to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}_cm_gby.csv", index=False)

In [223]:
oofs_score = mean_position_error(
        oofs_df_cm['x'], oofs_df_cm['y'], 0,
        train_df_gby['x'].values, train_df_gby['y'].values, 0)
print(f"CV:{oofs_score}")

CV:5.45021334803446


In [225]:
sub_df_cm.to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}_cm.csv", index=False)

In [226]:
oofs_df_cm = pd.read_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}_cm_gby.csv")

In [227]:
sub_df_cm = pd.read_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}_cm.csv")

In [228]:
def split_col(df):
    df = pd.concat([
        df['site_path_timestamp'].str.split('_', expand=True) \
        .rename(columns={0:'site',
                         1:'path',
                         2:'timestamp'}),
        df
    ], axis=1).copy()
    return df
def sub_process(sub, train_waypoints):
    train_waypoints['isTrainWaypoint'] = True
    sub = split_col(sub[['site_path_timestamp','floor','x','y']]).copy()
    sub = sub.merge(train_waypoints[['site','floorNo','floor']].drop_duplicates(), how='left')
    sub = sub.merge(
        train_waypoints[['x','y','site','floor','isTrainWaypoint']].drop_duplicates(),
        how='left',
        on=['site','x','y','floor']
             )
    sub['isTrainWaypoint'] = sub['isTrainWaypoint'].fillna(False)
    return sub.copy()

In [229]:
train_waypoints = pd.read_csv(str(DATA_DIR/'indoor-location-navigation') + '/train_waypoints.csv')


In [230]:
sub_df_cm = sub_process(sub_df_cm, train_waypoints)
oofs_df_cm = sub_process(oofs_df_cm, train_waypoints)

In [231]:
from scipy.spatial.distance import cdist

def add_xy(df):
    df['xy'] = [(x, y) for x,y in zip(df['x'], df['y'])]
    return df

def closest_point(point, points):
    """ Find closest point from a list of points. """
    return points[cdist([point], points).argmin()]

sub_df_cm = add_xy(sub_df_cm)
train_waypoints = add_xy(train_waypoints)

ds = []
for (site, myfloor), d in tqdm(sub_df_cm.groupby(['site','floor'])):
    true_floor_locs = train_waypoints.loc[(train_waypoints['floor'] == myfloor) &
                                          (train_waypoints['site'] == site)] \
        .reset_index(drop=True)
    if len(true_floor_locs) == 0:
        print(f'Skipping {site} {myfloor}')
        continue
    d['matched_point'] = [closest_point(x, list(true_floor_locs['xy'])) for x in d['xy']]
    d['x_'] = d['matched_point'].apply(lambda x: x[0])
    d['y_'] = d['matched_point'].apply(lambda x: x[1])
    ds.append(d)

sub_df_cm_ds = pd.concat(ds)


oofs_df_cm = add_xy(oofs_df_cm)
train_waypoints = add_xy(train_waypoints)

ds = []
for (site, myfloor), d in tqdm(oofs_df_cm.groupby(['site','floor'])):
    true_floor_locs = train_waypoints.loc[(train_waypoints['floor'] == myfloor) &
                                          (train_waypoints['site'] == site)] \
        .reset_index(drop=True)
    if len(true_floor_locs) == 0:
        print(f'Skipping {site} {myfloor}')
        continue
    d['matched_point'] = [closest_point(x, list(true_floor_locs['xy'])) for x in d['xy']]
    d['x_'] = d['matched_point'].apply(lambda x: x[0])
    d['y_'] = d['matched_point'].apply(lambda x: x[1])
    ds.append(d)

oofs_df_cm_ds = pd.concat(ds)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
100%|██████████| 118/118 [00:04<00:00, 27.14it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation:

In [232]:
def snap_to_grid(sub, threshold):
    """
    Snap to grid if within a threshold.
    
    x, y are the predicted points.
    x_, y_ are the closest grid points.
    _x_, _y_ are the new predictions after post processing.
    """
    sub['_x_'] = sub['x']
    sub['_y_'] = sub['y']
    sub.loc[sub['dist'] < threshold, '_x_'] = sub.loc[sub['dist'] < threshold]['x_']
    sub.loc[sub['dist'] < threshold, '_y_'] = sub.loc[sub['dist'] < threshold]['y_']
    return sub.copy()



In [233]:
# Calculate the distances
sub_df_cm_ds['dist'] = np.sqrt( (sub_df_cm_ds.x-sub_df_cm_ds.x_)**2 + (sub_df_cm_ds.y-sub_df_cm_ds.y_)**2 )
sub_pp = snap_to_grid(sub_df_cm_ds, threshold=5)
sub_pp = sub_pp[['site_path_timestamp','floor','_x_','_y_','site','path','floorNo']] \
    .rename(columns={'_x_':'x', '_y_':'y'})

# Calculate the distances
oofs_df_cm_ds['dist'] = np.sqrt( (oofs_df_cm_ds.x-oofs_df_cm_ds.x_)**2 + (oofs_df_cm_ds.y-oofs_df_cm_ds.y_)**2 )
oofs_pp = snap_to_grid(oofs_df_cm_ds, threshold=5)
oofs_pp = oofs_pp[['site_path_timestamp','floor','_x_','_y_','site','path','floorNo']] \
    .rename(columns={'_x_':'x', '_y_':'y'})

In [234]:
sub_pp = sub_pp.sort_index()
sub_pp

Unnamed: 0,site_path_timestamp,floor,x,y,site,path,floorNo
0,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,93.728470,97.948860,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
1,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,79.662285,102.766754,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
2,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,87.162830,104.450550,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
3,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,81.657740,110.509090,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
4,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,81.657740,110.509090,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
...,...,...,...,...,...,...,...
10128,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,205.596040,98.605774,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6
10129,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,206.627910,102.635210,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6
10130,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,205.511300,107.841324,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6
10131,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,192.553130,111.863014,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6


In [235]:
sub_pp[['site_path_timestamp','floor','x','y']] \
    .to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}_pp.csv", index=False)

In [236]:
oofs_pp = oofs_pp.sort_index()
oofs_pp

Unnamed: 0,site_path_timestamp,floor,x,y,site,path,floorNo
0,5c3c44b80379370013e0fd2b_5d08a2553f461f0008dac...,0.0,163.710890,36.197370,5c3c44b80379370013e0fd2b,5d08a2553f461f0008dac591,F1
1,5c3c44b80379370013e0fd2b_5d08a2553f461f0008dac...,0.0,188.360207,30.549024,5c3c44b80379370013e0fd2b,5d08a2553f461f0008dac591,F1
2,5c3c44b80379370013e0fd2b_5d073b814a19c000086c5...,2.0,195.909590,90.654030,5c3c44b80379370013e0fd2b,5d073b814a19c000086c558b,F3
3,5c3c44b80379370013e0fd2b_5d073b814a19c000086c5...,2.0,185.099150,88.624350,5c3c44b80379370013e0fd2b,5d073b814a19c000086c558b,F3
4,5c3c44b80379370013e0fd2b_5d073b821a69370008bc5...,2.0,181.347610,84.303024,5c3c44b80379370013e0fd2b,5d073b821a69370008bc5cf8,F3
...,...,...,...,...,...,...,...
70727,5a0546857ecc773753327266_5e15bf91f4c3420006d52...,-1.0,68.407646,177.942440,5a0546857ecc773753327266,5e15bf91f4c3420006d52341,B1
70728,5a0546857ecc773753327266_5e15bf91f4c3420006d52...,-1.0,61.413883,176.395230,5a0546857ecc773753327266,5e15bf91f4c3420006d52341,B1
70729,5a0546857ecc773753327266_5e15bf91f4c3420006d52...,-1.0,55.143543,185.641190,5a0546857ecc773753327266,5e15bf91f4c3420006d52341,B1
70730,5a0546857ecc773753327266_5e15bf91f4c3420006d52...,-1.0,49.308346,187.129460,5a0546857ecc773753327266,5e15bf91f4c3420006d52341,B1


In [238]:
oofs_score = mean_position_error(
        oofs_pp['x'], oofs_pp['y'], 0,
        train_df_gby['x'].values, train_df_gby['y'].values, 0)
print(f"CV:{oofs_score}")

CV:4.911771025398721


In [239]:
oofs_pp[['site_path_timestamp','floor','x','y']] \
    .to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}_pp_gby.csv", index=False)

In [64]:
wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type='summary')
wandb.run.name = 'summary'
wandb.log({'CV_score': oofs_score})
wandb.save(utils.get_notebook_path())
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,42.81095
Loss/xy,42.81095
Loss/floor,5.20935
MPE/val,7.12475
epoch,121.0
trainer/global_step,47945.0
_runtime,610.0
_timestamp,1619801987.0
_step,121.0


0,1
Loss/val,█▄▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▄▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,████▇▅▇▇▄▃▃▃▂▁▂▂▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
MPE/val,█▅▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


[34m[1mwandb[0m: wandb version 0.10.28 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade




VBox(children=(Label(value=' 0.00MB of 0.40MB uploaded (0.00MB deduped)\r'), FloatProgress(value=0.00159089871…

0,1
CV_score,5.1373
_runtime,2.0
_timestamp,1619803482.0
_step,0.0


0,1
CV_score,▁
_runtime,▁
_timestamp,▁
_step,▁


In [None]:
import json
import matplotlib.pylab as plt

def plot_preds(
    site,
    floorNo,
    sub=None,
    true_locs=None,
    base=str(DATA_DIR/'indoor-location-navigation'),
    show_train=True,
    show_preds=True,
    fix_labels=True,
    map_floor=None
):
    """
    Plots predictions on floorplan map.
    
    map_floor : use a different floor's map
    """
    if map_floor is None:
        map_floor = floorNo
    # Prepare width_meter & height_meter (taken from the .json file)
    floor_plan_filename = f"{base}/metadata/{site}/{map_floor}/floor_image.png"
    json_plan_filename = f"{base}/metadata/{site}/{map_floor}/floor_info.json"
    with open(json_plan_filename) as json_file:
        json_data = json.load(json_file)

    width_meter = json_data["map_info"]["width"]
    height_meter = json_data["map_info"]["height"]

    floor_img = plt.imread(f"{base}/metadata/{site}/{map_floor}/floor_image.png")

    fig, ax = plt.subplots(figsize=(12, 12))
    plt.imshow(floor_img)

    if show_train:
        true_locs = true_locs.query('site == @site and floorNo == @map_floor').copy()
        true_locs["x_"] = true_locs["x"] * floor_img.shape[0] / height_meter
        true_locs["y_"] = (
            true_locs["y"] * -1 * floor_img.shape[1] / width_meter
        ) + floor_img.shape[0]
        true_locs.query("site == @site and floorNo == @map_floor").groupby("path").plot(
            x="x_",
            y="y_",
            style="+",
            ax=ax,
            label="train waypoint location",
            color="grey",
            alpha=0.5,
        )

    if show_preds:
        sub = sub.query('site == @site and floorNo == @floorNo').copy()
        sub["x_"] = sub["x"] * floor_img.shape[0] / height_meter
        sub["y_"] = (
            sub["y"] * -1 * floor_img.shape[1] / width_meter
        ) + floor_img.shape[0]
        for path, path_data in sub.query(
            "site == @site and floorNo == @floorNo"
        ).groupby("path"):
            path_data.plot(
                x="x_",
                y="y_",
                style=".-",
                ax=ax,
                title=f"{site} - floor - {floorNo}",
                alpha=1,
                label=path,
            )
    if fix_labels:
        handles, labels = ax.get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        plt.legend(
            by_label.values(), by_label.keys(), loc="center left", bbox_to_anchor=(1, 0.5)
        )
    return fig, ax

In [None]:
example_site = '5a0546857ecc773753327266'
example_floorNo = 'F1'

sub_df = sub_process(sub_df, train_waypoints)
plot_preds(example_site, example_floorNo, sub_df,
           train_waypoints, show_preds=True)

In [None]:
plot_preds(example_site, example_floorNo, sub_df_cm,
           train_waypoints, show_preds=True)

In [None]:
plot_preds(example_site, example_floorNo, sub_pp,
           train_waypoints, show_preds=True)