In [24]:
import gc
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import seaborn as sns
import shutil
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import tqdm
from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.data import DataLoader, Dataset

from IPython.display import clear_output

sns.set(font_scale=1.8)
%matplotlib inline

import os

import catalyst
from catalyst import dl, utils
catalyst.__version__

'21.03.2'

In [22]:
X_tr = np.load('./train_wisdm_data.npy')
y_tr = np.load('./train_wisdm_label.npy')

X_val = np.load('./val_wisdm_data.npy')
y_val = np.load('./val_wisdm_label.npy')

X_test = np.load('./test_wisdm_data.npy')
y_test = np.load('./test_wisdm_label.npy')

In [25]:
train_loader = DataLoader(list(zip(X_tr, np.argmax(y_tr, axis=1))), 
                          batch_size=128, 
                          shuffle=True)
val_loader = DataLoader(list(zip(X_val,np.argmax(y_val, axis=1))), 
                        batch_size=128, 
                        shuffle=True)
loaders = {
    "train" : train_loader, 
    "valid": val_loader
}

In [26]:
class SimpleLSTM(nn.Module):
    def __init__(self, input_size, output_size, hidden_size=256, p_dropout=1/2):
        super().__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.LSTM(input_size, 
                           hidden_size, 
                           batch_first=True)
        self.dropout = nn.Dropout(p_dropout)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.ReLU(),
            nn.Dropout(p_dropout),
            nn.Linear(64, output_size)
        )
        
    def forward(self, input_seq):
        _, (hidden_state, _) = self.rnn(input_seq)
        hidden_state = self.dropout(hidden_state)
        predictions = self.classifier(hidden_state.squeeze(0))
        return predictions

In [27]:
class SimpleGRU(nn.Module):
    def __init__(self, input_size, output_size, hidden_size=128, p_dropout=1/2):
        super().__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.GRU(input_size, 
                          hidden_size, 
                          batch_first=True)
        self.dropout = nn.Dropout(p_dropout)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.ReLU(),
            nn.Dropout(p=p_dropout),
            nn.Linear(64, output_size)
        )
        
    def forward(self, input_seq):
        _, hidden_state = self.rnn(input_seq)
        predictions = self.classifier(hidden_state.squeeze(0))
        return predictions

In [28]:
device='cuda'
from catalyst.dl import (
    SupervisedRunner, AccuracyCallback,
    CriterionCallback, SchedulerCallback,
)

In [44]:
model = SimpleLSTM(3, 6, 256, 0.25).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=False)

# model training
runner = dl.SupervisedRunner(
    input_key="features", output_key="logits", target_key="targets", loss_key="loss"
)
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    logdir="./logdir",
    num_epochs=150,
    valid_loader="valid",
    valid_metric="accuracy",
    minimize_valid_metric=False,
    verbose=False,
    callbacks=[
        dl.AccuracyCallback(input_key="logits", target_key="targets", num_classes=6, topk_args=[1]),
        dl.ConfusionMatrixCallback(
             input_key="logits", target_key="targets", num_classes=6
         )
    ]
)

train (1/150) accuracy: 0.5993098020553589 | accuracy/std: 0.11014770193543356 | accuracy01: 0.5993098020553589 | accuracy01/std: 0.11014770193543356 | loss: 1.1081650257110596 | loss/mean: 1.1081650257110596 | loss/std: 0.23516191903759615 | lr: 0.001 | momentum: 0.9
valid (1/150) accuracy: 0.7188548445701599 | accuracy/std: 0.04101523257602724 | accuracy01: 0.7188548445701599 | accuracy01/std: 0.04101523257602724 | loss: 0.7893596291542053 | loss/mean: 0.7893596291542053 | loss/std: 0.09035154052517444 | lr: 0.001 | momentum: 0.9
* Epoch (1/150) lr: 0.001 | momentum: 0.9
train (2/150) accuracy: 0.7213382124900818 | accuracy/std: 0.04189024452542519 | accuracy01: 0.7213382124900818 | accuracy01/std: 0.04189024452542519 | loss: 0.7948208451271057 | loss/mean: 0.7948208451271057 | loss/std: 0.09780174257946794 | lr: 0.001 | momentum: 0.9
valid (2/150) accuracy: 0.6873182654380798 | accuracy/std: 0.046713034966983785 | accuracy01: 0.6873182654380798 | accuracy01/std: 0.046713034966983785

