In [None]:
import torch
import torch.utils.data as data
import torch.optim as optim
import torch.nn as nn
import torch.autograd as autograd
from torch.autograd import Variable

import os
from collections import OrderedDict
import pickle
from glob import glob

import numpy as np
import pandas as pd
import random
import math

import matplotlib.pyplot as plt
from datetime import datetime
from scipy import stats
import torch_two_sample.statistics_diff as diff
from scipy.stats import norm
import time

from sklearn.model_selection import KFold
from torch.utils.data import SubsetRandomSampler, DataLoader

from model.model import Model,Generator

pkl_dir = "../level_data"
levels = ['level0', 'level1', 'level3']

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

In [None]:
def angle(a, b, c):
    ang = math.degrees(math.atan2(c[1]-b[1], c[0]-b[0]) - math.atan2(a[1]-b[1], a[0]-b[0]))
    return ang + 360 if ang < 0 else ang

In [None]:
class TestAIMSCDataset(data.Dataset):
    def __init__(self, file_list, label_list, dim=2, max_frames=48, stride=30, data_type="2D", joint_num=16):
        super(TestAIMSCDataset, self).__init__()
        self.max_frames = max_frames
        
        self.clips = []
        self.labels = label_list
        self.subject_names = []
        self.ages = []
        self.angles = []
        
        # keypoint
        # 0    3    6    7         10
        # RHip LHip nose LShoulder RShoulder
        if joint_num == 5:
            self.joint_index = [0, 3, 6, 7, 10]
            
        index = 0
        for pkl_filename in file_list:
            # read label
            name = os.path.splitext(os.path.basename(pkl_filename))[0]
            ID = int(name.split('_')[0])
            age = name.split('_')[1]
            age = int(float(age.split('m')[0]))
            age_aug = np.ones((1, 10)) * age / 14.
            self.ages.append(age_aug)
            level = int(pkl_filename.split("3D/level")[1][:1])
            level = 2 if level == 3 else level

            # read skeletons
            if data_type == "3D":
                skeleton_dict = np.load(pkl_filename)
            else:
                rfile = open(pkl_filename, "rb")
                skeleton_dict = pickle.load(rfile)

            skeleton_np = np.zeros((len(skeleton_dict), joint_num, dim))
            head_angle = 0
            for i in range(len(skeleton_dict)):
                if dim == 2 and data_type != "3D":
                    sk = skeleton_dict[i][re_order_indices]
                elif joint_num == 5:
                    sk = skeleton_dict[i][self.joint_index]
                else:
                    sk = skeleton_dict[i]
                    
                if joint_num == 5:
                    head_angle += angle(sk[2], (sk[0]+sk[1])/2, (sk[3]+sk[4])/2)
                    G = (skeleton_dict[i][0] + skeleton_dict[i][1]) / 2
                else:
                    head_angle += angle(sk[2], sk[7], sk[10])
                    G = skeleton_dict[i][7]
                
                skeleton_np[i] = sk - G
            head_angle = np.ones(joint_num) * (head_angle / 360) / len(skeleton_dict)
            self.angles.append(head_angle)

            max_coord = np.max(skeleton_np, axis=(0, 1))
            min_coord = np.min(skeleton_np, axis=(0, 1))
            skeleton_np = (skeleton_np - min_coord) / (max_coord - min_coord)

            age_np = np.ones((len(skeleton_dict), joint_num, 1)) * age / 14.
            level_np = np.ones((len(skeleton_dict), joint_num, 1)) * level / 3.
            skeleton_np = np.concatenate((skeleton_np, age_np, level_np), axis=2)
            
            skeletons = []

            # fix the number of frames
            frame_data = np.zeros((max_frames, joint_num, dim+2))
            if max_frames < len(skeleton_dict): # drop last k frames with 1 step
                step = len(skeleton_dict) // max_frames
                compensation = len(skeleton_dict) - (max_frames * step)
                j, k = 0, 0
                for i in range(len(skeleton_dict)):
                    if(i == j):
                        frame_data[k] = skeleton_np[i]
                        if k < (max_frames - compensation):
                            j = j + step
                        else:
                            j = j + (step + 1)
                        k += 1
            else: # padding with first k frames
                step = max_frames // len(skeleton_dict)
                compensation = max_frames - (len(skeleton_dict) * step)
                j, repeat_times = 0, 0
                for i in range(max_frames):
                    frame_data[i] = skeleton_np[j]
                    repeat_times += 1
                    if j < compensation:
                        if (repeat_times % (step + 1)) == 0:
                            repeat_times = 0
                            j += 1
                    else:
                        if (repeat_times % step) == 0:
                            repeat_times = 0
                            j += 1
            skeletons.append(frame_data)
            
            self.clips.append(skeletons)
            self.subject_names.append(name)
            index += 1
            

    def __getitem__(self, index):
        clips = torch.tensor(self.clips[index], dtype=torch.float)
        label = torch.tensor(self.labels[index], dtype=torch.float)
        subject_name = self.subject_names[index]
        head_angle = torch.tensor(self.angles[index], dtype=torch.float)
        
        return clips, label, subject_name, head_angle
        # return clips, subject_name

    def __len__(self):
        # return len(self.labels)
        return len(self.subject_names)

