# Imports

In [1]:
# Allow importing items in src folder
import sys

SCRIPT_DIR = "/home/shane/Projects/bsltranslate-model"
sys.path.append(SCRIPT_DIR)

In [2]:
import pandas as pd
import torch.nn as nn
import torch
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torch.utils.mobile_optimizer import optimize_for_mobile
from datetime import datetime
from os import listdir, mkdir

In [3]:
from src.files import read_vocab_file, load_pickle, save_pickle, save_csv
from src.data_proc import normalize_data, tokenize
from src.training import split_by_video, to_device, DeviceDataLoader, get_optimizer, set_learning_rates, \
                         train_model, valid_model, calc_class_weights
from src.datasets import *
from src.models import *
from src.utilities import lr_finder

# Variables

In [4]:
run_lr_finder = False
train = True
print_per_sign_stats = True
save = True

In [5]:
dataset_folder = '../../bsltranslate-test-dataset/'
data_folder = dataset_folder + 'video_data'
vocab_file = dataset_folder + "vocab.csv"

In [6]:
model_name = 'cnn_3d'
fps = 7
in_chan = 1
validation_pct = 0.15
batch_size = 16384
learning_rate = 1e-2
weight_decay = 0
epochs = 25
no_of_coords = 126
dims = 3

In [7]:
model_fn = Cnn_3d
dataset_fn = DatasetStream3d
loss_func = nn.CrossEntropyLoss

In [8]:
model_output_folder = "../models"

In [9]:
lr_finder_weight_decays = [0, 1e-4, 1e-3, 1e-2, 1e-1]

# Load Mediapipe data - Individual videos

In [10]:
file_list = listdir(data_folder)
data = []
metadata = []
for file in file_list:
   data_import = load_pickle(data_folder + "/" + file)
   data.append(pd.DataFrame(data_import[0]))
   metadata.append(pd.DataFrame(data_import[1]))
data_df = pd.concat(data)
metadata_df = pd.concat(metadata)

# Normalize Data

In [11]:
normalized_data_df, normalization_stats = normalize_data(data_df)

# Tokenise Labels

In [12]:
vocab = read_vocab_file(vocab_file)
no_lbls = len(vocab)

In [13]:
metadata_df['label'] = [tokenize(label, vocab) for label in metadata_df['label']]

# Create Training & Validation Set

In [14]:
train_df, valid_df = split_by_video(pd.concat([normalized_data_df, metadata_df], axis=1), validation_pct, seed_no=0.05)

In [15]:
train_df['vid_fname'].unique()

