# This is a full pipeline with keypoints prediction using VRNN in transfer mode for VoxCeleb dataset

# Import functions

In [None]:
import os, sys  
from tqdm.auto import tqdm
import gc
import torch
from torch.utils.data import DataLoader
from Training_Prediction.FOMM.Source_Model.logger import Logger, Visualizer
import numpy as np
import imageio
from Training_Prediction.FOMM.Source_Model.sync_batchnorm import DataParallelWithCallback
from Training_Prediction.FOMM.Source_Model.modules.RNN_prediction_module import PredictionModule
from Training_Prediction.FOMM.Source_Model.augmentation import SelectRandomFrames, SelectFirstFrames_two, VideoToTensor
from tqdm import trange
from torch.utils.data import DataLoader, Dataset
from Training_Prediction.FOMM.Source_Model.frames_dataset import FramesDataset_transfer
import tensorflow.compat.v1 as tf
from Training_Prediction.PREDICTOR.Source_Model.VRNN import build_vrnn, get_config
from Training_Prediction.PREDICTOR.Source_Model.VRNN_pytorch import VRNN, get_config, init_weights
import pickle
from Training_Prediction.PREDICTOR.Source_Model.VRNN_prediction import VRNN_predict
from Training_Prediction.PREDICTOR.Source_Model.prediction_toplevel import KPDataset,get_data_from_dataloader_60
import yaml
from Training_Prediction.FOMM.Source_Model.modules.generator import OcclusionAwareGenerator
from Training_Prediction.FOMM.Source_Model.modules.keypoint_detector import KPDetector
from Training_Prediction.FOMM.Source_Model.logger import Logger, Visualizer, Visualizer_slow
from torch import nn

# After VRNN is trained, load the saved model

In [None]:
# model = build_vrnn
# frames = 24
# cfg = get_config()
# input_keypoint = tf.keras.Input(shape=[frames,10,6],name='keypoint')
# observed_keypoints_stop = tf.keras.layers.Lambda(tf.stop_gradient)(
# input_keypoint)
# vrnn_model = model(cfg)
# predicted_keypoints, kl_divergence = vrnn_model(observed_keypoints_stop)
# train_model = tf.keras.Model(inputs=[input_keypoint],outputs=[predicted_keypoints])
# vrnn_coord_pred_loss = tf.nn.l2_loss(
# observed_keypoints_stop - predicted_keypoints)
# # Normalize by batch size and sequence length:
# vrnn_coord_pred_loss /= tf.to_float(
#   tf.shape(input_keypoint)[0] * tf.shape(input_keypoint)[1])
# train_model.add_loss(vrnn_coord_pred_loss)
# kl_loss = tf.reduce_mean(kl_divergence)  # Mean over batch and timesteps.
# train_model.add_loss(cfg.kl_loss_scale * kl_loss)

# # Load saved model:
# cfg = get_config()
# checkpoint_path = "Checkpoints/VRNN_3883videos_vox_12-12.ckpt" #  12-12 frames
# # checkpoint_path = "Checkpoints/VRNN_3883videos_vox_6-6.ckpt" #  6-6 frames

# # Loads the weights
# train_model.load_weights(checkpoint_path)
# train_model.reset_states()
cfg = get_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_path = "Checkpoints_VRNN/VRNN_3883videos_vox_12-12_pytorch_T3.pth"
model = VRNN(cfg).to(device)
model.load_state_dict(torch.load(checkpoint_path))
model.eval()

# Import keypoints of 44 VoxCeleb driving test videos and convert to dictionary

In [None]:
with open("kp_test_44_vox.pkl", "rb") as f:
    kp_time_series = pickle.load(f)
print(len(kp_time_series))

#convert list of keypoints to dictionary:
for video_idx in range(len(kp_time_series)):
    kp_time_series[video_idx] = kp_time_series[video_idx]['kp']

