In [37]:
import os
import torch
import json
from io import BytesIO
import numpy as np
import torch.nn as nn
from tqdm import tqdm, trange
from model import EncoderDecoder
import matplotlib.pyplot as plt
import imageio.v2 as imageio
from utils import Dataset
from torch.utils.data.dataloader import DataLoader

In [38]:
device = torch.device('cpu')

In [39]:
data_path = '/Users/reza/Career/DMLab/SURROGATE/Data/trash/test'
result_path = '/Users/reza/Career/DMLab/SURROGATE/results/laplace/3D2D/psi_web_v_first_100_kpo'
cfg_path = os.path.join(result_path, 'cfg.json')
with open(cfg_path, 'r') as f:
    cfg = json.load(f)
subdir_paths = sorted(os.listdir(data_path))
cr_paths = [os.path.join(data_path, p) for p in subdir_paths if p.startswith("cr")]

In [40]:
dataset = Dataset(
    cr_paths=cr_paths,
    instruments=["kpo_mas_mas_std_0101"],
    v_min=cfg["v_min"],
    v_max=cfg["v_max"],
    # rho_min=cfg["rho_min"],
    # rho_max=cfg["rho_max"],
)

In [41]:
model = EncoderDecoder(
    in_channels=1, base_channels=cfg["base_channels"], latent_dim=cfg["latent_dim"]
).to(device)

In [42]:
state_path = os.path.join(result_path, '5.pth')
model.load_state_dict(torch.load(state_path, map_location='cpu'))

  model.load_state_dict(torch.load(state_path, map_location='cpu'))


<All keys matched successfully>

# Validation partition

In [48]:
i = 0
cube = dataset[i]
print(cube.shape)
with torch.no_grad():
    # slicing
    # print(x.shape)
    x = torch.tensor(cube[:, 0, :, :], dtype=torch.float32)
    yhat = model.predict(x.unsqueeze(0), n_slices=cube.shape[1]-1)
    yhat = yhat.squeeze(0)
print(yhat.shape)

(1, 141, 128, 128)


100%|██████████| 140/140 [01:03<00:00,  2.19it/s]

torch.Size([1, 140, 128, 128])





In [49]:
y = cube[:, 1:, :, :]

In [50]:
yhat.shape, y.shape

(torch.Size([1, 140, 128, 128]), (1, 140, 128, 128))

In [51]:
frames = []
for step in trange(yhat.shape[1]):
    val = y[0, step, :, :]
    pred = yhat[0, step, :, :]
    # plt.figure(figsize=(12, 6))
    # plt.subplot(1, 2, 1)
    # plt.imshow(val, cmap="viridis")
    # # plt.contourf(
    # #     np.arange(val.shape[0]),
    # #     np.arange(val.shape[1]),
    # #     val.T,
    # #     cmap="viridis",
    # #     # norm=norm,
    # # )
    # plt.title(f"Actual Slice (k={step+1+starting_slice})")
    # plt.colorbar()
    # # plt.xlabel("I")
    # # plt.ylabel("J")

    # plt.subplot(1, 2, 2)
    # plt.imshow(pred, cmap="viridis")
    # # plt.contourf(
    # #     np.arange(pred.shape[0]),
    # #     np.arange(pred.shape[1]),
    # #     pred.T,
    # #     cmap="viridis",
    # #     # norm=norm,
    # # )
    # plt.title(f"Predicted Slice (k={step+1+starting_slice})")
    # plt.colorbar()
    # # plt.xlabel("I")
    # # plt.ylabel("J")
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    cmap = "viridis"

    # Plot first subplot
    im1 = axes[0].imshow(val, cmap=cmap)
    axes[0].set_title(f"Actual Slice (k={step+1})")

    # Plot second subplot
    im2 = axes[1].imshow(pred, cmap=cmap)
    axes[1].set_title(f"Predicted Slice (k={step+1})")

    # Create a single colorbar
    cbar = fig.colorbar(im1, ax=axes, orientation="vertical", fraction=0.05, pad=0.02)

    # plt.tight_layout()
    # plt.show()

    # plt.show()
    buf = BytesIO()
    plt.savefig(buf, format='png')
    frames.append(imageio.imread(buf))
    plt.close()

100%|██████████| 140/140 [00:16<00:00,  8.62it/s]


In [52]:
output_filename = f"v_2d3d.mp4"
fps = 10  # Adjust frames per second as needed

# Create video directly from frames in memory
with imageio.get_writer(output_filename, fps=fps) as writer:
    for frame in frames:
        writer.append_data(frame)