array(['../videos/Alphabet36.mp4', '../videos/Alphabet41.mp4',
       '../videos/Alphabet38.mp4', '../videos/Alphabet31.mp4',
       '../videos/Alphabet16.mp4', '../videos/Alphabet07.webm',
       '../videos/Alphabet35.mp4', '../videos/Alphabet08.mp4',
       '../videos/Alphabet23.mp4', '../videos/Alphabet15.mp4',
       '../videos/Alphabet28.mp4', '../videos/Alphabet32.mp4',
       '../videos/Alphabet34.mp4', '../videos/Alphabet49.mp4',
       '../videos/Alphabet29.mp4', '../videos/Alphabet14.mp4',
       '../videos/Alphabet50.mp4', '../videos/Alphabet43.mp4',
       '../videos/Alphabet44.mp4', '../videos/Alphabet06.webm',
       '../videos/Alphabet03.mp4', '../videos/Alphabet33.mp4',
       '../videos/Alphabet12.mp4', '../videos/Alphabet19.mp4',
       '../videos/Alphabet27.mp4', '../videos/Alphabet30.mp4',
       '../videos/Alphabet09.mp4', '../videos/Alphabet02.mp4',
       '../videos/Alphabet11.mp4', '../videos/Alphabet48.mp4',
       '../videos/Alphabet05.mp4', '../videos/Alphabe

In [16]:
valid_df['vid_fname'].unique()

array(['../videos/Alphabet24.mp4', '../videos/Alphabet39.mp4',
       '../videos/Alphabet46.mp4', '../videos/Alphabet37.mp4',
       '../videos/Alphabet20.mp4', '../videos/Alphabet01.mp4',
       '../videos/Alphabet26.mp4'], dtype=object)

# Create Dataset

In [17]:
train_ds = dataset_fn(train_df, normalization_stats, frames_per_sign=fps)
valid_ds = dataset_fn(valid_df, normalization_stats, frames_per_sign=fps)

# Create Dataloaders

In [18]:
train_dl = DeviceDataLoader(DataLoader(train_ds, batch_size=batch_size, shuffle=False))
valid_dl = DeviceDataLoader(DataLoader(valid_ds, batch_size=batch_size, shuffle=False))

# Calculate Class Weights

Hopefully this will take care of how NaS has many times more examples in the dataset than the other signs

In [19]:
class_weights = calc_class_weights(train_df, no_lbls)
class_weights.size()

torch.Size([27])

# Run Learning Rate Finder (Optional)

In [20]:
if run_lr_finder:
    loss_fn = loss_func(weight=class_weights.cuda())
    model = model_fn(no_lbls, in_channels=in_chan)
    to_device(model)
    for wd in lr_finder_weight_decays:
        lrs, losses = lr_finder(train_dl, model, loss_fn, weight_decay=wd)
        plt.plot(lrs, losses)
        plt.xlabel('Learning Rates (10^)')
        plt.ylabel('Loss')
        plt.title("Weight Decay: {}".format(wd))
        plt.show()

# Create Model

In [21]:
if train:
    model = model_fn(no_lbls, in_channels=in_chan)
    to_device(model)
    optim = get_optimizer(model, lr=learning_rate, wd=weight_decay)
    loss_fn = loss_func(weight=class_weights.cuda())

# Train

In [22]:
if train:
    lrs = set_learning_rates(epochs * len(train_dl), learning_rate)
    for epoch in range(epochs):
        print("Epoch: ", epoch)
        loss = train_model(model, optim, loss_fn, train_dl, lrs, epoch)
        print("Training Loss: ", loss)
        valid_loss, acc, sign_correct = valid_model(model, valid_dl, loss_fn, no_lbls)
        print("Validation Loss: ", valid_loss)
        print("Accuracy: ", acc)
        if print_per_sign_stats:
            print("Per sign accuracy:")
            for i, j in zip(vocab, sign_correct):
                print("{}: {}".format(vocab[i], j))

Epoch:  0
Training Loss:  2.6450214439150286
Validation Loss:  3.0408025283949622
Accuracy:  0.08923990819495073
Per sign accuracy:
NaS: 0.9377063777687085
a: 0.1794871794871795
b: 0.05739795918367347
c: 0.004250797024442083
d: 0.0
e: 0.0
f: 0.0
g: 0.2984771573604061
h: 0.0029895366218236174
i: 0.0
j: 0.03
k: 0.0
l: 0.0
m: 0.44216867469879517
n: 0.0
o: 0.0
p: 0.0
q: 0.4342273307790549
r: 0.0
s: 0.20745131244707873
t: 0.18148487626031165
y: 0.14832535885167464
u: 0.0
v: 0.0
w: 0.0
x: 0.0
z: 0.21114369501466276
Epoch:  1
Training Loss:  1.8204613629210133
Validation Loss:  2.692819075025965
Accuracy:  0.24787363304981774
Per sign accuracy:
NaS: 0.6630059733610358
a: 0.014652014652014652
b: 0.6122448979591837
c: 0.9298618490967057
d: 0.6590909090909091
e: 0.0
f: 0.3088942307692308
g: 0.0
h: 0.08221225710014948
i: 0.031496062992125984
j: 0.37
k: 0.0
l: 0.28614008941877794
m: 0.3819277108433735
n: 0.3333333333333333
o: 0.00663716814159292
p: 0.5493562231759657
q: 0.35887611749680715
r: 0.10

# Save the Model (for desktop + mobile)

In [23]:
def save_models(model_output_folder, model_name, desktop_model, mobile_model, normalization_stats, vocab):
    d_today = datetime.now().today().strftime("%Y%m%d%H%M%S")
    model_save_dir = model_output_folder + "/" + d_today
    mkdir(model_save_dir)

    # Save desktop model
    mkdir(model_save_dir + "/desktop")
    save_pickle(normalization_stats, "{}/desktop/stream_{}_norm_stats.pkl".format(model_save_dir, model_name))
    torch.save(desktop_model.state_dict(), "{}/desktop/stream_{}.pt".format(model_save_dir, model_name))
    save_csv(vocab, "{}/desktop/stream_{}_vocab.csv".format(model_save_dir, model_name))

    # Save mobile model
    mkdir(model_save_dir + "/mobile")
    torch.jit.save(mobile_model, "{}/mobile/stream_{}.pt".format(model_save_dir, model_name))
    save_csv(vocab, "{}/mobile/stream_{}_vocab.csv".format(model_save_dir, model_name))
    save_csv(normalization_stats, "{}/mobile/stream_{}_norm_stats.csv".format(model_save_dir, model_name))

In [24]:
if save:
    dummy_input = torch.rand(1, 1, fps, dims, int(no_of_coords/dims))
    android_model = torch.jit.trace(model.to("cpu"), dummy_input)
    torchscript_optim_android_model = optimize_for_mobile(android_model)
    # Need to move mobile model to CPU before saving or the app crashes
    save_models(model_output_folder, model_name, model, torchscript_optim_android_model.cpu(), normalization_stats, vocab)