In [None]:
def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1, 1))).to(device)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(torch.Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False).to(device)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    # gradients = gradients.view(gradients.size(0), -1)
    gradients = gradients.reshape(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty


In [None]:
def train(max_epoch, critic_step, train_loader, generator, discriminator, optimizer_G, optimizer_D, save_model, model_name):
    batches_done = 0
    loss_G_plot = []
    loss_D_plot = []
    best_G_loss = float("inf")
    best_D_loss = float("inf")
    for epoch in range(max_epoch):
        for i, train_data in enumerate(train_loader, 0):
            inputs, labels, names, angles = train_data
            inputs = inputs.to(device)
            labels = labels.to(device)

            # (1, data_dim, frames, num_point, num_person)
            inputs = inputs.permute(0, 4, 2, 3, 1)
            
            # Configure input
            real_skeleton = inputs

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            # Generate a batch of images
            fake_skeleton = generator(real_skeleton)

            # Real images
            real_validity = discriminator(real_skeleton)
            # Fake images
            fake_validity = discriminator(fake_skeleton)
            # Gradient penalty
            gradient_penalty = compute_gradient_penalty(discriminator, real_skeleton, fake_skeleton)
            # wasserstein distance=-torch.mean(real_validity) + torch.mean(fake_validity)
            wasserstein_d = -torch.mean(real_validity) + torch.mean(fake_validity)
            # Adversarial loss
            d_loss = wasserstein_d + lambda_gp * gradient_penalty

            d_loss.backward()
            optimizer_D.step()

            optimizer_G.zero_grad()

            # Train the generator every critic steps
            if i % critic_step == 0:

                # -----------------
                #  Train Generator
                # -----------------

                # Generate a batch of images
                fake_skeleton = generator(real_skeleton)
                # Loss measures generator's ability to fool the discriminator
                # Train on fake images
                fake_validity = discriminator(fake_skeleton)
                g_loss = -torch.mean(fake_validity)

                g_loss.backward()
                optimizer_G.step()

                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                    % (epoch, max_epoch, i, len(train_loader), d_loss.item(), g_loss.item())
                )
                
                loss_G_plot.append(g_loss.item())
                loss_D_plot.append(d_loss.item())
                
                # update best loss
                if d_loss.item() <= 100 and abs(g_loss.item()) <= abs(best_G_loss):
                    best_G_loss = g_loss.item()
                    
                    # save best model
                    if save_model and d_loss.item() <= best_D_loss+10 and d_loss.item() <= 100:
                        PATH = './weights/wgan_G_best_{}.pth'.format(model_name)
                        torch.save(generator.state_dict(), PATH)
                        
                if abs(d_loss.item()) <= abs(best_D_loss):
                    best_D_loss = d_loss.item()

                batches_done += critic_step
                
                
    # plot loss
    epochs = []
    for i in range(len(loss_G_plot)):
        epochs.append(i)
    plt.title("Generator Loss")
    plt.xlabel("step")
    plt.ylabel("Loss")
    plt.plot(epochs, loss_G_plot)
    plt.show()
    
    plt.title("Discriminator Loss")
    plt.xlabel("step")
    plt.ylabel("Loss")
    plt.plot(epochs, loss_D_plot)
    plt.show()


In [None]:
def std(x, mean):
    std = 0.0
    for i in range(len(x)):
        std += (x[i] - mean)**2
    std /= len(x)
    std = math.sqrt(std)
    
    return std

In [None]:
def mean(x):
    mean = 0.0
    for i in range(len(x)):
        mean += x[i]
    mean /= len(x)
    
    return mean