kp_dict_init = []
for video_idx in range(len(kp_time_series)): # 
    init_mean = []
    init_jacobian = []
    for frame_idx in range(len(kp_time_series[video_idx])):
        kp_mean = kp_time_series[video_idx][frame_idx]['value'].reshape(1,10,2)
        kp_mean = torch.tensor(kp_mean)
        kp_jacobian = kp_time_series[video_idx][frame_idx]['jacobian'].reshape(1,10,2,2)
        kp_jacobian = torch.tensor(kp_jacobian)

        init_mean.append(kp_mean)
        init_jacobian.append(kp_jacobian)

    init_mean = torch.cat(init_mean)
    init_jacobian = torch.cat(init_jacobian)

    init_mean = torch.reshape(init_mean,(1,init_mean.shape[0],init_mean.shape[1],init_mean.shape[2]))
    init_jacobian = torch.reshape(init_jacobian,(1,init_jacobian.shape[0],10,2,2))

    if torch.cuda.is_available():
        # add tensor to cuda
        init_mean = init_mean.to('cuda:0')
        init_jacobian = init_jacobian.to('cuda:0')

    kp_dict_both = {"value":init_mean,"jacobian":init_jacobian}
    kp_dict_init.append(kp_dict_both)
    
kp_list_test1 = []
for video_idx in range(len(kp_dict_init)):
    kp_one_video = torch.cat((kp_dict_init[video_idx]['value'], kp_dict_init[video_idx]['jacobian'].reshape(1,-1,10,4)),dim=-1).reshape(-1,60)
    kp_one_video_array = np.array(kp_one_video.cpu())
    kp_list_test1.append(kp_one_video_array)

# Import keypoints of 44 VoxCeleb source test videos and convert to dictionary

In [None]:
with open("kp_test_44_vox_source.pkl", "rb") as f:
    kp_time_series = pickle.load(f)
print(len(kp_time_series))

#convert list of keypoints to dictionary:
for video_idx in range(len(kp_time_series)):
    kp_time_series[video_idx] = kp_time_series[video_idx]['kp']

kp_dict_init = []
for video_idx in range(len(kp_time_series)): # 
    init_mean = []
    init_jacobian = []
    for frame_idx in range(len(kp_time_series[video_idx])):
        kp_mean = kp_time_series[video_idx][frame_idx]['value'].reshape(1,10,2)
        kp_mean = torch.tensor(kp_mean)
        kp_jacobian = kp_time_series[video_idx][frame_idx]['jacobian'].reshape(1,10,2,2)
        kp_jacobian = torch.tensor(kp_jacobian)

        init_mean.append(kp_mean)
        init_jacobian.append(kp_jacobian)

    init_mean = torch.cat(init_mean)
    init_jacobian = torch.cat(init_jacobian)

    init_mean = torch.reshape(init_mean,(1,init_mean.shape[0],init_mean.shape[1],init_mean.shape[2]))
    init_jacobian = torch.reshape(init_jacobian,(1,init_jacobian.shape[0],10,2,2))

    if torch.cuda.is_available():
        # add tensor to cuda
        init_mean = init_mean.to('cuda:0')
        init_jacobian = init_jacobian.to('cuda:0')

    kp_dict_both = {"value":init_mean,"jacobian":init_jacobian}
    kp_dict_init.append(kp_dict_both)
    
kp_list_test2 = []
for video_idx in range(len(kp_dict_init)):
    kp_one_video = torch.cat((kp_dict_init[video_idx]['value'], kp_dict_init[video_idx]['jacobian'].reshape(1,-1,10,4)),dim=-1).reshape(-1,60)
    kp_one_video_array = np.array(kp_one_video.cpu())
    kp_list_test2.append(kp_one_video_array)

# Apply min-max std to keypoints of 44 driving test videos and convert to mini-batches: 12 or 24 frames per batch

