In [None]:
def create_cartesian_grid (latitude,longitude,dx):
    """ 
    Creates a cartesian grid (regular in distance, kilometers) from a geodesic latitude, longitude grid. 
    The new grid is expressed in latitude, longitude coordinates.

    Parameters
    ----------
    longitude : numpy ndarray 
        Vector of longitude for geodesic input grid. 
    latitude : numpy ndarray 
        Vector of latitude for geodesic input grid. 
    dx : float 
        Grid spacing in kilometers. 

    Returns
    -------
    ENSLAT2D : 
        2-D numpy ndarray of the latitudes of the points of the cartesian grid 
    ENSLON2D : 
        2-D numpy ndarray of the longitudes of the points of the cartesian grid 
    """
    km2deg = 1/111

    # ENSEMBLE OF LATITUDES # 
    ENSLAT = np.arange(latitude[0],latitude[-1]+dx*km2deg,dx*km2deg)
    range_lon = longitude[-1]-longitude[0]

    if longitude.size%2 == 0 : 
        nstep_lon = floor(range_lon/(dx*km2deg))+1 # or +2
    else : 
        nstep_lon = ceil(range_lon/(dx*km2deg))+1 # or +2
    ENSLAT2D = np.repeat(np.expand_dims(ENSLAT,axis=1),axis=1,repeats=nstep_lon)

    # ENSEMBLE OF LATITUDES # 
    mid_lon = (longitude[-1]+longitude[0])/2
    ENSLON2D=np.zeros_like(ENSLAT2D)

    for i in range(len(ENSLAT)):
        d_lon = dx*km2deg*(np.cos(np.pi*latitude[0]/180)/np.cos(np.pi*latitude[i]/180))
        d_lon_range = np.array([i*d_lon for i in range (1,int(nstep_lon/2)+1)])
        lon_left = np.flip(mid_lon-d_lon_range)
        lon_right = mid_lon+d_lon_range
        ENSLON2D[i,:]=np.concatenate((lon_left,lon_right))

    return ENSLAT2D, ENSLON2D, ENSLAT2D.shape[0], ENSLAT2D.shape[1]
    

In [None]:
def gaspari_cohn(array,distance,center):
    """
    NAME 
        bfn_gaspari_cohn

    DESCRIPTION 
        Gaspari-Cohn function. @vbellemin.
        
        Args: 
            array : array of value whose the Gaspari-Cohn function will be applied
            center : centered value of the function 
            distance : Distance above which the return values are zeros


        Returns:  smoothed values 
            
    """ 
    if type(array) is float or type(array) is int:
        array = np.array([array])
    else:
        array = array
    if distance<=0:
        return np.zeros_like(array)
    else:
        array = 2*np.abs(array-center*np.ones_like(array))/distance
        gp = np.zeros_like(array)
        i= np.where(array<=1.)[0]
        gp[i]=-0.25*array[i]**5+0.5*array[i]**4+0.625*array[i]**3-5./3.*array[i]**2+1.
        i =np.where((array>1.)*(array<=2.))[0]
        gp[i] = 1./12.*array[i]**5-0.5*array[i]**4+0.625*array[i]**3+5./3.*array[i]**2-5.*array[i]+4.-2./3./array[i]
        #if type(r) is float:
        #    gp = gp[0]
    return gp

In [None]:
def apply_low_pass_time(array_ssh,wint,H):
    array_ssh=array_ssh.values
    ssh_extended = np.concatenate((np.flip(array_ssh),
                                   array_ssh,
                                   np.flip(array_ssh)))
    ssh_win = wint * ssh_extended 
    ssh_f_t = fp.fft(ssh_win)
    ssh_f_filtered =  H * ssh_f_t
    ssh_filtered = np.real(fp.ifft(ssh_f_filtered))[nt:2*nt]
    return ssh_filtered

In [None]:
def extend(ssh,nx,ny):
    ssh_extended = np.empty((3*ny,3*nx))
    ssh_extended[ny:2*ny,nx:2*nx] = +ssh
    ssh_extended[0:ny,nx:2*nx] = +ssh[::-1,:]
    ssh_extended[2*ny:3*ny,nx:2*nx] = +ssh[::-1,:]
    ssh_extended[:,0:nx] = ssh_extended[:,nx:2*nx][:,::-1]
    ssh_extended[:,2*nx:3*nx] = ssh_extended[:,nx:2*nx][:,::-1]
    return ssh_extended

In [None]:
def create_spatial_window(nx,ny):
    result = np.ones((3*ny,3*nx))
    
    gaspari_x = np.expand_dims(gaspari_cohn(np.arange(2*nx),nx,nx),axis=0)
    gaspari_y = np.expand_dims(gaspari_cohn(np.arange(2*ny),ny,ny),axis=1)

    #paving edges with gaspari-cohn
    result[2*ny:,nx:2*nx] = np.repeat(gaspari_y[ny:,:],repeats=nx,axis=1)
    result[:ny,nx:2*nx] = np.repeat(gaspari_y[:ny,:],repeats=nx,axis=1)
    result[ny:2*ny,0:nx] = np.repeat(gaspari_x[:,:nx],repeats=ny,axis=0)
    result[ny:2*ny,2*nx:] = np.repeat(gaspari_x[:,nx:],repeats=ny,axis=0)

    #paving corners with gaspari-cohn
    result[2*ny:,2*nx:]=gaspari_y[ny:,:]*gaspari_x[:,nx:]
    result[:ny,:nx]=gaspari_y[:ny,:]*gaspari_x[:,:nx]
    result[2*ny:,:nx]=gaspari_y[ny:,:]*gaspari_x[:,:nx]
    result[:ny,2*nx:]=gaspari_y[:ny,:]*gaspari_x[:,nx:]

    return result 

