# LSTM baseline

from kuto

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

In [2]:
import os
import sys
import glob
import pickle
import random

In [3]:
import numpy as np
import pandas as pd
import scipy.stats as stats
from pathlib import Path


In [4]:
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder

In [5]:
import pytorch_lightning as pl
# from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping

import wandb
from pytorch_lightning.loggers import WandbLogger

In [6]:
sys.path.append('../../')
import src.utils as utils

In [7]:
import multiprocessing
import scipy.interpolate
import scipy.sparse
from tqdm import tqdm

from indoor_location_competition_20.io_f import read_data_file
import indoor_location_competition_20.compute_f as compute_f

In [8]:
DATA_DIR = Path("/home/knikaido/work/Indoor-Location-Navigation/data/")
WIFI_DIR = DATA_DIR / 'indoorunifiedwifids_original'
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'
OUTPUT_DIR = Path('./output/')
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'

## config

In [9]:
configs = {
    'loss':{
        'name': 'MSELoss',
        'params':{}
    },
    'optimizer':{
        'name': 'Adam',
        'params':{
            'lr': 0.01,
        }
    },

    'scheduler':{
        'name': 'ReduceLROnPlateau',
        'params':{
            'factor': 0.1,
            'patience': 3,
        }
    },

    'loader':{
        'train':{
            'batch_size': 512,
            'shuffle': True,
            'num_workers': 4,
        },
        'valid':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        },
        'test':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        }
    }
}

In [10]:
# config
config = configs

# globals variable
SEED = 777
MAX_EPOCHS = 500
N_SPLITS = 5
DEBUG = False
# EXP_MESSAGE = config['globals']['exp_message']

EXP_NAME = 43
IS_SAVE = True

utils.set_seed(SEED)

In [11]:
!wandb login e8aaf98060af90035c3c28a83b34452780aeec20

/bin/sh: 1: wandb: not found


## read data

In [12]:
train_df = pd.read_csv(WIFI_DIR / 'train_10_th10000_200_withcount.csv')
test_df = pd.read_csv(WIFI_DIR / 'test_10_th10000_200_withcount.csv')

In [13]:
sub = pd.read_csv(DATA_DIR/'indoor-location-navigation/sample_submission.csv', index_col=0)

BSSIDとRSSIは100ずつ存在しているけど全てが必要なわけではないみたい  
ここでは20だけ取り出している。

In [14]:
# training target features
NUM_FEATS = 200
NUM_FEATS_EMB = 80
BSSID_FEATS = [f'bssid_{i}' for i in range(NUM_FEATS)]
RSSI_FEATS  = [f'rssi_{i}' for i in range(NUM_FEATS)]

bssid_NはN個目のBSSIDを示しておりRSSI値が大きい順に番号が振られている。
100個しかない


In [15]:
# get numbers of bssids to embed them in a layer

# train
wifi_bssids = []
# bssidを列ごとにリストに入れていく
for i in BSSID_FEATS:
    wifi_bssids.extend(train_df.loc[:,i].values.tolist())
wifi_bssids = list(set(wifi_bssids))

train_wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(train): {train_wifi_bssids_size}')

# test
wifi_bssids_test = []
for i in BSSID_FEATS:
    wifi_bssids_test.extend(test_df.loc[:,i].values.tolist())
wifi_bssids_test = list(set(wifi_bssids_test))

test_wifi_bssids_size = len(wifi_bssids_test)
print(f'BSSID TYPES(test): {test_wifi_bssids_size}')


wifi_bssids.extend(wifi_bssids_test)
# wifi_bssids = list(set(wifi_bssids))
wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(all): {wifi_bssids_size}')


BSSID TYPES(train): 63084
BSSID TYPES(test): 33413
BSSID TYPES(all): 96497


In [16]:
# get numbers of bssids to embed them in a layer

# train
rssi_bssids = []
# bssidを列ごとにリストに入れていく
for i in RSSI_FEATS:
    rssi_bssids.extend(train_df.loc[:,i].values.tolist())
rssi_bssids = list(set(rssi_bssids))

train_rssi_bssids_size = len(rssi_bssids)
print(f'RSSI TYPES(train): {train_rssi_bssids_size}')

# test
rssi_bssids_test = []
for i in RSSI_FEATS:
    rssi_bssids_test.extend(test_df.loc[:,i].values.tolist())
rssi_bssids_test = list(set(rssi_bssids_test))

test_rssi_bssids_size = len(rssi_bssids_test)
print(f'RSSI TYPES(test): {test_rssi_bssids_size}')


rssi_bssids.extend(rssi_bssids_test)
# rssi_bssids = list(set(rssi_bssids))
rssi_bssids_size = len(rssi_bssids)
print(f'RSSI TYPES(all): {rssi_bssids_size}')

RSSI TYPES(train): 96
RSSI TYPES(test): 77
RSSI TYPES(all): 173


## PreProcess

In [17]:
# preprocess

le = LabelEncoder()
le.fit(wifi_bssids)
le_site = LabelEncoder()
le_site.fit(train_df['site_id'])
le_rssi = LabelEncoder()
le_rssi.fit(rssi_bssids)

ss = StandardScaler()
ss.fit(train_df.loc[:,RSSI_FEATS])


def preprocess(input_df, le=le, le_site=le_site, ss=ss):
    output_df = input_df.copy()
    # RSSIの正規化
#     output_df.loc[:,RSSI_FEATS] = ss.transform(input_df.loc[:,RSSI_FEATS])

    # BSSIDのLE(1からふる)
    for i in BSSID_FEATS:
        output_df.loc[:,i] = le.transform(input_df.loc[:,i])
#         output_df.loc[:,i] = output_df.loc[:,i] + 1  # 0からではなく1から番号を振りたいため なぜ？
    for i in RSSI_FEATS:
        output_df.loc[:,i] = le_rssi.transform(input_df.loc[:,i])

    # site_idのLE
    output_df.loc[:, 'site_id'] = le_site.transform(input_df.loc[:, 'site_id'])

    # なぜ２重でやる？
#     output_df.loc[:,RSSI_FEATS] = ss.transform(output_df.loc[:,RSSI_FEATS])
    return output_df

train = preprocess(train_df)
test = preprocess(test_df)

train  

  return self.partial_fit(X, y)


