In [10]:
import os
import numpy as np
import torch
from torch.utils.data import DataLoader
from accelerate import Accelerator
import sys
sys.path.append('/app')
from dataloaders.dataloaders.simple_dataloader import SimpleDataset, collect_sim_paths, get_sims, min_max_normalize, compute_climatology, get_coords, get_cr_dirs
from model import make_deeponet
from utils.gif_generator import create_gif_from_array




In [None]:
class DeepONetDataset(SimpleDataset):
    def __init__(
        self,
        data_path,
        cr_list,
        v_min=None,
        v_max=None,
        instruments=None,
        scale_up=1,
        pos_embedding = None,
        trunk_sample_size=32768,
    ):
        super().__init__(
            data_path=data_path,
            cr_list=cr_list,
            v_min=v_min,
            v_max=v_max,
            instruments=instruments,
            scale_up=scale_up,
            pos_embedding=pos_embedding,
        )
        self.trunk_sample_size = trunk_sample_size
        # self.sim_paths = collect_sim_paths(data_path, cr_list, instruments)
        # sims, _ = get_sims(self.sim_paths, scale_up, pos_embedding)
        # sims, self.v_min, self.v_max = min_max_normalize(sims, v_min, v_max)
        # self.sims = sims
        # self.climatology = compute_climatology(sims[:, 0, 1:, :, :], scale_up)

    def __getitem__(self, index):
        cube = self.sims[index]

        u_surface = cube[:, 0, :, :]   # (C, H, W)
        y_target = cube[0, 1:, :, :] 

        # Flatten surface for branch input
        branch_input = torch.tensor(u_surface, dtype=torch.float32).reshape(-1)
        
        # Fast random sampling of trunk points
        nR, nH, nW = y_target.shape
        idx_r = np.random.randint(0, nR, size=self.trunk_sample_size)
        idx_h = np.random.randint(0, nH, size=self.trunk_sample_size)
        idx_w = np.random.randint(0, nW, size=self.trunk_sample_size)
        
        
        self.r = np.arange(1, nR + 1, dtype=np.float32)
        self.h = np.arange(nH, dtype=np.float32)
        self.w = np.arange(nW, dtype=np.float32)
        

        return {
            "branch": branch_input,   # (H * W * C,)
        }

    def __len__(self):
        return len(self.sims)

    def get_min_max(self):
        return {"v_min": float(self.v_min), "v_max": float(self.v_max)}

    def get_grid_points(self):
        return get_coords(self.sim_paths[0])

    def get_branch_input_dims(self):
        C, H, W = self.sims.shape[1], self.sims.shape[3], self.sims.shape[4]
        return C * H * W
        
    def get_trunk_input_dims(self):
        return 3  # r, theta, phi
    def get_grid_points_dim(self):
        return self.r, self.h, self.w

In [11]:
def make_gifs(preds, targets, shape, dataset, out_dir):
    C, H, W = shape

    # reshape flat to (frames, H, W)
    preds_grid = preds.reshape(-1, H, W)
    targets_grid = targets.reshape(-1, H, W)
    error_grid = np.abs(preds_grid - targets_grid)

    np.save(os.path.join(out_dir, "preds_grid.npy"), preds_grid)
    np.save(os.path.join(out_dir, "targets_grid.npy"), targets_grid)
    np.save(os.path.join(out_dir, "error_grid.npy"), error_grid)

    # make gifs
    create_gif_from_array(preds_grid,   os.path.join(out_dir, "preds.gif"))
    create_gif_from_array(targets_grid, os.path.join(out_dir, "targets.gif"))
    create_gif_from_array(error_grid,   os.path.join(out_dir, "error.gif"))


if __name__ == "__main__":
    # paths
    MODEL = "/app/output/DeepONetCNN/2025_11_20__074500/best_model.pt"
    DATA = "/app/data"

    out_dir = "./eval_outputs"

    cr_dirs = get_cr_dirs(DATA)
    cr_eval = cr_dirs[:32]   # choose some CRs

    preds, targets, shape, dataset = run_inference(
        model_path=MODEL,
        data_dir=DATA,
        cr_list=cr_eval,
        out_dir=out_dir,
        batch_size=8,
        branch_layers=[128,128,128,128],
        trunk_layers=[128,128,128,128],
        pos_embedding=dataset.pos_embedding if hasattr(dataset, "pos_embedding") else None,
        trunk_sample_size=dataset.trunk_sample_size if hasattr(dataset, "trunk_sample_size") else None,
    )

    make_gifs(preds, targets, shape, dataset, out_dir)
    print("GIFs created in", out_dir)

NameError: name 'dataset' is not defined

