In [None]:
def bandpass(wavenumbers,nx,ny,wavenum2D):
    _bandpass = np.zeros((3*ny,3*nx))
    for i in range(3*ny):
        for j in range(3*nx):
            if wavenum2D[i,j]>wavenumbers[0] and wavenum2D[i,j]<wavenumbers[1]:
                _bandpass[i,j] = 1
    return _bandpass

In [None]:
def extend(ssh,nx,ny):
    ssh_extended = np.empty((3*ny,3*nx))
    ssh_extended[ny:2*ny,nx:2*nx] = +ssh
    ssh_extended[0:ny,nx:2*nx] = +ssh[::-1,:]
    ssh_extended[2*ny:3*ny,nx:2*nx] = +ssh[::-1,:]
    ssh_extended[:,0:nx] = ssh_extended[:,nx:2*nx][:,::-1]
    ssh_extended[:,2*nx:3*nx] = ssh_extended[:,nx:2*nx][:,::-1]
    return ssh_extended

In [None]:
def create_spatial_window(nx,ny):
    result = np.ones((3*ny,3*nx))
    
    gaspari_x = np.expand_dims(gaspari_cohn(np.arange(2*nx),nx,nx),axis=0)
    gaspari_y = np.expand_dims(gaspari_cohn(np.arange(2*ny),ny,ny),axis=1)

    #paving edges with gaspari-cohn
    result[2*ny:,nx:2*nx] = np.repeat(gaspari_y[ny:,:],repeats=nx,axis=1)
    result[:ny,nx:2*nx] = np.repeat(gaspari_y[:ny,:],repeats=nx,axis=1)
    result[ny:2*ny,0:nx] = np.repeat(gaspari_x[:,:nx],repeats=ny,axis=0)
    result[ny:2*ny,2*nx:] = np.repeat(gaspari_x[:,nx:],repeats=ny,axis=0)

    #paving corners with gaspari-cohn
    result[2*ny:,2*nx:]=gaspari_y[ny:,:]*gaspari_x[:,nx:]
    result[:ny,:nx]=gaspari_y[:ny,:]*gaspari_x[:,:nx]
    result[2*ny:,:nx]=gaspari_y[ny:,:]*gaspari_x[:,:nx]
    result[:ny,2*nx:]=gaspari_y[:ny,:]*gaspari_x[:,nx:]

    return result 

In [None]:
def extract_it_mode(ssh0,window,dx,bandpass_mode_1,bandpass_mode_2,bandpass_mode_3):

    nx = ssh0.shape[1]
    ny = ssh0.shape[0]
    
    ssh = extend(ssh0,nx,ny)
    ssh = ssh * window
    ssh_freq = fp.fft2(ssh)
        
    # MODE 1 #
    ssh_freq_filtered_1 = bandpass_mode_1 * ssh_freq
    ssh_filtered_1 = np.real(fp.ifft2(ssh_freq_filtered_1))[ny:2*ny,nx:2*nx]

    # MODE 2 #
    ssh_freq_filtered_2 = bandpass_mode_2 * ssh_freq
    ssh_filtered_2 = np.real(fp.ifft2(ssh_freq_filtered_2))[ny:2*ny,nx:2*nx]

    # MODE 3 #
    ssh_freq_filtered_3 = bandpass_mode_3 * ssh_freq
    ssh_filtered_3 = np.real(fp.ifft2(ssh_freq_filtered_3))[ny:2*ny,nx:2*nx]

    return ssh_filtered_1, ssh_filtered_2, ssh_filtered_3


In [None]:
def gaspari_cohn(array,distance,center):
    """
    NAME 
        bfn_gaspari_cohn

    DESCRIPTION 
        Gaspari-Cohn function. @vbellemin.
        
        Args: 
            array : array of value whose the Gaspari-Cohn function will be applied
            center : centered value of the function 
            distance : Distance above which the return values are zeros


        Returns:  smoothed values 
            
    """ 
    if type(array) is float or type(array) is int:
        array = np.array([array])
    else:
        array = array
    if distance<=0:
        return np.zeros_like(array)
    else:
        array = 2*np.abs(array-center*np.ones_like(array))/distance
        gp = np.zeros_like(array)
        i= np.where(array<=1.)[0]
        gp[i]=-0.25*array[i]**5+0.5*array[i]**4+0.625*array[i]**3-5./3.*array[i]**2+1.
        i =np.where((array>1.)*(array<=2.))[0]
        gp[i] = 1./12.*array[i]**5-0.5*array[i]**4+0.625*array[i]**3+5./3.*array[i]**2-5.*array[i]+4.-2./3./array[i]
        #if type(r) is float:
        #    gp = gp[0]
    return gp

