# LSTM baseline

from kuto

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

In [2]:
import os
import sys
import glob
import pickle
import random

In [3]:
import numpy as np
import pandas as pd
import scipy.stats as stats
from pathlib import Path


In [4]:
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder

In [5]:
import pytorch_lightning as pl
# from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping

import wandb
from pytorch_lightning.loggers import WandbLogger

In [6]:
sys.path.append('../../')
import src.utils as utils

In [7]:
import multiprocessing
import scipy.interpolate
import scipy.sparse
from tqdm import tqdm

from indoor_location_competition_20.io_f import read_data_file
import indoor_location_competition_20.compute_f as compute_f

In [8]:
DATA_DIR = Path("/home/knikaido/work/Indoor-Location-Navigation/data/")
WIFI_DIR = DATA_DIR / 'indoorunifiedwifids_original'
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'
OUTPUT_DIR = Path('./output/')
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'

## config

In [9]:
configs = {
    'loss':{
        'name': 'MSELoss',
        'params':{}
    },
    'optimizer':{
        'name': 'Adam',
        'params':{
            'lr': 0.01,
        }
    },

    'scheduler':{
        'name': 'ReduceLROnPlateau',
        'params':{
            'factor': 0.1,
            'patience': 3,
        }
    },

    'loader':{
        'train':{
            'batch_size': 512,
            'shuffle': True,
            'num_workers': 4,
        },
        'valid':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        },
        'test':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        }
    }
}

In [10]:
# config
config = configs

# globals variable
SEED = 777
MAX_EPOCHS = 200
N_SPLITS = 5
DEBUG = False
# EXP_MESSAGE = config['globals']['exp_message']

EXP_NAME = 24
IS_SAVE = True

utils.set_seed(SEED)

In [11]:
!wandb login e8aaf98060af90035c3c28a83b34452780aeec20

/bin/sh: 1: wandb: not found


## read data

In [12]:
train_df = pd.read_csv(WIFI_DIR / 'train_10.csv')
test_df = pd.read_csv(WIFI_DIR / 'test_10.csv')

In [13]:
sub = pd.read_csv(DATA_DIR/'indoor-location-navigation/sample_submission.csv', index_col=0)

BSSIDとRSSIは100ずつ存在しているけど全てが必要なわけではないみたい  
ここでは20だけ取り出している。

In [14]:
# training target features
NUM_FEATS = 80
BSSID_FEATS = [f'bssid_{i}' for i in range(NUM_FEATS)]
RSSI_FEATS  = [f'rssi_{i}' for i in range(NUM_FEATS)]

In [15]:
train_df.iloc[:, 100:110]

Unnamed: 0,bssid_0,bssid_1,bssid_2,bssid_3,bssid_4,bssid_5,bssid_6,bssid_7,bssid_8,bssid_9
0,e9b24f94c0007acb4b7169b945622efcd332cf6f,6ec56c3efdeb067eb20bd2f4a6ccdae07d640cc1,591ea59cf88e3397db5d60eb00a5147edd69399a,590a4dd2afa1ad07090fd5f390f65a55e3dc5f56,a77f8e93896f8fc8bc0d0700ca04b802ee79a07f,1b2fd184314ae440900fa9ce1addeb896b5604a9,2c09230bb32ee49f6a72928f6eeefb6885dc15ce,3799b46aa4cf6c3c45c0bc27d8f1efefea96914f,fc6956beb062b5158252c66953e92a0d25495cac,c71a2f5c4282d27f84b9b841db0e310ef0fcf6cd
1,e9b24f94c0007acb4b7169b945622efcd332cf6f,590a4dd2afa1ad07090fd5f390f65a55e3dc5f56,591ea59cf88e3397db5d60eb00a5147edd69399a,6ec56c3efdeb067eb20bd2f4a6ccdae07d640cc1,d32dd11040b254cd889c9ead2d4a50f6e3900196,9c832009dfb1ee02053c9ce9b7770b6cd3191003,509d1f842b0773e85c6beec0bb530542efd35cb9,1b2fd184314ae440900fa9ce1addeb896b5604a9,f920a2e4cb52165850990d9d37d391b630f7de14,6a42281c99a4cff2ce9dba3fc91ad6a431af64d9
2,590a4dd2afa1ad07090fd5f390f65a55e3dc5f56,e9b24f94c0007acb4b7169b945622efcd332cf6f,591ea59cf88e3397db5d60eb00a5147edd69399a,509d1f842b0773e85c6beec0bb530542efd35cb9,b2337b25e7d1df04928bf6698a9c0b2764df7795,76f81d5047273fa64a434457531d400fc5d90fac,6e388d1db5ba8dd9de80522a4ddf50402cf443b3,8c6ab78f2797e076f9106af81090d0ab9904f5cd,ceccac4f0e50ec9e36e8d2800b8f2c7c3b4d903e,f920a2e4cb52165850990d9d37d391b630f7de14
3,590a4dd2afa1ad07090fd5f390f65a55e3dc5f56,591ea59cf88e3397db5d60eb00a5147edd69399a,e9b24f94c0007acb4b7169b945622efcd332cf6f,6ec56c3efdeb067eb20bd2f4a6ccdae07d640cc1,d32dd11040b254cd889c9ead2d4a50f6e3900196,f920a2e4cb52165850990d9d37d391b630f7de14,6a42281c99a4cff2ce9dba3fc91ad6a431af64d9,9c832009dfb1ee02053c9ce9b7770b6cd3191003,509d1f842b0773e85c6beec0bb530542efd35cb9,0452e85d0a41780463cfe079077ea5bd2f519c7a
4,5875360455060f20a3cba705f44a4e3987c9b9f3,6ec56c3efdeb067eb20bd2f4a6ccdae07d640cc1,0452e85d0a41780463cfe079077ea5bd2f519c7a,3c7e7fa0576bc8a2af71d5899581df36f4dab6c8,09e103887f42552d20328aa41891cf82dace79ab,54bba3a36204f8c71b93798c31f9e0b039914575,18067f8d8861af3bcae51ba04b6b11b9150b9ff2,591ea59cf88e3397db5d60eb00a5147edd69399a,f920a2e4cb52165850990d9d37d391b630f7de14,d32dd11040b254cd889c9ead2d4a50f6e3900196
...,...,...,...,...,...,...,...,...,...,...
258120,5964a27e0cb3344b0a18540e6b3120c433971c38,d090a2f7f222fadeeb64e4fbdfe1ca8451116b04,cce41299a022ada08aebf3d309acb07d5f00b014,346b34a42e801c64e043dbaacbe7fef9b8880774,4b5dbdb52b131410ea10b59ea451de62280b41d6,21310f6a93112e4cb928817e3af33ebb1bb62875,fa11fc4d4960379cb68cc6968ba6415168fef53c,89395d0ee75307b3beb30aef2f19fc680095d514,cc5250324fd7779782cf7066839a6be43bdbf72a,5f583dcccc43b5b7ac25d270e29c92d878fb2be0
258121,5964a27e0cb3344b0a18540e6b3120c433971c38,346b34a42e801c64e043dbaacbe7fef9b8880774,d090a2f7f222fadeeb64e4fbdfe1ca8451116b04,cce41299a022ada08aebf3d309acb07d5f00b014,566e0c6e3bcf2b8b3d310d96f111043d17ace817,bd3fc24710537130e97dc2dab4a6bf70b3884a8b,4b5dbdb52b131410ea10b59ea451de62280b41d6,4d2e5639041b40b0df2ee258aa504bd904133d80,f4107af4418d57aacb3542343f7b47768debdc75,5f583dcccc43b5b7ac25d270e29c92d878fb2be0
258122,346b34a42e801c64e043dbaacbe7fef9b8880774,5964a27e0cb3344b0a18540e6b3120c433971c38,cce41299a022ada08aebf3d309acb07d5f00b014,566e0c6e3bcf2b8b3d310d96f111043d17ace817,bd3fc24710537130e97dc2dab4a6bf70b3884a8b,d090a2f7f222fadeeb64e4fbdfe1ca8451116b04,4b5dbdb52b131410ea10b59ea451de62280b41d6,a94eb920c0a198fe8385f3de6a8e8e6d44b6f6c9,ee5ca7a7deaacdcd5d99355ff5f156dc45b74efa,a7986c0cea5d2571ea42011ab4407039e977c0bd
258123,346b34a42e801c64e043dbaacbe7fef9b8880774,d090a2f7f222fadeeb64e4fbdfe1ca8451116b04,5964a27e0cb3344b0a18540e6b3120c433971c38,f4107af4418d57aacb3542343f7b47768debdc75,cce41299a022ada08aebf3d309acb07d5f00b014,4b5dbdb52b131410ea10b59ea451de62280b41d6,fa11fc4d4960379cb68cc6968ba6415168fef53c,a7986c0cea5d2571ea42011ab4407039e977c0bd,180a351ec58c07d60949862c534373c43f548a9a,4d2e5639041b40b0df2ee258aa504bd904133d80


