# 時系列モデル

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

In [2]:
import os
import sys
import glob
import pickle
import random
import gc
from collections import defaultdict


In [3]:
import numpy as np
import pandas as pd
import scipy.stats as stats
from pathlib import Path


In [4]:
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder

In [5]:
import pytorch_lightning as pl
# from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping

import wandb
from pytorch_lightning.loggers import WandbLogger

In [6]:
sys.path.append('../../')
import src.utils as utils

In [7]:
import multiprocessing
import scipy.interpolate
import scipy.sparse
from tqdm import tqdm

from indoor_location_competition_20.io_f import read_data_file
import indoor_location_competition_20.compute_f as compute_f

In [8]:
DATA_DIR = Path("/home/knikaido/work/Indoor-Location-Navigation/data/")
WIFI_DIR = DATA_DIR / 'time-sequence-unified-wifi'
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'
OUTPUT_DIR = Path('./output/')
MLFLOW_DIR = DATA_DIR / 'mlflow/mlruns'

## config

In [9]:
configs = {
    'memory_length': 5,
    'loss':{
        'name': 'MSELoss',
        'params':{}
    },
    'optimizer':{
        'name': 'Adam',
        'params':{
            'lr': 0.01,
        }
    },

    'scheduler':{
        'name': 'ReduceLROnPlateau',
        'params':{
            'factor': 0.1,
            'patience': 3,
        }
    },

    'loader':{
        'train':{
            'batch_size': 512,
            'shuffle': True,
            'num_workers': 4,
        },
        'valid':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        },
        'test':{
            'batch_size': 512,
            'shuffle': False,
            'num_workers': 4,
        }
    }
}

In [10]:
# config
config = configs

# globals variable
SEED = 777
MAX_EPOCHS = 500
N_SPLITS = 5
DEBUG = False
MEMORY_LENGTH = config['memory_length']
# EXP_MESSAGE = config['globals']['exp_message']

EXP_NAME = 52
IS_SAVE = True

utils.set_seed(SEED)

In [11]:
!wandb login e8aaf98060af90035c3c28a83b34452780aeec20

/bin/sh: 1: wandb: not found


## read data

In [12]:
train_df = pd.read_csv(WIFI_DIR / 'train_all.csv', index_col=0)
test_df = pd.read_csv(WIFI_DIR / 'test_all.csv', index_col=0)

In [13]:
train_df.sort_values(['path_id', 't1_wifi'])

Unnamed: 0,building,path_id,t1_wifi,x,y,floor,magn,wb_0,wb_1,wb_2,...,wr_90,wr_91,wr_92,wr_93,wr_94,wr_95,wr_96,wr_97,wr_98,wr_99
15069,5c3c44b80379370013e0fd2b,5d073b814a19c000086c558b,1951,197.323424,93.030002,2,32.320954,ed0c8923d8363c012ae93551a6d621ef1b47bf49,d6b5a7c24de18ba1deedbbbc8d918389c363a290,273fc15a48ff4dda348f2fea3f5b8412920dbcd0,...,-89.0,-89.0,-89.0,-90.0,-91.0,-92.0,-99.0,-99.0,-99.0,-99.0
15070,5c3c44b80379370013e0fd2b,5d073b814a19c000086c558b,3694,195.481345,91.784404,2,33.694665,273fc15a48ff4dda348f2fea3f5b8412920dbcd0,79060c71437fada756dd90db204b6c0cb47fa2ec,d6b5a7c24de18ba1deedbbbc8d918389c363a290,...,-87.0,-87.0,-88.0,-88.0,-89.0,-89.0,-89.0,-91.0,-92.0,-92.0
15071,5c3c44b80379370013e0fd2b,5d073b814a19c000086c558b,5952,192.408843,90.562736,2,43.278788,deb57e9620d57a75ce41d59aa415a979b08d6380,273fc15a48ff4dda348f2fea3f5b8412920dbcd0,79060c71437fada756dd90db204b6c0cb47fa2ec,...,-87.0,-87.0,-87.0,-87.0,-87.0,-87.0,-88.0,-88.0,-88.0,-88.0
15072,5c3c44b80379370013e0fd2b,5d073b814a19c000086c558b,7616,189.980213,90.076632,2,33.129354,deb57e9620d57a75ce41d59aa415a979b08d6380,79060c71437fada756dd90db204b6c0cb47fa2ec,d6b5a7c24de18ba1deedbbbc8d918389c363a290,...,-85.0,-85.0,-86.0,-86.0,-86.0,-86.0,-86.0,-87.0,-87.0,-87.0
15073,5c3c44b80379370013e0fd2b,5d073b814a19c000086c558b,9854,186.640115,90.025574,2,41.116156,79060c71437fada756dd90db204b6c0cb47fa2ec,273fc15a48ff4dda348f2fea3f5b8412920dbcd0,d6b5a7c24de18ba1deedbbbc8d918389c363a290,...,-84.0,-85.0,-85.0,-85.0,-85.0,-85.0,-86.0,-86.0,-86.0,-86.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
475,5a0546857ecc773753327266,5e15bf941506f2000638fec5,3909,107.801695,149.287739,-1,32.315803,8c936564ea4b4300576f53136505527eb5972c07,c8f3983a40e3c462cdfd155f3d0c77868a717ffd,569878b7f574a738d49e365f45c49f44b81936f2,...,-83.0,-84.0,-84.0,-85.0,-86.0,-86.0,-86.0,-86.0,-87.0,-87.0
476,5a0546857ecc773753327266,5e15bf941506f2000638fec5,5664,107.219428,151.720520,-1,38.678058,f26678bbbbd078e242638a0d1fb5ba2e61262f4c,ce9e2ebf59e424773fcdbd986f7633306bf03124,1f37bbb3f42125f665b83584d0376b21ec3eb43c,...,-83.0,-83.0,-84.0,-85.0,-85.0,-85.0,-86.0,-86.0,-86.0,-86.0
477,5a0546857ecc773753327266,5e15bf941506f2000638fec5,10606,106.233656,155.024749,-1,39.725306,ce9e2ebf59e424773fcdbd986f7633306bf03124,1f37bbb3f42125f665b83584d0376b21ec3eb43c,0c5cda0386b9de9d4e8d5788f192af0dac40ce2c,...,-79.0,-80.0,-80.0,-81.0,-81.0,-81.0,-81.0,-82.0,-82.0,-82.0
478,5a0546857ecc773753327266,5e15bf941506f2000638fec5,11876,105.146487,155.943913,-1,41.274052,1f37bbb3f42125f665b83584d0376b21ec3eb43c,0c5cda0386b9de9d4e8d5788f192af0dac40ce2c,e57024faa1bfe6de5d71e1b779bff11e72e97e1d,...,-76.0,-76.0,-76.0,-76.0,-77.0,-77.0,-77.0,-77.0,-78.0,-78.0


In [14]:
train_df

Unnamed: 0,building,path_id,t1_wifi,x,y,floor,magn,wb_0,wb_1,wb_2,...,wr_90,wr_91,wr_92,wr_93,wr_94,wr_95,wr_96,wr_97,wr_98,wr_99
0,5a0546857ecc773753327266,5e1581c41506f2000638fc74,2564,145.299586,173.060338,-1,31.046997,68870898474881668126747d3f20fb6c773a61e7,a42659766c670273e04f3e57008d2dbd5dd8d7c9,cdd64541ef65a742a8792a678d7380c12b321ffc,...,-79.0,-79.0,-80.0,-80.0,-80.0,-80.0,-80.0,-80.0,-80.0,-80.0
1,5a0546857ecc773753327266,5e1581c41506f2000638fc74,4179,143.192898,173.129835,-1,40.819800,a42659766c670273e04f3e57008d2dbd5dd8d7c9,bc6c9049fa1d4dc56d50c2c9b8a945c6b09bec2b,d3d056da278ff2975a96a486e4c1f70128570110,...,-80.0,-80.0,-81.0,-81.0,-81.0,-81.0,-81.0,-82.0,-82.0,-82.0
2,5a0546857ecc773753327266,5e1581c41506f2000638fc74,5753,141.427991,173.203771,-1,40.829225,d3d056da278ff2975a96a486e4c1f70128570110,f67bdd42a2c664b7283db103b49a2e5d95f53823,1b8286ce49240e893317288ec4c13ad463a63829,...,-81.0,-81.0,-82.0,-82.0,-82.0,-82.0,-82.0,-82.0,-82.0,-83.0
3,5a0546857ecc773753327266,5e1581c2f4c3420006d520fb,51345,140.315186,173.269150,-1,40.926437,d41a54c96fd9e42c17ae07a6a20dc76e27b1eeaa,ed3c9df2182af4c77e5de6beb4aeeb44b344d2a4,f67bdd42a2c664b7283db103b49a2e5d95f53823,...,-79.0,-79.0,-79.0,-79.0,-79.0,-79.0,-79.0,-79.0,-79.0,-79.0
4,5a0546857ecc773753327266,5e1581c41506f2000638fc74,9587,140.043815,172.636879,-1,40.227238,1ccfad270acf0c810365000a5b7b1202b946a80c,1b8286ce49240e893317288ec4c13ad463a63829,f67bdd42a2c664b7283db103b49a2e5d95f53823,...,-80.0,-80.0,-80.0,-80.0,-80.0,-80.0,-80.0,-80.0,-80.0,-80.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
220752,5dc8cea7659e181adb076a3f,5dcfb834878f3300066c70e8,15334,247.514738,75.595380,6,53.341577,1d6b29e1e0ccf63fde1bea4391dd08c9778dd199,b719d7de6ffe1b1b7409d2eb5ac3268e36cf2675,d5e6a4c6a501427a1f96ef34e221b502a36e9fe2,...,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0
220753,5dc8cea7659e181adb076a3f,5dcfb83394e49000061259a1,1696,249.736508,75.082064,6,44.700074,21f12bdf533e0437f00f79bd3455874e556acc5e,4b5dbdb52b131410ea10b59ea451de62280b41d6,66f3b15808291d9086fecd1bd5f1c0195ff5ccf5,...,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0
220754,5dc8cea7659e181adb076a3f,5dcfb83394e49000061259a1,4578,249.248034,74.793289,6,40.435783,21f12bdf533e0437f00f79bd3455874e556acc5e,66f3b15808291d9086fecd1bd5f1c0195ff5ccf5,d5e6a4c6a501427a1f96ef34e221b502a36e9fe2,...,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0
220755,5dc8cea7659e181adb076a3f,5dcfb83394e49000061259a1,5822,246.489740,74.066103,6,42.836657,21f12bdf533e0437f00f79bd3455874e556acc5e,24e18319c625893f325419b53c5cc22b75ed8e54,42a299c13d5fd662afdf3c77694372fa228a3da3,...,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0,-99.0


