In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from pathlib import Path
import numpy as np
import os
from random import seed
import pickle
import yaml
import imageio
from copy import deepcopy

from src.modules.vqvae import VQVae
from src.modules.fsq_vqvae import FSQVAE
from train_tokenizer import VQVAEModule
from src.dataset import Dataset
from IPython.display import Image, display
import matplotlib.pyplot as plt
from utils.tokenizer_utils import load_fsq_vae, prepare_features, process_reconstruction, repackage_output

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
pickle_dir = Path("dataset/pickles")

In [None]:
rest_fsq, rest_feats = load_fsq_vae(Path("outputs/modified_rest_fsq_D4/checkpoints/checkpoint_epoch=319.ckpt"))

In [None]:
rot_scale_fsq, rot_scale_feats = load_fsq_vae(Path("outputs/modified_rot_scale_D4/checkpoints/checkpoint_epoch=339.ckpt"))

In [None]:
exp_fsq, exp_feats = load_fsq_vae(Path("outputs/modified_rest_vel_reg_fsq_D4/checkpoints/checkpoint_epoch=459.ckpt"))

In [None]:
lip_fsq, lip_feats = load_fsq_vae(Path("outputs/modified_lips_vel_reg_fsq_D4/checkpoints/checkpoint_epoch=229.ckpt"))

In [None]:
ds = Dataset("dataset", split="eval", val_split=0.1, seed=2, compute_stats=False)

In [10]:
# for j in range(0, 20):
#     target_ids = ["dataset/pickles/JRG5gXNZbgQ_3.pkl"]#, "dataset/pickles/droRkoEh8iE_18.pkl", "dataset/pickles/WQvT1_tQDhg_22.pkl"]

#     counter = 0

#     ds = Dataset("dataset", split="eval", val_split=0.1, seed=j, compute_stats=False)
#     # find_id = f"dataset/pickles/{test_id}.pkl"

#     for i, item in enumerate(ds):
#         if item['metadata']['pickle_path'] in target_ids:
#             counter += 1
#             print(j, i)
        
#         if counter == 2:
#             exit()
#     counter = 0


### Prepare ds item

In [None]:
# 1746 guy video
# 134 girl video seed 2
# 122 seed 1

sample = ds[134] 

pickle_path = sample['metadata']['pickle_path']
vid_id = pickle_path.split("/")[-1].split(".")[0]
vid_path = f"dataset/train/{vid_id}.mp4"

# Read the first frame from the video and display it
frame = imageio.get_reader(vid_path).get_data(0)
plt.imshow(frame)
plt.show()


In [None]:
exp_features, exp_dims = prepare_features(sample, exp_feats, only_lips=False)

In [None]:
rest_features, rest_dims = prepare_features(sample, rest_feats, only_lips=False)

In [None]:
r_scale_features, r_scale_dims = prepare_features(sample, rot_scale_feats, only_lips=False)

In [None]:
feats_lip, lip_dims = prepare_features(sample, lip_feats, only_lips=True)

### Load data statistics

In [18]:
# Load stats
stats = pickle.load(open("dataset/stats_all.pkl", "rb"))

In [19]:
# Send stats to GPU
for key in stats['mean']:
    stats['mean'][key] = stats['mean'][key].to("cuda")
    stats['std'][key] = stats['std'][key].to("cuda")
std = stats['std']
mean = stats['mean']

### Reconstruct features

In [30]:
with torch.no_grad():
    reconstr_rest = rest_fsq(rest_features)
    reconstr_r_scale = rot_scale_fsq(r_scale_features)
    reconstr_exp_rest = exp_fsq(exp_features)
    reconstr_lip = lip_fsq(feats_lip)
    
    # Remove velocity components
    reconstr_exp = torch.cat([reconstr_exp_rest[..., :48], reconstr_lip[..., :15]], dim=-1)
    exp_dims['exp'] = 63


In [31]:
# with torch.no_grad():
#     rest_indices = rest_fsq.encode(rest_features)
#     r_scale_indices = rot_scale_fsq.encode(r_scale_features)
#     exp_indices = exp_fsq.encode(exp_features)
#     lip_indices = lip_fsq.encode(feats_lip)

# print(f"First 10 codes: {rest_indices[:, :10]}")

In [32]:
# with torch.no_grad():
#     rest_codes = rest_fsq.quantizer.indices_to_codes(rest_indices)
#     r_scale_codes = rot_scale_fsq.quantizer.indices_to_codes(r_scale_indices)
#     exp_codes = exp_fsq.quantizer.indices_to_codes(exp_indices)
#     lip_codes = lip_fsq.quantizer.indices_to_codes(lip_indices)

In [33]:
# with torch.no_grad():
#     rest_codes = rest_fsq.preprocess(rest_codes)
#     r_scale_codes = rot_scale_fsq.preprocess(r_scale_codes)
#     exp_codes = exp_fsq.preprocess(exp_codes)
#     lip_codes = lip_fsq.preprocess(lip_codes)