bssid_NはN個目のBSSIDを示しておりRSSI値が大きい順に番号が振られている。
100個しかない


In [16]:
# get numbers of bssids to embed them in a layer

# train
wifi_bssids = []
# bssidを列ごとにリストに入れていく
for i in range(100, 200):
    wifi_bssids.extend(train_df.iloc[:,i].values.tolist())
wifi_bssids = list(set(wifi_bssids))

train_wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(train): {train_wifi_bssids_size}')

# test
wifi_bssids_test = []
for i in range(100, 200):
    wifi_bssids_test.extend(test_df.iloc[:,i].values.tolist())
wifi_bssids_test = list(set(wifi_bssids_test))

test_wifi_bssids_size = len(wifi_bssids_test)
print(f'BSSID TYPES(test): {test_wifi_bssids_size}')


wifi_bssids.extend(wifi_bssids_test)
wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(all): {wifi_bssids_size}')


BSSID TYPES(train): 60633
BSSID TYPES(test): 30362
BSSID TYPES(all): 90995


## preprocessing

In [17]:
# preprocess

le = LabelEncoder()
le.fit(wifi_bssids)
le_site = LabelEncoder()
le_site.fit(train_df['site_id'])

ss = StandardScaler()
ss.fit(train_df.loc[:,RSSI_FEATS])


def preprocess(input_df, le=le, le_site=le_site, ss=ss):
    output_df = input_df.copy()
    # RSSIの正規化
    output_df.loc[:,RSSI_FEATS] = ss.transform(input_df.loc[:,RSSI_FEATS])

    # BSSIDのLE(1からふる)
    for i in BSSID_FEATS:
        output_df.loc[:,i] = le.transform(input_df.loc[:,i])
#         output_df.loc[:,i] = output_df.loc[:,i] + 1  # 0からではなく1から番号を振りたいため なぜ？

    # site_idのLE
    output_df.loc[:, 'site_id'] = le_site.transform(input_df.loc[:, 'site_id'])

    # なぜ２重でやる？
#     output_df.loc[:,RSSI_FEATS] = ss.transform(output_df.loc[:,RSSI_FEATS])
    return output_df

train = preprocess(train_df)
test = preprocess(test_df)

train  

  return self.partial_fit(X, y)
  from ipykernel import kernelapp as app
  from ipykernel import kernelapp as app


Unnamed: 0,ssid_0,ssid_1,ssid_2,ssid_3,ssid_4,ssid_5,ssid_6,ssid_7,ssid_8,ssid_9,...,frequency_97,frequency_98,frequency_99,wp_tmestamp,x,y,floor,floor_str,path_id,site_id
0,b7e6027447eb1f81327d66cfd3adbe557aabf26c,da39a3ee5e6b4b0d3255bfef95601890afd80709,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,cef6dc5e595dd99c3b2c605de65cfc1f147e892b,7182afc4e5c212133d5d7d76eb3df6c24618302b,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,8c1562bec17e1425615f3402f72dded3caa42ce5,b7e6027447eb1f81327d66cfd3adbe557aabf26c,da39a3ee5e6b4b0d3255bfef95601890afd80709,...,5745,5745,5180,1578469851129,157.99141,102.125390,-1.0,B1,5e158ef61506f2000638fd1f,0
1,b7e6027447eb1f81327d66cfd3adbe557aabf26c,cef6dc5e595dd99c3b2c605de65cfc1f147e892b,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,da39a3ee5e6b4b0d3255bfef95601890afd80709,7182afc4e5c212133d5d7d76eb3df6c24618302b,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,b7e6027447eb1f81327d66cfd3adbe557aabf26c,...,5180,2462,5765,1578469857653,162.93443,106.413020,-1.0,B1,5e158ef61506f2000638fd1f,0
2,cef6dc5e595dd99c3b2c605de65cfc1f147e892b,b7e6027447eb1f81327d66cfd3adbe557aabf26c,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,b7e6027447eb1f81327d66cfd3adbe557aabf26c,7182afc4e5c212133d5d7d76eb3df6c24618302b,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,...,5180,2452,5320,1578469857653,162.93443,106.413020,-1.0,B1,5e158ef61506f2000638fd1f,0
3,cef6dc5e595dd99c3b2c605de65cfc1f147e892b,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,b7e6027447eb1f81327d66cfd3adbe557aabf26c,da39a3ee5e6b4b0d3255bfef95601890afd80709,7182afc4e5c212133d5d7d76eb3df6c24618302b,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,b7e6027447eb1f81327d66cfd3adbe557aabf26c,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,...,2452,2452,5765,1578469857653,162.93443,106.413020,-1.0,B1,5e158ef61506f2000638fd1f,0
4,da39a3ee5e6b4b0d3255bfef95601890afd80709,da39a3ee5e6b4b0d3255bfef95601890afd80709,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b7e6027447eb1f81327d66cfd3adbe557aabf26c,7182afc4e5c212133d5d7d76eb3df6c24618302b,5731b8e08abc69d4c4d685c58164059207c93310,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,7182afc4e5c212133d5d7d76eb3df6c24618302b,...,2452,2467,5765,1578469862177,168.49713,109.861336,-1.0,B1,5e158ef61506f2000638fd1f,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
258120,ea7731d04cf9ed352d4805b1ff904bebdf60eb49,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,4abd3985ba804364272767c04cdc211615f77c56,b5d43f6b4e1938ed497c7b589c6eae9ac0bee168,70ba065d6d5456835fa594d193b2f41335da9dec,0a8a55cf161bc4980194ec9f7f7a448439be4b74,da39a3ee5e6b4b0d3255bfef95601890afd80709,2f797a25b58a1ed92176550ac6770f764703401d,1f09251bbfadafb11c63c87963af25238d6bc886,...,5745,5805,5765,1573733061352,203.53165,143.513960,6.0,F7,5dcd5c9323759900063d590a,23
258121,ea7731d04cf9ed352d4805b1ff904bebdf60eb49,4abd3985ba804364272767c04cdc211615f77c56,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,b5d43f6b4e1938ed497c7b589c6eae9ac0bee168,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,4abd3985ba804364272767c04cdc211615f77c56,1f09251bbfadafb11c63c87963af25238d6bc886,...,5805,5765,5745,1573733070079,192.57130,145.781450,6.0,F7,5dcd5c9323759900063d590a,23
258122,4abd3985ba804364272767c04cdc211615f77c56,ea7731d04cf9ed352d4805b1ff904bebdf60eb49,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,1556355684145fce5e67ba749d943a180266ad90,b5d43f6b4e1938ed497c7b589c6eae9ac0bee168,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,5d998a8668536c4f51004c25f474117fe9555f78,...,5745,0,0,1573733070079,192.57130,145.781450,6.0,F7,5dcd5c9323759900063d590a,23
258123,4abd3985ba804364272767c04cdc211615f77c56,1556355684145fce5e67ba749d943a180266ad90,ea7731d04cf9ed352d4805b1ff904bebdf60eb49,4abd3985ba804364272767c04cdc211615f77c56,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,b5d43f6b4e1938ed497c7b589c6eae9ac0bee168,0a8a55cf161bc4980194ec9f7f7a448439be4b74,5d998a8668536c4f51004c25f474117fe9555f78,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,...,5745,5805,5765,1573733070079,192.57130,145.781450,6.0,F7,5dcd5c9323759900063d590a,23