In [14]:
sub = pd.read_csv(DATA_DIR/'indoor-location-navigation/sample_submission.csv', index_col=0)

BSSIDとRSSIは100ずつ存在しているけど全てが必要なわけではないみたい  
ここでは20だけ取り出している。

In [15]:
# training target features
NUM_FEATS = 80
BSSID_FEATS = [f'wb_{i}' for i in range(NUM_FEATS)]
RSSI_FEATS  = [f'wr_{i}' for i in range(NUM_FEATS)]
# DELTA_FEATS  = ['delta_x', 'delta_y']

bssid_NはN個目のBSSIDを示しておりRSSI値が大きい順に番号が振られている。
100個しかない


In [16]:
def del_longdata(tgt_df):

    df_ = tgt_df.copy()
    del_path = []
    new_trains = []
    for i, (p, g) in enumerate(df_.groupby('path_id')):
        if len(g) > MEMORY_LENGTH:
            del_path.append(p)
            for j in range((len(g) // MEMORY_LENGTH) + 1):
                if j == (len(g) // MEMORY_LENGTH):
                    tmp = g.iloc[j*MEMORY_LENGTH:]
                else:
                    tmp = g.iloc[j*MEMORY_LENGTH:(j+1)*MEMORY_LENGTH]
                tmp.loc[:, 'path_id'] = p + '_' + str(j)
                new_trains.append(tmp)

    df_.drop(df_[df_['path_id'].isin(del_path)].index, inplace=True)
    df_ = pd.concat([df_, pd.concat(new_trains)]).reset_index(drop=True)
    print(f"previous path len:{i+1}, current path len:{len(df_.groupby('path_id'))}")
    return df_

train_data = del_longdata(train_df)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.obj[item] = s


previous path len:10830, current path len:48502


In [17]:
BSSID_FEATS = [f'wb_{i}' for i in range(NUM_FEATS)]
RSSI_FEATS = [f'wr_{i}' for i in range(NUM_FEATS)]
X_train = train_data.loc[:, ['t1_wifi', 'building',
                             'path_id'] + BSSID_FEATS + RSSI_FEATS]
y_train = train_data.loc[:, ['t1_wifi', 'path_id', 'x', 'y', 'building']]
X_test = test_df.loc[:, ['t1_wifi', 'building',
                           'path_id'] + BSSID_FEATS + RSSI_FEATS]

### building weight

In [18]:
test_building_weight = defaultdict(int)
train_building_weight = defaultdict(int)
building_weight = dict()

for building in [x.split('_')[0] for x in sub.reset_index()['site_path_timestamp'].values]:
    test_building_weight[building] += 1 * 24/len(sub)
    
for building in train_data['building'].values:
    train_building_weight[building] += 1
train_building_weight = dict((k, len(train_data)/24/v) for k, v in train_building_weight.items())

for building in list(train_building_weight.keys()):
    building_weight[building] = train_building_weight[building] * test_building_weight[building]
    
y_train['weight'] = [building_weight[x] for x in y_train['building'].values]

In [19]:
# get numbers of bssids to embed them in a layer

# train
wifi_bssids = []
# bssidを列ごとにリストに入れていく
for i in BSSID_FEATS:
    wifi_bssids.extend(train_df.loc[:,i].values.tolist())
wifi_bssids = list(set(wifi_bssids))

train_wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(train): {train_wifi_bssids_size}')

# test
wifi_bssids_test = []
for i in BSSID_FEATS:
    wifi_bssids_test.extend(test_df.loc[:,i].values.tolist())
wifi_bssids_test = list(set(wifi_bssids_test))

test_wifi_bssids_size = len(wifi_bssids_test)
print(f'BSSID TYPES(test): {test_wifi_bssids_size}')


wifi_bssids.extend(wifi_bssids_test)
wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(all): {wifi_bssids_size}')


BSSID TYPES(train): 29458
BSSID TYPES(test): 25156
BSSID TYPES(all): 54614


In [20]:
# get numbers of bssids to embed them in a layer

# train
rssi_bssids = []
# bssidを列ごとにリストに入れていく
for i in RSSI_FEATS:
    rssi_bssids.extend(train_df.loc[:,i].values.tolist())
rssi_bssids = list(set(rssi_bssids))

train_rssi_bssids_size = len(rssi_bssids)
print(f'RSSI TYPES(train): {train_rssi_bssids_size}')

# test
rssi_bssids_test = []
for i in RSSI_FEATS:
    rssi_bssids_test.extend(test_df.loc[:,i].values.tolist())
rssi_bssids_test = list(set(rssi_bssids_test))

test_rssi_bssids_size = len(rssi_bssids_test)
print(f'RSSI TYPES(test): {test_rssi_bssids_size}')


rssi_bssids.extend(rssi_bssids_test)
rssi_bssids_size = len(rssi_bssids)
print(f'RSSI TYPES(all): {rssi_bssids_size}')

RSSI TYPES(train): 94
RSSI TYPES(test): 234
RSSI TYPES(all): 328


## PreProcess

In [21]:
# preprocess

le = LabelEncoder()
le.fit(wifi_bssids)
le_site = LabelEncoder()
le_site.fit(train_df['building'])
le_rssi = LabelEncoder()
le_rssi.fit(rssi_bssids)

# ss = StandardScaler()
# ss.fit(train_data.loc[:,DELTA_FEATS])


def preprocess(input_df, le=le, le_site=le_site):
    output_df = input_df.copy()
    # RSSIの正規化
#     output_df.loc[:,RSSI_FEATS] = ss.transform(input_df.loc[:,RSSI_FEATS])

    # BSSIDのLE(1からふる)
    for i in BSSID_FEATS:
        output_df.loc[:,i] = le.transform(input_df.loc[:,i])
#         output_df.loc[:,i] = output_df.loc[:,i] + 1  # 0からではなく1から番号を振りたいため なぜ？
    for i in RSSI_FEATS:
        output_df.loc[:,i] = le_rssi.transform(input_df.loc[:,i])

    # site_idのLE
    output_df.loc[:, 'building'] = le_site.transform(input_df.loc[:, 'building'])

    # なぜ２重でやる？
#     output_df.loc[:,DELTA_FEATS] = ss.transform(output_df.loc[:,DELTA_FEATS])
    return output_df

train = preprocess(X_train)
test = preprocess(X_test)

train  

Unnamed: 0,t1_wifi,building,path_id,wb_0,wb_1,wb_2,wb_3,wb_4,wb_5,wb_6,...,wr_70,wr_71,wr_72,wr_73,wr_74,wr_75,wr_76,wr_77,wr_78,wr_79
0,1945,0,5e158f1f1506f2000638fd3b,8776,24860,17948,24284,12635,365,27656,...,67,67,67,67,67,67,67,67,67,67
1,3941,0,5e158f1f1506f2000638fd3b,24860,24284,12635,8776,365,27656,17948,...,61,61,61,61,61,61,61,57,57,57
2,5696,0,5e158f1f1506f2000638fd3b,24284,12635,365,8776,24860,27656,10132,...,67,67,67,67,67,67,67,61,61,61
3,7814,0,5e158f1f1506f2000638fd3b,8776,27656,365,12635,24284,24860,17948,...,71,71,71,71,67,67,67,67,67,67
4,9992,0,5e158f1f1506f2000638fd3b,24860,24284,8776,12635,365,2700,27656,...,76,76,76,76,71,71,71,71,71,71
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
220752,3909,0,5e15bf941506f2000638fec5_0,16182,23203,9977,26485,27440,20075,6025,...,67,67,67,67,67,67,67,57,57,57
220753,5664,0,5e15bf941506f2000638fec5_0,27968,23877,3588,9977,26485,23203,20075,...,71,71,71,71,71,71,71,71,71,67
220754,10606,0,5e15bf941506f2000638fec5_0,23877,3588,1403,26485,23203,9977,6025,...,67,67,67,67,67,61,61,61,61,61
220755,11876,0,5e15bf941506f2000638fec5_0,3588,1403,26485,23877,20075,6025,25268,...,89,89,89,89,83,83,83,83,76,76


In [22]:
x_min = np.min(y_train.loc[:, 'x'].values)
y_min = np.min(y_train.loc[:, 'y'].values)
norm_x = np.max(y_train.loc[:, 'x'].values) - x_min
norm_y = np.max(y_train.loc[:, 'y'].values) - y_min

In [23]:
site_count = len(train['building'].unique())
site_count

24

In [24]:
def make_data(data, col):
    train_tmp = []
    path_ids = []
    for path_id_, group in tqdm(data.groupby('path_id')):
        group = group.sort_values('t1_wifi')
        train_tmp.append(group[col])
        path_ids.append(path_id_)
        
    return train_tmp, path_ids
train_cols = ['building'] + BSSID_FEATS + RSSI_FEATS 
y_train_np, _ = np.array(make_data(y_train, col=['x', 'y', 'weight']))
X_train_np, path_ids = np.array(make_data(train, col=train_cols))

100%|██████████| 48502/48502 [00:45<00:00, 1055.10it/s]
  # This is added back by InteractiveShellApp.init_path()
100%|██████████| 48502/48502 [00:46<00:00, 1043.50it/s]
  if sys.path[0] == '':


## PyTorch model
- embedding layerが重要  

In [32]:
class CustomDataset(Dataset):
    def __init__(self, x_train, y_train, transform, inverse_ratio=0.5, combine_ratio=0.2):
        self.transform = transform
        self.x_train = x_train
        self.y_train = y_train
        self.inverse_ratio = inverse_ratio
        self.combine_ratio = combine_ratio

    def __getitem__(self, index):
        x = self.x_train[index].values
        y = self.y_train[index].values
        mask = np.full(MEMORY_LENGTH, True, dtype=bool)
        if len(x) < MEMORY_LENGTH:
            x_out = np.pad(x, ([(0, MEMORY_LENGTH - len(x)), (0, 0)]), 'edge')
            y_out = np.pad(y, ([(0, MEMORY_LENGTH - len(x)), (0, 0)]), 'edge')
            mask[len(x):] = False
        else:
            x_out = x
            y_out = y

#         if self.transform:
#             # inverse trajectory
#             if np.random.rand() < self.inverse_ratio:
#                 tmp = np.arange(MEMORY_LENGTH-1, -1, -1)
#                 x_out = x_out[tmp, :]
#                 y_out = y_out[tmp, :]
#                 mask = mask[tmp]

#             # # combine trajectory
#             if np.random.rand() < self.combine_ratio:
#                 p = np.random.permutation(list(range(MEMORY_LENGTH)))
#                 x_out = x_out[p]
#                 y_out = y_out[p]
#                 mask = mask[p]

        return x_out, y_out, mask

    def __len__(self):
        return len(self.x_train)

In [26]:
class ManyToMany(nn.Module):
    def __init__(self, wifi_bssids_size, input_dim, hidden_dim):
        super(ManyToMany, self).__init__()
        self.emb_dim = 256
        self.entitity_dim = 128
        self.bssi = nn.Embedding(wifi_bssids_size, self.emb_dim)
        self.building = nn.Embedding(24, 2)

        self.entity1 = nn.Sequential(
            nn.Linear(self.emb_dim + 3, self.entitity_dim),
            nn.Tanh()
        )
        self.entity2 = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.entitity_dim*NUM_FEATS, input_dim),
        )
        self.gru = nn.GRU(input_dim, hidden_dim, num_layers=1,batch_first=True)

        self.main = nn.Sequential(
            nn.Linear(hidden_dim, 32),
            nn.ReLU(True),
            nn.Linear(32, 2),
        )

    def forward(self, x):
        bssids = self.bssi(x[:, :, 1:int(NUM_FEATS+1)].long()).float()
        buildings = self.building(x[:, :, 0].long()).unsqueeze(2).expand(-1, -1, NUM_FEATS, -1).float()
        rssis = x[:, :, int(NUM_FEATS+1):int(NUM_FEATS*2+1)].unsqueeze(3).float()
        
        x = torch.cat((bssids,  rssis, buildings), axis=3)
        #(batch, memory_length, wifi_num, self.emb_dim+2+1)
        x = self.entity1(x)
        #(batch, memory_length, wifi_num, self.entity1_dim)
        x = x.flatten(start_dim=2)
        #(batch, memory_length, wifi_num * self.entity1_dim)
        x = self.entity2(x)
        #(batch, memory_length, input_dim)
        output, _ = self.gru(x)
        output = self.main(output)
        #(batch, memory_length, 2)
        return output

In [27]:
def mean_position_error(xhat, yhat, fhat, x, y, f):
    intermediate = np.sqrt(np.power(xhat-x, 2) + np.power(yhat-y, 2)) + 15 * np.abs(fhat-f)
    return intermediate.sum()/xhat.shape[0]

def to_np(input):
    return input.detach().cpu().numpy()

In [28]:
def get_optimizer(model: nn.Module, config: dict):
    optimizer_config = config["optimizer"]
    optimizer_name = optimizer_config.get("name")
    base_optimizer_name = optimizer_config.get("base_name")
    optimizer_params = optimizer_config['params']

    if hasattr(optim, optimizer_name):
        optimizer = optim.__getattribute__(optimizer_name)(model.parameters(), **optimizer_params)
        return optimizer
    else:
        base_optimizer = optim.__getattribute__(base_optimizer_name)
        optimizer = globals().get(optimizer_name)(
            model.parameters(), 
            base_optimizer,
            **optimizer_config["params"])
        return  optimizer

def get_scheduler(optimizer, config: dict):
    scheduler_config = config["scheduler"]
    scheduler_name = scheduler_config.get("name")

    if scheduler_name is None:
        return
    else:
        return optim.lr_scheduler.__getattribute__(scheduler_name)(
            optimizer, **scheduler_config["params"])


def get_criterion(config: dict):
    loss_config = config["loss"]
    loss_name = loss_config["name"]
    loss_params = {} if loss_config.get("params") is None else loss_config.get("params")
    if hasattr(nn, loss_name):
        criterion = nn.__getattribute__(loss_name)(**loss_params)
    else:
        criterion = globals().get(loss_name)(**loss_params)

    return criterion

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)
    