In [None]:
def trainClassifier(device, dataset, batch_size, max_epoch, optimizer, criterioin, scheduler, model, data_aug,
                    fake_real, pass_path, fail_path, alpha, bootstrap, joint_num):
    # load model
    if data_aug:
        generator_pass = Generator(num_class=1, num_point=joint_num, num_person=1, in_channels=4,
                     out_channels=300, frames=96, alpha=alpha)
        generator_fail = Generator(num_class=1, num_point=joint_num, num_person=1, in_channels=4,
                     out_channels=300, frames=96, alpha=alpha)
        PATH_pass = pass_path
        PATH_fail = fail_path
        generator_pass.load_state_dict(torch.load(PATH_pass))
        generator_fail.load_state_dict(torch.load(PATH_fail))
        generator_pass.to(device)
        generator_fail.to(device)
    
    m = nn.Sigmoid()

    # separate to fold
    kfold = KFold(n_splits=5, shuffle=True)
    
    print("dataset size: {}".format(len(dataset)))
    acc_list = []
    sen_list = []
    spe_list = []
    time_list = []
    # set train/validation dataset
    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
        print("======================fold {}======================".format(fold))
        train_sampler = SubsetRandomSampler(train_idx)
        val_sampler = SubsetRandomSampler(val_idx)
        
        train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
        val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_sampler)
        
        # train
        model.train()
        
        train_acc_list = []
        val_acc_list = []
        train_loss_list = []
        val_loss_list = []
        best_acc = 0
        for epoch in range(max_epoch):
            # train:
            # load batch
            train_acc = 0
            train_TP = 0
            train_TN = 0
            train_FP = 0
            train_FN = 0
            total = 0
            running_loss = 0.0
            for i, train_data in enumerate(train_loader, 0):
                inputs, labels, names, angles = train_data
                inputs = inputs.to(device)
                labels = labels.to(device)
                # (1, data_dim, frames, num_point, num_person)
                inputs = inputs.permute(0, 4, 2, 3, 1)
                
                # initialize
                optimizer.zero_grad()

                if data_aug:
                    # generate new data (pass/fail)
                    pass_data = []
                    fail_data = []
                    for j in range(len(inputs)):
                        if labels[j] == 1:
                            pass_data.append(inputs[j])
                        else:
                            fail_data.append(inputs[j])

                    # loop until match size
                    pass_done = False
                    fail_done = False
                    
                    # compute data size
                    total_data_size = int(len(inputs) * (1 + fake_real))
                    new_pass_size = int(total_data_size / 2) - len(pass_data)
                    # len(fail_data) may greater than total_data_size - int(total_data_size / 2)
                    new_fail_size = total_data_size - int(total_data_size / 2) - len(fail_data)                    
                    
                    # initialize
                    if len(pass_data) != 0:
                        new_pass = generator_pass(torch.stack(pass_data).to(device))
                    if new_fail_size > 0:
                        new_fail = generator_fail(torch.stack(fail_data).to(device))
                    
                    while not pass_done or not fail_done:
                        # generate pass data
                        if len(pass_data) == 0:
                            pass_done = True
                            pass
                        elif len(new_pass) == 0:
                            new_pass = generator_pass(torch.stack(pass_data).to(device))
                        elif len(new_pass) < new_pass_size:
                            new_pass = torch.cat((new_pass, generator_pass(torch.stack(pass_data).to(device))), dim=0)
                        else:
                            pass_done = True
                            new_pass = new_pass[:new_pass_size]
                        
                        # generate fail data
                        if new_fail_size <= 0:
                            fail_done = True
                            pass
                        elif len(new_fail) < new_fail_size:
                            new_fail = torch.cat((new_fail, generator_fail(torch.stack(fail_data).to(device))), dim=0)
                        else:
                            fail_done = True
                            new_fail = new_fail[:new_fail_size]
                            
                    if len(new_pass) > 0:
                        new_data = torch.cat((new_pass, new_fail), dim=0)
                    elif new_fail_size > 0:
                        new_data = new_fail
                    else:
                        new_data = new_pass
                    
                    new_label = torch.tensor([1 for i in range(len(new_pass))]+[0 for i in range(len(new_fail))]).to(device)

                    print("pass: {}, fail: {}".format(len(pass_data)+len(new_pass), len(fail_data)+len(new_fail)))
                    print("new_data size: {}, new_label size: {}".format(new_data.size(), new_label.size()))

                    # drop fail data to match size
                    if new_fail_size < 0:
                        fail_data = fail_data[:total_data_size-int(total_data_size/2)]
                    
                    # concat
                    # concat inputs if fail data is drop
                    if new_fail_size < 0:
                        inputs = torch.cat((pass_data, new_data), dim=0)
                        labels = torch.tensor([1 for i in range(len(pass_data))]+[0 for i in range(len(fail_data))]).to(device)
                    # concat inputs and new data
                    concat_data = torch.cat((inputs, new_data), dim=0)
                    concat_labels = torch.cat((labels, new_label), dim=0)
                    
                    # shuffle
                    shuffle_index = torch.randperm(len(concat_labels))
                    concat_data = concat_data[shuffle_index]
                    concat_labels = concat_labels[shuffle_index]
                    
                elif bootstrap: # bootstrap sampling
                    boot_index = np.random.choice(len(inputs), int(len(inputs) * (1 + fake_real)))
                    concat_data = [inputs[i] for i in boot_index]
                    concat_labels = [labels[i] for i in boot_index]
                    concat_data = torch.stack(concat_data).to(device)
                    concat_labels = torch.stack(concat_labels).to(device)
                    
                else: # no data_aug
                    concat_data = inputs
                    concat_labels = labels

                # classify
                outputs = model(concat_data)
                predicted = outputs[:,0].clone()
                predicted[predicted >= 0.5] = 1
                predicted[predicted < 0.5] = 0

                # compute loss/acc
                loss = criterion(m(outputs)[:,0], concat_labels)
                running_loss += loss.item()

                train_acc += (predicted == concat_labels).sum().item()
                for p in range(len(predicted)):
                    if predicted[p] == 1 and concat_labels[p] == 1:
                        train_TP += 1
                    elif predicted[p] == 0 and concat_labels[p] == 0:
                        train_TN += 1
                    elif predicted[p] == 1 and concat_labels[p] == 0:
                        train_FP += 1
                    else:
                        train_FN += 1
