# LSTM baseline

from kuto

In [1]:
import os
import sys
import glob
import pickle
import random

import numpy as np
import pandas as pd
import scipy.stats as stats
from pathlib import Path
from glob import glob


sys.path.append('../../')
import src.utils as utils
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

import pytorch_lightning as pl
# from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping

import wandb
from pytorch_lightning.loggers import WandbLogger


In [2]:
DATA_DIR = Path("/home/knikaido/work/Indoor-Location-Navigation/data/")
WIFI_DIR = DATA_DIR / 'unified-ds-wifi-and-beacon'
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'
OUTPUT_DIR = Path('./output/')
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'

## config

In [3]:
configs = {
    'loss':{
        'name': 'MSELoss',
        'params':{}
    },
    'optimizer':{
        'name': 'Adam',
        'params':{
            'lr': 0.001,
        }
    },

    'scheduler':{
        'name': 'ReduceLROnPlateau',
        'params':{
            'factor': 0.1,
            'patience': 3,
        }
    },

    'loader':{
        'train':{
            'batch_size': 512,
            'shuffle': True,
            'num_workers': 4,
        },
        'valid':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        },
        'test':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        }
    }
}

In [4]:
# config
config = configs

# globals variable
SEED = 777
MAX_EPOCHS = 200
N_SPLITS = 5
DEBUG = False
# EXP_MESSAGE = config['globals']['exp_message']

EXP_NAME = 13
IS_SAVE = True

utils.set_seed(SEED)

## read data

In [5]:
wifi_path_list = sorted(glob(str(WIFI_DIR / '*train.csv')))
train_df = []
for path in wifi_path_list:
    train_df.append(pd.read_csv(path))
train_df = pd.concat(train_df).reset_index(drop=True)
train_df = train_df.drop('Unnamed: 0', axis=1)
train_df = train_df.rename(columns={'site':'site_id'})
train_df

Unnamed: 0,site_id,floor,path,x,y,wifi_bssid_0,wifi_rssi_0,wifi_timegap_0,wifi_bssid_1,wifi_rssi_1,...,beacon_timegap_6,beacon_macaddress_7,beacon_distance_7,beacon_timegap_7,beacon_macaddress_8,beacon_distance_8,beacon_timegap_8,beacon_macaddress_9,beacon_distance_9,beacon_timegap_9
0,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,230.03738,153.496350,c08ad78a45798cfe176a42b35c7381ae602711c5,-46,434,30f85a5e14351468a6dd13718a9da3b0d7b73685,-46,...,1421,ff9a29cb2dcb4100880d02b305d5691e578e66cf,2.083036,2827,26145e606b575396f4ca1bc439d2a9b37fdc6fa0,2.341936,398,aeb5121f95a613552e00b083ee11cd70be497ab5,3.696872,1088
1,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,231.40290,158.415150,56bc1d2557bf8225384acb992f4b1993a894db63,-47,1856,25d990fc62eb6d5a42994cd7da0f9b6a70fea57a,-47,...,2562,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,2.629927,1757,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,2.949746,2141,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,2.949746,2664
2,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,232.46200,164.416730,c0beb6ad539d9bd9333c62866c08c844ac69ab28,-48,2767,56bc1d2557bf8225384acb992f4b1993a894db63,-48,...,1351,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,5.137370,1527,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,5.137370,2667,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,5.718217,2644
3,5a0546857ecc773753327266,-1,5e15730aa280850006f3d005,233.94418,171.414170,bfaebb72653fac35c19b00e7ce484dc2897f18bd,-51,136,bfaebb72653fac35c19b00e7ce484dc2897f18bd,-52,...,100,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,6.356626,148,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,6.356626,220,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,6.356626,324
4,5a0546857ecc773753327266,-1,5e15730b1506f2000638fc29,198.36833,163.520630,24677fe9a6f29ace69792429fd85fa8f3efd0192,-49,1919,e43e8b111c0061d7cba667a23d6f1c0143bc73ac,-50,...,567,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,10.591243,2147,7063a8d2f2f49ab0d53883f4eda9923b97f97fab,11.687424,770,26145e606b575396f4ca1bc439d2a9b37fdc6fa0,12.882066,90
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
75273,5dc8cea7659e181adb076a3f,6,5dd0d97c94e4900006125dd9,249.79349,74.839640,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-40,1848,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-43,...,3000,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000
75274,5dc8cea7659e181adb076a3f,6,5dd0d97d878f3300066c750b,249.79349,74.839640,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,1875,24489f6ba05ba6afd52ac54e453fd1a2f1128b1a,-79,...,2915,3873a11805bd7adb86762e806d0f20e56e709e76,2.341936,234,-,-99.000000,3000,-,-99.000000,3000
75275,5dc8cea7659e181adb076a3f,6,5dd0d97d878f3300066c750b,249.43129,76.241234,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,404,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,...,1048,3873a11805bd7adb86762e806d0f20e56e709e76,1.850688,636,3873a11805bd7adb86762e806d0f20e56e709e76,2.341936,2045,-,-99.000000,3000
75276,5dc8cea7659e181adb076a3f,6,5dd0d97d878f3300066c750b,242.54440,72.935265,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-46,2666,5c93fe6e92a16cbcf2259aea405cafe7c073be6e,-50,...,3000,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000


In [6]:
# train_df = pd.read_csv(WIFI_DIR / 'train_all.csv')
test_df = pd.read_csv(WIFI_DIR / 'test.csv')
test_df = test_df.drop('Unnamed: 0', axis=1)
test_df = test_df.rename(columns={'site':'site_id'})
test_df['site_path_timestamp'] = test_df['site_id'] + '_' + test_df['path'] + '_' + test_df['timestamp'].map('{:013}'.format)
test_df