def comp_metric_xy(output, y):
    return torch.sqrt((((output - y[:,:2]))**2).sum())

In [29]:
# Learner class(pytorch-lighting)
class Learner(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.xy_criterion = comp_metric_xy
        self.f_criterion = get_criterion(config)
    
    def training_step(self, batch, batch_idx):
        x, y, m = batch
        output = self.model(x)
        loss = self.xy_criterion(output[m], y[m])
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y, m = batch
        output = self.model(x)
        loss = self.xy_criterion(output[m], y[m])
        mpe = mean_position_error(
            to_np(output[m][:, 0]), to_np(output[m][:, 1]), 0, 
            to_np(y[m][:, 0]), to_np(y[m][:, 1]), 0)
        
        # floor lossは現状は無視して良い
        self.log(f'Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
#         self.log(f'Loss/xy', xy_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
#         self.log(f'Loss/floor', f_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'MPE/val', mpe, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return mpe
    
    def validation_epoch_end(self, outputs):
        avg_loss = np.mean(outputs)
        print(f'epoch = {self.current_epoch}, mpe_loss = {avg_loss}')

    def configure_optimizers(self):
        optimizer = get_optimizer(self.model, self.config)
        scheduler = get_scheduler(optimizer, self.config)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "Loss/val"}

In [30]:
# oof
def evaluate(model, loaders, phase):
    x_list = []
    y_list = []
    f_list = []
    with torch.no_grad():
        for batch in loaders[phase]:
            x, y = batch
            output = model(x)
            x_list.append(to_np(output['xy'][:, 0]))
            y_list.append(to_np(output['xy'][:, 1]))
            f_list.append(to_np(output['floor']))

    x_list = np.concatenate(x_list)
    y_list = np.concatenate(y_list)
    f_list = np.concatenate(f_list)
    return x_list, y_list, f_list

## train

In [33]:
oofs = np.zeros((len(train), 2), dtype = np.float32)  # 全てのoofをdfで格納する
predictions = []  # 全ての予測値をdfで格納する
val_scores = []
# skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
gkf = GroupKFold(n_splits=N_SPLITS)
# for fold, (trn_idx, val_idx) in enumerate(skf.split(train.loc[:, 'path'], train.loc[:, 'path'])):
# for fold, (trn_idx, val_idx) in enumerate(gkf.split(train.loc[:, 'path_id'], groups=train.loc[:, 'path_id'])):
for fold, (trn_idx, val_idx) in enumerate(gkf.split(path_ids, groups=path_ids)):


    # 指定したfoldのみループを回す

    print('=' * 20)
    print(f'Fold {fold}')
    print('=' * 20)

    # train/valid data

#     trn_df = train.loc[trn_idx, BSSID_FEATS + RSSI_FEATS + DELTA_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)
#     val_df = train.loc[val_idx, BSSID_FEATS + RSSI_FEATS + DELTA_FEATS + ['site_id', 'x','y','floor']].reset_index(drop=True)

    # data loader
    loaders = {}
    loader_config = config["loader"]
#     loaders["train"] = DataLoader(IndoorDataset(trn_df, phase="train"), **loader_config["train"], worker_init_fn=worker_init_fn) 
#     loaders["valid"] = DataLoader(IndoorDataset(val_df, phase="valid"), **loader_config["valid"], worker_init_fn=worker_init_fn)
#     loaders["test"] = DataLoader(IndoorDataset(test, phase="test"), **loader_config["test"], worker_init_fn=worker_init_fn)
  
    mtrain, mvalid = X_train_np[trn_idx], y_train_np[trn_idx]
    ltrain, lvalid = X_train_np[val_idx], y_train_np[val_idx]
    loaders["train"] = DataLoader(CustomDataset(
        mtrain, mvalid, transform=True), **loader_config["train"], worker_init_fn=worker_init_fn)
    loaders["valid"] = DataLoader(CustomDataset(
        ltrain, lvalid, transform=False), **loader_config["valid"], worker_init_fn=worker_init_fn)
    
    
    # model
#     model = LSTMModel(wifi_bssids_size+rssi_bssids_size, site_count)
    model = ManyToMany(wifi_bssids_size, 64, 128)
    model_name = model.__class__.__name__
    
    # loggers
    RUN_NAME = f'exp{str(EXP_NAME)}'
#     wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type=RUN_NAME + f'-fold-{fold}')
#     wandb.run.name = RUN_NAME + f'-fold-{fold}'
#     wandb_config = wandb.config
#     wandb_config.model_name = model_name
#     wandb.watch(model)
    
    
#     loggers = []
#     loggers.append(WandbLogger())

    learner = Learner(model, config)
    
    # callbacks
    callbacks = []
    checkpoint_callback = ModelCheckpoint(
        monitor=f'Loss/val',
        mode='min',
        dirpath=OUTPUT_DIR,
        verbose=False,
        filename=f'{model_name}-{learner.current_epoch}-{fold}')
    callbacks.append(checkpoint_callback)

    early_stop_callback = EarlyStopping(
        monitor='Loss/val',
        min_delta=0.00,
        patience=20,
        verbose=True,
        mode='min')
    callbacks.append(early_stop_callback)
    
    trainer = pl.Trainer(
#         logger=loggers,
        callbacks=callbacks,
        max_epochs=MAX_EPOCHS,
        default_root_dir=OUTPUT_DIR,
        gpus=1,
        fast_dev_run=DEBUG,
        deterministic=True,
        benchmark=True,
#         precision=16,
#         progress_bar_refresh_rate=0  # vscodeの時progress barの動作が遅いので表示しない
        )


    trainer.fit(learner, train_dataloader=loaders['train'], val_dataloaders=loaders['valid'])
    
    
    # prediction by sliding window averaging. see https://arxiv.org/pdf/1903.11703.pdf
    prediction = []
    with torch.no_grad():
        for p, x in X_test.groupby('path_id'):
            window_score = defaultdict(list)
            x = x.sort_values('t1_wifi')
            ts = x['t1_wifi'].to_numpy()
            x = x[train_cols].reset_index(drop=True)
            
            for window in range(len(ts)):
                if MEMORY_LENGTH + window > len(ts):
                    break
                
                x_out = torch.tensor(
                    x.iloc[window:(window+MEMORY_LENGTH)].drop('path_id', axis=1).to_numpy())
                x_out = x_out.unsqueeze(0)
                x_out = model(x_out.to(DEVICE))
                for i in range(MEMORY_LENGTH):
                    window_score[ts[window + i]].append(x_out.squeeze().cpu().detach().numpy()[i,:])
                
            #  sort by time and get average 
            window_score = sorted(window_score.items(), key=lambda x:x[0])
            prediction.extend(list(map(lambda x: np.mean(x[1],axis=0), window_score)))
    
    
    prediction = np.array(prediction)
    result['x'] += (prediction[:, 0] * norm_x + x_min) / FOLDS
    result['y'] += (prediction[:, 1] * norm_y + y_min) / FOLDS    
    
    
    
    

    #############
    # validation (to make oof)
    #############
    model.eval()
    oof_x, oof_y, oof_f = evaluate(model, loaders, phase="valid")
    oofs[val_idx, 0] = oof_x
    oofs[val_idx, 1] = oof_y

    
    val_score = mean_position_error(
        oof_x, oof_y, 0,
        val_df['x'].values, val_df['y'].values, 0)
    val_scores.append(val_score)
    print(f"fold {fold}: mean position error {val_score}")

    #############
#     # inference
#     #############
#     preds_x, preds_y, preds_f = evaluate(model, loaders, phase="test")
#     test_preds = pd.DataFrame(np.stack((preds_f, preds_x, preds_y))).T
#     test_preds.columns = sub.columns
#     test_preds["site_path_timestamp"] = test["site_path_timestamp"]
#     test_preds["floor"] = test_preds["floor"].astype(int)
#     predictions.append(test_preds)
    

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type       | Params
-------------------------------------------
0 | model       | ManyToMany | 14.7 M
1 | f_criterion | MSELoss    | 0     
-------------------------------------------
14.7 M    Trainable params
0         Non-trainable params
14.7 M    Total params
58.995    Total estimated model params size (MB)


Fold 0


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 176.17425710673575


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, mpe_loss = 81.63021728111929


Validating: 0it [00:00, ?it/s]

epoch = 1, mpe_loss = 82.15960835511287


Validating: 0it [00:00, ?it/s]

epoch = 2, mpe_loss = 56.975124435453345


Validating: 0it [00:00, ?it/s]

epoch = 3, mpe_loss = 52.889505729328334


Validating: 0it [00:00, ?it/s]

epoch = 4, mpe_loss = 47.196360408429825


Validating: 0it [00:00, ?it/s]

epoch = 5, mpe_loss = 45.30448861288317


Validating: 0it [00:00, ?it/s]

epoch = 6, mpe_loss = 43.09282166459271


Validating: 0it [00:00, ?it/s]

epoch = 7, mpe_loss = 42.46735385472363


Validating: 0it [00:00, ?it/s]

epoch = 8, mpe_loss = 40.83732725943968


Validating: 0it [00:00, ?it/s]

epoch = 9, mpe_loss = 42.2777386940944


Validating: 0it [00:00, ?it/s]

epoch = 10, mpe_loss = 39.92289662489666


Validating: 0it [00:00, ?it/s]

epoch = 11, mpe_loss = 41.364841015243684


Validating: 0it [00:00, ?it/s]

epoch = 12, mpe_loss = 38.44619130874611


Validating: 0it [00:00, ?it/s]

epoch = 13, mpe_loss = 37.125384878135435


Validating: 0it [00:00, ?it/s]

epoch = 14, mpe_loss = 35.5327365259904


Validating: 0it [00:00, ?it/s]

epoch = 15, mpe_loss = 37.61473719028016


Validating: 0it [00:00, ?it/s]

epoch = 16, mpe_loss = 34.89544206996741


Validating: 0it [00:00, ?it/s]

epoch = 17, mpe_loss = 32.94291429351417


Validating: 0it [00:00, ?it/s]

epoch = 18, mpe_loss = 32.22592679494703


Validating: 0it [00:00, ?it/s]

epoch = 19, mpe_loss = 31.089956347317376


Validating: 0it [00:00, ?it/s]

epoch = 20, mpe_loss = 30.612800550026986


Validating: 0it [00:00, ?it/s]

epoch = 21, mpe_loss = 30.081347785604425


Validating: 0it [00:00, ?it/s]

epoch = 22, mpe_loss = 30.209058477752052


Validating: 0it [00:00, ?it/s]

epoch = 23, mpe_loss = 29.802099052594873


Validating: 0it [00:00, ?it/s]

epoch = 24, mpe_loss = 29.421681548772305


Validating: 0it [00:00, ?it/s]

epoch = 25, mpe_loss = 28.855522733430128


Validating: 0it [00:00, ?it/s]

epoch = 26, mpe_loss = 29.908074817347142


Validating: 0it [00:00, ?it/s]

epoch = 27, mpe_loss = 29.262887146579207


Validating: 0it [00:00, ?it/s]

epoch = 28, mpe_loss = 27.62871764387524


Validating: 0it [00:00, ?it/s]

epoch = 29, mpe_loss = 28.279177793275824


Validating: 0it [00:00, ?it/s]

epoch = 30, mpe_loss = 27.91447159785292


Validating: 0it [00:00, ?it/s]

epoch = 31, mpe_loss = 26.698456750828644


Validating: 0it [00:00, ?it/s]

epoch = 32, mpe_loss = 26.94815845919454


Validating: 0it [00:00, ?it/s]

epoch = 33, mpe_loss = 27.07938308619902


Validating: 0it [00:00, ?it/s]

epoch = 34, mpe_loss = 27.501029072966215


Validating: 0it [00:00, ?it/s]

epoch = 35, mpe_loss = 27.49855428651519


Validating: 0it [00:00, ?it/s]

epoch = 36, mpe_loss = 26.20595178625433


Validating: 0it [00:00, ?it/s]

epoch = 37, mpe_loss = 25.873798373723197


Validating: 0it [00:00, ?it/s]

epoch = 38, mpe_loss = 25.42527576020229


Validating: 0it [00:00, ?it/s]

epoch = 39, mpe_loss = 25.23995005380097


Validating: 0it [00:00, ?it/s]

epoch = 40, mpe_loss = 25.122640114610977


Validating: 0it [00:00, ?it/s]

epoch = 41, mpe_loss = 24.843999617026597


Validating: 0it [00:00, ?it/s]

epoch = 42, mpe_loss = 24.742999575214505


Validating: 0it [00:00, ?it/s]

epoch = 43, mpe_loss = 24.49149108272376


Validating: 0it [00:00, ?it/s]

epoch = 44, mpe_loss = 24.40304733646647


Validating: 0it [00:00, ?it/s]

epoch = 45, mpe_loss = 24.395806972582722


Validating: 0it [00:00, ?it/s]

epoch = 46, mpe_loss = 24.231427789538944


Validating: 0it [00:00, ?it/s]

epoch = 47, mpe_loss = 24.09917789022666


Validating: 0it [00:00, ?it/s]

epoch = 48, mpe_loss = 23.960385975048133


Validating: 0it [00:00, ?it/s]

epoch = 49, mpe_loss = 23.811656754166485


Validating: 0it [00:00, ?it/s]

epoch = 50, mpe_loss = 23.7203645800301


Validating: 0it [00:00, ?it/s]

epoch = 51, mpe_loss = 23.61137413769105


Validating: 0it [00:00, ?it/s]

epoch = 52, mpe_loss = 23.56802340589336


Validating: 0it [00:00, ?it/s]

epoch = 53, mpe_loss = 23.414018801102788


Validating: 0it [00:00, ?it/s]

epoch = 54, mpe_loss = 23.351377575381147


Validating: 0it [00:00, ?it/s]

epoch = 55, mpe_loss = 23.300858427004375


Validating: 0it [00:00, ?it/s]

epoch = 56, mpe_loss = 23.238078476971502


Validating: 0it [00:00, ?it/s]

epoch = 57, mpe_loss = 23.130990756265128


Validating: 0it [00:00, ?it/s]

epoch = 58, mpe_loss = 23.067019509838765


Validating: 0it [00:00, ?it/s]

epoch = 59, mpe_loss = 23.06670953652134


Validating: 0it [00:00, ?it/s]

epoch = 60, mpe_loss = 22.956040149051876


Validating: 0it [00:00, ?it/s]

epoch = 61, mpe_loss = 22.76137662691147


Validating: 0it [00:00, ?it/s]

epoch = 62, mpe_loss = 22.637513228382616


Validating: 0it [00:00, ?it/s]

epoch = 63, mpe_loss = 22.586685503230186


Validating: 0it [00:00, ?it/s]

epoch = 64, mpe_loss = 22.509632862935028


Validating: 0it [00:00, ?it/s]

epoch = 65, mpe_loss = 22.38668140327503


Validating: 0it [00:00, ?it/s]

epoch = 66, mpe_loss = 22.365822739595313


Validating: 0it [00:00, ?it/s]

epoch = 67, mpe_loss = 22.25312231604252


Validating: 0it [00:00, ?it/s]

epoch = 68, mpe_loss = 22.237905864423613


Validating: 0it [00:00, ?it/s]

epoch = 69, mpe_loss = 22.239746891932384


Validating: 0it [00:00, ?it/s]

epoch = 70, mpe_loss = 22.107722570230738


Validating: 0it [00:00, ?it/s]

epoch = 71, mpe_loss = 22.071090404869654


Validating: 0it [00:00, ?it/s]

epoch = 72, mpe_loss = 22.00288203489086


Validating: 0it [00:00, ?it/s]

epoch = 73, mpe_loss = 22.06721694172354


Validating: 0it [00:00, ?it/s]

epoch = 74, mpe_loss = 21.98428912859688


Validating: 0it [00:00, ?it/s]

epoch = 75, mpe_loss = 21.89063734285641


Validating: 0it [00:00, ?it/s]

epoch = 76, mpe_loss = 21.872251856450188


Validating: 0it [00:00, ?it/s]

epoch = 77, mpe_loss = 21.89413393602048


Validating: 0it [00:00, ?it/s]

epoch = 78, mpe_loss = 21.768182042870055


Validating: 0it [00:00, ?it/s]

epoch = 79, mpe_loss = 21.717964051891144


Validating: 0it [00:00, ?it/s]

epoch = 80, mpe_loss = 21.637485661421213


Validating: 0it [00:00, ?it/s]

epoch = 81, mpe_loss = 21.59916115634281


Validating: 0it [00:00, ?it/s]

epoch = 82, mpe_loss = 21.6083983476549


Validating: 0it [00:00, ?it/s]

epoch = 83, mpe_loss = 21.491308475980436


Validating: 0it [00:00, ?it/s]

epoch = 84, mpe_loss = 21.425652310613117


Validating: 0it [00:00, ?it/s]

epoch = 85, mpe_loss = 21.36610241958278


Validating: 0it [00:00, ?it/s]

epoch = 86, mpe_loss = 21.452877475903335


Validating: 0it [00:00, ?it/s]

epoch = 87, mpe_loss = 21.41652471576207


Validating: 0it [00:00, ?it/s]

epoch = 88, mpe_loss = 21.37781081308011


Validating: 0it [00:00, ?it/s]

epoch = 89, mpe_loss = 21.216736854537515


Validating: 0it [00:00, ?it/s]

epoch = 90, mpe_loss = 21.16699111274305


Validating: 0it [00:00, ?it/s]

epoch = 91, mpe_loss = 21.10860494247925


Validating: 0it [00:00, ?it/s]

epoch = 92, mpe_loss = 21.033120437812048


Validating: 0it [00:00, ?it/s]

epoch = 93, mpe_loss = 21.051296096160314


Validating: 0it [00:00, ?it/s]

epoch = 94, mpe_loss = 21.025861588838513


Validating: 0it [00:00, ?it/s]

epoch = 95, mpe_loss = 20.833417808687955


Validating: 0it [00:00, ?it/s]

epoch = 96, mpe_loss = 20.834586852686034


Validating: 0it [00:00, ?it/s]

epoch = 97, mpe_loss = 20.857859944657452


Validating: 0it [00:00, ?it/s]

epoch = 98, mpe_loss = 20.802476120248468


Validating: 0it [00:00, ?it/s]

epoch = 99, mpe_loss = 20.748676443704205


Validating: 0it [00:00, ?it/s]

epoch = 100, mpe_loss = 20.758943573699025


Validating: 0it [00:00, ?it/s]

epoch = 101, mpe_loss = 20.76110601808217


Validating: 0it [00:00, ?it/s]

epoch = 102, mpe_loss = 20.779376085513242


Validating: 0it [00:00, ?it/s]

epoch = 103, mpe_loss = 20.785537325001098


Validating: 0it [00:00, ?it/s]

epoch = 104, mpe_loss = 20.624109822649338


Validating: 0it [00:00, ?it/s]

epoch = 105, mpe_loss = 20.64667559357338


Validating: 0it [00:00, ?it/s]

epoch = 106, mpe_loss = 20.484916049887822


Validating: 0it [00:00, ?it/s]

epoch = 107, mpe_loss = 20.666244594324432


Validating: 0it [00:00, ?it/s]

epoch = 108, mpe_loss = 20.584826898438298


Validating: 0it [00:00, ?it/s]

epoch = 109, mpe_loss = 20.5988348268633


Validating: 0it [00:00, ?it/s]

epoch = 110, mpe_loss = 20.531247099273582


Validating: 0it [00:00, ?it/s]

epoch = 111, mpe_loss = 20.49663817875602


Validating: 0it [00:00, ?it/s]

epoch = 112, mpe_loss = 20.483480800462065


Validating: 0it [00:00, ?it/s]

epoch = 113, mpe_loss = 20.447078503976602


Validating: 0it [00:00, ?it/s]

epoch = 114, mpe_loss = 20.427897646190495


Validating: 0it [00:00, ?it/s]

epoch = 115, mpe_loss = 20.409337389199038


Validating: 0it [00:00, ?it/s]

epoch = 116, mpe_loss = 20.41802277851438


Validating: 0it [00:00, ?it/s]

epoch = 117, mpe_loss = 20.392519288942072


Validating: 0it [00:00, ?it/s]

epoch = 118, mpe_loss = 20.370715739016244


Validating: 0it [00:00, ?it/s]

epoch = 119, mpe_loss = 20.376411523862117


Validating: 0it [00:00, ?it/s]

epoch = 120, mpe_loss = 20.36569747638353


Validating: 0it [00:00, ?it/s]

epoch = 121, mpe_loss = 20.337285034836654


Validating: 0it [00:00, ?it/s]

epoch = 122, mpe_loss = 20.32520357996873


Validating: 0it [00:00, ?it/s]

epoch = 123, mpe_loss = 20.3253826028317


Validating: 0it [00:00, ?it/s]

epoch = 124, mpe_loss = 20.33021643711651


Validating: 0it [00:00, ?it/s]

epoch = 125, mpe_loss = 20.27619175719499


Validating: 0it [00:00, ?it/s]

epoch = 126, mpe_loss = 20.273498469070088


Validating: 0it [00:00, ?it/s]

epoch = 127, mpe_loss = 20.24122217820519


Validating: 0it [00:00, ?it/s]

epoch = 128, mpe_loss = 20.22389776311833


Validating: 0it [00:00, ?it/s]

epoch = 129, mpe_loss = 20.216581566540317


Validating: 0it [00:00, ?it/s]

epoch = 130, mpe_loss = 20.190879432990474


Validating: 0it [00:00, ?it/s]

epoch = 131, mpe_loss = 20.175073103730004


Validating: 0it [00:00, ?it/s]

epoch = 132, mpe_loss = 20.166553819734006


Validating: 0it [00:00, ?it/s]

epoch = 133, mpe_loss = 20.179129821054275


Validating: 0it [00:00, ?it/s]

epoch = 134, mpe_loss = 20.15944414638408


Validating: 0it [00:00, ?it/s]

epoch = 135, mpe_loss = 20.148916123690842


Validating: 0it [00:00, ?it/s]

epoch = 136, mpe_loss = 20.138741543568916


Validating: 0it [00:00, ?it/s]

epoch = 137, mpe_loss = 20.122802802060047


Validating: 0it [00:00, ?it/s]

epoch = 138, mpe_loss = 20.109911668853766


Validating: 0it [00:00, ?it/s]

epoch = 139, mpe_loss = 20.101206064181905


Validating: 0it [00:00, ?it/s]

epoch = 140, mpe_loss = 20.09528472594683


Validating: 0it [00:00, ?it/s]

epoch = 141, mpe_loss = 20.087674631130586


Validating: 0it [00:00, ?it/s]

epoch = 142, mpe_loss = 20.083960615913902


Validating: 0it [00:00, ?it/s]

epoch = 143, mpe_loss = 20.076322893986106


Validating: 0it [00:00, ?it/s]

epoch = 144, mpe_loss = 20.055976394845754


Validating: 0it [00:00, ?it/s]

epoch = 145, mpe_loss = 20.02507301071632


Validating: 0it [00:00, ?it/s]

epoch = 146, mpe_loss = 20.052088653620075


Validating: 0it [00:00, ?it/s]

epoch = 147, mpe_loss = 20.028057531034005


Validating: 0it [00:00, ?it/s]

epoch = 148, mpe_loss = 20.02863457250468


Validating: 0it [00:00, ?it/s]

epoch = 149, mpe_loss = 20.019185900808463


Validating: 0it [00:00, ?it/s]

epoch = 150, mpe_loss = 20.010094419964293


Validating: 0it [00:00, ?it/s]

epoch = 151, mpe_loss = 20.00555186308854


Validating: 0it [00:00, ?it/s]

epoch = 152, mpe_loss = 20.003794337774448


Validating: 0it [00:00, ?it/s]

epoch = 153, mpe_loss = 19.97959791899011


Validating: 0it [00:00, ?it/s]

epoch = 154, mpe_loss = 19.969160913509352


Validating: 0it [00:00, ?it/s]

epoch = 155, mpe_loss = 19.94764541079437


Validating: 0it [00:00, ?it/s]

epoch = 156, mpe_loss = 19.941334000430903


Validating: 0it [00:00, ?it/s]

epoch = 157, mpe_loss = 19.931996348593895


Validating: 0it [00:00, ?it/s]

epoch = 158, mpe_loss = 19.932495169138253


Validating: 0it [00:00, ?it/s]

epoch = 159, mpe_loss = 19.932989740595424


Validating: 0it [00:00, ?it/s]

epoch = 160, mpe_loss = 19.914176181705283


Validating: 0it [00:00, ?it/s]

epoch = 161, mpe_loss = 19.888336622312195


Validating: 0it [00:00, ?it/s]

epoch = 162, mpe_loss = 19.885794328031267


Validating: 0it [00:00, ?it/s]

epoch = 163, mpe_loss = 19.86872974459124


Validating: 0it [00:00, ?it/s]

epoch = 164, mpe_loss = 19.87428425055267


Validating: 0it [00:00, ?it/s]

epoch = 165, mpe_loss = 19.86903836146808


Validating: 0it [00:00, ?it/s]

epoch = 166, mpe_loss = 19.85884323021471


Validating: 0it [00:00, ?it/s]

epoch = 167, mpe_loss = 19.85620992176632


Validating: 0it [00:00, ?it/s]

epoch = 168, mpe_loss = 19.859726805452386


Validating: 0it [00:00, ?it/s]

epoch = 169, mpe_loss = 19.82976575419193


Validating: 0it [00:00, ?it/s]

epoch = 170, mpe_loss = 19.84045012012633


Validating: 0it [00:00, ?it/s]

epoch = 171, mpe_loss = 19.83888705673626


Validating: 0it [00:00, ?it/s]

epoch = 172, mpe_loss = 19.846948019721417


Validating: 0it [00:00, ?it/s]

epoch = 173, mpe_loss = 19.826212127474758


Validating: 0it [00:00, ?it/s]

epoch = 174, mpe_loss = 19.830315135661962


Validating: 0it [00:00, ?it/s]

epoch = 175, mpe_loss = 19.852180176573626


Validating: 0it [00:00, ?it/s]

epoch = 176, mpe_loss = 19.843050627203944


Validating: 0it [00:00, ?it/s]

epoch = 177, mpe_loss = 19.836847446258155


Validating: 0it [00:00, ?it/s]

epoch = 178, mpe_loss = 19.836061190933798


Validating: 0it [00:00, ?it/s]

epoch = 179, mpe_loss = 19.82951820647744


Validating: 0it [00:00, ?it/s]

epoch = 180, mpe_loss = 19.81638431759598


Validating: 0it [00:00, ?it/s]

epoch = 181, mpe_loss = 19.80380770568676


Validating: 0it [00:00, ?it/s]

epoch = 182, mpe_loss = 19.77459346795653


Validating: 0it [00:00, ?it/s]

epoch = 183, mpe_loss = 19.75733804265773


Validating: 0it [00:00, ?it/s]

epoch = 184, mpe_loss = 19.747235604998913


Validating: 0it [00:00, ?it/s]

epoch = 185, mpe_loss = 19.746523125890313


Validating: 0it [00:00, ?it/s]

epoch = 186, mpe_loss = 19.72470742706485


Validating: 0it [00:00, ?it/s]

epoch = 187, mpe_loss = 19.740393409368735


Validating: 0it [00:00, ?it/s]

epoch = 188, mpe_loss = 19.73871462232013


Validating: 0it [00:00, ?it/s]

epoch = 189, mpe_loss = 19.720463563109757


Validating: 0it [00:00, ?it/s]

epoch = 190, mpe_loss = 19.71292493238836


Validating: 0it [00:00, ?it/s]

epoch = 191, mpe_loss = 19.724244219894544


Validating: 0it [00:00, ?it/s]

epoch = 192, mpe_loss = 19.71200958533249


Validating: 0it [00:00, ?it/s]

epoch = 193, mpe_loss = 19.72273413396491


Validating: 0it [00:00, ?it/s]

epoch = 194, mpe_loss = 19.723066239904394


Validating: 0it [00:00, ?it/s]

epoch = 195, mpe_loss = 19.711686515775256


Validating: 0it [00:00, ?it/s]

epoch = 196, mpe_loss = 19.71235026880191


Validating: 0it [00:00, ?it/s]

epoch = 197, mpe_loss = 19.709288652854493


Validating: 0it [00:00, ?it/s]

epoch = 198, mpe_loss = 19.70801151012577


Validating: 0it [00:00, ?it/s]

epoch = 199, mpe_loss = 19.686329178751908


Validating: 0it [00:00, ?it/s]

epoch = 200, mpe_loss = 19.673272545428915


Validating: 0it [00:00, ?it/s]

epoch = 201, mpe_loss = 19.65781427674218


Validating: 0it [00:00, ?it/s]

epoch = 202, mpe_loss = 19.66467263082498


Validating: 0it [00:00, ?it/s]

epoch = 203, mpe_loss = 19.654074374489703


Validating: 0it [00:00, ?it/s]

epoch = 204, mpe_loss = 19.64096504184609


Validating: 0it [00:00, ?it/s]

epoch = 205, mpe_loss = 19.652350822263042


Validating: 0it [00:00, ?it/s]

epoch = 206, mpe_loss = 19.64169929245054


Validating: 0it [00:00, ?it/s]

epoch = 207, mpe_loss = 19.61330005554739


Validating: 0it [00:00, ?it/s]

epoch = 208, mpe_loss = 19.631652769364795


Validating: 0it [00:00, ?it/s]

epoch = 209, mpe_loss = 19.61195876389696


Validating: 0it [00:00, ?it/s]

epoch = 210, mpe_loss = 19.61469322905081


Validating: 0it [00:00, ?it/s]

epoch = 211, mpe_loss = 19.59646235466097


Validating: 0it [00:00, ?it/s]

epoch = 212, mpe_loss = 19.623226844062593


Validating: 0it [00:00, ?it/s]

epoch = 213, mpe_loss = 19.610855304661644


Validating: 0it [00:00, ?it/s]

epoch = 214, mpe_loss = 19.600967825099353


Validating: 0it [00:00, ?it/s]

epoch = 215, mpe_loss = 19.64391646101439


Validating: 0it [00:00, ?it/s]

epoch = 216, mpe_loss = 19.61851037571932


Validating: 0it [00:00, ?it/s]

epoch = 217, mpe_loss = 19.611938103701107


Validating: 0it [00:00, ?it/s]

epoch = 218, mpe_loss = 19.61633046073328


Validating: 0it [00:00, ?it/s]

epoch = 219, mpe_loss = 19.61782803392073


Validating: 0it [00:00, ?it/s]

epoch = 220, mpe_loss = 19.61417638167445


Validating: 0it [00:00, ?it/s]

epoch = 221, mpe_loss = 19.613968556021998


Validating: 0it [00:00, ?it/s]

epoch = 222, mpe_loss = 19.61085468997202


Validating: 0it [00:00, ?it/s]

epoch = 223, mpe_loss = 19.610699279862434


Validating: 0it [00:00, ?it/s]

epoch = 224, mpe_loss = 19.610296040737513


Validating: 0it [00:00, ?it/s]

epoch = 225, mpe_loss = 19.60994011258078


Validating: 0it [00:00, ?it/s]

epoch = 226, mpe_loss = 19.60954044191138


Validating: 0it [00:00, ?it/s]

epoch = 227, mpe_loss = 19.609461152542185


Validating: 0it [00:00, ?it/s]

epoch = 228, mpe_loss = 19.609380462651323


Validating: 0it [00:00, ?it/s]

epoch = 229, mpe_loss = 19.609355698813605


Validating: 0it [00:00, ?it/s]

epoch = 230, mpe_loss = 19.609257112237568


Validating: 0it [00:00, ?it/s]

epoch = 231, mpe_loss = 19.609256577671754


Validating: 0it [00:00, ?it/s]

epoch = 232, mpe_loss = 19.609256599088994


Validating: 0it [00:00, ?it/s]

epoch = 233, mpe_loss = 19.609256115348096


Validating: 0it [00:00, ?it/s]

epoch = 234, mpe_loss = 19.609256305742715


KeyError: "['path_id'] not found in axis"

In [27]:
oofs_df = pd.DataFrame(oofs, columns=['x', 'y'])
oofs_df['path'] = train_df['path_id']
oofs_df['timestamp'] = train_df['wp_tmestamp']
oofs_df['site'] = train_df['site_id']
oofs_df['site_path_timestamp'] = oofs_df['site'] + '_' + oofs_df['path'] + '_' + oofs_df['timestamp'].astype(str)
oofs_df['floor'] = train_df['floor']
oofs_df.to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}.csv", index=False)
oofs_df