#                 print("curr acc: {}".format(train_acc))
                total += len(predicted)


                # updat weight
                loss.backward()
                optimizer.step()

            print("[train] epoch: {}, loss: {:.3f}, acc: {:.2f}, total: {}"
                  .format(epoch+1, running_loss/total, (train_acc/total)*100, total))
            print("sen: {:.2f}, spe: {:.2f}"
                  .format((train_TP/(train_TP+train_FN))*100, (train_TN/(train_TN+train_FP))*100))
            train_acc_list.append((train_acc/total)*100)
            train_loss_list.append(running_loss/total)
            
            scheduler.step()
            
            # validation:
            with torch.no_grad():
                model.eval()
                
                start = time.time()
                # load batch
                val_loss = 0.0
                val_acc = 0
                val_TP = 0
                val_TN = 0
                val_FP = 0
                val_FN = 0
                total = 0
                
                # check pass/fail size
                pass_size = 0
                fail_size = 0
                for i, val_data in enumerate(val_loader, 0):
                    inputs, labels, names, angles = val_data
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    # (1, data_dim, frames, num_point, num_person)
                    inputs = inputs.permute(0, 4, 2, 3, 1)

                    pass_size += len(labels[labels==1])
                    fail_size += len(labels[labels==0])
                    # inference
                    outputs = model(inputs)
                    predicted = outputs[:,0].clone()
                    predicted[predicted >= 0.5] = 1
                    predicted[predicted < 0.5] = 0

                    # compute loss/acc
                    loss = criterion(m(outputs)[:,0], labels)
                    val_loss += loss.item()

    #                 print("predicted:")
    #                 print(predicted)
    #                 print("labels:")
    #                 print(labels)
                    val_acc += (predicted == labels).sum().item()
                    for p in range(len(predicted)):
                        if predicted[p] == 1 and labels[p] == 1:
                            val_TP += 1
                        elif predicted[p] == 0 and labels[p] == 0:
                            val_TN += 1
                        elif predicted[p] == 1 and labels[p] == 0:
                            val_FP += 1
                        else:
                            val_FN += 1
                    total += len(predicted)

                    end = time.time()
                    inference_time = (end - start) / total

                print("[validation] epoch: {}, loss: {:.3f}, acc: {:.2f}, total: {}, pass:fail={}:{}"
                      .format(epoch+1, val_loss/total, (val_acc/total)*100, total, pass_size, fail_size))
                print("sen: {:.2f}, spe: {:.2f}, inference time: {:.4f}"
                      .format((val_TP/(val_TP+val_FN))*100, (val_TN/(val_TN+val_FP))*100, inference_time))
                val_acc_list.append((val_acc/total)*100)
                val_loss_list.append(val_loss/total)

                if val_acc/total >= best_acc:
                    best_acc = val_acc/total
                    PATH = './weights/fold{}_best_gcn.pth'.format(fold)
                    torch.save(model.state_dict(), PATH)

                if epoch == max_epoch-1:
                    acc_list.append(val_acc_list[epoch])
                    sen_list.append((val_TP/(val_TP+val_FN))*100)
                    spe_list.append((val_TN/(val_TN+val_FP))*100)
                    time_list.append(inference_time)
            
        epochs = [(i) for i in range(max_epoch)]
        plt.title("Fold {} Accuracy".format(fold))
        plt.plot(epochs, train_acc_list, label="train acc")
        plt.plot(epochs, val_acc_list, label="val acc")
        plt.legend()
        plt.show()
            
        plt.title("Fold {} Loss".format(fold))
        plt.plot(epochs, train_loss_list, label="train loss")
        plt.plot(epochs, val_loss_list, label="val loss")
        plt.legend()
        plt.show()
        
    M_acc = mean(acc_list)
    print("acc :{:.2f}+-{:.2f}".format(M_acc, std(acc_list, M_acc)))
    M_sen = mean(sen_list)
    print("sen :{:.2f}+-{:.2f}".format(M_sen, std(sen_list, M_sen)))
    M_spe = mean(spe_list)
    print("spe :{:.2f}+-{:.2f}".format(M_spe, std(spe_list, M_spe)))
    M_time = mean(time_list)
    print("time :{:.4f}+-{:.4f}".format(M_time, std(time_list, M_time)))
    
    return M_acc, M_sen, M_spe, M_time


