In [1]:
# Imports

In [None]:
# Utils
import torch
import torch.nn as nn
import numpy as np
import scipy.sparse as sp
from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels
import matplotlib.pyplot as plt
from torch.nn import init

# Dataloader 
import json
import math
import os
import random

import cv2

import utils

from torch.utils.data import Dataset
from sklearn.preprocessing import OneHotEncoder, LabelEncoder

# Model 
from __future__ import absolute_import
from __future__ import print_function

from torch.nn.parameter import Parameter

# Config
import configparser

# Train Function
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
import logging
import torch.optim as optim
from torch.utils.data import Dataset
#import utils
#from configs import Config
#from model import GCNMultiBlock
#from dataloader import Sign_Dataset

In [None]:
# Utils

In [None]:
#=====encoder and decoder functions=====

#to convert labels to categories (list of int)
def labels2cat(label_encoder, list):
    return label_encoder.transform(list)

#to encode labels to one-hot vectors
def labels2onehot(onehot_encoder, label_encoder, labels):
    return onehot_encoder.transform(label_encoder.transform(labels).reshape(-1, 1)).toarray()

#to decode one-hot to labels
def onehot2labels(label_encoder, y_onehot):
    return label_encoder.inverse_transform(np.where(y_onehot == 1)[1]).tolist()

#to convert categories to labels
def cat2labels(label_encoder, y_cat):
    return label_encoder.inverse_transform(y_cat).tolist()


# ------------------ RNN utils ------------------------##
def init_gru(gru):
    if isinstance(gru, nn.GRU) or isinstance(gru, nn.GRUCell):
        for param in gru.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)

def pad_and_pack_sequence(input_sequence):
    # pad sequences to have same length
    input_sequence = nn.utils.rnn.pad_sequence(input_sequence, batch_first=True)
    # calculate lengths of sequences and **store in a tensor**, otherwise pytorch cannot trace correctly.
    seq_lengths = torch.LongTensor(list(map(len, input_sequence)))
    # create packed sequence        
    packed_input_sequence = nn.utils.rnn.pack_padded_sequence(input_sequence, seq_lengths, batch_first=True,
                                                              enforce_sorted=False)

    return packed_input_sequence


def batch_select_tail(batch, in_lengths):
    """ Select tensors from a batch based on the time indices.
    
    E.g.     
    batch = tensor([[[ 0,  1,  2,  3],
                     [ 4,  5,  6,  7],
                     [ 8,  9, 10, 11]],
                     [[12, 13, 14, 15],
                     [16, 17, 18, 19],
                     [20, 21, 22, 23]]])
    of size = (2, 3, 4)
    
    indices = tensor([1, 2])
    
    returns tensor([[4, 5, 6, 7],
                    [20, 21, 22, 23]])
    """
    rv = torch.stack([torch.index_select(batch[i], 0, in_lengths[i] - 1).squeeze(0) for i in range(batch.size(0))])

    return rv


def batch_mean_pooling(batch, in_lengths):
    """ Select tensors from a batch based on the input sequence lengths. And apply mean pooling over it.
    E.g.
    batch = tensor([[[ 0,  1,  2,  3],
                     [ 4,  5,  6,  7],
                     [ 8,  9, 10, 11]],
                     [[12, 13, 14, 15],
                     [16, 17, 18, 19],
                     [20, 21, 22, 23]]])
    of size = (2, 3, 4)
    indices = tensor([1, 2])
    returns tensor([[0, 1, 2, 3],
                    [14, 15, 16, 17]])
    """
    mean = []

    for idx, instance in enumerate(batch):
        keep = instance[:int(in_lengths[idx])]

        mean.append(torch.mean(keep, dim=0))

    mean = torch.stack(mean, dim=0)
    return mean


def gather_last(batch_hidden_states, in_lengths):
    num_hidden_states = int(batch_hidden_states.size(2))

    indices = in_lengths.unsqueeze(1).unsqueeze(1) - 1
    indices = indices.repeat(1, 1, num_hidden_states)

    return torch.gather(batch_hidden_states, 1, indices).squeeze(1)


# --------------- plotting utils -------------------- #

