In [1]:
import os
import numpy as np
import torch
from torch.utils.data import DataLoader
from accelerate import Accelerator
import sys
sys.path.append('/app')
from dataloaders.simple_dataloader import SimpleDataset, collect_sim_paths, get_sims, min_max_normalize, compute_climatology, get_coords, get_cr_dirs
from model import make_deeponet
from utils.gif_generator import create_gif_from_array
import torch 
import torch.nn as nn
import toml

# from utils.data_utils import read_hdf
# from dataloaders.cnn_deeponet_dataloader import DeepONetDataset, get_cr_dirs
# from utils.gif_generator import create_gif_from_array 
# from trainer import train, save_training_results_artifacts
from model import DeepONetCNN

  from .autonotebook import tqdm as notebook_tqdm
Using backend: pytorch
Other supported backends: tensorflow.compat.v1, tensorflow, jax, paddle.
paddle supports more examples now and is recommended.


In [6]:
class DeepONetDataset(SimpleDataset):
    def __init__(
        self,
        data_path,
        cr_list,
        v_min=None,
        v_max=None,
        instruments=None,
        scale_up=1,
        pos_embedding=None
    ):
        super().__init__(
            data_path=data_path,
            cr_list=cr_list,
            v_min=v_min,
            v_max=v_max,
            instruments=instruments,
            scale_up=scale_up,
            pos_embedding=pos_embedding,
        )
    def __getitem__(self, index):
        cube = self.sims[index]

        u_surface = cube[:, 0, :, :]              # (C, H, W)
        y_target = cube[0, 1:, :, :]              # (R, H, W)

        # Branch input (CNN)
        branch_input = torch.tensor(u_surface, dtype=torch.float32)

        # Grid
        nR, nH, nW = y_target.shape
        r = np.arange(1, nR + 1, dtype=np.float32)
        h = np.arange(nH, dtype=np.float32)
        w = np.arange(nW, dtype=np.float32)

        Rg, Hg, Wg = np.meshgrid(r, h, w, indexing="ij")

        coords = np.stack([Rg, Hg, Wg], axis=-1).reshape(-1, 3)      # (N,3)
        target = y_target.reshape(-1).astype(np.float32)             # (N,)

        trunk_input = torch.from_numpy(coords)    # (1, N, 3)
        target = torch.from_numpy(target)         # (1, N)

        return {
            "branch": branch_input,        # (C, H, W)
            "trunk": trunk_input,          # (1, N, 3)
            "target": target,              # (1, N)
        }

    def __len__(self):
        return len(self.sims)

    def get_branch_input_dims(self):
        C, H, W = self.sims.shape[1], self.sims.shape[3], self.sims.shape[4]
        return C * H * W

    def get_trunk_input_dims(self):
        return 3

In [7]:
def predict_full_grid_in_chunks(model, branch, coords, H, W, chunk_size=32768, accelerator=None):
    """
    model: DeepONet
    branch: (1, C, H, W)
    coords: (N, 3)
    """
    device = next(model.parameters()).device
    branch = branch.to(device)
    coords = coords.to(device)

    N = coords.shape[0]
    preds = torch.zeros(N, device=device)

    model.eval()
    with torch.no_grad():
        for start in range(0, N, chunk_size):
            end = min(start + chunk_size, N)
            coords_chunk = coords[start:end].unsqueeze(0)        # (1, n_chunk, 3)

            if accelerator:
                with accelerator.autocast():
                    y_chunk = model(branch, coords_chunk)         # (1, n_chunk)
            else:
                y_chunk = model(branch, coords_chunk)

            preds[start:end] = y_chunk[0]

    return preds.reshape(H, W)

In [8]:
with open('/app/src/DeepONetCNN/config.toml', 'r') as f:
    config = toml.load(f)

DATA_DIR = config['train_params']['data_dir']
BASE_DIR = config['train_params']['base_dir']
batch_size = config['train_params']['batch_size']


model_type = config['model_params']['model_type']
scale_up = config['model_params']['scale_up']
loss_fn_str = config['model_params']['loss_fn']
pos_embedding = config['model_params']['pos_embedding']
trunk_sample_size = config['model_params']['trunk_sample_size']
branch_layers = config['model_params'].get('branch_layers', [128,128,128,128])
trunk_layers = config['model_params'].get('trunk_layers', [128,128,128,128])
job_id = "2025_11_20__080328"

cr_dirs = get_cr_dirs(DATA_DIR)
split_ix = int(len(cr_dirs) * 0.8)
cr_train, cr_val = cr_dirs[:10], cr_dirs[split_ix:]
cr_val = cr_val[::len(cr_val)//10]
train_dataset = DeepONetDataset(DATA_DIR, cr_train, scale_up=scale_up, pos_embedding=pos_embedding)   
val_dataset = DeepONetDataset(
    DATA_DIR,
    cr_val,
    scale_up=scale_up,
    v_min=train_dataset.v_min,
    v_max=train_dataset.v_max,
    pos_embedding=pos_embedding,
)

Loading simulations: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.54it/s]
Loading simulations: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [00:09<00:00,  3.46it/s]


In [9]:
device = torch.device(f"cuda:1" if torch.cuda.is_available() else "cpu")
radii, thetas, phis = train_dataset.get_grid_points()

if loss_fn_str == "l2":
    loss_fn = LpLoss(d=2, p=2)
elif loss_fn_str == "h1":
    loss_fn = H1LossSpherical(r_grid=radii[1:], theta_grid=thetas, phi_grid=phis)
elif loss_fn_str == "h1mae":
    loss_fn = H1LossSphericalMAE(r_grid=radii[1:], theta_grid=thetas, phi_grid=phis)
