# LSTM baseline

from kuto

In [1]:
import os
import sys
import glob
import pickle
import random

import numpy as np
import pandas as pd
import scipy.stats as stats
from pathlib import Path
from tqdm import tqdm

sys.path.append('../../')
import src.utils as utils
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

import pytorch_lightning as pl
# from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping

import wandb
from pytorch_lightning.loggers import WandbLogger


In [2]:
DATA_DIR = Path("/home/knikaido/work/Indoor-Location-Navigation/data/")
WIFI_DIR = DATA_DIR / 'indoorunifiedwifids'
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'
OUTPUT_DIR = Path('./output/')
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'

## config

In [3]:
configs = {
    'loss':{
        'name': 'MSELoss',
        'params':{}
    },
    'optimizer':{
        'name': 'Adam',
        'params':{
            'lr': 0.001,
        }
    },

    'scheduler':{
        'name': 'ReduceLROnPlateau',
        'params':{
            'factor': 0.1,
            'patience': 3,
        }
    },

    'loader':{
        'train':{
            'batch_size': 512,
            'shuffle': True,
            'num_workers': 4,
        },
        'valid':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        },
        'test':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        }
    }
}

In [4]:
# config
config = configs

# globals variable
SEED = 777
MAX_EPOCHS = 200
N_SPLITS = 5
DEBUG = False
# EXP_MESSAGE = config['globals']['exp_message']

EXP_NAME = 15
IS_SAVE = True

utils.set_seed(SEED)

In [5]:
# training target features
NUM_FEATS = 100
BSSID_FEATS = [f'bssid_{i}' for i in range(NUM_FEATS)]
RSSI_FEATS  = [f'rssi_{i}' for i in range(NUM_FEATS)]
POSX_FEATS = [f'bssid_x_{i}' for i in range(NUM_FEATS)]
POSY_FEATS = [f'bssid_y_{i}' for i in range(NUM_FEATS)]

In [6]:
!wandb login e8aaf98060af90035c3c28a83b34452780aeec20

/bin/sh: 1: wandb: not found


## read data

In [7]:
train_df = pd.read_csv(WIFI_DIR / 'train_all.csv')
test_df = pd.read_csv(WIFI_DIR / 'test_all.csv')

In [8]:
train_df