valid (15/150) accuracy: 0.8045179843902588 | accuracy/std: 0.03388748651858141 | accuracy01: 0.8045179843902588 | accuracy01/std: 0.03388748651858141 | loss: 0.478073388338089 | loss/mean: 0.478073388338089 | loss/std: 0.06762970379270135 | lr: 0.0001 | momentum: 0.9
* Epoch (15/150) lr: 0.0001 | momentum: 0.9
train (16/150) accuracy: 0.8216065764427185 | accuracy/std: 0.037913071541958214 | accuracy01: 0.8216065764427185 | accuracy01/std: 0.037913071541958214 | loss: 0.4444732069969177 | loss/mean: 0.4444732069969177 | loss/std: 0.06470676177711028 | lr: 0.0001 | momentum: 0.9
valid (16/150) accuracy: 0.8110042810440063 | accuracy/std: 0.031019063371290043 | accuracy01: 0.8110042810440063 | accuracy01/std: 0.031019063371290043 | loss: 0.4637625813484192 | loss/mean: 0.4637625813484192 | loss/std: 0.07004782136157191 | lr: 0.0001 | momentum: 0.9
* Epoch (16/150) lr: 0.0001 | momentum: 0.9
train (17/150) accuracy: 0.8259681463241577 | accuracy/std: 0.030453171366352384 | accuracy01: 0.

valid (29/150) accuracy: 0.8295683264732361 | accuracy/std: 0.03156661443849629 | accuracy01: 0.8295683264732361 | accuracy01/std: 0.03156661443849629 | loss: 0.46406492590904236 | loss/mean: 0.46406492590904236 | loss/std: 0.07260988220709719 | lr: 1e-05 | momentum: 0.9
* Epoch (29/150) lr: 1e-05 | momentum: 0.9
train (30/150) accuracy: 0.8374233245849609 | accuracy/std: 0.0308352096922692 | accuracy01: 0.8374233245849609 | accuracy01/std: 0.0308352096922692 | loss: 0.4297949969768524 | loss/mean: 0.4297949969768524 | loss/std: 0.07388398392809525 | lr: 1e-05 | momentum: 0.9
valid (30/150) accuracy: 0.8302392959594727 | accuracy/std: 0.04674864168123528 | accuracy01: 0.8302392959594727 | accuracy01/std: 0.04674864168123528 | loss: 0.45863059163093567 | loss/mean: 0.45863059163093567 | loss/std: 0.08292819707625868 | lr: 1e-05 | momentum: 0.9
* Epoch (30/150) lr: 1e-05 | momentum: 0.9
train (31/150) accuracy: 0.8359854221343994 | accuracy/std: 0.028656645056233138 | accuracy01: 0.83598

train (43/150) accuracy: 0.842887282371521 | accuracy/std: 0.029525636494669574 | accuracy01: 0.842887282371521 | accuracy01/std: 0.029525636494669574 | loss: 0.4033026099205017 | loss/mean: 0.4033026099205017 | loss/std: 0.06160279391288781 | lr: 1.0000000000000002e-06 | momentum: 0.9
valid (43/150) accuracy: 0.8338179588317871 | accuracy/std: 0.034649425524779 | accuracy01: 0.8338179588317871 | accuracy01/std: 0.034649425524779 | loss: 0.4321706295013428 | loss/mean: 0.4321706295013428 | loss/std: 0.07804662969399302 | lr: 1.0000000000000002e-06 | momentum: 0.9
* Epoch (43/150) lr: 1.0000000000000002e-06 | momentum: 0.9
train (44/150) accuracy: 0.8440375924110413 | accuracy/std: 0.031020910160329807 | accuracy01: 0.8440375924110413 | accuracy01/std: 0.031020910160329807 | loss: 0.40021398663520813 | loss/mean: 0.40021398663520813 | loss/std: 0.061877668346883366 | lr: 1.0000000000000002e-06 | momentum: 0.9
valid (44/150) accuracy: 0.8342652916908264 | accuracy/std: 0.0345943852675869

train (56/150) accuracy: 0.8431748747825623 | accuracy/std: 0.034296410024518165 | accuracy01: 0.8431748747825623 | accuracy01/std: 0.034296410024518165 | loss: 0.40365204215049744 | loss/mean: 0.40365204215049744 | loss/std: 0.06434134870782904 | lr: 1.0000000000000002e-07 | momentum: 0.9
valid (56/150) accuracy: 0.834712564945221 | accuracy/std: 0.03164479791182792 | accuracy01: 0.834712564945221 | accuracy01/std: 0.03164479791182792 | loss: 0.4306536018848419 | loss/mean: 0.4306536018848419 | loss/std: 0.07593497771472899 | lr: 1.0000000000000002e-07 | momentum: 0.9
* Epoch (56/150) lr: 1.0000000000000002e-07 | momentum: 0.9
train (57/150) accuracy: 0.8423600196838379 | accuracy/std: 0.02983994341153828 | accuracy01: 0.8423600196838379 | accuracy01/std: 0.02983994341153828 | loss: 0.40108245611190796 | loss/mean: 0.40108245611190796 | loss/std: 0.05890528756369932 | lr: 1.0000000000000002e-07 | momentum: 0.9
valid (57/150) accuracy: 0.834712564945221 | accuracy/std: 0.04064590445830

