In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
import sys
PROJECT_ROOT = os.path.abspath(os.curdir) ## Please make sure it outputs the correct root, No universal solution for Colab - Jupyter - Local environment
PROJECT_ROOT += "/drive/MyDrive/GCN/" 
print('The Project is running from : ', PROJECT_ROOT)
sys.path.append(PROJECT_ROOT)

The Project is running from :  /content/drive/MyDrive/GCN/


Code Adapted from https://github.com/garyzhao/SemGCN

In [3]:
from __future__ import print_function, absolute_import, division

import os
import time
import datetime
import argparse
import numpy as np
import os.path as path
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
import pickle

from progress.bar import Bar
from common.log import Logger, savefig
from common.utils import AverageMeter, lr_decay, save_ckpt
from common.graph_utils import adj_mx_from_skeleton
from common.data_utils import fetch, read_3d_data, create_2d_data
from common.generators import PoseGenerator,PoseGenerator_Multi
from common.loss import mpjpe, p_mpjpe
from models.sem_gcn import SemGCN
from models.MLP import mlp_model
from models.Sem_gcn_multiView import SemGCN_Concat_Non_Shared,SemGCN_Sum_Non_Shared,SemGCN_Concat_Shared,SemGCN_Sum_Shared

In [4]:
from sklearn.metrics import mean_squared_error
def parse_args():
        parser = argparse.ArgumentParser(description='PyTorch training script')

        # General arguments
        parser.add_argument('-d', '--dataset', default='h36m', type=str, metavar='NAME', help='target dataset')
        parser.add_argument('-k', '--keypoints', default='gt', type=str, metavar='NAME', help='2D detections to use')
        parser.add_argument('-a', '--actions', default='*', type=str, metavar='LIST',
                                                help='actions to train/test on, separated by comma, or * for all')
        parser.add_argument('--evaluate', default=False, type=str, metavar='FILENAME',
                                                help='checkpoint to evaluate (file name)')
        parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME',
                                                help='checkpoint to resume (file name)')
        parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
                                                help='checkpoint directory')
        parser.add_argument('--snapshot', default=5, type=int, help='save models for every #snapshot epochs (default: 20)')

        # Model arguments
        parser.add_argument('-l', '--num_layers', default=4, type=int, metavar='N', help='num of residual layers')
        parser.add_argument('-z', '--hid_dim', default=128, type=int, metavar='N', help='num of hidden dimensions')
        parser.add_argument('-b', '--batch_size', default=64, type=int, metavar='N',
                                                help='batch size in terms of predicted frames')
        parser.add_argument('-e', '--epochs', default=150, type=int, metavar='N', help='number of training epochs')
        parser.add_argument('--num_workers', default=2, type=int, metavar='N', help='num of workers for data loading')
        parser.add_argument('--lr', default=1.0e-4, type=float, metavar='LR', help='initial learning rate')
        parser.add_argument('--lr_decay', type=int, default=100000, help='num of steps of learning rate decay')
        parser.add_argument('--lr_gamma', type=float, default=0.96, help='gamma of learning rate decay')
        parser.add_argument('--no_max', dest='max_norm', action='store_false', help='if use max_norm clip on grad')
        parser.set_defaults(max_norm=True)
        parser.add_argument('--non_local', dest='non_local', action='store_true', help='if use non-local layers')
        parser.set_defaults(non_local=False)
        parser.add_argument('--dropout', default=0.0, type=float, help='dropout rate')

        # Experimental
        parser.add_argument('--downsample', default=1, type=int, metavar='FACTOR', help='downsample frame rate by factor')

        args,unknown = parser.parse_known_args()

        # Check invalid configuration
        if args.resume and args.evaluate:
                print('Invalid flags: --resume and --evaluate cannot be set at the same time')
                exit()

        return args

def pickle_read(filename):
        with open(filename, "rb") as f:
                data = pickle.load(f)
        return data