Unnamed: 0,bssid_0,bssid_1,bssid_2,bssid_3,bssid_4,bssid_5,bssid_6,bssid_7,bssid_8,bssid_9,...,rssi_95,rssi_96,rssi_97,rssi_98,rssi_99,x,y,floor,path,site_id
0,db01605eac3f33540038bd9722aba25774871d43,965f254a2e8d05bbb40bd2413ff61de3ad6c4151,0b64e537cc3d1818ec46f94f8dc14043a98d0089,922e582c66016a2b9f64e38f89ebe82f66eefb24,dc4c46287575c45f3e32c022d868d047b485ed4c,93e20595eeef175d3aa3c3381f6a22ee792d48d9,b2b0ddbb5a2aadfc6ab2f388db584b6c280d3f82,8c936564ea4b4300576f53136505527eb5972c07,61c3aaf1a526f808c05952ea3f098e37354a674a,3f564032c7eebc173b38aee35225e323d4389faf,...,-79,-79,-79,-79,-79,107.85044,161.892620,-1,5e1580adf4c3420006d520d4,5a0546857ecc773753327266
1,965f254a2e8d05bbb40bd2413ff61de3ad6c4151,db01605eac3f33540038bd9722aba25774871d43,1f37bbb3f42125f665b83584d0376b21ec3eb43c,922e582c66016a2b9f64e38f89ebe82f66eefb24,dc4c46287575c45f3e32c022d868d047b485ed4c,93e20595eeef175d3aa3c3381f6a22ee792d48d9,5c10b343d767a30515e6015de25751a2883328f8,3f564032c7eebc173b38aee35225e323d4389faf,46c934893439700099d03a6892ea934ecb2729d6,16374260af7d03b10f167358a4f6a70620e131f4,...,-79,-79,-79,-80,-80,107.85044,161.892620,-1,5e1580adf4c3420006d520d4,5a0546857ecc773753327266
2,965f254a2e8d05bbb40bd2413ff61de3ad6c4151,db01605eac3f33540038bd9722aba25774871d43,dc4c46287575c45f3e32c022d868d047b485ed4c,922e582c66016a2b9f64e38f89ebe82f66eefb24,93e20595eeef175d3aa3c3381f6a22ee792d48d9,61c3aaf1a526f808c05952ea3f098e37354a674a,ce28608c3d091ac0d25d84459ebad253edf83e1f,1bb0e992cff45a54d29e97f47a7d1281435a5e3b,1f37bbb3f42125f665b83584d0376b21ec3eb43c,ca86c5b074c5768e481e069b751bf22c6d95bd48,...,-77,-78,-78,-78,-78,98.33065,163.343340,-1,5e1580adf4c3420006d520d4,5a0546857ecc773753327266
3,61c3aaf1a526f808c05952ea3f098e37354a674a,922e582c66016a2b9f64e38f89ebe82f66eefb24,93e20595eeef175d3aa3c3381f6a22ee792d48d9,db01605eac3f33540038bd9722aba25774871d43,965f254a2e8d05bbb40bd2413ff61de3ad6c4151,0f5daed11a61e0d6941a1a42ff428ca216d61003,ce28608c3d091ac0d25d84459ebad253edf83e1f,40d99a3e5214aa704f637b7d72631e69550ee256,2aa08d092d0199c06d22684642ef1c79d9722adb,149c09a117b9851201c75f97b4a7cc94b75fdcb4,...,-75,-76,-76,-77,-77,98.33065,163.343340,-1,5e1580adf4c3420006d520d4,5a0546857ecc773753327266
4,965f254a2e8d05bbb40bd2413ff61de3ad6c4151,93e20595eeef175d3aa3c3381f6a22ee792d48d9,61c3aaf1a526f808c05952ea3f098e37354a674a,51782c2fabefa97e99dca895fd36f1a47e214610,db01605eac3f33540038bd9722aba25774871d43,0f5daed11a61e0d6941a1a42ff428ca216d61003,ce28608c3d091ac0d25d84459ebad253edf83e1f,4c83a7a1e51bfa8a5fa20e854ab3feec057c52c9,599fa96d549ed870671d6bc1927aaa8bbaacca12,dc9fd0f591e9bfc22748106f31d72a23c1d294fd,...,-75,-76,-76,-77,-77,98.33065,163.343340,-1,5e1580adf4c3420006d520d4,5a0546857ecc773753327266
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
258120,930dfcc059cd5c29f94b37435b31b14adf6141c1,3c79d9efcdb4c5803ab1b97982c2f0b87519c477,4472ce0cc1af0641cdffd6cdc850e9dd2ec4ab91,5953d0b2247e16447d327eb2a8a9c1abe24ff425,d2aa98dcefa4d46c4a946d26c5311bddddad9d6c,5f583dcccc43b5b7ac25d270e29c92d878fb2be0,040bfca8631c5506cba1c22cf7750636ba26ea54,7fcc50af42706af2d8ecfe55b7ccbb731aae6e27,b719d7de6ffe1b1b7409d2eb5ac3268e36cf2675,89395d0ee75307b3beb30aef2f19fc680095d514,...,-84,-85,-85,-85,-85,122.68994,124.028015,6,5dcd5c88a4dbe7000630b084,5dc8cea7659e181adb076a3f
258121,930dfcc059cd5c29f94b37435b31b14adf6141c1,4472ce0cc1af0641cdffd6cdc850e9dd2ec4ab91,5953d0b2247e16447d327eb2a8a9c1abe24ff425,040bfca8631c5506cba1c22cf7750636ba26ea54,d2aa98dcefa4d46c4a946d26c5311bddddad9d6c,5f583dcccc43b5b7ac25d270e29c92d878fb2be0,7e5f2b71ff184401c1c31ffeab2861d86f1f9c25,8bed40c61bc42247c3b142f214f78dcc5850553f,f9100b425367a18de9f1cec34a34ea60cc7bf009,3c79d9efcdb4c5803ab1b97982c2f0b87519c477,...,-85,-85,-85,-85,-85,127.17589,123.677780,6,5dcd5c88a4dbe7000630b084,5dc8cea7659e181adb076a3f
258122,040bfca8631c5506cba1c22cf7750636ba26ea54,930dfcc059cd5c29f94b37435b31b14adf6141c1,3c79d9efcdb4c5803ab1b97982c2f0b87519c477,4472ce0cc1af0641cdffd6cdc850e9dd2ec4ab91,5953d0b2247e16447d327eb2a8a9c1abe24ff425,5f583dcccc43b5b7ac25d270e29c92d878fb2be0,d2aa98dcefa4d46c4a946d26c5311bddddad9d6c,7e5f2b71ff184401c1c31ffeab2861d86f1f9c25,ca69ae425b53d4c2fae3d97ec4ec61897a4a6b73,f9100b425367a18de9f1cec34a34ea60cc7bf009,...,-84,-84,-85,-85,-85,127.17589,123.677780,6,5dcd5c88a4dbe7000630b084,5dc8cea7659e181adb076a3f
258123,3c79d9efcdb4c5803ab1b97982c2f0b87519c477,4472ce0cc1af0641cdffd6cdc850e9dd2ec4ab91,930dfcc059cd5c29f94b37435b31b14adf6141c1,5953d0b2247e16447d327eb2a8a9c1abe24ff425,d2aa98dcefa4d46c4a946d26c5311bddddad9d6c,040bfca8631c5506cba1c22cf7750636ba26ea54,7e5f2b71ff184401c1c31ffeab2861d86f1f9c25,ca69ae425b53d4c2fae3d97ec4ec61897a4a6b73,f5c69326982eb7d74f899b933c4cb5a3d1592af2,5f583dcccc43b5b7ac25d270e29c92d878fb2be0,...,-85,-85,-85,-85,-85,127.17589,123.677780,6,5dcd5c88a4dbe7000630b084,5dc8cea7659e181adb076a3f


In [9]:
use_sites = train_df['site_id'].unique()
use_sites

bssid_uni = []

for site in use_sites:

    bssid_csv = pd.read_csv(str(DATA_DIR) + f'/bssid_position/nb013_bssid_position/nb013_bssid_position_{site}.csv')
    bssid_tgts = bssid_csv['bssid'].unique()

    for tgt in tqdm(bssid_tgts):
        bssid_rssi = bssid_csv[bssid_csv['bssid'] == tgt]['n_samples_rssi_over_m50']
        bssid_series = bssid_csv.iloc[bssid_rssi.index[bssid_rssi.values.argmax()]]
        bssid_uni.append(bssid_series)
bssid_uni_df = pd.DataFrame(bssid_uni)


100%|██████████| 2766/2766 [00:02<00:00, 1314.85it/s]
100%|██████████| 2568/2568 [00:01<00:00, 1292.47it/s]
100%|██████████| 5492/5492 [00:05<00:00, 1040.20it/s]
100%|██████████| 1139/1139 [00:00<00:00, 1573.19it/s]
100%|██████████| 1610/1610 [00:01<00:00, 1511.84it/s]
100%|██████████| 718/718 [00:00<00:00, 1631.44it/s]
100%|██████████| 907/907 [00:00<00:00, 1609.97it/s]
100%|██████████| 1409/1409 [00:00<00:00, 1491.18it/s]
100%|██████████| 1763/1763 [00:01<00:00, 1394.05it/s]
100%|██████████| 5131/5131 [00:04<00:00, 1160.20it/s]
100%|██████████| 1300/1300 [00:00<00:00, 1588.71it/s]
100%|██████████| 876/876 [00:00<00:00, 1582.73it/s]
100%|██████████| 300/300 [00:00<00:00, 1688.68it/s]
100%|██████████| 2200/2200 [00:01<00:00, 1302.43it/s]
100%|██████████| 974/974 [00:00<00:00, 1433.56it/s]
100%|██████████| 528/528 [00:00<00:00, 1680.45it/s]
100%|██████████| 1203/1203 [00:00<00:00, 1558.88it/s]
100%|██████████| 904/904 [00:00<00:00, 1609.80it/s]
100%|██████████| 1315/1315 [00:00<00:00, 1

In [10]:
bssid_uni_df.reset_index(drop=True, inplace=True)

In [11]:
for i, bssid in tqdm(enumerate(BSSID_FEATS)):
    train_df = train_df.merge(bssid_uni_df[['bssid', 'bssid_x', 'bssid_y', 'site']],
                               how='left', left_on=[bssid, 'site_id'], right_on=['bssid', 'site']
                              ).drop(['bssid', 'site'], axis=1
                                    ).rename({'bssid_x': f'bssid_x_{i}', 'bssid_y': f'bssid_y_{i}'}, axis=1)

100it [03:33,  2.13s/it]


In [12]:
for i, bssid in tqdm(enumerate(BSSID_FEATS)):
    test_df = test_df.merge(bssid_uni_df[['bssid', 'bssid_x', 'bssid_y', 'site']],
                               how='left', left_on=[bssid, 'site_id'], right_on=['bssid', 'site']
                              ).drop(['bssid', 'site'], axis=1
                                    ).rename({'bssid_x': f'bssid_x_{i}', 'bssid_y': f'bssid_y_{i}'}, axis=1)

100it [00:10,  9.99it/s]


In [13]:
for i, bssid in tqdm(enumerate(POSX_FEATS)):
    train_df[bssid] = train_df[bssid].fillna(np.mean(train_df[bssid]))
    test_df[bssid] = test_df[bssid].fillna(np.mean(test_df[bssid]))
for i, bssid in tqdm(enumerate(POSY_FEATS)):
    train_df[bssid] = train_df[bssid].fillna(np.mean(train_df[bssid]))
    test_df[bssid] = test_df[bssid].fillna(np.mean(test_df[bssid]))

#     break

100it [00:00, 613.84it/s]
100it [00:00, 621.67it/s]


In [14]:
sub = pd.read_csv(DATA_DIR/'indoor-location-navigation/sample_submission.csv', index_col=0)

BSSIDとRSSIは100ずつ存在しているけど全てが必要なわけではないみたい  
ここでは20だけ取り出している。

In [15]:
train_df.iloc[:, 100:110]

Unnamed: 0,rssi_0,rssi_1,rssi_2,rssi_3,rssi_4,rssi_5,rssi_6,rssi_7,rssi_8,rssi_9
0,-32,-39,-47,-48,-48,-49,-51,-52,-54,-56
1,-29,-34,-47,-48,-48,-49,-52,-52,-52,-53
2,-33,-39,-48,-48,-49,-52,-54,-55,-55,-55
3,-46,-48,-49,-50,-51,-52,-54,-56,-57,-57
4,-42,-49,-51,-51,-52,-53,-54,-55,-55,-55
...,...,...,...,...,...,...,...,...,...,...
258120,-53,-63,-64,-66,-68,-68,-68,-68,-70,-71
258121,-58,-64,-66,-67,-68,-68,-69,-70,-71,-71
258122,-57,-58,-60,-64,-66,-67,-68,-69,-71,-73
258123,-58,-64,-66,-66,-68,-69,-69,-71,-71,-72


bssid_NはN個目のBSSIDを示しておりRSSI値が大きい順に番号が振られている。
100個しかない


In [16]:
# get numbers of bssids to embed them in a layer

# train
wifi_bssids = []
# bssidを列ごとにリストに入れていく
for i in range(100):
    wifi_bssids.extend(train_df.iloc[:,i].values.tolist())
wifi_bssids = list(set(wifi_bssids))

train_wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(train): {train_wifi_bssids_size}')

# test
wifi_bssids_test = []
for i in range(100):
    wifi_bssids_test.extend(test_df.iloc[:,i].values.tolist())
wifi_bssids_test = list(set(wifi_bssids_test))

test_wifi_bssids_size = len(wifi_bssids_test)
print(f'BSSID TYPES(test): {test_wifi_bssids_size}')


wifi_bssids.extend(wifi_bssids_test)
wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(all): {wifi_bssids_size}')


BSSID TYPES(train): 61206
BSSID TYPES(test): 33042
BSSID TYPES(all): 94248


## preprocessing

In [17]:
# preprocess

le = LabelEncoder()
le.fit(wifi_bssids)
le_site = LabelEncoder()
le_site.fit(train_df['site_id'])

ss = StandardScaler()
ss.fit(train_df.loc[:,RSSI_FEATS])

ss_posx = StandardScaler()
ss_posx.fit(train_df.loc[:,POSX_FEATS])

ss_posy = StandardScaler()
ss_posy.fit(train_df.loc[:,POSY_FEATS])


def preprocess(input_df, le=le, le_site=le_site, ss=ss):
    output_df = input_df.copy()
    # RSSIの正規化
    output_df.loc[:,RSSI_FEATS] = ss.transform(input_df.loc[:,RSSI_FEATS])
    output_df.loc[:,POSX_FEATS] = ss_posx.transform(input_df.loc[:,POSX_FEATS])
    output_df.loc[:,POSY_FEATS] = ss_posy.transform(input_df.loc[:,POSY_FEATS])


    # BSSIDのLE(1からふる)
    for i in BSSID_FEATS:
        output_df.loc[:,i] = le.transform(input_df.loc[:,i])
#         output_df.loc[:,i] = output_df.loc[:,i] + 1  # 0からではなく1から番号を振りたいため なぜ？

    # site_idのLE
    output_df.loc[:, 'site_id'] = le_site.transform(input_df.loc[:, 'site_id'])

    # なぜ２重でやる？
#     output_df.loc[:,RSSI_FEATS] = ss.transform(output_df.loc[:,RSSI_FEATS])
    return output_df

train = preprocess(train_df)
test = preprocess(test_df)

train  

  return self.partial_fit(X, y)


Unnamed: 0,bssid_0,bssid_1,bssid_2,bssid_3,bssid_4,bssid_5,bssid_6,bssid_7,bssid_8,bssid_9,...,bssid_x_95,bssid_y_95,bssid_x_96,bssid_y_96,bssid_x_97,bssid_y_97,bssid_x_98,bssid_y_98,bssid_x_99,bssid_y_99
0,52392,35870,2764,34897,52709,35259,42719,33509,23416,15248,...,-1.025369e-01,7.024569e-01,-0.423699,1.162873,0.080504,1.122268,-9.953795e-02,0.707304,0.009011,0.360876
1,35870,52392,7486,34897,52709,35259,21970,15248,17024,5350,...,1.843030e-01,2.581074e-01,-0.780503,1.253044,-0.105976,0.715343,-1.999884e-01,1.191787,0.117968,1.121007
2,35870,52392,52709,34897,35259,23416,49407,6672,7486,48500,...,3.734856e-01,1.090571e+00,0.082364,1.120134,-0.113916,0.656415,-1.069940e-01,0.727827,-0.467115,1.204502
3,23416,34897,35259,52392,35870,3706,49407,15612,10166,4977,...,4.908429e-01,1.075836e+00,-0.108160,0.725587,-0.082866,0.543102,-1.036072e-01,0.718640,-0.461139,1.197266
4,35870,35259,23416,19472,52392,3706,49407,18305,21409,52794,...,4.029704e-01,1.087480e+00,0.082364,1.120134,0.111995,1.115840,8.358646e-02,1.126110,-0.101391,0.719398
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
258120,35065,14545,16494,21326,50465,22830,943,30429,43802,32684,...,4.502291e-16,7.631404e-16,-0.146412,-0.175087,1.597496,0.289370,3.305615e-01,0.293692,0.118656,0.476024
258121,35065,16494,21326,943,50465,22830,30059,33363,59581,14545,...,1.097535e+00,-5.809640e-01,0.480270,0.217382,1.111275,-0.481343,3.305615e-01,0.293692,1.612922,0.292242
258122,943,35065,14545,16494,21326,22830,50465,30059,48476,59581,...,-2.554970e-01,-2.184554e-01,0.526810,0.240267,0.326511,0.290964,3.473219e-01,0.322715,-0.144139,-0.174254
258123,14545,16494,35065,21326,50465,943,30059,48476,58803,22830,...,4.773981e-01,2.174025e-01,0.068539,-0.017422,1.111275,-0.481343,2.926268e-01,0.226648,1.612922,0.292242


In [18]:
site_count = len(train['site_id'].unique())
site_count

24

## PyTorch model
- embedding layerが重要  

In [19]:
# dataset
from torch.utils.data import Dataset, DataLoader
class IndoorDataset(Dataset):
    def __init__(self, df, phase='train'):
        self.df = df
        self.phase = phase
        self.bssid_feats = df[BSSID_FEATS].values.astype(int)
        self.rssi_feats = df[RSSI_FEATS].values.astype(np.float32)
        self.posx_feats = df[POSX_FEATS].values.astype(np.float32)
        self.posy_feats = df[POSY_FEATS].values.astype(np.float32)

        self.site_id = df['site_id'].values.astype(int)

        if phase in ['train', 'valid']:
            self.xy = df[['x', 'y']].values.astype(np.float32)
            self.floor = df['floor'].values.astype(np.float32)
        
    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        
        feature = {
            'BSSID_FEATS':self.bssid_feats[idx],
            'RSSI_FEATS':self.rssi_feats[idx],
            'POSX_FEATS':self.posx_feats[idx],
            'POSY_FEATS':self.posy_feats[idx],
            'site_id':self.site_id[idx]
        }
        if self.phase in ['train', 'valid']:
            target = {
                'xy':self.xy[idx],
                'floor':self.floor[idx]
            }
        else:
            target = {}
        return feature, target

In [20]:
import torch
from torch import nn

class LSTMModel(nn.Module):
    def __init__(self, bssid_size=94248, site_size=24, embedding_dim=64):
        super(LSTMModel, self).__init__()
        
        # bssid
        # ->64次元に圧縮後sequence化にする
        # wifi_bssids_sizeが辞書の数を表す
        self.bssid_embedding = nn.Embedding(bssid_size, 64, max_norm=True)
        # site
        # ->2次元に圧縮後sequence化する
        # site_countが辞書の数を表す       
        self.site_embedding = nn.Embedding(site_size, 64, max_norm=True)

        # rssi
        # 次元を64倍に線形変換
        self.rssi = nn.Sequential(
            nn.BatchNorm1d(NUM_FEATS),
            nn.Linear(NUM_FEATS, NUM_FEATS * 64)
        )
        self.posx = nn.Sequential(
            nn.BatchNorm1d(NUM_FEATS),
            nn.Linear(NUM_FEATS, NUM_FEATS * 64)
        )
        self.posy = nn.Sequential(
            nn.BatchNorm1d(NUM_FEATS),
            nn.Linear(NUM_FEATS, NUM_FEATS * 64)
        )
        
        
        concat_size = 64 + (NUM_FEATS * 64) + (NUM_FEATS * 64) + (NUM_FEATS * 64) + (NUM_FEATS * 64)
        self.linear_layer2 = nn.Sequential(
            nn.BatchNorm1d(concat_size),
            nn.Dropout(0.3),
            nn.Linear(concat_size, 256),
            nn.ReLU()
        )
        self.bn1 = nn.BatchNorm1d(concat_size)

        self.flatten = nn.Flatten()

        self.dropout1 = nn.Dropout(0.3)
        self.linear1 = nn.Linear(in_features=concat_size, out_features=256)#, bias=False)
        self.bn2 = nn.BatchNorm1d(256)

        self.batch_norm1 = nn.BatchNorm1d(1)
        self.lstm1 = nn.LSTM(input_size=256,hidden_size=128,dropout=0.3, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=128,hidden_size=16,dropout=0.1, batch_first=True)

        self.fc_xy = nn.Linear(16, 2)
        # self.fc_x = nn.Linear(16, 1)
        # self.fc_y = nn.Linear(16, 1)
        self.fc_floor = nn.Linear(16, 1)

    
    def forward(self, x):
        # input embedding
        batch_size = x["site_id"].shape[0]
        x_bssid = self.bssid_embedding(x['BSSID_FEATS'])
        x_bssid = self.flatten(x_bssid)
        
        x_site_id = self.site_embedding(x['site_id'])
        x_site_id = self.flatten(x_site_id)

        x_rssi = self.rssi(x['RSSI_FEATS'])
        
        x_posx = self.posx(x['POSX_FEATS'])
        x_posy = self.posx(x['POSY_FEATS'])

        x = torch.cat([x_bssid, x_site_id, x_rssi, x_posx, x_posy], dim=1)
        x = self.linear_layer2(x)

        # lstm layer
        x = x.view(batch_size, 1, -1)  # [batch, 1]->[batch, 1, 1]
        x = self.batch_norm1(x)
        x, _ = self.lstm1(x)
        x = torch.relu(x)
        x, _ = self.lstm2(x)
        x = torch.relu(x)

        # output [batch, 1, 1] -> [batch]
        # x_ = self.fc_x(x).view(-1)
        # y_ = self.fc_y(x).view(-1)
        xy = self.fc_xy(x).squeeze(1)
        floor = torch.relu(self.fc_floor(x)).view(-1)
        # return {"x":x_, "y":y_, "floor":floor} 
        return {"xy": xy, "floor": floor}

In [21]:
def mean_position_error(xhat, yhat, fhat, x, y, f):
    intermediate = np.sqrt(np.power(xhat-x, 2) + np.power(yhat-y, 2)) + 15 * np.abs(fhat-f)
    return intermediate.sum()/xhat.shape[0]

def to_np(input):
    return input.detach().cpu().numpy()

In [22]:
def get_optimizer(model: nn.Module, config: dict):
    optimizer_config = config["optimizer"]
    optimizer_name = optimizer_config.get("name")
    base_optimizer_name = optimizer_config.get("base_name")
    optimizer_params = optimizer_config['params']

    if hasattr(optim, optimizer_name):
        optimizer = optim.__getattribute__(optimizer_name)(model.parameters(), **optimizer_params)
        return optimizer
    else:
        base_optimizer = optim.__getattribute__(base_optimizer_name)
        optimizer = globals().get(optimizer_name)(
            model.parameters(), 
            base_optimizer,
            **optimizer_config["params"])
        return  optimizer

def get_scheduler(optimizer, config: dict):
    scheduler_config = config["scheduler"]
    scheduler_name = scheduler_config.get("name")

    if scheduler_name is None:
        return
    else:
        return optim.lr_scheduler.__getattribute__(scheduler_name)(
            optimizer, **scheduler_config["params"])


def get_criterion(config: dict):
    loss_config = config["loss"]
    loss_name = loss_config["name"]
    loss_params = {} if loss_config.get("params") is None else loss_config.get("params")
    if hasattr(nn, loss_name):
        criterion = nn.__getattribute__(loss_name)(**loss_params)
    else:
        criterion = globals().get(loss_name)(**loss_params)

    return criterion

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

In [23]:
# Learner class(pytorch-lighting)
class Learner(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.xy_criterion = get_criterion(config)
        self.f_criterion = get_criterion(config)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = self.xy_criterion(output["xy"], y["xy"])
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        xy_loss = self.xy_criterion(output["xy"], y["xy"])
        f_loss = self.f_criterion(output["floor"], y["floor"])
        loss = xy_loss  # + f_loss
        mpe = mean_position_error(
            to_np(output['xy'][:, 0]), to_np(output['xy'][:, 1]), 0, 
            to_np(y['xy'][:, 0]), to_np(y['xy'][:, 1]), 0)
        
        # floor lossは現状は無視して良い
        self.log(f'Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/xy', xy_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/floor', f_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'MPE/val', mpe, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return mpe
    
    def validation_epoch_end(self, outputs):
        avg_loss = np.mean(outputs)
        print(f'epoch = {self.current_epoch}, mpe_loss = {avg_loss}')

    def configure_optimizers(self):
        optimizer = get_optimizer(self.model, self.config)
        scheduler = get_scheduler(optimizer, self.config)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "Loss/val"}

In [24]:
# oof
def evaluate(model, loaders, phase):
    x_list = []
    y_list = []
    f_list = []
    with torch.no_grad():
        for batch in loaders[phase]:
            x, y = batch
            output = model(x)
            x_list.append(to_np(output['xy'][:, 0]))
            y_list.append(to_np(output['xy'][:, 1]))
            f_list.append(to_np(output['floor']))

    x_list = np.concatenate(x_list)
    y_list = np.concatenate(y_list)
    f_list = np.concatenate(f_list)
    return x_list, y_list, f_list

## train

In [26]:
!wandb login --relogin d6579ae9cb57a257a99fafb83b92f724e08968b2

/bin/sh: 1: wandb: not found


In [None]:
oofs = []  # 全てのoofをdfで格納する
predictions = []  # 全ての予測値をdfで格納する
val_scores = []
# skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
gkf = GroupKFold(n_splits=N_SPLITS)
# for fold, (trn_idx, val_idx) in enumerate(skf.split(train.loc[:, 'path'], train.loc[:, 'path'])):
for fold, (trn_idx, val_idx) in enumerate(gkf.split(train.loc[:, 'path'], groups=train.loc[:, 'path'])):

    # 指定したfoldのみループを回す

    print('=' * 20)
    print(f'Fold {fold}')
    print('=' * 20)

    # train/valid data
    trn_df = train.loc[trn_idx, BSSID_FEATS + RSSI_FEATS + POSX_FEATS + POSY_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)
    val_df = train.loc[val_idx, BSSID_FEATS + RSSI_FEATS + POSX_FEATS + POSY_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)

    # data loader
    loaders = {}
    loader_config = config["loader"]
    loaders["train"] = DataLoader(IndoorDataset(trn_df, phase="train"), **loader_config["train"], worker_init_fn=worker_init_fn) 
    loaders["valid"] = DataLoader(IndoorDataset(val_df, phase="valid"), **loader_config["valid"], worker_init_fn=worker_init_fn)
    loaders["test"] = DataLoader(IndoorDataset(test, phase="test"), **loader_config["test"], worker_init_fn=worker_init_fn)
    
    # model
    model = LSTMModel(wifi_bssids_size, site_count)
    model_name = model.__class__.__name__
    
    # loggers
    RUN_NAME = f'exp{str(EXP_NAME)}'
    wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type=RUN_NAME + f'-fold-{fold}')
    wandb.run.name = RUN_NAME + f'-fold-{fold}'
    wandb_config = wandb.config
    wandb_config.model_name = model_name
    wandb.watch(model)
    
    
    loggers = []
    loggers.append(WandbLogger())

    learner = Learner(model, config)
    
    # callbacks
    callbacks = []
    checkpoint_callback = ModelCheckpoint(
        monitor=f'Loss/val',
        mode='min',
        dirpath=OUTPUT_DIR,
        verbose=False,
        filename=f'{model_name}-{learner.current_epoch}-{fold}')
    callbacks.append(checkpoint_callback)

    early_stop_callback = EarlyStopping(
        monitor='Loss/val',
        min_delta=0.00,
        patience=3,
        verbose=True,
        mode='min')
    callbacks.append(early_stop_callback)
    
    trainer = pl.Trainer(
        logger=loggers,
        checkpoint_callback=callbacks,
        max_epochs=MAX_EPOCHS,
        default_root_dir=OUTPUT_DIR,
        gpus=1,
        fast_dev_run=DEBUG,
        deterministic=True,
        benchmark=True,
#         precision=16,
#         progress_bar_refresh_rate=0  # vscodeの時progress barの動作が遅いので表示しない
        )


    trainer.fit(learner, train_dataloader=loaders['train'], val_dataloaders=loaders['valid'])

    #############
    # validation (to make oof)
    #############
    model.eval()
    oof_x, oof_y, oof_f = evaluate(model, loaders, phase="valid")
    val_df["oof_x"] = oof_x
    val_df["oof_y"] = oof_y
    val_df["oof_floor"] = oof_f
    oofs.append(val_df)
    
    val_score = mean_position_error(
        val_df["oof_x"].values, val_df["oof_y"].values, 0,
        val_df['x'].values, val_df['y'].values, 0)
    val_scores.append(val_score)
    print(f"fold {fold}: mean position error {val_score}")

    #############
    # inference
    #############
    preds_x, preds_y, preds_f = evaluate(model, loaders, phase="test")
    test_preds = pd.DataFrame(np.stack((preds_f, preds_x, preds_y))).T
    test_preds.columns = sub.columns
    test_preds["site_path_timestamp"] = test["site_path_timestamp"]
    test_preds["floor"] = test_preds["floor"].astype(int)
    predictions.append(test_preds)
    

Fold 0


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 21.4 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
21.4 M    Trainable params
0         Non-trainable params
21.4 M    Total params
85.696    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 133.48785400390625


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 160.24827582633142


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 154.85700512672207


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 149.6632346760345


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 144.62097875009667


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 139.7233601163443


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 134.97506224030732


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 130.387511923184


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 125.96876751637001


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 121.72579204317158


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 117.66600148751421


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 113.79264809241319


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 110.12347934587375


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 105.40796394614752


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 101.10142618573677


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 97.35042585151864


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 93.84472286235474


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 90.53138465090414


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 87.40602415979082


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 84.42999034381704


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 81.20852773874113


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 77.647579299358


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 74.23379113587795


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 71.48434490477129


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 69.0999061479163


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 66.92406840427157


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 64.95147213286553


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 63.10758409843378


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 61.44638931136667


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 59.9042029283598


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 58.521678810052975


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 57.175572811297776


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 55.963680082664425


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 54.756518433114856


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 53.60108599712214


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 52.45383876100695


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 51.237998969556294


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 49.994980734156954


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 48.72710801717217


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 47.44296452823876


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 46.06891502884702


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 44.73189452787298


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 43.37429887368364


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 41.94142175853287


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 40.555240226683395


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 39.19067077969977


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 37.898937510522074


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 36.589560880222216


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 35.30737895683417


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 34.05764064682175


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 32.88665330134788


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 31.678407696805586


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 30.575493906684304


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 29.50356892088644


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 28.3824658556601


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 27.39979346147025


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 26.33854157804957


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 25.356579545398194


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 24.466570618584466


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 23.546876481944345


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 22.760549215240122


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 21.888126254639936


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 21.09178571413337


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 20.369860682012597


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 19.616453466520937


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 18.968037430980417


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 18.276691464405804


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 17.721304299805006


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 17.143194681222035


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 16.634482274919176


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 16.07168599628055


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 15.519026378493846


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 15.142294066489681


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 14.62161009482471


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 14.16624450969085


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 13.843047685339899


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 13.370749727592402


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 13.065145289188356


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 12.778093374415617


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 12.54125779272445


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 12.2219504208973


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 11.921819241734251


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 11.622168525702465


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 11.51952115357165


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 11.222996076318367


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 11.080408415186175


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 10.843930397333622


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 10.658028917618108


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 10.486835032014842


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 10.320581370494224


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 10.235396795511662


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 10.056385416020886


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 10.026886614486791


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 9.895429798286186


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 9.85690326472945


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 9.651712391919428


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 9.650517142778506


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 9.528083828863322


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 9.498175454759014


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 9.470447091950742


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 9.399715651554914


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 9.332672939211667


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 9.237281577299694


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 9.23198158769602


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 9.130025262471612


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 9.096855022994736


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 9.087322753268616


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 9.082844374261766


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 8.994877233877387


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 8.974480499774131


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 8.948566728979609


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 8.893445151745059


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 8.832228915714982


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 8.839963278817658


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 8.79794968494511


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 8.752352625936288


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 8.73950078558769


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 8.743541602544056


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 8.698499081033429


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 8.690679690694838


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 8.652287548089625


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 8.66921362760807


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 8.60830184908752


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 8.664731771983444


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 8.5362904544757


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 8.591526385236772


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 8.61417795409001


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 8.544388694763184


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 8.564566992957994


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 8.460630818102064


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 8.503454501319224


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 8.481819154564974


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 8.468043812911178


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 8.472681724857159


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 8.461293444380763


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 8.473783171247728


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 8.501195850994398


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 8.47394768213351


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 8.474414784204287


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 8.476171548962245


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 8.48020733772771


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 8.454792475836458


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 8.462630836773418


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 8.503881713842748


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 8.465060548399101


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 8.462002953633224


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 8.469794134974272


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 8.476121491236123


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 8.46591802315998


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 8.46990736403157


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 8.450773200311145


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 8.485254965078434


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 8.468566624369814


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 8.467669331634468


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 8.443978694044857


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 8.453343051306092


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 8.492127333326124


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 8.461069763485748


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 8.449591478671039


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 8.453115632341015


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 8.443732581741253


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 8.487265942417471


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 8.48355959800627


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 8.494327526603525


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 8.454270846389655


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 8.459096889951303


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 8.455337120327203


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 8.485909644200763


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 8.47135902238156


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 8.492694629620312


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 8.470659196117136


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 8.448026193197757


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 8.474550079518998


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 8.47949389433819


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 8.481092108858139


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 8.450908734305486


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 8.469807086849046


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 8.461349101713665


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 8.451003177345111


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 8.479964749631177


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 8.47072678395462


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 8.436925739428299


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 8.482063874081948


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 8.449339343478114


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 8.450128497755243


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 8.472729556328623


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 8.492772505443043


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 8.463049440045099


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 8.470466766835091


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 8.455725213930089


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 8.47619904125438


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 8.4933220678495


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 8.449658731206183


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 8.484685254805159


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 8.452105662568613


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 8.455255338405548


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 8.480182033276655


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 8.449173347881942


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 8.465969438458433


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 8.4747917878357


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 8.476215443175132
fold 0: mean position error 8.474736752096453
Fold 1


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,61.30841
Loss/xy,61.30841
Loss/floor,4.91672
MPE/val,8.47474
epoch,199.0
trainer/global_step,80799.0
_runtime,1228.0
_timestamp,1617728364.0
_step,199.0


0,1
Loss/val,█▇▅▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▇▅▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,███▇▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
MPE/val,█▇▆▅▄▄▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 21.4 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
21.4 M    Trainable params
0         Non-trainable params
21.4 M    Total params
85.696    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 156.56138610839844


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 160.81987835385874


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 157.41902730996205


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 154.09503828845195


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 150.82351606943334


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 147.5990671211825


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 144.4204911245214


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 141.29021629497905


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 138.20807123316396


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 135.173680366956


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 132.18965814722648


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 129.25870361350343


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 126.38471343576457


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 123.57004376135389


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 120.81909493594651


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 118.13722286879772


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 115.53106395154896


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 113.0021038052754


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 110.5517716547189


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 107.07593332804977


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 103.35735501897564


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 99.10915499332275


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 95.77392567282995


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 92.7269194516208


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 89.91175625353964


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 87.24171480850904


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 84.71698532115602


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 82.3307663224253


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 80.03357783455868


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 77.35666445606034


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 74.52944455297312


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 72.23485652785837


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 70.10450846406007


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 67.9387395599934


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 65.88439204740885


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 64.04004579116386


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 62.36843573772984


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 60.8299914381486


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 59.432159434692345


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 58.15025793103333


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 56.942802793406166


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 55.77621769686806


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 54.39341460372934


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 53.02683596967054


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 51.76470877294702


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 50.51189425162402


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 49.26367670535763


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 48.07192693880289


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 46.87614397900981


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 45.67552695103836


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 44.45980549702897


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 43.227366073610625


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 42.010126353815394


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 40.788757013199834


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 39.611873276321155


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 38.44279357022578


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 37.25393115925109


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 36.12817942514347


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 35.03399650293793


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 33.85783241851842


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 32.79440325522048


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 31.716508064936537


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 30.681565221498783


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 29.70084294241948


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 28.725107297969547


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 27.847287095591433


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 26.943589192325458


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 26.10635067704811


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 25.225291733705657


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 24.471021556321034


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 23.77048961522764


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 22.952037803403293


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 22.167631602656655


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 21.431397990280622


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 20.86135102378121


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 20.154903373696044


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 19.581349556844597


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 19.000597036180626


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 18.48088332543348


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 17.916643717226485


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 17.438514759728015


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 16.97006183565392


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 16.507940324617806


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 16.14296394392579


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 15.713180665197799


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 15.380083619763703


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 14.956077470396172


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 14.763279054449546


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 14.249483097345244


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 13.897615869729826


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 13.588112236893865


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 13.420925691867888


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 13.033445478604849


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 12.767459413054103


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 12.49353154923934


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 12.293029055531523


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 12.050535985816683


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 11.813619429711068


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 11.663051477575552


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 11.447153400127876


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 11.269496280434664


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 11.134154466807876


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 10.930507150771108


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 10.891901456858436


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 10.689573923298699


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 10.5655080240524


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 10.423514443410298


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 10.426556068869198


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 10.259397321100073


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 10.12380194687885


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 10.014052902259449


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 9.955124703408954


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 9.874631706431083


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 9.80174708562461


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 9.753315249094799


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 9.640284834246893


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 9.647912048513138


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 9.51634891508899


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 9.484181805212414


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 9.420467652491379


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 9.373196969851547


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 9.296093593501878


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 9.22142712258154


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 9.242479171552976


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 9.135273311283953


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 9.101338383158627


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 9.061563645353411


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 9.072274916509452


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 8.954353167345584


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 8.942130257879223


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 8.914447641700384


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 8.913765301768281


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 8.89374153160535


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 8.81751767220161


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 8.874417195972761


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 8.755157350263557


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 8.732157702393218


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 8.71362205370957


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 8.682468890532826


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 8.655838234245465


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 8.6402152504696


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 8.574642341738187


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 8.553237871870998


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 8.552516667416572


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 8.53765068856241


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 8.558211860778933


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 8.49965628734493


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 8.473828419655927


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 8.462325303294872


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 8.494336972755962


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 8.457585572568775


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 8.422219997621513


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 8.404303422715682


Validating: 0it [00:00, ?it/s]

In [29]:
if len(oofs) > 1:
    oofs_df = pd.concat(oofs)
else:
    oofs_df = oofs[0]
oofs_df.to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}.csv", index=False)
oofs_df

