<a href="https://colab.research.google.com/github/skore11/Laban_Pose_Conditional_GAN/blob/main/LabanPoseTransformation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install aiohttp nest_asyncio tqdm c3d numpy scipy torch matplotlib

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import requests
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
#from tqdm import tqdm
from tqdm.notebook import tqdm
from zipfile import ZipFile

cuda = torch.cuda.is_available()
#cuda = False

In [None]:
# Given a 2D Matrix a, sort and compute the diffs of the rows.
# Only keep the row if the diff is greater than the 
# tolerance
# Returns the index array corresponding to the rows to keep
def rough_unique(a, tol=2**13):
    i = np.argsort(a, axis=0)[:,0]
    d = np.append(tol*2, np.mean(abs(np.diff(a[i], axis=0)), axis=1))
    return i[(d/a.shape[1])>tol]

DATA_ZIP = "./data.zip" # where the data is downloaded to
DATA_DIR = "./data" # where the data is unzipped to
def get_data_zip(
        url="https://cyprus-data.s3.us-east-2.amazonaws.com/data.zip",
        chunk_size=8192,
    ):
    """
    Ensure that dataset is downloaded and unzipped
    """
    if not os.path.isfile(DATA_ZIP): # not already downloaded
        filesize = int(requests.head(url).headers["Content-Length"])
        progress = tqdm(
            unit="B", 
            unit_scale=True, 
            unit_divisor=1024, 
            total=filesize, 
            desc=DATA_ZIP)
        r = requests.get(url, stream=True)
        with open(DATA_ZIP, 'wb') as fp:
            for chunk in r.iter_content(chunk_size=chunk_size):
                wrote = fp.write(chunk)
                progress.update(wrote)
            progress.close()
    if not os.path.isdir(DATA_DIR): # not already unzipped
        with ZipFile(DATA_ZIP, 'r') as zp:
            files = zp.namelist()
            progress = tqdm(
                total=len(files), 
                desc="Unzipping {}".format(DATA_DIR))
            for name in files:
                zp.extract(name)
                progress.update()
            progress.close()
    return

def cyprus_dataset(toTensor):
    """
    Loads our augmented version of the Cyprus Dataset,
    returns two TensorDatasets, one for train and one for test
    """
    MAX_ZEROS = 3 # max number of 0's a single pose can have
    if not os.path.isdir(DATA_DIR):
        raise Exception("{} does not exist, make sure the\
        data is downloaded and unzipped".format(DATA_DIR))
    get_data_zip() # Ensure data is downloaded
    
    # Get files from dataset
    npys = []
    for file in os.walk(DATA_DIR):
        npys = file[2]
        
    # Load the data
    labels_map = {
        'afraid':  0,
        'bored':   1,
        'excited': 2,
        'neutral': 3,
        'relaxed': 4,
    }
    poses = np.array([])
    labels = np.array([])
    for npy in tqdm(npys, desc="Loading data"):
        data = np.load(os.path.join(DATA_DIR,npy)).astype('float32')
        # Enforce MAX_ZEROS
        valid_rows = np.sum(data == 0., axis=-1) < MAX_ZEROS
        data = data[valid_rows]
        # Remove similar poses
        uis = rough_unique(data)
        data = data[uis]
            
        # Augment data with random scale and translations, both 
        # together and separately
        rand_scale = lambda: np.random.uniform(low=0.8,high=1.2,size=(data.shape[0],1))
        rand_trans = lambda: np.hstack((
            np.repeat(
                np.random.uniform(low=-100.0,high=100.0,size=(data.shape[0],3)), 
            38, axis=1),
            np.zeros((data.shape[0],17))))
        scale = np.vstack([data*rand_scale() for i in range(16)])
        trans = np.vstack([data+rand_trans() for i in range(16)])
        both = np.vstack([(data*rand_scale())+rand_trans() for i in range(16)])
        data = np.vstack([data,scale,trans,both])
                
        # Remove similar poses
        uis = rough_unique(data)
        data = data[uis]
                
        label = labels_map[npy.split("_")[0]]
                
        if poses.any(): # there already is data, add to it
            poses = np.vstack([poses,data])
            # repeat label for all samples in file
            labels = np.concatenate([
                labels,
                np.repeat(label,data.shape[0])], axis=None) 
        else: # first record, create arrays
            poses = np.array(data)
            labels = np.repeat(label,data.shape[0])     

    # Enforce MAX_ZEROS
    valid_rows = np.sum(poses == 0., axis=-1) < MAX_ZEROS
    poses = poses[valid_rows]
    labels = labels[valid_rows]
        
    # Remove similar poses
    uis = rough_unique(poses,2**12)
    poses = poses[uis]
    labels = labels[uis]
        
    cutoff = int(poses.shape[0] * 0.8)
    train_poses = toTensor(poses[:cutoff].astype('float32'))
    train_labels = toTensor(labels[:cutoff].astype('float32'))
    test_poses = toTensor(poses[cutoff:].astype('float32'))
    test_labels = toTensor(labels[cutoff:].astype('float32'))
    return TensorDataset(train_poses, train_labels), TensorDataset(test_poses, test_labels)
    
    

