# Filtering internal tide Mode 1 

This notebook aims at filtering the mode 1 Internal Tide **ssh_it1** with a more restrictive frequency than the initial filtering of the Internal Ground Waves (IGW) signal **ssh_igw**. A more restrictive bandpass filter around tidal frequency (12h) is applied : (11h - 13h).  

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import scipy.fftpack as fp
from scipy.signal import find_peaks
#from scipy.interpolate import RegularGridInterpolator, griddata
from joblib import Parallel
from joblib import delayed as jb_delayed
from pyinterp import fill, Axis, TemporalAxis, Grid3D, Grid2D
from math import *
import glob
# import xrft

import sys
sys.path.append("/bettik/bellemva/src/")
from functions import open_mfdataset_w

In [None]:
n_workers = 20

## 1. - Data import 

In [None]:
list_files = glob.glob("/bettik/bellemva/ocean_data_challenge/2023e_SSHmapping_HF_Hawaii/dc_ref_eval_coarse/*.nc")
list_files.sort()         
ds = open_mfdataset_w(list_files,drop_variables = ["ssh","ssh_bm"])#,chunks={'longitude':100,'latitude':100}).chunk({'time':len(list_files)*24})

In [None]:
# selecting the zone #
lon_min=185;lon_max=205
lat_min=15;lat_max=35

ds = ds.sel(longitude=slice(lon_min,lon_max),latitude=slice(lat_min,lat_max),drop=True)

In [None]:
ds_it = ds.ssh_it.load()

## 2. - Internal tide extraction 

In [None]:
array_time = ds.ssh_it.time.values
nt = array_time.size

In [None]:
# PARAMETERS # 
wint = np.ones(3*nt)
gaspari = gaspari_cohn(np.arange(0,2*nt,1),nt,nt)
wint[:nt]=gaspari[:nt]
wint[2*nt:]=gaspari[nt:]

dt = 3600 # seconds

w = fp.fftfreq(3*nt,dt)# seconds^-1
nw = w.size

w1 = 1/13/3600
w2 = 1/11/3600
H = (np.abs(w)>w1) & (np.abs(w)<w2)
w_filtered = H*w

In [None]:
idx_ocean = np.where(np.invert(np.isnan(ds_it[0,:,:].values))) # indexes of ocean pixels 

In [None]:
# PROCESSING # 
ssh_it_flat = np.array(Parallel(n_jobs=n_workers,backend='multiprocessing')(jb_delayed(extract_it)(ds_it[:,i,j],wint,H) for i,j in zip(idx_ocean[0],idx_ocean[1])))

In [None]:
n_time = ds.dims["time"]
n_latitude = ds.dims["latitude"]
n_longitude = ds.dims["longitude"]

In [None]:
del ds

In [None]:
del ds_it

In [None]:
# ARRAY TO STORE THE RESULTS # 
array_ssh_it = np.nan*np.ones((n_time,n_latitude,n_longitude),dtype="float64")
array_ssh_it[:,idx_ocean[0],idx_ocean[1]]=ssh_it_flat.T

# np.save(file="/bettik/bellemva/MITgcm/MITgcm_it/hawaii_long/ssh_it.npy",arr=array_ssh_it)

In [None]:
del ssh_it_flat

Saving internal tide **ssh_it** into xarray DataArray files. 

In [None]:
# RELOADING DS FOR THE STRUCTURE # 

list_files = glob.glob("/bettik/bellemva/ocean_data_challenge/2023e_SSHmapping_HF_Hawaii/dc_ref_eval_coarse/*.nc")
list_files.sort()         
ds = open_mfdataset_w(list_files)#,chunks={'longitude':100,'latitude':100}).chunk({'time':len(list_files)*24})

# selecting the zone #
lon_min=185;lon_max=205
lat_min=15;lat_max=35

ds = ds.sel(longitude=slice(lon_min,lon_max),latitude=slice(lat_min,lat_max),drop=True)

In [None]:
data_array_ssh_it = ds["ssh_it"].copy(data=array_ssh_it)
ds["ssh_it_12h"] = data_array_ssh_it

In [None]:
date_array = np.arange(np.datetime64("2012-05-01"),np.datetime64("2012-10-28"))

