In [139]:
import numpy as np
import json, codecs
import os
import torch
import torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.autograd.variable import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm

# We preprocess'd the file names so we have a index with a list of all frames per dance id
FRAME_LIST_INDEX = './dance-frame-list.json'

# np.random.seed(0)
NUM_BODY_PARTS = 13
TOTAL_FRAMES = 250

# We have 250 frames. We are going to going to take the 17 body parts, 
# and turn it into 13 (remove eyes and ears). Then 13x2 (13 body parts, 2 vectors)
def from_motion_to_numpy_vector(motion):
    # For now, we only take the first person. Later we can maybe try to feed in all people, or do batches of two
    motion_vector = np.zeros((250, NUM_BODY_PARTS, 2))
    for i, frame in enumerate(motion):
        if len(frame) > 0 and i < TOTAL_FRAMES:
            current_frame_data = frame
            # TODO extend this past just 1 person
            person0 = current_frame_data[0][1:]
            current_frame_vector = np.zeros((NUM_BODY_PARTS, 2))
            current_body_part_idx = 0
            for body_part_data in person0:
                body_part = body_part_data[0]
                if body_part not in ['left_eye', 'left_ear', 'right_eye', 'right_ear']:
                    current_frame_vector[current_body_part_idx] = body_part_data[1]
                    current_body_part_idx = current_body_part_idx + 1
            motion_vector[i] = current_frame_vector
    return motion_vector

def from_numpy_vector_to_motion_coordinates(motion_vector):
    # Reshape so each element in array is an a NUM_BODY_PARTS x 2 array that has coordinates
    return motion_vector.reshape(TOTAL_FRAMES, NUM_BODY_PARTS, 2)

class LetsDanceDataset(torch.utils.data.Dataset):
    categories_hash = {'tango': 0, 'break': 1, 'swing': 2,'quickstep': 3,
                  'foxtrot': 4,'pasodoble': 5,'tap': 6,'samba': 7,'flamenco': 8,
                  'ballet': 9,'rumba': 10,'waltz': 11,'cha': 12,'latin': 13,
                  'square': 14,'jive': 15}
    
    # Precomputed
    MEAN=torch.Tensor([[722.8463, 230.9753], [725.5026, 284.8430], [718.0136, 283.9306], [729.7226, 332.3776], 
          [717.4737, 331.9450], [731.9489, 333.1949], [719.6969, 335.0007], [724.7956, 446.3675],
          [719.7034, 446.8887], [727.5336, 563.0570], [720.0659, 563.5020], [729.5637, 658.3285],
          [716.7125, 658.4642]])
    
    STD = torch.Tensor([[248.5471,  54.5708], [253.7432,  50.6963], [256.4125,  50.9480], [259.8698,  64.7350],
           [262.0792,  64.8512], [262.9285,  85.9804], [260.5722,  86.1529], [254.4909,  51.1837],
           [256.7613,  51.5563], [253.6787,  62.7815], [256.8294,  62.8685], [260.5059,  70.0873],
           [262.8192,  68.4242]])
    
    def __init__(self, root_dir, dances):
        super().__init__()
        self.root_dir = root_dir
        
        self.data = np.zeros((len(dances), TOTAL_FRAMES, NUM_BODY_PARTS, 2))
        self.metadata = dances
        
        dances = list(filter(lambda dance: dance[2] >= TOTAL_FRAMES, dances))
        
        for i, dance in enumerate(tqdm(dances)):
            [category, dance_id, frames] = dance
            current_frame_path = "{}{}/{}.json".format(root_dir, category, dance_id)
            with open(current_frame_path) as f:
                motion = json.load(f)
            self.data[i] = from_motion_to_numpy_vector(motion)
        
        self.transform = transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize(mean=torch.Tensor(self.MEAN),
                               std=torch.Tensor(self.STD)),
        ])
        
        f.close()
        
    def __len__(self):
        return len(self.data)
    
    def getitem_metadata(self, index):
        return self.metadata[index]
    
    def __getitem__(self, index):
        '''
        Returns (category, motion)
        motion is in shape of `(NUM_FRAMES, 13, 2)`
        data is normalized 
        '''
        # todo add transform
        data = torch.Tensor(self.data[index])
        data = self.normalize(data)
        return data

    def get_num_body_parts(self):
        return NUM_BODY_PARTS

    
    
    def save_data(self, filename, np_array):
        # TODO 
        file_path = 'data/' + filename + ".json"
        d = np_array.tolist()