In [None]:
N_POINTS = 38*3
N_FEATURES = 17
N_DATA = N_POINTS+N_FEATURES
N_CLASSES = 5
N_NOISE = 19

def add_tensors(x,y):
    """
    Add x and y, trimming or zero-padding y to match the 
    size of x. Ensures z is on the same device as x
    """
    b,u,_ = x.size()
    v = y.size()[1]
    w = min(u,v)
    z = torch.zeros(b,u,1)
    z[:,:w] = y[:,:w]
    if cuda:
        z = z.cuda()
    return x+z

class ResConv(nn.Module):
    """
    Residual Convolutional layer from Generative Tweening
    https://arxiv.org/pdf/2005.08891.pdf
    """
    def __init__(self, in_ch, out_ch, kernel_size, stride, padding, ratio):
        super(ResConv, self).__init__()
        self.root_ratio = torch.sqrt(torch.tensor(float(max(0,ratio))))
        self.om_root_ratio = torch.sqrt(torch.tensor(float(max(0,1.0-ratio))))
        
        self.conv = nn.Conv1d(
            in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding)
        self.affine = nn.BatchNorm1d(out_ch, affine=True)
        self.prelu = nn.PReLU(out_ch)
        
    def forward(self, x):
        x_prime = self.conv(x)
        prelu_affine = self.prelu(self.affine(x_prime))
        res_x = add_tensors(prelu_affine, self.om_root_ratio * x)
        y = add_tensors(self.root_ratio*res_x, self.om_root_ratio * x)
        return y

class ResConvT(ResConv):
    """
    Transpose Residual Convolutional layer
    """
    def __init__(self, in_ch, out_ch, kernel_size, stride, padding, ratio):
        super(ResConvT, self).__init__(in_ch,out_ch,kernel_size,stride,padding,ratio)
        self.conv = nn.ConvTranspose1d(
            in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding)

class Generator(nn.Module):
    """        
    Similar to Generator and Discriminator from Generative Tweening
    """
    def __init__(self):
        super(Generator, self).__init__()
        self.label_embedding = nn.Embedding(N_CLASSES, N_CLASSES)
        self.model = nn.Sequential(
            # Encoder
            ResConv(N_DATA+N_NOISE+N_CLASSES, 256,3,2,1,1/1),
            ResConv(256,512,3,2,1,1/2),
            ResConv(512,768,3,2,1,1/3),
            ResConv(768,1024,3,2,1,1/4),
            ResConv(1024,1536,3,2,1,1/5),
            ResConv(1536,2048,3,2,1,1/6),
            # Decoder
            ResConv(2048,1024,1,1,0,1/1),
            ResConv(1024,1024,3,1,1,1/3),
            ResConvT(1024,1024,3,2,1,1/3),
            ResConv(1024,512,3,1,1,1/9),
            ResConvT(512,512,3,2,1,1/9),
            ResConv(512,256,3,1,1,1/18),
            ResConvT(256,256,3,2,1,1/18),
            ResConv(256,N_POINTS,3,1,1,1/28),
        )
        
    def forward(self, pose, label, noise):
        x = torch.cat((pose, self.label_embedding(label), noise), -1).unsqueeze(2)
        y = self.model(x)
        return y
        
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            ResConv(N_POINTS,512,3,2,1,1/1),
            ResConv(512,512,3,2,1,1/2),
            ResConvT(512,512,3,2,1,1/3),
            ResConv(512,512,3,2,1,1/4),
            ResConvT(512,1024,3,2,1,1/5),
            ResConv(1024,1024,3,2,1,1/6),
            nn.Conv1d(1024,N_CLASSES,kernel_size=1,stride=1,padding=0),
        )
        
    def forward(self, x):
        y = self.model(x)
        return y
    