elif loss_fn_str == "mse":
    loss_fn = nn.MSELoss()
else:
    raise ValueError("unsupported loss function")

out_path = os.path.join(BASE_DIR, model_type, job_id)

os.makedirs(os.path.join(out_path, 'result_gifs'), exist_ok=True)

if pos_embedding == 'pt':
    in_channels = 4
elif pos_embedding == 'ptr':
    raise ValueError('radii embedding is the same in full channel and is not supported here')
elif pos_embedding is None or pos_embedding == False:
    in_channels = 1
else:
    raise ValueError('wrong pos embedding')

model = DeepONetCNN(
    in_channels=in_channels,
    trunk_in_dim=3,
    latent_dim=128,
    trunk_hidden=trunk_layers,
)

model.load_state_dict(torch.load('./best_model.pt', map_location='cpu', weights_only=True))
model = model.to(device)
print(model)
batch_size = 6


gen_cpu = torch.Generator(device="cuda")
gen_cpu.manual_seed(42)  # optional, for reproducibility    # Make DataLoaders use CPU RNG to avoid device mismatch

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    pin_memory=False,
    generator=gen_cpu,
)

model.eval()
step = 1
for batch in tqdm(train_loader):
    u = batch["branch"].to(device)     # (B, C*H*W)   or (B, D_branch)
    coords = batch["trunk"].to(device) # (B, N, 3)    or sometimes (N, 3) broadcasted
    y_true = batch["target"].to(device) # (B, N)

    B = y_true.shape[0]
    R,H,W = val_dataset.sims.shape[2:]
    pred = model(u, coords)     # (B, N)

    # ---- denormalize for metrics (matches your code path) ----
    real_y   = y_true * (train_dataset.v_max - train_dataset.v_min) + train_dataset.v_min
    real_pred= pred    * (train_dataset.v_max - train_dataset.v_min) + train_dataset.v_min
    real_y   = real_y * 481.3711
    real_pred= real_pred * 481.3711
    real_y = real_y.view(B, R, H, W)
    real_pred = real_pred.view(B, R, H, W)
    for i in range(len(B)):
        input_file_name = f'input_step_{step}.gif'
        output_file_name = f'output_step_{step}.gif'
        create_gif_from_array(real_y.transpose(1,2,0), os.path.join(out_path, 'result_gifs'), file_name=input_file_name)
        create_gif_from_array(real_pred.transpose(1,2,0), os.path.join(out_path, 'result_gifs'), file_name=input_file_name)
        step += 1


DeepONetCNN(
  (branch): CNNBranch(
    (cnn): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ReLU()
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU()
      (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (9): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (10): ReLU()
      (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (mix): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
    (gap): AdaptiveAvgPool2d(output_size=(1, 1))
    (fc): Linear(in_features=128, out_features=128, bias=True)
  )
  (trunk): TrunkMLP(
    (ml

RuntimeError: CUDA error: CUDA driver version is insufficient for CUDA runtime version
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Device-side assertions were explicitly omitted for this error check; the error probably arose while initializing the DSA handlers.

In [21]:
# !pip install torch==2.4.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

Looking in indexes: https://download.pytorch.org/whl/cu124, https://pypi.ngc.nvidia.com
Collecting torch==2.4.0
  Downloading https://download.pytorch.org/whl/cu124/torch-2.4.0%2Bcu124-cp310-cp310-linux_x86_64.whl (797.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m797.3/797.3 MB[0m [31m77.7 MB/s[0m  [33m0:00:10[0m:00:01[0m00:01[0m
Collecting torchaudio
  Downloading https://download.pytorch.org/whl/cu124/torchaudio-2.6.0%2Bcu124-cp310-cp310-linux_x86_64.whl.metadata (6.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.99 (from torch==2.4.0)
  Downloading https://download.pytorch.org/whl/cu124/nvidia_cuda_nvrtc_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl (24.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.7/24.7 MB[0m [31m71.8 MB/s[0m  [33m0:00:00[0m eta [36m0:00:01[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.4.99 (from torch==2.4.0)
  Downloading https://download.pytorch.org/whl/cu124/nvidia_cuda_runtime_cu12-12.4.99-py

Installing collected packages: triton, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cusolver-cu12, torch, torchvision, torchaudio
[2K  Attempting uninstall: triton
[2K    Found existing installation: triton 3.2.0
[2K    Uninstalling triton-3.2.0:
[2K      Successfully uninstalled triton-3.2.0━━━━━[0m [32m 0/15[0m [triton]
[2K  Attempting uninstall: nvidia-nvtx-cu12━━━━━━━━[0m [32m 0/15[0m [triton]
[2K    Found existing installation: nvidia-nvtx-cu12 12.4.1275[0m [triton]
[2K    Uninstalling nvidia-nvtx-cu12-12.4.127:━[0m [32m 0/15[0m [triton]
[2K      Successfully uninstalled nvidia-nvtx-cu12-12.4.127━━━━━━━━━━[0m [32m 1/15[0m [nvidia-nvtx-cu12]
[2K  Attempting uninstall: nvidia-nvjitlink-cu12━━━━━━━━━━━━━━━━━[0m [32m 1/15[0m [nvidia-nvtx-cu12]
[2K    Found existing installation: nvidia-

In [10]:
print("Torch CUDA version:", torch.version.cuda)
print("Compiled With:", torch.version.cuda)
print("CUDA available:", torch.cuda.is_available())
print("Device Count:", torch.cuda.device_count())

x = torch.randn(1).cuda()
print("Tensor is ok:", x)

Torch CUDA version: 12.4
Compiled With: 12.4
CUDA available: False
Device Count: 0


RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx