# This is a full pipeline with keypoints prediction using RNN in reconstruction mode for VoxCeleb dataset

# Import functions

In [1]:
import os   
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from Training_Prediction.FOMM.Source_Model.logger import Logger, Visualizer
import numpy as np
import imageio
from Training_Prediction.FOMM.Source_Model.sync_batchnorm import DataParallelWithCallback
from Training_Prediction.PREDICTOR.RNN import GRUModel
from Training_Prediction.FOMM.Source_Model.augmentation import SelectRandomFrames, SelectFirstFrames_two, VideoToTensor
from tqdm import trange
from torch.utils.data import DataLoader, Dataset
from Training_Prediction.FOMM.Source_Model.frames_dataset import FramesDataset
import tensorflow.compat.v1 as tf
import pickle
from Training_Prediction.PREDICTOR.Source_Model.prediction_toplevel import KPDataset,get_data_from_dataloader_60
import gc
import pickle
import yaml
from Training_Prediction.FOMM.Source_Model.modules.generator import OcclusionAwareGenerator,calculate_frechet_distance,compute_fvd
from Training_Prediction.FOMM.Source_Model.modules.keypoint_detector import KPDetector
#from Training_Prediction.FOMM.Source_Model.logger import Logger, Visualizer, Visualizer_slow
from Training_Prediction.FOMM.Source_Model.logger import Logger, Visualizer
from torch import nn
import tensorflow.compat.v1 as tf
from torch.autograd import Variable
import random
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

import os, sys
os.environ["CUDA_VISIBLE_DEVICES"]='0'

2024-07-18 01:38:01.324445: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-18 01:38:01.336312: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-18 01:38:01.339780: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-18 01:38:01.351297: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Import keypoints of 3883 training videos

In [None]:
with open("kp_train_3883_vox.pkl", "rb") as f:
    kp_time_series = pickle.load(f)
len(kp_time_series)

# Convert list of keypoints to dictionary

In [None]:
for video_idx in range(len(kp_time_series)):
    kp_time_series[video_idx] = kp_time_series[video_idx]['kp']

kp_dict_init = []
for video_idx in range(len(kp_time_series)): 
    init_mean = []
    init_jacobian = []
    for frame_idx in range(len(kp_time_series[video_idx])):
        kp_mean = kp_time_series[video_idx][frame_idx]['value'].reshape(1,10,2)
        kp_mean = torch.tensor(kp_mean)
        kp_jacobian = kp_time_series[video_idx][frame_idx]['jacobian'].reshape(1,10,2,2)
        kp_jacobian = torch.tensor(kp_jacobian)

        init_mean.append(kp_mean)
        init_jacobian.append(kp_jacobian)

    init_mean = torch.cat(init_mean)
    init_jacobian = torch.cat(init_jacobian)

    init_mean = torch.reshape(init_mean,(1,init_mean.shape[0],init_mean.shape[1],init_mean.shape[2]))
    init_jacobian = torch.reshape(init_jacobian,(1,init_jacobian.shape[0],10,2,2))

    if torch.cuda.is_available():
        # add tensor to cuda
        init_mean = init_mean.to('cuda:0')
        init_jacobian = init_jacobian.to('cuda:0')

    kp_dict_both = {"value":init_mean,"jacobian":init_jacobian}
    kp_dict_init.append(kp_dict_both)

# Apply min-max standardization to keypoints 

In [None]:
kp_list_train = []
for video_idx in range(len(kp_dict_init)):
    kp_one_video = torch.cat((kp_dict_init[video_idx]['value'], kp_dict_init[video_idx]['jacobian'].reshape(1,-1,10,4)),dim=-1).reshape(-1,60)
    kp_one_video_array = np.array(kp_one_video.cpu())
    kp_list_train.append(kp_one_video_array)
    
#####  min-max std to 60 dimensions of selected one video ######
kp_list_train_std = []
min_list = []
range_list = []
for video_idx in range(len(kp_list_train)):
    min_values = np.min(kp_list_train[video_idx],axis=0) # 60 mins of one selected video in the loop
    max_values = np.max(kp_list_train[video_idx],axis=0) # 60 maxs of one selected video in the loop
    range_values = max_values - min_values 
    kp_one_video_std = (kp_list_train[video_idx] - min_values) / range_values
    kp_list_train_std.append(kp_one_video_std)
    min_list.append(min_values)
    range_list.append(range_values)

trajs = kp_list_train_std
print(len(trajs))
print(trajs[0].shape)

# Convert standardized keypoints to mini-batches: 12 or 24 frames a batch

In [None]:
######### convert data into batches #########
data_batch_train = []

frames = 24 # 24 as one batch, use 12 ground truth frames as input to predict next 12 frames as output
input_frames = int(frames / 2) 
input_dim = 60
for t,x in enumerate(kp_list_train_std):
    if x.shape[0] > frames:
        num_full_batches = x.shape[0] // frames
        for arr in np.array_split(x[:num_full_batches * frames], num_full_batches):
            data_batch_train.append(arr)
print(f'train dataset batches:', len(data_batch_train))
print(data_batch_train[0].shape)

In [None]:
##### train dataset:

train_data_reshape = np.array(data_batch_train).reshape(-1,frames,60)
train_data_reshape.shape

# Define RNN

In [4]:
# Instantiate the model 
input_dim = 60
hidden_dim = 256
output_dim = input_dim
num_layers = 3
learning_rate = 0.001
num_epochs = 100
model = GRUModel(input_dim, hidden_dim, output_dim, num_layers)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training

In [3]:
# Train the model 
train_data = train_data_reshape

# Convert train_data to DataLoader for efficient batching and shuffling
train_loader = DataLoader(train_data, batch_size=250, shuffle=True)

for epoch in range(num_epochs):
    for step, batch in enumerate(train_loader):

        batch_input = batch[:,:input_frames]
        batch_output = batch[:,input_frames:]
        outputs = model(batch_input)

        loss = criterion(outputs, batch_output)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f'Epoch {epoch + 1}/{num_epochs}, Step {step + 1}/{len(train_loader)}, Loss: {loss.item():.4f}')
    
# Save the best model's state_dict to a file
torch.save(model.state_dict(), 'Checkpoints/RNN_3883videos_vox_12-12.pth') # 12-12 frames
# torch.save(model.state_dict(), 'Checkpoints/RNN_3883videos_vox_6-6.pth') # 6-6 frames

NameError: name 'train_data_reshape' is not defined

# After RNN is trained, load saved model

In [5]:
# Load the saved parameters
model.load_state_dict(torch.load('Checkpoints/RNN_3883videos_vox_12-12.pth')) # 12-12 frames
# model.load_state_dict(torch.load('Checkpoints/RNN_3883videos_vox_6-6.pth')) # 6-6 frames

# Set the model to evaluation mode (important if using dropout or batch normalization)
model.eval()

GRUModel(
  (gru): GRU(60, 256, num_layers=3, batch_first=True)
  (fc): Linear(in_features=256, out_features=60, bias=True)
)

# Import keypoints of 44 VoxCeleb test videos

In [6]:
with open("kp_test_44_vox.pkl", "rb") as f:
    kp_time_series = pickle.load(f)
len(kp_time_series)

44

# Convert list of keypoints to dictionary

In [7]:
for video_idx in range(len(kp_time_series)):
    kp_time_series[video_idx] = kp_time_series[video_idx]['kp']

kp_dict_init = []
for video_idx in range(len(kp_time_series)): # 
    init_mean = []
    init_jacobian = []
    for frame_idx in range(len(kp_time_series[video_idx])):
        kp_mean = kp_time_series[video_idx][frame_idx]['value'].reshape(1,10,2)
        kp_mean = torch.tensor(kp_mean)
        kp_jacobian = kp_time_series[video_idx][frame_idx]['jacobian'].reshape(1,10,2,2)
        kp_jacobian = torch.tensor(kp_jacobian)

        init_mean.append(kp_mean)
        init_jacobian.append(kp_jacobian)

    init_mean = torch.cat(init_mean)
    init_jacobian = torch.cat(init_jacobian)

    init_mean = torch.reshape(init_mean,(1,init_mean.shape[0],init_mean.shape[1],init_mean.shape[2]))
    init_jacobian = torch.reshape(init_jacobian,(1,init_jacobian.shape[0],10,2,2))

    if torch.cuda.is_available():
        # add tensor to cuda
        init_mean = init_mean.to('cuda:0')
        init_jacobian = init_jacobian.to('cuda:0')

    kp_dict_both = {"value":init_mean,"jacobian":init_jacobian}
    kp_dict_init.append(kp_dict_both)

# Apply min-max std to keypoints and convert to batches