ds.sel(time=slice(date_array[5],date_array[6]-np.timedelta64(1,"h")),drop=True)

In [None]:
date_array = np.arange(np.datetime64("2012-05-01"),np.datetime64("2012-10-28"))


for i in range (len(date_array)) : 
    
    ds_day = ds.sel(time=slice(date_array[i],date_array[i+1]-np.timedelta64(1,"h")),drop=True) 

    ds_day.to_netcdf("/bettik/bellemva/ocean_data_challenge/2023e_SSHmapping_HF_Hawaii/dc_ref_eval_coarse/copy/2023e_SSHmapping_HF_Hawaii_eval_"+date_array[i].astype('str')+".nc")

    print(date_array[i])

ds_day = ds.sel(time=slice(np.datetime64("2012-10-27"),np.datetime64("2012-10-28")-np.timedelta64(1,"h")),drop=True) 
ds_day.to_netcdf("/bettik/bellemva/ocean_data_challenge/2023e_SSHmapping_HF_Hawaii/dc_ref_eval_coarse/copy/2023e_SSHmapping_HF_Hawaii_eval_"+np.datetime64("2012-10-27").astype('str')+".nc")

       

## Functions 

In [None]:
def extract_it(array_ssh,wint,H): 
    array_ssh=array_ssh.values
    ssh_extended = np.concatenate((np.flip(array_ssh),
                                   array_ssh,
                                   np.flip(array_ssh)))
    ssh_win = wint * ssh_extended 
    ssh_f_t = fp.fft(ssh_win)
    ssh_f_filtered =  H * ssh_f_t
    ssh_filtered = np.real(fp.ifft(ssh_f_filtered))[nt:2*nt]
    del array_ssh
    return ssh_filtered

In [None]:
def gaspari_cohn(array,distance,center):
    """
    NAME 
        bfn_gaspari_cohn

    DESCRIPTION 
        Gaspari-Cohn function. @vbellemin.
        
        Args: 
            array : array of value whose the Gaspari-Cohn function will be applied
            center : centered value of the function 
            distance : Distance above which the return values are zeros


        Returns:  smoothed values 
            
    """ 
    if type(array) is float or type(array) is int:
        array = np.array([array])
    else:
        array = array
    if distance<=0:
        return np.zeros_like(array)
    else:
        array = 2*np.abs(array-center*np.ones_like(array))/distance
        gp = np.zeros_like(array)
        i= np.where(array<=1.)[0]
        gp[i]=-0.25*array[i]**5+0.5*array[i]**4+0.625*array[i]**3-5./3.*array[i]**2+1.
        i =np.where((array>1.)*(array<=2.))[0]
        gp[i] = 1./12.*array[i]**5-0.5*array[i]**4+0.625*array[i]**3+5./3.*array[i]**2-5.*array[i]+4.-2./3./array[i]
        #if type(r) is float:
        #    gp = gp[0]
    return gp

In [None]:
def create_cartesian_grid (latitude,longitude,dx):
    """ 
    Creates a cartesian grid (regular in distance, kilometers) from a geodesic latitude, longitude grid. 
    The new grid is expressed in latitude, longitude coordinates.

    Parameters
    ----------
    longitude : numpy ndarray 
        Vector of longitude for geodesic input grid. 
    latitude : numpy ndarray 
        Vector of latitude for geodesic input grid. 
    dx : float 
        Grid spacing in kilometers. 

    Returns
    -------
    ENSLAT2D : 
        2-D numpy ndarray of the latitudes of the points of the cartesian grid 
    ENSLON2D : 
        2-D numpy ndarray of the longitudes of the points of the cartesian grid 
    """
    km2deg = 1/111

    # ENSEMBLE OF LATITUDES # 
    ENSLAT = np.arange(latitude[0],latitude[-1]+dx*km2deg,dx*km2deg)
    range_lon = longitude[-1]-longitude[0]

    if longitude.size%2 == 0 : 
        nstep_lon = floor(range_lon/(dx*km2deg))+2
    else : 
        nstep_lon = ceil(range_lon/(dx*km2deg))+2
    ENSLAT2D = np.repeat(np.expand_dims(ENSLAT,axis=1),axis=1,repeats=nstep_lon)

    # ENSEMBLE OF LATITUDES # 
    mid_lon = (longitude[-1]+longitude[0])/2
    ENSLON2D=np.zeros_like(ENSLAT2D)

    for i in range(len(ENSLAT)):
        d_lon = dx*km2deg*(np.cos(np.pi*ENSLAT[0]/180)/np.cos(np.pi*ENSLAT[i]/180))
        d_lon_range = np.array([i*d_lon for i in range (1,int(nstep_lon/2)+1)])
        lon_left = np.flip(mid_lon-d_lon_range)
        lon_right = mid_lon+d_lon_range
        ENSLON2D[i,:]=np.concatenate((lon_left,lon_right))

    return ENSLAT2D, ENSLON2D, ENSLAT2D.shape[0], ENSLAT2D.shape[1]
    