Unnamed: 0,timestamp,site_id,floor,path,x,y,wifi_bssid_0,wifi_rssi_0,wifi_timegap_0,wifi_bssid_1,...,beacon_macaddress_7,beacon_distance_7,beacon_timegap_7,beacon_macaddress_8,beacon_distance_8,beacon_timegap_8,beacon_macaddress_9,beacon_distance_9,beacon_timegap_9,site_path_timestamp
0,10,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,889bfa434d66eed8c386ccbc90f445932c43f8dd,-58,1170,29c7d9e757292e7b2b3d00dc4dae7514531b20b4,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
1,4048,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,11567178cc5ca582a37c4733207c77739e1bf5fd,-56,1000,889bfa434d66eed8c386ccbc90f445932c43f8dd,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
2,12526,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,11567178cc5ca582a37c4733207c77739e1bf5fd,-62,1952,11567178cc5ca582a37c4733207c77739e1bf5fd,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
3,25542,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,98d67fadac518296992afddd24e97a2855af9472,-55,2019,889bfa434d66eed8c386ccbc90f445932c43f8dd,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
4,37134,5da1389e4db8ce0c98bd0547,0,00ff0c9a71cc37a2ebdd0f05,0,0,11567178cc5ca582a37c4733207c77739e1bf5fd,-54,1977,11567178cc5ca582a37c4733207c77739e1bf5fd,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5da1389e4db8ce0c98bd0547_00ff0c9a71cc37a2ebdd0...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10128,35117,5a0546857ecc773753327266,0,ffcd9524c80c0fa5bb859eaf,0,0,6c381e7e7d8e984394ce02a3da912acc1b7d294e,-60,2011,6c381e7e7d8e984394ce02a3da912acc1b7d294e,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5a0546857ecc773753327266_ffcd9524c80c0fa5bb859...
10129,41230,5a0546857ecc773753327266,0,ffcd9524c80c0fa5bb859eaf,0,0,46e733fa58deea74d962874847a529fb4897e655,-59,244,6c381e7e7d8e984394ce02a3da912acc1b7d294e,...,5a95e3ee4af260d25d5c13fcf02760820cb6dbdc,67.382098,1334,-,-99.000000,3000,-,-99.000000,3000,5a0546857ecc773753327266_ffcd9524c80c0fa5bb859...
10130,51634,5a0546857ecc773753327266,0,ffcd9524c80c0fa5bb859eaf,0,0,46e733fa58deea74d962874847a529fb4897e655,-49,1166,46e733fa58deea74d962874847a529fb4897e655,...,-,-99.000000,3000,-,-99.000000,3000,-,-99.000000,3000,5a0546857ecc773753327266_ffcd9524c80c0fa5bb859...
10131,60483,5a0546857ecc773753327266,0,ffcd9524c80c0fa5bb859eaf,0,0,e5c220800a2d5ec83e355c3b1a2d7a141947e95f,-53,1309,6c381e7e7d8e984394ce02a3da912acc1b7d294e,...,216d1cba23183d9f79bdd0ba7e45f1d55e611c6e,18.299263,118,9b91bb4419360103cea4e1f7d434878550956b32,18.299263,2225,9e9913d8d1bf56161d199b1655220e8be5711f8f,18.800756,2213,5a0546857ecc773753327266_ffcd9524c80c0fa5bb859...


In [7]:
sub = pd.read_csv(DATA_DIR/'indoor-location-navigation/sample_submission.csv', index_col=0)
sub

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,75.0,75.0
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,75.0,75.0
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,75.0,75.0
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,75.0,75.0
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,75.0,75.0
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,0,75.0,75.0
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,0,75.0,75.0
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,0,75.0,75.0
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,0,75.0,75.0


BSSIDとRSSIは100ずつ存在しているけど全てが必要なわけではないみたい  
ここでは20だけ取り出している。

In [8]:
# training target features
NUM_FEATS = 80
BSSID_FEATS = [f'wifi_bssid_{i}' for i in range(NUM_FEATS)]
RSSI_FEATS  = [f'wifi_rssi_{i}' for i in range(NUM_FEATS)]
TIMEGAP_FEATS  = [f'wifi_timegap_{i}' for i in range(NUM_FEATS)]

B_NUM_FEATS = 10
MAC_FEATS = [f'beacon_macaddress_{i}' for i in range(B_NUM_FEATS)]
DIS_FEATS  = [f'beacon_distance_{i}' for i in range(B_NUM_FEATS)]
B_TIMEGAP_FEATS  = [f'beacon_timegap_{i}' for i in range(B_NUM_FEATS)]

train_df.loc[:,DIS_FEATS] = train_df.loc[:,DIS_FEATS].replace(np.inf,1e7)
test_df.loc[:,DIS_FEATS] = test_df.loc[:,DIS_FEATS].replace(np.inf,1e7)

In [9]:
# get numbers of bssids to embed them in a layer

# train
wifi_bssids = []
# bssidを列ごとにリストに入れていく
for i in BSSID_FEATS:
    wifi_bssids.extend(train_df.loc[:,i].values.tolist())
wifi_bssids = list(set(wifi_bssids))

train_wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(train): {train_wifi_bssids_size}')

# test
wifi_bssids_test = []
for i in BSSID_FEATS:
    wifi_bssids_test.extend(test_df.loc[:,i].values.tolist())
wifi_bssids_test = list(set(wifi_bssids_test))

test_wifi_bssids_size = len(wifi_bssids_test)
print(f'BSSID TYPES(test): {test_wifi_bssids_size}')


wifi_bssids.extend(wifi_bssids_test)
wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(all): {wifi_bssids_size}')

# train
beacon_macs = []
# bssidを列ごとにリストに入れていく
for i in MAC_FEATS:
    beacon_macs.extend(train_df.loc[:,i].values.tolist())
beacon_macs = list(set(beacon_macs))

train_beacon_macs_size = len(beacon_macs)
print(f'BSSID TYPES(train): {train_beacon_macs_size}')

# test
beacon_macs_test = []
for i in MAC_FEATS:
    beacon_macs_test.extend(test_df.loc[:,i].values.tolist())
beacon_macs_test = list(set(beacon_macs_test))

test_beacon_macs_size = len(beacon_macs_test)
print(f'BSSID TYPES(test): {test_beacon_macs_size}')


beacon_macs.extend(beacon_macs_test)
beacon_macs_size = len(beacon_macs)
print(f'BSSID TYPES(all): {beacon_macs_size}')

BSSID TYPES(train): 52185
BSSID TYPES(test): 25967
BSSID TYPES(all): 78152
BSSID TYPES(train): 16107
BSSID TYPES(test): 4413
BSSID TYPES(all): 20520


## preprocessing

In [14]:
# preprocess

le = LabelEncoder()
le.fit(wifi_bssids)

le_mac = LabelEncoder()
le_mac.fit(beacon_macs)

le_site = LabelEncoder()
le_site.fit(train_df['site_id'])

ss = StandardScaler()
ss.fit(train_df.loc[:,RSSI_FEATS])

ss_gap = StandardScaler()
ss_gap.fit(train_df.loc[:,TIMEGAP_FEATS])

ss_dis = StandardScaler()
ss_dis.fit(train_df.loc[:,DIS_FEATS])

ss_beacon_gap = StandardScaler()
ss_beacon_gap.fit(train_df.loc[:,B_TIMEGAP_FEATS])


def preprocess(input_df, le=le, le_site=le_site, ss=ss):
    output_df = input_df.copy()
    # RSSIの正規化
    output_df.loc[:,RSSI_FEATS] = ss.transform(input_df.loc[:,RSSI_FEATS])
    
    # gapの正規化
    output_df.loc[:,TIMEGAP_FEATS] = ss_gap.transform(input_df.loc[:,TIMEGAP_FEATS])

    # BSSIDのLE(1からふる)
    for i in BSSID_FEATS:
        output_df.loc[:,i] = le.transform(input_df.loc[:,i])
        
    # RSSIの正規化
    output_df.loc[:,DIS_FEATS] = ss_dis.transform(input_df.loc[:,DIS_FEATS])
    
    # gapの正規化
    output_df.loc[:,B_TIMEGAP_FEATS] = ss_beacon_gap.transform(input_df.loc[:,B_TIMEGAP_FEATS])

    # BSSIDのLE(1からふる)
    for i in MAC_FEATS:
        output_df.loc[:,i] = le_mac.transform(input_df.loc[:,i])
#         output_df.loc[:,i] = output_df.loc[:,i] + 1  # 0からではなく1から番号を振りたいため なぜ？

    # site_idのLE
    output_df.loc[:, 'site_id'] = le_site.transform(input_df.loc[:, 'site_id'])
    

    return output_df

train = preprocess(train_df)
test = preprocess(test_df)
train    

  return self.partial_fit(X, y)
  return self.partial_fit(X, y)
  return self.partial_fit(X, y)


Unnamed: 0,site_id,floor,path,x,y,wifi_bssid_0,wifi_rssi_0,wifi_timegap_0,wifi_bssid_1,wifi_rssi_1,...,beacon_timegap_6,beacon_macaddress_7,beacon_distance_7,beacon_timegap_7,beacon_macaddress_8,beacon_distance_8,beacon_timegap_8,beacon_macaddress_9,beacon_distance_9,beacon_timegap_9
0,0,-1,5e15730aa280850006f3d005,230.03738,153.496350,39873,0.145725,-1.134794,10121,0.161037,...,-0.816203,16382,-0.017373,0.580331,2356,-0.003689,-1.959768,11185,-0.005039,-1.306932
1,0,-1,5e15730aa280850006f3d005,231.40290,158.415150,17965,0.137646,0.537887,7792,0.152948,...,0.342621,7290,-0.017373,-0.518456,7290,-0.003689,-0.156913,7290,-0.005039,0.348762
2,0,-1,5e15730aa280850006f3d005,232.46200,164.416730,39921,0.129567,1.609485,17965,0.144860,...,-0.887296,7290,-0.017373,-0.754644,7290,-0.003689,0.387149,7290,-0.005039,0.327751
3,0,-1,5e15730aa280850006f3d005,233.94418,171.414170,39705,0.105331,-1.485328,39705,0.112506,...,-2.157838,7290,-0.017373,-2.170744,7290,-0.003689,-2.143880,7290,-0.005039,-2.109565
4,0,-1,5e15730b1506f2000638fc29,198.36833,163.520630,7522,0.121489,0.611993,47156,0.128683,...,-1.683543,7290,-0.017373,-0.117964,7290,-0.003689,-1.574993,2356,-0.005039,-2.355398
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
75273,23,6,5dd0d97c94e4900006125dd9,249.79349,74.839640,19175,0.194199,0.528477,19175,0.185302,...,0.787463,0,-0.017373,0.757985,0,-0.003689,0.731585,0,-0.005039,0.701753
75274,23,6,5dd0d97d878f3300066c750b,249.79349,74.839640,19175,0.145725,0.560237,7503,-0.105883,...,0.701135,3504,-0.017373,-2.082430,0,-0.003689,0.731585,0,-0.005039,0.701753
75275,23,6,5dd0d97d878f3300066c750b,249.43129,76.241234,19175,0.145725,-1.170083,19175,0.161037,...,-1.195029,3504,-0.017373,-1.669615,3504,-0.003689,-0.256210,0,-0.005039,0.701753
75276,23,6,5dd0d97d878f3300066c750b,242.54440,72.935265,19175,0.145725,1.490680,19175,0.128683,...,0.787463,0,-0.017373,0.757985,0,-0.003689,0.731585,0,-0.005039,0.701753


