In [1]:
# BUT: enregister les termes du bilan de salinité: dSdt, ADV (adv_h, adv_v), SFX, RES pour tout le globe. 

# ------------------------
# Les variables à choisir :
# ------------------------
nexpREF = "GAI" # "AI" or "S"

# ------------------------
## ----import libraries
# ------------------------
import os,sys
import numpy as np
# xarray
import xarray as xr

from dask.distributed import Client
c = Client()

# ------------------------
# loading data
# ------------------------
diro = "/gpfsscratch/rech/cli/uor98hu/BILANS/"+nexpREF+"/" 
diri="/gpfswork/rech/cli/rcli002/eORCA025.L75/eORCA025.L75-I/"
mesh_hgr=xr.open_dataset(diri+'mesh_hgr.nc').squeeze()
tmask = mesh_hgr.tmask.rename({'nav_lev':"deptht"})
tmaskutil = mesh_hgr.tmaskutil

e1t = mesh_hgr.e1t.fillna(0)
e2t = mesh_hgr.e2t.fillna(0)
e2u = mesh_hgr.e2u.fillna(0)
e1v = mesh_hgr.e1v.fillna(0)
nav_lon =  mesh_hgr.nav_lon
nav_lat =  mesh_hgr.nav_lat

chunk_size = {"x":500,"y":500}

prefix = "eORCA025.L75-IMHOTEP"
diridatref="/gpfsstore/rech/cli/rcli002/eORCA025.L75/"+prefix+"."+nexpREF+"-S/"
fo="1m" # frequency used
    