def train_mono_or_mlp(data_loader, model_pos, criterion, optimizer, device, lr_init, lr_now, step, decay, gamma, max_norm=True):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        epoch_loss_3d_pos = AverageMeter()

        # Switch to train mode
        torch.set_grad_enabled(True)
        model_pos.train()
        end = time.time()

        bar = Bar('Train', max=len(data_loader))
        for i, (targets_3d, inputs_2d, _) in enumerate(data_loader):
                # Measure data loading time
                data_time.update(time.time() - end)
                num_poses = targets_3d.size(0)
                
                ###Flatten FOR MLP, comment if GCN
                # inputs_2d = torch.flatten(inputs_2d,start_dim = 1)
                # targets_3d = torch.flatten(targets_3d,start_dim = 1)
                

                step += 1
                if step % decay == 0 or step == 1:
                        lr_now = lr_decay(optimizer, step, lr_init, decay, gamma)
                targets_3d, inputs_2d = targets_3d.to(device), inputs_2d.to(device)
                
                outputs_3d = model_pos(inputs_2d)



                optimizer.zero_grad()
                loss_3d_pos = criterion(outputs_3d, targets_3d)
                error_threshold = 10000
                if(loss_3d_pos.item() > error_threshold):
                  print("Encountered bad Data in Training, skipped this batch")
                  continue
                
                loss_3d_pos.backward()
                
                if max_norm:
                        nn.utils.clip_grad_norm_(model_pos.parameters(), max_norm=1)
                optimizer.step()
                epoch_loss_3d_pos.update(loss_3d_pos.item(), num_poses)
                  

                # Measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {ttl:} | ETA: {eta:} ' \
                                         '| Loss: {loss: .4f}' \
                        .format(batch=i + 1, size=len(data_loader), data=data_time.avg, bt=batch_time.avg,
                                        ttl=bar.elapsed_td, eta=bar.eta_td, loss=epoch_loss_3d_pos.avg)
                bar.next()

        bar.finish()


        return epoch_loss_3d_pos.avg, lr_now, step


def evaluate_mono_or_mlp(data_loader, model_pos, device):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        epoch_loss_3d_pos = AverageMeter()
        epoch_loss_3d_pos_procrustes = AverageMeter()

        # Switch to evaluate mode
        torch.set_grad_enabled(False)
        model_pos.eval()
        end = time.time()

        bar = Bar('Eval ', max=len(data_loader))
        for i, (targets_3d, inputs_2d, _) in enumerate(data_loader):
                # Measure data loading time
                data_time.update(time.time() - end)
                num_poses = targets_3d.size(0)

                ###Flatten FOR MLP, comment if MLP
                # inputs_2d = torch.flatten(inputs_2d,start_dim = 1)
                # targets_3d = torch.flatten(targets_3d,start_dim = 1)

                inputs_2d = inputs_2d.to(device)
                outputs_3d = model_pos(inputs_2d).cpu()
                outputs_3d[:, :] -= outputs_3d[:, :1]  # Zero-centre the root (hip)

                crit = nn.MSELoss(reduction='mean').to(device)
                ev_loss = crit(outputs_3d, targets_3d).item()
                error_threshold = 100
                if(ev_loss < error_threshold):
                  epoch_loss_3d_pos.update(ev_loss, num_poses)
                else :
                   print("Encountered bad Data in Evaluation, skipped this batch")


                # Measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {ttl:} | ETA: {eta:} ' \
                                         '| MPJPE: {e1: .4f} | P-MPJPE: {e2: .4f}' \
                        .format(batch=i + 1, size=len(data_loader), data=data_time.avg, bt=batch_time.avg,
                                        ttl=bar.elapsed_td, eta=bar.eta_td, e1=epoch_loss_3d_pos.avg, e2=0)
                bar.next()

        bar.finish()
        return epoch_loss_3d_pos.avg, 0