In [18]:
site_count = len(train['site_id'].unique())
site_count

24

## PyTorch model
- embedding layerが重要  

In [19]:
# dataset
from torch.utils.data import Dataset, DataLoader
class IndoorDataset(Dataset):
    def __init__(self, df, phase='train'):
        self.df = df
        self.phase = phase
        self.bssid_feats = df[BSSID_FEATS].values.astype(int)
        self.rssi_feats = df[RSSI_FEATS].values.astype(np.float32)
        self.site_id = df['site_id'].values.astype(int)

        if phase in ['train', 'valid']:
            self.xy = df[['x', 'y']].values.astype(np.float32)
            self.floor = df['floor'].values.astype(np.float32)
        
    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        
        feature = {
            'BSSID_FEATS':self.bssid_feats[idx],
            'RSSI_FEATS':self.rssi_feats[idx],
            'site_id':self.site_id[idx]
        }
        if self.phase in ['train', 'valid']:
            target = {
                'xy':self.xy[idx],
                'floor':self.floor[idx]
            }
        else:
            target = {}
        return feature, target

In [20]:
import torch
from torch import nn

class LSTMModel(nn.Module):
    def __init__(self, bssid_size=94248, site_size=24, embedding_dim=64):
        super(LSTMModel, self).__init__()
        
        # bssid
        # ->64次元に圧縮後sequence化にする
        # wifi_bssids_sizeが辞書の数を表す
        self.bssid_embedding = nn.Embedding(bssid_size, 64, max_norm=True)
        # site
        # ->2次元に圧縮後sequence化する
        # site_countが辞書の数を表す       
        self.site_embedding = nn.Embedding(site_size, 64, max_norm=True)

        # rssi
        # 次元を64倍に線形変換
        self.rssi = nn.Sequential(
            nn.BatchNorm1d(NUM_FEATS),
            nn.Linear(NUM_FEATS, NUM_FEATS * 64)
        )
        
        concat_size = 64 + (NUM_FEATS * 64) + (NUM_FEATS * 64)
        self.linear_layer2 = nn.Sequential(
            nn.BatchNorm1d(concat_size),
            nn.Dropout(0.3),
            nn.Linear(concat_size, 256),
            nn.ReLU()
        )
        self.bn1 = nn.BatchNorm1d(concat_size)

        self.flatten = nn.Flatten()

        self.dropout1 = nn.Dropout(0.3)
        self.linear1 = nn.Linear(in_features=concat_size, out_features=256)#, bias=False)
        self.bn2 = nn.BatchNorm1d(256)

        self.batch_norm1 = nn.BatchNorm1d(1)
        self.lstm1 = nn.LSTM(input_size=256,hidden_size=128,dropout=0.3, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=128,hidden_size=16,dropout=0.1, batch_first=True)

        self.fc_xy = nn.Linear(16, 2)
        # self.fc_x = nn.Linear(16, 1)
        # self.fc_y = nn.Linear(16, 1)
        self.fc_floor = nn.Linear(16, 1)

    
    def forward(self, x):
        # input embedding
        batch_size = x["site_id"].shape[0]
        x_bssid = self.bssid_embedding(x['BSSID_FEATS'])
        x_bssid = self.flatten(x_bssid)
        
        x_site_id = self.site_embedding(x['site_id'])
        x_site_id = self.flatten(x_site_id)

        x_rssi = self.rssi(x['RSSI_FEATS'])

        x = torch.cat([x_bssid, x_site_id, x_rssi], dim=1)
        x = self.linear_layer2(x)

        # lstm layer
        x = x.view(batch_size, 1, -1)  # [batch, 1]->[batch, 1, 1]
        x = self.batch_norm1(x)
        x, _ = self.lstm1(x)
        x = torch.relu(x)
        x, _ = self.lstm2(x)
        x = torch.relu(x)

        # output [batch, 1, 1] -> [batch]
        # x_ = self.fc_x(x).view(-1)
        # y_ = self.fc_y(x).view(-1)
        xy = self.fc_xy(x).squeeze(1)
        floor = torch.relu(self.fc_floor(x)).view(-1)
        # return {"x":x_, "y":y_, "floor":floor} 
        return {"xy": xy, "floor": floor}

In [21]:
def mean_position_error(xhat, yhat, fhat, x, y, f):
    intermediate = np.sqrt(np.power(xhat-x, 2) + np.power(yhat-y, 2)) + 15 * np.abs(fhat-f)
    return intermediate.sum()/xhat.shape[0]

def to_np(input):
    return input.detach().cpu().numpy()

In [22]:
def get_optimizer(model: nn.Module, config: dict):
    optimizer_config = config["optimizer"]
    optimizer_name = optimizer_config.get("name")
    base_optimizer_name = optimizer_config.get("base_name")
    optimizer_params = optimizer_config['params']

    if hasattr(optim, optimizer_name):
        optimizer = optim.__getattribute__(optimizer_name)(model.parameters(), **optimizer_params)
        return optimizer
    else:
        base_optimizer = optim.__getattribute__(base_optimizer_name)
        optimizer = globals().get(optimizer_name)(
            model.parameters(), 
            base_optimizer,
            **optimizer_config["params"])
        return  optimizer

def get_scheduler(optimizer, config: dict):
    scheduler_config = config["scheduler"]
    scheduler_name = scheduler_config.get("name")

    if scheduler_name is None:
        return
    else:
        return optim.lr_scheduler.__getattribute__(scheduler_name)(
            optimizer, **scheduler_config["params"])


def get_criterion(config: dict):
    loss_config = config["loss"]
    loss_name = loss_config["name"]
    loss_params = {} if loss_config.get("params") is None else loss_config.get("params")
    if hasattr(nn, loss_name):
        criterion = nn.__getattribute__(loss_name)(**loss_params)
    else:
        criterion = globals().get(loss_name)(**loss_params)

    return criterion

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

