In [None]:
# Improve by ChatGPT & GitHub Copilot
# Running the requierements.ipynb
%run /bettik/PROJECTS/pr-data-ocean/riverama/Notebooks/requierements.ipynb

## Filter

In [None]:
def transform_to_regular_grid(ssh_hf_1h_masked):
    # Set up UTM Zone 58S projection
    proj = Proj(proj='utm', zone=58, south=True, ellps='WGS84')

    # Get latitude and longitude data
    lat = ssh_hf_1h_masked.lat.values
    lon = ssh_hf_1h_masked.lon.values

    # Transform lat, lon to UTM coordinates
    x, y = np.vectorize(proj)(lon, lat)

    # Convert coordinates from meters to kilometers
    x_km = x / 1000
    y_km = y / 1000

    x_min, x_max = x_km.min(), x_km.max()
    y_min, y_max = y_km.min(), y_km.max()

    # Define new grid
    x_new = np.arange(x_min, x_max, dx)
    y_new = np.arange(y_min, y_max, dy)
    x_grid, y_grid = np.meshgrid(x_new, y_new)

    # Interpolate to the new grid using 'nearest' for robust handling of NaNs
    data_interp = griddata(
        (x_km.flatten(), y_km.flatten()),
        ssh_hf_1h_masked.values.flatten(),
        (x_grid, y_grid),
        method='nearest'
    )

    # Create new DataArray with the interpolated data on the regular grid
    ssh_hf_1h_masked_regular = xr.DataArray(data_interp, coords=[('y', y_new), ('x', x_new)], dims=['y', 'x'])
    
    return ssh_hf_1h_masked_regular


In [None]:
def fill_nan(ssh_hf_1h_masked_regular):
    # Define the axes correctly using pyinterp.core.Axis
    x_axis = pyinterp.core.Axis(ssh_hf_1h_masked_regular.x.values, is_circle=True)
    y_axis = pyinterp.core.Axis(ssh_hf_1h_masked_regular.y.values, is_circle=True)
    
    # Prepare the grid using pyinterp.Grid2D
    grid = pyinterp.Grid2D(y_axis, x_axis, ssh_hf_1h_masked_regular.values)

    # Apply Gauss-Seidel method to fill data
    has_converged, ssh_hf_1h_masked_regular_filled = fill.gauss_seidel(grid, num_threads=16)
    
    return has_converged, ssh_hf_1h_masked_regular_filled

In [None]:
def extend(ssh,nx,ny):
    ssh_extended = np.empty((3*ny,3*nx))
    ssh_extended[ny:2*ny,nx:2*nx] = +ssh
    ssh_extended[0:ny,nx:2*nx] = +ssh[::-1,:]
    ssh_extended[2*ny:3*ny,nx:2*nx] = +ssh[::-1,:]
    ssh_extended[:,0:nx] = ssh_extended[:,nx:2*nx][:,::-1]
    ssh_extended[:,2*nx:3*nx] = ssh_extended[:,nx:2*nx][:,::-1]
    return ssh_extended

In [None]:
def lowpass(_lambda,nx,ny,wavenum2D) : 
    _lowpass = np.zeros((3*ny,3*nx))
    for i in range (3*ny):
        for j in range(3*nx):
            if wavenum2D[i,j]<1/_lambda:
                _lowpass[i,j] = 1
    return _lowpass 

In [None]:
def gaspari_cohn(array,distance,center):
    if type(array) is float or type(array) is int:
        array = np.array([array])
    else:
        array = array
    if distance<=0:
        return np.zeros_like(array)
    else:
        array = 2*np.abs(array-center*np.ones_like(array))/distance
        gp = np.zeros_like(array)
        i= np.where(array<=1.)[0]
        gp[i]=-0.25*array[i]**5+0.5*array[i]**4+0.625*array[i]**3-5./3.*array[i]**2+1.
        i =np.where((array>1.)*(array<=2.))[0]
        gp[i] = 1./12.*array[i]**5-0.5*array[i]**4+0.625*array[i]**3+5./3.*array[i]**2-5.*array[i]+4.-2./3./array[i]
        #if type(r) is float:
        #    gp = gp[0]
    return gp