Unnamed: 0,ssid_0,ssid_1,ssid_2,ssid_3,ssid_4,ssid_5,ssid_6,ssid_7,ssid_8,ssid_9,...,frequency_198,frequency_199,raw_count,wp_tmestamp,x,y,floor,floor_str,path_id,site_id
0,b7e6027447eb1f81327d66cfd3adbe557aabf26c,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,cef6dc5e595dd99c3b2c605de65cfc1f147e892b,7182afc4e5c212133d5d7d76eb3df6c24618302b,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,8c1562bec17e1425615f3402f72dded3caa42ce5,da39a3ee5e6b4b0d3255bfef95601890afd80709,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,b7e6027447eb1f81327d66cfd3adbe557aabf26c,...,0,0,200,1578469851129,157.99141,102.125390,-1.0,B1,5e158ef61506f2000638fd1f,0
1,b7e6027447eb1f81327d66cfd3adbe557aabf26c,cef6dc5e595dd99c3b2c605de65cfc1f147e892b,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,7182afc4e5c212133d5d7d76eb3df6c24618302b,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,da39a3ee5e6b4b0d3255bfef95601890afd80709,b7e6027447eb1f81327d66cfd3adbe557aabf26c,...,0,0,200,1578469857653,162.93443,106.413020,-1.0,B1,5e158ef61506f2000638fd1f,0
2,cef6dc5e595dd99c3b2c605de65cfc1f147e892b,b7e6027447eb1f81327d66cfd3adbe557aabf26c,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,b7e6027447eb1f81327d66cfd3adbe557aabf26c,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,7182afc4e5c212133d5d7d76eb3df6c24618302b,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,...,0,0,200,1578469857653,162.93443,106.413020,-1.0,B1,5e158ef61506f2000638fd1f,0
3,cef6dc5e595dd99c3b2c605de65cfc1f147e892b,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,b7e6027447eb1f81327d66cfd3adbe557aabf26c,da39a3ee5e6b4b0d3255bfef95601890afd80709,7182afc4e5c212133d5d7d76eb3df6c24618302b,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,b7e6027447eb1f81327d66cfd3adbe557aabf26c,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,...,2452,2412,200,1578469857653,162.93443,106.413020,-1.0,B1,5e158ef61506f2000638fd1f,0
4,da39a3ee5e6b4b0d3255bfef95601890afd80709,da39a3ee5e6b4b0d3255bfef95601890afd80709,d839a45ebe64ab48b60a407d837fb01d3c0dfef9,b7e6027447eb1f81327d66cfd3adbe557aabf26c,7182afc4e5c212133d5d7d76eb3df6c24618302b,5731b8e08abc69d4c4d685c58164059207c93310,b6ffe5619e02871fcd04f61c9bb4b5c53a3f46b7,ea4a14e0d5bcdd20703fbe3bbc90f70b171ff140,b9f0208be00bd8b337be7f12e02e3a3ce846e22b,7182afc4e5c212133d5d7d76eb3df6c24618302b,...,5745,5745,200,1578469862177,168.49713,109.861336,-1.0,B1,5e158ef61506f2000638fd1f,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
251108,ea7731d04cf9ed352d4805b1ff904bebdf60eb49,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,4abd3985ba804364272767c04cdc211615f77c56,b5d43f6b4e1938ed497c7b589c6eae9ac0bee168,0a8a55cf161bc4980194ec9f7f7a448439be4b74,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,da39a3ee5e6b4b0d3255bfef95601890afd80709,1f09251bbfadafb11c63c87963af25238d6bc886,1556355684145fce5e67ba749d943a180266ad90,...,0,0,200,1573733061352,203.53165,143.513960,6.0,F7,5dcd5c9323759900063d590a,23
251109,ea7731d04cf9ed352d4805b1ff904bebdf60eb49,4abd3985ba804364272767c04cdc211615f77c56,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,b5d43f6b4e1938ed497c7b589c6eae9ac0bee168,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,4abd3985ba804364272767c04cdc211615f77c56,1556355684145fce5e67ba749d943a180266ad90,...,0,0,200,1573733070079,192.57130,145.781450,6.0,F7,5dcd5c9323759900063d590a,23
251110,4abd3985ba804364272767c04cdc211615f77c56,ea7731d04cf9ed352d4805b1ff904bebdf60eb49,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,1556355684145fce5e67ba749d943a180266ad90,b5d43f6b4e1938ed497c7b589c6eae9ac0bee168,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,5d998a8668536c4f51004c25f474117fe9555f78,...,0,0,200,1573733070079,192.57130,145.781450,6.0,F7,5dcd5c9323759900063d590a,23
251111,4abd3985ba804364272767c04cdc211615f77c56,1556355684145fce5e67ba749d943a180266ad90,ea7731d04cf9ed352d4805b1ff904bebdf60eb49,4abd3985ba804364272767c04cdc211615f77c56,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,b5d43f6b4e1938ed497c7b589c6eae9ac0bee168,0a8a55cf161bc4980194ec9f7f7a448439be4b74,5d998a8668536c4f51004c25f474117fe9555f78,1556355684145fce5e67ba749d943a180266ad90,ad82e27aa3cd9f276fd3a5146fa8c7c5e5b5207d,...,0,0,200,1573733070079,192.57130,145.781450,6.0,F7,5dcd5c9323759900063d590a,23


In [18]:
site_count = len(train['site_id'].unique())
site_count

24

## PyTorch model
- embedding layerが重要  

In [19]:
# dataset
from torch.utils.data import Dataset, DataLoader
class IndoorDataset(Dataset):
    def __init__(self, df, phase='train'):
        self.df = df
        self.phase = phase
        self.bssid_feats = df[BSSID_FEATS].values.astype(int)
        self.rssi_feats = df[RSSI_FEATS].values.astype(int)
        self.site_id = df['site_id'].values.astype(int)
        self.raw_count = df['raw_count'].values.astype(int)

        if phase in ['train', 'valid']:
            self.xy = df[['x', 'y']].values.astype(np.float32)
            self.floor = df['floor'].values.astype(np.float32)
        
    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        
        bssid = self.bssid_feats[idx]
        rssi = self.rssi_feats[idx]
        
        if self.phase != 'test':
            if np.random.rand() < 0.3:
                tyohuku = np.where(rssi == pd.Series(rssi[:min(self.raw_count[idx], NUM_FEATS)]).value_counts().index[0])[0]
                indexes = np.arange(len(rssi))
                indexes[tyohuku[0]:tyohuku[-1]+1] = np.random.permutation(tyohuku)
                bssid = bssid[indexes]
                rssi = rssi[indexes]
        
        concat_feat = np.empty(2 * NUM_FEATS_EMB).astype(int)
        concat_feat[0::2] = bssid[:NUM_FEATS_EMB]
        concat_feat[1::2] = rssi[:NUM_FEATS_EMB]
        
        feature = {
            'RSSI_BSSID_FEATS':concat_feat,
            'site_id':self.site_id[idx]
        }
        if self.phase in ['train', 'valid']:
            target = {
                'xy':self.xy[idx],
                'floor':self.floor[idx]
            }
        else:
            target = {}
        return feature, target

In [20]:
import torch
from torch import nn

class LSTMModel(nn.Module):
    def __init__(self, bssid_size=94248, site_size=24, embedding_dim=64):
        super(LSTMModel, self).__init__()
        
        # bssid
        # ->64次元に圧縮後sequence化にする
        # wifi_bssids_sizeが辞書の数を表す
        self.bssid_embedding = nn.Embedding(bssid_size, 64, max_norm=True)
        # site
        # ->2次元に圧縮後sequence化する
        # site_countが辞書の数を表す       
        self.site_embedding = nn.Embedding(site_size, 64, max_norm=True)

        # rssi
        # 次元を64倍に線形変換