In [23]:
# Learner class(pytorch-lighting)
class Learner(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.xy_criterion = get_criterion(config)
        self.f_criterion = get_criterion(config)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = self.xy_criterion(output["xy"], y["xy"])
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        xy_loss = self.xy_criterion(output["xy"], y["xy"])
        f_loss = self.f_criterion(output["floor"], y["floor"])
        loss = xy_loss  # + f_loss
        mpe = mean_position_error(
            to_np(output['xy'][:, 0]), to_np(output['xy'][:, 1]), 0, 
            to_np(y['xy'][:, 0]), to_np(y['xy'][:, 1]), 0)
        
        # floor lossは現状は無視して良い
        self.log(f'Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/xy', xy_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/floor', f_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'MPE/val', mpe, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return mpe
    
    def validation_epoch_end(self, outputs):
        avg_loss = np.mean(outputs)
        print(f'epoch = {self.current_epoch}, mpe_loss = {avg_loss}')

    def configure_optimizers(self):
        optimizer = get_optimizer(self.model, self.config)
        scheduler = get_scheduler(optimizer, self.config)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "Loss/val"}

In [24]:
# oof
def evaluate(model, loaders, phase):
    x_list = []
    y_list = []
    f_list = []
    with torch.no_grad():
        for batch in loaders[phase]:
            x, y = batch
            output = model(x)
            x_list.append(to_np(output['xy'][:, 0]))
            y_list.append(to_np(output['xy'][:, 1]))
            f_list.append(to_np(output['floor']))

    x_list = np.concatenate(x_list)
    y_list = np.concatenate(y_list)
    f_list = np.concatenate(f_list)
    return x_list, y_list, f_list

## train

In [25]:
oofs = np.zeros((len(train), 2), dtype = np.float32)  # 全てのoofをdfで格納する
predictions = []  # 全ての予測値をdfで格納する
val_scores = []
# skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
gkf = GroupKFold(n_splits=N_SPLITS)
# for fold, (trn_idx, val_idx) in enumerate(skf.split(train.loc[:, 'path'], train.loc[:, 'path'])):
for fold, (trn_idx, val_idx) in enumerate(gkf.split(train.loc[:, 'path_id'], groups=train.loc[:, 'path_id'])):

    # 指定したfoldのみループを回す

    print('=' * 20)
    print(f'Fold {fold}')
    print('=' * 20)

    # train/valid data
    trn_df = train.loc[trn_idx, BSSID_FEATS + RSSI_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)
    val_df = train.loc[val_idx, BSSID_FEATS + RSSI_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)

    # data loader
    loaders = {}
    loader_config = config["loader"]
    loaders["train"] = DataLoader(IndoorDataset(trn_df, phase="train"), **loader_config["train"], worker_init_fn=worker_init_fn) 
    loaders["valid"] = DataLoader(IndoorDataset(val_df, phase="valid"), **loader_config["valid"], worker_init_fn=worker_init_fn)
    loaders["test"] = DataLoader(IndoorDataset(test, phase="test"), **loader_config["test"], worker_init_fn=worker_init_fn)
    
    # model
    model = LSTMModel(wifi_bssids_size, site_count)
    model_name = model.__class__.__name__
    
    # loggers
    RUN_NAME = f'exp{str(EXP_NAME)}'
    wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type=RUN_NAME + f'-fold-{fold}')
    wandb.run.name = RUN_NAME + f'-fold-{fold}'
    wandb_config = wandb.config
    wandb_config.model_name = model_name
    wandb.watch(model)
    
    
    loggers = []
    loggers.append(WandbLogger())

    learner = Learner(model, config)
    
    # callbacks
    callbacks = []
    checkpoint_callback = ModelCheckpoint(
        monitor=f'Loss/val',
        mode='min',
        dirpath=OUTPUT_DIR,
        verbose=False,
        filename=f'{model_name}-{learner.current_epoch}-{fold}')
    callbacks.append(checkpoint_callback)

    early_stop_callback = EarlyStopping(
        monitor='Loss/val',
        min_delta=0.00,
        patience=10,
        verbose=True,
        mode='min')
    callbacks.append(early_stop_callback)
    
    trainer = pl.Trainer(
        logger=loggers,
        callbacks=callbacks,
        max_epochs=MAX_EPOCHS,
        default_root_dir=OUTPUT_DIR,
        gpus=1,
        fast_dev_run=DEBUG,
        deterministic=True,
        benchmark=True,
#         precision=16,
#         progress_bar_refresh_rate=0  # vscodeの時progress barの動作が遅いので表示しない
        )


    trainer.fit(learner, train_dataloader=loaders['train'], val_dataloaders=loaders['valid'])

    #############
    # validation (to make oof)
    #############
    model.eval()
    oof_x, oof_y, oof_f = evaluate(model, loaders, phase="valid")
    oofs[val_idx, 0] = oof_x
    oofs[val_idx, 1] = oof_y

    
    val_score = mean_position_error(
        oof_x, oof_y, 0,
        val_df['x'].values, val_df['y'].values, 0)
    val_scores.append(val_score)
    print(f"fold {fold}: mean position error {val_score}")

    #############
    # inference
    #############
    preds_x, preds_y, preds_f = evaluate(model, loaders, phase="test")
    test_preds = pd.DataFrame(np.stack((preds_f, preds_x, preds_y))).T
    test_preds.columns = sub.columns
    test_preds["site_path_timestamp"] = test["site_path_timestamp"]
    test_preds["floor"] = test_preds["floor"].astype(int)
    predictions.append(test_preds)
    