In [None]:
def test(model, test_dataset, batch_size, device):
    # loop for 5 fold
    acc_list = []
    sen_list = []
    spe_list = []
    time_list = []
    for fold in range(5):
        # load model
        PATH = './weights/fold{}_best_gcn.pth'.format(fold)
        model.load_state_dict(torch.load(PATH))
        model.to(device)
        
        # load data
        test_loader = DataLoader(test_dataset, batch_size=batch_size)
        
        # test
        with torch.no_grad():
            model.eval()
            
            # start time stamp
            start = time.time()

            test_acc = 0
            test_TP = 0
            test_TN = 0
            test_FP = 0
            test_FN = 0
            total = 0
            for i, test_data in enumerate(test_loader, 0):
                inputs, labels, names, angles = test_data
                inputs = inputs.to(device)
                labels = labels.to(device)
                # (1, data_dim, frames, num_point, num_person)
                inputs = inputs.permute(0, 4, 2, 3, 1)

                # inference
                outputs = model(inputs)
                predicted = outputs[:,0].clone()
                predicted[predicted >= 0.5] = 1
                predicted[predicted < 0.5] = 0

                test_acc += (predicted == labels).sum().item()
                for p in range(len(predicted)):
                    if predicted[p] == 1 and labels[p] == 1:
                        test_TP += 1
                    elif predicted[p] == 0 and labels[p] == 0:
                        test_TN += 1
                    elif predicted[p] == 1 and labels[p] == 0:
                        test_FP += 1
                    else:
                        test_FN += 1
                total += len(predicted)

                end = time.time()
                inference_time = (end - start) / total

            print("[test] fold: {}, acc: {:.2f}, total: {}".format(fold, (test_acc/total)*100, total))
            print("sen: {:.2f}, spe: {:.2f}, inference time: {:.4f}"
                      .format((test_TP/(test_TP+test_FN))*100, (test_TN/(test_TN+test_FP))*100, inference_time))

            acc_list.append((test_acc/total)*100)
            sen_list.append((test_TP/(test_TP+test_FN))*100)
            spe_list.append((test_TN/(test_TN+test_FP))*100)
            time_list.append(inference_time)

    M = mean(acc_list)
    print("acc :{:.2f}+-{:.2f}".format(M, std(acc_list, M)))
    M = mean(sen_list)
    print("sen :{:.2f}+-{:.2f}".format(M, std(sen_list, M)))
    M = mean(spe_list)
    print("spe :{:.2f}+-{:.2f}".format(M, std(spe_list, M)))
    M = mean(time_list)
    print("time :{:.4f}+-{:.4f}".format(M, std(time_list, M)))


In [None]:
def find_files(filename, search_path):
    result = []

    # Wlaking top-down from the root
    for root, dir, files in os.walk(search_path):
        if filename in files:
            result.append(os.path.join(root, filename))
            
    return result


In [None]:
def getFileName(dir_name):
    pos = dir_name.find("/")
    while pos != -1:
        dir_name = dir_name[pos+1:]
        pos = dir_name.find("/")

    return dir_name

In [None]:
model = Generator(num_class=1, num_point=13, num_person=1, in_channels=4,
                 out_channels=300, frames=96)
print(model)

In [None]:
input_dir = "../../EvoSkeleton-master/data/human3.6M_origin/threeDPose_train.npy"
x = np.load(input_dir, allow_pickle=True)
print(x)
# subjects = glob(input_dir + "/*/*/*") # {level}/{view}/{file}
# print(len(subjects))

In [None]:
# Loss weight for gradient penalty
lambda_gp = 10

# settings
data_dim = 3
in_channels = data_dim + 1 # coord. + age
out_channels = 300
frames = 96
stride = 60
re_order_indices= [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16]
data_type = "3D"
joint_num = 5

max_epoch = 200 # 1000 is enough
batch_size = 128
lr_G = 1e-4
lr_D = 1e-4
beta1 = 0
beta2 = 0.9
critic_step = 5
save_model = True

# Initialize generator and discriminator
generator = Generator(num_class=1, num_point=joint_num, num_person=1, in_channels=in_channels,
                 out_channels=out_channels, frames=frames)
discriminator = Model(num_class=1, num_point=joint_num, num_person=1, graph="graph.h36m.Graph",
                              in_channels=in_channels, out_channels=out_channels, frames=frames)
# replace batch normalization with layer normalization
# input shape = (batch_size, num_person * in_channels * num_point, frames)
discriminator.data_bn = nn.LayerNorm([1 * in_channels * joint_num, frames])

generator = generator.to(device)
discriminator = discriminator.to(device)

# data
input_dir = "{}/3D".format(pkl_dir)
subjects = glob(input_dir + "/level3/*/*") # {level}/{view}/{file}

sample = subjects
label = [(1) for i in range(len(sample))]

dataset = TestAIMSCDataset(file_list=sample, label_list=label, dim=data_dim, max_frames=frames, stride=stride,
                           data_type=data_type, joint_num=joint_num)
train_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)

# optimizer
# wgan_gp
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr_G, betas=(beta1, beta2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr_D, betas=(beta1, beta2))
# wgan
# optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=lr_G)
# optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=lr_D)

model_name = "level3_5_joint"
train(max_epoch, critic_step, train_loader, generator, discriminator, optimizer_G, optimizer_D, save_model, model_name)