def train_multi(data_loader, model_pos, criterion, optimizer, device, lr_init, lr_now, step, decay, gamma, max_norm=True):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        epoch_loss_3d_pos = AverageMeter()

        # Switch to train mode
        torch.set_grad_enabled(True)
        model_pos.train()
        end = time.time()

        bar = Bar('Train', max=len(data_loader))
        for i, (targets_3d, input_1,input_2,input_3,input_4, _) in enumerate(data_loader):
                # Measure data loading time
                data_time.update(time.time() - end)
                num_poses = targets_3d.size(0)

                step += 1
                if step % decay == 0 or step == 1:
                        lr_now = lr_decay(optimizer, step, lr_init, decay, gamma)
                input_1,input_2,input_3,input_4 = input_1.to(device),input_2.to(device),input_3.to(device),input_4.to(device)
                targets_3d = targets_3d.to(device)
                outputs_3d = model_pos(input_1,input_2,input_3,input_4)

                optimizer.zero_grad()
                loss_3d_pos = criterion(outputs_3d, targets_3d)
                error_threshold = 10000
                if(loss_3d_pos.item() > error_threshold):
                  print("Encountered bad Data in Training, skipped this batch")
                  continue
                loss_3d_pos.backward()
                if max_norm:
                        nn.utils.clip_grad_norm_(model_pos.parameters(), max_norm=1)
                optimizer.step()

                epoch_loss_3d_pos.update(loss_3d_pos.item(), num_poses)

                # Measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {ttl:} | ETA: {eta:} ' \
                                         '| Loss: {loss: .4f}' \
                        .format(batch=i + 1, size=len(data_loader), data=data_time.avg, bt=batch_time.avg,
                                        ttl=bar.elapsed_td, eta=bar.eta_td, loss=epoch_loss_3d_pos.avg)
                bar.next()

        bar.finish()
        return epoch_loss_3d_pos.avg, lr_now, step


def evaluate_multi(data_loader, model_pos, device):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        epoch_loss_3d_pos = AverageMeter()
        epoch_loss_3d_pos_procrustes = AverageMeter()

        # Switch to evaluate mode
        torch.set_grad_enabled(False)
        model_pos.eval()
        end = time.time()

        bar = Bar('Eval ', max=len(data_loader))
        for i, (targets_3d, input_1,input_2,input_3,input_4, _) in enumerate(data_loader):
                # Measure data loading time
                data_time.update(time.time() - end)
                num_poses = targets_3d.size(0)
                input_1,input_2,input_3,input_4 = input_1.to(device),input_2.to(device),input_3.to(device),input_4.to(device)
                
                outputs_3d = model_pos(input_1,input_2,input_3,input_4).cpu()
                outputs_3d[:, :, :] -= outputs_3d[:, :1, :]  # Zero-centre the root (hip)

                # epoch_loss_3d_pos.update(mpjpe(outputs_3d, targets_3d).item() * 1000.0, num_poses)
                # epoch_loss_3d_pos_procrustes.update(p_mpjpe(outputs_3d.numpy(), targets_3d.numpy()).item() * 1000.0, num_poses)
                crit = nn.MSELoss(reduction='mean').to(device)
                ev_loss = crit(outputs_3d, targets_3d).item()
                error_threshold = 10000
                if(ev_loss < error_threshold):
                  epoch_loss_3d_pos.update(ev_loss, num_poses)
                else :
                   print("Encountered bad Data in Evaluation, skipped this batch")
                  
                

                # Measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {ttl:} | ETA: {eta:} ' \
                                         '| MPJPE: {e1: .4f} | P-MPJPE: {e2: .4f}' \
                        .format(batch=i + 1, size=len(data_loader), data=data_time.avg, bt=batch_time.avg,
                                        ttl=bar.elapsed_td, eta=bar.eta_td, e1=epoch_loss_3d_pos.avg, e2=epoch_loss_3d_pos_procrustes.avg)
                bar.next()

        bar.finish()
        return epoch_loss_3d_pos.avg, 0





In [5]:
adj_matrix = torch.tensor([[0.2500, 0.2500, 0.0000, 0.0000, 0.2500, 0.0000, 0.0000, 0.2500, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.0000, 0.0000, 0.0000, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3333, 0.3333, 0.3333, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.5000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3333, 0.3333,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2000, 0.2000,
         0.2000, 0.2000, 0.0000, 0.0000, 0.2000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000,
         0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3333,
         0.0000, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3333,
         0.0000, 0.0000, 0.0000, 0.0000, 0.3333, 0.3333, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.3333, 0.3333, 0.3333],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.5000]])