Unnamed: 0,x,y,path,timestamp,site,site_path_timestamp,floor
0,161.605011,104.301254,5e158ef61506f2000638fd1f,1578469851129,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
1,163.834015,106.054291,5e158ef61506f2000638fd1f,1578469857653,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
2,162.539764,110.929245,5e158ef61506f2000638fd1f,1578469857653,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
3,162.025391,110.759506,5e158ef61506f2000638fd1f,1578469857653,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
4,161.221939,110.610168,5e158ef61506f2000638fd1f,1578469862177,5a0546857ecc773753327266,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0
...,...,...,...,...,...,...,...
251108,198.552979,141.156830,5dcd5c9323759900063d590a,1573733061352,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0
251109,193.014008,142.992310,5dcd5c9323759900063d590a,1573733070079,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0
251110,190.655258,142.768112,5dcd5c9323759900063d590a,1573733070079,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0
251111,190.270554,143.015411,5dcd5c9323759900063d590a,1573733070079,5dc8cea7659e181adb076a3f,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0


In [28]:
oofs_score = mean_position_error(
        oofs_df['x'], oofs_df['y'], 0,
        train_df['x'].values, train_df['y'].values, 0)
print(f"CV:{oofs_score}")

CV:7.00547689143968


In [29]:
all_preds = pd.concat(predictions).groupby('site_path_timestamp').mean()
all_preds

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474563646,0,85.713966,104.649239
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474572654,0,80.899323,102.921295
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474578963,0,85.225449,105.812569
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474582400,0,88.351288,107.935364
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474585965,0,87.318352,108.258492
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731143256,0,210.101776,100.415657
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731146426,0,208.612549,101.582489
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731151563,0,205.288788,106.346169
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731157567,0,199.996414,112.778862


