Create images & videos out of saved generation data

# Imports

In [1]:
import multiprocessing
import sys
import tempfile
from pathlib import Path

import imageio
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from IPython.display import HTML
from matplotlib.gridspec import GridSpec
from moviepy.editor import ImageSequenceClip
from torch import Tensor

sys.path.append(str(Path("..").resolve()))

from utils.misc import _normalize_elements_for_logging

In [2]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f65c75c1f10>

# Experiment

In [3]:
root_experiments_path = Path("/", "projects", "static2dynamic", "Thomas", "experiments")

In [6]:
project_name = "REMOVEME_tests"
run_name = "iter_inv_regen_test"

In [7]:
project_path = root_experiments_path / project_name
assert project_path.exists(), f"Project path {project_path} does not exist."

run_path = project_path / run_name
assert run_path.exists(), f"Run path {run_path} does not exist."

output_dir = Path(".", "artifacts_viz_outputs", project_name, run_name)
output_dir.mkdir(exist_ok=True, parents=True)
print("output dir:", output_dir)

output dir: artifacts_viz_outputs/REMOVEME_tests/iter_inv_regen_test


# Show video of trajectories

In [11]:
vids_path = run_path / "saved_artifacts" / "trajectories"
assert vids_path.exists(), f"Vids path {vids_path} does not exist."
all_trajs = list(vids_path.rglob("*.pt"))
all_trajs = [x for x in all_trajs if "proc_0" in x.name]

print(f"Found {len(all_trajs)} trajectories generated by process 0.")
last_traj_path = all_trajs[-1]
print(f"Last trajectory: {last_traj_path.name}")

Found 5 trajectories generated by process 0.
Last trajectory: step_150_proc_0.pt


In [13]:
last_traj = torch.load(last_traj_path, map_location="cpu")
print(f"Loaded trajectory with shape {last_traj.shape}, dtype {last_traj.dtype} on cpu")

Loaded trajectory with shape torch.Size([16, 5, 3, 128, 128]), dtype torch.float32 on cpu


In [14]:
nomrd_vids = _normalize_elements_for_logging(
    last_traj,
    ["image min-max", "video min-max", "[-1;1] raw", "[-1;1] clipped"],
)

In [15]:
gen_traj_index = 1  # what sample to visualize (sample = video)

In [17]:
for norm_method, vid in nomrd_vids.items():
    print(
        f"Norm method: {norm_method} | saved data shape: {vid.shape} | selected video sample shape: {vid[gen_traj_index].shape}"
    )
    # Convert np array to video
    video_save_path = Path(output_dir, f"{norm_method}.mp4")
    imageio.mimwrite(video_save_path, vid[gen_traj_index].transpose(0, 2, 3, 1), fps=1)

    # Display the video with a specific width
    display(HTML(f'<video width="300" controls><source src="{video_save_path}" type="video/mp4"></video>'))

Norm method: image min-max | saved data shape: (16, 5, 3, 128, 128) | selected video sample shape: (5, 3, 128, 128)


Norm method: video min-max | saved data shape: (16, 5, 3, 128, 128) | selected video sample shape: (5, 3, 128, 128)


Norm method: [-1;1] raw | saved data shape: (16, 5, 3, 128, 128) | selected video sample shape: (5, 3, 128, 128)


Norm method: [-1;1] clipped | saved data shape: (16, 5, 3, 128, 128) | selected video sample shape: (5, 3, 128, 128)


# Video grid

In [7]:
def tensor_transformation(t: Tensor, norm: str | None, img_size: int) -> Tensor:
    assert t.ndim == 4, f"Expected 4D tensor, got {t.ndim}D tensor"
    assert t.shape[2] == t.shape[3], f"Expected square image, got {t.shape}"
    match norm:
        case "clip[-1;1]":
            # clip to [-1:1] and move to [0;1]
            t = t.clip(-1, 1)
            t /= 2
            t += 0.5
        case "min-max":
            # min-max normalize to move to [0;1]
            t -= t.amin(dim=(1, 2, 3), keepdim=True)
            t /= t.amax(dim=(1, 2, 3), keepdim=True)
        case "min-max_across_times":
            # min-max normalize and move to [0;1]
            t -= t.min()
            t /= t.max()
        case "min-95perc_across_times":
            # min-max normalize to move to [0;1]
            t -= t.min()
            t = t.clip(0, t.to(torch.float32).quantile(0.95).item())
            t /= t.max()
        case None | "None":
            # do nothing and let Matplotlib *clip* to [0; 1]
            pass
        case "old log":
            t = (t.cpu().numpy() * 255).astype(np.uint8)
        case _:
            raise ValueError
    # resize to img_size
    if t.shape[2] != img_size:
        print(f"Resizing from {t.shape[2]} to {img_size}")
        t = F.interpolate(t, size=(img_size, img_size), mode="bilinear", align_corners=False)
    # order dims for matplotlib
    t = t.permute(0, 2, 3, 1)
    # convert to fp32 for matplotlib
    t = t.to(torch.float32)
    return t