In [8]:
def main(args):
        print('==> Using settings {}'.format(args))

        print('==> Loading dataset...')


        _2d_training = pickle_read(PROJECT_ROOT+"data/2D_CAMERA_TRAINING_FOR_3D_CAMERA.pickle")
        # _3d_training_CAM = pickle_read(PROJECT_ROOT+"data/3D_CAMERA_TRAINING_VF.pickle")
        _3d_training_WORLD = pickle_read(PROJECT_ROOT+"data/3D_WORLD_TRAINING.pickle")

        _2d_test = pickle_read(PROJECT_ROOT+"data/2D_CAMERA_TEST_FOR_3D_CAMERA.pickle")
        # _3d_test_CAM = pickle_read(PROJECT_ROOT+"data/3D_CAMERA_TEST_VF.pickle")
        _3d_test_WORLD = pickle_read(PROJECT_ROOT+"data/3D_WORLD_TEST.pickle") 

        

        actions_train = ['SomeAction1' for i in range(0,len(_3d_training_WORLD))]
        actions_test = ['SomeAction1' for i in range(0,len(_3d_test_WORLD))]
        
        train_loader_mono = DataLoader(PoseGenerator(_3d_training_WORLD, _2d_training, actions_train), batch_size=args.batch_size,
                                                            shuffle=True, num_workers=args.num_workers, pin_memory=True)


        valid_loader_mono = DataLoader(PoseGenerator(_3d_test_WORLD, _2d_test, actions_test), batch_size=args.batch_size,
                                                            shuffle=False, num_workers=args.num_workers, pin_memory=True)
        
        train_loader_multi = DataLoader(PoseGenerator_Multi(_3d_training_WORLD, _2d_training, actions_train), batch_size=args.batch_size,
                                                            shuffle=True, num_workers=args.num_workers, pin_memory=True) 

        valid_loader_multi = DataLoader(PoseGenerator_Multi(_3d_test_WORLD, _2d_test, actions_test), batch_size=args.batch_size,
                                                            shuffle=False, num_workers=args.num_workers, pin_memory=True)  
        
        
        ### Set this variable to True if using multiview methods
        multiview_mode = True 
        if(multiview_mode):
          train = train_multi
          evaluate = evaluate_multi
          train_loader = train_loader_multi
          valid_loader = valid_loader_multi
        else : 
          train = train_mono_or_mlp
          evaluate = evaluate_mono_or_mlp
          train_loader = train_loader_mono
          valid_loader = valid_loader_mono


        stride = args.downsample
        cudnn.benchmark = True
        device = torch.device("cuda")

        # Create model
        print("==> Creating model...")

        p_dropout = (None if args.dropout == 0.0 else args.dropout)
        adj = adj_matrix
        
        ## Baseline GCN Config
        # model_pos = SemGCN(adj, args.hid_dim, num_layers=args.num_layers, p_dropout=p_dropout,
        #                                      nodes_group=None).to(device)

        ## Configuration MLP
        # model_pos = mlp_model(nb_hidden_layers = 1).to(device)

        # Configuration Aggregation Sum
        model_pos = SemGCN_Sum_Non_Shared(adj, args.hid_dim, num_layers=args.num_layers, p_dropout=p_dropout,
                                             nodes_group=None).to(device)


        ## Configuration Aggregation Concatenation
        # model_pos = SemGCN_Concat_Non_Shared(adj, args.hid_dim, num_layers=args.num_layers, p_dropout=p_dropout,
        #                                      nodes_group=None).to(device)

        print("==> Total parameters: {:.2f}M".format(sum(p.numel() for p in model_pos.parameters()) / 1000000.0))

        criterion = nn.MSELoss(reduction='mean').to(device)
        optimizer = torch.optim.Adam(model_pos.parameters(), lr=args.lr)


        # Optionally resume from a checkpoint
        if args.resume or args.evaluate:
                ## Set the path to the model parameters to resume from or evaluate
                ckpt_path = "/content/drive/MyDrive/GCN/Final_MLP_WORLD/2023-01-19T13:45:49.676759/ckpt_best.pth.tar"

                if path.isfile(ckpt_path):
                        print("==> Loading checkpoint '{}'".format(ckpt_path))
                        ckpt = torch.load(ckpt_path)
                        start_epoch = ckpt['epoch']
                        error_best = ckpt['error']
                        glob_step = ckpt['step']
                        lr_now = ckpt['lr']
                        model_pos.load_state_dict(ckpt['state_dict'])
                        optimizer.load_state_dict(ckpt['optimizer'])
                        print("==> Loaded checkpoint (Epoch: {} | Error: {})".format(start_epoch, error_best))

                        if args.resume:
                                ckpt_dir_path = path.dirname(ckpt_path)
                                logger = Logger(path.join(ckpt_dir_path, 'log.txt'), resume=True)
                else:
                        raise RuntimeError("==> No checkpoint found at '{}'".format(ckpt_path))
        else:
                start_epoch = 0
                error_best = None
                glob_step = 0
                lr_now = args.lr
                ckpt_dir_path = path.join(PROJECT_ROOT+"Final_MLP_WORLD", datetime.datetime.now().isoformat())
                if not path.exists(ckpt_dir_path):
                  os.makedirs(ckpt_dir_path)
                  print('==> Making checkpoint dir: {}'.format(ckpt_dir_path))
                logger = Logger(os.path.join(ckpt_dir_path, 'log.txt'))
                logger.set_names(['epoch', 'lr', 'loss_train', 'error_eval_p1', 'error_eval_p2'])

        if args.evaluate :
          ## Code for evaluation, Adapt it for GCN or MLP
          print("The evaluation loss is", evaluate(valid_loader,model_pos,device)[0])
        else :


          ep_loss =[]
          eval_loss = []
          for epoch in range(start_epoch, args.epochs):
                  print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr_now))

                  # Train for one epoch
                  epoch_loss, lr_now, glob_step = train(train_loader, model_pos, criterion, optimizer, device, args.lr, lr_now,
                                                                                              glob_step, args.lr_decay, args.lr_gamma, max_norm=args.max_norm)

                  # Evaluate
                  error_eval_p1, error_eval_p2 = evaluate(valid_loader, model_pos, device)

                  ep_loss.append(epoch_loss)
                  eval_loss.append(error_eval_p1)

                  # Update log file
                  logger.append([epoch + 1, lr_now, epoch_loss, error_eval_p1, error_eval_p2])

                  # Save checkpoint
                  if error_best is None or error_best > error_eval_p1:
                          error_best = error_eval_p1
                          save_ckpt({'epoch': epoch + 1, 'lr': lr_now, 'step': glob_step, 'state_dict': model_pos.state_dict(),
                                              'optimizer': optimizer.state_dict(), 'error': error_eval_p1}, ckpt_dir_path, suffix='best')

                  if (epoch + 1) % args.snapshot == 0:
                          save_ckpt({'epoch': epoch + 1, 'lr': lr_now, 'step': glob_step, 'state_dict': model_pos.state_dict(),
                                              'optimizer': optimizer.state_dict(), 'error': error_eval_p1}, ckpt_dir_path)

          logger.close()

          # create figure and axis objects with subplots()
          fig,ax = plt.subplots()
          # make a plot
          ax.plot(
                  ep_loss,
                  color="red")
          # set x-axis label
          ax.set_xlabel("epoch", fontsize = 14)
          # set y-axis label
          ax.set_ylabel("Train loss",
                        color="red",
                        fontsize=14)
          # twin object for two different y-axis on the sample plot
          ax2=ax.twinx()
          # make a plot with different y-axis using second axis object
          ax2.plot(eval_loss,color="blue")
          ax2.set_ylabel("Evaluation loss",color="blue",fontsize=14)
          plt.show()
          # save the plot as a file
          fig.savefig(path.join(ckpt_dir_path, 'logAdapted.png'),
                      format='jpeg',
                      dpi=100,
                      bbox_inches='tight')
          return



In [None]:
main(parse_args())