In [30]:
all_preds_37 = pd.read_csv('../37/output/sub37.csv', index_col=0)
all_preds_37.index = pd.read_csv(WIFI_DIR / 'test_7_th20000.csv')['site_path_timestamp']
all_preds_37

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474563646,0,88.266884,104.794300
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474572654,0,82.316630,104.338745
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474578963,0,84.221380,105.362060
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474582400,0,87.842510,109.344190
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474585965,0,87.390120,108.134900
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731143256,5,214.121280,98.048190
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731146426,5,211.845250,100.757540
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731151563,5,208.958570,107.238950
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731157567,5,202.751280,110.968445


In [31]:
all_preds_merge = pd.merge(all_preds_37, all_preds, how='left', on='site_path_timestamp')[['floor_y', 'x_y', 'y_y']]
all_preds_merge = all_preds_merge.rename(columns={'floor_y': 'floor', 'x_y': 'x', 'y_y': 'y'})
all_preds_merge

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474563646,0.0,85.713966,104.649239
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474572654,0.0,80.899323,102.921295
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474578963,0.0,85.225449,105.812569
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474582400,0.0,88.351288,107.935364
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474585965,0.0,87.318352,108.258492
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731143256,0.0,210.101776,100.415657
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731146426,0.0,208.612549,101.582489
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731151563,0.0,205.288788,106.346169
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731157567,0.0,199.996414,112.778862


