# LSTM baseline

from kuto

In [1]:
import os
import sys
import glob
import pickle
import random

import numpy as np
import pandas as pd
import scipy.stats as stats
from pathlib import Path
from tqdm import tqdm

sys.path.append('../../')
import src.utils as utils
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

import pytorch_lightning as pl
# from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping

import wandb
from pytorch_lightning.loggers import WandbLogger


In [2]:
DATA_DIR = Path("/home/knikaido/work/Indoor-Location-Navigation/data/")
WIFI_DIR = DATA_DIR / 'indoorunifiedwifids'
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'
OUTPUT_DIR = Path('./output/')
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'

## config

In [3]:
configs = {
    'loss':{
        'name': 'MSELoss',
        'params':{}
    },
    'optimizer':{
        'name': 'Adam',
        'params':{
            'lr': 0.001,
        }
    },

    'scheduler':{
        'name': 'ReduceLROnPlateau',
        'params':{
            'factor': 0.1,
            'patience': 3,
        }
    },

    'loader':{
        'train':{
            'batch_size': 512,
            'shuffle': True,
            'num_workers': 4,
        },
        'valid':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        },
        'test':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        }
    }
}

In [4]:
# config
config = configs

# globals variable
SEED = 777
MAX_EPOCHS = 200
N_SPLITS = 5
DEBUG = False
# EXP_MESSAGE = config['globals']['exp_message']

EXP_NAME = 14
IS_SAVE = True

utils.set_seed(SEED)

In [5]:
# training target features
NUM_FEATS = 100
BSSID_FEATS = [f'bssid_{i}' for i in range(NUM_FEATS)]
RSSI_FEATS  = [f'rssi_{i}' for i in range(NUM_FEATS)]
POSX_FEATS = [f'posx_{i}' for i in range(NUM_FEATS)]

In [6]:
!wandb login e8aaf98060af90035c3c28a83b34452780aeec20

/bin/sh: 1: wandb: not found


## read data

In [7]:
train_df = pd.read_csv(WIFI_DIR / 'train_all.csv')
test_df = pd.read_csv(WIFI_DIR / 'test_all.csv')

In [8]:
train_df

Unnamed: 0,bssid_0,bssid_1,bssid_2,bssid_3,bssid_4,bssid_5,bssid_6,bssid_7,bssid_8,bssid_9,...,rssi_95,rssi_96,rssi_97,rssi_98,rssi_99,x,y,floor,path,site_id
0,db01605eac3f33540038bd9722aba25774871d43,965f254a2e8d05bbb40bd2413ff61de3ad6c4151,0b64e537cc3d1818ec46f94f8dc14043a98d0089,922e582c66016a2b9f64e38f89ebe82f66eefb24,dc4c46287575c45f3e32c022d868d047b485ed4c,93e20595eeef175d3aa3c3381f6a22ee792d48d9,b2b0ddbb5a2aadfc6ab2f388db584b6c280d3f82,8c936564ea4b4300576f53136505527eb5972c07,61c3aaf1a526f808c05952ea3f098e37354a674a,3f564032c7eebc173b38aee35225e323d4389faf,...,-79,-79,-79,-79,-79,107.85044,161.892620,-1,5e1580adf4c3420006d520d4,5a0546857ecc773753327266
1,965f254a2e8d05bbb40bd2413ff61de3ad6c4151,db01605eac3f33540038bd9722aba25774871d43,1f37bbb3f42125f665b83584d0376b21ec3eb43c,922e582c66016a2b9f64e38f89ebe82f66eefb24,dc4c46287575c45f3e32c022d868d047b485ed4c,93e20595eeef175d3aa3c3381f6a22ee792d48d9,5c10b343d767a30515e6015de25751a2883328f8,3f564032c7eebc173b38aee35225e323d4389faf,46c934893439700099d03a6892ea934ecb2729d6,16374260af7d03b10f167358a4f6a70620e131f4,...,-79,-79,-79,-80,-80,107.85044,161.892620,-1,5e1580adf4c3420006d520d4,5a0546857ecc773753327266
2,965f254a2e8d05bbb40bd2413ff61de3ad6c4151,db01605eac3f33540038bd9722aba25774871d43,dc4c46287575c45f3e32c022d868d047b485ed4c,922e582c66016a2b9f64e38f89ebe82f66eefb24,93e20595eeef175d3aa3c3381f6a22ee792d48d9,61c3aaf1a526f808c05952ea3f098e37354a674a,ce28608c3d091ac0d25d84459ebad253edf83e1f,1bb0e992cff45a54d29e97f47a7d1281435a5e3b,1f37bbb3f42125f665b83584d0376b21ec3eb43c,ca86c5b074c5768e481e069b751bf22c6d95bd48,...,-77,-78,-78,-78,-78,98.33065,163.343340,-1,5e1580adf4c3420006d520d4,5a0546857ecc773753327266
3,61c3aaf1a526f808c05952ea3f098e37354a674a,922e582c66016a2b9f64e38f89ebe82f66eefb24,93e20595eeef175d3aa3c3381f6a22ee792d48d9,db01605eac3f33540038bd9722aba25774871d43,965f254a2e8d05bbb40bd2413ff61de3ad6c4151,0f5daed11a61e0d6941a1a42ff428ca216d61003,ce28608c3d091ac0d25d84459ebad253edf83e1f,40d99a3e5214aa704f637b7d72631e69550ee256,2aa08d092d0199c06d22684642ef1c79d9722adb,149c09a117b9851201c75f97b4a7cc94b75fdcb4,...,-75,-76,-76,-77,-77,98.33065,163.343340,-1,5e1580adf4c3420006d520d4,5a0546857ecc773753327266
4,965f254a2e8d05bbb40bd2413ff61de3ad6c4151,93e20595eeef175d3aa3c3381f6a22ee792d48d9,61c3aaf1a526f808c05952ea3f098e37354a674a,51782c2fabefa97e99dca895fd36f1a47e214610,db01605eac3f33540038bd9722aba25774871d43,0f5daed11a61e0d6941a1a42ff428ca216d61003,ce28608c3d091ac0d25d84459ebad253edf83e1f,4c83a7a1e51bfa8a5fa20e854ab3feec057c52c9,599fa96d549ed870671d6bc1927aaa8bbaacca12,dc9fd0f591e9bfc22748106f31d72a23c1d294fd,...,-75,-76,-76,-77,-77,98.33065,163.343340,-1,5e1580adf4c3420006d520d4,5a0546857ecc773753327266
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
258120,930dfcc059cd5c29f94b37435b31b14adf6141c1,3c79d9efcdb4c5803ab1b97982c2f0b87519c477,4472ce0cc1af0641cdffd6cdc850e9dd2ec4ab91,5953d0b2247e16447d327eb2a8a9c1abe24ff425,d2aa98dcefa4d46c4a946d26c5311bddddad9d6c,5f583dcccc43b5b7ac25d270e29c92d878fb2be0,040bfca8631c5506cba1c22cf7750636ba26ea54,7fcc50af42706af2d8ecfe55b7ccbb731aae6e27,b719d7de6ffe1b1b7409d2eb5ac3268e36cf2675,89395d0ee75307b3beb30aef2f19fc680095d514,...,-84,-85,-85,-85,-85,122.68994,124.028015,6,5dcd5c88a4dbe7000630b084,5dc8cea7659e181adb076a3f
258121,930dfcc059cd5c29f94b37435b31b14adf6141c1,4472ce0cc1af0641cdffd6cdc850e9dd2ec4ab91,5953d0b2247e16447d327eb2a8a9c1abe24ff425,040bfca8631c5506cba1c22cf7750636ba26ea54,d2aa98dcefa4d46c4a946d26c5311bddddad9d6c,5f583dcccc43b5b7ac25d270e29c92d878fb2be0,7e5f2b71ff184401c1c31ffeab2861d86f1f9c25,8bed40c61bc42247c3b142f214f78dcc5850553f,f9100b425367a18de9f1cec34a34ea60cc7bf009,3c79d9efcdb4c5803ab1b97982c2f0b87519c477,...,-85,-85,-85,-85,-85,127.17589,123.677780,6,5dcd5c88a4dbe7000630b084,5dc8cea7659e181adb076a3f
258122,040bfca8631c5506cba1c22cf7750636ba26ea54,930dfcc059cd5c29f94b37435b31b14adf6141c1,3c79d9efcdb4c5803ab1b97982c2f0b87519c477,4472ce0cc1af0641cdffd6cdc850e9dd2ec4ab91,5953d0b2247e16447d327eb2a8a9c1abe24ff425,5f583dcccc43b5b7ac25d270e29c92d878fb2be0,d2aa98dcefa4d46c4a946d26c5311bddddad9d6c,7e5f2b71ff184401c1c31ffeab2861d86f1f9c25,ca69ae425b53d4c2fae3d97ec4ec61897a4a6b73,f9100b425367a18de9f1cec34a34ea60cc7bf009,...,-84,-84,-85,-85,-85,127.17589,123.677780,6,5dcd5c88a4dbe7000630b084,5dc8cea7659e181adb076a3f
258123,3c79d9efcdb4c5803ab1b97982c2f0b87519c477,4472ce0cc1af0641cdffd6cdc850e9dd2ec4ab91,930dfcc059cd5c29f94b37435b31b14adf6141c1,5953d0b2247e16447d327eb2a8a9c1abe24ff425,d2aa98dcefa4d46c4a946d26c5311bddddad9d6c,040bfca8631c5506cba1c22cf7750636ba26ea54,7e5f2b71ff184401c1c31ffeab2861d86f1f9c25,ca69ae425b53d4c2fae3d97ec4ec61897a4a6b73,f5c69326982eb7d74f899b933c4cb5a3d1592af2,5f583dcccc43b5b7ac25d270e29c92d878fb2be0,...,-85,-85,-85,-85,-85,127.17589,123.677780,6,5dcd5c88a4dbe7000630b084,5dc8cea7659e181adb076a3f


