In [1]:
# Improve by ChatGPT & GitHub Copilot
# Running the requierements.ipynb
%run /bettik/PROJECTS/pr-data-ocean/riverama/Notebooks/requierements.ipynb

In [2]:
ssh_hf = xr.open_mfdataset("/bettik/PROJECTS/pr-data-ocean/riverama/Datos/Filtrage/ssh_hf/ssh_hf_01.nc")

In [3]:
ssh_hf_1h = ssh_hf.rename({'__xarray_dataarray_variable__': 'ssh_hf'}).sel(time_counter=pd.to_datetime('2014-01-03T15:30:00'))

In [None]:
# Low pass filter
def lowpass(_lambda,nx,ny,wavenum2D) : 
    """
    Name: lowpass
    
    Description: This function generates a low-pass filter matrix. 
    The filter will pass low-frequency signals with a wave number less than 1/_lambda and attenuate signals with higher frequencies.
    
    Argss:
    _lambda (float): The cutoff wavelength. Only frequencies corresponding to wavelengths longer than this value will be allowed through the filter.
    nx (int): The no. of horizontal elements in the original data array.
    ny (int): The no. of vertical elements in the original data array.
    wavenum2D (2D array): A 2D array representing the wave numbers associated with each element in a data array.

    Returns: A 2D array (_lowpass) which is of size 3*ny by 3*nx that acts as a mask 
    where the value 1 indicates passing the frequency and 0 indicates blocking the frequency.
    """
    _lowpass = np.zeros((3*ny,3*nx))
    for i in range (3*ny):
        for j in range(3*nx):
            if wavenum2D[i,j]<1/_lambda:
                _lowpass[i,j] = 1
    return _lowpass

In [None]:
# extend
def extend(ssh,nx,ny):
    """
    Name: extend
    
    Description: Extends a given 2D array by tiling it with its mirrored versions in both dimensions. 
    This creates a larger array that contains the original array at its center, surrounded by flipped copies of itself along the borders.

    Parameters:

    ssh (2D array): The 2D array that needs to be extended.
    nx (int): The number of columns in the original array ssh.
    ny (int): The number of rows in the original array ssh.
    
    Returns:

    ssh_extended (2D array): An extended array of shape (3*ny, 3*nx), which is three times the size of the original in both dimensions.
    """
    ssh_extended = np.empty((3*ny,3*nx))
    ssh_extended[ny:2*ny,nx:2*nx] = +ssh
    ssh_extended[0:ny,nx:2*nx] = +ssh[::-1,:]
    ssh_extended[2*ny:3*ny,nx:2*nx] = +ssh[::-1,:]
    ssh_extended[:,0:nx] = ssh_extended[:,nx:2*nx][:,::-1]
    ssh_extended[:,2*nx:3*nx] = ssh_extended[:,nx:2*nx][:,::-1]
    return ssh_extended

In [None]:
# Gaspari Cohn
def gaspari_cohn(array,distance,center):
    """
    Name: bfn_gaspari_cohn
    
    Decription: Gaspari-Cohn function. @vbellemin.
    This function is typically used to apply a smooth, localized weighting to data points based on their distance from a central point.
        
    Args: 
    array: The array to which the Gaspari-Cohn function will be applied.
    center: The central value around which the localization will be centered.
    distance: Specifies the distance beyond which the function's effects are zero, effectively limiting the influence of distant points.
    
    Returns: An array of smoothed values based on the application of the Gaspari-Cohn function to the input array.        
    """ 
    if type(array) is float or type(array) is int:
        array = np.array([array])
    else:
        array = array
    if distance<=0:
        return np.zeros_like(array)
    else:
        array = 2*np.abs(array-center*np.ones_like(array))/distance
        gp = np.zeros_like(array)
        i= np.where(array<=1.)[0]
        gp[i]=-0.25*array[i]**5+0.5*array[i]**4+0.625*array[i]**3-5./3.*array[i]**2+1.
        i =np.where((array>1.)*(array<=2.))[0]
        gp[i] = 1./12.*array[i]**5-0.5*array[i]**4+0.625*array[i]**3+5./3.*array[i]**2-5.*array[i]+4.-2./3./array[i]
        #if type(r) is float:
        #    gp = gp[0]
    return gp

In [None]:
# Create spatial window
def create_spatial_window(nx,ny):
    """
    Name: create_spatial_window
    
    Description: Creates a spatial window using the Gaspari-Cohn function to apply weights on the edges, 
    which can be useful in data assimilation or when applying filters to prevent edge effects.

    Parameters:

    nx (int): The number of columns in the original array.
    ny (int): The number of rows in the original array.

    Returns:
    result (2D array): A weighted window array of shape (3*ny, 3*nx).
    """
    result = np.ones((3*ny,3*nx))
    
    gaspari_x = np.expand_dims(gaspari_cohn(np.arange(2*nx),nx,nx),axis=0)
    gaspari_y = np.expand_dims(gaspari_cohn(np.arange(2*ny),ny,ny),axis=1)

    #paving edges with gaspari-cohn
    result[2*ny:,nx:2*nx] = np.repeat(gaspari_y[ny:,:],repeats=nx,axis=1)
    result[:ny,nx:2*nx] = np.repeat(gaspari_y[:ny,:],repeats=nx,axis=1)
    result[ny:2*ny,0:nx] = np.repeat(gaspari_x[:,:nx],repeats=ny,axis=0)
    result[ny:2*ny,2*nx:] = np.repeat(gaspari_x[:,nx:],repeats=ny,axis=0)

    #paving corners with gaspari-cohn
    result[2*ny:,2*nx:]=gaspari_y[ny:,:]*gaspari_x[:,nx:]
    result[:ny,:nx]=gaspari_y[:ny,:]*gaspari_x[:,:nx]
    result[2*ny:,:nx]=gaspari_y[ny:,:]*gaspari_x[:,:nx]
    result[:ny,2*nx:]=gaspari_y[:ny,:]*gaspari_x[:,nx:]

    return result 