In [None]:
# train: origin + gan
# test: origin
# try 8 ratio
# fake_real_list = [0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2]
# fake_real_list = [1]
fake_real = 1
alpha_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
# loop for 8 ratio
total_acc_list = []
total_sen_list = []
total_spe_list = []
total_time_list = []
for alpha in alpha_list:
    print("fake_real: {}".format(fake_real))
    print("===============================")
    
    # dataset
    # parse csv
    file = pd.read_csv('Pull_to_sit_AIMS_0209.csv')
    file_names = np.squeeze(file.values)[:,1]
    labels = np.squeeze(file.values)[:,4]

    file_list = []
    label_list = []

    not_exist = 0
    invalid = 0
    for i in range(len(file_names)):
        temp = file_names[i]
        if not isinstance(file_names[i], str):
            continue
        if temp.find("\t") != -1:
            temp = temp[0:-2]
        if temp.find(".mp4") != -1:
            temp = temp[0:-4]

        temp = temp + ".npy"
#         temp = temp + ".pkl"
        file_dir = find_files(temp, "../level_data/3D")
        if len(file_dir) == 0:
            print("file {} does not exist!".format(temp))
            not_exist += 1
        elif not isinstance(labels[i], str) or (labels[i] != '1' and labels[i] != '0'):
            print("invalid label {}".format(labels[i]))
            invalid += 1
        else:
            file_list.append(file_dir[0])
            if labels[i] == '1':
                label_list.append(1)
            else:
                label_list.append(0)

    # drop some fail data to balance with pass data
    pass_index = [i for i in range(len(label_list)) if label_list[i] == 1]
    pass_size = len(pass_index)

    fail_index = [i for i in range(len(label_list)) if label_list[i] == 0]
    final_fail = np.random.choice(np.array(fail_index), pass_size, replace=False).tolist()

    file_index = pass_index + final_fail

    file_list = [file_list[i] for i in file_index]
    labels_list = [label_list[i] for i in file_index]

    print(file_list)
    print(label_list)
    print("origin data size: {}".format(len(labels)))
    print("used data size: {}".format(len(file_list)))
    print("number of not exist files: {}".format(not_exist))
    print("number of invalid label: {}".format(invalid))

    data_dim = 3
    frames = 96
    stride = 60
    data_type = "3D"
    joint_num = 13

    dataset = TestAIMSCDataset(file_list=file_list, label_list=label_list, dim=data_dim, max_frames=frames, stride=stride,
                               data_type=data_type, joint_num=joint_num)

    # model
    in_channels = data_dim + 2
    out_channels = 300
    model = Model(num_class=1, num_point=joint_num, num_person=1, graph="graph.h36m.Graph",
                                  in_channels=in_channels, out_channels=out_channels, frames=frames)
#     model.load_state_dict(torch.load("./weights/pull_1.pt"))
    model.to(device)

    # parameters
    batch_size = 64
    max_epoch = 100
    optimizer = optim.SGD(model.parameters(), lr=5e-4, momentum=0.9, weight_decay=0)
#     optimizer = optim.Adam(model.parameters(), lr=5e-4)
    criterion = nn.BCELoss()
#     step_list = [(i) for i in range(100)]
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60], gamma=0.5)
#     scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epoch)
    data_aug = False
    pass_path = './weights/wgan_G_best_level3_5_joint.pth'
    fail_path = './weights/wgan_G_best_all_5_joint.pth'
    alpha = alpha
    bootstrap = False

    # train
    acc, sen, spe, inf_time = trainClassifier(device, dataset, batch_size, max_epoch, optimizer, criterion, scheduler,
                                              model, data_aug, fake_real, pass_path, fail_path, alpha, bootstrap,
                                             joint_num)
    
    total_acc_list.append(acc)
    total_sen_list.append(sen)
    total_spe_list.append(spe)
    total_time_list.append(inf_time)
    
print("==============final result==============")
print("ratio:\t0.1\t0.2\t0.3\t0.4\t0.5\t0.6\t0.7\t0.8\t0.9\t1")
print("acc:\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}"
      .format(total_acc_list[0], total_acc_list[1], total_acc_list[2], total_acc_list[3],
             total_acc_list[4], total_acc_list[5], total_acc_list[6], total_acc_list[7],
             total_acc_list[8], total_acc_list[9]))
print("sen:\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}"
      .format(total_sen_list[0], total_sen_list[1], total_sen_list[2], total_sen_list[3],
             total_sen_list[4], total_sen_list[5], total_sen_list[6], total_sen_list[7],
             total_sen_list[8], total_sen_list[9]))
print("spe:\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}"
      .format(total_spe_list[0], total_spe_list[1], total_spe_list[2], total_spe_list[3],
             total_spe_list[4], total_spe_list[5], total_spe_list[6], total_spe_list[7],
             total_spe_list[8], total_spe_list[9]))