In [10]:
# use_sites = train_df['site_id'].unique()
# use_sites

# train_pos_x = []
# test_pos_x = []

# for site in use_sites:

#     bssid_csv = pd.read_csv(str(DATA_DIR) + f'/bssid_position/nb013_bssid_position/nb013_bssid_position_{site}.csv')
#     bssid_tgts = bssid_csv['bssid'].unique()

#     bssid_uni = []
#     for tgt in tqdm(bssid_tgts):
#         bssid_rssi = bssid_csv[bssid_csv['bssid'] == tgt]['n_samples_rssi_over_m50']
#         bssid_series = bssid_csv.iloc[bssid_rssi.index[bssid_rssi.values.argmax()]]
#         bssid_uni.append(bssid_series)
#     bssid_uni_df = pd.DataFrame(bssid_uni)
    
#     bssid_x_values = bssid_uni_df['bssid_x'].values
#     bssid_x_values = np.append(bssid_x_values, -100)
    
#     train_df_onesite = train_df[train_df['site_id'] == site]
#     for i in tqdm(range(len(train_df_onesite))):
#         order = []
#         for ordered_bssid in train_df_onesite[BSSID_FEATS].iloc[i].values:
#             itti_ind = np.where(bssid_uni_df['bssid'] == ordered_bssid)[0]
#             if len(itti_ind) > 0:
#                 order.append(itti_ind[0])
#             else:
#                 order.append(-1)
#         order = np.array(order).reshape(-1)
#         train_pos_x.append(bssid_x_values[order])
        
#     test_df_onesite = test_df[test_df['site_id'] == site]
#     for i in tqdm(range(len(test_df_onesite))):
#         order = []
#         for ordered_bssid in test_df_onesite[BSSID_FEATS].iloc[i].values:
#             itti_ind = np.where(bssid_uni_df['bssid'] == ordered_bssid)[0]
#             if len(itti_ind) > 0:
#                 order.append(itti_ind[0])
#             else:
#                 order.append(-1)
#         order = np.array(order).reshape(-1)
#         test_pos_x.append(bssid_x_values[order])

# train_pos_x = np.array(train_pos_x)
# test_pos_x = np.array(test_pos_x)

# np.save(DATA_DIR / 'train_pos_x.npy', train_pos_x)
# np.save(DATA_DIR / 'test_pos_x.npy', test_pos_x)

In [11]:
train_pos_x = np.load(DATA_DIR / 'train_pos_x.npy')
test_pos_x = np.load(DATA_DIR / 'test_pos_x.npy')

In [12]:
for i, retu in enumerate(POSX_FEATS):
    train_df[retu] = train_pos_x[:, i]
for i, retu in enumerate(POSX_FEATS):
    test_df[retu] = test_pos_x[:, i]

In [13]:
sub = pd.read_csv(DATA_DIR/'indoor-location-navigation/sample_submission.csv', index_col=0)

BSSIDとRSSIは100ずつ存在しているけど全てが必要なわけではないみたい  
ここでは20だけ取り出している。

In [14]:
train_df.iloc[:, 100:110]

Unnamed: 0,rssi_0,rssi_1,rssi_2,rssi_3,rssi_4,rssi_5,rssi_6,rssi_7,rssi_8,rssi_9
0,-32,-39,-47,-48,-48,-49,-51,-52,-54,-56
1,-29,-34,-47,-48,-48,-49,-52,-52,-52,-53
2,-33,-39,-48,-48,-49,-52,-54,-55,-55,-55
3,-46,-48,-49,-50,-51,-52,-54,-56,-57,-57
4,-42,-49,-51,-51,-52,-53,-54,-55,-55,-55
...,...,...,...,...,...,...,...,...,...,...
258120,-53,-63,-64,-66,-68,-68,-68,-68,-70,-71
258121,-58,-64,-66,-67,-68,-68,-69,-70,-71,-71
258122,-57,-58,-60,-64,-66,-67,-68,-69,-71,-73
258123,-58,-64,-66,-66,-68,-69,-69,-71,-71,-72


bssid_NはN個目のBSSIDを示しておりRSSI値が大きい順に番号が振られている。
100個しかない


In [15]:
# get numbers of bssids to embed them in a layer

# train
wifi_bssids = []
# bssidを列ごとにリストに入れていく
for i in range(100):
    wifi_bssids.extend(train_df.iloc[:,i].values.tolist())
wifi_bssids = list(set(wifi_bssids))

train_wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(train): {train_wifi_bssids_size}')

# test
wifi_bssids_test = []
for i in range(100):
    wifi_bssids_test.extend(test_df.iloc[:,i].values.tolist())
wifi_bssids_test = list(set(wifi_bssids_test))

test_wifi_bssids_size = len(wifi_bssids_test)
print(f'BSSID TYPES(test): {test_wifi_bssids_size}')


wifi_bssids.extend(wifi_bssids_test)
wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(all): {wifi_bssids_size}')


BSSID TYPES(train): 61206
BSSID TYPES(test): 33042
BSSID TYPES(all): 94248


## preprocessing

In [16]:
# preprocess

le = LabelEncoder()
le.fit(wifi_bssids)
le_site = LabelEncoder()
le_site.fit(train_df['site_id'])

ss = StandardScaler()
ss.fit(train_df.loc[:,RSSI_FEATS])

ss_posx = StandardScaler()
ss_posx.fit(train_df.loc[:,POSX_FEATS])


def preprocess(input_df, le=le, le_site=le_site, ss=ss):
    output_df = input_df.copy()
    # RSSIの正規化
    output_df.loc[:,RSSI_FEATS] = ss.transform(input_df.loc[:,RSSI_FEATS])
    output_df.loc[:,POSX_FEATS] = ss_posx.transform(input_df.loc[:,POSX_FEATS])

    # BSSIDのLE(1からふる)
    for i in BSSID_FEATS:
        output_df.loc[:,i] = le.transform(input_df.loc[:,i])
