In [1]:
import os
import torch
import json
from io import BytesIO
import numpy as np
import torch.nn as nn
import imageio.v2 as imageio
from tqdm import tqdm, trange
from utils import SphericalDataset
from architectures import UNetSpherical
import matplotlib.pyplot as plt

In [2]:
device = torch.device('cpu')

In [3]:
data_path = '/Users/reza/Career/DMLab/SURROGATE/Data/psi_web_sample/train'
result_path = '/Users/reza/Career/DMLab/SURROGATE/results/laplace/deepsphere-weather-2/exp_11'
cfg_path = os.path.join(result_path, 'cfg.json')
with open(cfg_path, 'r') as f:
    cfg = json.load(f)

In [4]:
len(cfg['train_files'])

680

In [5]:
instruments = [
        "kpo_mas_mas_std_0101",
        "mdi_mas_mas_std_0101",
        "hmi_mast_mas_std_0101",
        "hmi_mast_mas_std_0201",
        "hmi_masp_mas_std_0201",
        "mdi_mas_mas_std_0201",
    ]
subdir_paths = sorted(os.listdir(data_path))
cr_paths = [os.path.join(data_path, p) for p in subdir_paths if p.startswith("cr")]
sim_paths = []
for cr_path in cr_paths:
    for instrument in instruments:
        instrument_path = os.path.join(cr_path, instrument)
        if os.path.exists(instrument_path):
            sim_paths.append(instrument_path)

In [6]:
dataset = SphericalDataset(
    sim_paths,
    b_min=cfg["train_min"],
    b_max=cfg["train_max"],
)

In [7]:
cfg['train_min'], cfg['train_max']

(-2.6859591007232666, 2.788627862930298)

In [8]:
model = UNetSpherical(32, in_channels=2, out_channels=1, knn=cfg['n_neighbors'])
print(model.test_mode)
state = torch.load(os.path.join(result_path, '96.pth'), map_location=device)
model.load_state_dict(state)

32
32
32
32
16
16
8
8
False


<All keys matched successfully>

In [9]:
cube = dataset[0]

In [10]:
cube.shape

torch.Size([141, 12288, 1])

In [15]:
I = 10
cube = dataset[I].unsqueeze(0)

model.eval()
x0 = cube[:, 0, :, :].to(device)
xi = cube[:, 0, :, :].to(device)
yhats = []
for i in trange(cfg['seq_len'], leave=False):
    x = torch.cat([x0, xi], dim=-1)
    with torch.no_grad():
        yhat = model(x.to(device))
        yhats.append(yhat)
        xi = yhat
yhats = torch.stack(yhats, dim=1)

                                             

In [16]:
yhats.shape

torch.Size([1, 5, 12288, 1])

In [17]:
frames = []
for i in trange(cfg['seq_len']):
    y = cube[:, i+1, :, :]
    yhat = yhats[:, i, :, :]
    # print(cube.shape)
    y = y.reshape((96, 128))
    # print(yhat.shape)
    yhat = yhat.reshape((96, 128))
    # y = y.transpose(1, 0)
    # yhat = yhat.transpose(1, 0)

    error = np.abs(y - yhat)
    vmin = min(y.min(), yhat.min())
    vmax = max(y.max(), yhat.max())
    fig, axes = plt.subplots(1, 3, figsize=(12, 6))
    cmap = "coolwarm"

    # # Plot first subplot
    im1 = axes[0].imshow(y, cmap=cmap, vmin=vmin, vmax=vmax)
    axes[0].set_title(f"gt: br002 at {i+1}")

    # Plot second subplot
    im3 = axes[1].imshow(yhat, cmap=cmap, vmin=vmin, vmax=vmax)
    axes[1].set_title(f"pred at {i+1} ")

    cbar = fig.colorbar(im1, ax=axes, orientation="horizontal", fraction=0.1, pad=0.02)

    # Plot second subplot
    im4 = axes[2].imshow(error, cmap="gray")
    axes[2].set_title(f"|pred-gt|")

    # Create a single colorbar
    cbar = fig.colorbar(im4, ax=axes, orientation="vertical", fraction=0.05, pad=0.02)

    # plt.tight_layout()
    buf = BytesIO()
    plt.savefig(buf, format="png")
    frames.append(imageio.imread(buf))
    plt.close()

100%|██████████| 5/5 [00:00<00:00,  6.16it/s]


In [18]:
output_filename = f"b{I}.gif"
fps = 2  # Adjust frames per second as needed

# Create video directly from frames in memory
with imageio.get_writer(output_filename, fps=fps, loop=0) as writer:
    for frame in frames:
        writer.append_data(frame)