In [5]:
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import imageio
import numpy as np
import cv2

data = torch.load(r"./data/cloud/n_512_dt_0.1_F_512/000000.pt", weights_only=True)

In [7]:
WIDTH = 512
HEIGHT = 512

In [21]:
def points_to_histograms(X, weight, batch_size=4):
    """
    Vectorized histogram creation with batched frames to manage memory.
    
    arguments:
        X:       (F, n, 2) tensor of positions across F frames
        weight:  (F, n) tensor of weights
        batch_size: number of frames to process at once
    """
    F, n, _ = X.shape
    assert tuple(weight.shape) == (F, n), "Weights must be of shape (F, n)"
    assert F % batch_size == 0
    
    result = torch.zeros(F, WIDTH * HEIGHT, device=X.device)
    for i in tqdm(range(0, F, batch_size)):
        batch_end = min(i + batch_size, F)

        X_slice = X[i:batch_end]   # shape: (f, n, 2) where f = batch_end - i
        mask = (0 <= X_slice[:,:,0]) & (X_slice[:,:,0] < WIDTH) \
             & (0 <= X_slice[:,:,1]) & (X_slice[:,:,1] < HEIGHT)
    
        net_weight = torch.zeros(batch_size, HEIGHT * WIDTH)
        
        # Assign flattened indices
        indices = (X_slice[:,:,0].long() * WIDTH + X_slice[:,:,1].long())  # shape: (f, n)
        indices = torch.clamp(indices, 0, WIDTH * HEIGHT - 1)

        net_weight.scatter_add_(1, indices, weight[i:batch_end] * mask)

        # Add back to results, cutting out extra index
        result[i:batch_end] += net_weight

    return result.reshape((F, HEIGHT, WIDTH))

In [22]:
X = data["X"]
F, n, _ = X.shape

m = torch.broadcast_to(data["m"][None,:], (F, n))
print(X.shape, m.shape)

torch.Size([512, 512, 2]) torch.Size([512, 512])


In [27]:
hist = points_to_histograms(X, m, batch_size=64)

100%|██████████| 8/8 [00:00<00:00, 119.51it/s]


In [28]:
x_min = 0
x_max = 512
y_min = 0
y_max = 512

frames = []
fig, ax = plt.subplots(figsize=(5.12, 5.12), dpi=100)
fig.patch.set_facecolor('black')

F, _, _ = hist.shape

for i in tqdm(range(F), ncols=80):
    x = X[i]

    ax.clear()
    ax.set_xlim(0, 512)
    ax.set_ylim(0, 512)
    ax.imshow(hist[i], cmap="Greys_r")
    plt.axis("off")
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

    fig.canvas.draw()
    image = np.frombuffer(fig.canvas.buffer_rgba(), dtype="uint8")
    image = image.reshape(fig.canvas.get_width_height() + (4,))
    image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
    
    # Make the image writable by creating a copy
    image = image.copy()
    frames.append(image)

plt.close()

print(f"Finished rendering, saving to MP4...")

# Save frames as an animated GIF with looping
imageio.mimsave(f"./test.mp4", frames, fps=30) #, loop=0)

100%|█████████████████████████████████████████| 512/512 [00:14<00:00, 34.26it/s]


Finished rendering, saving to MP4...