#         output_df.loc[:,i] = output_df.loc[:,i] + 1  # 0からではなく1から番号を振りたいため なぜ？

    # site_idのLE
    output_df.loc[:, 'site_id'] = le_site.transform(input_df.loc[:, 'site_id'])

    # なぜ２重でやる？
#     output_df.loc[:,RSSI_FEATS] = ss.transform(output_df.loc[:,RSSI_FEATS])
    return output_df

train = preprocess(train_df)
test = preprocess(test_df)

train  

  return self.partial_fit(X, y)


Unnamed: 0,bssid_0,bssid_1,bssid_2,bssid_3,bssid_4,bssid_5,bssid_6,bssid_7,bssid_8,bssid_9,...,posx_90,posx_91,posx_92,posx_93,posx_94,posx_95,posx_96,posx_97,posx_98,posx_99
0,52392,35870,2764,34897,52709,35259,42719,33509,23416,15248,...,0.162567,0.370890,-0.294730,0.171690,0.178635,0.182635,-0.035823,0.312678,0.195191,0.273622
1,35870,52392,7486,34897,52709,35259,21970,15248,17024,5350,...,0.497911,0.298596,-0.048805,-0.075937,0.178635,0.379002,-0.279804,0.185843,0.127319,0.346746
2,35870,52392,52709,34897,35259,23416,49407,6672,7486,48500,...,0.316117,0.172696,-0.074634,-0.291525,0.104827,0.508513,0.310220,0.180442,0.190153,-0.045920
3,23416,34897,35259,52392,35870,3706,49407,15612,10166,4977,...,0.162356,0.166995,0.168673,0.171118,0.172957,0.588854,0.179941,0.201561,0.192441,-0.041908
4,35870,35259,23416,19472,52392,3706,49407,18305,21409,52794,...,0.530552,0.166996,0.583522,0.504496,0.178635,0.528698,0.310220,0.334097,0.318924,0.199529
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
258120,35065,14545,16494,21326,50465,22830,943,30429,43802,32684,...,1.291378,0.792870,0.040202,0.841729,0.073349,-2.101580,0.153784,1.344472,0.485800,0.347208
258121,35065,16494,21326,943,50465,22830,30059,33363,59581,14545,...,-2.143823,0.038261,0.068889,0.071024,0.275655,1.004186,0.582306,1.013765,0.485800,1.350053
258122,943,35065,14545,16494,21326,22830,50465,30059,48476,59581,...,0.061612,-2.132787,0.893998,0.042857,0.702699,0.077921,0.614130,0.480002,0.497125,0.170839
258123,14545,16494,35065,21326,50465,943,30059,48476,58803,22830,...,0.695602,1.292888,0.065861,0.895252,0.326178,0.579650,0.300767,1.013765,0.460168,1.350053


In [17]:
site_count = len(train['site_id'].unique())
site_count

24

## PyTorch model
- embedding layerが重要  

In [18]:
# dataset
from torch.utils.data import Dataset, DataLoader
class IndoorDataset(Dataset):
    def __init__(self, df, phase='train'):
        self.df = df
        self.phase = phase
        self.bssid_feats = df[BSSID_FEATS].values.astype(int)
        self.rssi_feats = df[RSSI_FEATS].values.astype(np.float32)
        self.posx_feats = df[POSX_FEATS].values.astype(np.float32)
        self.site_id = df['site_id'].values.astype(int)

        if phase in ['train', 'valid']:
            self.xy = df[['x', 'y']].values.astype(np.float32)
            self.floor = df['floor'].values.astype(np.float32)
        
    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        
        feature = {
            'BSSID_FEATS':self.bssid_feats[idx],
            'RSSI_FEATS':self.rssi_feats[idx],
            'POSX_FEATS':self.posx_feats[idx],
            'site_id':self.site_id[idx]
        }
        if self.phase in ['train', 'valid']:
            target = {
                'xy':self.xy[idx],
                'floor':self.floor[idx]
            }
        else:
            target = {}
        return feature, target

In [19]:
import torch
from torch import nn

class LSTMModel(nn.Module):
    def __init__(self, bssid_size=94248, site_size=24, embedding_dim=64):
        super(LSTMModel, self).__init__()
        
        # bssid
        # ->64次元に圧縮後sequence化にする
        # wifi_bssids_sizeが辞書の数を表す
        self.bssid_embedding = nn.Embedding(bssid_size, 64, max_norm=True)
        # site
        # ->2次元に圧縮後sequence化する
        # site_countが辞書の数を表す       
        self.site_embedding = nn.Embedding(site_size, 64, max_norm=True)

        # rssi
        # 次元を64倍に線形変換
        self.rssi = nn.Sequential(
            nn.BatchNorm1d(NUM_FEATS),
            nn.Linear(NUM_FEATS, NUM_FEATS * 64)
        )
        self.posx = nn.Sequential(
            nn.BatchNorm1d(NUM_FEATS),
            nn.Linear(NUM_FEATS, NUM_FEATS * 64)
        )
        
        
        concat_size = 64 + (NUM_FEATS * 64) + (NUM_FEATS * 64) + (NUM_FEATS * 64)
        self.linear_layer2 = nn.Sequential(
            nn.BatchNorm1d(concat_size),
            nn.Dropout(0.3),
            nn.Linear(concat_size, 256),
            nn.ReLU()
        )
        self.bn1 = nn.BatchNorm1d(concat_size)

        self.flatten = nn.Flatten()

        self.dropout1 = nn.Dropout(0.3)
        self.linear1 = nn.Linear(in_features=concat_size, out_features=256)#, bias=False)
        self.bn2 = nn.BatchNorm1d(256)

        self.batch_norm1 = nn.BatchNorm1d(1)
        self.lstm1 = nn.LSTM(input_size=256,hidden_size=128,dropout=0.3, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=128,hidden_size=16,dropout=0.1, batch_first=True)

        self.fc_xy = nn.Linear(16, 2)
        # self.fc_x = nn.Linear(16, 1)
        # self.fc_y = nn.Linear(16, 1)
        self.fc_floor = nn.Linear(16, 1)

    
    def forward(self, x):
        # input embedding
        batch_size = x["site_id"].shape[0]
        x_bssid = self.bssid_embedding(x['BSSID_FEATS'])
        x_bssid = self.flatten(x_bssid)
        
        x_site_id = self.site_embedding(x['site_id'])
        x_site_id = self.flatten(x_site_id)

        x_rssi = self.rssi(x['RSSI_FEATS'])
        
        x_posx = self.posx(x['POSX_FEATS'])

        x = torch.cat([x_bssid, x_site_id, x_rssi, x_posx], dim=1)
        x = self.linear_layer2(x)

        # lstm layer
        x = x.view(batch_size, 1, -1)  # [batch, 1]->[batch, 1, 1]
        x = self.batch_norm1(x)
        x, _ = self.lstm1(x)
        x = torch.relu(x)
        x, _ = self.lstm2(x)
        x = torch.relu(x)

        # output [batch, 1, 1] -> [batch]
        # x_ = self.fc_x(x).view(-1)
        # y_ = self.fc_y(x).view(-1)
        xy = self.fc_xy(x).squeeze(1)
        floor = torch.relu(self.fc_floor(x)).view(-1)
        # return {"x":x_, "y":y_, "floor":floor} 
        return {"xy": xy, "floor": floor}

In [20]:
def mean_position_error(xhat, yhat, fhat, x, y, f):
    intermediate = np.sqrt(np.power(xhat-x, 2) + np.power(yhat-y, 2)) + 15 * np.abs(fhat-f)
    return intermediate.sum()/xhat.shape[0]

def to_np(input):
    return input.detach().cpu().numpy()