def plot_curves(A=None, B=None, C=None, D=None):
    if not A:
        A = np.load('output/epoch_training_losses.npy')
        B = np.load('output/epoch_training_scores.npy')
        C = np.load('output/epoch_test_loss.npy')
        D = np.load('output/epoch_test_score.npy')

    epochs = A.shape[0]
    # plot
    plt.figure(figsize=(10, 4))

    plt.subplot(121)
    plt.plot(np.arange(1, epochs + 1), np.mean(A, axis=1))  # train loss (on epoch end)
    plt.plot(np.arange(1, epochs + 1), C)  # test loss (on epoch end)
    plt.title("model loss")
    plt.xlabel('epochs')
    plt.ylabel('loss')
    plt.legend(['train', 'test'], loc="upper left")

    # 2nd figure
    plt.subplot(122)
    plt.plot(np.arange(1, epochs + 1), np.mean(B, axis=1))  # train accuracy (on epoch end)
    plt.plot(np.arange(1, epochs + 1), D)  # test accuracy (on epoch end)
    plt.title("training scores")
    plt.xlabel('epochs')
    plt.ylabel('accuracy')
    plt.legend(['train', 'test'], loc="upper left")
    title = "output/curves.png"
    plt.savefig(title, dpi=600)
    # plt.close(fig)
    # plt.show()


def plot_confusion_matrix(y_true, y_pred, classes,
                          normalize=False,
                          title=None,
                          cmap=plt.cm.Blues,
                          save_to=None):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if not title:
        if normalize:
            title = 'Normalized confusion matrix'
        else:
            title = 'Confusion matrix, without normalization'

    # Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    # Only use the labels that appear in the data
    # classes = classes[unique_labels(y_true, y_pred)]
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    fig, ax = plt.subplots()
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           # ... and label them with the respective list entries
           xticklabels=classes, yticklabels=classes,
           title=title,
           ylabel='True label',
           xlabel='Predicted label',
           )

    ax.tick_params(axis='both', which='major', labelsize=5)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    fig.tight_layout()

    if save_to:
        plt.savefig(save_to, dpi=600)

    return ax


def encode_onehot(labels):
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
                    enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)),
                             dtype=np.int32)
    return labels_onehot


def load_data(path="../data/cora/", dataset="cora"):
    """Load citation network dataset (cora only for now)"""
    print('Loading {} dataset...'.format(dataset))

    idx_features_labels = np.genfromtxt("{}{}.content".format(path, dataset),
                                        dtype=np.dtype(str))
    features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
    labels = encode_onehot(idx_features_labels[:, -1])

    # build graph
    idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
    idx_map = {j: i for i, j in enumerate(idx)}
    edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset),
                                    dtype=np.int32)
    edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
                     dtype=np.int32).reshape(edges_unordered.shape)
    adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                        shape=(labels.shape[0], labels.shape[0]),
                        dtype=np.float32)

    # build symmetric adjacency matrix
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)

    features = normalize(features)
    adj = normalize(adj + sp.eye(adj.shape[0]))

    idx_train = range(140)
    idx_val = range(200, 500)
    idx_test = range(500, 1500)

    features = torch.FloatTensor(np.array(features.todense()))
    labels = torch.LongTensor(np.where(labels)[1])
    adj = sparse_mx_to_torch_sparse_tensor(adj)

    idx_train = torch.LongTensor(idx_train)
    idx_val = torch.LongTensor(idx_val)
    idx_test = torch.LongTensor(idx_test)

    return adj, features, labels, idx_train, idx_val, idx_test


def normalize(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx


def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)


def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)


In [2]:
# Dataloader