print("acc:\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}"
      .format(total_time_list[0], total_time_list[1], total_time_list[2], total_time_list[3],
             total_time_list[4], total_time_list[5], total_time_list[6], total_time_list[7],
             total_time_list[8], total_time_list[9]))

In [None]:
# fake_real ratio 0
# 5-fold cross validation 30 times

total_acc_list = []
total_sen_list = []
total_spe_list = []
total_time_list = []
for i in range(30):
    print("{} times".format(i+1))
    # dataset
    # parse csv
    file = pd.read_csv('Pull_to_sit_AIMS_0209.csv')
    file_names = np.squeeze(file.values)[:,1]
    labels = np.squeeze(file.values)[:,4]

    file_list = []
    label_list = []

    not_exist = 0
    invalid = 0
    for i in range(len(file_names)):
        temp = file_names[i]
        if not isinstance(file_names[i], str):
            continue
        if temp.find("\t") != -1:
            temp = temp[0:-2]
        if temp.find(".mp4") != -1:
            temp = temp[0:-4]

        temp = temp + ".npy"
        file_dir = find_files(temp, "../level_data/3D")
        if len(file_dir) == 0:
            print("file {} does not exist!".format(temp))
            not_exist += 1
        elif not isinstance(labels[i], str) or (labels[i] != '1' and labels[i] != '0'):
            print("invalid label {}".format(labels[i]))
            invalid += 1
        else:
            file_list.append(file_dir[0])
            if labels[i] == '1':
                label_list.append(1)
            else:
                label_list.append(0)

    # drop some fail data to balance with pass data
    pass_index = [i for i in range(len(label_list)) if label_list[i] == 1]
    pass_size = len(pass_index)

    fail_index = [i for i in range(len(label_list)) if label_list[i] == 0]
    final_fail = np.random.choice(np.array(fail_index), pass_size, replace=False).tolist()

    file_index = pass_index + final_fail

    file_list = [file_list[i] for i in file_index]
    labels_list = [label_list[i] for i in file_index]

    print(file_list)
    print(label_list)
    print("origin data size: {}".format(len(labels)))
    print("used data size: {}".format(len(file_list)))
    print("number of not exist files: {}".format(not_exist))
    print("number of invalid label: {}".format(invalid))

    data_dim = 3
    frames = 96
    stride = 60
    data_type = "3D"
    joint_num = 13

    dataset = TestAIMSCDataset(file_list=file_list, label_list=label_list, dim=data_dim, max_frames=frames, stride=stride,
                               data_type=data_type, joint_num=13)

    # model
    in_channels = data_dim + 1
    out_channels = 300
    model = Model(num_class=1, num_point=joint_num, num_person=1, graph="graph.h36m.Graph",
                                  in_channels=in_channels, out_channels=out_channels, frames=frames)
    model.to(device)

    # parameters
    batch_size = 64
    max_epoch = 100
    optimizer = optim.SGD(model.parameters(), lr=5e-4, momentum=0.9, weight_decay=0)
    criterion = nn.BCELoss()
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60], gamma=0.5)
    data_aug = False
    fake_real = 1.25
    pass_path = './weights/wgan_G_best_level3_default_para.pth'
    fail_path = './weights/wgan_G_best_all_default_para.pth'
    alpha = 0.5
    bootstrap = True

    # train
    acc, sen, spe, inf_time = trainClassifier(device, dataset, batch_size, max_epoch, optimizer, criterion, scheduler,
                                              model, data_aug, fake_real, pass_path, fail_path, alpha, bootstrap)
    
    total_acc_list.append(acc)
    total_sen_list.append(sen)
    total_spe_list.append(spe)
    total_time_list.append(inf_time)

print("==============final result==============")
M = mean(total_acc_list)
print("acc :{:.2f}+-{:.2f}".format(M, std(total_acc_list, M)))
M = mean(total_sen_list)
print("sen :{:.2f}+-{:.2f}".format(M, std(total_sen_list, M)))
M = mean(total_spe_list)
print("spe :{:.2f}+-{:.2f}".format(M, std(total_spe_list, M)))
M = mean(total_time_list)
print("time :{:.4f}+-{:.4f}".format(M, std(total_time_list, M)))


In [None]:
a = torch.tensor([4, 8, 16, 32])
print(torch.pow(a, 1/3))

