In [1]:
import numpy as np
import pandas as pd

%matplotlib inline
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

import random


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device: " + DEVICE)
if torch.backends.cudnn.is_available():
    torch.backends.cudnn.enabled = True

SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

Using device: cuda


## Preparing data

In [2]:
def read_data(file_path_X, file_path_y, seq_length):
    X_data = np.array(pd.read_csv(file_path_X, header=None))
    y_data = np.array(pd.read_csv(file_path_y, header=None))
    
    # Shift labels in PyTorch classification models labels start from 0
    y_data = y_data - 1
    
    blocks = X_data.shape[0] / seq_length
    
    X_seq = np.array(np.split(X_data, blocks, axis=0))
    
    return X_seq, y_data

In [3]:
DATASETS_DIR = "./poses/"
SEQ_LENGTH = 32

X_train, y_train = read_data(DATASETS_DIR + "X_train.txt",
                             DATASETS_DIR + "Y_train.txt",
                             SEQ_LENGTH)

# In repo there aren't labels for validation set, so use test as validation for now
X_val, y_val = read_data(DATASETS_DIR + "X_test.txt",
                         DATASETS_DIR + "Y_test.txt",
                         SEQ_LENGTH)


In [4]:
train_dataset = torch.utils.data.TensorDataset(torch.tensor(X_train, dtype=torch.float32),
                                               torch.tensor(y_train, dtype=torch.long).squeeze())

val_dataset = torch.utils.data.TensorDataset(torch.tensor(X_val, dtype=torch.float32),
                                             torch.tensor(y_val, dtype=torch.long).squeeze())

## NN model

In [6]:
class LstmClassifier(nn.Module):
    
    def __init__(self, input_dim, n_classes, lstm_hidden_dim=256, fc_hidden_dim=256, n_lstm_layers=2):
        super(LstmClassifier, self).__init__()
        
        self._lstm = nn.LSTM(input_size=input_dim,
                             hidden_size=lstm_hidden_dim,
                             num_layers=n_lstm_layers,
                             batch_first=True)
        
        self._fc = nn.Sequential(nn.Linear(lstm_hidden_dim, fc_hidden_dim),
                                 nn.ReLU(),
                                 nn.Linear(fc_hidden_dim, n_classes))
        
    def forward(self, x):
        lstm_output, _ = self._lstm.forward(x)
        lstm_output = lstm_output[:, -1, :]
        fc_output = self._fc.forward(lstm_output)
        return fc_output

## Training utils

In [12]:
def run_epoch(model, optimizer, criterion, batches, phase='train'):
    is_train = phase == 'train'
    if is_train:
        model.train()
    else:
        model.eval()

    epoch_loss = 0.0
    n_predictions = 0
    
    correct_predictions = 0

    for X_batch, y_batch in batches:
        X_batch = X_batch.to(DEVICE)
        y_batch = y_batch.to(DEVICE)

        with torch.set_grad_enabled(is_train):
            y_pred = model.forward(X_batch)
            loss = criterion.forward(y_pred, y_batch)
    
        if is_train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss += loss.item() * y_batch.shape[0]
        correct_predictions += (torch.argmax(y_pred, dim=1) == y_batch).sum().item()
        n_predictions += y_batch.shape[0]

    epoch_loss = epoch_loss / n_predictions
    epoch_accuracy = correct_predictions / n_predictions

    return epoch_loss, epoch_accuracy


def train_model(model, optimizer, criterion, n_epoch, batch_size, train_dataset, val_dataset, backup_name):
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    best_val_loss = np.inf
    
    train_batches = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=4, shuffle=True)
    val_batches = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=4, shuffle=False)

    for epoch in range(n_epoch):
        train_loss, train_accuracy = run_epoch(model, optimizer, criterion, train_batches, phase='train')
        val_loss, val_accuracy = run_epoch(model, optimizer, criterion, val_batches, phase='val')

        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), backup_name)

        print("Epoch: " + str(epoch))
        print("Train loss: " + str(train_loss) + ", accuracy: " + str(train_accuracy))
        print("Val loss: " + str(val_loss) + ", accuracy: " + str(val_accuracy) + "\n\n")
        
    return train_losses, train_accuracies, val_losses, val_accuracies

## Training process

In [12]:
N_CLASSES = 6
INPUT_DIM = 36

In [13]:
model = LstmClassifier(INPUT_DIM, N_CLASSES)
model = model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

train_model(model,
            optimizer=torch.optim.Adam(model.parameters(), lr=1e-3),
            criterion=nn.CrossEntropyLoss(),
            n_epoch=100,
            batch_size=200,
            train_dataset=train_dataset,
            val_dataset=val_dataset,
            backup_name="lstm_action_classifier.pth.tar")