In [None]:
def predict_full_grid_in_chunks(model, branch, coords, H, W, chunk_size=32768, accelerator=None):
    """
    model: DeepONet
    branch: (1, C, H, W)
    coords: (N, 3)
    """
    device = next(model.parameters()).device
    branch = branch.to(device)
    coords = coords.to(device)

    N = coords.shape[0]
    preds = torch.zeros(N, device=device)

    model.eval()
    with torch.no_grad():
        for start in range(0, N, chunk_size):
            end = min(start + chunk_size, N)
            coords_chunk = coords[start:end].unsqueeze(0)        # (1, n_chunk, 3)

            if accelerator:
                with accelerator.autocast():
                    y_chunk = model(branch, coords_chunk)         # (1, n_chunk)
            else:
                y_chunk = model(branch, coords_chunk)

            preds[start:end] = y_chunk[0]

    return preds.reshape(H, W)

In [None]:
def main():

    parser = argparse.ArgumentParser(description='Document helper.....')
    parser.add_argument('--ngpu', type=int, default=0, help='set the gpu on which the model will run')
    
    args = parser.parse_args()
    ngpu      = args.ngpu
    
    with open('/app/src/DeepONet/config.toml', 'r') as f:
        config = toml.load(f)
    
    DATA_DIR = config['train_params']['data_dir']
    BASE_DIR = config['train_params']['base_dir']
    batch_size = config['train_params']['batch_size']


    model_type = config['model_params']['model_type']
    scale_up = config['model_params']['scale_up']
    loss_fn_str = config['model_params']['loss_fn']
    pos_embedding = config['model_params']['pos_embedding']
    trunk_sample_size = config['model_params']['trunk_sample_size']
    branch_layers = config['model_params'].get('branch_layers', [128,128,128,128])
    trunk_layers = config['model_params'].get('trunk_layers', [128,128,128,128])

    cr_dirs = get_cr_dirs(DATA_DIR)
    split_ix = int(len(cr_dirs) * 0.8)
    cr_train, cr_val = cr_dirs[:split_ix], cr_dirs[split_ix:]
    cr_val = cr_val[::len(cr_val)//10]
    train_dataset = DeepONetDataset(DATA_DIR, cr_train, scale_up=scale_up, pos_embedding=pos_embedding, trunk_sample_size=trunk_sample_size)   
    val_dataset = DeepONetDataset(
        DATA_DIR,
        cr_val,
        scale_up=scale_up,
        v_min=train_dataset.v_min,
        v_max=train_dataset.v_max,
        pos_embedding=pos_embedding,
        trunk_sample_size=trunk_sample_size
    )
    device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
    radii, thetas, phis = train_dataset.get_grid_points()

    if loss_fn_str == "l2":
        loss_fn = LpLoss(d=2, p=2)
    elif loss_fn_str == "h1":
        loss_fn = H1LossSpherical(r_grid=radii[1:], theta_grid=thetas, phi_grid=phis)
    elif loss_fn_str == "h1mae":
        loss_fn = H1LossSphericalMAE(r_grid=radii[1:], theta_grid=thetas, phi_grid=phis)
    elif loss_fn_str == "mse":
        loss_fn = nn.MSELoss()
    else:
        raise ValueError("unsupported loss function")

    out_path = os.path.join(BASE_DIR, model_type, job_id)

        
    if pos_embedding == 'pt':
        in_channels = 4
    elif pos_embedding == 'ptr':
        raise ValueError('radii embedding is the same in full channel and is not supported here')
    elif pos_embedding is None:
        in_channels = 1
    else:
        raise ValueError('wrong pos embedding')
    
    model = DeepONet(
        in_channels=in_channels,
        trunk_in_dim=3,
        latent_dim=128,
        trunk_hidden=trunk_layers,
    ).to(device)
    
    print(model)
    batch_size = 6
    

    gen_cpu = torch.Generator(device="cuda")
    gen_cpu.manual_seed(42)  # optional, for reproducibility    # Make DataLoaders use CPU RNG to avoid device mismatch

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=False,
        generator=gen_cpu,
    )

    model.load_state_dict(torch.load(os.path.join(BASE_DIR), map_location='cpu', weights_only=True))
    model = model.to(device)

    torch.save(best_state_dict, os.path.join(out_path, "best_model.pt"))
    if run is not None:
        artifact = wandb.Artifact(
            name='best_model',
            type='model',
            description='best model after training'
        )
        artifact.add_file(os.path.join(out_path, f"best_model.pt"))
        run.log_artifact(artifact)

    filename = f"best_epoch-{best_epoch}.txt"
    with open(
        os.path.join(out_path, filename), "w", encoding="utf-8"
    ) as f:
        f.write(f"best_epoch: {best_epoch}")
    if run is not None:
        artifact = wandb.Artifact(
            name='best_epoch',
            type='evaluation',
            description='epoch with lowest validation loss'
        )
        artifact.add_file(os.path.join(out_path, filename))
        run.log_artifact(artifact)

    save_training_results_artifacts(run, out_path, training_results)

    print("Training completed.")
    wandb.finish()