Unnamed: 0,bssid_0,bssid_1,bssid_2,bssid_3,bssid_4,bssid_5,bssid_6,bssid_7,bssid_8,bssid_9,...,bssid_y_97,bssid_y_98,bssid_y_99,site_id,x,y,floor,oof_x,oof_y,oof_floor
0,54638,6982,8546,26610,28338,37494,2205,25810,39211,20529,...,8.520601e-01,1.049220,1.050485,0,192.90768,159.26582,-1,196.314224,163.230988,0.036809
1,26610,6982,8546,2205,37494,51325,28338,39211,14273,25810,...,1.045541e+00,0.989748,1.050512,0,192.90768,159.26582,-1,198.174408,163.540314,0.030045
2,8546,6982,51325,26610,20054,8702,30893,15036,39211,37494,...,1.097333e+00,1.114323,1.050342,0,198.36833,163.52063,-1,196.313293,162.206604,0.034472
3,6982,51325,26610,8702,30893,20054,39211,11853,15036,54638,...,1.110497e+00,1.101270,1.102449,0,198.36833,163.52063,-1,196.413803,164.353668,0.030483
4,51325,6982,26610,8702,44925,39211,37494,27631,14273,45777,...,1.041316e+00,1.017684,1.032227,0,198.36833,163.52063,-1,196.816406,161.940338,0.035933
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
51620,50465,53733,943,9222,20964,18804,962,35065,52136,14545,...,-5.102916e-16,0.000000,0.309021,23,132.28098,130.23691,6,121.991768,138.803467,0.164411
51621,9222,18804,20964,50465,53733,943,14545,35065,52136,962,...,-5.102916e-16,0.000000,-0.224974,23,122.73780,138.97691,6,123.435173,137.288727,0.159062
51622,53733,9222,20964,50465,18804,14545,13391,962,943,35065,...,-2.063195e-01,1.299566,0.687591,23,122.73780,138.97691,6,122.310532,140.577744,0.162617
51623,50465,9222,53733,20964,943,14545,18804,962,29009,35065,...,1.782349e-01,0.000000,-0.417125,23,122.73780,138.97691,6,121.966888,137.833160,0.161954