train (69/150) accuracy: 0.84375 | accuracy/std: 0.03202014970387288 | accuracy01: 0.84375 | accuracy01/std: 0.03202014970387288 | loss: 0.4017910957336426 | loss/mean: 0.4017910957336426 | loss/std: 0.06310302164420548 | lr: 1.0000000000000004e-08 | momentum: 0.9
valid (69/150) accuracy: 0.834712564945221 | accuracy/std: 0.027761277652403726 | accuracy01: 0.834712564945221 | accuracy01/std: 0.027761277652403726 | loss: 0.43057680130004883 | loss/mean: 0.43057680130004883 | loss/std: 0.07148700080535293 | lr: 1.0000000000000004e-08 | momentum: 0.9
* Epoch (69/150) lr: 1.0000000000000004e-08 | momentum: 0.9
train (70/150) accuracy: 0.8429352045059204 | accuracy/std: 0.03286976729971871 | accuracy01: 0.8429352045059204 | accuracy01/std: 0.03286976729971871 | loss: 0.4021720588207245 | loss/mean: 0.4021720588207245 | loss/std: 0.06904200211738032 | lr: 1.0000000000000004e-08 | momentum: 0.9
valid (70/150) accuracy: 0.834712564945221 | accuracy/std: 0.023074543561569846 | accuracy01: 0.834

train (82/150) accuracy: 0.842743456363678 | accuracy/std: 0.030927307929413606 | accuracy01: 0.842743456363678 | accuracy01/std: 0.030927307929413606 | loss: 0.40254271030426025 | loss/mean: 0.40254271030426025 | loss/std: 0.06617691719253466 | lr: 1.0000000000000004e-08 | momentum: 0.9
valid (82/150) accuracy: 0.834712564945221 | accuracy/std: 0.023997479274182717 | accuracy01: 0.834712564945221 | accuracy01/std: 0.023997479274182717 | loss: 0.4305189251899719 | loss/mean: 0.4305189251899719 | loss/std: 0.06517157343138899 | lr: 1.0000000000000004e-08 | momentum: 0.9
* Epoch (82/150) lr: 1.0000000000000004e-08 | momentum: 0.9
train (83/150) accuracy: 0.8439896702766418 | accuracy/std: 0.031837463529103695 | accuracy01: 0.8439896702766418 | accuracy01/std: 0.031837463529103695 | loss: 0.4021463692188263 | loss/mean: 0.4021463692188263 | loss/std: 0.06504789408756358 | lr: 1.0000000000000004e-08 | momentum: 0.9
valid (83/150) accuracy: 0.834712564945221 | accuracy/std: 0.03442334957930

train (95/150) accuracy: 0.8437020778656006 | accuracy/std: 0.0345025728110666 | accuracy01: 0.8437020778656006 | accuracy01/std: 0.0345025728110666 | loss: 0.40028417110443115 | loss/mean: 0.40028417110443115 | loss/std: 0.06703421275301791 | lr: 1.0000000000000004e-08 | momentum: 0.9
valid (95/150) accuracy: 0.834936261177063 | accuracy/std: 0.03823467878627138 | accuracy01: 0.834936261177063 | accuracy01/std: 0.03823467878627138 | loss: 0.4304687976837158 | loss/mean: 0.4304687976837158 | loss/std: 0.08964281914705075 | lr: 1.0000000000000004e-08 | momentum: 0.9
* Epoch (95/150) lr: 1.0000000000000004e-08 | momentum: 0.9
train (96/150) accuracy: 0.8431748747825623 | accuracy/std: 0.03197987444069538 | accuracy01: 0.8431748747825623 | accuracy01/std: 0.03197987444069538 | loss: 0.4014713764190674 | loss/mean: 0.4014713764190674 | loss/std: 0.06917688093811863 | lr: 1.0000000000000004e-08 | momentum: 0.9
valid (96/150) accuracy: 0.834936261177063 | accuracy/std: 0.025084570115406612 |

train (108/150) accuracy: 0.844756543636322 | accuracy/std: 0.034563842880930395 | accuracy01: 0.844756543636322 | accuracy01/std: 0.034563842880930395 | loss: 0.4005781412124634 | loss/mean: 0.4005781412124634 | loss/std: 0.0686213820168871 | lr: 1.0000000000000004e-08 | momentum: 0.9
valid (108/150) accuracy: 0.834936261177063 | accuracy/std: 0.03356813121643025 | accuracy01: 0.834936261177063 | accuracy01/std: 0.03356813121643025 | loss: 0.43042314052581787 | loss/mean: 0.43042314052581787 | loss/std: 0.08883567966493892 | lr: 1.0000000000000004e-08 | momentum: 0.9
* Epoch (108/150) lr: 1.0000000000000004e-08 | momentum: 0.9
train (109/150) accuracy: 0.8437020778656006 | accuracy/std: 0.031256702925327395 | accuracy01: 0.8437020778656006 | accuracy01/std: 0.031256702925327395 | loss: 0.4002814292907715 | loss/mean: 0.4002814292907715 | loss/std: 0.06699475480019727 | lr: 1.0000000000000004e-08 | momentum: 0.9
valid (109/150) accuracy: 0.834936261177063 | accuracy/std: 0.030400115151

