In [1]:
import numpy as np
import math
import xarray as xr
import os
import tensorflow as tf
import xbatcher as xb
import xbatcher.loaders.keras
import copy

os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
from cnn import CNN, Scenario, ReplicationPadding2D, MaskedMSELoss
import preprocess_data

2026-02-03 10:41:42.957987: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-03 10:41:43.002313: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-03 10:42:26.698090: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.


Some functions

In [2]:
def _add_dimension_and_slice(ds, x_slice, y_slice, region):
    '''
        Adds a dimension to allow all regional datasets to 
            be combined.
        Slice the data based on input x and y.
    '''
    ds = ds.expand_dims({"r": 1})
    for coord in ds.coords:
        if coord != "t": #* omit t dimension
            ds[coord] = ds[coord].expand_dims({"r": 1})

    ds_tmp = ds.isel(x_c=x_slice, y_c=y_slice)

    return ds_tmp

In [3]:
def _slice_data(ds, x_slice, y_slice):
    '''
        Slice the data based on input x and y.
    '''

    ds_tmp = ds.isel(x_c=x_slice, y_c=y_slice)

    return ds_tmp

In [4]:
def open_and_process_data(scenario, directory, filenames, domain):
    '''
        Opens and preprocesses the data by feeding in directory and filenames
            into the preprocessing step.
            n.b. this only provides local normalization related to each sub domain.
                use open_and_combine_data then feed in datasets for global 
                normalization.
    '''
    ds = {}

    sc = copy.deepcopy(scenario) # make a copy
    print(sc)

    for region in domain:

        directory_region = directory.format(domain=region)
        fnames_region = [f.format(domain=region) for f in filenames]

        processor = preprocess_data.data_preparation(scenario, 
                                                     directory=directory_region, 
                                                     filenames=fnames_region,
                                                     parallel=False,
                                                     )
        ds[region] = processor()
        # print(ds[region])

        scenario = copy.deepcopy(sc)
        print(scenario)

    dom_slice_dict = {
        'dDP': (slice(2, 37), slice(30, 65)),
        'uDP': (slice(2, 37), slice(10, 45)),
        'SP': (slice(2, 37), slice(15, 50)),
        'IO': (slice(2, 37), slice(15, 50)),
        'SO_JET': (slice(3, 43), slice(7, 47)),
        }
    
    for region in domain:
        ds_region = _slice_data(ds[region], 
                                 dom_slice_dict[region][0],
                                 dom_slice_dict[region][1],
                                )

        ds[region] = ds_region

    dataset_list = [ds[region] for region in domain]

    ds_combined = xr.concat(dataset_list, dim='r')

    return ds_combined

Set up the scenario

In [5]:
sc = Scenario(['mke', 'vor', 'sa', 'eke_shift'], ['eke'],
              [256, 128, 64, 32, 32, 1], 
              [(5,5), (3,3), (3,3), (3,3), (3,3), (1,1)], 
              [(2,2), (1,1), (1,1), (1,1), (1,1), (0,0)],
              name = 'testing')

Open and process data using local normalization

In [6]:
directory = "/gws/nopw/j04/ai4pex/twilder/NEMO_data/DINO/EXP16/features/{domain}/coarsened_data/"

fns = ["MINT_1d_0061-0072_sa_c_{domain}.nc",
       "MINT_1d_0061-0072_vor_cg_{domain}_mod.nc",
       "MINT_1d_0061-0072_eke_c_{domain}.nc",
       "MINT_1d_0061-0072_eke_c_{domain}_shifted.nc",
       "MINT_1d_0061-0072_mke_c_{domain}.nc",
       "mesh_mask_exp4_{domain}_xnemo.nc"]

domain = ['SO_JET']

ds = open_and_process_data(sc, directory, fns, domain)