In [8]:
kp_list_test = []
for video_idx in range(len(kp_dict_init)):
    kp_one_video = torch.cat((kp_dict_init[video_idx]['value'], kp_dict_init[video_idx]['jacobian'].reshape(1,-1,10,4)),dim=-1).reshape(-1,60)
    kp_one_video_array = np.array(kp_one_video.cpu())
    kp_list_test.append(kp_one_video_array)
    
#####  min-max std to 60 dimensions of selected one video ######
kp_list_test_std = []
min_list = []
range_list = []
for video_idx in range(len(kp_list_test)):
    data = kp_list_test[video_idx]
    data_length = len(kp_list_test[video_idx])
    step_interval = 12 # choose between 12 frames or 24 frames 
    min_required_steps = 2*step_interval
    selected_data = []
    for i in range(0, data_length - min_required_steps+1, 2 * step_interval):
        selected_data.extend(data[i:i + step_interval])
    min_values = np.min(selected_data,axis=0) # 60 mins of one selected video in the loop
    max_values = np.max(selected_data,axis=0) # 60 maxs of one selected video in the loop 
    range_values = max_values - min_values 
    kp_one_video_std = (kp_list_test[video_idx] - min_values) / range_values
    kp_list_test_std.append(kp_one_video_std)
    min_list.append(min_values)
    range_list.append(range_values)

trajs = kp_list_test_std
print(len(trajs))
print(trajs[0].shape)

44
(118, 60)


In [9]:
######### convert into batches:
frames = min_required_steps
input_frames = int(frames / 2)
data_batch_test = []
for t,x in enumerate(kp_list_test_std):
    if x.shape[0] >= frames:
        num_full_batches = x.shape[0] // frames
        for arr in np.array_split(x[:num_full_batches * frames], num_full_batches):
            data_batch_test.append(arr)
print(f'test dataset batches:', len(data_batch_test))
print(data_batch_test[0].shape)

test dataset batches: 529
(24, 60)


In [10]:
###### test dataset:

test_data_reshape = np.array(data_batch_test).reshape(-1,frames,60)
test_data_reshape.shape

(529, 24, 60)

# Predict keypoints using trained model:

In [11]:
# test dataset
validation_data = test_data_reshape

# evaluate model:
validation_input = torch.tensor(validation_data[:,:input_frames], dtype = torch.float32) # input: [24,10,17]
kp_gt = torch.tensor(validation_data[:,input_frames:], dtype = torch.float32) # gtoundtruth: [24,10,17]
pred = model(validation_input) # outputs: [24,10,30]

print(kp_gt.shape)
print(pred.shape)

torch.Size([529, 12, 60])
torch.Size([529, 12, 60])


# Generate unstd keypoints 

In [12]:
# save num_batches for each video:
num_batch_video = []
num_full_batches_all = 0
for t,x in enumerate(kp_list_test_std):
    if x.shape[0] > frames:
        num_full_batches = x.shape[0] // frames
        num_full_batches_all += num_full_batches
        num_batch_video.append(num_full_batches)
print(f'number of batches of each video:', len(num_batch_video))

number of batches of each video: 44


In [13]:
# first half of frames: groundtruth; last half of frames: predicted
test_gt_pred = np.concatenate((test_data_reshape[:,:input_frames], pred.detach().numpy()), axis = 1)
test_gt_pred.shape

(529, 24, 60)

In [14]:
# unstd for each video:
test_video_unstd_list = []
for video_idx in range(len(num_batch_video)):
    test_video = test_gt_pred[sum(num_batch_video[:video_idx]):sum(num_batch_video[:video_idx+1])]
    test_video_unstd = test_video * range_list[video_idx] + min_list[video_idx]
    test_video_unstd_list.append(test_video_unstd) # unstd video keypoints

# Optical flow and generator

In [15]:
####### call the config functions and inference dataloader #########
config="./config/abs-vox.yml"

# Test dataset
with open(config) as f:
    config = yaml.safe_load(f)
dataset = FramesDataset(is_train=(False), **config['dataset_params'],mode="RNN") # test

print(len(dataset))
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

### call the functions        
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
                                        **config['model_params']['common_params'])
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                            **config['model_params']['common_params'])

log_dir="./log/test-reconstruction-vox"
checkpoint="./Training_Prediction/FOMM/Trained_Models/vox-cpk.pth.tar"

if checkpoint is not None:
    Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