In [None]:
def create_spatial_window(nx,ny):
    result = np.ones((3*ny,3*nx))
    
    gaspari_x = np.expand_dims(gaspari_cohn(np.arange(2*nx),nx,nx),axis=0)
    gaspari_y = np.expand_dims(gaspari_cohn(np.arange(2*ny),ny,ny),axis=1)

    #paving edges with gaspari-cohn
    result[2*ny:,nx:2*nx] = np.repeat(gaspari_y[ny:,:],repeats=nx,axis=1)
    result[:ny,nx:2*nx] = np.repeat(gaspari_y[:ny,:],repeats=nx,axis=1)
    result[ny:2*ny,0:nx] = np.repeat(gaspari_x[:,:nx],repeats=ny,axis=0)
    result[ny:2*ny,2*nx:] = np.repeat(gaspari_x[:,nx:],repeats=ny,axis=0)

    #paving corners with gaspari-cohn
    result[2*ny:,2*nx:]=gaspari_y[ny:,:]*gaspari_x[:,nx:]
    result[:ny,:nx]=gaspari_y[:ny,:]*gaspari_x[:,:nx]
    result[2*ny:,:nx]=gaspari_y[ny:,:]*gaspari_x[:,:nx]
    result[:ny,2*nx:]=gaspari_y[:ny,:]*gaspari_x[:,nx:]

    return result 

In [None]:
def bar_igw_filter(date_str, lambda_bar=100):
    # Parse the date string to get the month
    date = pd.to_datetime(date_str)
    month = date.month

    # Construct the file path using the extracted month
    file_path = f"/bettik/PROJECTS/pr-data-ocean/riverama/Datos/Filtrage/ssh_hf/ssh_hf_{month:02}.nc"
    ssh_hf = xr.open_mfdataset(file_path)
    
    # Date selection
    ssh_hf_1h = ssh_hf.rename({'__xarray_dataarray_variable__': 'ssh_hf'}).sel(time_counter=date).load()
    
    # Applying mask
    mask = xr.open_dataset('/bettik/PROJECTS/pr-data-ocean/riverama/Datos/CALEDO60/1_mesh_mask_TROPICO12_L125_tr21.nc', drop_variables={"x", "y"}) 
    ssh_hf_1h_masked = ssh_hf_1h['ssh_hf'].where(mask.tmaskutil[0,:,:] == 1, np.nan).load()

    # Transform and fill missing data
    ssh_hf_1h_masked_regular = transform_to_regular_grid(ssh_hf_1h_masked)
    ssh_hf_1h_masked_regular_filled = fill_nan(ssh_hf_1h_masked_regular)[1]
    
    # Extend the filled data
    nx = int(ssh_hf_1h_masked_regular_filled.shape[1])
    ny = int(ssh_hf_1h_masked_regular_filled.shape[0])
    ssh_hf_1h_masked_regular_filled_extended = extend(ssh_hf_1h_masked_regular_filled, nx, ny)
    
    # Frequency domain processing
    kx = np.fft.fftfreq(3*nx, dx)  # in km
    ky = np.fft.fftfreq(3*ny, dy)  # in km
    k, l = np.meshgrid(kx, ky)
    wavenum2D = np.sqrt(k**2 + l**2)
    
    # Apply window function
    window = create_spatial_window(nx, ny)
    ssh_hf_1h_masked_regular_filled_extended_windowed = ssh_hf_1h_masked_regular_filled_extended * window
    
    # Filter in the frequency domain
    lowpass_bar = lowpass(lambda_bar, nx, ny, wavenum2D)
    ssh_hf_freq = fp.fft2(ssh_hf_1h_masked_regular_filled_extended_windowed)
    ssh_freq_filtered = lowpass_bar * ssh_hf_freq
    ssh_bar = np.real(fp.ifft2(ssh_freq_filtered))[ny:2*ny, nx:2*nx]
    
    # Calculate the internal gravity waves (IGW)
    ssh_igw = ssh_hf_1h_masked_regular_filled - ssh_bar
    
    return ssh_igw, ssh_bar

## Regridding