In [None]:
class Sign_Dataset(Dataset):

    def __init__(self, file_name_index, split, pose_directory, sampling_method='random', num_samples=25, num_copies=4,
                 img_transforms=None, video_transforms=None, test_index_file=None):
        
        assert os.path.exists(file_name_index), "File path does not exist...check...: {}.".format(file_name_index)
        assert os.path.exists(pose_directory), "File to pose does not exist...check... {}.".format(pose_directory)

        self.data = []
        self.label_encoder, self.onehot_encoder = LabelEncoder(), OneHotEncoder(categories='auto')

        if type(split) == 'str':
            split = [split]

        self.test_index_file = test_index_file
        self.init_dataset(file_name_index, split)

        self.file_name_index = file_name_index
        self.pose_directory = pose_directory
        self.framename = 'image_{}_keypoints.json'
        self.sampling_method = sampling_method
        self.num_samples = num_samples

        self.img_transforms = img_transforms
        self.video_transforms = video_transforms

        self.num_copies = num_copies

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        video_id, label_number, frame_start, frame_end = self.data[index]
        # frames of dimensions (T, H, W, C)
        x = self._load_poses(video_id, frame_start, frame_end, self.sampling_method, self.num_samples)

        if self.video_transforms:
            x = self.video_transforms(x)

        y = label_number

        return x, y, video_id

    def init_dataset(self, file_name_index, split):
        with open(file_name_index, 'r') as f:
            content = json.load(f)

        # create label encoder
        labels = sorted([label_entry['gloss'] for label_entry in content])

        self.label_encoder.fit(labels)
        self.onehot_encoder.fit(self.label_encoder.transform(self.label_encoder.classes_).reshape(-1, 1))

        if self.test_index_file is not None:
            print('Trained on {}, tested on {}'.format(file_name_index, self.test_index_file))
            with open(self.test_index_file, 'r') as f:
                content = json.load(f)

        # make dataset
        for label_entry in content:
            gloss, instances = label_entry['gloss'], label_entry['instances']
            label_number = utils.labels2cat(self.label_encoder, [gloss])[0]

            for instance in instances:
                if instance['split'] not in split:
                    continue

                frame_end = instance['frame_end']
                frame_start = instance['frame_start']
                video_id = instance['video_id']

                instance_entry = video_id, label_number, frame_start, frame_end
                self.data.append(instance_entry)

    def _load_poses(self, video_id, frame_start, frame_end, sampling_method, num_samples):
        """ Load frames of a video. Start and end indices are provided just to avoid listing and sorting the directory unnecessarily.
         """
        poses = []

        if sampling_method == 'random':
            frames_to_sample = rand_start_sampling(frame_start, frame_end, num_samples)
        elif sampling_method == 'seq':
            frames_to_sample = sequential_sampling(frame_start, frame_end, num_samples)
        elif sampling_method == 'k_copies':
            frames_to_sample = k_copies_fixed_length_sequential_sampling(frame_start, frame_end, num_samples,
                                                                         self.num_copies)
        
        else:
            raise NotImplementedError('Unimplemented sample strategy found: {}.'.format(sampling_method))

        for i in frames_to_sample:
            pose_path = os.path.join(self.pose_directory, video_id, self.framename.format(str(i).zfill(5)))
            # pose = cv2.imread(frame_path, cv2.COLOR_BGR2RGB)
            pose = read_pose_file(pose_path)

            if pose is not None:
                if self.img_transforms:
                    pose = self.img_transforms(pose)

                poses.append(pose)
            else:
                try:
                    poses.append(poses[-1])
                except IndexError:
                    print(pose_path)

        pad = None

        # if len(frames_to_sample) < num_samples:
        if len(poses) < num_samples:
            num_padding = num_samples - len(frames_to_sample)
            last_pose = poses[-1]
            pad = last_pose.repeat(1, num_padding)

        poses_across_time = torch.cat(poses, dim=1)
        if pad is not None:
            poses_across_time = torch.cat([poses_across_time, pad], dim=1)

        return poses_across_time


def rand_start_sampling(frame_start, frame_end, num_samples):
    """Randomly select a starting point and return the continuous ${num_samples} frames."""
    num_frames = frame_end - frame_start + 1

    if num_frames > num_samples:
        select_from = range(frame_start, frame_end - num_samples + 1)
        sample_start = random.choice(select_from)
        frames_to_sample = list(range(sample_start, sample_start + num_samples))
    else:
        frames_to_sample = list(range(frame_start, frame_end + 1))

    return frames_to_sample