for year in np.arange(2015,2019):
    y1 = str(year)
    print(y1)
    listTfiles=diridatref+"/1m/"+y1+"/*1m_gridT.nc"
    listUfiles=diridatref+"/1m/"+y1+"/*1m_gridU.nc"
    listVfiles=diridatref+"/1m/"+y1+"/*1m_gridV.nc"

    T_ds = xr.open_mfdataset(listTfiles,decode_coords=True,chunks=chunk_size,parallel=True)
    U_ds= xr.open_mfdataset(listUfiles,decode_coords=True,chunks=chunk_size,parallel=True)
    V_ds = xr.open_mfdataset(listVfiles,decode_coords=True,chunks=chunk_size,parallel=True)

    e3t = T_ds.e3t.fillna(0)
    e3u = U_ds.e3u.fillna(0)
    e3v = V_ds.e3v.fillna(0)

    listProdUfiles = diridatref+"/1m/"+y1+"/*1m_PRODU.nc"
    listProdVfiles = diridatref+"/1m/"+y1+"/*1m_PRODV.nc"
    listProdWfiles = diridatref+"/1m/"+y1+"/*1m_PRODW.nc"

    PRODU_ds = xr.open_mfdataset(listProdUfiles,decode_coords=True,chunks=chunk_size,parallel=True)
    PRODV_ds = xr.open_mfdataset(listProdVfiles,decode_coords=True,chunks=chunk_size,parallel=True)
    PRODW_ds = xr.open_mfdataset(listProdWfiles,decode_coords=True,chunks=chunk_size,parallel=True)

    US_ref = PRODU_ds.vous.fillna(0)
    VS_ref = PRODV_ds.vovs.fillna(0)
    WS_ref = PRODW_ds.vows.fillna(0)

    # ------------------------
    # time tendency: dSdt
    # ------------------------

    dSdt1darr = np.zeros((12,75,1207,1442))
    nbsec = np.zeros((12))

    i = 0
    for month in ['01','02','03','04','05','06','07','08','09','10','11','12']:
        file = diridatref+"/1d/"+y1+"-concat"+"/eORCA025.L75-IMHOTEP."+nexpREF+"_y"+y1+"m"+month+"_1d_gridT.nc"
        ds_T1d = xr.open_dataset(file,chunks = {"time_counter":2,"deptht":5}).squeeze()
        S1d = ds_T1d.vosaline
        deltat1d = (S1d.time_counter[-1].values - S1d.time_counter[0].values)*1e-9 + 86400 ## nombre de secondes dans le mois
        dSdt1d = (S1d[-1] - S1d[0]) / float(deltat1d)
        nbsec[i] = float(deltat1d)
        dSdt1d = dSdt1d.compute()
        dSdt1darr[i,:,:,:]=dSdt1d.values
        i+=1

    # convert into dataset 
    dsdSdt1d = xr.Dataset(
        data_vars=dict(
            dSdt=(["time_counter","deptht","y", "x"], dSdt1darr),
            nbsec=(["time_counter"], nbsec)),
        coords=dict(
            time_counter=WS_ref.time_counter.values,
            deptht=e3t.deptht.values,
            nav_lat=(["y", "x"], nav_lat.values),
            nav_lon=(["y", "x"], nav_lon.values)),
        attrs=dict(
            description="dS/dt for each mont using daily snapshots",
            units="10-3/s")
        )

    # saving
    dsdSdt1d.to_netcdf(path = diro+nexpREF+"_dSdt_1m"+str(y1)+".nc", mode='w')


    # ------------------------
    # calcul de l'advection horizontale
    # ------------------------
    # manuel NEMO 4.0.1 §4.1.
    bt = (e3t*e1t*e2t) # volume of each cell

    prod1_U = (e3u * US_ref * e2u)
    prod1_V = (e3v * VS_ref * e1v)

    deltaU = (prod1_U - prod1_U.roll(x=1)) # garder en tete que le premier point est pas bon
    deltaV = (prod1_V - prod1_V.roll(y=1)) # garder en tete que le premier point est pas bon

    DIV = ( deltaU.rename({'depthu':'deptht'}) + deltaV.rename({'depthv':'deptht'}) ).where(tmask)
    DIV = DIV * tmaskutil * (-1)
    adv_h = DIV/bt
    adv_h = adv_h.compute()
    dsadv_h = adv_h.to_dataset(name = 'adv_h')

    # saving
    dsadv_h.to_netcdf(path=diro+nexpREF+"_adv_h_1m"+str(y1)+".nc", mode='w')

    # ------------------------
    # calcul de l'advection verticale
    # ------------------------
    deltaW = -(WS_ref.diff(dim = "depthw")) # do top - bottom of each cell, we loose the depthw of the top cell, but the "value" of the bottom cell. 
    # arranging e3t depth dimension: putting depthw
    e3t4W = e3t.isel(deptht=np.arange(0,74)).assign_coords({"deptht":deltaW.depthw.values}).rename({'deptht':'depthw'})
    adv_v = deltaW/e3t4W # we divide by the e3t cell thickness corresponding to the cell top - bottom
    adv_v = adv_v.rename({'depthw':'deptht'}).assign_coords({"deptht":e3t.deptht[:74].values})
    #adv_v : ajouter une couche au fond = 0
    adv_varr = np.zeros((12,75,1207,1442))
    adv_varr[:,0:74,:,:] = adv_v.values

    # convert to dataArray
    adv_vda  = xr.DataArray(
        data=adv_varr,
        dims=["time_counter","deptht", "y", "x"],
        coords=dict(
            time_counter=WS_ref.time_counter.values,
            time_centered = (["time_counter"],adv_h.time_centered.values),
            deptht=e3t.deptht.values,
            nav_lat=(["y", "x"], nav_lat.values),
            nav_lon=(["y", "x"], nav_lon.values)
        )
    )

    adv_v = adv_vda.where(tmask) 
    adv_v = adv_v * tmaskutil * (-1)
    dsadv_v = adv_v.to_dataset(name = 'adv_v')

    #saving
    dsadv_v.to_netcdf(path=diro+nexpREF+"_adv_v_1m"+str(y1)+".nc", mode='w')

    # ------------------------
    # calcul de l'advection totale
    # ------------------------
    ADV = adv_h + adv_v
    dsADV = ADV.to_dataset(name = 'adv')
    dsADV.to_netcdf(path=diro+nexpREF+"_adv_1m"+str(y1)+".nc", mode='w')

    # ------------------------
    # calcul du flux de sel de surface
    # ------------------------
    listflxTfiles=diridatref+"/1m/"+y1+"/*1m_flxT.nc"
    flxT_ds = xr.open_mfdataset(listflxTfiles,decode_coords=True,chunks=chunk_size,parallel=True)
    sfxvar = flxT_ds.sosfldow

    SFX = (sfxvar)/(e3t[:,0,:,:]*(-1000))
    SFXds = SFX.to_dataset(name = 'sfx')
    SFXds.to_netcdf(path=diro+nexpREF+"_sfx_1m"+str(y1)+".nc", mode='w')

    # ------------------------
    # calcul du RHS : ADV + SFX
    # ------------------------
    SFXarr = np.zeros((12,75,1207,1442))
    SFXarr[:,0,:,:] = SFX.values

    SFXda  = xr.DataArray(
        data=SFXarr,
        dims=["time_counter", "deptht", "y", "x"],
        coords=dict(
            time_counter=e3t.time_counter.values,
            deptht=e3t.deptht.values,
            nav_lat=(["y", "x"], nav_lat.values),
            nav_lon=(["y", "x"], nav_lon.values))
        )

    rhs = ADV + SFXda

    rhsds = rhs.to_dataset(name = 'rhs')
    rhsds.to_netcdf(path=diro+nexpREF+"_rhs_1m"+str(y1)+".nc", mode='w')

    # ------------------------
    # calcul du Résidu
    # ------------------------
    res = dsdSdt1d.dSdt - rhs
    resds = res.to_dataset(name = 'res')
    resds.to_netcdf(path=diro+nexpREF+"_res_1m"+str(y1)+".nc", mode='w')


2015
2016
2017
2018
