In [1]:
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader

from modules import Generator, Gaussian_Predictor, Decoder_Fusion, Label_Encoder, RGB_Encoder

from dataloader import Dataset_Dance
from torchvision.utils import save_image
import random
import torch.optim as optim
from torch import stack

from tqdm import tqdm
import imageio

import matplotlib.pyplot as plt
from math import log10

In [5]:

# evaluate the quality of a figure
def Generate_PSNR(imgs1, imgs2, data_range=1.): 
    """PSNR for torch tensor"""
    mse = nn.functional.mse_loss(imgs1, imgs2) # wrong computation for batch size > 1
    psnr = 20 * log10(data_range) - 10 * torch.log10(mse)
    return psnr


class VAE_Model(nn.Module):
    def __init__(self, args):
        super(VAE_Model, self).__init__()
        self.args = args
        
        # Modules to transform image from RGB-domain to feature-domain
        self.frame_transformation = RGB_Encoder(3, args.F_dim)
        self.label_transformation = Label_Encoder(3, args.L_dim)
        
        # Conduct Posterior prediction in Encoder
        self.Gaussian_Predictor   = Gaussian_Predictor(args.F_dim + args.L_dim, args.N_dim)
        self.Decoder_Fusion       = Decoder_Fusion(args.F_dim + args.L_dim + args.N_dim, args.D_out_dim)
        
        # Generative model
        self.Generator            = Generator(input_nc=args.D_out_dim, output_nc=3)
        
        self.optim      = optim.Adam(self.parameters(), lr=self.args.lr)
        self.scheduler  = optim.lr_scheduler.MultiStepLR(self.optim, milestones=[2, 5], gamma=0.1)
        self.mse_criterion = nn.MSELoss()
        self.current_epoch = 0
        
        # Teacher forcing arguments
        self.tfr = args.tfr
        self.tfr_d_step = args.tfr_d_step
        self.tfr_sde = args.tfr_sde
        
        self.train_vi_len = args.train_vi_len
        self.val_vi_len   = args.val_vi_len
        self.batch_size = args.batch_size
        
        
    def forward(self, img, label):
        pass
    
            
    @torch.no_grad()
    def eval(self):
        avg_psnr =[]
        val_loader = self.val_dataloader()
        for (img, label) in (pbar := tqdm(val_loader, ncols=120)):
            img = img.to(self.args.device)
            label = label.to(self.args.device)
            loss, psnr = self.val_one_step(img, label)
            show_dic = {'loss': float(loss.detach().cpu()), 'psnr': psnr}
            self.tqdm_bar('val', pbar, show_dic, lr=self.scheduler.get_last_lr()[0])
            avg_psnr.append(psnr)
        self.psnr = avg_psnr
    
    
    def val_one_step(self, img, label):
        # TODO
        img = img.permute(1, 0, 2, 3, 4)
        label = label.permute(1, 0, 2, 3, 4)

        loss, avg_psnr = 0.0, 0.0
        pre_out = img[0]
        for i in range(1, self.val_vi_len):
            # Decoder
            z = torch.cuda.FloatTensor(1, self.args.N_dim, self.args.frame_H, self.args.frame_W).normal_()
            out_feat = self.frame_transformation(pre_out)
            label_feat = self.label_transformation(label[i])
            parm = self.Decoder_Fusion(out_feat, label_feat, z)
            out = self.Generator(parm)
            MSE = self.mse_criterion(out, img[i])

            psnr = Generate_PSNR(out, img[i]).item()
            avg_psnr += psnr
            # if self.args.store_visualization and self.current_epoch == (self.args.num_epoch-1):
            #     self.save_history('psnr', psnr)

            pre_out = out
            loss += MSE
            
        loss = loss/(self.val_vi_len - 1)
        avg_psnr = avg_psnr/(self.val_vi_len - 1)
        # self.save_history('avg_psnr', avg_psnr)

        return loss, avg_psnr
                
    def val_dataloader(self):
        transform = transforms.Compose([
            transforms.Resize((self.args.frame_H, self.args.frame_W)),
            transforms.ToTensor()
        ])
        dataset = Dataset_Dance(root=self.args.DR, transform=transform, mode='val', video_len=self.val_vi_len, partial=1.0)  
        val_loader = DataLoader(dataset,
                                  batch_size=1,
                                  num_workers=self.args.num_workers,
                                  drop_last=True,
                                  shuffle=False)  
        return val_loader
    


    def load_checkpoint(self):
        
        checkpoint = torch.load(os.path.join('/mnt/left/home/2023/angelina/DLCourse/Lab4/Result_Mono_0427/epoch=80.ckpt'))
        self.load_state_dict(checkpoint['state_dict'], strict=True) 
        self.args.lr = checkpoint['lr']
        self.tfr = checkpoint['tfr']
        
        self.optim      = optim.Adam(self.parameters(), lr=self.args.lr)
        self.scheduler  = optim.lr_scheduler.MultiStepLR(self.optim, milestones=[2, 4], gamma=0.1)
        self.current_epoch = checkpoint['last_epoch']

    def optimizer_step(self):
        nn.utils.clip_grad_norm_(self.parameters(), 1.)
        self.optim.step()



