In [None]:
# Improve by ChatGPT & GitHub Copilot
# Running the requierements.ipynb
%run /bettik/PROJECTS/pr-data-ocean/riverama/Notebooks/requierements.ipynb

# Running functions_IT.ipynb
%run /bettik/PROJECTS/pr-data-ocean/riverama/Notebooks/OSSE_borrador/IT_modes/functions_IT_modes_v1.ipynb

In [None]:
# Define a function to process a single time step
def process_time_step(time_step, ds, output_dir):
    ssh_it = ds.sel(time_counter=time_step)  # Select one time step

    # To cartesian grid
    dx = 2  # in kilometers
    ENSLAT2D, ENSLON2D, i_lat, i_lon = create_cartesian_grid(ssh_it.nav_lat[:,0].values, ssh_it.nav_lon[0,:].values, dx)
    ENSLAT2D_flat = ENSLAT2D.flatten()
    ENSLON2D_flat = ENSLON2D.flatten()

    points = np.column_stack((ssh_it.nav_lat.data.flatten(), ssh_it.nav_lon.data.flatten()))
    values = ssh_it.ssh_it.data.flatten()
    target_grid = np.column_stack((ENSLAT2D_flat, ENSLON2D_flat))
    ssh_it_cart = griddata(points, values, target_grid, method='linear')
    ssh_it_cart = ssh_it_cart.reshape(ENSLAT2D.shape)

    # Interpolation of NaNs
    x_axis = Axis(np.arange(i_lon))
    y_axis = Axis(np.arange(i_lat))
    grid = Grid2D(y_axis, x_axis, ssh_it_cart.reshape((i_lat, i_lon)))
    has_converged, filled = fill.gauss_seidel(grid)

    ssh_it_cart_filled = xr.DataArray(
        data=filled,
        dims=["y", "x"],
        coords={"y": np.arange(i_lat), "x": np.arange(i_lon)}
    )

    # Define wavenumbers
    k1 = 0.00252
    k2 = 0.00964
    k3 = 0.02100
    k4 = 0.03100
    nx = ssh_it_cart_filled['x'].size
    ny = ssh_it_cart_filled['y'].size

    kx = np.fft.fftfreq(3*nx, dx)  # km
    ky = np.fft.fftfreq(3*ny, dx)  # km
    k, l = np.meshgrid(kx, ky)
    wavenum2D = np.sqrt(k**2 + l**2)

    wavenumbers_mode_1 = [0.5*k1, 0.5*(k1+k2)]  # in km
    wavenumbers_mode_2 = [0.5*(k1+k2), 0.5*(k2+k3)]  # in km
    wavenumbers_mode_3 = [0.5*(k2+k3), 0.5*(k3+k4)]  # in km

    # Apply the bandpass filters
    bandpass_mode_1 = bandpass(wavenumbers=wavenumbers_mode_1, nx=nx, ny=ny, wavenum2D=wavenum2D)
    bandpass_mode_2 = bandpass(wavenumbers=wavenumbers_mode_2, nx=nx, ny=ny, wavenum2D=wavenum2D)
    bandpass_mode_3 = bandpass(wavenumbers=wavenumbers_mode_3, nx=nx, ny=ny, wavenum2D=wavenum2D)

    # Create the spatial window
    window = create_spatial_window(nx=nx, ny=ny)

    # Extract modes
    results = extract_it_mode(
        ssh_it_cart_filled.values, window, dx, bandpass_mode_1, bandpass_mode_2, bandpass_mode_3
    )
    ssh_it_filtered_1, ssh_it_filtered_2, ssh_it_filtered_3 = results

    lon2d, lat2d = np.meshgrid(ssh_it.nav_lon[0,:].values, ssh_it.nav_lat[:,0].values)

    ssh_it_filtered_1_array = interpolate_back(ssh_it_filtered_1, lat2d, lon2d, ENSLAT2D, ENSLON2D)
    ssh_it_filtered_2_array = interpolate_back(ssh_it_filtered_2, lat2d, lon2d, ENSLAT2D, ENSLON2D)
    ssh_it_filtered_3_array = interpolate_back(ssh_it_filtered_3, lat2d, lon2d, ENSLAT2D, ENSLON2D)

    time_counter = ssh_it.time_counter.values

    ssh_it_mode_1 = xr.DataArray(
        data=ssh_it_filtered_1_array[0][:, :, None],  # Add a new axis for time_counter
        dims=['nav_lat', 'nav_lon', 'time_counter'],
        coords={'time_counter': [time_counter], 'nav_lat': ssh_it.nav_lat[:, 0].values, 'nav_lon': ssh_it.nav_lon[0, :].values}
    )

    ssh_it_mode_2 = xr.DataArray(
        data=ssh_it_filtered_2_array[0][:, :, None],  # Add a new axis for time_counter
        dims=['nav_lat', 'nav_lon', 'time_counter'],
        coords={'time_counter': [time_counter], 'nav_lat': ssh_it.nav_lat[:, 0].values, 'nav_lon': ssh_it.nav_lon[0, :].values}
    )

    ssh_it_mode_3 = xr.DataArray(
        data=ssh_it_filtered_3_array[0][:, :, None],  # Add a new axis for time_counter
        dims=['nav_lat', 'nav_lon', 'time_counter'],
        coords={'time_counter': [time_counter], 'nav_lat': ssh_it.nav_lat[:, 0].values, 'nav_lon': ssh_it.nav_lon[0, :].values}
    )

    # Create a single xarray.Dataset
    it_modes = xr.Dataset(
        {
            'ssh_it_mode_1': ssh_it_mode_1,
            'ssh_it_mode_2': ssh_it_mode_2,
            'ssh_it_mode_3': ssh_it_mode_3,
        }
    )

    # Determine the month
    month = pd.to_datetime(time_step).strftime('%m')
    year_month = pd.to_datetime(time_step).strftime('%Y-%m')

    # Save datasets by month
    output_path = os.path.join(output_dir, f"it_modes_{year_month}.nc")

    if os.path.exists(output_path):
        existing_ds = xr.open_dataset(output_path)
        combined_ds = xr.concat([existing_ds, it_modes], dim='time_counter')
        combined_ds.to_netcdf(output_path)
    else:
        it_modes.to_netcdf(output_path)

    print(f"Processed and saved data for time step: {time_step}")

# Load dataset
ds = xr.open_mfdataset("/bettik/PROJECTS/pr-data-ocean/riverama/Datos/Filtrage/ssh_it/ssh_it_*.nc",combine='nested',concat_dim='time_counter',parallel=True)

# Define time range and output paths
start = '2014-05-01T00:30:00'
end = '2014-07-31T23:30:00'
output_dir = "/bettik/PROJECTS/pr-data-ocean/riverama/Datos/Filtrage/ssh_it_modes"

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Time steps
time_steps = pd.date_range(start=start, end=end, freq='1H')

# Parallel processing
Parallel(n_jobs=-1, backend='multiprocessing')(delayed(process_time_step)(time_step, ds, output_dir) for time_step in time_steps)

print("Processing complete.")