#         d = from_numpy_vector_to_motion_coordinates(np_array).tolist()
        json.dump(d, codecs.open(file_path, 'w', encoding='utf-8'), separators=(',', ':'), sort_keys=True)
    

    def normalize(self, motion):
        return (motion - self.MEAN) / (self.STD)
    
    def denormalize(self, motion):
        return (motion * self.STD) + self.MEAN

# For this first test, we are just using Latin dances
with open(FRAME_LIST_INDEX) as f:
    frames_index = json.load(f)
    np.random.shuffle(frames_index)

    
# train_dances= frames_index[:1000]
# valid_dances = frames_index[100:]
# train_dataset = LetsDanceDataset('../densepose/full/', train_dances)
# valid_dataset = LetsDanceDataset('../densepose/full/', valid_dances)

mini_dataset = LetsDanceDataset('../densepose/full/', frames_index[:10])





  0%|          | 0/10 [00:00<?, ?it/s][A[A[A


 30%|███       | 3/10 [00:00<00:00, 13.79it/s][A[A[A


 60%|██████    | 6/10 [00:00<00:00, 13.09it/s][A[A[A


 80%|████████  | 8/10 [00:00<00:00, 13.80it/s][A[A[A


100%|██████████| 10/10 [00:00<00:00, 14.31it/s][A[A[A


[A[A[A

In [141]:
mini_dataset[9].reshape(250, 26)

tensor([[ 0.0835, -0.7903,  0.0752,  ...,  5.9665, -0.0803,  6.1095],
        [ 0.0732, -0.7499,  0.0573,  ...,  5.9665,  1.0394,  6.1095],
        [ 0.0872, -0.5644,  0.0632,  ...,  5.9665,  1.0301,  6.1095],
        ...,
        [ 1.4977, -2.1652,  2.3629,  ...,  4.3849,  0.7016,  4.4895],
        [ 1.7235, -1.0195,  2.5406,  ...,  4.9546,  0.7213,  4.1823],
        [ 1.8532, -1.6082,  1.9623,  ...,  3.8469,  0.8097,  3.9531]])

In [89]:
small_dataset = frames_index[:30]
test_dataset = LetsDanceDataset('../densepose/full/', small_dataset)




  0%|          | 0/26 [00:00<?, ?it/s][A[A[A


  8%|▊         | 2/26 [00:00<00:02,  9.45it/s][A[A[A


 12%|█▏        | 3/26 [00:00<00:02,  7.77it/s][A[A[A


 15%|█▌        | 4/26 [00:00<00:03,  6.51it/s][A[A[A


 27%|██▋       | 7/26 [00:00<00:02,  9.39it/s][A[A[A


 38%|███▊      | 10/26 [00:00<00:01, 11.14it/s][A[A[A


 46%|████▌     | 12/26 [00:01<00:01, 10.29it/s][A[A[A


 54%|█████▍    | 14/26 [00:01<00:01, 10.45it/s][A[A[A


 62%|██████▏   | 16/26 [00:01<00:00, 10.10it/s][A[A[A


 69%|██████▉   | 18/26 [00:01<00:00, 10.50it/s][A[A[A


 77%|███████▋  | 20/26 [00:01<00:00, 10.73it/s][A[A[A


 85%|████████▍ | 22/26 [00:02<00:00, 10.91it/s][A[A[A


 92%|█████████▏| 24/26 [00:02<00:00, 11.16it/s][A[A[A


100%|██████████| 26/26 [00:02<00:00, 10.24it/s][A[A[A


[A[A[A

tensor([[722.8463, 230.9753],
        [725.5026, 284.8430],
        [718.0136, 283.9306],
        [729.7226, 332.3776],
        [717.4737, 331.9450],
        [731.9489, 333.1949],
        [719.6969, 335.0007],
        [724.7956, 446.3675],
        [719.7034, 446.8887],
        [727.5336, 563.0570],
        [720.0659, 563.5020],
        [729.5637, 658.3285],
        [716.7125, 658.4642]])

In [92]:
v = torch.arange(80).reshape(10, 4, 2).double()
v

tensor([[[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.],
         [ 6.,  7.]],

        [[ 8.,  9.],
         [10., 11.],
         [12., 13.],
         [14., 15.]],

        [[16., 17.],
         [18., 19.],
         [20., 21.],
         [22., 23.]],

        [[24., 25.],
         [26., 27.],
         [28., 29.],
         [30., 31.]],

        [[32., 33.],
         [34., 35.],
         [36., 37.],
         [38., 39.]],

        [[40., 41.],
         [42., 43.],
         [44., 45.],
         [46., 47.]],

        [[48., 49.],
         [50., 51.],
         [52., 53.],
         [54., 55.]],

        [[56., 57.],
         [58., 59.],
         [60., 61.],
         [62., 63.]],

        [[64., 65.],
         [66., 67.],
         [68., 69.],
         [70., 71.]],

        [[72., 73.],
         [74., 75.],
         [76., 77.],
         [78., 79.]]], dtype=torch.float64)

In [95]:
# example of denorm and norm
mean = v.mean(0)
std = v.std(0)
mean /= len(v)
std /= len(v)

def normalize(v, mean, std):
    return (v - mean)/std
def denormalize(v, mean, std):
    return (v * std) + mean

denormalize(normalize(v, mean, std), mean, std)

tensor([[[ 0.0000,  1.0000],
         [ 2.0000,  3.0000],
         [ 4.0000,  5.0000],
         [ 6.0000,  7.0000]],

        [[ 8.0000,  9.0000],
         [10.0000, 11.0000],
         [12.0000, 13.0000],
         [14.0000, 15.0000]],

        [[16.0000, 17.0000],
         [18.0000, 19.0000],
         [20.0000, 21.0000],
         [22.0000, 23.0000]],

        [[24.0000, 25.0000],
         [26.0000, 27.0000],
         [28.0000, 29.0000],
         [30.0000, 31.0000]],

        [[32.0000, 33.0000],
         [34.0000, 35.0000],
         [36.0000, 37.0000],
         [38.0000, 39.0000]],

        [[40.0000, 41.0000],
         [42.0000, 43.0000],
         [44.0000, 45.0000],
         [46.0000, 47.0000]],

        [[48.0000, 49.0000],
         [50.0000, 51.0000],
         [52.0000, 53.0000],
         [54.0000, 55.0000]],

        [[56.0000, 57.0000],
         [58.0000, 59.0000],
         [60.0000, 61.0000],
         [62.0000, 63.0000]],

        [[64.0000, 65.0000],
         [66.0000, 67.0000]

In [None]:
import os
from argparse import ArgumentParser

import torch
import torch.nn as nn
from torch import optim

from c_rnn_gan import Generator, Discriminator

DATA_DIR = 'data'
CKPT_DIR = 'models'

G_FN = 'c_rnn_gan_g.pth'
D_FN = 'c_rnn_gan_d.pth'

G_LRN_RATE = 0.001
D_LRN_RATE = 0.001
MAX_GRAD_NORM = 5.0
# following values are modified at runtime
MAX_SEQ_LEN = 200
BATCH_SIZE = 10


EPSILON = 1e-40 # value to use to approximate zero (to prevent undefined results)

class GLoss(nn.Module):
    ''' C-RNN-GAN generator loss
    '''
    def __init__(self):
        super(GLoss, self).__init__()

    def forward(self, logits_gen):
        logits_gen = torch.clamp(logits_gen, EPSILON, 1.0)
        batch_loss = -torch.log(logits_gen)

        return torch.mean(batch_loss)


class DLoss(nn.Module):
    ''' C-RNN-GAN discriminator loss
    '''
    def __init__(self, label_smoothing=False):
        super(DLoss, self).__init__()
        self.label_smoothing = label_smoothing

    def forward(self, logits_real, logits_gen):
        ''' Discriminator loss

        logits_real: logits from D, when input is real
        logits_gen: logits from D, when input is from Generator

        loss = -(ylog(p) + (1-y)log(1-p))

        '''
        logits_real = torch.clamp(logits_real, EPSILON, 1.0)
        d_loss_real = -torch.log(logits_real)

        if self.label_smoothing:
            p_fake = torch.clamp((1 - logits_real), EPSILON, 1.0)
            d_loss_fake = -torch.log(p_fake)
            d_loss_real = 0.9*d_loss_real + 0.1*d_loss_fake

        logits_gen = torch.clamp((1 - logits_gen), EPSILON, 1.0)
        d_loss_gen = -torch.log(logits_gen)

        batch_loss = d_loss_real + d_loss_gen
        return torch.mean(batch_loss)


def run_training(model, optimizer, criterion, dataloader, freeze_g=False, freeze_d=False):
    ''' Run single training epoch
    '''
    
    num_feats = train_dataset.get_num_body_parts()
    # dataloader.rewind(part='train')
    # batch_meta, batch_song = dataloader.get_batch(BATCH_SIZE, MAX_SEQ_LEN, part='train')

    model['g'].train()
    model['d'].train()

    loss = {}
    g_loss_total = 0.0
    d_loss_total = 0.0
    num_corrects = 0
    num_sample = 0

    for step, dance in enumerate(dataloader):

        real_batch_sz = dance.shape[0]

        # get initial states
        # each batch is independent i.e. not a continuation of previous batch
        # so we reset states for each batch
        # POSSIBLE IMPROVEMENT: next batch is continuation of previous batch
        g_states = model['g'].init_hidden(real_batch_sz)
        d_state = model['d'].init_hidden(real_batch_sz)

        #### GENERATOR ####
        if not freeze_g:
            optimizer['g'].zero_grad()
        # prepare inputs
        z = torch.empty([real_batch_sz, MAX_SEQ_LEN, num_feats]).uniform_() # random vector
        dance = torch.Tensor(dance)

        # feed inputs to generator
        g_feats, _ = model['g'](z, g_states)

        # calculate loss, backprop, and update weights of G
        if isinstance(criterion['g'], GLoss):
            d_logits_gen, _, _ = model['d'](g_feats, d_state)
            loss['g'] = criterion['g'](d_logits_gen)
        else: # feature matching
            # feed real and generated input to discriminator
            _, d_feats_real, _ = model['d'](dance, d_state)
            _, d_feats_gen, _ = model['d'](g_feats, d_state)
            loss['g'] = criterion['g'](d_feats_real, d_feats_gen)

        if not freeze_g:
            loss['g'].backward()
            nn.utils.clip_grad_norm_(model['g'].parameters(), max_norm=MAX_GRAD_NORM)
            optimizer['g'].step()

        #### DISCRIMINATOR ####
        if not freeze_d:
            optimizer['d'].zero_grad()
        # feed real and generated input to discriminator
        d_logits_real, _, _ = model['d'](dance, d_state)
        # need to detach from operation history to prevent backpropagating to generator
        d_logits_gen, _, _ = model['d'](g_feats.detach(), d_state)
        # calculate loss, backprop, and update weights of D
        loss['d'] = criterion['d'](d_logits_real, d_logits_gen)
        if not freeze_d:
            loss['d'].backward()
            nn.utils.clip_grad_norm_(model['d'].parameters(), max_norm=MAX_GRAD_NORM)
            optimizer['d'].step()

        g_loss_total += loss['g'].item()
        d_loss_total += loss['d'].item()
        num_corrects += (d_logits_real > 0.5).sum().item() + (d_logits_gen < 0.5).sum().item()
        num_sample += real_batch_sz

        # # fetch next batch
        # batch_meta, batch_song = dataloader.get_batch(BATCH_SIZE, MAX_SEQ_LEN, part='train')

    g_loss_avg, d_loss_avg = 0.0, 0.0
    d_acc = 0.0
    if num_sample > 0:
        g_loss_avg = g_loss_total / num_sample
        d_loss_avg = d_loss_total / num_sample
        d_acc = 100 * num_corrects / (2 * num_sample) # 2 because (real + generated)

    return model, g_loss_avg, d_loss_avg, d_acc


# def run_validation(model, criterion, dataloader):
#     ''' Run single validation epoch
#     '''
#     num_feats = dataloader.get_num_body_parts()
#     dataloader.rewind(part='validation')
#     batch_meta, batch_song = dataloader.get_batch(BATCH_SIZE, MAX_SEQ_LEN, part='validation')

#     model['g'].eval()
#     model['d'].eval()

#     g_loss_total = 0.0
#     d_loss_total = 0.0
#     num_corrects = 0
#     num_sample = 0

#     while batch_meta is not None and batch_song is not None:

#         real_batch_sz = batch_song.shape[0]

#         # initial states
#         g_states = model['g'].init_hidden(real_batch_sz)
#         d_state = model['d'].init_hidden(real_batch_sz)

#         #### GENERATOR ####
#         # prepare inputs
#         z = torch.empty([real_batch_sz, MAX_SEQ_LEN, num_feats]).uniform_() # random vector
#         batch_song = torch.Tensor(batch_song)

#         # feed inputs to generator
#         g_feats, _ = model['g'](z, g_states)
#         # feed real and generated input to discriminator
#         d_logits_real, d_feats_real, _ = model['d'](batch_song, d_state)
#         d_logits_gen, d_feats_gen, _ = model['d'](g_feats, d_state)
#         # calculate loss
#         if isinstance(criterion['g'], GLoss):
#             g_loss = criterion['g'](d_logits_gen)
#         else: # feature matching
#             g_loss = criterion['g'](d_feats_real, d_feats_gen)

#         d_loss = criterion['d'](d_logits_real, d_logits_gen)

#         g_loss_total += g_loss.item()
#         d_loss_total += d_loss.item()
#         num_corrects += (d_logits_real > 0.5).sum().item() + (d_logits_gen < 0.5).sum().item()
#         num_sample += real_batch_sz

#         # fetch next batch
#         batch_meta, batch_song = dataloader.get_batch(BATCH_SIZE, MAX_SEQ_LEN, part='validation')

#     g_loss_avg, d_loss_avg = 0.0, 0.0
#     d_acc = 0.0
#     if num_sample > 0:
#         g_loss_avg = g_loss_total / num_sample
#         d_loss_avg = d_loss_total / num_sample
#         d_acc = 100 * num_corrects / (2 * num_sample) # 2 because (real + generated)

#     return g_loss_avg, d_loss_avg, d_acc


def run_epoch(model, optimizer, criterion, dataloader, ep, num_ep,
              freeze_g=False, freeze_d=False, pretraining=False):
    ''' Run a single epoch
    '''
    model, trn_g_loss, trn_d_loss, trn_acc = \
        run_training(model, optimizer, criterion, dataloader, freeze_g=freeze_g, freeze_d=freeze_d)

    # val_g_loss, val_d_loss, val_acc = run_validation(model, criterion, dataloader)

    if pretraining:
        print("Pretraining Epoch %d/%d " % (ep+1, num_ep), "[Freeze G: ", freeze_g, ", Freeze D: ", freeze_d, "]")
    else:
        print("Epoch %d/%d " % (ep+1, num_ep), "[Freeze G: ", freeze_g, ", Freeze D: ", freeze_d, "]")
        print("\t[Training] G_loss: %0.8f, D_loss: %0.8f, D_acc: %0.2f" % (trn_g_loss, trn_d_loss, trn_acc))
#     print("\t[Training] G_loss: %0.8f, D_loss: %0.8f, D_acc: %0.2f\n"
#           "\t[Validation] G_loss: %0.8f, D_loss: %0.8f, D_acc: %0.2f" %
#           (trn_g_loss, trn_d_loss, trn_acc)
#         #    val_g_loss, val_d_loss, val_acc)
# 		   )
# FIX
        

    # -- DEBUG --
    # This is for monitoring the current output from generator
    # generate from model then save to MIDI file
    g_states = model['g'].init_hidden(1)
    num_feats = train_dataset.get_num_body_parts()
    z = torch.empty([1, MAX_SEQ_LEN, num_feats]).uniform_() # random vector
    if torch.cuda.is_available():
        z = z.cuda()
        model['g'].cuda()

    model['g'].eval()
    g_feats, _ = model['g'](z, g_states)
    dance_data = g_feats.squeeze().cpu()
    dance_data = dance_data.detach().numpy()

    # FIX - this is p bad
    if (ep+1) == num_ep:
        generated_dance = train_dataset.save_data('sample{}_final.dance'.format(num_ep), dance_data)
    else:
        generated_dance = train_dataset.save_data('sample{}.dance'.format(num_ep), dance_data)
    # -- DEBUG --

    return model, trn_acc


def main(args):
    ''' Training sequence
    '''
    train_loader = DataLoader(train_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=True)
    dataloader = train_loader
    num_feats = train_dataset.get_num_body_parts() # FIX

    # First checking if GPU is available
    train_on_gpu = torch.cuda.is_available()
    if train_on_gpu:
        print('Training on GPU.')
    else:
        print('No GPU available, training on CPU.')

    model = {
        'g': Generator(num_feats, use_cuda=train_on_gpu),
        'd': Discriminator(num_feats, use_cuda=train_on_gpu)
    }

    if args.use_sgd:
        optimizer = {
            'g': optim.SGD(model['g'].parameters(), lr=args.g_lrn_rate, momentum=0.9),
            'd': optim.SGD(model['d'].parameters(), lr=args.d_lrn_rate, momentum=0.9)
        }
    else:
        optimizer = {
            'g': optim.Adam(model['g'].parameters(), args.g_lrn_rate),
            'd': optim.Adam(model['d'].parameters(), args.d_lrn_rate)
        }

    criterion = {
        'g': nn.MSELoss(reduction='sum') if args.feature_matching else GLoss(),
        'd': DLoss(args.label_smoothing)
    }

    if args.load_g:
        ckpt = torch.load(os.path.join(CKPT_DIR, G_FN))
        model['g'].load_state_dict(ckpt)
        print("Continue training of %s" % os.path.join(CKPT_DIR, G_FN))

    if args.load_d:
        ckpt = torch.load(os.path.join(CKPT_DIR, D_FN))
        model['d'].load_state_dict(ckpt)
        print("Continue training of %s" % os.path.join(CKPT_DIR, D_FN))

    if train_on_gpu:
        model['g'].cuda()
        model['d'].cuda()

    if not args.no_pretraining:
        for ep in range(args.d_pretraining_epochs):
            model, _ = run_epoch(model, optimizer, criterion, dataloader,
                              ep, args.d_pretraining_epochs, freeze_g=True, pretraining=True)

        for ep in range(args.g_pretraining_epochs):
            model, _ = run_epoch(model, optimizer, criterion, dataloader,
                              ep, args.g_pretraining_epochs, freeze_d=True, pretraining=True)

    freeze_d = False
    for ep in range(args.num_epochs):
        # if ep % args.freeze_d_every == 0:
        #     freeze_d = not freeze_d

        model, trn_acc = run_epoch(model, optimizer, criterion, dataloader, ep, args.num_epochs, freeze_d=freeze_d)
        if args.conditional_freezing:
            # conditional freezing
            freeze_d = False
            if trn_acc >= 95.0:
                freeze_d = True

    if not args.no_save_g:
        torch.save(model['g'].state_dict(), os.path.join(CKPT_DIR, G_FN))
        print("Saved generator: %s" % os.path.join(CKPT_DIR, G_FN))

    if not args.no_save_d:
        torch.save(model['d'].state_dict(), os.path.join(CKPT_DIR, D_FN))
        print("Saved discriminator: %s" % os.path.join(CKPT_DIR, D_FN))


# if __name__ == "__main__":

# ARG_PARSER = ArgumentParser()
# ARG_PARSER.add_argument('--load_g', action='store_true')
# ARG_PARSER.add_argument('--load_d', action='store_true')
# ARG_PARSER.add_argument('--no_save_g', action='store_true')
# ARG_PARSER.add_argument('--no_save_d', action='store_true')

# ARG_PARSER.add_argument('--num_epochs', default=300, type=int)
# ARG_PARSER.add_argument('--batch_size', default=16, type=int)
# ARG_PARSER.add_argument('--g_lrn_rate', default=0.001, type=float)
# ARG_PARSER.add_argument('--d_lrn_rate', default=0.001, type=float)

# ARG_PARSER.add_argument('--no_pretraining', action='store_true')
# ARG_PARSER.add_argument('--g_pretraining_epochs', default=5, type=int)
# ARG_PARSER.add_argument('--d_pretraining_epochs', default=5, type=int)
# ARG_PARSER.add_argument('--use_sgd', action='store_true')
# ARG_PARSER.add_argument('--conditional_freezing', action='store_true')
# ARG_PARSER.add_argument('--label_smoothing', action='store_true')
# ARG_PARSER.add_argument('--feature_matching', action='store_true')

class ARGS():
    def __init__(self):
        self.load_g =  False
        self.load_d =  False
        self.no_save_g =  False
        self.no_save_d =  False

        self.num_epochs =  500
        self.batch_size =  8
        self.g_lrn_rate =  0.0000001
        self.d_lrn_rate =  0.0000001

        self.no_pretraining =  False
        self.g_pretraining_epochs =  5
        self.d_pretraining_epochs =  5
        self.use_sgd =  False
        self.conditional_freezing =  False
        self.label_smoothing =  False
        self.feature_matching =  False
ARGS = ARGS()
# ARGS = ARG_PARSER.parse_args()
MAX_SEQ_LEN = TOTAL_FRAMES # todo
BATCH_SIZE = ARGS.batch_size

main(ARGS)