Epoch: 0
Train loss: 0.8735486226845841, accuracy: 0.6489281767955801
Val loss: 0.6211841620039762, accuracy: 0.7440445139975657


Epoch: 1
Train loss: 0.48130274497343034, accuracy: 0.8121988950276243
Val loss: 0.45876633668316125, accuracy: 0.8125543383759346


Epoch: 2
Train loss: 0.3702782436299719, accuracy: 0.8585193370165746
Val loss: 0.5034662859063835, accuracy: 0.8101199791340636


Epoch: 3
Train loss: 0.3491351202047991, accuracy: 0.8681546961325967
Val loss: 0.27011767895693073, accuracy: 0.9012345679012346


Epoch: 4
Train loss: 0.2769810041340675, accuracy: 0.891889502762431
Val loss: 0.313655665158899, accuracy: 0.869414014953921


Epoch: 5
Train loss: 0.20154464583205914, accuracy: 0.9223425414364641
Val loss: 0.20805899514211246, accuracy: 0.9231438010780734


Epoch: 6
Train loss: 0.2445332597879415, accuracy: 0.906121546961326
Val loss: 0.3991670820106114, accuracy: 0.8497652582159625


Epoch: 7
Train loss: 0.17224384997431086, accuracy: 0.9344530386740332
Val loss: 0