def sequential_sampling(frame_start, frame_end, num_samples):
    """Keep sequentially ${num_samples} frames from the whole video sequence by uniformly skipping frames."""
    num_frames = frame_end - frame_start + 1

    frames_to_sample = []
    if num_frames > num_samples:
        frames_skip = set()

        num_skips = num_frames - num_samples
        interval = num_frames // num_skips

        for i in range(frame_start, frame_end + 1):
            if i % interval == 0 and len(frames_skip) <= num_skips:
                frames_skip.add(i)

        for i in range(frame_start, frame_end + 1):
            if i not in frames_skip:
                frames_to_sample.append(i)
    else:
        frames_to_sample = list(range(frame_start, frame_end + 1))

    return frames_to_sample

def k_copies_fixed_length_sequential_sampling(frame_start, frame_end, num_samples, num_copies):
    num_frames = frame_end - frame_start + 1

    frames_to_sample = []

    if num_frames <= num_samples:
        num_pads = num_samples - num_frames

        frames_to_sample = list(range(frame_start, frame_end + 1))
        frames_to_sample.extend([frame_end] * num_pads)

        frames_to_sample *= num_copies

    elif num_samples * num_copies < num_frames:
        mid = (frame_start + frame_end) // 2
        half = num_samples * num_copies // 2

        frame_start = mid - half

        for i in range(num_copies):
            frames_to_sample.extend(list(range(frame_start + i * num_samples,
                                               frame_start + i * num_samples + num_samples)))

    else:
        stride = math.floor((num_frames - num_samples) / (num_copies - 1))
        for i in range(num_copies):
            frames_to_sample.extend(list(range(frame_start + i * stride,
                                               frame_start + i * stride + num_samples)))

    return frames_to_sample