In [21]:
def get_optimizer(model: nn.Module, config: dict):
    optimizer_config = config["optimizer"]
    optimizer_name = optimizer_config.get("name")
    base_optimizer_name = optimizer_config.get("base_name")
    optimizer_params = optimizer_config['params']

    if hasattr(optim, optimizer_name):
        optimizer = optim.__getattribute__(optimizer_name)(model.parameters(), **optimizer_params)
        return optimizer
    else:
        base_optimizer = optim.__getattribute__(base_optimizer_name)
        optimizer = globals().get(optimizer_name)(
            model.parameters(), 
            base_optimizer,
            **optimizer_config["params"])
        return  optimizer

def get_scheduler(optimizer, config: dict):
    scheduler_config = config["scheduler"]
    scheduler_name = scheduler_config.get("name")

    if scheduler_name is None:
        return
    else:
        return optim.lr_scheduler.__getattribute__(scheduler_name)(
            optimizer, **scheduler_config["params"])


def get_criterion(config: dict):
    loss_config = config["loss"]
    loss_name = loss_config["name"]
    loss_params = {} if loss_config.get("params") is None else loss_config.get("params")
    if hasattr(nn, loss_name):
        criterion = nn.__getattribute__(loss_name)(**loss_params)
    else:
        criterion = globals().get(loss_name)(**loss_params)

    return criterion

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

In [22]:
# Learner class(pytorch-lighting)
class Learner(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.xy_criterion = get_criterion(config)
        self.f_criterion = get_criterion(config)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = self.xy_criterion(output["xy"], y["xy"])
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        xy_loss = self.xy_criterion(output["xy"], y["xy"])
        f_loss = self.f_criterion(output["floor"], y["floor"])
        loss = xy_loss  # + f_loss
        mpe = mean_position_error(
            to_np(output['xy'][:, 0]), to_np(output['xy'][:, 1]), 0, 
            to_np(y['xy'][:, 0]), to_np(y['xy'][:, 1]), 0)
        
        # floor lossは現状は無視して良い
        self.log(f'Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/xy', xy_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/floor', f_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'MPE/val', mpe, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return mpe
    
    def validation_epoch_end(self, outputs):
        avg_loss = np.mean(outputs)
        print(f'epoch = {self.current_epoch}, mpe_loss = {avg_loss}')

    def configure_optimizers(self):
        optimizer = get_optimizer(self.model, self.config)
        scheduler = get_scheduler(optimizer, self.config)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "Loss/val"}

In [23]:
# oof
def evaluate(model, loaders, phase):
    x_list = []
    y_list = []
    f_list = []
    with torch.no_grad():
        for batch in loaders[phase]:
            x, y = batch
            output = model(x)
            x_list.append(to_np(output['xy'][:, 0]))
            y_list.append(to_np(output['xy'][:, 1]))
            f_list.append(to_np(output['floor']))

    x_list = np.concatenate(x_list)
    y_list = np.concatenate(y_list)
    f_list = np.concatenate(f_list)
    return x_list, y_list, f_list

## train

In [24]:
oofs = []  # 全てのoofをdfで格納する
predictions = []  # 全ての予測値をdfで格納する
val_scores = []
# skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
gkf = GroupKFold(n_splits=N_SPLITS)
# for fold, (trn_idx, val_idx) in enumerate(skf.split(train.loc[:, 'path'], train.loc[:, 'path'])):
for fold, (trn_idx, val_idx) in enumerate(gkf.split(train.loc[:, 'path'], groups=train.loc[:, 'path'])):

    # 指定したfoldのみループを回す

    print('=' * 20)
    print(f'Fold {fold}')
    print('=' * 20)

    # train/valid data
    trn_df = train.loc[trn_idx, BSSID_FEATS + RSSI_FEATS + POSX_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)
    val_df = train.loc[val_idx, BSSID_FEATS + RSSI_FEATS + POSX_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)

    # data loader
    loaders = {}
    loader_config = config["loader"]
    loaders["train"] = DataLoader(IndoorDataset(trn_df, phase="train"), **loader_config["train"], worker_init_fn=worker_init_fn) 
    loaders["valid"] = DataLoader(IndoorDataset(val_df, phase="valid"), **loader_config["valid"], worker_init_fn=worker_init_fn)
    loaders["test"] = DataLoader(IndoorDataset(test, phase="test"), **loader_config["test"], worker_init_fn=worker_init_fn)
    
    # model
    model = LSTMModel(wifi_bssids_size, site_count)
    model_name = model.__class__.__name__
    
    # loggers
    RUN_NAME = f'exp{str(EXP_NAME)}'
    wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type=RUN_NAME + f'-fold-{fold}')
    wandb.run.name = RUN_NAME + f'-fold-{fold}'
    wandb_config = wandb.config
    wandb_config.model_name = model_name
    wandb.watch(model)
    
    
    loggers = []
    loggers.append(WandbLogger())

    learner = Learner(model, config)
    
    # callbacks
    callbacks = []
    checkpoint_callback = ModelCheckpoint(
        monitor=f'Loss/val',
        mode='min',
        dirpath=OUTPUT_DIR,
        verbose=False,
        filename=f'{model_name}-{learner.current_epoch}-{fold}')
    callbacks.append(checkpoint_callback)

    early_stop_callback = EarlyStopping(
        monitor='Loss/val',
        min_delta=0.00,
        patience=3,
        verbose=True,
        mode='min')
    callbacks.append(early_stop_callback)
    
    trainer = pl.Trainer(
        logger=loggers,
        checkpoint_callback=callbacks,
        max_epochs=MAX_EPOCHS,
        default_root_dir=OUTPUT_DIR,
        gpus=1,
        fast_dev_run=DEBUG,
        deterministic=True,
        benchmark=True,
#         precision=16,
#         progress_bar_refresh_rate=0  # vscodeの時progress barの動作が遅いので表示しない
        )


    trainer.fit(learner, train_dataloader=loaders['train'], val_dataloaders=loaders['valid'])

    #############
    # validation (to make oof)
    #############
    model.eval()
    oof_x, oof_y, oof_f = evaluate(model, loaders, phase="valid")
    val_df["oof_x"] = oof_x
    val_df["oof_y"] = oof_y
    val_df["oof_floor"] = oof_f
    oofs.append(val_df)
    
    val_score = mean_position_error(
        val_df["oof_x"].values, val_df["oof_y"].values, 0,
        val_df['x'].values, val_df['y'].values, 0)
    val_scores.append(val_score)
    print(f"fold {fold}: mean position error {val_score}")

    #############
    # inference
    #############
    preds_x, preds_y, preds_f = evaluate(model, loaders, phase="test")
    test_preds = pd.DataFrame(np.stack((preds_f, preds_x, preds_y))).T
    test_preds.columns = sub.columns
    test_preds["site_path_timestamp"] = test["site_path_timestamp"]
    test_preds["floor"] = test_preds["floor"].astype(int)
    predictions.append(test_preds)
    