In [32]:
all_preds_merge['floor'].fillna(all_preds_37['floor'], inplace=True)
all_preds_merge['x'].fillna(all_preds_37['x'], inplace=True)
all_preds_merge['y'].fillna(all_preds_37['y'], inplace=True)
all_preds_merge

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474563646,0.0,85.713966,104.649239
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474572654,0.0,80.899323,102.921295
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474578963,0.0,85.225449,105.812569
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474582400,0.0,88.351288,107.935364
5a0546857ecc773753327266_046cfa46be49fc10834815c6_1578474585965,0.0,87.318352,108.258492
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731143256,0.0,210.101776,100.415657
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731146426,0.0,208.612549,101.582489
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731151563,0.0,205.288788,106.346169
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_1573731157567,0.0,199.996414,112.778862


In [33]:
# foldの結果を平均した後、reindexでsubmission fileにindexを合わせる
all_preds_merge.index = sub.index
all_preds_merge

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0.0,85.713966,104.649239
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0.0,80.899323,102.921295
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0.0,85.225449,105.812569
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0.0,88.351288,107.935364
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0.0,87.318352,108.258492
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,0.0,210.101776,100.415657
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,0.0,208.612549,101.582489
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,0.0,205.288788,106.346169
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,0.0,199.996414,112.778862


In [34]:
# floorの数値を置換
simple_accurate_99 = pd.read_csv('../01/submission.csv')
all_preds_merge['floor'] = simple_accurate_99['floor'].values
all_preds_merge.to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}.csv")
all_preds_merge

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,85.713966,104.649239
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,80.899323,102.921295
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,85.225449,105.812569
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,88.351288,107.935364
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,87.318352,108.258492
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,5,210.101776,100.415657
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,5,208.612549,101.582489
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,5,205.288788,106.346169
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,5,199.996414,112.778862


# Post Proccess

In [14]:
oofs_df = pd.read_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}.csv")
sub_df = pd.read_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}.csv")

