# Imports

In [None]:
import random
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from PIL import Image
from torch import IntTensor

In [None]:
torch.set_grad_enabled(False)

# Training figure: forward noising & objective

In [None]:
base_data_path = Path(
    "/projects/static2dynamic/datasets/biotine/3_channels_min_99_perc_normalized_rgb_stacks_png/patches_255"
)

In [None]:
time_1 = random.sample(list((base_data_path / "1").glob("*.png")), k=3)
time_5 = random.sample(list((base_data_path / "5").glob("*.png")), k=2)
samples = time_1 + time_5
print([s.name for s in samples])

In [None]:
batch = []
for s in samples:
    img = Image.open(s)
    batch.append(torch.tensor(np.array(img)))

clean = torch.stack(batch).permute(0, 3, 1, 2)  # (N, C, H, W)
print(clean.shape, clean.dtype, clean.min().item(), clean.max().item())

fig, axes = plt.subplots(1, len(samples))
for i, img in enumerate(samples):
    axes[i].imshow(np.array(Image.open(img)))
    axes[i].axis("off")
plt.show()

In [None]:
scaled_clean = ((clean / 255 - 0.5) * 2).clamp(-1, 1).to(torch.float32)
noise = torch.randn_like(scaled_clean)

fig, axes = plt.subplots(1, len(samples))
for i, img in enumerate(noise):
    img = (img - img.min()) / (img.max() - img.min())
    axes[i].imshow(img.permute(1, 2, 0).numpy())
    axes[i].axis("off")
plt.show()
fig, axes = plt.subplots(1, len(samples))
for i, img in enumerate(noise):
    img = img.clamp(-1, 1) / 2 + 0.5
    axes[i].imshow(img.permute(1, 2, 0).numpy())
    axes[i].axis("off")
plt.show()


scheduler = DDIMScheduler()

noised = clean.clone()
noised = scheduler.add_noise(scaled_clean, noise, IntTensor([300]))

fig, axes = plt.subplots(1, len(samples))
for i, img in enumerate(noised):
    img = (img.clamp(-1, 1) / 2 + 0.5) * 255
    axes[i].imshow(img.permute(1, 2, 0).to(torch.uint8).numpy())
    axes[i].axis("off")
plt.show()

In [None]:
for i, img in enumerate(noise):
    img = img.clamp(-1, 1) / 2 + 0.5
    img = img.permute(1, 2, 0).numpy()
    Image.fromarray((img * 255).astype(np.uint8)).save(f"misc_figures/noise_{i}.png")
plt.show()

# Video time embeddings visualization

In [None]:
import sys

import seaborn as sns
import umap

sys.path.append("..")
from GaussianProxy.utils.models import VideoTimeEncoding

plt.style.use("ggplot")

In [None]:
video_time_encoder_path = Path(
    "/projects/static2dynamic/Thomas/experiments/GaussianProxy/biotine_all_paired_new_jz_MANUAL_WEIGHTS_DOWNLOAD_FROM_JZ_11-02-2025_14h31/saved_model/video_time_encoder"
)
video_time_encoder = VideoTimeEncoding.from_pretrained(
    video_time_encoder_path,
)
video_time_encoder

In [None]:
vid_times = torch.arange(0, 1, 1e-3)
print(vid_times.shape)
vid_times_encs = video_time_encoder(vid_times)
print(vid_times_encs.shape)

In [None]:
vid_times_encs

In [None]:
import pandas as pd

In [None]:
df = pd.DataFrame(vid_times_encs[:, np.random.choice(range(vid_times_encs.shape[1]), size=10)])
df

In [None]:
%matplotlib inline

In [None]:
sns.pairplot(df, corner=True)
plt.show()

## umap

In [None]:
reducer = umap.UMAP()
# no scaling needed
embedding = reducer.fit_transform(vid_times_encs)
print(embedding.shape)

In [None]:
cmap = sns.color_palette("magma", as_cmap=True)
fig, ax = plt.subplots(figsize=(8, 8))
sc = ax.scatter(
    embedding[:, 0],
    embedding[:, 1],
    c=vid_times.numpy(),
    cmap=cmap,
)
ax.set_aspect("equal", "datalim")
ax.set_title("UMAP projection of learned video time embeddings")
cbar = fig.colorbar(
    sc,
    ax=ax,
    label="video time",
    fraction=0.046,
    pad=0.04,
)
cbar.set_ticks([0, 0.2, 0.4, 0.6, 0.8, 1])
plt.show()

In [None]:
reducer = umap.UMAP(n_components=3)
# no scaling needed
embedding = reducer.fit_transform(vid_times_encs)
print(embedding.shape)

In [None]:
%matplotlib widget

In [None]:
ax = plt.figure(figsize=(10, 10)).add_subplot(projection="3d")
sc = ax.scatter(
    embedding[:, 0],
    embedding[:, 1],
    embedding[:, 2],
    c=vid_times.numpy(),
    cmap=cmap,
)
ax.set_aspect("equal", "datalim")
ax.set_title("UMAP projection of learned video time embeddings")
cbar = plt.colorbar(
    sc,
    ax=ax,
    label="video time",
    fraction=0.046,
    pad=0.04,
)
cbar.set_ticks([0, 0.2, 0.4, 0.6, 0.8, 1])
plt.show()

In [None]:
%matplotlib inline

## PCA

In [None]:
from sklearn.decomposition import PCA

In [None]:
pca = PCA()
reduced = pca.fit_transform(vid_times_encs)
reduced.shape

In [None]:
fig, ax = plt.subplots(figsize=(8, 8))
sc = ax.scatter(
    reduced[:, 0],
    reduced[:, 1],
    c=vid_times.numpy(),
    cmap=cmap,
)
ax.set_aspect("equal", "datalim")
ax.set_title("PCA projection of learned video time embeddings")
cbar = fig.colorbar(
    sc,
    ax=ax,
    label="video time",
    fraction=0.046,
    pad=0.04,
)
cbar.set_ticks([0, 0.2, 0.4, 0.6, 0.8, 1])
variance_ratios = pca.explained_variance_ratio_
ax.set_xlabel(f"PCA 1 ({variance_ratios[0] * 100:.0f}% of variance)")
ax.set_ylabel(f"PCA 2 ({variance_ratios[1] * 100:.0f}% of variance)")
plt.show()

In [None]:
%matplotlib widget

In [None]:
ax = plt.figure(figsize=(10, 10)).add_subplot(projection="3d")
sc = ax.scatter(
    reduced[:, 0],
    reduced[:, 1],
    reduced[:, 2],
    c=vid_times.numpy(),
    cmap=cmap,
)
ax.set_aspect("equal", "datalim")
ax.set_title("PCA projection of learned video time embeddings")
cbar = plt.colorbar(
    sc,
    ax=ax,
    label="video time",
    fraction=0.046,
    pad=0.04,
)
cbar.set_ticks([0, 0.2, 0.4, 0.6, 0.8, 1])
variance_ratios = pca.explained_variance_ratio_
ax.set_xlabel(f"PCA 1 ({variance_ratios[0] * 100:.0f}% of variance)")
ax.set_ylabel(f"PCA 2 ({variance_ratios[1] * 100:.0f}% of variance)")
ax.set(zlabel=f"PCA 3 ({variance_ratios[2] * 100:.0f}% of variance)")
plt.show()

In [None]:
%matplotlib inline