Scenario(input_var=['mke', 'vor', 'sa', 'eke_shift'], target=['eke'], filters=[256, 128, 64, 32, 32, 1], kernels=[(5, 5), (3, 3), (3, 3), (3, 3), (3, 3), (1, 1)], padding=[(2, 2), (1, 1), (1, 1), (1, 1), (1, 1), (0, 0)], name='testing')
Scenario(input_var=['mke', 'vor', 'sa', 'eke_shift'], target=['eke'], filters=[256, 128, 64, 32, 32, 1], kernels=[(5, 5), (3, 3), (3, 3), (3, 3), (3, 3), (1, 1)], padding=[(2, 2), (1, 1), (1, 1), (1, 1), (1, 1), (0, 0)], name='testing')


In [7]:
ds

In [10]:
directory = "/gws/nopw/j04/ai4pex/twilder/NEMO_data/DINO/EXP16/features/SO_JET/coarsened_data/"
ds.to_netcdf(directory + 'preprocessed_SO_JET_data.nc')

Subset the data for prediction

In [8]:
# ----------------------------
# 2. Define split sizes
# ----------------------------
n_test = 359     # last 30 days
n_val  = 360     # 60 days before test
train_stride = 5  # every 4th day

nt = ds.sizes["t"]

# ----------------------------
# 3. Create time indices
# ----------------------------
test_idx = np.arange(nt - n_test, nt)
val_idx  = np.arange(nt - n_test - n_val, nt - n_test)
train_idx_full = np.arange(0, nt - n_test - n_val)

# ----------------------------
# 4. Subsample training every 4th day
# ----------------------------
train_idx = train_idx_full[::train_stride]

# ----------------------------
# 5. Create split datasets
# ----------------------------
ds_train = ds.isel(t=train_idx)
ds_val   = ds.isel(t=val_idx)
ds_test  = ds.isel(t=test_idx)

print("Train:", ds_train.sizes["t"])
print("Val:  ", ds_val.sizes["t"])
print("Test: ", ds_test.sizes["t"])

Train: 721
Val:   360
Test:  359


In [21]:
ds_train

Loading the trained model

In [9]:
# load in keras model
filename = 'training/cnn_20260203-092450.keras'
model = keras.saving.load_model(
    filename,
    compile=True,
)
model.summary()

2026-02-03 10:46:21.091001: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [10]:
ds_prediction = ds_test.isel(r=0)

Now making a prediction

In [11]:
# retrieve input and taget variables
batch_input  = [ds_prediction[x] for x in sc.input_var]
batch_target  = [ds_prediction[x] for x in sc.target]
# adds an additional dimension for tf readability
batch_input  = xr.merge(batch_input).to_array('var').transpose(...,'var') # channels
batch_target  = xr.merge(batch_target).to_array('var').transpose(...,'var')

target = model.predict(batch_input.to_numpy())

[1m12/12[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 40ms/step


In [12]:
pred_ds = xr.Dataset(
        data_vars={
            "fine_ke_pred": (['time_counter', 'y', 'x', 'var'], np.exp(target)),  # Adjusted to match the model output
            "fine_ke_true": (['time_counter', 'y', 'x', 'var'], np.exp(batch_target.to_numpy())),
        },
        coords={
            "time_counter": (["time_counter"], ds_prediction.t.values, 
                             ds_prediction.t.attrs),
            "gphit": (["y", "x"], ds_prediction.gphit.values, 
                      {"standard_name": "Latitude", "units": "degrees_north"}),
            "glamt": (["y", "x"], ds_prediction.glamt.values, 
                      {"standard_name": "Longitude","units": "degrees_east"}),
            'var': sc.target,
        },
        attrs={
            'Title': 'Fine kinetic energy - predicted and truth (region IO)',
            'Description': 'Predicted fine kinetic energy from coarse-grained data using CNN',
            'Units': 'm^2/s^2',
            'Source': f'{filename}',
        }
    )

In [13]:

pred_ds.to_netcdf('predictions/eddy_energy_20260203-092450_SO_JET.nc')