#         self.rssi_embedding = nn.Embedding(rssi_size, 64, max_norm=True)
#         self.rssi = nn.Sequential(
#             nn.BatchNorm1d(NUM_FEATS),
#             nn.Linear(NUM_FEATS, NUM_FEATS * 64)
#         )
        
        concat_size = 64 + (2 * NUM_FEATS_EMB * 64)
        self.linear_layer2 = nn.Sequential(
            nn.BatchNorm1d(concat_size),
            nn.Dropout(0.3),
            nn.Linear(concat_size, 256),
            nn.ReLU()
        )
        self.bn1 = nn.BatchNorm1d(concat_size)

        self.flatten = nn.Flatten()

        self.dropout1 = nn.Dropout(0.3)
        self.linear1 = nn.Linear(in_features=concat_size, out_features=256)#, bias=False)
        self.bn2 = nn.BatchNorm1d(256)

        self.batch_norm1 = nn.BatchNorm1d(1)
        self.lstm1 = nn.LSTM(input_size=256,hidden_size=128,dropout=0.3, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=128,hidden_size=16,dropout=0.1, batch_first=True)

        self.fc_xy = nn.Linear(16, 2)
        # self.fc_x = nn.Linear(16, 1)
        # self.fc_y = nn.Linear(16, 1)
        self.fc_floor = nn.Linear(16, 1)

    
    def forward(self, x):
        # input embedding
        batch_size = x["site_id"].shape[0]
        x_bssid = self.bssid_embedding(x['RSSI_BSSID_FEATS'])
        x_bssid = self.flatten(x_bssid)
        
        x_site_id = self.site_embedding(x['site_id'])
        x_site_id = self.flatten(x_site_id)

#         x_rssi = self.rssi_embedding(x['RSSI_FEATS'])
#         x_rssi = self.flatten(x_rssi)

        x = torch.cat([x_bssid, x_site_id], dim=1)
        x = self.linear_layer2(x)

        # lstm layer
        x = x.view(batch_size, 1, -1)  # [batch, 1]->[batch, 1, 1]
        x = self.batch_norm1(x)
        x, _ = self.lstm1(x)
        x = torch.relu(x)
        x, _ = self.lstm2(x)
        x = torch.relu(x)

        # output [batch, 1, 1] -> [batch]
        # x_ = self.fc_x(x).view(-1)
        # y_ = self.fc_y(x).view(-1)
        xy = self.fc_xy(x).squeeze(1)
        floor = torch.relu(self.fc_floor(x)).view(-1)
        # return {"x":x_, "y":y_, "floor":floor} 
        return {"xy": xy, "floor": floor}

In [21]:
def mean_position_error(xhat, yhat, fhat, x, y, f):
    intermediate = np.sqrt(np.power(xhat-x, 2) + np.power(yhat-y, 2)) + 15 * np.abs(fhat-f)
    return intermediate.sum()/xhat.shape[0]

def to_np(input):
    return input.detach().cpu().numpy()

In [22]:
def get_optimizer(model: nn.Module, config: dict):
    optimizer_config = config["optimizer"]
    optimizer_name = optimizer_config.get("name")
    base_optimizer_name = optimizer_config.get("base_name")
    optimizer_params = optimizer_config['params']

    if hasattr(optim, optimizer_name):
        optimizer = optim.__getattribute__(optimizer_name)(model.parameters(), **optimizer_params)
        return optimizer
    else:
        base_optimizer = optim.__getattribute__(base_optimizer_name)
        optimizer = globals().get(optimizer_name)(
            model.parameters(), 
            base_optimizer,
            **optimizer_config["params"])
        return  optimizer

def get_scheduler(optimizer, config: dict):
    scheduler_config = config["scheduler"]
    scheduler_name = scheduler_config.get("name")

    if scheduler_name is None:
        return
    else:
        return optim.lr_scheduler.__getattribute__(scheduler_name)(
            optimizer, **scheduler_config["params"])


def get_criterion(config: dict):
    loss_config = config["loss"]
    loss_name = loss_config["name"]
    loss_params = {} if loss_config.get("params") is None else loss_config.get("params")
    if hasattr(nn, loss_name):
        criterion = nn.__getattribute__(loss_name)(**loss_params)
    else:
        criterion = globals().get(loss_name)(**loss_params)

    return criterion

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

In [23]:
# Learner class(pytorch-lighting)
class Learner(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.xy_criterion = get_criterion(config)
        self.f_criterion = get_criterion(config)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = self.xy_criterion(output["xy"], y["xy"])
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        xy_loss = self.xy_criterion(output["xy"], y["xy"])
        f_loss = self.f_criterion(output["floor"], y["floor"])
        loss = xy_loss  # + f_loss
        mpe = mean_position_error(
            to_np(output['xy'][:, 0]), to_np(output['xy'][:, 1]), 0, 
            to_np(y['xy'][:, 0]), to_np(y['xy'][:, 1]), 0)
        
        # floor lossは現状は無視して良い
        self.log(f'Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/xy', xy_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/floor', f_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'MPE/val', mpe, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return mpe
    
    def validation_epoch_end(self, outputs):
        avg_loss = np.mean(outputs)
        print(f'epoch = {self.current_epoch}, mpe_loss = {avg_loss}')

    def configure_optimizers(self):
        optimizer = get_optimizer(self.model, self.config)
        scheduler = get_scheduler(optimizer, self.config)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "Loss/val"}

In [24]:
# oof
def evaluate(model, loaders, phase):
    x_list = []
    y_list = []
    f_list = []
    with torch.no_grad():
        for batch in loaders[phase]:
            x, y = batch
            output = model(x)
            x_list.append(to_np(output['xy'][:, 0]))
            y_list.append(to_np(output['xy'][:, 1]))
            f_list.append(to_np(output['floor']))

    x_list = np.concatenate(x_list)
    y_list = np.concatenate(y_list)
    f_list = np.concatenate(f_list)
    return x_list, y_list, f_list

## train

In [25]:
oofs = np.zeros((len(train), 2), dtype = np.float32)  # 全てのoofをdfで格納する
predictions = []  # 全ての予測値をdfで格納する
val_scores = []
# skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
gkf = GroupKFold(n_splits=N_SPLITS)
# for fold, (trn_idx, val_idx) in enumerate(skf.split(train.loc[:, 'path'], train.loc[:, 'path'])):
for fold, (trn_idx, val_idx) in enumerate(gkf.split(train.loc[:, 'path_id'], groups=train.loc[:, 'path_id'])):

    # 指定したfoldのみループを回す

    print('=' * 20)
    print(f'Fold {fold}')
    print('=' * 20)

    # train/valid data
    trn_df = train.loc[trn_idx, BSSID_FEATS + RSSI_FEATS + ['raw_count', 'site_id', 'x','y','floor']].reset_index(drop=True)
    val_df = train.loc[val_idx, BSSID_FEATS + RSSI_FEATS + ['raw_count', 'site_id', 'x','y','floor']].reset_index(drop=True)

    # data loader
    loaders = {}
    loader_config = config["loader"]
    loaders["train"] = DataLoader(IndoorDataset(trn_df, phase="train"), **loader_config["train"], worker_init_fn=worker_init_fn) 
    loaders["valid"] = DataLoader(IndoorDataset(val_df, phase="valid"), **loader_config["valid"], worker_init_fn=worker_init_fn)
    loaders["test"] = DataLoader(IndoorDataset(test, phase="test"), **loader_config["test"], worker_init_fn=worker_init_fn)
    
    # model
    model = LSTMModel(wifi_bssids_size+rssi_bssids_size, site_count)
    model_name = model.__class__.__name__
    
    # loggers
    RUN_NAME = f'exp{str(EXP_NAME)}'
    wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type=RUN_NAME + f'-fold-{fold}')
    wandb.run.name = RUN_NAME + f'-fold-{fold}'
    wandb_config = wandb.config
    wandb_config.model_name = model_name
    wandb.watch(model)
    
    
    loggers = []
    loggers.append(WandbLogger())

    learner = Learner(model, config)
    
    # callbacks
    callbacks = []
    checkpoint_callback = ModelCheckpoint(
        monitor=f'Loss/val',
        mode='min',
        dirpath=OUTPUT_DIR,
        verbose=False,
        filename=f'{model_name}-{learner.current_epoch}-{fold}')
    callbacks.append(checkpoint_callback)

    early_stop_callback = EarlyStopping(
        monitor='Loss/val',
        min_delta=0.00,
        patience=20,
        verbose=True,
        mode='min')
    callbacks.append(early_stop_callback)
    
    trainer = pl.Trainer(
        logger=loggers,
        callbacks=callbacks,
        max_epochs=MAX_EPOCHS,
        default_root_dir=OUTPUT_DIR,
        gpus=1,
        fast_dev_run=DEBUG,
        deterministic=True,
        benchmark=True,
#         precision=16,
#         progress_bar_refresh_rate=0  # vscodeの時progress barの動作が遅いので表示しない
        )


    trainer.fit(learner, train_dataloader=loaders['train'], val_dataloaders=loaders['valid'])

    #############
    # validation (to make oof)
    #############
    model.eval()
    oof_x, oof_y, oof_f = evaluate(model, loaders, phase="valid")
    oofs[val_idx, 0] = oof_x
    oofs[val_idx, 1] = oof_y

    
    val_score = mean_position_error(
        oof_x, oof_y, 0,
        val_df['x'].values, val_df['y'].values, 0)
    val_scores.append(val_score)
    print(f"fold {fold}: mean position error {val_score}")

    #############
    # inference
    #############
    preds_x, preds_y, preds_f = evaluate(model, loaders, phase="test")
    test_preds = pd.DataFrame(np.stack((preds_f, preds_x, preds_y))).T
    test_preds.columns = sub.columns
    test_preds["site_path_timestamp"] = test["site_path_timestamp"]
    test_preds["floor"] = test_preds["floor"].astype(int)
    predictions.append(test_preds)
    