def read_pose_file(filepath):
    
    body_pose_exclude = {9, 10, 11, 22, 23, 24, 12, 13, 14, 19, 20, 21}

    try:
        content = json.load(open(filepath))["people"][0]
    except IndexError:
        return None

    path_parts = os.path.split(filepath)

    frame_id = path_parts[1][:11]
    vid = os.path.split(path_parts[0])[-1]

    save_to = os.path.join('/home/jovyan/Documents/DL/DL_Project/WLASL/code/Pose-GCN/posegcn/features', vid)

    try:
        ft = torch.load(os.path.join(save_to, frame_id + '_ft.pt'))

        xy = ft[:, :2]
        # angles = torch.atan(ft[:, 110:]) / 90
        # ft = torch.cat([xy, angles], dim=1)
        return xy
    except FileNotFoundError:
        
        # print(filepath)
        body_pose = content["pose_keypoints_2d"]
        left_hand_pose = content["hand_left_keypoints_2d"]
        right_hand_pose = content["hand_right_keypoints_2d"]

        body_pose.extend(left_hand_pose)
        body_pose.extend(right_hand_pose)

        x = [v for i, v in enumerate(body_pose) if i % 3 == 0 and i // 3 not in body_pose_exclude]
        y = [v for i, v in enumerate(body_pose) if i % 3 == 1 and i // 3 not in body_pose_exclude]
        # conf = [v for i, v in enumerate(body_pose) if i % 3 == 2 and i // 3 not in body_pose_exclude]

        x = 2 * ((torch.FloatTensor(x) / 256.0) - 0.5)
        y = 2 * ((torch.FloatTensor(y) / 256.0) - 0.5)
        # conf = torch.FloatTensor(conf)

        x_diff = torch.FloatTensor(get_difference(x)) / 2
        y_diff = torch.FloatTensor(get_difference(y)) / 2

        zero_indices = (x_diff == 0).nonzero()

        orient = y_diff / x_diff
        orient[zero_indices] = 0

        xy = torch.stack([x, y]).transpose_(0, 1)

        ft = torch.cat([xy, x_diff, y_diff, orient], dim=1)

        path_parts = os.path.split(filepath)

        frame_id = path_parts[1][:11]
        vid = os.path.split(path_parts[0])[-1]

        save_to = os.path.join('/home/jovyan/Documents/DL/DL_Project/WLASL/code/Pose-GCN/posegcn/features', vid)
        if not os.path.exists(save_to):
            os.mkdir(save_to)
        torch.save(ft, os.path.join(save_to, frame_id + '_ft.pt'))

        xy = ft[:, :2]
        # angles = torch.atan(ft[:, 110:]) / 90
        # ft = torch.cat([xy, angles], dim=1)
        #
        return xy
    

def get_difference(x):
    diff = []

    for i, j in enumerate(x):
        temp = []
        for j, k in enumerate(x):
            if i != j:
                temp.append(j - k)

        diff.append(temp)

    return diff

In [None]:
# TGCN Model

In [None]:
# Basic GCN Layer 
# Reference: https://towardsdatascience.com/program-a-simple-graph-net-in-pytorch-e00b500a642d
class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features, acti=True):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        if acti:
            self.acti = nn.ReLU(inplace=True)
        else:
            self.acti = None
    def forward(self, F):
        output = self.linear(F)
        if not self.acti:
            return output
        return self.acti(output)

# GC block module with stacked GCN and BatchNorm layers
# BN is for standardising input to layer for each mini-batch. Helps to stablise the learning process
class GCBlock(nn.Module):
    def __init__(self, in_features, p_dropout, bias=True, is_resi=True):
        super(GCBlock, self).__init__()
        self.in_features = in_features
        self.out_features = in_features
        # Set residual links to deal with gradient outflow
        self.is_resi = is_resi

        self.gc1 = GCNLayer(in_features, in_features)
        self.bn1 = nn.BatchNorm1d(55 * in_features)

        self.gc2 = GCNLayer(in_features, in_features)
        self.bn2 = nn.BatchNorm1d(55 * in_features)

        self.do = nn.Dropout(p_dropout)
        self.act_f = nn.Tanh()

    def forward(self, x):
        y = self.gc1(x)
        b, n, f = y.shape
        y = self.bn1(y.view(b, -1)).view(b, n, f)
        y = self.act_f(y)
        y = self.do(y)

        y = self.gc2(y)
        b, n, f = y.shape
        y = self.bn2(y.view(b, -1)).view(b, n, f)
        y = self.act_f(y)
        y = self.do(y)
        # Use residual links
        if self.is_resi:
            return y + x
        else:
            return y

# Multi GCBlock model with FC final layer for classification
class GCNMultiBlock(nn.Module):
    def __init__(self, input_feature, hidden_feature, num_class, p_dropout, num_stage=1, is_resi=True):
        super(GCNMultiBlock, self).__init__()
        self.num_stage = num_stage

        self.gc1 = GCNLayer(input_feature, hidden_feature)
        self.bn1 = nn.BatchNorm1d(55 * hidden_feature)

        self.gcbs = []
        for i in range(num_stage):
            self.gcbs.append(GCBlock(hidden_feature, p_dropout=p_dropout, is_resi=is_resi))

        self.gcbs = nn.ModuleList(self.gcbs)

        self.do = nn.Dropout(p_dropout)
        self.act_f = nn.Tanh()

        self.fc_out = nn.Linear(hidden_feature, num_class)

    def forward(self, x):
        y = self.gc1(x)
        b, n, f = y.shape
        y = self.bn1(y.view(b, -1)).view(b, n, f)
        y = self.act_f(y)
        y = self.do(y)

        for i in range(self.num_stage):
            y = self.gcbs[i](y)

        out = torch.mean(y, dim=1)
        out = self.fc_out(out)

        return out

In [None]:
# Config

In [None]:
class Config:
    def __init__(self, config_path):
        config = configparser.ConfigParser()
        config.read(config_path)

        # training
        train_config = config['TRAIN']
        self.batch_size = int(train_config['BATCH_SIZE'])
        self.max_epochs = int(train_config['MAX_EPOCHS'])
        self.log_interval = int(train_config['LOG_INTERVAL'])
        self.num_samples = int(train_config['NUM_SAMPLES'])
        self.drop_p = float(train_config['DROP_P'])

        # optimizer
        opt_config = config['OPTIMIZER']
        self.init_lr = float(opt_config['INIT_LR'])
        self.adam_eps = float(opt_config['ADAM_EPS'])
        self.adam_weight_decay = float(opt_config['ADAM_WEIGHT_DECAY'])

        # GCN
        gcn_config = config['GCN']
        self.hidden_size = int(gcn_config['HIDDEN_SIZE'])
        self.num_stages = int(gcn_config['NUM_STAGES'])

    def __str__(self):
        return 'bs={}_ns={}_drop={}_lr={}_eps={}_wd={}'.format(
            self.batch_size, self.num_samples, self.drop_p, self.init_lr, self.adam_eps, self.adam_weight_decay
        )

In [None]:
# Parameters

In [3]:
[TRAIN]
BATCH_SIZE = 64
MAX_EPOCHS = 200
LOG_INTERVAL = 1
; during training, only take NUM_SAMPLES frames per video
NUM_SAMPLES = 50
DROP_P = 0.3

[OPTIMIZER]
INIT_LR = 0.001
ADAM_EPS = 1e-3
ADAM_WEIGHT_DECAY = 0

[GCN]
HIDDEN_SIZE = 64
NUM_STAGES = 20

NameError: name 'TRAIN' is not defined

In [None]:
# Train Function

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
def run(file_split, pose_directory, configs, save_model=None):
    
    epochs = configs.max_epochs
    log_interval = configs.log_interval
    num_samples = configs.num_samples
    hidden_size = configs.hidden_size
    drop_p = configs.drop_p
    num_stages = configs.num_stages

    # setup dataset
    
    train_dataset = Sign_Dataset(file_name_index=file_split, split=['train', 'val'], pose_directory=pose_directory, img_transforms=None, video_transforms=None, num_samples=num_samples )

    train_data_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=configs.batch_size,
                                                    shuffle=True)

    val_dataset = Sign_Dataset(file_name_index=file_split, split='test', pose_directory=pose_directory,
                               img_transforms=None, video_transforms=None,
                               num_samples=num_samples, sampling_method='k_copies')
    
    val_data_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=configs.batch_size,
                                                  shuffle=True)

    # setup the model
    model = GCNMultiBlock(input_feature=num_samples*2, hidden_feature=num_samples*2,
                         num_class=len(train_dataset.label_encoder.classes_), p_dropout=drop_p, num_stage=num_stages).cuda()

    # setup training parameters, learning rate, optimizer, scheduler
    lr = configs.init_lr
    optimizer = optim.Adam(model.parameters(), lr=lr, eps=configs.adam_eps, weight_decay=configs.adam_weight_decay)


    best_test_acc = 0
    # start training
    for epoch in range(int(epochs)):
        # train, test model

        print('start training.')
        train_losses, train_scores, train_gts, train_preds = train(log_interval, model,
                                                                   train_data_loader, optimizer, epoch)
        print('start testing.')
        val_loss, val_score, val_gts, val_preds, incorrect_samples = validation(model,
                                                                                val_data_loader, epoch,
                                                                                save_to=save_model)

        if val_score[0] > best_test_acc:
            best_test_acc = val_score[0]
            best_epoch_num = epoch

            torch.save(model.state_dict(), os.path.join('checkpoints', subset, 'gcn_epoch={}_val_acc={}.pth'.format(
                best_epoch_num, best_test_acc)))

    utils.plot_curves()

    class_names = train_dataset.label_encoder.classes_
    utils.plot_confusion_matrix(train_gts, train_preds, classes=class_names, normalize=False,
                                save_to='output/train-conf-mat')
    utils.plot_confusion_matrix(val_gts, val_preds, classes=class_names, normalize=False, save_to='output/val-conf-mat')