Fold 0


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))
[34m[1mwandb[0m: Currently logged in as: [33msqrt4kaido[0m (use `wandb login --relogin` to force relogin)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.8 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.8 M    Trainable params
0         Non-trainable params
11.8 M    Total params
47.060    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 135.8000259399414


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 130.62685050577775


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 105.60582144673369


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 91.00020438365902


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 84.02335176373472


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 81.59445069381367


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 81.25338824192728


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 81.48093150348252


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 81.70892866006895


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 81.81809040076244


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 81.83781193541591


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 58.179682257179614


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 48.87949196786053


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 42.224844721048186


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 34.398960034076595


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 29.259575208798353


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 24.207503289782032


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 20.634366575517692


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 17.36483926223195


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 15.68844353785262


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 14.452224559950563


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 13.399619726870874


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 11.944204067925623


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 11.730755112569694


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 11.395116039125046


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 10.901851076660424


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 11.221438813006802


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 10.15766740447918


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 9.732048886965094


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 9.856073560220388


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 9.922029201294283


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 9.69457903117167


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 9.697015034281797


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 9.790456993378521


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 9.596496587496688


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 9.277262793970301


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 9.183351087597963


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 9.376507153930397


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 9.108586016051214


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 9.927116493609455


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 9.001709274365309


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 9.31218242858549


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 9.465435067245137


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 8.539821625977806


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 8.388722534635141


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 8.32652168391677


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 8.27843983658539


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 8.327626954980138


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 8.238431599827512


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 8.254618714687501


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 8.251307316859517


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 8.215697584679882


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 8.255848410822447


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 8.2186753063724


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 8.227459574595537


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 8.209156223501518


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 8.185951298217685


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 8.195995242385997


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 8.173303574865892


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 8.199877809414561


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 8.193950710840953
fold 0: mean position error 8.191549548991294
Fold 1


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,58.00534
Loss/xy,58.00534
Loss/floor,5.38096
MPE/val,8.19155
epoch,59.0
trainer/global_step,24239.0
_runtime,302.0
_timestamp,1618240828.0
_step,59.0


0,1
Loss/val,█▆▄▄▄▄▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▆▄▄▄▄▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,███████▂▂▃▃▁▃▃▄▄▅▅▅▅▆▅▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
MPE/val,█▇▅▅▅▅▅▄▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.8 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.8 M    Trainable params
0         Non-trainable params
11.8 M    Total params
47.060    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 155.24102020263672


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 116.30949365890227


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 90.27973139981162


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 81.13762520307431


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 79.65129481480577


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 80.03455038000695


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 55.13054811403234


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 44.70443746767839


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 37.64919569523167


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 33.26257497506428


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 28.49725941077595


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 24.573624227564903


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 20.608383363880385


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 17.83902774721922


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 15.797276849885561


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 14.285050537395977


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 13.249610953622595


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 12.17951372542073


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 11.695887257015558


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 11.705100586346296


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 10.661985203638004


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 10.19470611704458


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 10.610176222017495


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 10.157618139263636


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 9.563963455162426


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 9.419030460664263


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 9.43572563504645


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 9.766936590109166


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 9.191282526270623


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 8.997410133608428


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 9.066134529624765


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 9.387885004786744


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 8.937625695541284


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 9.043690528203113


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 9.103533605943257


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 9.35182990556827


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 8.735590493363839


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 9.571344192651516


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 8.57802982692463


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 8.433670223071676


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 8.78119882592505


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 8.567899157746837


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 8.944955671429287


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 8.81606551744109


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 8.159348325323906


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 8.100004029360079


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 8.036016716748806


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 8.030627889283258


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 8.033688841982505


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 7.990240437147155


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 8.00027925532435


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 8.010347652146478


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 7.996411998756555


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 8.016323846984756


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 7.9851420117590965


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 7.974978922475683


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 7.968690961139036


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 7.972494768227126


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 7.969012157660136


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 7.971749977679467


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 7.973233625431982


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 7.9716364455792394


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 7.963661527439266


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 7.972430496026973


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 7.964780144105488


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 7.958785046725754


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 7.9681470426627214


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 7.965931763812839


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 7.974634366593668


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 7.971732635775762


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 7.960608968790284


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 7.971252920720628


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 7.960884648469138


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 7.957574933207028


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 7.9706774832692755


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 7.974940546510104
fold 1: mean position error 7.971440327853582
Fold 2


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,50.20943
Loss/xy,50.20943
Loss/floor,5.05675
MPE/val,7.97144
epoch,74.0
trainer/global_step,30299.0
_runtime,379.0
_timestamp,1618241214.0
_step,74.0


0,1
Loss/val,█▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,███▇▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
MPE/val,█▆▆▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.8 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.8 M    Trainable params
0         Non-trainable params
11.8 M    Total params
47.060    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 149.55107498168945


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 122.78918478155387


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 96.80499970220033


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 84.77094107755062


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 80.83242211512375


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 80.20392052909838


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 58.377765339277616


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 54.60047019691889


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 51.630816645391754


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 49.72870551449793


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 41.53731334639623


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 33.77810675370534


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 27.2754194590458


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 23.121797385374055


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 19.185516019820056


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 16.904791465441573


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 14.833818002155374


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 13.532764840523171


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 12.414262516030616


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 11.702193687007952


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 11.427962409937598


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 10.874175596564886


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 10.53036969432042


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 11.934870749358598


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 9.841949367212152


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 9.809004170135072


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 9.963278136955976


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 9.620331886393704


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 9.245109704761362


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 9.306191190254154


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 9.093328488460445


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 9.013959277791205


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 8.965676342014941


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 9.047503341026061


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 8.79251337055835


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 9.06405160933924


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 9.34024570431762


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 8.648286849006915


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 8.938369697844072


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 8.856104930384463


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 8.809442174519365


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 8.930186866600291


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 8.753183115646461


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 8.878030444752289


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 9.194091487811205


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 8.567775827017925


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 8.388927860615485


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 10.589792694787196


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 8.540059529261736


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 8.687138045360408


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 8.478834814583763


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 7.846203791918835


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 7.771625767998316


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 7.742514916377215


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 7.72840215286406


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 7.722912463836359


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 7.717536167373614


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 7.712960043931049


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 7.705925362653072


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 7.714237260379692


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 7.708014277400982


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 7.70017517296464


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 7.682415004862417


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 7.672647615326097


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 7.6804973712159095


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 7.67909400375624


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 7.676002961726958


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 7.674132647861487


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 7.6754022557498


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 7.66509164539686


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 7.679686353619042


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 7.668059175265828


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 7.667518545166309


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 7.669675379771442


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 7.668629060177589


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 7.664149207757807


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 7.667642162786227


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 7.665802686771266


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 7.664331576603404


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 7.6737559025386215


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 7.6698692878318795


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 7.676475618675689


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 7.678208999889287


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 7.665742773445943


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 7.667706242874047


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 7.67103390662005


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 7.672002648001989


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 7.672121786996245


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 7.669851372584953


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 7.665960478385521


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 7.663680426742563


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 7.670345506773623


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 7.661681661611389


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 7.66276080709734


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 7.665617137202427


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 7.670549235707741


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 7.676018897485927


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 7.669626779184136


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 7.673072202029309


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 7.674023849296237


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 7.6709714565821425


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 7.671799680257719


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 7.669256988840875
fold 2: mean position error 7.666425908389862
Fold 3


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,46.81648
Loss/xy,46.81648
Loss/floor,4.18743
MPE/val,7.66643
epoch,101.0
trainer/global_step,41207.0
_runtime,510.0
_timestamp,1618241730.0
_step,101.0


0,1
Loss/val,█▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,▁▁▅▆▄▅▆▇▇▇██████████████████████████████
MPE/val,█▆▄▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.8 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.8 M    Trainable params
0         Non-trainable params
11.8 M    Total params
47.060    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 142.0883445739746


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 138.58492027224952


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 114.62236557531163


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 98.62092948873584


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 89.33778999146433


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 84.86081967211075


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 83.16711397440959


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 82.70125823130354


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 82.6652247636514


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 82.70972065778409


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 59.94616637708145


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 53.78065199010513


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 42.51548448654824


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 34.7762285681554


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 28.903262669380556


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 24.282131439157887


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 21.00731675378514


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 18.89972865479591


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 17.03063836504848


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 15.75278163013447


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 14.63665759658147


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 13.854145462367168


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 13.607637602184564


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 12.730968995624735


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 11.938361518873192


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 11.826570032705174


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 11.528056491980665


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 11.01488734394711


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 10.706896116304481


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 10.468203351549027


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 10.603115508713273


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 10.17414776492688


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 9.873750596254734


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 9.847499624573357


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 10.36363722089702


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 9.603947608305676


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 9.653294611771438


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 9.431802869584578


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 9.37714541869033


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 9.516571683589289


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 9.2447760327583


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 9.56065907975443


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 9.145729018451867


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 9.208138545196835


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 9.27952917778846


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 9.313861116865866


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 9.743208974372806


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 8.83763467222157


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 9.098474951573563


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 8.959123539921844


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 9.192765164172574


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 9.520148436702401


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 8.472105205915803


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 8.373985546580844


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 8.294962720954286


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 8.294494154612964


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 8.272476036002905


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 8.26423967995752


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 8.235934596220003


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 8.228615802360538


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 8.233971502729537


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 8.212174873385377


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 8.215985140925447


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 8.20216519935088


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 8.206279804627425


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 8.20653741171142


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 8.198507601716315


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 8.18568124140149


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 8.188060243914276


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 8.175941766473395


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 8.181705966043


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 8.215244092802426


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 8.185439856067674


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 8.174998386796297


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 8.168210075658354


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 8.153377586458058


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 8.159937554449971


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 8.153893575085133


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 8.150285925912385


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 8.149074218430073


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 8.153031583762683


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 8.147441625234062


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 8.151864297408036


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 8.143357260964537


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 8.14822689244132


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 8.15631755089024


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 8.143462163195336


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 8.147555579128277


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 8.145706119726201


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 8.145536412942805


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 8.14301516440141


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 8.144976852508083


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 8.138347356881802


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 8.141853796415157


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 8.147755893025206


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 8.145076222883523


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 8.139712357629366


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 8.14025269087901


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 8.132712011664983


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 8.141746542035804


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 8.14044536279535


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 8.143100294325334


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 8.142396088046395


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 8.144951366864461


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 8.143465111032084


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 8.142544860195544


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 8.1468216338627


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 8.143871160095072


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 8.144417828141636
fold 3: mean position error 8.142003652268727
Fold 4


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,53.26152
Loss/xy,53.26152
Loss/floor,4.4991
MPE/val,8.142
epoch,107.0
trainer/global_step,43631.0
_runtime,555.0
_timestamp,1618242291.0
_step,107.0


0,1
Loss/val,█▅▄▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▅▄▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,▁▁▁▁▂▄▅▆▆▇▇▇▇███████████████████████████
MPE/val,█▆▅▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.8 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.8 M    Trainable params
0         Non-trainable params
11.8 M    Total params
47.060    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 150.51847076416016


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 134.99227560675416


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 111.11755766098209


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 95.35843028073542


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 84.49261951390991


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 67.58954177700925


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 61.09037850861502


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 56.99449417411969


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 54.32584685826066


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 52.39227737257907


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 50.92205180693034


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 49.157635414598985


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 40.798884312469184


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 34.76594269790827


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 30.154037723108022


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 25.943620786927923


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 22.658160369197642


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 20.097121026822837


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 17.91201190616339


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 16.08627617860993


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 15.157724019118916


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 13.687336163795752


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 13.064724602274374


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 12.18463729558466


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 12.066587036790443


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 11.487587620441346


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 10.869584366926706


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 10.650795654114694


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 10.33669225588328


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 10.24715984194516


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 10.19643508605646


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 10.241053953220765


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 9.65905775376788


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 9.824141849835527


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 9.508735875855647


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 9.33853150794863


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 9.605646655416516


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 9.504166918328563


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 9.435665395644493


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 9.578645077318248


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 9.087342026665521


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 9.14023711783843


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 9.03947866880831


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 9.143906055343797


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 9.052459553044828


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 9.28992927467955


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 9.025334765290408


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 8.978605137226813


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 9.09168808866533


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 9.102852330938216


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 9.34660702935887


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 8.386158278014817


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 8.299851170552632


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 8.264143681567841


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 8.212396680813024


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 8.189160874621702


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 8.190530503372432


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 8.159343730790988


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 8.169717140286622


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 8.172991318555509


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 8.160128925836705


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 8.170827814812013


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 8.181416814611898


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 8.1346763540078


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 8.129903650367128


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 8.13373131381542


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 8.121440275155603


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 8.12646975144579


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 8.115262185587186


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 8.122976765518766


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 8.113991446664462


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 8.11629811830693


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 8.109364422624306


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 8.117686529676096


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 8.11449372672035


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 8.113523801104856


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 8.112145597383458


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 8.113379352282978


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 8.112310346071292


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 8.108544449414634


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 8.104635530726188


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 8.109142272651507


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 8.105169237577853


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 8.108526193577672


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 8.10539634399103


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 8.108520857978158


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 8.112366825708069


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 8.109469832352032


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 8.110073980843529


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 8.110748103313412


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 8.10368282287011
fold 4: mean position error 8.101528006346443


In [26]:
oofs_df = pd.DataFrame(oofs, columns=['x', 'y'])
oofs_df['path'] = train_df['path_id']
oofs_df['timestamp'] = train_df['wp_tmestamp']
oofs_df['site'] = train_df['site_id']
oofs_df['site_path_timestamp'] = oofs_df['site'] + '_' + oofs_df['path'] + '_' + oofs_df['timestamp'].astype(str)
oofs_df['floor'] = train_df['floor']
oofs_df.to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}.csv", index=False)
oofs_df

Unnamed: 0,x,y,path,timestamp,site,site_path_timestamp,floor
0,161.518951,104.774261,5e158ef61506f2000638fd1f,1578469851129,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
1,163.856537,105.838348,5e158ef61506f2000638fd1f,1578469857653,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
2,162.445648,110.185104,5e158ef61506f2000638fd1f,1578469857653,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
3,163.074844,108.243256,5e158ef61506f2000638fd1f,1578469857653,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
4,160.549713,111.044586,5e158ef61506f2000638fd1f,1578469862177,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
...,...,...,...,...,...,...,...
258120,194.339676,143.278320,5dcd5c9323759900063d590a,1573733061352,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0
258121,191.134796,143.961334,5dcd5c9323759900063d590a,1573733070079,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0
258122,190.360886,143.833862,5dcd5c9323759900063d590a,1573733070079,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0
258123,191.964066,143.614563,5dcd5c9323759900063d590a,1573733070079,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0


In [28]:
oofs_score = mean_position_error(
        oofs_df['x'], oofs_df['y'], 0,
        train_df['x'].values, train_df['y'].values, 0)
print(f"CV:{oofs_score}")

CV:8.014589488769984


In [29]:
# foldの結果を平均した後、reindexでsubmission fileにindexを合わせる
all_preds = pd.concat(predictions).groupby('site_path_timestamp').mean().reindex(sub.index)
all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,86.680717,104.738564
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,81.394920,103.298012
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,85.116112,105.778549
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,88.789299,108.187607
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,86.664574,107.656998
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,0,215.937210,92.194702
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,0,210.181702,100.495781
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,0,208.077911,106.701790
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,0,201.723770,111.234108


In [30]:
# floorの数値を置換
simple_accurate_99 = pd.read_csv('../01/submission.csv')
all_preds['floor'] = simple_accurate_99['floor'].values
all_preds.to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}.csv")
all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,86.680717,104.738564
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,81.394920,103.298012
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,85.116112,105.778549
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,88.789299,108.187607
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,86.664574,107.656998
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,5,215.937210,92.194702
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,5,210.181702,100.495781
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,5,208.077911,106.701790
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,5,201.723770,111.234108


# Post Proccess

In [31]:
oofs_df = pd.read_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}.csv")
sub_df = pd.read_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}.csv")

In [32]:
def compute_rel_positions(acce_datas, ahrs_datas):
    step_timestamps, step_indexs, step_acce_max_mins = compute_f.compute_steps(acce_datas)
    headings = compute_f.compute_headings(ahrs_datas)
    stride_lengths = compute_f.compute_stride_length(step_acce_max_mins)
    step_headings = compute_f.compute_step_heading(step_timestamps, headings)
    rel_positions = compute_f.compute_rel_positions(stride_lengths, step_headings)
    return rel_positions

In [33]:
def correct_path(args):

    path, path_df = args
    T_ref  = path_df['timestamp'].values
    xy_hat = path_df[['x', 'y']].values
    txt_path = path_df['txt_path'].values[0]
    
    example = read_data_file(txt_path)
    rel_positions = compute_rel_positions(example.acce, example.ahrs)
    if T_ref[-1] > rel_positions[-1, 0]:
        rel_positions = [np.array([[0, 0, 0]]), rel_positions, np.array([[T_ref[-1], 0, 0]])]
    else:
        rel_positions = [np.array([[0, 0, 0]]), rel_positions]
    rel_positions = np.concatenate(rel_positions)
    
    T_rel = rel_positions[:, 0]
    delta_xy_hat = np.diff(scipy.interpolate.interp1d(T_rel, np.cumsum(rel_positions[:, 1:3], axis=0), axis=0)(T_ref), axis=0)

    N = xy_hat.shape[0]
    delta_t = np.diff(T_ref)
    alpha = (8.1)**(-2) * np.ones(N)
    beta  = (0.3 + 0.3 * 1e-3 * delta_t)**(-2)
    A = scipy.sparse.spdiags(alpha, [0], N, N)
    B = scipy.sparse.spdiags( beta, [0], N-1, N-1)
    D = scipy.sparse.spdiags(np.stack([-np.ones(N), np.ones(N)]), [0, 1], N-1, N)

    Q = A + (D.T @ B @ D)
    c = (A @ xy_hat) + (D.T @ (B @ delta_xy_hat))
    xy_star = scipy.sparse.linalg.spsolve(Q, c)

    return pd.DataFrame({
        'site_path_timestamp' : path_df['site_path_timestamp'],
        'floor' : path_df['floor'],
        'x' : xy_star[:, 0],
        'y' : xy_star[:, 1],
    })

In [34]:
tmp = sub_df['site_path_timestamp'].apply(lambda s : pd.Series(s.split('_')))
sub_df['site'] = tmp[0]
sub_df['path'] = tmp[1]
sub_df['timestamp'] = tmp[2].astype(float)

In [35]:
used_buildings = sorted(sub_df['site'].value_counts().index.tolist())
test_txts = sorted(glob.glob(str(DATA_DIR/'indoor-location-navigation') + f'/test/*.txt'))
train_txts = [sorted(glob.glob(str(DATA_DIR/'indoor-location-navigation') + f'/train/{used_building}/*/*.txt')) for used_building in used_buildings]
train_txts = sum(train_txts, [])

In [36]:
txt_pathes = []
for path in tqdm(sub_df['path'].values):
    txt_pathes.append([test_txt for test_txt in test_txts if path in test_txt][0])

100%|██████████| 10133/10133 [00:00<00:00, 27882.63it/s]


In [37]:
sub_df['txt_path'] = txt_pathes

In [38]:
processes = multiprocessing.cpu_count()
with multiprocessing.Pool(processes=processes) as pool:
    dfs = pool.imap_unordered(correct_path, sub_df.groupby('path'))
    dfs = tqdm(dfs)
    dfs = list(dfs)
sub_df_cm = pd.concat(dfs).sort_values('site_path_timestamp')

626it [01:08,  9.11it/s]


In [39]:
txt_pathes = []
for path in tqdm(oofs_df['path'].values):
    txt_pathes.append([train_txt for train_txt in train_txts if path in train_txt][0])

100%|██████████| 258125/258125 [02:35<00:00, 1656.84it/s]


In [40]:
oofs_df['txt_path'] = txt_pathes

In [41]:
processes = multiprocessing.cpu_count()
with multiprocessing.Pool(processes=processes) as pool:
    dfs = pool.imap_unordered(correct_path, oofs_df.groupby('path'))
    dfs = tqdm(dfs)
    dfs = list(dfs)
oofs_df_cm = pd.concat(dfs).sort_index()
oofs_df_cm

10852it [08:03, 22.45it/s]


Unnamed: 0,site_path_timestamp,floor,x,y
0,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,159.128950,104.586692
1,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.220472,108.161175
2,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.214949,108.164104
3,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.209110,108.164261
4,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,165.672209,111.009325
...,...,...,...,...
258120,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,200.704045,141.057770
258121,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,186.671693,144.794587
258122,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,186.689922,144.792721
258123,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,186.703115,144.792170


In [42]:
oofs_df_cm.to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}_cm.csv", index=False)