([0.8735486226845841,
  0.48130274497343034,
  0.3702782436299719,
  0.3491351202047991,
  0.2769810041340675,
  0.20154464583205914,
  0.2445332597879415,
  0.17224384997431086,
  0.1885981674023096,
  0.19384880530241444,
  0.19337554822639866,
  0.16536105377239416,
  0.1571138994552154,
  0.15273290199140158,
  0.10815980746436514,
  0.12133629483412643,
  0.13833611768088946,
  0.1293170387481046,
  0.10733797376491747,
  0.10590762594786797,
  0.0918239034835805,
  0.1047938061950121,
  0.07058675783115197,
  0.08766820735364987,
  0.10277271190071632,
  0.12322432884653176,
  0.08019538524956828,
  0.08205282129964776,
  0.07426594978620334,
  0.0755969595853548,
  0.09760807308867492,
  0.07207354542732404,
  0.0975812500682325,
  0.09537964315045605,
  0.09366875433450493,
  0.04696851966758542,
  0.0512240380254569,
  0.05459230239451631,
  0.08089525441829507,
  0.06131623744944018,
  0.06199158364649636,
  0.051481724839184166,
  0.06538778840672245,
  0.0629294374865063,
 

In [14]:
train_losses, train_accuracies, val_losses, val_accuracies = ([0.8735486226845841,
  0.48130274497343034,
  0.3702782436299719,
  0.3491351202047991,
  0.2769810041340675,
  0.20154464583205914,
  0.2445332597879415,
  0.17224384997431086,
  0.1885981674023096,
  0.19384880530241444,
  0.19337554822639866,
  0.16536105377239416,
  0.1571138994552154,
  0.15273290199140158,
  0.10815980746436514,
  0.12133629483412643,
  0.13833611768088946,
  0.1293170387481046,
  0.10733797376491747,
  0.10590762594786797,
  0.0918239034835805,
  0.1047938061950121,
  0.07058675783115197,
  0.08766820735364987,
  0.10277271190071632,
  0.12322432884653176,
  0.08019538524956828,
  0.08205282129964776,
  0.07426594978620334,
  0.0755969595853548,
  0.09760807308867492,
  0.07207354542732404,
  0.0975812500682325,
  0.09537964315045605,
  0.09366875433450493,
  0.04696851966758542,
  0.0512240380254569,
  0.05459230239451631,
  0.08089525441829507,
  0.06131623744944018,
  0.06199158364649636,
  0.051481724839184166,
  0.06538778840672245,
  0.0629294374865063,
  0.04425359480803349,
  0.05855927577252546,
  0.08699687327160525,
  0.07487187519443134,
  0.05951084620550851,
  0.05211047360593793,
  0.0654746382687632,
  0.07681369263658207,
  0.044812741151172154,
  0.04169330514596971,
  0.056397179180745924,
  0.06540772032921424,
  0.07633163423087057,
  0.055633779746886776,
  0.06838037114775641,
  0.053590912112544256,
  0.0628714390222539,
  0.045393846311621903,
  0.05751636609517408,
  0.08074402875827821,
  0.04321444615804029,
  0.05079126861872594,
  0.050103869856186006,
  0.03507644276016325,
  0.03688339269977289,
  0.041359042481246576,
  0.03544109247172077,
  0.03485896200435596,
  0.04210835051470699,
  0.07559601816384749,
  0.04357416177816813,
  0.030860471100808837,
  0.036541297235294604,
  0.03869030817890365,
  0.036192047033067894,
  0.05795359878571838,
  0.040659130261747875,
  0.04481880535572273,
  0.06331975865306445,
  0.04256203787514518,
  0.03762714751888046,
  0.0230426168100116,
  0.03029545205979716,
  0.04663788862835142,
  0.04618885129772497,
  0.06090077068450866,
  0.05378697729046914,
  0.05114451131438682,
  0.0467184322129791,
  0.05389563671920387,
  0.048165335642040105,
  0.04513653215453111,
  0.037053695992046956,
  0.07355283905870348,
  0.05021775581395428,
  0.05387003534121017],
 [0.6489281767955801,
  0.8121988950276243,
  0.8585193370165746,
  0.8681546961325967,
  0.891889502762431,
  0.9223425414364641,
  0.906121546961326,
  0.9344530386740332,
  0.9319779005524862,
  0.927646408839779,
  0.925878453038674,
  0.9390055248618785,
  0.9413480662983426,
  0.9398453038674033,
  0.9611491712707182,
  0.9558895027624309,
  0.9517348066298342,
  0.9534585635359116,
  0.9620773480662983,
  0.9622099447513812,
  0.9653038674033149,
  0.9622541436464088,
  0.9747624309392265,
  0.9688397790055249,
  0.963403314917127,
  0.9549613259668508,
  0.9732154696132597,
  0.9703867403314917,
  0.974232044198895,
  0.9735690607734807,
  0.9656574585635359,
  0.9750276243093923,
  0.964817679558011,
  0.9656132596685083,
  0.9657016574585635,
  0.9833370165745856,
  0.9824530386740331,
  0.9800220994475138,
  0.9707845303867403,
  0.9776795580110498,
  0.9775027624309393,
  0.9813480662983426,
  0.9775911602209945,
  0.9772375690607735,
  0.9844861878453038,
  0.9794475138121547,
  0.9695469613259668,
  0.9748066298342541,
  0.9797569060773481,
  0.981524861878453,
  0.9777237569060774,
  0.9734806629834254,
  0.9838232044198895,
  0.9848397790055249,
  0.980243093922652,
  0.9768839779005525,
  0.9737016574585635,
  0.9805966850828729,
  0.9762651933701657,
  0.9809060773480663,
  0.978342541436464,
  0.9840441988950276,
  0.97953591160221,
  0.9716685082872928,
  0.9851491712707182,
  0.9826298342541436,
  0.9824972375690608,
  0.9877127071823204,
  0.9868729281767956,
  0.9853259668508287,
  0.9874475138121547,
  0.9878453038674033,
  0.9854143646408839,
  0.9754696132596685,
  0.9849723756906077,
  0.9887734806629834,
  0.9881546961325967,
  0.9861215469613259,
  0.9877569060773481,
  0.9792707182320441,
  0.9848397790055249,
  0.984707182320442,
  0.9776795580110498,
  0.984707182320442,
  0.9879779005524861,
  0.9927513812154696,
  0.9897458563535911,
  0.9848839779005525,
  0.9849281767955801,
  0.9796685082872928,
  0.9809502762430939,
  0.9824530386740331,
  0.9832044198895028,
  0.9803314917127072,
  0.9834254143646409,
  0.9847513812154696,
  0.9872265193370166,
  0.9749834254143647,
  0.9823204419889503,
  0.982718232044199],
 [0.6211841620039762,
  0.45876633668316125,
  0.5034662859063835,
  0.27011767895693073,
  0.313655665158899,
  0.20805899514211246,
  0.3991670820106114,
  0.32158903264995403,
  0.24071570668090958,
  0.2567666179594932,
  0.3852716677423312,
  0.4597050237921575,
  0.17627169369967327,
  0.23679502981565434,
  0.17220754785561754,
  0.18002146034804078,
  0.22931813937735357,
  0.16827500484832936,
  0.15206738518692042,
  0.14282209915193997,
  0.18685480398813387,
  0.11685570925779426,
  0.20067864179665781,
  0.1411727972821085,
  0.1530777178196244,
  0.13876111770846125,
  0.14196548254774952,
  0.2708777552003365,
  0.11546770097299551,
  0.3046354775860649,
  0.10854558102802168,
  0.07933161259128273,
  0.12404691129011139,
  0.38784327351626957,
  0.07989218154476951,
  0.17622217972792195,
  0.09162371938435548,
  0.13405118411540942,
  0.10392435214123062,
  0.10115272597895769,
  0.09087322915753134,
  0.1590506696334804,
  0.22317246896386603,
  0.08073770713501655,
  0.058100116569523644,
  0.14888879645115743,
  0.13044291378562928,
  0.07736094586014841,
  0.11220023536174023,
  0.13659915878864518,
  0.07129768802977254,
  0.20124599370539134,
  0.07354985906765003,
  0.08835384215030306,
  0.21074753585173564,
  0.09616019399096185,
  0.11539241931760864,
  0.120142063001071,
  0.0775641988276421,
  0.09695740938079764,
  0.09058009887954965,
  0.06536958760441548,
  0.14220434166765117,
  0.07930320302022702,
  0.08943844097148232,
  0.06493218757685769,
  0.08265407143692353,
  0.0510503365727251,
  0.06409464690922057,
  0.0902351198168243,
  0.06698930089201968,
  0.07941066838737426,
  0.27309605946031246,
  0.10141229745728485,
  0.08438794725363463,
  0.06564738088682306,
  0.07934415308124455,
  0.11631594676814944,
  0.11903673065906457,
  0.09058221255546416,
  0.12191057274405673,
  0.1233115003656721,
  0.10492566676817347,
  0.11040246008239304,
  0.0784957536592911,
  0.05796452499581156,
  0.06605394684588234,
  0.08701971364890956,
  0.23499713032428876,
  0.14679802382028134,
  0.0931847726414798,
  0.1281534519503502,
  0.09721326736259316,
  0.0960857172550666,
  0.1704642465234064,
  0.11685481813445657,
  0.14966437045596195,
  0.22490323309782892,
  0.09092609374476246,
  0.08504296933522287],
 [0.7440445139975657,
  0.8125543383759346,
  0.8101199791340636,
  0.9012345679012346,
  0.869414014953921,
  0.9231438010780734,
  0.8497652582159625,
  0.892018779342723,
  0.9128847157016171,
  0.9021039819161885,
  0.8753260302556077,
  0.8575899843505478,
  0.9285341679707877,
  0.9189706138062945,
  0.9384454877412624,
  0.9398365501651886,
  0.9163623717614328,
  0.94679186228482,
  0.9419231438010781,
  0.942792557816032,
  0.9297513475917232,
  0.952529994783516,
  0.9400104329681794,
  0.9497478699356634,
  0.9521822291775344,
  0.9539210572074421,
  0.9410537297861241,
  0.9031472787341331,
  0.9601808381151105,
  0.8941053729786124,
  0.9612241349330551,
  0.9688749782646496,
  0.9575725960702487,
  0.8626325856372805,
  0.9680055642496957,
  0.9514866979655712,
  0.9640062597809077,
  0.952529994783516,
  0.9575725960702487,
  0.9641801425838985,
  0.9650495565988524,
  0.9384454877412624,
  0.9297513475917232,
  0.9683533298556772,
  0.9746131107633456,
  0.9436619718309859,
  0.9544427056164145,
  0.9688749782646496,
  0.9613980177360459,
  0.9547904712223961,
  0.9721787515214745,
  0.9337506520605112,
  0.9754825247782994,
  0.968527212658668,
  0.9198400278212485,
  0.9629629629629629,
  0.9573987132672579,
  0.9657450878108155,
  0.97339593114241,
  0.96452790818988,
  0.9657450878108155,
  0.9735698139454008,
  0.9497478699356634,
  0.9697443922796035,
  0.9680055642496957,
  0.9721787515214745,
  0.970961571900539,
  0.9780907668231612,
  0.974960876369327,
  0.9699182750825943,
  0.9723526343244653,
  0.9711354547035298,
  0.9167101373674144,
  0.9673100330377326,
  0.9711354547035298,
  0.9798295948530691,
  0.9725265171274561,
  0.9643540253868892,
  0.9551382368283776,
  0.9697443922796035,
  0.9587897756911842,
  0.9591375412971657,
  0.970961571900539,
  0.9648756737958616,
  0.9720048687184838,
  0.9810467744740046,
  0.975830290384281,
  0.9728742827334377,
  0.9300991131977048,
  0.9553121196313684,
  0.9706138062945575,
  0.9601808381151105,
  0.9674839158407234,
  0.9673100330377326,
  0.9394887845592071,
  0.9636584941749261,
  0.9558337680403408,
  0.9257520431229351,
  0.9659189706138063,
  0.970961571900539])