In [None]:
# Extract bar tide
def extract_bar_tide(ssh0,dx):
    """
    Name: extract_bar_tide
    
    Description: Extracts the barotropic component from ssh_hf by applying a low-pass filter in the frequency domain 
    to isolate components associated with long wavelengths typical of barotropic tides.

    Parameters:
    ssh0 (2D array): The initial SSH data.
    dx (float): The spatial resolution of the data (distance between points).
    
    Returns:
    ssh_filtered (2D array): The SSH data filtered to retain only the barotropic tide component.
    """
    nx = ssh0.shape[1]
    ny = ssh0.shape[0]

    kx = np.fft.fftfreq(3*nx,dx) # km
    ky = np.fft.fftfreq(3*ny,dx) # km
    k, l = np.meshgrid(kx,ky)
    wavenum2D = np.sqrt(k**2 + l**2)

    lambda_bar = 400
    lowpass_bar = lowpass(lambda_bar,nx,ny,wavenum2D)

    window = create_spatial_window(nx,ny)

    ssh = extend(ssh0,nx,ny)
    ssh = ssh * window
    ssh_freq = fp.fft2(ssh)
    ssh_freq_filtered = lowpass_bar * ssh_freq
    ssh_filtered = np.real(fp.ifft2(ssh_freq_filtered))[ny:2*ny,nx:2*nx]

    return ssh_filtered

In [None]:
# Create bar tide

def create_bar_tide(date):

    ds = xr.open_dataset(path_to_input+date.astype('str').replace('-','')+".nc")
    mask = np.load(path_to_mask)

    # PROCESSING #

    ssh_hf = ds.ssh_hf.where(mask==False,np.nan)

    ssh_hf = ssh_hf.coarsen(longitude=4,latitude=4,boundary='trim').mean()
    ssh_hf = ssh_hf.load().chunk({'time':1})

    x_axis = Axis(ssh_hf.longitude.values,is_circle=True)
    y_axis = Axis(ssh_hf.latitude.values,is_circle=True)
    t_axis = TemporalAxis(ssh_hf.time.values)

    grid = Grid3D(y_axis, x_axis, t_axis, ssh_hf.values.transpose(1,2,0))
    has_converged, filled = fill.gauss_seidel(grid,num_threads=16)

    ssh_hf_filled = ssh_hf.copy(deep=True,data=filled.transpose(2,0,1)).chunk({'time':1})

    # INTERPOLATION OF NaNs # 
    x_axis = Axis(np.arange(i_lon))
    y_axis = Axis(np.arange(i_lat))
    t_axis = TemporalAxis(ssh_hf.time.values)

    grid = Grid3D(y_axis, x_axis, t_axis, array_cart_ssh.reshape((24,i_lat,i_lon)).transpose(1,2,0))
    has_converged, filled = fill.gauss_seidel(grid,num_threads=16)

    mask_cart = np.isnan(array_cart_ssh[0].reshape((i_lat,i_lon)))

    cart_ssh_hf = xr.DataArray(data=filled.transpose(2,0,1),
                            dims=["time","y","x"],
                            coords = dict(
                                time = ssh_hf_filled.time.values,
                                y=(["y"],np.arange(i_lat)),
                                x=(["x"],np.arange(i_lon))
                            )).chunk({'time':1})
    
    
    # EXTRACTING BAROTROPIC TIDE # 
    cart_ssh_filtered = np.array(Parallel(n_jobs=16,backend='multiprocessing')(jb_delayed(extract_bar_tide)(cart_ssh_hf[i].values,dx) for i in range(24)))  

    # FINAL FILE CREATION # 

    ssh_filtered = ssh_hf.copy(deep=True,data=cart_ssh_filtered).chunk({'time':1}) #changer geo_filtered par cart_ssh_filtered
    
    ssh_filtered = ssh_filtered.interp_like(ds,kwargs={"fill_value": "extrapolate"}) #interpolation 

    ssh_filtered = ssh_filtered.where(mask==False,np.nan)

    ssh_filtered = ssh_filtered.rename("ssh_bar")

    ssh_filtered.to_netcdf(path_to_save+date.astype('str').replace('-','')+".nc")


In [None]:
# # Define your range of x values
# x = np.arange(-3., 3., 0.1)

# # Set the parameters for your gaspari_cohn function
# center = 0  # Center of the function
# distance = 6  # Distance over which the function is non-zero

# # Plot the function
# plt.figure()
# plt.plot(ssh_hf_1h, gaspari_cohn(ssh_hf_1h, distance, center), 'k')
# plt.xlabel('x')
# plt.ylabel('Gaspari-Cohn Function Value')
# plt.title('Gaspari-Cohn Function')
# plt.grid(True)
# plt.show()