In [16]:
def compute_rel_positions(acce_datas, ahrs_datas):
    step_timestamps, step_indexs, step_acce_max_mins = compute_f.compute_steps(acce_datas)
    headings = compute_f.compute_headings(ahrs_datas)
    stride_lengths = compute_f.compute_stride_length(step_acce_max_mins)
    step_headings = compute_f.compute_step_heading(step_timestamps, headings)
    rel_positions = compute_f.compute_rel_positions(stride_lengths, step_headings)
    return rel_positions

In [32]:
import math

order = 3
fs = 50.0  # sample rate, Hz
# fs = 100
# cutoff = 3.667  # desired cutoff frequency of the filter, Hz
cutoff = 3

step_distance = 0.8
w_height = 1.7
m_trans = -5

from scipy.signal import butter, lfilter

def butter_lowpass(cutoff, fs, order=5):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    return b, a

def butter_lowpass_filter(data, cutoff, fs, order=5):
    b, a = butter_lowpass(cutoff, fs, order=order)
    y = lfilter(b, a, data)
    return y

def peak_accel_threshold(data, timestamps, threshold):
    d_acc = []
    last_state = 'below'
    crest_troughs = 0
    crossings = []

    for i, datum in enumerate(data):
        
        current_state = last_state
        if datum < threshold:
            current_state = 'below'
        elif datum > threshold:
            current_state = 'above'

        if current_state is not last_state:
            if current_state is 'above':
                crossing = [timestamps[i], threshold]
                crossings.append(crossing)
            else:
                crossing = [timestamps[i], threshold]
                crossings.append(crossing)

            crest_troughs += 1
        last_state = current_state
    return np.array(crossings)

In [33]:
def steps_compute_rel_positions(sample_file):
    
    mix_acce = np.sqrt(sample_file.acce[:,1:2]**2 + sample_file.acce[:,2:3]**2 + sample_file.acce[:,3:4]**2)
    mix_acce = np.concatenate([sample_file.acce[:,0:1], mix_acce], 1)
    mix_df = pd.DataFrame(mix_acce)
    mix_df.columns = ["timestamp","acce"]
    
    filtered = butter_lowpass_filter(mix_df["acce"], cutoff, fs, order)

    threshold = filtered.mean() * 1.1
    crossings = peak_accel_threshold(filtered, mix_df["timestamp"], threshold)

    step_sum = len(crossings)/2
    distance = w_height * 0.4 * step_sum

    mag_df = pd.DataFrame(sample_file.magn)
    mag_df.columns = ["timestamp","x","y","z"]
    
    acce_df = pd.DataFrame(sample_file.acce)
    acce_df.columns = ["timestamp","ax","ay","az"]
    
    mag_df = pd.merge(mag_df,acce_df,on="timestamp")
    mag_df.dropna()
    
    time_di_list = []

    for i in mag_df.iterrows():

        gx,gy,gz = i[1][1],i[1][2],i[1][3]
        ax,ay,az = i[1][4],i[1][5],i[1][6]

        roll = math.atan2(ay,az)
        pitch = math.atan2(-1*ax , (ay * math.sin(roll) + az * math.cos(roll)))

        q = m_trans - math.degrees(math.atan2(
            (gz*math.sin(roll)-gy*math.cos(roll)),(gx*math.cos(pitch) + gy*math.sin(roll)*math.sin(pitch) + gz*math.sin(pitch)*math.cos(roll))
        )) -90
        if q <= 0:
            q += 360
        time_di_list.append((i[1][0],q))

    d_list = [x[1] for x in time_di_list]
    
    steps = []
    step_time = []
    di_dict = dict(time_di_list)

    for n,i in enumerate(crossings[:,:1]):
        if n % 2 == 1:
            continue
        direct_now = di_dict[i[0]]
        dx = math.sin(math.radians(direct_now))
        dy = math.cos(math.radians(direct_now))
#         print(int(n/2+1),"歩目/x:",dx,"/y:",dy,"/角度：",direct_now)
        steps.append((i[0],dx,dy))
        step_time.append(i[0])
    
        step_dtime = np.diff(step_time)/1000
        step_dtime = step_dtime.tolist()
        step_dtime.insert(0,5)
        
        rel_position = []

        wp_idx = 0
#         print("WP:",round(sample_file.waypoint[0,1],3),round(sample_file.waypoint[0,2],3),sample_file.waypoint[0,0])
#         print("------------------")
        for p,i in enumerate(steps):
            step_distance = 0
            if step_dtime[p] >= 1:
                step_distance = w_height*0.25
            elif step_dtime[p] >= 0.75:
                step_distance = w_height*0.3
            elif step_dtime[p] >= 0.5:
                step_distance = w_height*0.4
            elif step_dtime[p] >= 0.35:
                step_distance = w_height*0.45
            elif step_dtime[p] >= 0.2:
                step_distance = w_height*0.5
            else:
                step_distance = w_height*0.4

#             step_x += i[1]*step_distance
#             step_y += i[2]*step_distance
            
            rel_position.append([i[0], i[1]*step_distance, i[2]*step_distance])
#     print(rel_position)
    
    return np.array(rel_position)

In [34]:
def correct_path(args):

    path, path_df = args
    T_ref  = path_df['timestamp'].values
    xy_hat = path_df[['x', 'y']].values
    txt_path = path_df['txt_path'].values[0]
    
    example = read_data_file(txt_path)
    rel_positions1 = compute_rel_positions(example.acce, example.ahrs)
    rel_positions2 = steps_compute_rel_positions(example)
    rel1 = rel_positions1.copy()
    rel2 = rel_positions2.copy()
    rel1[:,1:] = rel_positions1[:,1:] / 2
    rel2[:,1:] = rel_positions2[:,1:] / 2
    rel_positions = np.vstack([rel1,rel2])
    rel_positions = rel_positions[np.argsort(rel_positions[:, 0])]
    
    if T_ref[-1] > rel_positions[-1, 0]:
        rel_positions = [np.array([[0, 0, 0]]), rel_positions, np.array([[T_ref[-1], 0, 0]])]
    else:
        rel_positions = [np.array([[0, 0, 0]]), rel_positions]
    rel_positions = np.concatenate(rel_positions)
    
    T_rel = rel_positions[:, 0]
    delta_xy_hat = np.diff(scipy.interpolate.interp1d(T_rel, np.cumsum(rel_positions[:, 1:3], axis=0), axis=0)(T_ref), axis=0)

    N = xy_hat.shape[0]
    delta_t = np.diff(T_ref)
    alpha = (8.1)**(-2) * np.ones(N)
    beta  = (0.3 + 0.3 * 1e-3 * delta_t)**(-2)
    A = scipy.sparse.spdiags(alpha, [0], N, N)
    B = scipy.sparse.spdiags( beta, [0], N-1, N-1)
    D = scipy.sparse.spdiags(np.stack([-np.ones(N), np.ones(N)]), [0, 1], N-1, N)

    Q = A + (D.T @ B @ D)
    c = (A @ xy_hat) + (D.T @ (B @ delta_xy_hat))
    xy_star = scipy.sparse.linalg.spsolve(Q, c)

    return pd.DataFrame({
        'site_path_timestamp' : path_df['site_path_timestamp'],
        'floor' : path_df['floor'],
        'x' : xy_star[:, 0],
        'y' : xy_star[:, 1],
    })

In [35]:
tmp = sub_df['site_path_timestamp'].apply(lambda s : pd.Series(s.split('_')))
sub_df['site'] = tmp[0]
sub_df['path'] = tmp[1]
sub_df['timestamp'] = tmp[2].astype(float)

In [36]:
used_buildings = sorted(sub_df['site'].value_counts().index.tolist())
test_txts = sorted(glob.glob(str(DATA_DIR/'indoor-location-navigation') + f'/test/*.txt'))
train_txts = [sorted(glob.glob(str(DATA_DIR/'indoor-location-navigation') + f'/train/{used_building}/*/*.txt')) for used_building in used_buildings]
train_txts = sum(train_txts, [])

In [37]:
txt_pathes = []
for path in tqdm(sub_df['path'].values):
    txt_pathes.append([test_txt for test_txt in test_txts if path in test_txt][0])

100%|██████████| 10133/10133 [00:00<00:00, 28922.91it/s]


In [38]:
sub_df['txt_path'] = txt_pathes

In [39]:
processes = multiprocessing.cpu_count()
with multiprocessing.Pool(processes=processes) as pool:
    dfs = pool.imap_unordered(correct_path, sub_df.groupby('path'))
    dfs = tqdm(dfs)
    dfs = list(dfs)
sub_df_cm = pd.concat(dfs).sort_values('site_path_timestamp')

626it [02:21,  4.43it/s]


In [40]:
txt_pathes = []
for path in tqdm(oofs_df['path'].values):
    txt_pathes.append([train_txt for train_txt in train_txts if path in train_txt][0])

100%|██████████| 251113/251113 [02:26<00:00, 1711.14it/s]


In [41]:
oofs_df['txt_path'] = txt_pathes

In [42]:
processes = multiprocessing.cpu_count()
with multiprocessing.Pool(processes=processes) as pool:
    dfs = pool.imap_unordered(correct_path, oofs_df.groupby('path'))
    dfs = tqdm(dfs)
    dfs = list(dfs)
oofs_df_cm = pd.concat(dfs).sort_index()
oofs_df_cm

10789it [16:30, 10.89it/s]


Unnamed: 0,site_path_timestamp,floor,x,y
0,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,158.978636,104.825028
1,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.257042,108.604798
2,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.251276,108.609015
3,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.245114,108.610050
4,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,165.494042,112.005578
...,...,...,...,...
251108,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,203.264141,140.781741
251109,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,190.250754,143.186429
251110,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,190.243241,143.182151
251111,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,190.235164,143.178441


In [None]:
oofs_df_cm.to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}_cm.csv", index=False)

In [45]:
oofs_score = mean_position_error(
        oofs_df_cm['x'], oofs_df_cm['y'], 0,
        train_df['x'].values, train_df['y'].values, 0)
print(f"CV:{oofs_score}")

CV:5.5303376207693065


In [None]:
sub_df_cm.to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}_cm.csv", index=False)

In [None]:
oofs_df_cm = pd.read_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}_cm.csv")

In [None]:
sub_df_cm = pd.read_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}_cm.csv")

In [46]:
def split_col(df):
    df = pd.concat([
        df['site_path_timestamp'].str.split('_', expand=True) \
        .rename(columns={0:'site',
                         1:'path',
                         2:'timestamp'}),
        df
    ], axis=1).copy()
    return df