def train(frequency, model, train_loader, optimizer, epoch):
    # set model as training mode
    losses = []
    scores = []
    train_labels = []
    train_preds = []

    train_count = 0  # counting total trained sample in one epoch
    for batch_idx, data in enumerate(train_loader):
        X, y, video_ids = data
        # distribute data to device
        X, y = X.cuda(), y.cuda().view(-1, )

        train_count += X.size(0)

        optimizer.zero_grad()
        out = model(X)  # output has dim = (batch, number of classes)

        loss = F.cross_entropy(out, y)

        losses.append(loss.item())

        # to compute accuracy
        y_pred = torch.max(out, 1)[1]  # y_pred != output

        step_score = accuracy_score(y.cpu().data.squeeze().numpy(), y_pred.cpu().data.squeeze().numpy())

        # collect prediction labels
        train_labels.extend(y.cpu().data.squeeze().tolist())
        train_preds.extend(y_pred.cpu().data.squeeze().tolist())

        scores.append(step_score)  # computed on CPU

        loss.backward()


        optimizer.step()

        # show information
        if (batch_idx + 1) % frequency == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, Accu: {:.6f}%'.format(
                epoch + 1, train_count, len(train_loader.dataset), 100. * (batch_idx + 1) / len(train_loader), loss.item(),
                100 * step_score))

    return losses, scores, train_labels, train_preds


