In [None]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np

In [None]:
from thesis.data_management import SequenceManager

sm = SequenceManager(3)
w2c = sm.cameras[1]
c2w = torch.linalg.inv(w2c)
cam_pos = c2w[:, :3, 3]
center = cam_pos.mean(dim=0)
dist_to_center = torch.linalg.norm(cam_pos - center, dim=1)
diagonal = torch.max(dist_to_center)
diagonal * 1.1

In [None]:
from thesis.flame import FlameHeadVanilla
from thesis.render_vertex_video import render_vertex_video

flame_head = FlameHeadVanilla()
faces = flame_head.faces
sequence = 86
a = torch.load(f'saved_vertex_preds/sequence_{sequence}_flame.pt').cpu()
# audio_path = '/home/schlack/CodeTalker/demo/wav/man.wav'
audio_path = f'/home/schlack/new_master_thesis/data/nersemble/Paul-audio-856/856/sequences/sequence_{sequence:04d}/audio/audio_recording.ogg'
output_path = 'demo.mp4'

render_vertex_video(
    vertices=a,
    faces=faces,
    audio_path=audio_path,
    output_path=output_path,
)

In [None]:
a.shape

In [None]:
from thesis.evaluation import evaluate

pred_dir = 'tmp/pred/ablations/7_markov_chain_monte_carlo/flame'
gt, pred, l = evaluate(
    pred_dir=pred_dir,
    sequences=list(range(80, 82)),
    device='cuda',
)

In [None]:
from matplotlib import pyplot as plt

plt.imshow(pred.cpu().numpy())
plt.show()

In [None]:
from matplotlib import pyplot as plt

plt.imshow(gt.cpu().numpy())
plt.show()

In [None]:
pred

In [None]:
from thesis.evaluation import evaluate

path = 'tmp/pred/ablations/7_markov_chain_monte_carlo/flame'
evaluate(path)

In [None]:
# load config
from thesis.config import load_config

config_path = "configs/single_frame.yml"
config = load_config(config_path)

In [None]:
# load single frame model
from thesis.gaussian_splatting.single_frame import GaussianSplattingSingleFrame

model = GaussianSplattingSingleFrame(
    gaussian_splatting_settings=config.gaussian_splatting_settings,
    learning_rates=config.learning_rates,
)
model = model.cuda()

In [None]:
# sanity check
params = model.splats
optimizers, schedulers = model.configure_optimizers()
splat_optimizers = {k: optimizers[i] for i, k in enumerate(model.splat_optimizer_keys)}
model.strategy.check_sanity(params, splat_optimizers)

In [None]:
from thesis.data_management import SingleSequenceDataset
from torch.utils.data import DataLoader

train_set = SingleSequenceDataset(
    sequence=config.sequence,
    start_idx=config.frame,
    end_idx=config.frame + 1,
    n_cameras_per_frame=config.gaussian_splatting_settings.camera_batch_size)

train_loader = DataLoader(train_set, batch_size=None, shuffle=True)

In [None]:
a = next(iter(train_loader))
batch = train_set.prepare_data(a)
rendered_images, rendered_alphas, infos = model.forward(
    intrinsics=batch.intrinsics,
    world_2_cam=batch.world_2_cam,
    cam_2_world=None,
    image_height=int(batch.image.shape[1]),
    image_width=int(batch.image.shape[2]),
    # color_correction=batch.color_correction, # TODO: fix color correction
    cur_sh_degree=None,
    se3_transform=batch.se3_transform,
)

In [None]:
from matplotlib import pyplot as plt
from dreifus.render import project
from dreifus.matrix import Pose, Intrinsics

intrinsics = Intrinsics(batch.intrinsics[0].detach().cpu().numpy())
pose = Pose(batch.world_2_cam[0].detach().cpu().numpy())
projected = project(model.splats['means'].detach().cpu().numpy(), pose, intrinsics)
image = rendered_images[0].detach().cpu().numpy()
projected_x = projected[:, 0].round().astype(int)
projected_y = projected[:, 1].round().astype(int)
valid_x = (0 <= projected_x) & (projected_x < image.shape[1])
valid_y = (0 <= projected_y) & (projected_y < image.shape[0])
valid_xy = valid_x & valid_y
print(f"{valid_xy.sum() / projected.shape[0] * 100:.1f}% of the vertices are visible"
      f" i.e. {valid_xy.sum()} out of {projected.shape[0]}\n")
for y, x in zip(projected_y[valid_xy], projected_x[valid_xy]):
    for i in range(-3, 3):
        if not 0 <= y + i < image.shape[0]:
            continue
        for j in range(-3, 3):
            if not 0 <= x + j < image.shape[1]:
                continue
            image[y + i, x + j] = [255, 255, 255]

plt.imshow(image)
plt.show()

In [None]:
path = '/home/schlack/master-thesis/data/Paul-audio-85/085/sequences/SEN-05-glow_eyes_sweet_girl/annotations/tracking/FLAME2023_v2/tracked_flame_params.npz'
data = np.load(path)
data['translation'][72 // 3]