In [34]:
# with torch.no_grad():
#     reconstr_rest = rest_fsq.decoder(rest_codes)
#     reconstr_r_scale = rot_scale_fsq.decoder(r_scale_codes)
#     reconstr_exp_rest = exp_fsq.decoder(exp_codes)
#     reconstr_lip = lip_fsq.decoder(lip_codes)

#     reconstr_rest = rest_fsq.postprocess(reconstr_rest)
#     reconstr_r_scale = rot_scale_fsq.postprocess(reconstr_r_scale)
#     reconstr_exp_rest = exp_fsq.postprocess(reconstr_exp_rest)
#     reconstr_lip = lip_fsq.postprocess(reconstr_lip)

#     reconstr_exp = torch.cat([reconstr_exp_rest[..., :45], reconstr_lip[..., :18]], dim=-1)
#     exp_dims['exp'] = 63

### Test encoding features

In [35]:
with torch.no_grad():
    encodings = rest_fsq.encode(rest_features)
    result = rest_fsq.decode(encodings)

In [None]:
rest_new_reconstr = process_reconstruction(rest_dims, reconstr_rest, False, std, mean)

In [None]:
r_scale_new_reconstr = process_reconstruction(r_scale_dims, reconstr_r_scale, False, std, mean)

In [None]:
exp_new_reconstr = process_reconstruction(exp_dims, reconstr_exp, False, std, mean)

In [41]:
new_reconstr = torch.zeros((*rest_new_reconstr.shape[:-1], 205))
new_reconstr[..., :9] = r_scale_new_reconstr[..., :9]
new_reconstr[..., 9:12] = rest_new_reconstr[..., :3]
new_reconstr[..., 12:75] = exp_new_reconstr
new_reconstr[..., 75:138] = rest_new_reconstr[..., 3:66]
new_reconstr[..., 138:139] = rest_new_reconstr[..., 75:76]
new_reconstr[..., 138:139] = r_scale_new_reconstr[..., 9:10]
new_reconstr[..., 139:142] = rest_new_reconstr[..., 66:69]
new_reconstr[..., 142:205] = rest_new_reconstr[..., 69:132]

In [42]:
new_dims = {
    'R': 9,
    'c_eyes_lst': 2,
    'c_lip_lst': 1,
    'exp': 63,
    'kp': 63,
    'scale': 1,
    't': 3,
    'x_s': 63
}

In [43]:
# zeros = torch.zeros(1, new_reconstr.shape[1], new_reconstr.shape[-1] + 18)
# # full_exp = sample['exp'][:95].reshape(1, 95, -1).to('cuda')

# cur_ind = 0
# reconstr_ind = 0

# for feat, indices in dims.items():
#     print(f"{feat}: {cur_ind}: {cur_ind + indices}")
#     if feat == 'exp':
#         zeros[..., cur_ind: cur_ind + indices] = new_reconstr[..., reconstr_ind: reconstr_ind + indices] 
#         # zeros[..., cur_ind: cur_ind + indices] = full_exp[..., reconstr_ind: reconstr_ind + indices] 
#         zeros[..., cur_ind + indices: cur_ind + indices + 18] = reconst_lip * std[feat][:, 45:] + mean[feat][:, 45:]
#         cur_ind += indices + 18
        
#         reconstr_ind += indices
#     else:
#         zeros[..., cur_ind: cur_ind + indices] = new_reconstr[..., reconstr_ind: reconstr_ind + indices]
#         cur_ind += indices
#         reconstr_ind += indices

In [44]:
# full_feats = sample['exp'][:300].reshape(1, 300, -1).to('cuda')
# zeros = torch.zeros_like(full_feats, device='cuda')
# zeros[..., :45] = new_reconstr
# zeros[..., 45:] = full_feats[..., 45:] * std['exp'][:, 45:] + mean['exp'][:, 45:]

# new_reconstr = zeros

### Generate reconstructed pickle

In [45]:
new_rest_feats = {}

for key, value in rest_feats.items():
    new_rest_feats[key] = value
    new_rest_feats[key]['enabled'] = True

In [None]:
output = repackage_output(sample, new_reconstr, ds, dims=new_dims, feats_data=new_rest_feats)

In [48]:
# # Replace reconstructed features with original ones for a specific range
# feat = 'exp'

# feat_range = (0, 16)

# denormalized_sample = ds.denormalize_features(sample[feat], feat)

# for i in range(output['n_frames']):
#     output['motion'][i][feat][:, feat_range[0]:feat_range[1]] = denormalized_sample[i][:, feat_range[0]:feat_range[1]]

In [None]:
video_id = Path(pickle_path).stem
new_path = pickle_dir / f"{video_id}_reconstructed.pkl"
print(new_path)
with open(new_path, "wb") as f:
    pickle.dump(output, f)

In [50]:
import imageio

vid = imageio.get_reader(f"dataset/train/{video_id}.mp4")
frame = vid.get_data(0)

# save frame
imageio.imwrite("assets/examples/source/reconstructed.png", frame)

In [None]:
!python inference.py -d {new_path} -s assets/examples/source/reconstructed.png