In [15]:
site_count = len(train['site_id'].unique())
site_count

24

## PyTorch model
- embedding layerが重要  

In [16]:
# dataset
from torch.utils.data import Dataset, DataLoader
class IndoorDataset(Dataset):
    def __init__(self, df, phase='train'):
        self.df = df
        self.phase = phase
        self.bssid_feats = df[BSSID_FEATS].values.astype(int)
        self.rssi_feats = df[RSSI_FEATS].values.astype(np.float32)
        self.timegap_feats = df[TIMEGAP_FEATS].values.astype(np.float32)
        
        self.mac_feats = df[MAC_FEATS].values.astype(int)
        self.dis_feats = df[DIS_FEATS].values.astype(np.float32)
        self.b_timegap_feats = df[B_TIMEGAP_FEATS].values.astype(np.float32)
        self.site_id = df['site_id'].values.astype(int)

        if phase in ['train', 'valid']:
            self.xy = df[['x', 'y']].values.astype(np.float32)
            self.floor = df['floor'].values.astype(np.float32)
        
    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        
        feature = {
            'BSSID_FEATS':self.bssid_feats[idx],
            'RSSI_FEATS':self.rssi_feats[idx],
            'TIMEGAP_FEATS':self.timegap_feats[idx],
            'MAC_FEATS':self.mac_feats[idx],
            'DIS_FEATS':self.dis_feats[idx],
            'B_TIMEGAP_FEATS':self.b_timegap_feats[idx],
            'site_id':self.site_id[idx]
        }
        if self.phase in ['train', 'valid']:
            target = {
                'xy':self.xy[idx],
                'floor':self.floor[idx]
            }
        else:
            target = {}
        return feature, target

In [17]:
import torch
from torch import nn

class LSTMModel(nn.Module):
    def __init__(self, bssid_size=94248, macs_size=94248, site_size=24, embedding_dim=64):
        super(LSTMModel, self).__init__()
        
        # bssid
        # ->64次元に圧縮後sequence化にする
        # wifi_bssids_sizeが辞書の数を表す
        self.bssid_embedding = nn.Embedding(bssid_size, 64, max_norm=True)
        self.mac_embedding = nn.Embedding(macs_size, 64, max_norm=True)
        # site
        # ->2次元に圧縮後sequence化する
        # site_countが辞書の数を表す       
        self.site_embedding = nn.Embedding(site_size, 64, max_norm=True)

        # rssi
        # 次元を64倍に線形変換
        self.rssi = nn.Sequential(
            nn.BatchNorm1d(NUM_FEATS),
            nn.Linear(NUM_FEATS, NUM_FEATS * 64)
        )
        
        self.timegap = nn.Sequential(
            nn.BatchNorm1d(NUM_FEATS),
            nn.Linear(NUM_FEATS, NUM_FEATS * 64)
        )
        
        self.dis = nn.Sequential(
            nn.BatchNorm1d(B_NUM_FEATS),
            nn.Linear(B_NUM_FEATS, B_NUM_FEATS * 64)
        )
        
        self.b_timegap = nn.Sequential(
            nn.BatchNorm1d(B_NUM_FEATS),
            nn.Linear(B_NUM_FEATS, B_NUM_FEATS * 64)
        )
        
        concat_size = 64 + (NUM_FEATS * 64) + (NUM_FEATS * 64) + (NUM_FEATS * 64) + (B_NUM_FEATS * 64) + (B_NUM_FEATS * 64) + (B_NUM_FEATS * 64)
        self.linear_layer2 = nn.Sequential(
            nn.BatchNorm1d(concat_size),
            nn.Dropout(0.3),
            nn.Linear(concat_size, 256),
            nn.ReLU()
        )
        self.bn1 = nn.BatchNorm1d(concat_size)

        self.flatten = nn.Flatten()

        self.dropout1 = nn.Dropout(0.3)
        self.linear1 = nn.Linear(in_features=concat_size, out_features=256)#, bias=False)
        self.bn2 = nn.BatchNorm1d(256)

        self.batch_norm1 = nn.BatchNorm1d(1)
        self.lstm1 = nn.LSTM(input_size=256,hidden_size=128,dropout=0.3, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=128,hidden_size=16,dropout=0.1, batch_first=True)

        self.fc_xy = nn.Linear(16, 2)
        # self.fc_x = nn.Linear(16, 1)
        # self.fc_y = nn.Linear(16, 1)
        self.fc_floor = nn.Linear(16, 1)

    
    def forward(self, x):
        # input embedding
        batch_size = x["site_id"].shape[0]
        x_bssid = self.bssid_embedding(x['BSSID_FEATS'])
        x_bssid = self.flatten(x_bssid)
        
        x_mac = self.mac_embedding(x['MAC_FEATS'])
        x_mac = self.flatten(x_mac)
        
        x_site_id = self.site_embedding(x['site_id'])
        x_site_id = self.flatten(x_site_id)

        x_rssi = self.rssi(x['RSSI_FEATS'])
        x_dis = self.dis(x['DIS_FEATS'])
        
        x_timegap = self.timegap(x['TIMEGAP_FEATS'])
        x_b_timegap = self.b_timegap(x['B_TIMEGAP_FEATS'])


        x = torch.cat([x_bssid, x_site_id, x_rssi, x_timegap, x_mac, x_dis, x_b_timegap], dim=1)
        x = self.linear_layer2(x)

        # lstm layer
        x = x.view(batch_size, 1, -1)  # [batch, 1]->[batch, 1, 1]
        x = self.batch_norm1(x)
        x, _ = self.lstm1(x)
        x = torch.relu(x)
        x, _ = self.lstm2(x)
        x = torch.relu(x)

        # output [batch, 1, 1] -> [batch]
        # x_ = self.fc_x(x).view(-1)
        # y_ = self.fc_y(x).view(-1)
        xy = self.fc_xy(x).squeeze(1)
        floor = torch.relu(self.fc_floor(x)).view(-1)
        # return {"x":x_, "y":y_, "floor":floor} 
        return {"xy": xy, "floor": floor}