In [None]:
# baseline balance pass/fail
# 5-fold cross validation 30 times
total_acc_list = []
total_sen_list = []
total_spe_list = []
total_time_list = []
for i in range(30):
    print("{} times".format(i+1))
    # dataset
    # parse csv
    file = pd.read_csv('Pull_to_sit_AIMS_0209.csv')
    file_names = np.squeeze(file.values)[:,1]
    labels = np.squeeze(file.values)[:,4]

    file_list = []
    label_list = []

    not_exist = 0
    invalid = 0
    for i in range(len(file_names)):
        temp = file_names[i]
        if not isinstance(file_names[i], str):
            continue
        if temp.find("\t") != -1:
            temp = temp[0:-2]
        if temp.find(".mp4") != -1:
            temp = temp[0:-4]

        temp = temp + ".npy"
        file_dir = find_files(temp, "../level_data/3D")
        if len(file_dir) == 0:
            print("file {} does not exist!".format(temp))
            not_exist += 1
        elif not isinstance(labels[i], str) or (labels[i] != '1' and labels[i] != '0'):
            print("invalid label {}".format(labels[i]))
            invalid += 1
        else:
            file_list.append(file_dir[0])
            if labels[i] == '1':
                label_list.append(1)
            else:
                label_list.append(0)

    # drop some fail data to balance with pass data
    pass_index = [i for i in range(len(label_list)) if label_list[i] == 1]
    pass_size = len(pass_index)

    fail_index = [i for i in range(len(label_list)) if label_list[i] == 0]
    final_fail = np.random.choice(np.array(fail_index), pass_size, replace=False).tolist()

    file_index = pass_index + final_fail

    file_list = [file_list[i] for i in file_index]
    labels_list = [label_list[i] for i in file_index]

    print(file_list)
    print(label_list)
    print("origin data size: {}".format(len(labels)))
    print("used data size: {}".format(len(file_list)))
    print("number of not exist files: {}".format(not_exist))
    print("number of invalid label: {}".format(invalid))

    data_dim = 3
    frames = 96
    stride = 60
    data_type = "3D"
    joint_num = 13

    dataset = TestAIMSCDataset(file_list=file_list, label_list=label_list, dim=data_dim, max_frames=frames, stride=stride,
                               data_type=data_type, joint_num=13)

    # model
    in_channels = data_dim + 1
    out_channels = 300
    model = Model(num_class=1, num_point=joint_num, num_person=1, graph="graph.h36m.Graph",
                                  in_channels=in_channels, out_channels=out_channels, frames=frames)
    model.to(device)

    # parameters
    batch_size = 128
    max_epoch = 100
    optimizer = optim.SGD(model.parameters(), lr=5e-4, momentum=0.9, weight_decay=0)
    criterion = nn.BCELoss()
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60], gamma=0.5)
    data_aug = False
    fake_real = 0

    # train
    acc, sen, spe, inf_time = trainClassifier(device, dataset, batch_size, max_epoch, optimizer, criterion, scheduler, model,
                                          data_aug, fake_real)
    
    total_acc_list.append(acc)
    total_sen_list.append(sen)
    total_spe_list.append(spe)
    total_time_list.append(inf_time)

M = mean(total_acc_list)
print("acc :{:.2f}+-{:.2f}".format(M, std(total_acc_list, M)))
M = mean(total_sen_list)
print("sen :{:.2f}+-{:.2f}".format(M, std(total_sen_list, M)))
M = mean(total_spe_list)
print("spe :{:.2f}+-{:.2f}".format(M, std(total_spe_list, M)))
M = mean(total_time_list)
print("time :{:.4f}+-{:.4f}".format(M, std(total_time_list, M)))


In [None]:
# generate and save skeleton
best_generator = Generator()
PATH = './weights/wgan_G_best_epoch192.pth'
best_generator.load_state_dict(torch.load(PATH))
best_generator.to(device)

new_data = best_generator()

In [None]:
from matplotlib import pyplot as plt
import numpy as np
import mpl_toolkits.mplot3d.axes3d as p3
from matplotlib import animation


fig = plt.figure()
ax = p3.Axes3D(fig)

q = np.load(sample[4])
q = q[-96:]
print(q.shape)

q = np.reshape(q, (3, 96, 13))

x=np.array(q[0][0])
y=np.array(q[1][0])
z=np.array(q[2][0])

# lines index pair
pairs = [[0,1],[0,3],[0,10],[1,2],[3,4],[3,5],[6,7],[6,10],
        [7,8],[8,9],[10,11],[11,12]]
lines = []
for i in range(len(pairs)):
    line, = ax.plot(np.take(x, pairs[i]), np.take(y, pairs[i]), np.take(z, pairs[i]))
    lines.append(line)

points, = ax.plot(x, y, z, 'bo')
txt = fig.suptitle('')

def update_points(num, q, points, lines):
    txt.set_text('num={:d}'.format(num)) # for debug purposes

    # calculate the new sets of coordinates here. The resulting arrays should have the same shape
    # as the original x,y,z
    new_x = q[0][num]
    new_y = q[1][num]
    new_z = q[2][num]

    # update properties
    points.set_data(new_x,new_y)
    points.set_3d_properties(new_z, 'z')
    
    for i in range(len(pairs)):
        lines[i].set_data(np.take(new_x, pairs[i]),np.take(new_y,pairs[i]))
        lines[i].set_3d_properties(np.take(new_z, pairs[i]), 'z')
    

    # return modified artists
    return points,txt

ani=animation.FuncAnimation(fig, update_points, frames=96, fargs=(q, points, lines))


In [None]:
from IPython.display import HTML

HTML(ani.to_html5_video())