<a href="https://colab.research.google.com/github/sanAkel/ufs_diurnal_diagnostics/blob/main/GFS/Nov2025/repair_tests.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import xarray as xr
import numpy as np

import datetime

import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
def fix_dataset(fName, date):
    """
    Opens an xarray dataset, replaces its 'Time' coordinate with a specified datetime
    and masks out all data_variables where 'Salt' is 0.

    Args:
        fName (str): The path to the netCDF file.
        date (datetime.datetime): The datetime object to use for the 'Time' coordinate.

    Returns:
        xr.Dataset: The modified xarray dataset with the fixed 'Time' coordinate
                    and masked values.
    """
    ds = xr.open_dataset(fName)
    # Assuming 'Time' should be a single value, similar to the original dataset's structure.
    ds['Time'] = (('Time',), np.array([date], dtype='datetime64[ns]'))

    # Create a mask where Salt is 0
    salt_mask = ds['Salt'] == 0

    # Apply the mask to all data_variables
    for var_name in ['Temp', 'Salt', 'ave_ssh']:
        # Ensure the variable has 'Salt's dimensions, otherwise masking might fail
        # For simplicity, we assume 'Salt's dimensions cover other data_vars.
        # More robust solution would check dimensions or broadcast.
        if 'Salt' in ds.data_vars and var_name != 'Salt': # Don't mask Salt with itself if it's already 0
            # Apply the mask: where salt_mask is True (Salt is 0), set to NaN
            ds[var_name] = ds[var_name].where(~salt_mask)
        elif var_name == 'Salt':
            # Also mask Salt itself if it's 0
            ds[var_name] = ds[var_name].where(~salt_mask)

    return ds

In [None]:
data_path = "/content/drive/MyDrive/UFS-no-RTOFS/GFS/rt17_upd03_realtime/"
fName = data_path + "gfs.t00z.ocn.ana.nc"

fModel = data_path + "gfs.ocean.t00z.6hr_avg.f006.nc"

In [None]:
ds_model = xr.open_dataset(fModel, decode_times=False)[["xh", "yh", "geolon", "geolat"]]

file_date = datetime.datetime(2024, 10, 24, 0, 0, 0)
ds = fix_dataset(fName, file_date)

# Get the values of xh and yh from ds_model for re-assignment
xh_values = ds_model['xh'].values
yh_values = ds_model['yh'].values

# Assign 'xh' and 'yh' as coordinates, explicitly linking them to the 'xaxis_1' and 'yaxis_1' dimensions.
ds = ds.assign_coords(
    xh=('xaxis_1', xh_values),
    yh=('yaxis_1', yh_values)
)

# Now, swap the dimensions 'yaxis_1' with 'yh' and 'xaxis_1' with 'xh'.
# This will update the dimensions of all data variables in the dataset.
ds = ds.swap_dims({'yaxis_1': 'yh', 'xaxis_1': 'xh'})

# Re-assign 'geolon' and 'geolat' as coordinates, ensuring they are indexed by the new 'yh' and 'xh' dimensions.
ds = ds.assign_coords(
    geolon=(('yh', 'xh'), ds_model['geolon'].values),
    geolat=(('yh', 'xh'), ds_model['geolat'].values)
)

In [None]:
# Plot ave_ssh with the new dimensions and coordinates
ds.isel(zaxis_1=0).ave_ssh.plot(x="geolon", y="geolat", vmin=-1.5, vmax=1.5, cmap='jet')

In [None]:
ds.isel(zaxis_1=0).sel(xh=slice(-85, -60), yh=slice(15, 30)).ave_ssh.plot(x="geolon", y="geolat", vmin=-0.6, vmax=0.6, cmap='bwr')