In [None]:
#####  min-max std to 60 dimensions of selected one video ######
kp_list_test_std = []
min_list = []
range_list = []
for video_idx in range(len(kp_list_test1)):
    data = kp_list_test1[video_idx]
    data_length = len(kp_list_test1[video_idx])
    step_interval = 12 # choose between 12 frames or 24 frames 
    min_required_steps = 2*step_interval
    selected_data = []
    for i in range(0, data_length - min_required_steps+1, 2 * step_interval):
        selected_data.extend(data[i:i + step_interval])
    min_values = np.min(selected_data,axis=0) # 60 mins of one selected video in the loop
    max_values = np.max(selected_data,axis=0) # 60 maxs of one selected video in the loop 
    range_values = max_values - min_values 
    kp_one_video_std = (kp_list_test1[video_idx] - min_values) / range_values
    kp_list_test_std.append(kp_one_video_std)
    min_list.append(min_values)
    range_list.append(range_values)

trajs = kp_list_test_std
print(len(trajs))
print(trajs[0].shape)

In [None]:
######### convert into batches:
frames = min_required_steps
input_frames = int(frames / 2)
data_batch_test = []
for t,x in enumerate(kp_list_test_std):
    if x.shape[0] >= frames:
        num_full_batches = x.shape[0] // frames
        for arr in np.array_split(x[:num_full_batches * frames], num_full_batches):
            data_batch_test.append(arr)
print(f'test dataset batches:', len(data_batch_test))
print(data_batch_test[0].shape)

In [None]:
###### test dataset:

test_data_reshape = np.array(data_batch_test).reshape(-1,frames,60)
test_data_reshape.shape

# Predict keypoints using trained model:

In [None]:
# Load test dataset and process model.predict():

# validation_data = test_data_reshape

# validation_data_tensor = tf.convert_to_tensor(validation_data.reshape(-1,frames,10,6))
# pred = train_model.predict(validation_data_tensor)
# print(pred.shape)
validation_data = test_data_reshape.reshape(-1, frames, 10, 6)
validation_data_tensor = torch.tensor(validation_data, dtype=torch.float32).to(device)

# run inference
with torch.no_grad():
    pred, kl_div = model(validation_data_tensor)

print(pred.shape)  # should be [B, T, 10, 6]
preds = pred.mean(axis=1)
pred_np = preds.detach().cpu().numpy()

# Generate unstd keypoints:

In [None]:
# save num_batches for each video:
num_batch_video = []
num_full_batches_all = 0
for t,x in enumerate(kp_list_test_std):
    if x.shape[0] > frames:
        num_full_batches = x.shape[0] // frames
        num_full_batches_all += num_full_batches
        num_batch_video.append(num_full_batches)
print(f'number of batches of each video:', len(num_batch_video))

In [None]:
# first half of frames: groundtruth; last half of frames: predicted
test_gt_pred = np.concatenate((test_data_reshape[:,:input_frames], pred_np.reshape(-1,frames,60)[:,input_frames:]), axis = 1)
test_gt_pred.shape

In [None]:
# unstd for each video:
test_video_unstd_list = []
for video_idx in range(len(num_batch_video)):
    test_video = test_gt_pred[sum(num_batch_video[:video_idx]):sum(num_batch_video[:video_idx+1])]
    test_video_unstd = test_video * range_list[video_idx] + min_list[video_idx]
    test_video_unstd_list.append(test_video_unstd) # unstd video keypoints

# Optical flow and generator

In [None]:
####### call the config functions and inference dataloader #########

config="config/abs-vox.yml"

# Test dataset
with open(config) as f:
    config = yaml.safe_load(f)
dataset1 = FramesDataset_transfer(is_train=(False), **config['dataset_params'],mode="RNN") # test: driving videos
print(len(dataset1))
dataloader1 = DataLoader(dataset1, batch_size=1, shuffle=False, num_workers=1)

dataset2 = FramesDataset_transfer(is_train=(True), **config['dataset_params'],mode="VRNN") # test_source: source videos
print(len(dataset2))
dataloader2 = DataLoader(dataset2, batch_size=1, shuffle=False, num_workers=1)