In [None]:
def interpolate_back_to_original_grid(cart_ssh_filtered, ENSLAT2D, ENSLON2D, ssh_hf_original):
    lon2d, lat2d = np.meshgrid(ssh_hf_original.nav_lon[0, :].values, ssh_hf_original.nav_lat[:, 0].values)
    
    ssh_bar_array = griddata(
        np.array([ENSLAT2D.flatten(), ENSLON2D.flatten()]).T,
        cart_ssh_filtered.flatten(),
        (lat2d, lon2d),
        method='linear'
    )
    
    ssh_bar = xr.DataArray(
        data=ssh_bar_array,
        dims=ssh_hf_original.ssh_it.isel(time_counter=0).dims,
        coords={
            "y": ssh_hf_original.coords["y"],
            "x": ssh_hf_original.coords["x"],
            "nav_lat": (("y", "x"), ssh_hf_original.nav_lat.values),
            "nav_lon": (("y", "x"), ssh_hf_original.nav_lon.values),
        },
        name='ssh_bar'
    )
    
    return ssh_bar

In [None]:
def create_cartesian_grid(latitude, longitude, dx, extra_pixels=70):
    """ 
    Creates a cartesian grid (regular in distance, kilometers) from a geodesic latitude, longitude grid. 
    The new grid is expressed in latitude, longitude coordinates.

    Parameters
    ----------
    longitude : numpy ndarray 
        Vector of longitude for geodesic input grid. 
    latitude : numpy ndarray 
        Vector of latitude for geodesic input grid. 
    dx : float 
        Grid spacing in kilometers. 
    extra_pixels : int, optional
        Number of extra pixels to add on each side of the grid. Default is 2.

    Returns
    -------
    ENSLAT2D : 
        2-D numpy ndarray of the latitudes of the points of the cartesian grid 
    ENSLON2D : 
        2-D numpy ndarray of the longitudes of the points of the cartesian grid 
    """
    km2deg = 1 / 111

    # Extend the latitude range by extra_pixels grid points on each side
    ENSLAT = np.arange(latitude[0] - extra_pixels * dx * km2deg, latitude[-1] + (extra_pixels + 1) * dx * km2deg, dx * km2deg)
    range_lon = longitude[-1] - longitude[0]
    
    # Extend by the number of extra pixels on each side
    if longitude.size % 2 == 0:
        nstep_lon = floor(range_lon / (dx * km2deg)) + 2 * (extra_pixels + 1)  
    else:
        nstep_lon = ceil(range_lon / (dx * km2deg)) + 2 * (extra_pixels + 1) 

    ENSLAT2D = np.repeat(np.expand_dims(ENSLAT, axis=1), axis=1, repeats=nstep_lon)

    # ENSEMBLE OF LONGITUDES
    mid_lon = (longitude[-1] + longitude[0]) / 2
    ENSLON2D = np.zeros_like(ENSLAT2D)

    for i in range(len(ENSLAT)):
        d_lon = dx * km2deg * (np.cos(np.pi * latitude[0] / 180) / np.cos(np.pi * ENSLAT[i] / 180))
        d_lon_range = np.array([j * d_lon for j in range(1, int(nstep_lon / 2) + 1)])
        lon_left = np.flip(mid_lon - d_lon_range)
        lon_right = mid_lon + d_lon_range
        ENSLON2D[i, :] = np.concatenate((lon_left, [mid_lon], lon_right))[:nstep_lon]

    return ENSLAT2D, ENSLON2D, ENSLAT2D.shape[0], ENSLAT2D.shape[1]


In [None]:
def interpolate_back(filtered_data, lat2d, lon2d, ENSLAT2D, ENSLON2D):
    return np.array([griddata(
        np.array([ENSLAT2D.flatten(), ENSLON2D.flatten()]).T,
        filtered_data.flatten(),
        (lat2d, lon2d),
        method='linear'
    )])