def main(args):
    
    os.makedirs(args.save_root, exist_ok=True)
    model = VAE_Model(args).to(args.device)
    model.load_checkpoint()
    model.eval()

parser = argparse.ArgumentParser(add_help=True)
parser.add_argument('--batch_size',    type=int,    default=2)
parser.add_argument('--lr',            type=float,  default=0.001,     help="initial learning rate")
parser.add_argument('--device',        type=str, choices=["cuda", "cpu"], default="cuda")
parser.add_argument('--optim',         type=str, choices=["Adam", "AdamW"], default="Adam")
parser.add_argument('--gpu',           type=int, default=1)
parser.add_argument('--test',          action='store_true')
parser.add_argument('--store_visualization',      action='store_true', help="If you want to see the result while training")
parser.add_argument('--DR',            type=str, required=True,  help="Your Dataset Path")
parser.add_argument('--save_root',     type=str, required=True,  help="The path to save your data")
parser.add_argument('--num_workers',   type=int, default=4)
parser.add_argument('--num_epoch',     type=int, default=70,     help="number of total epoch")
parser.add_argument('--per_save',      type=int, default=3,      help="Save checkpoint every seted epoch")
parser.add_argument('--partial',       type=float, default=1.0,  help="Part of the training dataset to be trained")
parser.add_argument('--train_vi_len',  type=int, default=16,     help="Training video length")
parser.add_argument('--val_vi_len',    type=int, default=630,    help="valdation video length")
parser.add_argument('--frame_H',       type=int, default=32,     help="Height input image to be resize")
parser.add_argument('--frame_W',       type=int, default=64,     help="Width input image to be resize")


# Module parameters setting
parser.add_argument('--F_dim',         type=int, default=128,    help="Dimension of feature human frame")
parser.add_argument('--L_dim',         type=int, default=32,     help="Dimension of feature label frame")
parser.add_argument('--N_dim',         type=int, default=12,     help="Dimension of the Noise")
parser.add_argument('--D_out_dim',     type=int, default=192,    help="Dimension of the output in Decoder_Fusion")

# Teacher Forcing strategy
parser.add_argument('--tfr',           type=float, default=1.0,  help="The initial teacher forcing ratio")
parser.add_argument('--tfr_sde',       type=int,   default=10,   help="The epoch that teacher forcing ratio start to decay")
parser.add_argument('--tfr_d_step',    type=float, default=0.1,  help="Decay step that teacher forcing ratio adopted")
parser.add_argument('--ckpt_path',     type=str,    default=None,help="The path of your checkpoints")   

# Training Strategy
parser.add_argument('--fast_train',         action='store_true')
parser.add_argument('--fast_partial',       type=float, default=0.4,    help="Use part of the training data to fasten the convergence")
parser.add_argument('--fast_train_epoch',   type=int, default=5,        help="Number of epoch to use fast train mode")

# Kl annealing stratedy arguments
parser.add_argument('--kl_anneal_type',     type=str, default='Cyclical',       help="")
parser.add_argument('--kl_anneal_cycle',    type=int, default=10,               help="")
parser.add_argument('--kl_anneal_ratio',    type=float, default=1,              help="")

args = parser.parse_args()

main(args)


usage: ipykernel_launcher.py [-h] [--batch_size BATCH_SIZE] [--lr LR]
                             [--device {cuda,cpu}] [--optim {Adam,AdamW}]
                             [--gpu GPU] [--test] [--store_visualization] --DR
                             DR --save_root SAVE_ROOT
                             [--num_workers NUM_WORKERS]
                             [--num_epoch NUM_EPOCH] [--per_save PER_SAVE]
                             [--partial PARTIAL] [--train_vi_len TRAIN_VI_LEN]
                             [--val_vi_len VAL_VI_LEN] [--frame_H FRAME_H]
                             [--frame_W FRAME_W] [--F_dim F_DIM]
                             [--L_dim L_DIM] [--N_dim N_DIM]
                             [--D_out_dim D_OUT_DIM] [--tfr TFR]
                             [--tfr_sde TFR_SDE] [--tfr_d_step TFR_D_STEP]
                             [--ckpt_path CKPT_PATH] [--fast_train]
                             [--fast_partial FAST_PARTIAL]
                             [--fast_train_e

SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:

main(args)