In [18]:
def mean_position_error(xhat, yhat, fhat, x, y, f):
    intermediate = np.sqrt(np.power(xhat-x, 2) + np.power(yhat-y, 2)) + 15 * np.abs(fhat-f)
    return intermediate.sum()/xhat.shape[0]

def to_np(input):
    return input.detach().cpu().numpy()

In [19]:
def get_optimizer(model: nn.Module, config: dict):
    optimizer_config = config["optimizer"]
    optimizer_name = optimizer_config.get("name")
    base_optimizer_name = optimizer_config.get("base_name")
    optimizer_params = optimizer_config['params']

    if hasattr(optim, optimizer_name):
        optimizer = optim.__getattribute__(optimizer_name)(model.parameters(), **optimizer_params)
        return optimizer
    else:
        base_optimizer = optim.__getattribute__(base_optimizer_name)
        optimizer = globals().get(optimizer_name)(
            model.parameters(), 
            base_optimizer,
            **optimizer_config["params"])
        return  optimizer

def get_scheduler(optimizer, config: dict):
    scheduler_config = config["scheduler"]
    scheduler_name = scheduler_config.get("name")

    if scheduler_name is None:
        return
    else:
        return optim.lr_scheduler.__getattribute__(scheduler_name)(
            optimizer, **scheduler_config["params"])


def get_criterion(config: dict):
    loss_config = config["loss"]
    loss_name = loss_config["name"]
    loss_params = {} if loss_config.get("params") is None else loss_config.get("params")
    if hasattr(nn, loss_name):
        criterion = nn.__getattribute__(loss_name)(**loss_params)
    else:
        criterion = globals().get(loss_name)(**loss_params)

    return criterion

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

In [20]:
# Learner class(pytorch-lighting)
class Learner(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.xy_criterion = get_criterion(config)
        self.f_criterion = get_criterion(config)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = self.xy_criterion(output["xy"], y["xy"])
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        xy_loss = self.xy_criterion(output["xy"], y["xy"])
        f_loss = self.f_criterion(output["floor"], y["floor"])
        loss = xy_loss  # + f_loss
        mpe = mean_position_error(
            to_np(output['xy'][:, 0]), to_np(output['xy'][:, 1]), 0, 
            to_np(y['xy'][:, 0]), to_np(y['xy'][:, 1]), 0)
        
        # floor lossは現状は無視して良い
        self.log(f'Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/xy', xy_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/floor', f_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'MPE/val', mpe, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return mpe
    
    def validation_epoch_end(self, outputs):
        avg_loss = np.mean(outputs)
        print(f'epoch = {self.current_epoch}, mpe_loss = {avg_loss}')

    def configure_optimizers(self):
        optimizer = get_optimizer(self.model, self.config)
        scheduler = get_scheduler(optimizer, self.config)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "Loss/val"}

In [21]:
# oof
def evaluate(model, loaders, phase):
    x_list = []
    y_list = []
    f_list = []
    with torch.no_grad():
        for batch in loaders[phase]:
            x, y = batch
            output = model(x)
            x_list.append(to_np(output['xy'][:, 0]))
            y_list.append(to_np(output['xy'][:, 1]))
            f_list.append(to_np(output['floor']))

    x_list = np.concatenate(x_list)
    y_list = np.concatenate(y_list)
    f_list = np.concatenate(f_list)
    return x_list, y_list, f_list

## train

In [None]:
oofs = []  # 全てのoofをdfで格納する
predictions = []  # 全ての予測値をdfで格納する
val_scores = []
# skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
gkf = GroupKFold(n_splits=N_SPLITS)
# for fold, (trn_idx, val_idx) in enumerate(skf.split(train.loc[:, 'path'], train.loc[:, 'path'])):
for fold, (trn_idx, val_idx) in enumerate(gkf.split(train.loc[:, 'path'], groups=train.loc[:, 'path'])):

    # 指定したfoldのみループを回す

    print('=' * 20)
    print(f'Fold {fold}')
    print('=' * 20)

    # train/valid data
    trn_df = train.loc[trn_idx, BSSID_FEATS + RSSI_FEATS + TIMEGAP_FEATS + 
                       MAC_FEATS + DIS_FEATS + B_TIMEGAP_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)
    val_df = train.loc[val_idx, BSSID_FEATS + RSSI_FEATS + TIMEGAP_FEATS + 
                       MAC_FEATS + DIS_FEATS + B_TIMEGAP_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)

    # data loader
    loaders = {}
    loader_config = config["loader"]
    loaders["train"] = DataLoader(IndoorDataset(trn_df, phase="train"), **loader_config["train"], worker_init_fn=worker_init_fn) 
    loaders["valid"] = DataLoader(IndoorDataset(val_df, phase="valid"), **loader_config["valid"], worker_init_fn=worker_init_fn)
    loaders["test"] = DataLoader(IndoorDataset(test, phase="test"), **loader_config["test"], worker_init_fn=worker_init_fn)
    
    # model
    model = LSTMModel(wifi_bssids_size, beacon_macs_size, site_count)
    model_name = model.__class__.__name__
    
    # loggers
    RUN_NAME = f'exp{str(EXP_NAME)}'
    wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type=RUN_NAME + f'-fold-{fold}')
    wandb.run.name = RUN_NAME + f'-fold-{fold}'
    wandb_config = wandb.config
    wandb_config.model_name = model_name
    wandb.watch(model)
    
    
    loggers = []
    loggers.append(WandbLogger())

    learner = Learner(model, config)
    
    # callbacks
    callbacks = []
    checkpoint_callback = ModelCheckpoint(
        monitor=f'Loss/val',
        mode='min',
        dirpath=OUTPUT_DIR,
        verbose=False,
        filename=f'{model_name}-{learner.current_epoch}-{fold}')
    callbacks.append(checkpoint_callback)

    early_stop_callback = EarlyStopping(
        monitor='Loss/val',
        min_delta=0.00,
        patience=3,
        verbose=True,
        mode='min')
    callbacks.append(early_stop_callback)
    
    trainer = pl.Trainer(
        logger=loggers,
        checkpoint_callback=callbacks,
        max_epochs=MAX_EPOCHS,
        default_root_dir=OUTPUT_DIR,
        gpus=1,
        fast_dev_run=DEBUG,
        deterministic=True,
        benchmark=True,
#         precision=16,
#         progress_bar_refresh_rate=0  # vscodeの時progress barの動作が遅いので表示しない
        )


    trainer.fit(learner, train_dataloader=loaders['train'], val_dataloaders=loaders['valid'])

    #############
    # validation (to make oof)
    #############
    model.eval()
    oof_x, oof_y, oof_f = evaluate(model, loaders, phase="valid")
    val_df["oof_x"] = oof_x
    val_df["oof_y"] = oof_y
    val_df["oof_floor"] = oof_f
    oofs.append(val_df)
    
    val_score = mean_position_error(
        val_df["oof_x"].values, val_df["oof_y"].values, 0,
        val_df['x'].values, val_df['y'].values, 0)
    val_scores.append(val_score)
    print(f"fold {fold}: mean position error {val_score}")

    #############
    # inference
    #############
    preds_x, preds_y, preds_f = evaluate(model, loaders, phase="test")
    test_preds = pd.DataFrame(np.stack((preds_f, preds_x, preds_y))).T
    test_preds.columns = sub.columns
    test_preds["site_path_timestamp"] = test["site_path_timestamp"]
    test_preds["floor"] = test_preds["floor"].astype(int)
    predictions.append(test_preds)
    