def validation(model, test_loader, epoch, save_to):

    model.eval()

    val_loss = []
    labels = []
    labels_pred = []
    all_video_ids = []
    all_pool_out = []

    num_copies = 4

    with torch.no_grad():
        for batch_idx, data in enumerate(test_loader):
            # distribute data to device
            X, y, video_ids = data
            X, y = X.cuda(), y.cuda().view(-1, )

            all_output = []

            stride = X.size()[2] // num_copies

            for i in range(num_copies):
                X_slice = X[:, :, i * stride: (i+1) * stride]
                output = model(X_slice)
                all_output.append(output)

            all_output = torch.stack(all_output, dim=1)
            output = torch.mean(all_output, dim=1)

           
            loss = F.cross_entropy(output, y)

            val_loss.append(loss.item())  # sum up batch loss
            y_pred = output.max(1, keepdim=True)[1]  # (y_pred != output) get the index of the max log-probability

            # collect all y and y_pred in all batches
            labels.extend(y)
            labels_pred.extend(y_pred)
            all_video_ids.extend(video_ids)
            all_pool_out.extend(output)

    # this computes the average loss on the BATCH
    val_loss = sum(val_loss) / len(val_loss)

    # compute accuracy
    labels = torch.stack(labels, dim=0)
    labels_pred = torch.stack(labels_pred, dim=0).squeeze()
    all_pool_out = torch.stack(all_pool_out, dim=0).cpu().data.numpy()

    # log down incorrectly labelled instances
    incorrect_indices = torch.nonzero(labels - labels_pred).squeeze().data
    incorrect_video_ids = [(vid, int(labels_pred[i].data)) for i, vid in enumerate(all_video_ids) if
                           i in incorrect_indices]

    labels = labels.cpu().data.numpy()
    labels_pred = labels_pred.cpu().data.numpy()

    # top-k accuracy
    top1acc = accuracy_score(labels, labels_pred)
    top3acc = compute_top_n_accuracy(labels, all_pool_out, 3)
    top5acc = compute_top_n_accuracy(labels, all_pool_out, 5)
    top10acc = compute_top_n_accuracy(labels, all_pool_out, 10)
    top30acc = compute_top_n_accuracy(labels, all_pool_out, 30)

    # show information
    print('\nVal. set ({:d} samples): Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(len(labels), val_loss,
                                                                                        100 * top1acc))

    if save_to:
        # save Pytorch models of best record
        torch.save(model.state_dict(),
                   os.path.join(save_to, 'gcn_epoch{}.pth'.format(epoch + 1)))  # save spatial_encoder
        print("Epoch {} model saved!".format(epoch + 1))

    return val_loss, [top1acc, top3acc, top5acc, top10acc, top30acc], labels.tolist(), labels_pred.tolist(), incorrect_video_ids



def compute_top_n_accuracy(truths, preds, n):
    best_n = np.argsort(preds, axis=1)[:, -n:]
    ts = truths
    successes = 0
    for i in range(ts.shape[0]):
        if ts[i] in best_n[i, :]:
            successes += 1
    return float(successes) / ts.shape[0]


if __name__ == "__main__":
    direc = '/home/jovyan/Documents/DL/DL_Project/WLASL'
    subset = 'asl100'
    split_file = os.path.join(direc, 'data/splits/{}.json'.format(subset))
    pose_data = os.path.join(direc, 'data/pose_per_individual_videos')
    config_file = os.path.join(direc, 'code/TGCN/configs/{}.ini'.format(subset))
    configs = Config(config_file)
    run(file_split=split_file, configs=configs, pose_directory=pose_data)