In [None]:
# For general setup and training approach:
#   https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/cgan/cgan.py
    
# Training inits
batch_size = int(2**7)
epochs = 200
lr = 0.00001
b1 = 0.5
b2 = 0.999

generator = Generator()
discriminator = Discriminator()
adversarial_loss = nn.CrossEntropyLoss()

FloatTensor = torch.FloatTensor
LongTensor = torch.LongTensor

# If GPU available, init for cuda
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    
    FloatTensor = torch.cuda.FloatTensor
    LongTensor = torch.cuda.LongTensor

# Create optimizers after moving model to GPU
#opt_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
#opt_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
opt_G = torch.optim.RMSprop(generator.parameters(), lr=lr)
opt_D = torch.optim.RMSprop(discriminator.parameters(), lr=lr)


In [None]:
# Load data, display pie chart to show split
train, test = cyprus_dataset(FloatTensor)
    
# Count occurences of each label
count = {}
for _, label in tqdm(train, desc="Analyzing data"):
    count[int(label)] = count.get(int(label), 0) + 1

total = len(train)
print('Total Frames: {:,}'.format(total))

# Create pie chart to display the makeup of the dataset
pie_labels = ['afraid','bored','excited','neutral','relaxed']
pcts = [count[k] / total for k in sorted(count.keys())]
plt.pie(
    pcts, 
    labels=[pie_labels[k] + " {:.2f}%".format(pcts[k]*100) for k in sorted(count.keys())], 
    normalize=False)
plt.show()
print()

In [None]:
dl = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=0)
print("Training started: {:,} total poses, batch size {:,}".format(len(train), batch_size))
op = tqdm(range(epochs), position=1, desc="Epoch")
break_early = False

g_losses = np.array([])
d_losses = np.array([])

for epoch in op:
    ip = tqdm(
        enumerate(dl), 
        position=2, 
        total=int(len(train)/batch_size)+1,
        desc="[G loss: ???] [D loss: ???]",
        leave=False)
    
    for i, (data, labels) in ip:        
        # Train generator
        opt_G.zero_grad()
        noise = FloatTensor(np.random.normal(0, 1, (data.shape[0], N_NOISE)))
        gen_labels = labels[torch.randperm(data.shape[0])].long()
        gen_output = generator(data, gen_labels, noise)
        
        # How well are we fooling the discriminator 
        # (matching the desired label)?
        cls = discriminator(gen_output).squeeze(2)
        g_loss = adversarial_loss(cls, gen_labels)
        g_loss.backward()
        opt_G.step()
        
        # Train Discriminator
        opt_D.zero_grad()
        
        # How well can the discriminator label real poses?
        real_cls = discriminator(data[:,:N_POINTS].unsqueeze(2)).squeeze(2)
        d_real_loss = adversarial_loss(real_cls, labels.long())
        
        # Can the discriminator detect the generated poses?
        gen_cls = discriminator(gen_output.detach()).squeeze(2)
        d_gen_loss = adversarial_loss(gen_cls, labels.long())
        
        d_loss = (d_real_loss + d_gen_loss) / 2
        d_loss.backward()
        opt_D.step()     
        
        # Update progress bar
        ip.set_description(
            desc="[G loss: %f] [D loss: %f]"
                % (g_loss.item(), d_loss.item())
        )
        
        # Track losses
        g_losses = np.append(g_losses, g_loss.item())
        d_losses = np.append(d_losses, d_loss.item())
        
        # Break early on low generator loss
        if g_loss.item() < 0.001:
            break_early = True
            break

    ip.close()
    
    if break_early:
        break
    
    torch.save(generator, "generator_checkpoint.torch")
    torch.save(discriminator, "discriminator_checkpoint.torch")
    
torch.save(generator, "generator_final.torch")
torch.save(discriminator, "discriminator_final.torch")
print("Saved models!")

In [None]:
# TODO: display poses
# TODO: Play with losses + optimizers
# TODO: speed up training

In [None]:
xs = np.arange(g_losses.shape[0])
plt.plot(xs, g_losses, label="G Losses")
plt.plot(xs, d_losses, label="D Losses")
plt.xlabel("Training step (Batch)")
plt.ylabel("Loss")
plt.legend()
plt.show()