else:
    raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
    
def save_obj(obj, name ):
    with open('./'+ name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(name ):
    with open('./' + name + '.pkl', 'rb') as f:
        return pickle.load(f)
        
png_dir = os.path.join(log_dir, 'prediction/png')
log_dir = os.path.join(log_dir, 'prediction')

if not os.path.exists(log_dir):
    os.makedirs(log_dir)

if not os.path.exists(png_dir):
    os.makedirs(png_dir)

if torch.cuda.is_available():
    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

generator.eval()
kp_detector.eval()

prediction_params = config['prediction_params']

num_epochs = prediction_params['num_epochs']
lr = prediction_params['lr']
bs = prediction_params['batch_size']
num_frames = prediction_params['num_frames']
loss_list_total = []
fvd_list_total = []

Use predefined train-test split.
using videos from test directory
['id10280#NXjT3732Ekg#001093#001192.mp4', 'id10281#NHARUN9OhSo#000605#000886.mp4', 'id10281#NHARUN9OhSo#001059#001210.mp4', 'id10281#NHARUN9OhSo#002098#002175.mp4', 'id10281#NHARUN9OhSo#002209#002570.mp4', 'id10281#NHARUN9OhSo#006609#006906.mp4', 'id10281#NHARUN9OhSo#006912#007284.mp4', 'id10281#NHARUN9OhSo#007425#007663.mp4', 'id10282#IDA_ElNHLn4#000674#000852.mp4', 'id10282#IDA_ElNHLn4#001226#001390.mp4', 'id10283#N69Hp2DGMLk#000519#000619.mp4', 'id10283#N69Hp2DGMLk#000721#000842.mp4', 'id10283#N69Hp2DGMLk#000893#001589.mp4', 'id10283#N69Hp2DGMLk#004133#005157.mp4', 'id10283#N69Hp2DGMLk#005157#005316.mp4', 'id10283#N69Hp2DGMLk#005931#006184.mp4', 'id10283#N69Hp2DGMLk#006184#006353.mp4', 'id10283#N69Hp2DGMLk#006405#006583.mp4', 'id10283#N69Hp2DGMLk#006600#007118.mp4', 'id10283#N69Hp2DGMLk#007129#007281.mp4', 'id10283#r9-0pljhZqs#002414#002769.mp4', 'id10283#r9-0pljhZqs#003725#003847.mp4', 'id10283#r9-0pljhZqs#004062#004

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [16]:
#########  FOMM+RNN ########

for it, x in tqdm(enumerate(dataloader)):
        if config['reconstruction_params']['num_videos'] is not None:
            if it > config['reconstruction_params']['num_videos']:
                break
        with torch.no_grad():
            predictions = []
            visualizations = []

            ######## keypoints ########
            kp_driving_video = test_video_unstd_list[it].reshape(-1,10,6)
            kp_driving_video = torch.tensor(kp_driving_video)
            kp_source = {"value":kp_driving_video[0,:,:2].reshape(1,10,2),"jacobian":kp_driving_video[0,:,2:].reshape(1,10,2,2)} # kp of the ith frame      
        
        ##### Start generator
        loss_list = []
        fvd_list = []
        for i in range(((x['video'].shape[2])//frames)*frames): # cut the last <24 frames
            source = x['video'][:, :, 0]
            driving = x['video'][:, :, i]
            kp_driving = {"value":kp_driving_video[i,:,:2],"jacobian":kp_driving_video[i,:,2:]} # kp of the ith frame
            kp_driving['value'] = kp_driving['value'].reshape(1,10,2)
            kp_driving['jacobian'] = kp_driving['jacobian'].reshape(1,10,2,2)
            out = generator(source, kp_source=kp_source, kp_driving=kp_driving)
            out['kp_source'] = kp_source
            out['kp_driving'] = kp_driving
            del out['sparse_deformed']
            predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])

            visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
                                                                                    driving=driving, out=out)
            visualizations.append(visualization)
            # mse loss
            if np.abs(out['prediction'].detach().cpu().numpy() - driving.cpu().numpy()).mean() != 0:
                loss_list.append(np.abs(out['prediction'].detach().cpu().numpy() - driving.cpu().numpy()).mean())
                # # Calculate FVD for each frame using ground truth and predicted videos
                # ground_truth_features = driving.detach().cpu().permute(0,2,3,1).reshape(256,256,3)
                # predicted_features = out['prediction'].detach().cpu().permute(0,2,3,1).reshape(256,256,3)
                # fvd_list.append(compute_fvd(ground_truth_features, predicted_features))

        print("Reconstruction loss: %s" % np.mean(loss_list))
        loss_list_total.append(np.mean(loss_list))

        # print("FVD Score: %s" % np.mean(fvd_list))
        # fvd_list_total.append(np.mean(fvd_list))

#        predictions = np.concatenate(predictions, axis=1)
#        imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))
#        image_name = x['name'][0] + config['reconstruction_params']['format']
#        imageio.mimsave(os.path.join(log_dir, image_name), visualizations)

print("mean Reconstruction loss: %s" % np.mean(loss_list_total)) 
# print("mean FVD score: %s" % np.mean(fvd_list_total)) 

1it [00:03,  3.42s/it]

Reconstruction loss: 0.069955476


2it [00:13,  7.45s/it]

Reconstruction loss: 0.05067924


3it [00:18,  6.46s/it]

Reconstruction loss: 0.061348695


4it [00:21,  4.82s/it]

Reconstruction loss: 0.06758981


5it [00:34,  7.76s/it]

Reconstruction loss: 0.07394007


6it [00:44,  8.46s/it]

Reconstruction loss: 0.08074589


7it [00:57,  9.93s/it]

Reconstruction loss: 0.05999605


8it [01:04,  9.19s/it]

Reconstruction loss: 0.054235034


9it [01:09,  7.96s/it]

Reconstruction loss: 0.10292894


10it [01:14,  6.86s/it]

Reconstruction loss: 0.07873031


11it [01:18,  5.96s/it]

Reconstruction loss: 0.042145252


12it [01:22,  5.56s/it]

Reconstruction loss: 0.044186223


13it [01:47, 11.39s/it]

Reconstruction loss: 0.052939765


14it [02:24, 18.94s/it]

Reconstruction loss: 0.051838864


15it [02:29, 14.81s/it]

Reconstruction loss: 0.055624884


16it [02:37, 12.93s/it]

Reconstruction loss: 0.0676938


17it [02:43, 10.77s/it]

Reconstruction loss: 0.060428556


18it [02:49,  9.36s/it]

Reconstruction loss: 0.052590877


19it [03:07, 11.96s/it]

Reconstruction loss: 0.0648884


20it [03:12,  9.83s/it]

Reconstruction loss: 0.06157716


21it [03:25, 10.70s/it]

Reconstruction loss: 0.071538426


22it [03:29,  8.82s/it]

Reconstruction loss: 0.058248665


23it [03:41,  9.77s/it]

Reconstruction loss: 0.07494083


24it [03:49,  9.25s/it]

Reconstruction loss: 0.06661647


25it [03:58,  9.06s/it]

Reconstruction loss: 0.07490342


26it [04:02,  7.49s/it]

Reconstruction loss: 0.06279347


27it [04:07,  6.98s/it]

Reconstruction loss: 0.062716626


28it [04:24,  9.95s/it]

Reconstruction loss: 0.06446099


29it [04:31,  8.91s/it]

Reconstruction loss: 0.06397948


30it [04:36,  7.74s/it]

Reconstruction loss: 0.07525618


31it [04:44,  7.99s/it]

Reconstruction loss: 0.06610256


32it [04:56,  9.14s/it]

Reconstruction loss: 0.066081755


33it [05:03,  8.57s/it]

Reconstruction loss: 0.067728676


34it [05:12,  8.58s/it]

Reconstruction loss: 0.08904547


35it [05:16,  7.05s/it]

Reconstruction loss: 0.08821155


36it [05:23,  7.08s/it]

Reconstruction loss: 0.09193583


37it [05:29,  6.75s/it]

Reconstruction loss: 0.059987795


38it [05:44,  9.38s/it]

Reconstruction loss: 0.043185733


39it [05:50,  8.26s/it]

Reconstruction loss: 0.060258795


40it [05:55,  7.26s/it]

Reconstruction loss: 0.04956418


41it [05:58,  6.16s/it]

Reconstruction loss: 0.0762435


42it [06:03,  5.64s/it]

Reconstruction loss: 0.059528988


43it [06:07,  5.24s/it]

Reconstruction loss: 0.050037395


44it [06:23,  8.71s/it]

Reconstruction loss: 0.043088377
mean Reconstruction loss: 0.06455724