Fold 0


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))
[34m[1mwandb[0m: Currently logged in as: [33msqrt4kaido[0m (use `wandb login --relogin` to force relogin)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 16.3 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
16.3 M    Trainable params
0         Non-trainable params
16.3 M    Total params
65.272    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 185.50580596923828


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 164.21453346839317


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 162.68755182119514


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 161.26596838144158


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 159.88225518251076


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 158.52115003145659


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 157.17906498053134


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 155.8494790297288


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 154.5323122660319


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 153.22463774069763


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 151.92757818760018


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 150.63906539525743


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 149.35911830021783


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 148.08891380505685


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 146.82688483213767


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 145.5763808225974


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 144.3334410251715


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 143.0993180103791


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 141.87454984616008


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 140.65893247555462


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 139.45034025143352


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 138.2530257396209


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 137.06298274504834


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 135.88268590095717


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 134.71271895384177


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 133.55134776188777


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 132.3998813726963


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 131.2589983622233


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 130.12617835998535


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 129.00449844751603


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 127.89291063944499


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 126.7937385950333


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 125.70113182067871


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 124.62274239368928


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 123.55364780915089


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 122.49472999572754


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 121.44845171219262


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 120.413648439065


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 119.39071903717824


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 118.37874089754544


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 117.37894954192332


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 116.39163589477539


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 115.41663920084635


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 114.4549687507825


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 113.50366731301331


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 112.56558943528395


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 111.63933180784566


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 110.7275762900328


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 109.82693482423441


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 108.93899110647348


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 108.06310786711863


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 107.19987117082646


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 106.34938068878957


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 105.51050260494917


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 104.68576745253343


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 103.87299942603478


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 103.07379691295135


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 102.28752660506811


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 101.51498597952036


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 100.75537997514773


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 100.01126636602939


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 99.2796947870499


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 98.56223665873209


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 97.85816122201773


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 97.1690634213961


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 96.49504535381611


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 95.83533095335348


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 95.18920243092072


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 94.55866262973883


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 93.94233352465508


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 93.33963837256799


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 92.75046227283967


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 92.17539239541078


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 91.61489377144056


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 91.06564830877842


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 90.53166983188727


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 90.01078847249349


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 89.50435911325307


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 89.00975176493327


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 88.53022428659293


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 88.06446646665914


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 87.61194918705867


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 87.17344636183519


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 86.74851776514298


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 86.33865483601888


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 85.9418761131091


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 85.55916349948981


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 85.1901193569868


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 84.83358509357159


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 84.49177040686975


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 84.16308504740397


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 83.84786721254007


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 83.54432592147437


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 83.25328679207044


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 82.9763583354461


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 82.71191479609563


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 82.45964242005961


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 82.21874809265137


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 81.99105425125512


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 81.77460690033742


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 81.56864146697215


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 80.77766870351938


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 72.22503167176859


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 70.71740951538087


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 69.94296202048278


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 69.1327172303811


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 68.47070334507869


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 67.83581787011562


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 67.26919415302766


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 66.69492616897975


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 66.22493870563996


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 65.67819861387595


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 65.11937413826966


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 64.52781217526167


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 64.19149635510567


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 63.590401942913346


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 63.14219340299949


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 62.70362789936554


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 62.25926517584385


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 61.842777369572566


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 61.42409010178004


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 60.6428036909837


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 60.035848192068244


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 59.62021657014505


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 59.00569032522348


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 58.58489869435628


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 58.14309962346003


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 57.69986280783629


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 57.26826893244034


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 56.87449635236691


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 56.549303250435074


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 56.15566631708389


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 55.72510838630872


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 55.340171099931766


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 55.02423300131773


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 54.649075747758914


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 54.29378399971203


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 53.94693208841177


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 53.37293145106389


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 52.89320846948868


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 52.553182866023135


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 52.15190177819668


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 51.823186708107976


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 51.407547041086055


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 51.031285373981184


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 50.6753368964562


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 50.393764471396416


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 50.054389899816265


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 49.83603002597124


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 49.42348681718875


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 49.10769348144531


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 48.7339834359976


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 48.421540710253595


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 48.07668234507243


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 47.58828861530011


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 47.13448937244904


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 46.749201422471266


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 46.40439704503769


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 45.98604679107666


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 45.6199368256789


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 45.275859607794345


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 44.87975508861053


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 44.51674419794327


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 44.12896214020558


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 43.842440605163574


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 43.42462201729799


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 43.06160600124261


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 42.73014196738219


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 42.38409704061655


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 42.0692001293867


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 41.7712527348445


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 41.356556002298994


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 41.01743808648526


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 40.67887920477451


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 40.35922321417392


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 39.94199211903107


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 39.661359518002236


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 39.31086054092798


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 38.968641508542575


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 38.640452903356305


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 38.28806829207983


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 37.97169240804819


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 37.68799068255302


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 37.297861216618465


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 36.97957339164539


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 36.62945007911095


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 36.316778170756805


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 35.967099336477425


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 35.664007869133584


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 35.3259547037956


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 34.98782885869344


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 34.70558058420817


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 34.38608503586207


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 34.05516784863594


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 33.774956600482646


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 33.414751390310435


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 33.15274800520677


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 32.83171578431741


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 32.545239089085506


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 32.18667281957773


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 31.91097989449134
fold 0: mean position error 31.94406264090428
Fold 1


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))


VBox(children=(Label(value=' 0.01MB of 0.01MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,1096.37659
Loss/xy,1096.37659
Loss/floor,5.79977
MPE/val,31.94406
epoch,199.0
trainer/global_step,23599.0
_runtime,433.0
_timestamp,1617473871.0
_step,199.0


0,1
Loss/val,██▇▇▆▆▅▅▅▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
Loss/xy,██▇▇▆▆▅▅▅▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
Loss/floor,████████████████████████▆▁▂▆▇███████████
MPE/val,██▇▇▇▆▆▆▅▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 16.3 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
16.3 M    Trainable params
0         Non-trainable params
16.3 M    Total params
65.272    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 195.15760040283203


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 166.4236842228816


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 165.4966950441018


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 164.48180678930038


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 163.55049958840394


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 162.64638266930214


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 161.75907829480292


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 160.88214518229168


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 160.00978147066556


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 159.14620781922952


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 158.29060232700445


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 157.4371874295748


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 156.58663682204028


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 155.7438499255058


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 154.9018365028577


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 154.0638660919972


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 153.23172361911872


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 152.39834868602262


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 151.57236369206353


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 150.7495622683794


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 149.9263015942696


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 149.11083366198417


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 148.2968764476287


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 147.48434954912233


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 146.67980876824794


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 145.88002057197767


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 145.07935767540567


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 144.27603481977414


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 143.4886687058669


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 142.69610847081896


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 141.9083386739095


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 141.1292752779447


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 140.3430654183412


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 139.57265599568686


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 138.7972546112843


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 138.02328713245882


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 137.2574082692464


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 136.49257479936648


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 135.72845678084937


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 134.98022578312802


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 134.22021361130933


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 133.47855008443196


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 132.73094830635267


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 131.99365811470227


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 131.2501277923584


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 130.51354059072642


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 129.78560735262357


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 129.0633153475248


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 128.33474782307943


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 127.60594454056177


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 126.89609596056816


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 126.18557205200196


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 125.47806069789789


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 124.76059867663261


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 124.07007190997783


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 123.36617308396559


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 122.68443011748485


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 121.98504483149601


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 121.31294939090044


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 120.63535957336425


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 119.96099633436937


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 119.29366391499838


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 118.62451406625601


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 117.95797052627954


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 117.32038850050706


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 116.66714360163762


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 116.0248241228935


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 115.36940999153333


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 114.73884486663036


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 114.12226693569086


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 113.48241683764336


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 112.8741899050199


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 112.26654522235576


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 111.66213800479204


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 111.06213627350637


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 110.46350048749875


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 109.86399635901817


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 109.28369600344926


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 108.70418181786171


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 108.11475161038912


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 107.54558154375125


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 106.9798934349647


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 106.43100767869215


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 105.85663738739797


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 105.31807188376403


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 104.77071986076159


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 104.24593244699331


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 103.71531606820913


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 103.19154531038724


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 102.66942403744429


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 102.15568418258276


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 101.28237077761919


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 100.59045184208797


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 99.9639492817414


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 99.35383246984237


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 98.78032695085575


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 98.21485036214193


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 97.46098935053898


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 95.03888706305088


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 92.99421203808906


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 91.37983596019257


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 90.13650390429375


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 88.97699463673126


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 87.90511629153521


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 86.94432279635699


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 86.02823723524044


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 85.1589059438461


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 84.34100123185378


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 83.4966822795379


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 82.74675395183074


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 80.85901052523882


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 79.76894234877366


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 78.7841617388603


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 77.90170207879483


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 76.92305920918783


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 76.11021967178735


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 75.3052658765744


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 74.50887653644267


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 73.74628942440718


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 72.82942042717566


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 71.97362509996464


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 71.12157701834654


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 70.42464751708202


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 69.67093112163055


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 69.00812319731101


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 68.3468441596398


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 67.71456192212227


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 67.06913792047744


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 66.51284366509853


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 65.8456379768176


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 65.23443524776361


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 64.6651625657693


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 64.11134749192458


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 63.539692472800226


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 62.951444650307685


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 62.48221326974722


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 61.9159180078751


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 61.38612256172375


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 60.930719673939244


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 60.378814614124785


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 59.88543075659336


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 59.44726197658441


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 58.91243160932492


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 58.484623913887226


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 58.02016063592373


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 57.58203585698054


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 57.13406004294371


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 56.73791082333296


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 56.261215068132444


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 55.890159059182196


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 55.492509910387874


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 55.08657012352577


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 54.69676717122396


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 54.32864536872277


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 53.9775392190004


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 53.55571793776292


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 53.201374024611255


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 52.881135128705935


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 52.48099399957901


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 52.11852062910031


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 51.77794640369904


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 51.41094168638572


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 51.04613425425995


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 50.676095077319026


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 50.34397906278953


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 49.98350299933018


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 49.595476067371855


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 49.28078318864871


In [22]:
if len(oofs) > 1:
    oofs_df = pd.concat(oofs)
else:
    oofs_df = oofs[0]
oofs_df.to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}.csv", index=False)
oofs_df

Unnamed: 0,wifi_bssid_0,wifi_bssid_1,wifi_bssid_2,wifi_bssid_3,wifi_bssid_4,wifi_bssid_5,wifi_bssid_6,wifi_bssid_7,wifi_bssid_8,wifi_bssid_9,...,wifi_timegap_77,wifi_timegap_78,wifi_timegap_79,site_id,x,y,floor,oof_x,oof_y,oof_floor
0,39873,10121,39095,16067,52302,13060,39873,16067,10121,52302,...,-1.330807,-1.330601,-1.331937,0,230.03738,153.496350,-1,176.430313,155.572342,0.157174
1,17965,7792,39921,19108,29424,39921,16191,7792,19108,29424,...,0.432958,0.430131,-1.733404,0,231.40290,158.415150,-1,176.407486,155.841202,0.158001
2,39921,17965,16191,7792,19108,52934,52302,16067,10121,39095,...,-0.580340,-0.581424,1.269790,0,232.46200,164.416730,-1,176.420746,155.689865,0.157538
3,39705,39705,39873,39873,39705,16067,52302,39095,13060,10121,...,0.232759,0.230276,0.227092,0,233.94418,171.414170,-1,176.414627,155.753967,0.157730
4,26417,12652,41723,21199,16863,39050,23964,19076,17470,10752,...,0.378155,0.375422,0.372066,0,210.86192,165.376080,-1,176.171432,157.637131,0.138766
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15050,19175,19175,19175,7503,18513,13910,7503,18513,13910,7503,...,-1.715547,0.164402,0.161296,23,249.43129,76.241234,6,176.280716,118.478104,0.013724
15051,19175,19175,19175,13910,32715,13910,32715,13910,7503,34182,...,0.582827,0.579743,0.576145,23,237.22395,73.177680,6,168.918488,119.839279,0.107121
15052,19175,19175,19175,7503,32715,7503,7503,7038,7038,32715,...,-0.258232,-0.259870,-0.262474,23,242.54440,72.935265,6,170.729889,116.128365,0.027674
15053,19175,19175,19175,7503,7038,7503,7503,7038,32715,13910,...,-0.471852,-0.473123,-0.475475,23,249.43129,76.241234,6,172.952377,113.620102,0.000000


In [23]:
    # foldの結果を平均した後、reindexでsubmission fileにindexを合わせる
all_preds = pd.concat(predictions).groupby('site_path_timestamp').mean().reindex(sub.index)

all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,91.197029,101.293404
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,84.478302,101.666359
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,85.910324,106.208855
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,88.830978,106.066658
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,88.957901,105.958168
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,0,167.698700,125.386986
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,0,167.104019,124.574036
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,0,170.583908,129.099243
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,0,171.617325,130.661057


In [24]:
# floorの数値を置換
simple_accurate_99 = pd.read_csv('../01/submission.csv')
all_preds['floor'] = simple_accurate_99['floor'].values
all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,91.197029,101.293404
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,84.478302,101.666359
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,85.910324,106.208855
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,88.830978,106.066658
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,88.957901,105.958168
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,5,167.698700,125.386986
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,5,167.104019,124.574036
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,5,170.583908,129.099243
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,5,171.617325,130.661057


In [25]:
all_preds.to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}.csv")

In [26]:
print(f"CV:{np.mean(val_scores)}")

CV:34.501771288999024


In [27]:
wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type='summary')
wandb.run.name = 'summary'
wandb.log({'CV_score': np.mean(val_scores)})
wandb.save(utils.get_notebook_path())
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,878.90088
Loss/xy,878.90088
Loss/floor,5.63142
MPE/val,26.88436
epoch,199.0
trainer/global_step,23599.0
_runtime,384.0
_timestamp,1617432959.0
_step,199.0


0,1
Loss/val,██▇▇▆▆▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,██▇▇▆▆▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁███▇▇▇▇▇▇▇▆▆▆▄▃▃▃▃▃▃▃▂▂▂▂
MPE/val,██▇▇▇▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███




VBox(children=(Label(value=' 0.57MB of 0.57MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
CV_score,34.50177
_runtime,2.0
_timestamp,1617433227.0
_step,0.0


0,1
CV_score,▁
_runtime,▁
_timestamp,▁
_step,▁
