In [None]:
import torch
import numpy as np
import sys

IUFNO_path = '/home/cdagrad/dwyerdei/MOR_MoE/IUFNO-CHL'
sys.path.append(IUFNO_path)

# vor=vorticity
#-------------------------------------------------
# Dwyer: 21x400x32x33x16x4 (groups, time, x, y, z, channels)
# channel flow variables are: u,v,w and p. (these components make up the size_w dim)
vor_data = np.load(f'{IUFNO_path}/data_chl_re180/data_mave.npy')
n_steps_per_flow_through = int(800/42.5+1-1e-9) # roundup
vor_data_small = vor_data[:2,-n_steps_per_flow_through:]
#np.save(f'{IUFNO_path}/data_chl_re180/data_mave_small.npy', vor_data_small)

**Results Summary:**
* IUNFO model with reduced dataset size (`weights_IUFNO.pth`): **explodes** well before reaching only 1 flow through!!
    * It doesn’t freeze or rather we couldn’t observe a potential freeze because it explodes first.
* There was a bug with `data_mave_small.np` preprocessing code that included several flow through times but it still exploded.
    * This is why we couldn't possibly train it for the calculated required number of epochs
* The paper's original model (`re180_IUFNO.pth`) is stable but of course requires a huge amount of data...

In [None]:
data_set_size_mutiplier = (n_steps_per_flow_through/vor_data.shape[1])*(1/(vor_data.shape[0]-1))
print(f'{vor_data.shape=}')
print(f'{data_set_size_mutiplier=}')

In [None]:
from IUFNO import FNO4d

modes = 8
width = 80
epochs = 100
learning_rate = 0.001
weight_decay_value = 1e-11
nlayer = 40

model = FNO4d(modes, modes, modes, modes, width, nlayer).cuda()

In [None]:
paper_model = 're180_IUFNO.pth'
our_model = 'weights_IUFNO.pth'
model_fn = our_model
state_dict = torch.load(f'{IUFNO_path}/{model_fn}')
model.load_state_dict(state_dict)
model.eval()

In [None]:
import torch
import numpy as np
from tqdm import tqdm

@torch.no_grad()
def rollout_autoregressive(model: torch.nn.Module,
                           init_frames_np: np.ndarray,
                           num_steps: int,
                           device = torch.device('cuda')) -> np.ndarray:
    """
    Autoregressively predict `num_steps` future frames.
    - init_frames_np: [5, X, Y, Z, 3]
    Returns: [num_steps, X, Y, Z, 3]
    """
    if init_frames_np.ndim != 5 or init_frames_np.shape[0] != 5 or init_frames_np.shape[-1] != 3:
        raise ValueError(f"Expected init_frames_np as [5, X, Y, Z, 3], got {init_frames_np.shape}")

    window = torch.from_numpy(init_frames_np.astype(np.float32))            # [5, X, Y, Z, 3]
    window = window.permute(1, 2, 3, 4, 0).unsqueeze(0).to(device)          # [1, X, Y, Z, 3, 5]

    preds = []
    for _ in tqdm(range(num_steps)):
        delta = model(window)                                               # [1, X, Y, Z, 3, 1]
        delta = delta.squeeze(-1)                                           # [1, X, Y, Z, 3]
        next_frame = window[..., -1] + delta                                # [1, X, Y, Z, 3]
        preds.append(next_frame.squeeze(0).detach().cpu().numpy())
        window = torch.cat([window[..., 1:], next_frame.unsqueeze(-1)], dim=-1)

    return np.stack(preds, axis=0)

In [None]:
# Load initial 5 ground-truth frames from real simulation
# Assumes `IUFNO_path` and `model` are already defined above

group_index = 0
start_timestep = 0

# Load full dataset with memory mapping to avoid loading entire file into RAM
data_path = f"{IUFNO_path}/data_chl_re180/data_mave.npy"
vor_data_mem = np.load(data_path, mmap_mode='r')  # [groups, time, X, Y, Z, 4]

# Use only u,v,w channels
init_frames = vor_data_mem[group_index, start_timestep:start_timestep+5, ..., :3]  # [5, X, Y, Z, 3]
print("Init frames shape:", init_frames.shape)