In [43]:
oofs_score = mean_position_error(
        oofs_df_cm['x'], oofs_df_cm['y'], 0,
        train_df['x'].values, train_df['y'].values, 0)
print(f"CV:{oofs_score}")

CV:6.377523991280171


In [44]:
sub_df_cm.to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}_cm.csv", index=False)

In [45]:
oofs_df_cm = pd.read_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}_cm.csv")

In [46]:
sub_df_cm = pd.read_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}_cm.csv")

In [47]:
def split_col(df):
    df = pd.concat([
        df['site_path_timestamp'].str.split('_', expand=True) \
        .rename(columns={0:'site',
                         1:'path',
                         2:'timestamp'}),
        df
    ], axis=1).copy()
    return df
def sub_process(sub, train_waypoints):
    train_waypoints['isTrainWaypoint'] = True
    sub = split_col(sub[['site_path_timestamp','floor','x','y']]).copy()
    sub = sub.merge(train_waypoints[['site','floorNo','floor']].drop_duplicates(), how='left')
    sub = sub.merge(
        train_waypoints[['x','y','site','floor','isTrainWaypoint']].drop_duplicates(),
        how='left',
        on=['site','x','y','floor']
             )
    sub['isTrainWaypoint'] = sub['isTrainWaypoint'].fillna(False)
    return sub.copy()