In [None]:
def lowpass(_lambda,nx,ny,wavenum2D) : 
    _lowpass = np.zeros((3*ny,3*nx))
    for i in range (3*ny):
        for j in range(3*nx):
            if wavenum2D[i,j]<1/_lambda:
                _lowpass[i,j] = 1
    return _lowpass 

In [None]:
def extract_low_pass_space(ssh0,dx):

    nx = ssh0.shape[1]
    ny = ssh0.shape[0]

    kx = np.fft.fftfreq(3*nx,dx) # km
    ky = np.fft.fftfreq(3*ny,dx) # km
    k, l = np.meshgrid(kx,ky)
    wavenum2D = np.sqrt(k**2 + l**2)

    lowpass_filter = lowpass(lambda_cut,nx,ny,wavenum2D)

    window = create_spatial_window(nx,ny)

    ssh = extend(ssh0,nx,ny)
    ssh = ssh * window
    ssh_freq = fp.fft2(ssh)
    # LOWPASS # 
    ssh_freq_filtered = lowpass_filter * ssh_freq
    ssh_filtered = np.real(fp.ifft2(ssh_freq_filtered))[ny:2*ny,nx:2*nx]

    return ssh_filtered


In [None]:
def apply_low_pass_space(ds_ssh_filtered_time):

    x_axis = Axis(ds_ssh_filtered_time.longitude.values,is_circle=True)
    y_axis = Axis(ds_ssh_filtered_time.latitude.values,is_circle=True)
    t_axis = TemporalAxis(np.ascontiguousarray(ds_ssh_filtered_time.time.values))
    # t_axis = TemporalAxis(ds_ssh_filtered_time.time.values)

    grid = Grid3D(y_axis, x_axis, t_axis, ds_ssh_filtered_time.values.transpose(1,2,0))
    has_converged, filled = fill.gauss_seidel(grid,num_threads=n_workers)

    ssh_filtered_time_filled = ds_ssh_filtered_time.copy(deep=False,data=filled.transpose(2,0,1)).chunk({'time':1})

    print("ssh filled!")

    # TO CARTESIAN GRID #

    dx = 14 # in kilometers, spacing of the target grid 

    ENSLAT2D, ENSLON2D, i_lat, i_lon = create_cartesian_grid(ssh_filtered_time_filled.latitude.values,
                                                            ssh_filtered_time_filled.longitude.values,
                                                            dx)
    
    array_cart_ssh = ssh_filtered_time_filled.interp(latitude=('z',ENSLAT2D.flatten()),
                                                     longitude=('z',ENSLON2D.flatten()),
                                                    ).values
    
    print("interpolated to cartesian grid!")

    # INTERPOLATION OF NaNs # 
    x_axis = Axis(np.arange(i_lon))
    y_axis = Axis(np.arange(i_lat))
    t_axis = TemporalAxis(np.ascontiguousarray(ds_ssh_filtered_time.time.values))

    grid = Grid3D(y_axis, x_axis, t_axis, array_cart_ssh.reshape((ds_ssh_filtered_time.sizes["time"],i_lat,i_lon)).transpose(1,2,0))
    has_converged, filled = fill.gauss_seidel(grid,num_threads=n_workers)

    mask_cart = np.isnan(array_cart_ssh[0].reshape((i_lat,i_lon)))

    cart_ssh_hf = xr.DataArray(data=filled.transpose(2,0,1),
                            dims=["time","y","x"],
                            coords = dict(
                                time = ssh_filtered_time_filled.time.values,
                                y=(["y"],np.arange(i_lat)),
                                x=(["x"],np.arange(i_lon))
                            )).chunk({'time':1})
    
    print("nan interpolated on cartesian grid!")
    
    # EXTRACTING BAROTROPIC TIDE # 
    cart_ssh_filtered = np.array(Parallel(n_jobs=n_workers,backend='multiprocessing')(jb_delayed(extract_low_pass_space)(cart_ssh_hf[i].values,dx) for i in range(ds_ssh_filtered_time.sizes["time"])))

    print("low pass filter applied!")

    lon2d, lat2d = np.meshgrid(ds_ssh_filtered_time.longitude.values, ds_ssh_filtered_time.latitude.values)

    geo_filtered = np.array(Parallel(n_jobs=n_workers,backend='multiprocessing')(jb_delayed(griddata)(np.array([ENSLAT2D.flatten(),ENSLON2D.flatten()]).T,
                                        cart_ssh_filtered[i].flatten(),
                                        (lat2d,lon2d),'cubic') for i in range(ds_ssh_filtered_time.sizes["time"])))
    
    print("interpolated to geo grid back!")    
    
    return geo_filtered