train (121/150) accuracy: 0.842887282371521 | accuracy/std: 0.0353125884532519 | accuracy01: 0.842887282371521 | accuracy01/std: 0.0353125884532519 | loss: 0.40036430954933167 | loss/mean: 0.40036430954933167 | loss/std: 0.06695299271654583 | lr: 1.0000000000000004e-08 | momentum: 0.9
valid (121/150) accuracy: 0.834936261177063 | accuracy/std: 0.024404967173927785 | accuracy01: 0.834936261177063 | accuracy01/std: 0.024404967173927785 | loss: 0.43037137389183044 | loss/mean: 0.43037137389183044 | loss/std: 0.06944729776873684 | lr: 1.0000000000000004e-08 | momentum: 0.9
* Epoch (121/150) lr: 1.0000000000000004e-08 | momentum: 0.9
train (122/150) accuracy: 0.8439896702766418 | accuracy/std: 0.029083723554336378 | accuracy01: 0.8439896702766418 | accuracy01/std: 0.029083723554336378 | loss: 0.40223702788352966 | loss/mean: 0.40223702788352966 | loss/std: 0.057583846952316414 | lr: 1.0000000000000004e-08 | momentum: 0.9
valid (122/150) accuracy: 0.834936261177063 | accuracy/std: 0.03036143

train (134/150) accuracy: 0.842887282371521 | accuracy/std: 0.03436673906011205 | accuracy01: 0.842887282371521 | accuracy01/std: 0.03436673906011205 | loss: 0.39943039417266846 | loss/mean: 0.39943039417266846 | loss/std: 0.06744782500766042 | lr: 1.0000000000000004e-08 | momentum: 0.9
valid (134/150) accuracy: 0.834936261177063 | accuracy/std: 0.03133156494951861 | accuracy01: 0.834936261177063 | accuracy01/std: 0.03133156494951861 | loss: 0.430320143699646 | loss/mean: 0.430320143699646 | loss/std: 0.06670858096324342 | lr: 1.0000000000000004e-08 | momentum: 0.9
* Epoch (134/150) lr: 1.0000000000000004e-08 | momentum: 0.9
train (135/150) accuracy: 0.8432707190513611 | accuracy/std: 0.03109088995821821 | accuracy01: 0.8432707190513611 | accuracy01/std: 0.03109088995821821 | loss: 0.4030103385448456 | loss/mean: 0.4030103385448456 | loss/std: 0.06596742865455026 | lr: 1.0000000000000004e-08 | momentum: 0.9
valid (135/150) accuracy: 0.834936261177063 | accuracy/std: 0.02794212565972055

train (147/150) accuracy: 0.8442292809486389 | accuracy/std: 0.03064201198033942 | accuracy01: 0.8442292809486389 | accuracy01/std: 0.03064201198033942 | loss: 0.4019884765148163 | loss/mean: 0.4019884765148163 | loss/std: 0.07110130202919344 | lr: 1.0000000000000004e-08 | momentum: 0.9
valid (147/150) accuracy: 0.834936261177063 | accuracy/std: 0.028886482315009492 | accuracy01: 0.834936261177063 | accuracy01/std: 0.028886482315009492 | loss: 0.4302501380443573 | loss/mean: 0.4302501380443573 | loss/std: 0.06428140552376965 | lr: 1.0000000000000004e-08 | momentum: 0.9
* Epoch (147/150) lr: 1.0000000000000004e-08 | momentum: 0.9
train (148/150) accuracy: 0.8436062335968018 | accuracy/std: 0.03244388053872605 | accuracy01: 0.8436062335968018 | accuracy01/std: 0.03244388053872605 | loss: 0.40159931778907776 | loss/mean: 0.40159931778907776 | loss/std: 0.07327961113483113 | lr: 1.0000000000000004e-08 | momentum: 0.9
valid (148/150) accuracy: 0.834936261177063 | accuracy/std: 0.03595509354

In [43]:
!rm -r logdir

In [46]:
%reload_ext tensorboard
%tensorboard --logdir logdir/tensorboard/ --port 6000

Reusing TensorBoard on port 6000 (pid 769), started 0:05:42 ago. (Use '!kill 769' to kill it.)

In [None]:
!ls