In [48]:
train_waypoints = pd.read_csv(str(DATA_DIR/'indoor-location-navigation') + '/train_waypoints.csv')


In [49]:
sub_df_cm = sub_process(sub_df_cm, train_waypoints)
oofs_df_cm = sub_process(oofs_df_cm, train_waypoints)

In [50]:
from scipy.spatial.distance import cdist

def add_xy(df):
    df['xy'] = [(x, y) for x,y in zip(df['x'], df['y'])]
    return df

def closest_point(point, points):
    """ Find closest point from a list of points. """
    return points[cdist([point], points).argmin()]

sub_df_cm = add_xy(sub_df_cm)
train_waypoints = add_xy(train_waypoints)

ds = []
for (site, myfloor), d in tqdm(sub_df_cm.groupby(['site','floor'])):
    true_floor_locs = train_waypoints.loc[(train_waypoints['floor'] == myfloor) &
                                          (train_waypoints['site'] == site)] \
        .reset_index(drop=True)
    if len(true_floor_locs) == 0:
        print(f'Skipping {site} {myfloor}')
        continue
    d['matched_point'] = [closest_point(x, list(true_floor_locs['xy'])) for x in d['xy']]
    d['x_'] = d['matched_point'].apply(lambda x: x[0])
    d['y_'] = d['matched_point'].apply(lambda x: x[1])
    ds.append(d)

sub_df_cm_ds = pd.concat(ds)


oofs_df_cm = add_xy(oofs_df_cm)
train_waypoints = add_xy(train_waypoints)

ds = []
for (site, myfloor), d in tqdm(oofs_df_cm.groupby(['site','floor'])):
    true_floor_locs = train_waypoints.loc[(train_waypoints['floor'] == myfloor) &
                                          (train_waypoints['site'] == site)] \
        .reset_index(drop=True)
    if len(true_floor_locs) == 0:
        print(f'Skipping {site} {myfloor}')
        continue
    d['matched_point'] = [closest_point(x, list(true_floor_locs['xy'])) for x in d['xy']]
    d['x_'] = d['matched_point'].apply(lambda x: x[0])
    d['y_'] = d['matched_point'].apply(lambda x: x[1])
    ds.append(d)