def process_index(tensor: Tensor, idx: int, norm_method: str | None, img_size: int) -> Tensor:
    assert tensor.ndim == 5, f"Expected 5D tensor, got {tensor.ndim}D tensor"
    img = tensor_transformation(tensor[idx, ...], norm_method, img_size)
    return img

In [8]:
def save_frame(t: int, images: list[Tensor], temp_dir: str, nrows: int, ncols: int, tick_values: list[float]):
    fig = plt.figure(layout="constrained", figsize=(nrows * 2 + 0.2, ncols * 2))  # nrows * ncols vids + 1 progress bar
    fig.patch.set_facecolor("black")
    title = fig.suptitle(f"Time: {t}/{(len(images[0]) - 1)}")
    title.set_color("white")
    height_ratios = [1] * nrows + [0.2]  # 0.5 for the progress bar row
    gs = GridSpec(nrows + 1, ncols, figure=fig, height_ratios=height_ratios)

    # Create a progress bar artist
    progress_bar = patches.Rectangle((0, 0.45), t / (len(images[0]) - 1), 0.1, facecolor="white")
    progress_ax = fig.add_subplot(gs[nrows, :])  # Add progress bar to the bottom row
    progress_ax.add_patch(progress_bar)
    progress_ax.set_xticks(tick_values)
    progress_ax.get_yaxis().set_visible(False)
    progress_ax.xaxis.set_tick_params(width=2, color="white")
    progress_ax.set_facecolor("black")
    progress_ax.set_xticklabels([str(idx + 1) for idx in range(len(tick_values))], color="white")

    for row in range(nrows):
        for col in range(ncols):
            ax = fig.add_subplot(gs[row, col])
            ax.imshow(images[row * nrows + col][t])
            ax.axis("off")

    # Save the frame as an image
    plt.savefig(f"{temp_dir}/frame_{t}.png")
    plt.close()


def create_image_sequence_video(
    tensor_path: Path,
    norm_method: str | None,
    tick_values: list[float],
    sample_idx: list[int] | None = None,
    img_size: int = 128,
    nrows: int = 4,
    ncols: int = 4,
):
    # Load tensor
    tensor = torch.load(tensor_path).cpu()
    print(f"Loaded tensor of shape {tensor.shape}")

    print(f"Using norm method: {norm_method}, image_size: {img_size}")

    sel_samples = sample_idx if sample_idx is not None else list(range(nrows * ncols))
    with multiprocessing.Pool() as pool:
        images = pool.starmap(process_index, [(tensor, idx, norm_method, img_size) for idx in sel_samples])

    print(f"Processed images; got: len(images): {len(images)}; images[0].shape: {images[0].shape}")

    # Create a temporary directory
    with tempfile.TemporaryDirectory() as temp_dir:
        with multiprocessing.Pool() as pool:
            pool.starmap(save_frame, [(t, images, temp_dir, nrows, ncols, tick_values) for t in range(len(images[0]))])

        # Compile the frames into a video using imageio
        frames = [f"{temp_dir}/frame_{i}.png" for i in range(len(images[0]))]
        print(f"Saved {len(frames)} frames: {frames}")
        assert len(frames) == len(images[0]), f"Expected {len(images[0])} frames, got {len(frames)}"

        save_path = Path(
            "make_traj_vids_saved_plots", project_name, run_name, f"{tensor_path.stem}_{norm_method}_norm.mp4"
        )
        save_path.parent.mkdir(parents=True, exist_ok=True)

        print(f"Saving videos to {save_path}...")
        fps = max(int(len(images[0]) / 10), 1)
        clip = ImageSequenceClip(frames, fps=fps)
        clip.write_videofile(save_path.as_posix(), threads=4)

    print(f"Saved videos to {save_path}")

In [9]:
ticks_pos = [0.0, 0.2531645596027374, 0.5063291192054749, 0.7468354105949402, 1.0]

In [10]:
create_image_sequence_video(last_traj_path, "min-max", ticks_pos)