dataset3 = FramesDataset_transfer(is_train=(False), **config['dataset_params'],mode="VRNN") # test_recon: reference videos(only FOMM, no keypoints prediction)
print(len(dataset3))
dataloader3 = DataLoader(dataset3, batch_size=1, shuffle=False, num_workers=1)

### call the functions        
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
                                        **config['model_params']['common_params'])
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                            **config['model_params']['common_params'])

log_dir="./log/test-reconstruction-vox"
checkpoint="./Training_Prediction/FOMM/Trained_Models/vox-cpk.pth.tar"

if checkpoint is not None:
    Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
else:
    raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
    
def save_obj(obj, name ):
    with open('./'+ name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(name ):
    with open('./' + name + '.pkl', 'rb') as f:
        return pickle.load(f)

png_dir = os.path.join(log_dir, 'prediction/png')
log_dir = os.path.join(log_dir, 'prediction')

if not os.path.exists(log_dir):
    os.makedirs(log_dir)

if not os.path.exists(png_dir):
    os.makedirs(png_dir)

if torch.cuda.is_available():
    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

generator.eval()
kp_detector.eval()

prediction_params = config['prediction_params']

num_epochs = prediction_params['num_epochs']
lr = prediction_params['lr']
bs = prediction_params['batch_size']
num_frames = prediction_params['num_frames']
loss_list_total = []

In [None]:
for it, x in tqdm(enumerate(dataloader1)):
        if config['reconstruction_params']['num_videos'] is not None:
            if it > config['reconstruction_params']['num_videos']:
                break
        with torch.no_grad():
            predictions = []
            visualizations = []

            ######## keypoints ########
            kp_driving_video = test_video_unstd_list[it].reshape(-1,10,6) 
            kp_driving_video = torch.tensor(kp_driving_video).to(device)
            kp_source = {"value":torch.tensor(kp_list_test2[it][0]).reshape(10,6)[:,:2].to(device),"jacobian":torch.tensor(kp_list_test2[it][0]).reshape(10,6)[:,2:].reshape(1,10,2,2).to(device)} # kp of the ith frame 
            generator = generator.to(device)
        
        ##### Start generator
        loss_list = []
        for i in range(((x['video'].shape[2])//frames)*frames): # cut the last <24 frames
            driving = x['video'][:, :, i].to(device)
            source = torch.tensor(dataset2[it]['video'][:,0]).reshape(1,3,256,256).to(device) # source frame from set2 
            driving_reference = torch.tensor(dataset3[it]['video'][:,i]).reshape(1,3,256,256)
            kp_driving = {"value":kp_driving_video[i,:,:2],"jacobian":kp_driving_video[i,:,2:]} # kp of the ith frame
            kp_driving['value'] = kp_driving['value'].reshape(1,10,2)
            kp_driving['jacobian'] = kp_driving['jacobian'].reshape(1,10,2,2)
            out = generator(source, kp_source=kp_source, kp_driving=kp_driving)
            out['kp_source'] = kp_source
            out['kp_driving'] = kp_driving
            del out['sparse_deformed']
            predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])

            visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
                                                                                    driving=driving, out=out)
            visualizations.append(visualization) #visualizations[0].shape: (256, 1280, 3)
            # mqe loss
            if np.abs(out['prediction'].detach().cpu().numpy() - driving_reference.cpu().numpy()).mean() != 0:
                loss_list.append(np.abs(out['prediction'].detach().cpu().numpy() - driving_reference.cpu().numpy()).mean())
            del driving, source, driving_reference, kp_driving, out, visualization
            gc.collect()
            torch.cuda.empty_cache()

        print("Reconstruction loss: %s" % np.mean(loss_list))
        loss_list_total.append(np.mean(loss_list))
        
        predictions = np.concatenate(predictions, axis=1)
        imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))
        image_name = x['name'][0] + config['reconstruction_params']['format']
        imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
        del predictions, visualizations
        gc.collect()
        torch.cuda.empty_cache()

print("mean Reconstruction loss: %s" % np.mean(loss_list_total)) 