In [30]:
    # foldの結果を平均した後、reindexでsubmission fileにindexを合わせる
all_preds = pd.concat(predictions).groupby('site_path_timestamp').mean().reindex(sub.index)

all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,88.836189,104.709816
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,84.591652,102.256493
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,84.202522,104.898315
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,87.837555,107.035484
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,88.760399,107.360710
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,0,215.667282,92.139450
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,0,212.510925,99.947670
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,0,209.150009,106.003639
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,0,202.931107,114.472313


In [31]:
# floorの数値を置換
simple_accurate_99 = pd.read_csv('../01/submission.csv')
all_preds['floor'] = simple_accurate_99['floor'].values
all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,88.836189,104.709816
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,84.591652,102.256493
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,84.202522,104.898315
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,87.837555,107.035484
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,88.760399,107.360710
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,5,215.667282,92.139450
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,5,212.510925,99.947670
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,5,209.150009,106.003639
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,5,202.931107,114.472313


In [32]:
all_preds.to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}.csv")

In [33]:
print(f"CV:{np.mean(val_scores)}")

CV:8.215602696222664


In [34]:
wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type='summary')
wandb.run.name = 'summary'
wandb.log({'CV_score': np.mean(val_scores)})
wandb.save(utils.get_notebook_path())
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,58.80938
Loss/xy,58.80938
Loss/floor,4.76888
MPE/val,8.42835
epoch,199.0
trainer/global_step,80799.0
_runtime,1260.0
_timestamp,1617733414.0
_step,199.0


0,1
Loss/val,█▇▆▅▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▇▆▅▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,▇▇▇▇▁▂▃▃▃▃▄▅▅▅▆▆▆▆▇▇▇▇▇█████████████████
MPE/val,█▇▆▆▅▄▃▃▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███




VBox(children=(Label(value=' 0.00MB of 0.23MB uploaded (0.00MB deduped)\r'), FloatProgress(value=0.00269669146…

0,1
CV_score,8.2156
_runtime,2.0
_timestamp,1617768720.0
_step,0.0


0,1
CV_score,▁
_runtime,▁
_timestamp,▁
_step,▁