Loaded tensor of shape torch.Size([16, 10, 3, 128, 128])
Using norm method: min-max, image_size: 128
Processed images; got: len(images): 16; images[0].shape: torch.Size([10, 128, 128, 3])
Saved 10 frames: ['/tmp/tmp4leanf4y/frame_0.png', '/tmp/tmp4leanf4y/frame_1.png', '/tmp/tmp4leanf4y/frame_2.png', '/tmp/tmp4leanf4y/frame_3.png', '/tmp/tmp4leanf4y/frame_4.png', '/tmp/tmp4leanf4y/frame_5.png', '/tmp/tmp4leanf4y/frame_6.png', '/tmp/tmp4leanf4y/frame_7.png', '/tmp/tmp4leanf4y/frame_8.png', '/tmp/tmp4leanf4y/frame_9.png']
Saving videos to make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_min-max_norm.mp4...
Moviepy - Building video make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_min-max_norm.mp4.
Moviepy - Writing video make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_min-max_norm.mp4



                                                           

Moviepy - Done !
Moviepy - video ready make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_min-max_norm.mp4
Saved videos to make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_min-max_norm.mp4


In [11]:
create_image_sequence_video(last_traj_path, "None", ticks_pos)

Loaded tensor of shape torch.Size([16, 10, 3, 128, 128])
Using norm method: None, image_size: 128
Processed images; got: len(images): 16; images[0].shape: torch.Size([10, 128, 128, 3])


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

Saved 10 frames: ['/tmp/tmphc16kare/frame_0.png', '/tmp/tmphc16kare/frame_1.png', '/tmp/tmphc16kare/frame_2.png', '/tmp/tmphc16kare/frame_3.png', '/tmp/tmphc16kare/frame_4.png', '/tmp/tmphc16kare/frame_5.png', '/tmp/tmphc16kare/frame_6.png', '/tmp/tmphc16kare/frame_7.png', '/tmp/tmphc16kare/frame_8.png', '/tmp/tmphc16kare/frame_9.png']
Saving videos to make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_None_norm.mp4...
Moviepy - Building video make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_None_norm.mp4.
Moviepy - Writing video make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_None_norm.mp4



                                                            

Moviepy - Done !
Moviepy - video ready make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_None_norm.mp4
Saved videos to make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_None_norm.mp4


In [12]:
create_image_sequence_video(last_traj_path, "min-max_across_times", ticks_pos)

Loaded tensor of shape torch.Size([16, 10, 3, 128, 128])
Using norm method: min-max_across_times, image_size: 128
Processed images; got: len(images): 16; images[0].shape: torch.Size([10, 128, 128, 3])
Saved 10 frames: ['/tmp/tmp9uh3wyt0/frame_0.png', '/tmp/tmp9uh3wyt0/frame_1.png', '/tmp/tmp9uh3wyt0/frame_2.png', '/tmp/tmp9uh3wyt0/frame_3.png', '/tmp/tmp9uh3wyt0/frame_4.png', '/tmp/tmp9uh3wyt0/frame_5.png', '/tmp/tmp9uh3wyt0/frame_6.png', '/tmp/tmp9uh3wyt0/frame_7.png', '/tmp/tmp9uh3wyt0/frame_8.png', '/tmp/tmp9uh3wyt0/frame_9.png']
Saving videos to make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_min-max_across_times_norm.mp4...
Moviepy - Building video make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_min-max_across_times_norm.mp4.
Moviepy - Writing video make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_min-max_across_times_norm.mp4



                                                            

Moviepy - Done !
Moviepy - video ready make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_min-max_across_times_norm.mp4
Saved videos to make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_min-max_across_times_norm.mp4


In [13]:
create_image_sequence_video(last_traj_path, "clip[-1;1]", ticks_pos)

Loaded tensor of shape torch.Size([16, 10, 3, 128, 128])
Using norm method: clip[-1;1], image_size: 128
Processed images; got: len(images): 16; images[0].shape: torch.Size([10, 128, 128, 3])
Saved 10 frames: ['/tmp/tmpbg8fv47o/frame_0.png', '/tmp/tmpbg8fv47o/frame_1.png', '/tmp/tmpbg8fv47o/frame_2.png', '/tmp/tmpbg8fv47o/frame_3.png', '/tmp/tmpbg8fv47o/frame_4.png', '/tmp/tmpbg8fv47o/frame_5.png', '/tmp/tmpbg8fv47o/frame_6.png', '/tmp/tmpbg8fv47o/frame_7.png', '/tmp/tmpbg8fv47o/frame_8.png', '/tmp/tmpbg8fv47o/frame_9.png']
Saving videos to make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_clip[-1;1]_norm.mp4...
Moviepy - Building video make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_clip[-1;1]_norm.mp4.
Moviepy - Writing video make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_clip[-1;1]_norm.mp4



                                                            

Moviepy - Done !
Moviepy - video ready make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_clip[-1;1]_norm.mp4
Saved videos to make_traj_vids_saved_plots/GaussianProxy/xattn_IBENS_test/step_56400_clip[-1;1]_norm.mp4