def sub_process(sub, train_waypoints):
    train_waypoints['isTrainWaypoint'] = True
    sub = split_col(sub[['site_path_timestamp','floor','x','y']]).copy()
    sub = sub.merge(train_waypoints[['site','floorNo','floor']].drop_duplicates(), how='left')
    sub = sub.merge(
        train_waypoints[['x','y','site','floor','isTrainWaypoint']].drop_duplicates(),
        how='left',
        on=['site','x','y','floor']
             )
    sub['isTrainWaypoint'] = sub['isTrainWaypoint'].fillna(False)
    return sub.copy()

In [47]:
train_waypoints = pd.read_csv(str(DATA_DIR/'indoor-location-navigation') + '/train_waypoints.csv')


In [48]:
sub_df_cm = sub_process(sub_df_cm, train_waypoints)
oofs_df_cm = sub_process(oofs_df_cm, train_waypoints)

In [49]:
from scipy.spatial.distance import cdist

def add_xy(df):
    df['xy'] = [(x, y) for x,y in zip(df['x'], df['y'])]
    return df

def closest_point(point, points):
    """ Find closest point from a list of points. """
    return points[cdist([point], points).argmin()]

sub_df_cm = add_xy(sub_df_cm)
train_waypoints = add_xy(train_waypoints)

ds = []
for (site, myfloor), d in tqdm(sub_df_cm.groupby(['site','floor'])):
    true_floor_locs = train_waypoints.loc[(train_waypoints['floor'] == myfloor) &
                                          (train_waypoints['site'] == site)] \
        .reset_index(drop=True)
    if len(true_floor_locs) == 0:
        print(f'Skipping {site} {myfloor}')
        continue
    d['matched_point'] = [closest_point(x, list(true_floor_locs['xy'])) for x in d['xy']]
    d['x_'] = d['matched_point'].apply(lambda x: x[0])
    d['y_'] = d['matched_point'].apply(lambda x: x[1])
    ds.append(d)

sub_df_cm_ds = pd.concat(ds)


oofs_df_cm = add_xy(oofs_df_cm)
train_waypoints = add_xy(train_waypoints)

ds = []
for (site, myfloor), d in tqdm(oofs_df_cm.groupby(['site','floor'])):
    true_floor_locs = train_waypoints.loc[(train_waypoints['floor'] == myfloor) &
                                          (train_waypoints['site'] == site)] \
        .reset_index(drop=True)
    if len(true_floor_locs) == 0:
        print(f'Skipping {site} {myfloor}')
        continue
    d['matched_point'] = [closest_point(x, list(true_floor_locs['xy'])) for x in d['xy']]
    d['x_'] = d['matched_point'].apply(lambda x: x[0])
    d['y_'] = d['matched_point'].apply(lambda x: x[1])
    ds.append(d)

oofs_df_cm_ds = pd.concat(ds)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
100%|██████████| 118/118 [00:04<00:00, 27.25it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation:

In [50]:
def snap_to_grid(sub, threshold):
    """
    Snap to grid if within a threshold.
    
    x, y are the predicted points.
    x_, y_ are the closest grid points.
    _x_, _y_ are the new predictions after post processing.
    """
    sub['_x_'] = sub['x']
    sub['_y_'] = sub['y']
    sub.loc[sub['dist'] < threshold, '_x_'] = sub.loc[sub['dist'] < threshold]['x_']
    sub.loc[sub['dist'] < threshold, '_y_'] = sub.loc[sub['dist'] < threshold]['y_']
    return sub.copy()



In [51]:
# Calculate the distances
sub_df_cm_ds['dist'] = np.sqrt( (sub_df_cm_ds.x-sub_df_cm_ds.x_)**2 + (sub_df_cm_ds.y-sub_df_cm_ds.y_)**2 )
sub_pp = snap_to_grid(sub_df_cm_ds, threshold=5)
sub_pp = sub_pp[['site_path_timestamp','floor','_x_','_y_','site','path','floorNo']] \
    .rename(columns={'_x_':'x', '_y_':'y'})

# Calculate the distances
oofs_df_cm_ds['dist'] = np.sqrt( (oofs_df_cm_ds.x-oofs_df_cm_ds.x_)**2 + (oofs_df_cm_ds.y-oofs_df_cm_ds.y_)**2 )
oofs_pp = snap_to_grid(oofs_df_cm_ds, threshold=5)
oofs_pp = oofs_pp[['site_path_timestamp','floor','_x_','_y_','site','path','floorNo']] \
    .rename(columns={'_x_':'x', '_y_':'y'})

In [52]:
sub_pp = sub_pp.sort_index()
sub_pp

Unnamed: 0,site_path_timestamp,floor,x,y,site,path,floorNo
0,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,93.728470,97.948860,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
1,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,79.662285,102.766754,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
2,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,87.162830,104.450550,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
3,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,81.657740,110.509090,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
4,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,81.657740,110.509090,5a0546857ecc773753327266,046cfa46be49fc10834815c6,F1
...,...,...,...,...,...,...,...
10128,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,205.596040,98.605774,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6
10129,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,206.627910,102.635210,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6
10130,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,205.511300,107.841324,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6
10131,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,192.553130,111.863014,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,F6


In [55]:
sub_pp[['site_path_timestamp','floor','x','y']] \
    .to_csv(str(OUTPUT_DIR) + f"/sub{EXP_NAME}_pp.csv", index=False)

In [53]:
oofs_pp = oofs_pp.sort_index()
oofs_pp

Unnamed: 0,site_path_timestamp,floor,x,y,site,path,floorNo
0,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,158.49695,107.122680,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
1,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.33182,110.822685,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
2,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.33182,110.822685,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
3,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.33182,110.822685,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
4,5a0546857ecc773753327266_5e158ef61506f2000638f...,-1.0,162.33182,110.822685,5a0546857ecc773753327266,5e158ef61506f2000638fd1f,B1
...,...,...,...,...,...,...,...
251108,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,202.50334,140.972670,5dc8cea7659e181adb076a3f,5dcd5c9323759900063d590a,F7
251109,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,191.95721,143.862530,5dc8cea7659e181adb076a3f,5dcd5c9323759900063d590a,F7
251110,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,191.95721,143.862530,5dc8cea7659e181adb076a3f,5dcd5c9323759900063d590a,F7
251111,5dc8cea7659e181adb076a3f_5dcd5c9323759900063d5...,6.0,191.95721,143.862530,5dc8cea7659e181adb076a3f,5dcd5c9323759900063d590a,F7


In [54]:
oofs_score = mean_position_error(
        oofs_pp['x'], oofs_pp['y'], 0,
        train_df['x'].values, train_df['y'].values, 0)
print(f"CV:{oofs_score}")

CV:5.05298385725833


In [56]:
oofs_pp[['site_path_timestamp','floor','x','y']] \
    .to_csv(str(OUTPUT_DIR) + f"/oof{EXP_NAME}_pp.csv", index=False)

In [59]:
wandb.init(project='Indoor_Location_Navigation', entity='sqrt4kaido', group=RUN_NAME, job_type='summary')
wandb.run.name = 'summary'
wandb.log({'CV_score': oofs_score})
wandb.save(utils.get_notebook_path())
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33msqrt4kaido[0m (use `wandb login --relogin` to force relogin)




VBox(children=(Label(value=' 0.00MB of 0.42MB uploaded (0.00MB deduped)\r'), FloatProgress(value=0.00150438817…

0,1
CV_score,5.05298
_runtime,1.0
_timestamp,1619533532.0
_step,0.0


0,1
CV_score,▁
_runtime,▁
_timestamp,▁
_step,▁


In [None]:
import json
import matplotlib.pylab as plt

def plot_preds(
    site,
    floorNo,
    sub=None,
    true_locs=None,
    base=str(DATA_DIR/'indoor-location-navigation'),
    show_train=True,
    show_preds=True,
    fix_labels=True,
    map_floor=None
):
    """
    Plots predictions on floorplan map.
    
    map_floor : use a different floor's map
    """
    if map_floor is None:
        map_floor = floorNo
    # Prepare width_meter & height_meter (taken from the .json file)
    floor_plan_filename = f"{base}/metadata/{site}/{map_floor}/floor_image.png"
    json_plan_filename = f"{base}/metadata/{site}/{map_floor}/floor_info.json"
    with open(json_plan_filename) as json_file:
        json_data = json.load(json_file)

    width_meter = json_data["map_info"]["width"]
    height_meter = json_data["map_info"]["height"]

    floor_img = plt.imread(f"{base}/metadata/{site}/{map_floor}/floor_image.png")

    fig, ax = plt.subplots(figsize=(12, 12))
    plt.imshow(floor_img)

    if show_train:
        true_locs = true_locs.query('site == @site and floorNo == @map_floor').copy()
        true_locs["x_"] = true_locs["x"] * floor_img.shape[0] / height_meter
        true_locs["y_"] = (
            true_locs["y"] * -1 * floor_img.shape[1] / width_meter
        ) + floor_img.shape[0]
        true_locs.query("site == @site and floorNo == @map_floor").groupby("path").plot(
            x="x_",
            y="y_",
            style="+",
            ax=ax,
            label="train waypoint location",
            color="grey",
            alpha=0.5,
        )

    if show_preds:
        sub = sub.query('site == @site and floorNo == @floorNo').copy()
        sub["x_"] = sub["x"] * floor_img.shape[0] / height_meter
        sub["y_"] = (
            sub["y"] * -1 * floor_img.shape[1] / width_meter
        ) + floor_img.shape[0]
        for path, path_data in sub.query(
            "site == @site and floorNo == @floorNo"
        ).groupby("path"):
            path_data.plot(
                x="x_",
                y="y_",
                style=".-",
                ax=ax,
                title=f"{site} - floor - {floorNo}",
                alpha=1,
                label=path,
            )
    if fix_labels:
        handles, labels = ax.get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        plt.legend(
            by_label.values(), by_label.keys(), loc="center left", bbox_to_anchor=(1, 0.5)
        )
    return fig, ax

In [None]:
example_site = '5a0546857ecc773753327266'
example_floorNo = 'F1'

sub_df = sub_process(sub_df, train_waypoints)
plot_preds(example_site, example_floorNo, sub_df,
           train_waypoints, show_preds=True)

In [None]:
plot_preds(example_site, example_floorNo, sub_df_cm,
           train_waypoints, show_preds=True)

In [None]:
plot_preds(example_site, example_floorNo, sub_pp,
           train_waypoints, show_preds=True)