In [None]:
# Loading ssh_hf and making and selecting 1h for testing
ssh_hf = xr.open_mfdataset("/bettik/PROJECTS/pr-data-ocean/riverama/Datos/Filtrage/ssh_hf/ssh_hf_01.nc")
ssh_hf_1h = ssh_hf.rename({'__xarray_dataarray_variable__': 'ssh_hf'}).sel(time_counter=pd.to_datetime('2014-01-10T00:30:00')).load()

# Appying the mask to ssh_hf_1h, keep data where mask is equal to 1
mask = xr.open_dataset('/bettik/PROJECTS/pr-data-ocean/riverama/Datos/CALEDO60/1_mesh_mask_TROPICO12_L125_tr21.nc',drop_variables={"x","y"}) 
ssh_hf_1h_masked = ssh_hf_1h['ssh_hf'].where(mask.tmaskutil[0,:,:] == 1, np.nan).load()

In [None]:
# Ensure pyproj is correctly configured for UTM Zone 58S
utm_proj = pyproj.Proj(proj='utm', zone=58, ellps='WGS84', south=True)

# Apply the projection transformation directly on the 2D coordinate arrays
utm_x, utm_y = utm_proj(ssh_hf_1h_masked['lon'].values, ssh_hf_1h_masked['lat'].values)

# Add UTM coordinates to the DataArray
ssh_hf_1h_masked = ssh_hf_1h_masked.assign_coords({
    "utm_x": (("y", "x"), utm_x),
    "utm_y": (("y", "x"), utm_y)
})

In [None]:
plt.figure()
plt.pcolormesh(ssh_hf_1h_masked['utm_x'], ssh_hf_1h_masked['utm_y'], ssh_hf_1h_masked)
plt.show()

In [None]:
dx = 1.7
dy = 1.7

In [None]:
ssh_hf_1h_masked_regular = transform_to_regular_grid(ssh_hf_1h_masked)

In [None]:
date_to_process = '2014-01-10T00:30:00'
lambda_value = 400  # km
ssh_igw, ssh_bar = bar_igw_filter(date_to_process, lambda_bar=lambda_value)

In [None]:
rows, columns = ssh_igw.shape
print("Number of rows:", rows)
print("Number of columns:", columns)

In [None]:
# Define the coordinates based on the size of your ssh_igw_array and the provided example values
y_coords = np.linspace(6791.458971, 8263.658971, num=867)  # Adjusted to match the number of rows in ssh_igw_array
x_coords = np.linspace(-118.79531, 1288.80469, num=829)  # Adjusted to match the number of columns in ssh_igw_array

# Create the DataArray
ssh_igw_da = xr.DataArray(data=ssh_igw, coords=[("y", y_coords), ("x", x_coords)])

In [None]:
plt.figure()
plt.pcolormesh(ssh_igw_da['x'], ssh_igw_da['y'], ssh_igw_da)
plt.show()

In [None]:
ssh_igw_da

In [None]:
# Define the projection for UTM (specify the correct zone and whether it's north or south)
proj_utm = Proj(proj='utm', zone=58, ellps='WGS84', south=True)

# Define the projection for geographic coordinates
proj_geo = Proj(proj='latlong', datum='WGS84')

# Example coordinates transformation from UTM to geographic
lon, lat = transform(proj_utm, proj_geo, utm_x, utm_y)

In [None]:
interpolated_da = ssh_igw_da.interp(x=ssh_hf_1h_masked.lon, y=ssh_hf_1h_masked.lat, method='linear')


In [None]:
interpolated_da

In [None]:
plt.figure()
plt.pcolormesh(interpolated_da['lon'], interpolated_da['lat'], interpolated_da)
plt.show()

In [None]:
from pyproj import Proj, transform

# Define the projection for UTM (specify the correct zone and whether it's north or south)
proj_utm = Proj(proj='utm', zone=58, ellps='WGS84', south=True)

# Define the projection for geographic coordinates
proj_geo = Proj(proj='latlong', datum='WGS84')

# Example coordinates transformation from UTM to geographic
lon, lat = transform(proj_utm, proj_geo, utm_x, utm_y)