**Get the training dataset**

In [3]:
!wget -O RNN-HAR-2D-Pose-database.zip https://drive.google.com/u/1/uc?id=1IuZlyNjg6DMQE3iaO1Px6h1yLKgatynt

--2021-07-29 15:19:47--  https://drive.google.com/u/1/uc?id=1IuZlyNjg6DMQE3iaO1Px6h1yLKgatynt
Resolving drive.google.com (drive.google.com)... 172.217.163.238, 2404:6800:4005:810::200e
Connecting to drive.google.com (drive.google.com)|172.217.163.238|:443... connected.
HTTP request sent, awaiting response... 302 Moved Temporarily
Location: https://doc-10-00-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/b0k70af4mu2q65l0puoj4afa647fn07v/1627546725000/01198563878503023152/*/1IuZlyNjg6DMQE3iaO1Px6h1yLKgatynt [following]
--2021-07-29 15:19:51--  https://doc-10-00-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/b0k70af4mu2q65l0puoj4afa647fn07v/1627546725000/01198563878503023152/*/1IuZlyNjg6DMQE3iaO1Px6h1yLKgatynt
Resolving doc-10-00-docs.googleusercontent.com (doc-10-00-docs.googleusercontent.com)... 142.250.199.65, 2404:6800:4005:80b::2001
Connecting to doc-10-00-docs.googleusercontent.com (doc-10-00-docs.googleusercontent.com)|142.250.1

In [4]:
!unzip RNN-HAR-2D-Pose-database.zip

Archive:  RNN-HAR-2D-Pose-database.zip
   creating: RNN-HAR-2D-Pose-database/
  inflating: RNN-HAR-2D-Pose-database/README.md  
  inflating: RNN-HAR-2D-Pose-database/X_val2.txt  
  inflating: RNN-HAR-2D-Pose-database/X_val.txt  
  inflating: RNN-HAR-2D-Pose-database/Y_train.txt  
  inflating: RNN-HAR-2D-Pose-database/Y_test.txt  
  inflating: RNN-HAR-2D-Pose-database/X_train.txt  
  inflating: RNN-HAR-2D-Pose-database/X_test.txt  


**List first two rows of training set**

In [6]:
!head -2 RNN-HAR-2D-Pose-database/X_train.txt

295.914,161.579,307.693,203.413,281.546,203.368,274.997,251.562,267.194,293.253,337.619,204.669,347.958,255.443,341.541,295.866,286.81,289.393,297.196,355.832,297.22,405.371,321.967,291.959,327.143,358.408,328.528,411.922,294.546,156.42,305.002,156.418,0,0,318.083,161.632
295.855,161.6,307.684,203.408,281.529,203.385,274.989,251.574,267.191,291.961,337.615,204.646,347.974,254.209,344.093,295.816,286.803,289.377,297.165,355.827,297.205,404.095,323.248,290.652,324.564,358.409,328.493,410.63,293.252,157.686,303.706,157.706,0,0,318.024,161.654


**Clone the codebase**

In [None]:
!git clone https://github.com/spmallick/learnopencv.git

In [None]:
%cd learnopencv/Human-Action-Recognition-Using-Detectron2-And-Lstm

**Install dependencies**

In [None]:
!pip install -r requirements.txt

In [7]:
DATASET_PATH = "RNN-HAR-2D-Pose-database/"

In [8]:
from argparse import ArgumentParser

def configuration_parser(parent_parser):
    parser = ArgumentParser(parents=[parent_parser], add_help=False)
    parser.add_argument('--batch_size', type=int, default=512)
    parser.add_argument('--epochs', type=int, default=400)
    parser.add_argument('--data_root', type=str, default=DATASET_PATH)
    parser.add_argument('--learning_rate', type=float, default=0.0001)
    parser.add_argument('--num_class', type=int, default=6)
    return parser

In [9]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor

from src.lstm import ActionClassificationLSTM, PoseDataModule

In [10]:
def do_training_validation():
    pl.seed_everything(21)    
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = configuration_parser(parser)
    # args = parser.parse_args()
    args, unknown = parser.parse_known_args()
    # init model    
    model = ActionClassificationLSTM(34, 50, learning_rate=args.learning_rate)
    data_module = PoseDataModule(data_root=args.data_root,
                                        batch_size=args.batch_size)    
    #save only the top 1 model based on val_loss
    checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor='val_loss')
    lr_monitor = LearningRateMonitor(logging_interval='step')  
    #trainer
    trainer = pl.Trainer.from_argparse_args(args,
        # fast_dev_run=True,
        max_epochs=args.epochs, 
        deterministic=True, 
        gpus=1, 
        progress_bar_refresh_rate=1, 
        callbacks=[EarlyStopping(monitor='train_loss', patience=15), checkpoint_callback, lr_monitor])    
    trainer.fit(model, data_module)    
    return model

In [11]:
# To reload tensorBoard
%load_ext tensorboard

# logs folder path
%tensorboard --logdir=lightning_logs

**Training**



In [None]:
do_training_validation()

**Get the saved model**

In [13]:
import os
def get_latest_run_version_ckpt_epoch_no(lightning_logs_dir='lightning_logs', run_version=None):
    if run_version is None:
        run_version = 0
        for dir_name in os.listdir(lightning_logs_dir):
            if 'version' in dir_name:
                if int(dir_name.split('_')[1]) > run_version:
                    run_version = int(dir_name.split('_')[1])                
    checkpoints_dir = os.path.join(lightning_logs_dir, 'version_{}'.format(run_version), 'checkpoints')    
    files = os.listdir(checkpoints_dir)
    ckpt_filename = None
    for file in files:
        print(file)
        if file.endswith('.ckpt'):
            ckpt_filename = file        
    if ckpt_filename is not None:
        ckpt_path = os.path.join(checkpoints_dir, ckpt_filename)
    else:
        print('CKPT file is not present')    
    return ckpt_path

In [None]:
ckpt_path = get_latest_run_version_ckpt_epoch_no()
print('The latest model path: {}'.format(ckpt_path))