oofs_df_cm_ds = pd.concat(ds)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
100%|██████████| 118/118 [00:04<00:00, 26.91it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation:

In [51]:
def snap_to_grid(sub, threshold):
    """
    Snap to grid if within a threshold.
    
    x, y are the predicted points.
    x_, y_ are the closest grid points.
    _x_, _y_ are the new predictions after post processing.
    """
    sub['_x_'] = sub['x']
    sub['_y_'] = sub['y']
    sub.loc[sub['dist'] < threshold, '_x_'] = sub.loc[sub['dist'] < threshold]['x_']
    sub.loc[sub['dist'] < threshold, '_y_'] = sub.loc[sub['dist'] < threshold]['y_']
    return sub.copy()



In [52]:
# Calculate the distances
sub_df_cm_ds['dist'] = np.sqrt( (sub_df_cm_ds.x-sub_df_cm_ds.x_)**2 + (sub_df_cm_ds.y-sub_df_cm_ds.y_)**2 )
sub_pp = snap_to_grid(sub_df_cm_ds, threshold=5)
sub_pp = sub_pp[['site_path_timestamp','floor','_x_','_y_','site','path','floorNo']] \
    .rename(columns={'_x_':'x', '_y_':'y'})

# Calculate the distances
oofs_df_cm_ds['dist'] = np.sqrt( (oofs_df_cm_ds.x-oofs_df_cm_ds.x_)**2 + (oofs_df_cm_ds.y-oofs_df_cm_ds.y_)**2 )
oofs_pp = snap_to_grid(oofs_df_cm_ds, threshold=5)
oofs_pp = oofs_pp[['site_path_timestamp','floor','_x_','_y_','site','path','floorNo']] \
    .rename(columns={'_x_':'x', '_y_':'y'})

In [53]:
sub_pp = sub_pp.sort_index()
sub_pp

Unnamed: 0,site_path_timestamp,floor,x,y,site,path,floorNo
0,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,93.728470,97.948860,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
1,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,79.662285,102.766754,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
2,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,80.718400,107.197110,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
3,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,81.657740,110.509090,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
4,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,89.695300,111.917800,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
...,...,...,...,...,...,...,...
10128,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,215.262270,97.973610,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6
10129,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,210.662190,104.351776,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6
10130,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,205.511300,107.841324,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6
10131,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,197.610623,114.583396,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6


In [54]:
sub_pp[['site_path_timestamp','floor','x','y']] \
    .to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}_pp.csv", index=False)

In [55]:
oofs_pp = oofs_pp.sort_index()
oofs_pp

Unnamed: 0,site_path_timestamp,floor,x,y,site,path,floorNo
0,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,158.496950,107.122680,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
1,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.934430,106.413020,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
2,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.934430,106.413020,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
3,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.934430,106.413020,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
4,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,168.497130,109.861336,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
...,...,...,...,...,...,...,...
258120,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,202.503340,140.972670,5dc8cea7659e181adb076a3f,5dcd5c9323759900063d590a,F7
258121,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,186.671693,144.794587,5dc8cea7659e181adb076a3f,5dcd5c9323759900063d590a,F7
258122,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,186.689922,144.792721,5dc8cea7659e181adb076a3f,5dcd5c9323759900063d590a,F7
258123,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,186.703115,144.792170,5dc8cea7659e181adb076a3f,5dcd5c9323759900063d590a,F7


In [56]:
oofs_score = mean_position_error(
        oofs_pp['x'], oofs_pp['y'], 0,
        train_df['x'].values, train_df['y'].values, 0)
print(f"CV:{oofs_score}")

CV:6.022643066563715


In [57]:
oofs_pp[['site_path_timestamp','floor','x','y']] \
    .to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}_pp.csv", index=False)

In [58]:
wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type='summary')
wandb.run.name = 'summary'
wandb.log({'CV_score': oofs_score})
wandb.save(utils.get_notebook_path())
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,52.37339
Loss/xy,52.37339
Loss/floor,4.72416
MPE/val,8.10153
epoch,89.0
trainer/global_step,36359.0
_runtime,455.0
_timestamp,1618242752.0
_step,89.0


0,1
Loss/val,█▅▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▅▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,██▅▃▁▂▂▂▁▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
MPE/val,█▆▄▄▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███




VBox(children=(Label(value=' 0.00MB of 0.33MB uploaded (0.00MB deduped)\r'), FloatProgress(value=0.00192948912…

0,1
CV_score,6.02264
_runtime,2.0
_timestamp,1618243660.0
_step,0.0


0,1
CV_score,▁
_runtime,▁
_timestamp,▁
_step,▁


In [None]:
import json
import matplotlib.pylab as plt

def plot_preds(
    site,
    floorNo,
    sub=None,
    true_locs=None,
    base=str(DATA_DIR/'indoor-location-navigation'),
    show_train=True,
    show_preds=True,
    fix_labels=True,
    map_floor=None
):
    """
    Plots predictions on floorplan map.
    
    map_floor : use a different floor's map
    """
    if map_floor is None:
        map_floor = floorNo
    # Prepare width_meter & height_meter (taken from the .json file)
    floor_plan_filename = f"{base}/metadata/{site}/{map_floor}/floor_image.png"
    json_plan_filename = f"{base}/metadata/{site}/{map_floor}/floor_info.json"
    with open(json_plan_filename) as json_file:
        json_data = json.load(json_file)

    width_meter = json_data["map_info"]["width"]
    height_meter = json_data["map_info"]["height"]

    floor_img = plt.imread(f"{base}/metadata/{site}/{map_floor}/floor_image.png")

    fig, ax = plt.subplots(figsize=(12, 12))
    plt.imshow(floor_img)

    if show_train:
        true_locs = true_locs.query('site == @site and floorNo == @map_floor').copy()
        true_locs["x_"] = true_locs["x"] * floor_img.shape[0] / height_meter
        true_locs["y_"] = (
            true_locs["y"] * -1 * floor_img.shape[1] / width_meter
        ) + floor_img.shape[0]
        true_locs.query("site == @site and floorNo == @map_floor").groupby("path").plot(
            x="x_",
            y="y_",
            style="+",
            ax=ax,
            label="train waypoint location",
            color="grey",
            alpha=0.5,
        )

    if show_preds:
        sub = sub.query('site == @site and floorNo == @floorNo').copy()
        sub["x_"] = sub["x"] * floor_img.shape[0] / height_meter
        sub["y_"] = (
            sub["y"] * -1 * floor_img.shape[1] / width_meter
        ) + floor_img.shape[0]
        for path, path_data in sub.query(
            "site == @site and floorNo == @floorNo"
        ).groupby("path"):
            path_data.plot(
                x="x_",
                y="y_",
                style=".-",
                ax=ax,
                title=f"{site} - floor - {floorNo}",
                alpha=1,
                label=path,
            )
    if fix_labels:
        handles, labels = ax.get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        plt.legend(
            by_label.values(), by_label.keys(), loc="center left", bbox_to_anchor=(1, 0.5)
        )
    return fig, ax

In [None]:
example_site = '5dbc1d84c1eb61796cf7c010'
example_floorNo = 'F3'

sub_df = sub_process(sub_df, train_waypoints)
plot_preds(example_site, example_floorNo, sub_df,
           train_waypoints, show_preds=True)

In [None]:
plot_preds(example_site, example_floorNo, sub_df_cm,
           train_waypoints, show_preds=True)

In [None]:
plot_preds(example_site, example_floorNo, sub_pp,
           train_waypoints, show_preds=True)