In [None]:
def interpolate_ssh_it(ssh_it):

    x_axis = Axis(ssh_it.longitude.values,is_circle=True)
    y_axis = Axis(ssh_it.latitude.values,is_circle=True)
    t_axis = TemporalAxis(ssh_it.time.values)

    grid = Grid3D(y_axis, x_axis, t_axis, ssh_it.values.transpose(1,2,0))
    has_converged, filled = fill.gauss_seidel(grid,num_threads=4)

    ssh_it_filled = ssh_it.copy(deep=True,data=filled.transpose(2,0,1)).chunk({'time':1})

    dx = 2 # in kilometers, spacing of the grid 

    ENSLAT2D, ENSLON2D, i_lat, i_lon = create_cartesian_grid(ssh_it_filled.latitude.values,
                                                            ssh_it_filled.longitude.values,
                                                            dx)

    array_cart_ssh = ssh_it_filled.interp(latitude=('z',ENSLAT2D.flatten()),
                                        longitude=('z',ENSLON2D.flatten()),
                                        ).values

    # INTERPOLATION OF NaNs # 
    x_axis = Axis(np.arange(i_lon))
    y_axis = Axis(np.arange(i_lat))
    t_axis = TemporalAxis(ssh_it.time.values)

    grid = Grid3D(y_axis, x_axis, t_axis, array_cart_ssh.reshape((24,i_lat,i_lon)).transpose(1,2,0))
    has_converged, filled = fill.gauss_seidel(grid,num_threads=4)


    # CREATION OF DataArray #
    cart_ssh_it = xr.DataArray(data=filled.transpose(2,0,1),
                            dims=["time","y","x"],
                            coords = dict(
                                time = ssh_it_filled.time.values,
                                #y=(["y"],np.arange(i_lat)),
                                #x=(["x"],np.arange(i_lon))
                                y=np.array([i*dx for i in range (i_lat)]),
                                x=np.array([i*dx for i in range (i_lon)])
                            )).chunk({'time':1})
    
    return cart_ssh_it


In [None]:
def plot_spectrum(res,k):
    
    fig, ax = plt.subplots(1,2,figsize=(8,4),dpi=200)

    k1 = k[0]#0.0070
    k2 = k[1]#0.0126
    k3 = k[2]#0.0191
    k4 = k[3]#0.0269

    ax[0].plot(res.freq_r.values,res.values)
    ax[0].set_xlim(0.03,0)
    ax[0].set_xlabel("Wavenumber [km-1]")
    ax[0].axvline(k1,c='red',linestyle=':')
    ax[0].axvline(k2,c='red',linestyle=':')
    ax[0].axvline(k3,c='red',linestyle=':')
    ax[0].axvline(k4,c='red',linestyle=':')
    ax[0].axvline(k1/2,c='red',linestyle='-')
    ax[0].axvline((k1+k2)/2,c='red',linestyle='-')
    ax[0].axvline((k2+k3)/2,c='red',linestyle='-')
    ax[0].axvline((k3+k4)/2,c='red',linestyle='-')
    
    

    ax[1].plot(1/res.freq_r.values,res.values)
    ax[1].set_xlim(0,200)
    ax[1].set_xlabel("Wavelength [km]")

    fig.suptitle("Isotropic Power Spectrum of Internal Tides")