Fold 0


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))
[34m[1mwandb[0m: Currently logged in as: [33msqrt4kaido[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.25 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 17.5 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
17.5 M    Trainable params
0         Non-trainable params
17.5 M    Total params
69.900    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 133.70531845092773


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 161.64858097411897


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 157.4395309837488


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 153.3587029188597


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 149.36238688349516


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 145.44788429451322


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 141.61147283484092


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 137.852638168646


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 134.17923164532084


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 130.59401952544687


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 127.10826777760902


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 123.72000330521189


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 120.43618042569835


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 117.25999566679718


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 114.19700986507819


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 111.25324139505608


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 108.44190690579912


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 105.76203318725302


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 103.22213678878757


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 100.82609994429521


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 98.57570354659403


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 96.46933822440768


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 92.36312887699435


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 87.89820222142554


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 84.93925472048458


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 82.32092543746957


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 79.88458273703787


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 77.63519924183814


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 75.5474996674526


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 73.61302682895472


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 71.8459815076096


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 70.17304944260007


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 68.68218520935983


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 66.70232466940416


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 65.0153904079155


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 63.17842857538578


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 61.642853773158166


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 60.22075571032409


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 58.90550496607377


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 57.64633506863771


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 56.46373940780541


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 55.34129156327067


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 54.25252571374785


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 53.22541957090952


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 52.23884637009494


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 51.235419894036816


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 50.25521319689832


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 49.25639727604132


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 48.250990415049905


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 47.23365260657976


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 46.18149161234941


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 45.19378610563194


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 44.12802833246088


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 43.138775813758684


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 42.09780241709083


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 41.154692272592854


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 40.10517951423232


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 39.039994136663665


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 38.01750970664938


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 37.034311826169386


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 36.05163732837506


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 35.06913284150681


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 34.16459583561297


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 33.219221713388805


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 32.340651290818016


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 31.404104097585176


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 30.512839926518048


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 29.7477197969911


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 28.869572819091584


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 28.032385288131884


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 27.07435033588265


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 26.091099101906956


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 25.257988495968355


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 24.423973450735687


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 23.64442165526944


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 22.88276088757923


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 22.152280423759237


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 21.480388946533203


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 20.819307015918337


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 20.201577169977096


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 19.559413172922373


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 18.912452800019786


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 18.290059120698107


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 17.696895481725036


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 17.173656982372997


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 16.651152893448543


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 16.13959817531988


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 15.684286418563763


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 15.247402196227636


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 14.781241371041478


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 14.528888775853828


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 14.020515947658515


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 13.672846418301866


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 13.368876113436148


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 13.026512824521205


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 12.743896346317173


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 12.457891901257293


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 12.22010269342777


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 12.007497709137656


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 11.818208102780021


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 11.525159199936278


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 11.365400755953914


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 11.220509149810223


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 10.981847273016227


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 10.804244608348654


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 10.68308564792071


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 10.492373877543658


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 10.347553606338813


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 10.261101211732255


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 10.144938115823665


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 10.035425147788082


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 9.988028732926331


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 9.85741204318989


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 9.84252390915388


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 9.710717370517044


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 9.624332004196088


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 9.57162354963077


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 9.502474065857289


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 9.444211943642511


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 9.352620772160135


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 9.363152016976855


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 9.295240069129958


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 9.236228957134552


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 9.157836180327587


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 9.115032027238458


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 9.098672980051369


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 9.024526733960533


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 9.044305368289047


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 9.016397683205268


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 8.937840900976573


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 8.884798633728739


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 8.895123364532417


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 8.87771017923283


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 8.801576354471694


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 8.771252315667875


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 8.742046344092785


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 8.732559144870445


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 8.68560520382072


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 8.652371079195497


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 8.67984455705175


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 8.652180093544198


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 8.60455905743234


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 8.591434721782864


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 8.577984172263392


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 8.573019699936491


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 8.551698537581037


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 8.564932166939915


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 8.545904411033758


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 8.552633272747558


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 8.55286195250674


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 8.558158279467824


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 8.540475052348667


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 8.54131769058103


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 8.538009289489906


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 8.528530038479254


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 8.540680479817288


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 8.539580896874119


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 8.548532730218069


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 8.524998108703311


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 8.532005009948339


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 8.540071612676355


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 8.539280555138562


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 8.539645545239635


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 8.542788064209056


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 8.548947155740834


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 8.539258636636795


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 8.547626971420883


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 8.542655314288309


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 8.519781595917726


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 8.54019182511798


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 8.533100714697419


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 8.543089487056921


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 8.539401289796023


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 8.526380814711715


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 8.546152520721293


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 8.524785733642313


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 8.521591673358348


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 8.533053305953064


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 8.544338457999565


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 8.545237423680728


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 8.553002446673954


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 8.53838636392761


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 8.539868743932033


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 8.54017448929624


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 8.53878921757622


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 8.543702041457125


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 8.535992810466501


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 8.528139480153142


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 8.529211360307093


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 8.53174865997039


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 8.542716481337106


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 8.543594426536671


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 8.545967011915506


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 8.53949824360407


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 8.537921181202767


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 8.533444221954257


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 8.542866504982452


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 8.532250671931042


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 8.534623385347734


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 8.537934221024422


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 8.529231771936788
fold 0: mean position error 8.528395007406811
Fold 1


  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,64.97753
Loss/xy,64.97753
Loss/floor,5.32458
MPE/val,8.5284
epoch,199.0
trainer/global_step,80799.0
_runtime,1075.0
_timestamp,1617694406.0
_step,199.0


0,1
Loss/val,█▇▆▅▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▇▆▅▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,████▁▆▅▄▃▃▂▂▂▂▂▂▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
MPE/val,█▇▆▆▅▄▄▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


[34m[1mwandb[0m: wandb version 0.10.25 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 17.5 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
17.5 M    Trainable params
0         Non-trainable params
17.5 M    Total params
69.900    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 156.66897583007812


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 159.8349915842935


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 155.61534765884497


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 151.51888358958738


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 147.5096078367683


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 143.58186920014938


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 139.73679683989678


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 135.97296009505843


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 132.2925219736338


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 128.69624699512053


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 125.19416478329781


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 121.78857294437006


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 118.4863903062095


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 115.30176637195194


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 112.2345743738529


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 109.28873191375821


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 106.46576201846034


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 103.76934369962022


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 101.20293888511947


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 98.77196324934428


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 96.48069424490909


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 94.33268396047639


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 92.332227109177


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 90.48578567238276


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 88.79143490581924


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 87.25454056674822


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 85.87638103047429


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 84.65444519927215


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 83.5908473540602


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 74.96012362254103


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 70.92012619871697


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 68.02382655804145


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 65.82154206028774


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 63.321410688534684


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 61.03940312913774


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 59.133015407680006


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 57.440391286184564


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 55.90403906969393


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 54.4872265306405


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 53.14252895399659


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 51.87179872962716


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 50.66887719578151


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 49.44791872258928


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 48.261263862925354


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 47.091057749033396


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 45.85458127462246


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 44.64668493358632


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 43.43163493556899


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 42.129655751187784


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 40.85946308500348


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 39.610818778056355


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 38.35750758931466


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 37.102698306504145


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 35.8978597156989


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 34.647485236921625


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 33.47425817508509


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 32.265961919872304


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 31.11881007037332


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 30.006054443354934


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 28.948898918305154


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 27.912745087312


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 26.927930769278827


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 26.03471327554506


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 25.182149892272683


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 24.215124454531615


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 23.356364968577026


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 22.509372784965038


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 21.831894651711785


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 21.0698901967608


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 20.33664438700912


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 19.619070899543473


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 19.064880617793797


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 18.4136098951563


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 17.81385289922591


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 17.324663796610935


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 16.798085667105283


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 16.28753821516842


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 15.871767163662696


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 15.451141858586894


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 14.933770378015037


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 14.62462976443469


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 14.184786480418735


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 13.785869442101614


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 13.494317190062535


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 13.2318709797712


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 12.86695617716883


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 12.567979304325323


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 12.29070166197362


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 12.040166967699676


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 11.760716951689057


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 11.559675761030105


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 11.339426406633736


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 11.150847066713755


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 10.999842832339803


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 10.808278452951544


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 10.698521053723006


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 10.58092778341952


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 10.493900830863726


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 10.305096108631542


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 10.31937676862429


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 10.109019208873644


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 10.012351237637537


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 9.958440016500465


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 9.804570767519845


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 9.734655302755348


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 9.67952760224295


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 9.569372422457713


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 9.48870086020068


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 9.401267185822155


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 9.357695573530158


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 9.293339775454255


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 9.23539274041293


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 9.18098488884883


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 9.097093976353516


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 9.058691979521017


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 9.00989160855426


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 8.966604034177216


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 8.912661828189377


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 8.866170138923943


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 8.857128806505775


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 8.769826995269739


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 8.804972861217216


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 8.69763512560845


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 8.65272877442677


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 8.61959237885739


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 8.61328857101947


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 8.581566691723548


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 8.516599996507898


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 8.521900195686639


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 8.52179648747608


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 8.472145055127685


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 8.460103622529836


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 8.456692423934358


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 8.460962375303167


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 8.4199895121375


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 8.358053828568812


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 8.365593157675493


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 8.343737222719275


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 8.383275002794482


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 8.306858163277077


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 8.307913831900773


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 8.323299480076091


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 8.332462256081103


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 8.260839089219765


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 8.235246607647708


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 8.305162129443818


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 8.267471889113715


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 8.189216627077649


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 8.20487151437536


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 8.209430386604927


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 8.206386095234732


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 8.206409891547334


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 8.178280705300333


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 8.193542882289453


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 8.202724852353665


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 8.131103923844263


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 8.154009345848337


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 8.140105589609474


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 8.151614346079308


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 8.151015924965844


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 8.096339073392215


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 8.086724098507721


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 8.075589312274024


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 8.07735929479138


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 8.079328433577238


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 8.071628679532122


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 8.077150932549737


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 8.066252271733314


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 8.071738764752325


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 8.063841485849423


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 8.060470754537876


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 8.06214805646351


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 8.062922069913924


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 8.062562970343064


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 8.06603431606126


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 8.068593250512382


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 8.070555381430536


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 8.063068781626663


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 8.065756510577371


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 8.075302077611656


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 8.066776422246223


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 8.067314304625079


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 8.055837150122722


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 8.059480894218716


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 8.065050229043134


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 8.062132165572223


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 8.069238657243181


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 8.072655525174193


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 8.060831559547243


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 8.055873495278556


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 8.06411669609501


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 8.065272889501076


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 8.068350815881319


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 8.058821863553241


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 8.065363543604304


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 8.072083962184438


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 8.055468388036996


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 8.06996242862006


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 8.06427877212862


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 8.059157269210681


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 8.065410970166873
fold 1: mean position error 8.059990456492113
Fold 2


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,51.79793
Loss/xy,51.79793
Loss/floor,5.25406
MPE/val,8.05999
epoch,199.0
trainer/global_step,80799.0
_runtime,1090.0
_timestamp,1617695505.0
_step,199.0


0,1
Loss/val,█▇▅▅▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▇▅▅▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
MPE/val,█▇▆▆▅▅▄▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


[34m[1mwandb[0m: wandb version 0.10.25 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 17.5 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
17.5 M    Trainable params
0         Non-trainable params
17.5 M    Total params
69.900    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 149.43074417114258


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 159.80377971375898


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 155.588648941671


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 151.49655542825778


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 147.49613834938756


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 143.58644465646984


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 139.76450237720735


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 136.028035439173


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 132.37576838160643


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 128.81134201707434


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 125.33686308197039


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 121.95371582395612


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 118.67467253265647


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 115.50886953959302


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 112.46228341917417


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 109.54141496932708


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 106.75107779909999


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 104.09720561770136


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 101.5860821872238


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 99.22287285730877


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 97.00801297382762


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 94.94718234520425


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 93.0377407506766


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 91.27095336985157


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 89.64404911261211


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 88.15395105800728


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 86.8025785185559


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 85.59102671188327


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 84.52223092254124


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 83.59134594787882


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 82.51035457684401


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 71.58133340256812


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 68.95656982875107


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 66.56358129560218


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 64.55146860781421


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 62.73194653381632


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 61.09254339559174


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 59.5050862940731


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 58.05928860499405


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 56.674360923011704


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 55.423362237156034


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 54.2588313384992


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 53.064184687108444


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 51.937758230854755


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 51.04835665914439


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 49.89405589944065


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 48.909128907147576


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 47.974394657910395


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 46.97167737995252


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 45.97548123441329


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 44.97967429341586


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 43.99494882824675


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 42.98134483437379


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 40.12941906259722


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 38.93206446804831


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 37.83465942385034


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 36.848525922916345


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 35.70759937064205


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 34.6697623666109


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 33.7098863221659


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 32.7461497669631


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 31.751281853955227


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 30.78253445672517


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 29.859716974312857


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 28.989829060249573


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 28.167572633057883


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 27.335320772727563


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 26.573014771623114


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 25.73832476930835


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 24.95466746358588


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 24.249225252414764


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 23.510329590493605


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 22.809888137479682


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 22.144235008819614


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 21.490399050193258


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 20.88023152934026


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 20.265443277986645


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 19.72957675190672


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 19.16180900523742


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 18.680754282156535


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 18.124684749030067


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 17.638639695629102


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 17.26039792096955


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 16.699738419809936


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 16.120763396728016


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 15.658557180561294


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 15.224794548970014


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 14.779548336022387


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 14.426748247141052


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 14.054351649731483


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 13.718724070667317


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 13.370190601748787


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 13.107185307037852


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 12.777767783738739


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 12.583278290765815


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 12.227877573980864


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 12.046801469144116


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 11.815451071622


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 11.624789320246453


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 11.37367620578115


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 11.283623376000948


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 11.041066984022226


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 10.852512308708373


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 10.725390933874948


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 10.583382167282949


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 10.426162653208195


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 10.282560637679014


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 10.210769667217187


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 10.137203691212605


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 9.990844355801475


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 9.917872085926765


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 9.841519065783617


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 9.73778486028548


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 9.686739576536345


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 9.622631084717632


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 9.494827364986556


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 9.476780215925995


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 9.469744103342277


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 9.355867551460332


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 9.328300640381695


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 9.181877085497439


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 9.139310481560807


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 9.058390428290926


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 9.085751374055288


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 9.008381779292467


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 8.964593272189727


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 8.905798347986362


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 8.876738984536209


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 8.835046445805455


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 8.794464523235447


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 8.765065119548392


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 8.73999211902874


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 8.6955102123045


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 8.689803678760295


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 8.6350798790609


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 8.617740389357907


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 8.581770514564315


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 8.552992878371224


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 8.53630253571303


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 8.541298813694914


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 8.444115268782815


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 8.437399416975177


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 8.410963028833903


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 8.407284258840802


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 8.355378011049194


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 8.363808256814698


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 8.382039780380703


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 8.348767814214103


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 8.307898698416851


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 8.276759946475421


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 8.2873748142218


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 8.269644517004316


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 8.268692164812661


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 8.24181151736109


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 8.256080868153914


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 8.20013570673881


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 8.231168122044398


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 8.21871599863349


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 8.171518488369088


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 8.172759298958887


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 8.193014043409184


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 8.161853316428706


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 8.166109841268424


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 8.143959067205252


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 8.096924364815358


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 8.117499497601926


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 8.119735920238995


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 8.113314098487015


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 8.106194346460688


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 8.102784109604379


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 8.072119839312242


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 8.076559157352081


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 8.056219582332728


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 8.039784211395922


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 8.01795378085688


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 8.019010341990876


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 8.034711818761721


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 7.980755535979984


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 8.015903530965105


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 8.070492248129693


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 8.000433825713364


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 7.977646172106371


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 7.980490575078344


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 7.9803719080667825


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 7.982546418733769


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 7.996617128131136


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 8.01023755785054


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 7.954464622013224


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 7.9475625920837265


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 7.970287974429255


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 7.9433887037789335


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 7.910532712542046


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 7.94749267909717


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 7.927639436843997


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 7.938416831094856


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 7.915596737236775


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 7.899971465942258


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 7.903140149658061


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 7.919827391625007


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 7.899620616337415


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 7.862000670652439
fold 2: mean position error 7.859496951847164
Fold 3


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,48.34843
Loss/xy,48.34843
Loss/floor,4.65816
MPE/val,7.8595
epoch,199.0
trainer/global_step,80799.0
_runtime,1098.0
_timestamp,1617696612.0
_step,199.0


0,1
Loss/val,█▇▅▅▄▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▇▅▅▄▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,▁▁▁▁▁▁▂▄▄▄▅▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇▇█████████████
MPE/val,█▇▆▆▅▅▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


[34m[1mwandb[0m: wandb version 0.10.25 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 17.5 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
17.5 M    Trainable params
0         Non-trainable params
17.5 M    Total params
69.900    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 141.7290267944336


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 165.83690360929637


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 161.2312629245587


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 156.77162758126812


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 152.41373581990158


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 148.1555254470212


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 144.00163698688525


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 139.95543750227506


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 136.02117514350905


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 132.19520784879052


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 128.48094417680886


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 124.88804149812088


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 121.42847525794907


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 118.1056814658732


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 114.9161535660972


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 111.86349652942975


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 108.9481111242331


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 106.1793106869146


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 103.56475715468355


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 101.10756214175171


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 98.81699413115157


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 96.68566418082892


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 92.2179966092429


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 86.815661397233


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 83.2967522025622


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 80.41069171285658


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 77.39590464281494


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 74.70203374481646


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 72.10941949647739


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 69.60415370134677


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 67.37654925693519


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 65.32015435949202


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 63.45126015669256


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 61.64492552146845


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 59.996601173476414


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 58.47634368596552


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 57.04081569844925


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 55.67566162264895


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 54.38593744825609


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 53.031576417779675


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 51.731291615281194


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 50.482496427314956


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 49.20971051162249


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 47.868697040693384


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 46.60161398859307


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 45.378168585815615


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 44.11014158034505


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 42.967989293422264


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 41.66865401830657


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 40.38718022161538


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 39.17042363285116


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 37.98156931130082


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 36.74789431941322


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 35.53608648946692


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 34.41391165760553


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 33.27222139227439


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 32.24629713884952


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 31.18890170997704


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 30.18267512103188


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 29.190695983118005


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 28.220368338647667


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 27.267207671372642


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 26.365590702494565


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 25.474433478853626


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 24.66793462264517


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 24.086935959428843


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 23.115103297602943


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 22.364752151275972


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 21.67324892929424


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 21.009799431060454


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 20.400839992483203


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 19.82666443606484


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 19.226413949867563


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 18.741689288650896


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 18.217237422578755


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 17.697837433190145


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 17.2333673363865


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 16.78217245666246


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 16.449508602828846


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 15.858297722649283


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 15.410311812143654


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 14.995661580947228


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 14.64491103050663


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 14.294565174800404


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 13.957824899499057


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 13.64160375336262


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 13.360615106303815


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 13.080046716122414


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 12.862798536252892


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 12.577947763821797


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 12.390345028070216


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 12.159935976327265


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 11.996258793021608


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 11.838681057978317


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 11.614338341236948


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 11.4801626457832


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 11.349208490263951


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 11.200040814727979


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 11.085634010672639


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 10.923569484415204


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 10.847208233235113


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 10.708114434832671


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 10.623659147308118


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 10.50909087660428


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 10.448851670613452


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 10.344478060189694


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 10.235272077739273


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 10.178736741982995


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 10.09345831768278


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 9.985757954852415


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 9.962405433010487


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 9.881070094200227


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 9.80456653143961


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 9.720090982529335


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 9.646474628237423


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 9.634801568779476


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 9.581628490363572


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 9.470463604679896


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 9.422163056655199


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 9.395836822496435


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 9.358017920581283


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 9.282858020113732


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 9.248702476281515


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 9.220950328058189


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 9.208851856393876


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 9.143019687006067


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 9.14160966127781


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 9.120634368358827


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 9.063460245460151


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 9.00988593573062


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 9.0081931498888


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 8.984324136755115


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 8.950194171156797


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 8.952518754646954


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 8.925543029831811


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 8.900966544942795


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 8.90042658362058


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 8.844080781109055


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 8.858351077400258


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 8.833480144338546


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 8.825178223289441


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 8.832415193203374


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 8.75087867386895


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 8.75476050988421


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 8.719129854223377


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 8.755136717605813


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 8.72688476650258


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 8.704159113495727


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 8.685695450049051


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 8.657528072805965


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 8.659042516396221


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 8.652277750427388


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 8.58079999131662


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 8.637369565458304


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 8.608744104359825


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 8.55763096888814


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 8.566410503098627


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 8.560390998649263


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 8.542931806567664


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 8.533746252021057


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 8.512599471852608


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 8.525461547109506


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 8.510453160907476


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 8.48810670421647


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 8.515293282068393


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 8.476174810161824


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 8.519607755064479


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 8.489769744073227


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 8.461201696645817


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 8.441923743610682


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 8.43921223284411


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 8.4362430933984


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 8.428632480119228


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 8.427811940448118


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 8.430247807433318


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 8.435621420338741


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 8.424233034697119


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 8.423002863080264


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 8.421066659028682


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 8.417228499597496


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 8.419934100504872


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 8.415072453231678


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 8.414121622942046


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 8.40976543487611


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 8.414997037188842


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 8.42185493000455


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 8.399263030603727


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 8.411536768167325


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 8.411637366177665


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 8.404406503367438


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 8.40588233978858


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 8.404071684744585


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 8.403831070315094


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 8.399054053022756


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 8.40379979617386


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 8.392957768537446


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 8.395517989699064


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 8.398557550966844


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 8.398405165919469


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 8.402724176094708


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 8.39697988437926
fold 3: mean position error 8.394969491496205
Fold 4


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,62.20581
Loss/xy,62.20581
Loss/floor,4.91806
MPE/val,8.39497
epoch,199.0
trainer/global_step,80799.0
_runtime,1097.0
_timestamp,1617697718.0
_step,199.0


0,1
Loss/val,█▇▅▅▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▇▅▅▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,▄▄▄▄▄▃▁▁▂▃▃▂▂▃▃▅▅▆▆▆▆▇▇▇▇▇██████████████
MPE/val,█▇▆▆▅▄▄▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


[34m[1mwandb[0m: wandb version 0.10.25 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | model        | LSTMModel | 17.5 M
1 | xy_criterion | MSELoss   | 0     
2 | f_criterion  | MSELoss   | 0     
-------------------------------------------
17.5 M    Trainable params
0         Non-trainable params
17.5 M    Total params
69.900    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 151.0456428527832


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 162.31380747471292


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 158.11779787706786


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 154.04032396222104


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 150.04876763514326


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 146.14195613027937


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 142.31766736794296


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 138.57624049538293


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 134.9222717910349


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 131.35732702345229


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 127.88442897923268


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 124.5036979894799


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 121.22547376731556


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 118.05697311934581


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 115.00532010024553


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 112.06901547724658


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 109.25332188396864


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 106.55632130084652


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 103.98296241573628


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 101.54344163120383


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 99.2367699402602


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 97.07013731519359


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 95.04776478935692


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 87.22410605011254


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 82.40886319146576


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 78.72687359551601


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 75.63007741458206


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 72.87443744359491


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 70.3783361355621


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 68.13334124725644


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 66.04164431034859


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 64.05142880448089


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 62.32083567339942


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 60.658659483432494


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 59.076386221528544


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 57.68399039726724


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 56.369186320763106


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 55.0954907970173


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 53.91616907303044


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 52.73423478928012


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 51.51988033230248


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 50.2480168445345


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 48.95694667256288


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 47.60402902986396


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 46.28484968584224


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 44.96578032620785


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 43.67630559273633


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 42.34977715336172


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 40.99581274133681


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 39.69071741971709


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 38.43266069128962


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 37.116576570345345


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 35.86701361791158


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 34.67233440392505


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 33.4811868309572


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 32.31504047148299


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 31.175800613898765


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 30.137766332487484


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 29.072021359803028


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 28.082695067344048


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 27.097627461643917


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 26.138654901630044


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 25.296325618752785


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 24.460393059796626


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 23.602841915015084


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 22.788875400663187


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 22.0133196797868


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 21.27390066144629


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 20.552281425955968


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 19.862433889196858


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 19.224552993841066


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 18.586991235314237


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 18.012563744424455


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 17.45934775044214


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 16.89591256790961


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 16.436518905286395


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 15.950155440382161


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 15.53416651673003


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 15.066825086452548


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 14.675740660965687


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 14.241021566985301


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 13.859777561747618


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 13.519871260432497


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 13.206491229457779


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 12.875550923255828


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 12.628895302552015


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 12.347788312019965


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 12.138538647386733


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 11.908585400056477


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 11.664383004187982


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 11.485939018658309


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 11.291693980595229


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 11.15714755098835


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 11.02438838327771


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 10.819265246038597


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 10.724268268314155


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 10.596373351890817


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 10.500292185915022


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 10.312759837337754


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 10.22659679716107


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 10.13116279640931


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 10.064951870773307


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 9.956417092005042


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 9.828280976817021


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 9.77074025028031


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 9.713684997636353


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 9.667333651783165


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 9.616247887253136


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 9.53481402943144


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 9.468005859694928


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 9.431441251058251


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 9.359067134790527


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 9.348750426115238


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 9.26984394724451


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 9.247257963934816


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 9.165600256222795


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 9.109902337773567


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 9.092837586433442


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 9.053233551331989


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 9.01212671417654


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 8.972025189452486


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 8.915599312801728


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 8.879557061870221


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 8.818044986969243


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 8.850165332800856


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 8.761579172937552


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 8.782986143053307


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 8.75922534942627


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 8.762596677659069


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 8.7261849426154


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 8.700792447713898


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 8.700064255431704


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 8.644967059578114


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 8.637668053844187


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 8.621028277830947


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 8.645187453412705


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 8.632192999762578


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 8.573353096346002


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 8.579332316770204


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 8.548764125166345


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 8.565351631040135


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 8.512412982078645


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 8.536949562468221


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 8.495729999009024


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 8.471289025963776


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 8.466306171494995


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 8.515968806811662


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 8.489660541143401


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 8.472215598689031


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 8.446324259691899


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 8.405930835188443


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 8.399808132791492


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 8.3993391826032


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 8.39473988560236


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 8.386135605016142


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 8.381778676550125


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 8.384873956225956


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 8.379807276906284


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 8.37239979264621


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 8.368625195702634


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 8.37595070477343


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 8.377045181332178


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 8.372910448539791


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 8.37001141725339


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 8.370727286130519


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 8.36918993395015


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 8.367962375767508


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 8.376508887562562


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 8.370332744000745


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 8.379120609136814


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 8.378429270573807


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 8.37527531534416


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 8.370230574943884


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 8.369256751583146


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 8.37186341238494


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 8.379850434118325


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 8.367490051338406


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 8.365938634022628


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 8.369689803059599


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 8.378760485851842


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 8.36753116983137


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 8.370765812340071


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 8.371129707998499


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 8.36008649511676


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 8.369664055630434


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 8.374741634497756


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 8.363782680924714


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 8.370954462261345


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 8.372105916689215


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 8.374787914562726


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 8.37126133372774


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 8.365395818209329


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 8.37472704964595


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 8.372808420362345


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 8.37102428327415


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 8.376327565392575


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 8.369209948269239


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 8.3731480985807


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 8.375674145997925


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 8.375019073130852


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 8.368026091142942
fold 4: mean position error 8.362432512686086


In [25]:
if len(oofs) > 1:
    oofs_df = pd.concat(oofs)
else:
    oofs_df = oofs[0]
oofs_df.to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}.csv", index=False)
oofs_df

Unnamed: 0,bssid_0,bssid_1,bssid_2,bssid_3,bssid_4,bssid_5,bssid_6,bssid_7,bssid_8,bssid_9,...,posx_97,posx_98,posx_99,site_id,x,y,floor,oof_x,oof_y,oof_floor
0,54638,6982,8546,26610,28338,37494,2205,25810,39211,20529,...,1.272323,0.834709,0.838465,0,192.90768,159.26582,-1,194.452713,162.153702,0.0
1,26610,6982,8546,2205,37494,51325,28338,39211,14273,25810,...,0.828953,1.090478,0.838144,0,192.90768,159.26582,-1,197.646942,161.264389,0.0
2,8546,6982,51325,26610,20054,8702,30893,15036,39211,37494,...,0.484567,0.474116,0.838968,0,198.36833,163.52063,-1,198.924484,160.096298,0.0
3,6982,51325,26610,8702,30893,20054,39211,11853,15036,54638,...,0.468287,0.489564,0.496164,0,198.36833,163.52063,-1,197.777466,163.455688,0.0
4,51325,6982,26610,8702,44925,39211,37494,27631,14273,45777,...,1.048296,1.080331,1.025532,0,198.36833,163.52063,-1,197.561127,161.933960,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
51620,50465,53733,943,9222,20964,18804,962,35065,52136,14545,...,-2.083513,-2.070738,0.716463,23,132.28098,130.23691,6,121.046852,142.936386,0.0
51621,9222,18804,20964,50465,53733,943,14545,35065,52136,962,...,-2.083513,-2.070738,0.015276,23,122.73780,138.97691,6,119.349892,138.471207,0.0
51622,53733,9222,20964,50465,18804,14545,13391,962,943,35065,...,1.192445,1.123033,0.855399,23,122.73780,138.97691,6,120.011894,142.899796,0.0
51623,50465,9222,53733,20964,943,14545,18804,962,29009,35065,...,0.996133,-2.070738,-0.376593,23,122.73780,138.97691,6,122.750595,140.380402,0.0


In [26]:
    # foldの結果を平均した後、reindexでsubmission fileにindexを合わせる
all_preds = pd.concat(predictions).groupby('site_path_timestamp').mean().reindex(sub.index)

all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,88.722160,105.782982
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,83.814423,101.425003
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,83.859917,105.824104
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,87.263161,106.920578
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,87.503738,108.331017
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,0,215.957413,91.799461
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,0,211.413361,98.566650
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,0,209.161575,106.988518
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,0,203.829727,112.459106


In [27]:
# floorの数値を置換
simple_accurate_99 = pd.read_csv('../01/submission.csv')
all_preds['floor'] = simple_accurate_99['floor'].values
all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,88.722160,105.782982
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,83.814423,101.425003
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,83.859917,105.824104
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,87.263161,106.920578
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,87.503738,108.331017
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,5,215.957413,91.799461
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,5,211.413361,98.566650
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,5,209.161575,106.988518
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,5,203.829727,112.459106


In [28]:
all_preds.to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}.csv")

In [29]:
print(f"CV:{np.mean(val_scores)}")

CV:8.241056883985674


In [30]:
wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type='summary')
wandb.run.name = 'summary'
wandb.log({'CV_score': np.mean(val_scores)})
wandb.save(utils.get_notebook_path())
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Loss/val,58.07487
Loss/xy,58.07487
Loss/floor,4.9106
MPE/val,8.36243
epoch,199.0
trainer/global_step,80799.0
_runtime,1102.0
_timestamp,1617698829.0
_step,199.0


0,1
Loss/val,█▇▆▅▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/xy,█▇▆▅▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/floor,██████▇▆▅▄▃▃▂▁▁▂▂▂▃▃▃▄▄▄▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
MPE/val,█▇▆▆▅▄▄▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


[34m[1mwandb[0m: wandb version 0.10.25 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade




VBox(children=(Label(value=' 0.00MB of 0.53MB uploaded (0.00MB deduped)\r'), FloatProgress(value=0.00120007976…

0,1
CV_score,8.24106
_runtime,2.0
_timestamp,1617698891.0
_step,0.0


0,1
CV_score,▁
_runtime,▁
_timestamp,▁
_step,▁