In [None]:
# Smoke test: 1-step rollout (quick)
assert torch.cuda.is_available(), "CUDA is required to run the IUFNO model."
preds_1 = rollout_autoregressive(model, init_frames, num_steps=1)
print("1-step preds shape:", preds_1.shape)  # expected: [1, X, Y, Z, 3]

In [None]:
# Full rollout: 800 steps, then save
steps = 800
preds_800 = rollout_autoregressive(model, init_frames, num_steps=steps)

save_path = f"{IUFNO_path}/rollout_{steps}.npy"
np.save(save_path, preds_800)
print(f"Saved predictions to: {save_path}")
print("Preds shape:", preds_800.shape)  # expected: [800, X, Y, Z, 3]

# Visualize Predictions:

In [None]:
from grid_figures import GridFigure
import numpy as np

# Load predictions if not present in memory
if 'preds_800' not in globals():
    preds_800 = np.load(f"{IUFNO_path}/rollout_800.npy")

# preds_800 shape: [T, X, Y, Z, 3]
z_idx = preds_800.shape[3] // 2  # middle z-slice
max_time = 800
max_time = min(max_time, preds_800.shape[0])

# Prepare 3D arrays [H, W, T] for each component (u,v,w)
u_xyt = np.transpose(preds_800[:max_time, :, :, z_idx, 0], (1, 2, 0))
v_xyt = np.transpose(preds_800[:max_time, :, :, z_idx, 1], (1, 2, 0))
w_xyt = np.transpose(preds_800[:max_time, :, :, z_idx, 2], (1, 2, 0))

fig = GridFigure(title=f"Predicted u/v/w at z={z_idx} over {max_time} steps", cmap='bwr')
fig.add_3d_row(u_xyt, y_title='u')
fig.add_3d_row(v_xyt, y_title='v')
fig.add_3d_row(w_xyt, y_title='w')
fig_path = f"{IUFNO_path}/grid_rollout_midZ.png"
fig.show(fig_path=fig_path)
print("Saved figure:", fig_path)

## Ground Truth vs Prediction Comparison

In [None]:
# Compare predicted vs ground-truth (u component) at mid z across selected timesteps
import os
import numpy as np
from grid_figures import GridFigure

assert 'preds_800' in globals() or os.path.exists(f"{IUFNO_path}/rollout_800.npy"), "Run the rollout cell first."
if 'preds_800' not in globals():
    preds_800 = np.load(f"{IUFNO_path}/rollout_800.npy")

# Ensure we have the ground-truth sequence aligned with the rollout start
T_pred = preds_800.shape[0]
assert 'vor_data_mem' in globals(), "Reload the dataset cell so vor_data_mem is available."

gt_seq = vor_data_mem[group_index, start_timestep+5:start_timestep+5+T_pred, ..., :3]
T_gt = gt_seq.shape[0]
T = min(T_pred, T_gt)
if T < T_pred:
    preds_800 = preds_800[:T]

z_idx = preds_800.shape[3] // 2

u_pred_xyt = np.transpose(preds_800[:, :, :, z_idx, 0], (1, 2, 0))  # [H,W,T]
u_gt_xyt   = np.transpose(gt_seq[:T, :, :, z_idx, 0], (1, 2, 0))    # [H,W,T]

# Sample 10 timesteps across the rollout
ncols = 10
sample_ts = list(np.linspace(0, T-1, num=ncols, dtype=int))
imgs_gt = [u_gt_xyt[:, :, t] for t in sample_ts]
imgs_pr = [u_pred_xyt[:, :, t] for t in sample_ts]
titles  = [f"t={start_timestep+5+t}" for t in sample_ts]

# Set symmetric color range using 99th percentile to avoid outliers
stack_for_range = np.stack([u_gt_xyt[:, :, sample_ts], u_pred_xyt[:, :, sample_ts]], axis=-1)

fig = GridFigure(title=f"u at z={z_idx}: GT vs Pred", cmap='bwr')
fig.add_img_seq_row(imgs_gt, x_titles=titles, y_title='GT u')
fig.add_img_seq_row(imgs_pr, x_titles=titles, y_title='Pred u')
fig_path = f"{IUFNO_path}/grid_rollout_u_midZ_gt_vs_pred.png"
fig.show(fig_path=fig_path)
print("Saved comparison:", fig_path)