Fold 0


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))
[34m[1mwandb[0m: Currently logged in as: [33msqrt4kaido[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.27 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.7 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.853    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 165.76107025146484


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 128.15609688659444


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 101.8420923139981


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 88.60386714242541


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 83.54742989129157


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 82.26042180293969


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 82.14991872981366


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 61.05881417563672


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 55.44368725081247


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 52.95973851638227


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 50.81757233236621


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 49.663070475391585


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 38.89993947740527


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 30.5987829648626


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 25.07132397983091


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 21.057183110839496


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 18.588058731465242


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 16.183237605317288


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 14.199433055252971


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 13.464134222559844


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 12.697578943700602


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 11.864304386260253


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 11.306422810643662


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 11.999736666192605


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 10.360296932927978


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 9.99236094677497


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 9.671869519190867


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 9.391724760292064


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 9.50841221890346


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 9.279137630552828


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 8.827249985470866


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 8.838616850144014


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 8.527716733295893


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 12.727327382464628


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 8.396872012353368


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 8.156014158883044


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 8.022965764784182


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 8.316738171191107


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 8.08713803710462


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 8.612432272234912


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 8.433000098805824


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 8.198811089380207


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 8.138828652150497


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 7.997845270032655


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 8.198026965977734


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 7.949030318466177


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 7.991575101716735


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 8.508665139767885


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 8.335968278695916


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 7.940641403095762


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 7.348090277982736


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 7.287501660311527


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 7.2332409107118885


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 7.200745656365814


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 7.186207815367459


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 7.183946047806878


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 7.154546785938748


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 7.162778063535741


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 7.166738241020644


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 7.131001164522218


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 7.170235988941598


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 7.155275939127309


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 7.132156270308159


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 7.111174437401757


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 7.105548206258019


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 7.106398198877692


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 7.10943408338275


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 7.101488962958455


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 7.124444872339438


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 7.10664435240522


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 7.075672070595261


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 7.056560013285509


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 7.046312092351267


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 7.054810419406477


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 7.056851470288527


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 7.044776471589115


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 7.0512539032272645


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 7.0392803357233165


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 7.046124821244171


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 7.050636173396271


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 7.043907994894015


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 7.0529337499107365


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 7.050322962247914


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 7.0451707040113245


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 7.045564059567611


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 7.049176002097033


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 7.047470690638897


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 7.048671252748517


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 7.045448736096105


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 7.048034614474446


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 7.046639291184091


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 7.054345888653905


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 7.041269995847005


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 7.0523371226091225


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 7.052889547526337


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 7.0470105953019795


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 7.049643890385779


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 7.047360315134314
fold 0: mean position error 7.058994149399365
Fold 1


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,41.82158
Loss/xy,41.82158
Loss/floor,4.8581
MPE/val,7.05899
epoch,96.0
trainer/global_step,38120.0
_runtime,1475.0
_timestamp,1619153939.0
_step,96.0


0,1
Loss/val,█▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,███▅▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
MPE/val,█▆▅▄▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


[34m[1mwandb[0m: wandb version 0.10.27 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.7 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.853    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 141.1328468322754


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 139.65845244534933


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 117.28800261264048


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 101.01453177400643


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 80.56472068506237


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 64.67406606766426


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 56.54222046265724


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 49.14760028032392


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 44.25856235773877


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 39.65974244081459


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 35.705290155720256


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 31.48872875321841


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 27.8245288353696


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 24.538305292225846


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 21.610528401551697


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 19.128991205206027


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 17.754910601458704


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 14.706079826953461


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 13.308940563513442


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 13.787017434278201


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 11.702704100769719


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 11.564382828463028


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 12.083750082655873


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 10.007510073026637


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 10.173038309578125


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 9.528949151161074


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 10.51959224129502


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 8.998331019321412


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 9.178463472644667


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 8.69955566460769


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 8.91242224770978


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 8.650530040686858


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 8.460257601025743


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 8.634771109600875


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 8.477457343494407


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 8.527865687468065


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 9.158885273194533


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 8.944440294229263


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 8.408862407248993


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 8.345497911934082


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 8.288389333929524


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 8.319521734900508


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 8.305535243501874


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 11.875412927554189


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 7.96893596218756


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 7.900703968500645


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 7.814570064368413


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 7.9808136201945015


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 8.42691802445622


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 8.213643969809038


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 8.471724518391232


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 7.4302777414959325


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 7.342169974712616


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 7.2975232776918135


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 7.30830157651149


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 7.268365234757864


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 7.275087226110967


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 7.259440271515706


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 7.252744756461265


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 7.223677779893938


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 7.228151337910334


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 7.243001609316903


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 7.227855447182573


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 7.240735146247673


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 7.205430876483668


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 7.212316470225089


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 7.197476812313891


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 7.203085805276837


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 7.195742767087213


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 7.205808652028867


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 7.1964576819776696


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 7.196332330373751


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 7.198031805572575


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 7.19378573300632


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 7.193654413200721


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 7.193119384077782


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 7.19419134372848


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 7.195121369668291


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 7.192315216195627


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 7.1907348980218035
fold 1: mean position error 7.182029051910519
Fold 2


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,40.65705
Loss/xy,40.65705
Loss/floor,5.25952
MPE/val,7.18203
epoch,78.0
trainer/global_step,31046.0
_runtime,1204.0
_timestamp,1619155151.0
_step,78.0


0,1
Loss/val,█▅▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▅▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
MPE/val,█▆▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


[34m[1mwandb[0m: wandb version 0.10.27 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.7 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.853    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 170.21638870239258


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 136.0650980900417


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 114.24931284865332


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 98.5636090521092


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 86.07186693584639


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 68.44458391126857


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 59.48639373750603


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 54.53459206791609


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 50.22583616335828


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 43.49666847620066


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 38.608388188114326


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 34.76197419534466


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 31.114212361710216


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 27.607628953090515


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 25.913152820137068


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 22.304327864967675


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 20.184367040805913


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 18.05104009448752


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 16.411934019749122


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 15.519179313945791


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 14.034007658734005


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 13.122778097686014


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 12.78681787459691


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 12.618283235921723


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 10.914962053657629


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 10.777384592491199


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 9.99112818765302


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 9.708623172338614


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 10.172802155796186


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 9.118804123692172


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 9.222008797678157


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 9.019549289775515


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 9.027173386013638


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 9.331937342901371


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 8.713147298869025


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 8.62725273809099


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 8.325194147512935


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 8.536920712477363


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 8.371260313841187


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 8.519994431353629


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 8.60436930001589


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 8.596157214474017


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 8.411048558151341


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 8.183836227587518


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 8.092002657544555


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 9.377500765046431


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 8.55908877870587


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 8.19383822956983


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 7.9290218403486445


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 7.970960138793559


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 8.195256751659171


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 8.231315210868331


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 8.63534641491324


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 7.640485828624293


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 7.508844866179763


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 7.451542057018497


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 7.401309574646974


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 7.407048731241692


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 7.395404484152614


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 7.387416569668581


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 7.377965280170879


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 7.351538125965996


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 7.3389264936219645


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 7.33203410456674


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 7.338821937467574


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 7.336394778226438


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 7.353948624908296


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 7.342243811199288


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 7.309265513330948


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 7.29592333889356


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 7.294052083339072


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 7.290389516085623


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 7.282982380447043


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 7.2854674353846525


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 7.287392329329346


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 7.280581150366469


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 7.28305957631452


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 7.287052978271315


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 7.280267704488794


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 7.283998327359779


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 7.28173771699782


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 7.279621380671312


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 7.287126428000317


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 7.277851829676379


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 7.2783382290951435


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 7.279680867669609


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 7.274801762673308


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 7.286642956318662


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 7.281539075635414


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 7.276458503506707


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 7.274496909190115


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 7.281988338631966


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 7.279107917131415


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 7.272323275954908


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 7.2715223853674695


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 7.2772863516365875


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 7.279194784195173


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 7.272473300013162


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 7.275597262736061


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 7.276708989150546


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 7.280768330420521


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 7.276866531925459


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 7.278692569831046


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 7.2867758276640355


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 7.278547778479033


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 7.282587084082565


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 7.275385572145299


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 7.277576407181072


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 7.2747645484691486


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 7.277940312432085


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 7.286266995788056


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 7.283217373546057


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 7.277390431466216


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 7.281071535315842


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 7.274931263980009
fold 2: mean position error 7.248996071681983
Fold 3


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,43.6517
Loss/xy,43.6517
Loss/floor,4.78346
MPE/val,7.249
epoch,113.0
trainer/global_step,44801.0
_runtime,1745.0
_timestamp,1619156904.0
_step,113.0


0,1
Loss/val,█▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,███▅▃▃▁▁▁▂▂▃▂▂▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
MPE/val,█▆▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


[34m[1mwandb[0m: wandb version 0.10.27 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.7 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.853    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 129.30083465576172


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 134.4047525219086


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 110.74695031036659


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 95.06788544367949


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 80.22321747791155


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 68.14116473024123


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 55.711721979844036


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 44.661983123095766


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 37.50457131742227


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 32.25412362491975


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 27.379043785165578


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 23.03416986951296


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 20.1041429554835


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 17.553826639722406


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 15.521352126507264


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 14.997154732780809


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 13.820959426459378


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 12.28059214125823


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 11.314013987085932


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 10.704214453958974


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 10.828071273415038


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 9.944741587289132


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 10.266930555218817


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 9.270966805226456


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 10.671377712454145


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 9.034268085193258


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 8.896621827446792


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 9.850486644800158


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 8.86182356864986


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 8.620155662415998


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 8.525109860893922


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 8.530576345530005


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 8.611299071020865


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 9.21299971181361


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 8.274199925549574


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 8.034088265440367


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 8.02699535836449


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 8.175522231092357


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 8.555499173383415


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 13.311519851902363


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 8.351829915388038


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 7.606622970884287


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 7.560605239491219


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 7.5043514875345805


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 7.481935292312227


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 7.466698429591832


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 7.441546337930163


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 7.450828102004596


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 7.45350267994964


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 7.403951570345846


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 7.437056939749536


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 7.3995106309246585


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 7.383398189636821


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 7.380705200165157


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 7.37885966968997


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 7.372113611282463


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 7.372050333923554


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 7.372823917504513


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 7.366476509620289


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 7.368005498543506


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 7.3599483182569525


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 7.353635054174934


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 7.359996202592008


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 7.358990311779681


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 7.354316941781897


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 7.347572312210545


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 7.353550853988922


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 7.348121537940251


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 7.355943281921735


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 7.347596398036523


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 7.343027585422883


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 7.354809081915653


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 7.342254358475073


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 7.352189301083262


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 7.3483528196680785


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 7.349574773836198


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 7.345424666130391


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 7.352031002245717


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 7.346599530482638


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 7.355405320606056


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 7.349028433212126


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 7.345215480789365


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 7.349753900636012


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 7.3498281806406345


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 7.345305200726723


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 7.352183414228035


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 7.353268234469045


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 7.350514339049972


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 7.344298240852607


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 7.35265535485289


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 7.350644990337172


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 7.35045543803842


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 7.352473380013835


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 7.34803722739272


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 7.3431528074258665


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 7.3481493536505615


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 7.353245626482253


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 7.350976009352143


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 7.353224179699697
fold 3: mean position error 7.35376866193105
Fold 4


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,45.23141
Loss/xy,45.23141
Loss/floor,4.68014
MPE/val,7.35377
epoch,97.0
trainer/global_step,38513.0
_runtime,1490.0
_timestamp,1619158403.0
_step,97.0


0,1
Loss/val,█▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,██▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
MPE/val,█▆▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


[34m[1mwandb[0m: wandb version 0.10.27 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 11.7 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.853    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 128.2037353515625


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 132.97237778966448


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 107.76354549242103


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 92.11378971630928


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 84.09072963511888


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 63.63630756768522


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 56.58085540181773


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 47.689476058835986


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 40.52468007313457


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 34.964058931113655


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 29.322387703270255


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 25.271932468950407


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 21.95266860506228


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 19.461046449647014


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 17.373194187945003


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 16.355721170042603


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 14.243789981109927


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 13.17537991645622


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 12.661616619815417


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 11.57888360059099


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 11.447407033498944


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 10.157826114967826


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 9.75526177678969


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 9.603937444950752


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 9.572061492415651


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 9.245711644699975


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 9.869139431964738


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 10.32968621679079


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 8.698666816329453


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 8.556978815838493


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 8.857925277453637


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 8.624632609010947


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 8.941400135092218


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 8.407096995771273


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 8.618670259801558


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 8.435482375497195


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 11.397017688985764


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 8.186459770839258


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 8.281280995043527


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 8.282953016057473


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 8.500364378978142


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 8.269145874480799


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 7.638418081406705


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 7.543741772140281


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 7.48093143303142


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 7.423008620660453


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 7.41018030169139


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 7.407477198853323


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 7.402129112025023


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 7.3901301456429636


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 7.364016569336099


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 7.338585789286363


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 7.377372902320255


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 7.354107634875666


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 7.3085843976464355


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 7.314482110605252


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 7.336440785600263


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 7.311366519123826


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 7.33953237900045


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 7.306654056855655


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 7.342541539046029


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 7.334924319044874


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 7.300726912238381


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 7.296903377360402


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 7.288653292465377


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 7.277245927391002


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 7.2761788803807566


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 7.278255481074974


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 7.285533482116412


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 7.271386515429316


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 7.2794548434859365


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 7.273602449428004


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 7.272742778102086


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 7.284513035834124


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 7.287177557268842


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 7.27450609845868


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 7.2863017843646025


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 7.283469903976916
fold 4: mean position error 7.3014489812039685


In [26]:
oofs_df = pd.DataFrame(oofs, columns=['x', 'y'])
oofs_df['path'] = train_df['path_id']
oofs_df['timestamp'] = train_df['wp_tmestamp']
oofs_df['site'] = train_df['site_id']
oofs_df['site_path_timestamp'] = oofs_df['site'] + '_' + oofs_df['path'] + '_' + oofs_df['timestamp'].astype(str)
oofs_df['floor'] = train_df['floor']
oofs_df.to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}.csv", index=False)
oofs_df

Unnamed: 0,x,y,path,timestamp,site,site_path_timestamp,floor
0,159.672821,104.919128,5e158ef61506f2000638fd1f,1578469851129,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
1,163.329376,104.348907,5e158ef61506f2000638fd1f,1578469857653,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
2,162.308899,109.900879,5e158ef61506f2000638fd1f,1578469857653,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
3,157.548828,111.771835,5e158ef61506f2000638fd1f,1578469857653,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
4,161.536011,112.530319,5e158ef61506f2000638fd1f,1578469862177,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
...,...,...,...,...,...,...,...
251108,194.566864,142.143356,5dcd5c9323759900063d590a,1573733061352,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0
251109,192.168854,142.372360,5dcd5c9323759900063d590a,1573733070079,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0
251110,188.707001,142.054688,5dcd5c9323759900063d590a,1573733070079,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0
251111,189.105972,143.832748,5dcd5c9323759900063d590a,1573733070079,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0


In [27]:
oofs_score = mean_position_error(
        oofs_df['x'], oofs_df['y'], 0,
        train_df['x'].values, train_df['y'].values, 0)
print(f"CV:{oofs_score}")

CV:7.229046598228676


In [28]:
all_preds = pd.concat(predictions).groupby('site_path_timestamp').mean()
all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474563646,0,85.759979,104.662834
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474572654,0,81.538498,102.262177
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474578963,0,83.906349,104.877182
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474582400,0,88.609116,107.953110
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474585965,0,88.815140,108.171616
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731143256,0,210.329102,98.822342
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731146426,0,208.603775,101.567291
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731151563,0,205.314331,105.555504
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731157567,0,200.148834,110.720680


In [29]:
all_preds_37 = pd.read_csv('../37/output/sub37.csv', index_col=0)
all_preds_37.index = pd.read_csv(WIFI_DIR / 'test_7_th20000.csv')['site_path_timestamp']
all_preds_37

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474563646,0,88.266884,104.794300
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474572654,0,82.316630,104.338745
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474578963,0,84.221380,105.362060
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474582400,0,87.842510,109.344190
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474585965,0,87.390120,108.134900
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731143256,5,214.121280,98.048190
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731146426,5,211.845250,100.757540
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731151563,5,208.958570,107.238950
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731157567,5,202.751280,110.968445


In [30]:
all_preds_merge = pd.merge(all_preds_37, all_preds, how='left', on='site_path_timestamp')[['floor_y', 'x_y', 'y_y']]
all_preds_merge = all_preds_merge.rename(columns={'floor_y': 'floor', 'x_y': 'x', 'y_y': 'y'})
all_preds_merge

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474563646,0.0,85.759979,104.662834
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474572654,0.0,81.538498,102.262177
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474578963,0.0,83.906349,104.877182
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474582400,0.0,88.609116,107.953110
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474585965,0.0,88.815140,108.171616
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731143256,0.0,210.329102,98.822342
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731146426,0.0,208.603775,101.567291
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731151563,0.0,205.314331,105.555504
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731157567,0.0,200.148834,110.720680


In [31]:
all_preds_merge['floor'].fillna(all_preds_37['floor'], inplace=True)
all_preds_merge['x'].fillna(all_preds_37['x'], inplace=True)
all_preds_merge['y'].fillna(all_preds_37['y'], inplace=True)
all_preds_merge

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474563646,0.0,85.759979,104.662834
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474572654,0.0,81.538498,102.262177
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474578963,0.0,83.906349,104.877182
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474582400,0.0,88.609116,107.953110
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474585965,0.0,88.815140,108.171616
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731143256,0.0,210.329102,98.822342
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731146426,0.0,208.603775,101.567291
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731151563,0.0,205.314331,105.555504
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731157567,0.0,200.148834,110.720680


In [32]:
# foldの結果を平均した後、reindexでsubmission fileにindexを合わせる
all_preds_merge.index = sub.index
all_preds_merge

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0.0,85.759979,104.662834
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0.0,81.538498,102.262177
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0.0,83.906349,104.877182
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0.0,88.609116,107.953110
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0.0,88.815140,108.171616
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,0.0,210.329102,98.822342
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,0.0,208.603775,101.567291
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,0.0,205.314331,105.555504
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,0.0,200.148834,110.720680


In [33]:
# floorの数値を置換
simple_accurate_99 = pd.read_csv('../01/submission.csv')
all_preds_merge['floor'] = simple_accurate_99['floor'].values
all_preds_merge.to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}.csv")
all_preds_merge

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,85.759979,104.662834
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,81.538498,102.262177
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,83.906349,104.877182
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,88.609116,107.953110
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,88.815140,108.171616
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,5,210.329102,98.822342
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,5,208.603775,101.567291
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,5,205.314331,105.555504
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,5,200.148834,110.720680


# Post Proccess

In [34]:
oofs_df = pd.read_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}.csv")
sub_df = pd.read_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}.csv")

In [35]:
def compute_rel_positions(acce_datas, ahrs_datas):
    step_timestamps, step_indexs, step_acce_max_mins = compute_f.compute_steps(acce_datas)
    headings = compute_f.compute_headings(ahrs_datas)
    stride_lengths = compute_f.compute_stride_length(step_acce_max_mins)
    step_headings = compute_f.compute_step_heading(step_timestamps, headings)
    rel_positions = compute_f.compute_rel_positions(stride_lengths, step_headings)
    return rel_positions

In [36]:
def correct_path(args):

    path, path_df = args
    T_ref  = path_df['timestamp'].values
    xy_hat = path_df[['x', 'y']].values
    txt_path = path_df['txt_path'].values[0]
    
    example = read_data_file(txt_path)
    rel_positions = compute_rel_positions(example.acce, example.ahrs)
    if T_ref[-1] > rel_positions[-1, 0]:
        rel_positions = [np.array([[0, 0, 0]]), rel_positions, np.array([[T_ref[-1], 0, 0]])]
    else:
        rel_positions = [np.array([[0, 0, 0]]), rel_positions]
    rel_positions = np.concatenate(rel_positions)
    
    T_rel = rel_positions[:, 0]
    delta_xy_hat = np.diff(scipy.interpolate.interp1d(T_rel, np.cumsum(rel_positions[:, 1:3], axis=0), axis=0)(T_ref), axis=0)

    N = xy_hat.shape[0]
    delta_t = np.diff(T_ref)
    alpha = (8.1)**(-2) * np.ones(N)
    beta  = (0.3 + 0.3 * 1e-3 * delta_t)**(-2)
    A = scipy.sparse.spdiags(alpha, [0], N, N)
    B = scipy.sparse.spdiags( beta, [0], N-1, N-1)
    D = scipy.sparse.spdiags(np.stack([-np.ones(N), np.ones(N)]), [0, 1], N-1, N)

    Q = A + (D.T @ B @ D)
    c = (A @ xy_hat) + (D.T @ (B @ delta_xy_hat))
    xy_star = scipy.sparse.linalg.spsolve(Q, c)

    return pd.DataFrame({
        'site_path_timestamp' : path_df['site_path_timestamp'],
        'floor' : path_df['floor'],
        'x' : xy_star[:, 0],
        'y' : xy_star[:, 1],
    })

In [37]:
tmp = sub_df['site_path_timestamp'].apply(lambda s : pd.Series(s.split('_')))
sub_df['site'] = tmp[0]
sub_df['path'] = tmp[1]
sub_df['timestamp'] = tmp[2].astype(float)

In [38]:
used_buildings = sorted(sub_df['site'].value_counts().index.tolist())
test_txts = sorted(glob.glob(str(DATA_DIR/'indoor-location-navigation') + f'/test/*.txt'))
train_txts = [sorted(glob.glob(str(DATA_DIR/'indoor-location-navigation') + f'/train/{used_building}/*/*.txt')) for used_building in used_buildings]
train_txts = sum(train_txts, [])

In [39]:
txt_pathes = []
for path in tqdm(sub_df['path'].values):
    txt_pathes.append([test_txt for test_txt in test_txts if path in test_txt][0])

100%|██████████| 10133/10133 [00:00<00:00, 27012.08it/s]


In [40]:
sub_df['txt_path'] = txt_pathes

In [41]:
processes = multiprocessing.cpu_count()
with multiprocessing.Pool(processes=processes) as pool:
    dfs = pool.imap_unordered(correct_path, sub_df.groupby('path'))
    dfs = tqdm(dfs)
    dfs = list(dfs)
sub_df_cm = pd.concat(dfs).sort_values('site_path_timestamp')

626it [01:08,  9.10it/s]


In [42]:
txt_pathes = []
for path in tqdm(oofs_df['path'].values):
    txt_pathes.append([train_txt for train_txt in train_txts if path in train_txt][0])

100%|██████████| 251113/251113 [02:32<00:00, 1648.39it/s]


In [43]:
oofs_df['txt_path'] = txt_pathes

In [44]:
processes = multiprocessing.cpu_count()
with multiprocessing.Pool(processes=processes) as pool:
    dfs = pool.imap_unordered(correct_path, oofs_df.groupby('path'))
    dfs = tqdm(dfs)
    dfs = list(dfs)
oofs_df_cm = pd.concat(dfs).sort_index()
oofs_df_cm

10789it [08:07, 22.12it/s]


Unnamed: 0,site_path_timestamp,floor,x,y
0,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,157.676557,105.222664
1,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,160.798654,108.835284
2,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,160.792444,108.841855
3,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,160.784154,108.846972
4,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,164.344130,111.724300
...,...,...,...,...
251108,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,202.175651,141.740328
251109,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,186.988907,144.806225
251110,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,186.993952,144.799463
251111,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,186.996647,144.796466


In [45]:
oofs_df_cm.to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}_cm.csv", index=False)

In [46]:
oofs_score = mean_position_error(
        oofs_df_cm['x'], oofs_df_cm['y'], 0,
        train_df['x'].values, train_df['y'].values, 0)
print(f"CV:{oofs_score}")

CV:5.605500300723079


In [47]:
sub_df_cm.to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}_cm.csv", index=False)

In [48]:
oofs_df_cm = pd.read_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}_cm.csv")

In [49]:
sub_df_cm = pd.read_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}_cm.csv")

In [50]:
def split_col(df):
    df = pd.concat([
        df['site_path_timestamp'].str.split('_', expand=True) \
        .rename(columns={0:'site',
                         1:'path',
                         2:'timestamp'}),
        df
    ], axis=1).copy()
    return df
def sub_process(sub, train_waypoints):
    train_waypoints['isTrainWaypoint'] = True
    sub = split_col(sub[['site_path_timestamp','floor','x','y']]).copy()
    sub = sub.merge(train_waypoints[['site','floorNo','floor']].drop_duplicates(), how='left')
    sub = sub.merge(
        train_waypoints[['x','y','site','floor','isTrainWaypoint']].drop_duplicates(),
        how='left',
        on=['site','x','y','floor']
             )
    sub['isTrainWaypoint'] = sub['isTrainWaypoint'].fillna(False)
    return sub.copy()

In [51]:
train_waypoints = pd.read_csv(str(DATA_DIR/'indoor-location-navigation') + '/train_waypoints.csv')


In [52]:
sub_df_cm = sub_process(sub_df_cm, train_waypoints)
oofs_df_cm = sub_process(oofs_df_cm, train_waypoints)

In [53]:
from scipy.spatial.distance import cdist

def add_xy(df):
    df['xy'] = [(x, y) for x,y in zip(df['x'], df['y'])]
    return df

def closest_point(point, points):
    """ Find closest point from a list of points. """
    return points[cdist([point], points).argmin()]

sub_df_cm = add_xy(sub_df_cm)
train_waypoints = add_xy(train_waypoints)

ds = []
for (site, myfloor), d in tqdm(sub_df_cm.groupby(['site','floor'])):
    true_floor_locs = train_waypoints.loc[(train_waypoints['floor'] == myfloor) &
                                          (train_waypoints['site'] == site)] \
        .reset_index(drop=True)
    if len(true_floor_locs) == 0:
        print(f'Skipping {site} {myfloor}')
        continue
    d['matched_point'] = [closest_point(x, list(true_floor_locs['xy'])) for x in d['xy']]
    d['x_'] = d['matched_point'].apply(lambda x: x[0])
    d['y_'] = d['matched_point'].apply(lambda x: x[1])
    ds.append(d)

sub_df_cm_ds = pd.concat(ds)


oofs_df_cm = add_xy(oofs_df_cm)
train_waypoints = add_xy(train_waypoints)

ds = []
for (site, myfloor), d in tqdm(oofs_df_cm.groupby(['site','floor'])):
    true_floor_locs = train_waypoints.loc[(train_waypoints['floor'] == myfloor) &
                                          (train_waypoints['site'] == site)] \
        .reset_index(drop=True)
    if len(true_floor_locs) == 0:
        print(f'Skipping {site} {myfloor}')
        continue
    d['matched_point'] = [closest_point(x, list(true_floor_locs['xy'])) for x in d['xy']]
    d['x_'] = d['matched_point'].apply(lambda x: x[0])
    d['y_'] = d['matched_point'].apply(lambda x: x[1])
    ds.append(d)

oofs_df_cm_ds = pd.concat(ds)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
100%|██████████| 118/118 [00:04<00:00, 26.74it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation:

In [54]:
def snap_to_grid(sub, threshold):
    """
    Snap to grid if within a threshold.
    
    x, y are the predicted points.
    x_, y_ are the closest grid points.
    _x_, _y_ are the new predictions after post processing.
    """
    sub['_x_'] = sub['x']
    sub['_y_'] = sub['y']
    sub.loc[sub['dist'] < threshold, '_x_'] = sub.loc[sub['dist'] < threshold]['x_']
    sub.loc[sub['dist'] < threshold, '_y_'] = sub.loc[sub['dist'] < threshold]['y_']
    return sub.copy()



In [55]:
# Calculate the distances
sub_df_cm_ds['dist'] = np.sqrt( (sub_df_cm_ds.x-sub_df_cm_ds.x_)**2 + (sub_df_cm_ds.y-sub_df_cm_ds.y_)**2 )
sub_pp = snap_to_grid(sub_df_cm_ds, threshold=5)
sub_pp = sub_pp[['site_path_timestamp','floor','_x_','_y_','site','path','floorNo']] \
    .rename(columns={'_x_':'x', '_y_':'y'})

# Calculate the distances
oofs_df_cm_ds['dist'] = np.sqrt( (oofs_df_cm_ds.x-oofs_df_cm_ds.x_)**2 + (oofs_df_cm_ds.y-oofs_df_cm_ds.y_)**2 )
oofs_pp = snap_to_grid(oofs_df_cm_ds, threshold=5)
oofs_pp = oofs_pp[['site_path_timestamp','floor','_x_','_y_','site','path','floorNo']] \
    .rename(columns={'_x_':'x', '_y_':'y'})

In [56]:
sub_pp = sub_pp.sort_index()
sub_pp

Unnamed: 0,site_path_timestamp,floor,x,y,site,path,floorNo
0,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,93.728470,97.948860,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
1,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,79.662285,102.766754,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
2,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,80.718400,107.197110,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
3,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,81.657740,110.509090,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
4,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,81.657740,110.509090,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
...,...,...,...,...,...,...,...
10128,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,205.596040,98.605774,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6
10129,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,206.627910,102.635210,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6
10130,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,205.511300,107.841324,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6
10131,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,190.580600,115.806564,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6


In [57]:
sub_pp[['site_path_timestamp','floor','x','y']] \
    .to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}_pp.csv", index=False)

In [58]:
oofs_pp = oofs_pp.sort_index()
oofs_pp

Unnamed: 0,site_path_timestamp,floor,x,y,site,path,floorNo
0,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,158.496950,107.122680,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
1,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.331820,110.822685,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
2,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.331820,110.822685,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
3,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.331820,110.822685,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
4,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.331820,110.822685,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
...,...,...,...,...,...,...,...
251108,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,202.503340,140.972670,5dc8cea7659e181adb076a3f,5dcd5c9323759900063d590a,F7
251109,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,186.988907,144.806225,5dc8cea7659e181adb076a3f,5dcd5c9323759900063d590a,F7
251110,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,186.993952,144.799463,5dc8cea7659e181adb076a3f,5dcd5c9323759900063d590a,F7
251111,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,186.996647,144.796466,5dc8cea7659e181adb076a3f,5dcd5c9323759900063d590a,F7


In [59]:
oofs_score = mean_position_error(
        oofs_pp['x'], oofs_pp['y'], 0,
        train_df['x'].values, train_df['y'].values, 0)
print(f"CV:{oofs_score}")

CV:5.157887849633147


In [60]:
oofs_pp[['site_path_timestamp','floor','x','y']] \
    .to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}_pp.csv", index=False)

In [61]:
wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type='summary')
wandb.run.name = 'summary'
wandb.log({'CV_score': oofs_score})
wandb.save(utils.get_notebook_path())
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,46.10779
Loss/xy,46.10779
Loss/floor,4.94275
MPE/val,7.30145
epoch,76.0
trainer/global_step,30260.0
_runtime,1178.0
_timestamp,1619159590.0
_step,76.0


0,1
Loss/val,█▆▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▆▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,███▂▁▁▁▂▂▃▃▃▄▄▄▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
MPE/val,█▇▅▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


[34m[1mwandb[0m: wandb version 0.10.27 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade




VBox(children=(Label(value=' 0.00MB of 0.35MB uploaded (0.00MB deduped)\r'), FloatProgress(value=0.00180777507…

0,1
CV_score,5.15789
_runtime,2.0
_timestamp,1619163036.0
_step,0.0


0,1
CV_score,▁
_runtime,▁
_timestamp,▁
_step,▁


In [None]:
import json
import matplotlib.pylab as plt

def plot_preds(
    site,
    floorNo,
    sub=None,
    true_locs=None,
    base=str(DATA_DIR/'indoor-location-navigation'),
    show_train=True,
    show_preds=True,
    fix_labels=True,
    map_floor=None
):
    """
    Plots predictions on floorplan map.
    
    map_floor : use a different floor's map
    """
    if map_floor is None:
        map_floor = floorNo
    # Prepare width_meter & height_meter (taken from the .json file)
    floor_plan_filename = f"{base}/metadata/{site}/{map_floor}/floor_image.png"
    json_plan_filename = f"{base}/metadata/{site}/{map_floor}/floor_info.json"
    with open(json_plan_filename) as json_file:
        json_data = json.load(json_file)

    width_meter = json_data["map_info"]["width"]
    height_meter = json_data["map_info"]["height"]

    floor_img = plt.imread(f"{base}/metadata/{site}/{map_floor}/floor_image.png")

    fig, ax = plt.subplots(figsize=(12, 12))
    plt.imshow(floor_img)

    if show_train:
        true_locs = true_locs.query('site == @site and floorNo == @map_floor').copy()
        true_locs["x_"] = true_locs["x"] * floor_img.shape[0] / height_meter
        true_locs["y_"] = (
            true_locs["y"] * -1 * floor_img.shape[1] / width_meter
        ) + floor_img.shape[0]
        true_locs.query("site == @site and floorNo == @map_floor").groupby("path").plot(
            x="x_",
            y="y_",
            style="+",
            ax=ax,
            label="train waypoint location",
            color="grey",
            alpha=0.5,
        )

    if show_preds:
        sub = sub.query('site == @site and floorNo == @floorNo').copy()
        sub["x_"] = sub["x"] * floor_img.shape[0] / height_meter
        sub["y_"] = (
            sub["y"] * -1 * floor_img.shape[1] / width_meter
        ) + floor_img.shape[0]
        for path, path_data in sub.query(
            "site == @site and floorNo == @floorNo"
        ).groupby("path"):
            path_data.plot(
                x="x_",
                y="y_",
                style=".-",
                ax=ax,
                title=f"{site} - floor - {floorNo}",
                alpha=1,
                label=path,
            )
    if fix_labels:
        handles, labels = ax.get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        plt.legend(
            by_label.values(), by_label.keys(), loc="center left", bbox_to_anchor=(1, 0.5)
        )
    return fig, ax

In [None]:
example_site = '5a0546857ecc773753327266'
example_floorNo = 'F1'

sub_df = sub_process(sub_df, train_waypoints)
plot_preds(example_site, example_floorNo, sub_df,
           train_waypoints, show_preds=True)

In [None]:
plot_preds(example_site, example_floorNo, sub_df_cm,
           train_waypoints, show_preds=True)

In [None]:
plot_preds(example_site, example_floorNo, sub_pp,
           train_waypoints, show_preds=True)