In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from pathlib import Path
import os
from random import seed
import pickle
import yaml

from src.modules.vqvae import VQVae
from src.dataset import Dataset

In [3]:
pickle_dir = Path("dataset/pickles")

In [4]:
# Load yaml config
config = yaml.safe_load(open("configs/default.yaml", "r"))

In [5]:
vqvae = VQVae(**config["vqvae"])

pretrained = torch.load(Path("models/vqvae-train-full-feats-lr1e-4-bs64-e300-t512-d512_final_20250509_210032.pth"))
vqvae.load_state_dict(pretrained)
vqvae.to("cuda")
vqvae.eval()

VQVae(
  (encoder): Encoder(
    (branch1): Sequential(
      (0): Conv1d(202, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): ReLU()
      (2): Sequential(
        (0): Conv1d(512, 512, kernel_size=(4,), stride=(2,), padding=(1,))
        (1): Resnet1D(
          (model): Sequential(
            (0): ResConv1DBlock(
              (norm1): Identity()
              (norm2): Identity()
              (activation1): ReLU()
              (activation2): ReLU()
              (conv1): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(9,), dilation=(9,))
              (conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
            )
            (1): ResConv1DBlock(
              (norm1): Identity()
              (norm2): Identity()
              (activation1): ReLU()
              (activation2): ReLU()
              (conv1): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))
              (conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,))

In [6]:
ds = Dataset("dataset", split="eval", compute_stats=False)

Loading precomputed statistics from dataset/stats_all.pkl
Loaded feature-wise statistics successfully
Loaded 650 eval samples


### Prepare ds item

In [7]:
sample = ds[1]

In [8]:
pickle_path = sample['metadata']['pickle_path']

In [13]:
frames = sample['kp'].shape[0]

seq_len = min(sample['kp'].shape[0], 300)

kps = sample['kp'][:seq_len].reshape(1, seq_len, -1)
exps = sample['exp'][:seq_len].reshape(1, seq_len, -1)
x_s = sample['x_s'][:seq_len].reshape(1, seq_len, -1)
t = sample['t'][:seq_len].reshape(1, seq_len, -1)
R = sample['R'][:seq_len].reshape(1, seq_len, -1)
scale = sample['scale'][:seq_len].reshape(1, seq_len, -1)
features = torch.concat([kps, exps, x_s, t, R, scale], dim=2).to("cuda")

### Reconstruct feats

In [14]:
# Load stats
stats = pickle.load(open("dataset/stats_all.pkl", "rb"))

In [15]:
mean = stats['mean'].to("cuda")
std = stats['std'].to("cuda")

In [16]:
features = (features - mean) / std
with torch.no_grad():
    reconstr, commit_loss, perplexity = vqvae(features)
reconstr = reconstr * std + mean

### Generate reconstructed pickle

In [41]:
# Prepare output
def repackage_output(original, reconstr):
    rec_kps = reconstr.squeeze(0)[:, :63]
    rec_exps = reconstr.squeeze(0)[:, 63:126]
    x_s = reconstr.squeeze(0)[:, 126:189]
    t = reconstr.squeeze(0)[:, 189:192]
    R = reconstr.squeeze(0)[:, 192:201]
    scale = reconstr.squeeze(0)[:, 201]

    output = {
        "n_frames": original['metadata']['n_frames'],
        "output_fps": original['metadata']['output_fps'],
        "motion": [
            {
                "kp": rec_kps[i].reshape(1, 21, 3).cpu().numpy(),
                "exp": rec_exps[i].reshape(1, 21, 3).cpu().numpy(),
                "x_s": x_s[i].reshape(1, 21, 3).cpu().numpy(),
                "t": t[i].reshape(1, 3).cpu().numpy(),
                "R": R[i].reshape(1, 3, 3).cpu().numpy(),
                "scale": scale[i].reshape(1, 1).cpu().numpy(),
            } for i in range(len(rec_kps))
        ],
        "c_eyes_lst": [original['c_eyes_lst'][i].cpu().numpy() for i in range(len(original['c_eyes_lst']))],
        "c_lip_lst": [original['c_lip_lst'][i].cpu().numpy() for i in range(len(original['c_lip_lst']))],
    }
    return output

In [42]:
output = repackage_output(sample, reconstr)

In [43]:
video_id = Path(pickle_path).stem
new_path = pickle_dir / f"{video_id}_reconstructed.pkl"

with open(new_path, "wb") as f:
    pickle.dump(output, f)

In [23]:
new_path

PosixPath('dataset/pickles/F4ExnAr-QSE_0_reconstructed.pkl')

In [None]:
import imageio

vid = imageio.get_reader(f"dataset/train/{video_id}.mp4")
frame = vid.get_data(0)

# save frame
imageio.imwrite("assets/examples/source/reconstructed.png", frame)

In [20]:
stats = pickle.load(open("dataset/train/F4ExnAr-QSE_0.pkl", "rb"))

In [40]:
output['motion'][0]['R'].shape

(3, 3)

In [38]:
stats['motion'][0]['R'].shape

(1, 3, 3)