In [6]:
#!/usr/bin/env python
# coding: utf-8

"""
    Script to extract hydrographic section from NEMO model
    - adapted from GEOMAR git website https://git.geomar.de/python/xorca_brokenline/tree/master/ - 
    
    
    Input:
    -------
    - data on T,U and V grid
    - mesh and mask files (mask.nc, mesh_hgr.nc, mesh_zgr.nc)
    - station pairs for section (can be more than two)
    
    Output:
    -------
    - netcdf file with properties on hydrographic section 
    - velocities normal to and along section
    - volume, freshwater and heat transport
    
    
"""

# In[1]:


import sys
sys.path.append("/vortexfs1/home/sryan/Python/NEMO/NEMO_python_tools/xorca_lonlat2ij-master/")  # adds full path to working directory
sys.path.append("/vortexfs1/home/sryan/Python/NEMO/NEMO_python_tools/xorca_brokenline-master/")  # adds full path to working directory

import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import xorca_brokenline as bl
from matplotlib import colors as c
from xorca_lonlat2ij import get_ij

In [7]:
## Load Data
datapath = "/vortexfs1/share/clidex/data/ORCA/NUSA/INDOARCHIPEL.L46.LIM2vp.JRA.XIOS2-KPW002.hindcast_5d/"
gridT = xr.open_mfdataset(datapath + "1_INDOARCHIPEL.L46.LIM2vp.JRA.XIOS2-KPW002.hindcast_5d_*_grid_T_subset.nc",
                         combine='by_coords')


gridV = xr.open_mfdataset(datapath + "1_INDOARCHIPEL.L46.LIM2vp.JRA.XIOS2-KPW002.hindcast_5d_*_grid_V_subset.nc",
                         combine='by_coords')
gridU = xr.open_mfdataset(datapath + "1_INDOARCHIPEL.L46.LIM2vp.JRA.XIOS2-KPW002.hindcast_5d_*_grid_U_subset.nc",
                         combine='by_coords')

# mesh and mask files
meshpath = datapath
mesh_mask = xr.open_dataset(meshpath + '1_mesh_mask.nc')


In [None]:
## Define section

# We use [xorca_lonlat2ij](https://git.geomar.de/python/xorca_lonlat2ij) to convert lat/lon coordinates to ji coordinates.

# In[124]:
#
#
#------- Enter section pairs and names here ----------
#latlon_pairs = [(-4.5,120),(-3.3,116)] # ITF_120E (-8.5,120),(-20.5,120), #Makassar Strait (-4.5,120),(-3.3,116); 30S_LC (-30,110),(-30,117)
sections ={
        "ITF120E": [(-20.5,120),(-8.5,120)],
        "ITF115E": [(-22.2,115),(-8.3,115)],
        "LC24S": [(-24,114),(-24,100)],
	"LC22S": [(-22,114),(-22,100)],
	"LC27S": [(-27,114),(-27,100)],
	"LC30S": [(-30,117),(-30,100)],
	"LC32S": [(-32,117),(-32,100)],
        "Makassar": [(-4.5,120),(-3.3,116)]
}
#------------------------------------------
#
#

# give list of sections you want to process
# 'LC22S','LC24S','LC30S','ITF120E','ITF115E','Makassar'
for sec_name in ['LC24S','Maksassar']:
#     sec_name = sections
    saveplot = 0
    pathplot = '~/Python/NUSA/sections/' + sec_name + '/'
    if not os.path.exists(pathplot):
        os.makedirs(pathplot)
    #------------------------------------------
    #
    #
    
    latlon_pairs = sections[sec_name]
    ji_pairs = get_ij(mesh_hgr,latlon_pairs)

    ji = bl.section_indices(ji_pairs=ji_pairs)
    jj,ii = zip(*ji)

    # plot map with defined section
    # fmask.plot(cmap=c.ListedColormap(['gray','white']),add_colorbar=0)
    # plt.plot(ii,jj,'x')

    # plt.ylabel("j")
    # plt.xlabel("i")
    # plt.title('Selected points');
    # plt.xlim(min(ii)-40,max(ii)+40)
    # plt.ylim(min(jj)-40,max(jj)+40)


    # ## Extract Section
    # bl.shift_grids() interpolates all variables onto the u- and v-points, drops unnecessary coordinate labels and variables and renames depth dimension. <br>
    # select_section() selects the section defined by ji from gridU_processed and gridV_processed and returns a section.
    # 
    ji = bl.section_indices(ji_pairs=ji_pairs)
    gridU_processed, gridV_processed = bl.shift_grids(gridU=gridU, gridV=gridV,
                                                      mesh_hgr=mesh_mask, mesh_zgr=mesh_mask,
                                                      mask=mesh_mask, gridT=gridT,
                                                      vars_to_keep=('vozocrtx','vomecrty','votemper','vosaline','vosigma0'))
    
    gridU_processed.to_netcdf('/vortexfs1/share/clidex/data/ORCA/NUSA/sections/gridU_processed_Nusa.nc')
    gridV_processed.to_netcdf('/vortexfs1/share/clidex/data/ORCA/NUSA/sections/gridV_processed_Nusa.nc')
    

    # sryan: it takes some time to shift grids, thus I saved them and am loading them in here:
    # gridU_processed = xr.open_dataset('/climodes/data4/sryan/gridU_processed.nc')
    # gridV_processed = xr.open_dataset('/climodes/data4/sryan/gridV_processed.nc')
    
    print('shift_grids done')



    # section = bl.select_section(ji,gridU_processed,gridV_processed,mesh_hgr,mesh_zgr,mask) #.compute()


    # sryan: For memory purposes we cannot process the whole dataset at once but have to do it via a loop. ( One other option might be to reduce the dataset to just around the section beforehand, but then the mesh files also have to be cut.)

    # In[125]:
    datasets = []
    for tt in range(len(gridT.time_counter.values)):#
        ji = bl.section_indices(ji_pairs=ji_pairs)
        dummy = bl.select_section(ji,gridU_processed.isel(time_counter=tt),gridV_processed.isel(time_counter=tt),mesh_hgr,mesh_zgr,mask)
        datasets.append(dummy)
    section = xr.concat(datasets,data_vars='minimal',dim='time_counter')
    for var in ['ii','jj','dx','lat','lon','dz','mask']:
        section[var] = dummy[var]
    del dummy,datasets
    section

    print('section done')
    # Section includes the velocity normal to the section segment (either U or V, depending on the orientation of each segment), the velocity along the section segment (U or V interpolated to the V or U point), salinity, temperature (both interpolated to the U or V point), the length and depth of each segment, the ji and lat/lon coordinates and a land-sea mask. 
    #
    #
    #------------------------------------
    # plot mean section

    plt.rcParams['figure.figsize'] = [10, 4]
    plt.rcParams.update({'font.size': 12})

    if saveplot==1:
        # In[132]:
        section['u_normal'].mean('time_counter').plot()
        section['mask'].where(section['mask']==0).plot(cmap=c.ListedColormap(['gray']),add_colorbar=0)
        plt.ylim(1000,0)
        plt.ylabel("Depth (m)")
        plt.xlabel("Distance (m)")
        plt.title("Across section velocity");
        plt.savefig(pathplot + 'section_' + sec_name + 'u_normal_mean.png',dpi=300,bbox_inches='tight')


        # In[140]:
        section['votemper'].mean('time_counter').plot(vmin=10,vmax=29)
        section['mask'].where(section['mask']==0).plot(cmap=c.ListedColormap(['gray']),add_colorbar=0)
        plt.ylim(1000,0)
        plt.ylabel("Depth (m)")
        plt.xlabel("Distance (m)")
        plt.title("Section Temperature");
        plt.savefig(pathplot + 'section_' + sec_name + 'votemper_mean.png',dpi=300,bbox_inches='tight')


        # In[141]:
        section['vosaline'].mean('time_counter').plot(vmin=34.2,vmax=35.5)
        section['mask'].where(section['mask']==0).plot(cmap=c.ListedColormap(['gray']),add_colorbar=0)
        plt.ylim(1000,0)
        plt.ylabel("Depth (m)")
        plt.xlabel("Distance (m)")
        plt.title("Section Salinity");
        plt.savefig(pathplot + 'section_' + sec_name + 'vosaline_mean.png',dpi=300,bbox_inches='tight')

    #------------------------------------------------
    # ## Across section transport
    transport = bl.calculate_transport(section)


    # transport includes volume transport and fresh water transport. Note that a reference salinity is required which is set to S_ref = 34.8 by default.

    # ## Write everything into NetCDF File
    # write section variables and transports into NetCDF file using format='NETCDF4_CLASSIC' and unlimited_dims='time_counter'

    # In[111]:

    #------------------------------------
    # save data to netcdf
    data_sec = xr.merge([section, transport])
    data_sec.to_netcdf('/vortexfs1/share/clidex/data/ORCA/NUSA/sections/section_'
                       + sec_name +'_'+ run +'.nc', format='NETCDF4_CLASSIC',
                       unlimited_dims='time_counter')
    del data_sec


    # ## Check timeseries plots
    # plot volume, freshwater and heat tranpsport timeseries 

    if saveplot==1:
        plt.plot(transport.time_counter.values,transport['trsp'])
        plt.ylabel("Volume transport (Sv)");
        plt.gca().axhline(0,color='k')
        plt.savefig(pathplot + 'section_' + sec_name + 'vol_transport_fullsec.png',dpi=300,bbox_inches='tight')



        # In[109]:
        plt.plot(transport.time_counter.values,transport['fw_trsp'])
        plt.ylabel("Freshwater transport (Sv)");
        plt.gca().axhline(0,color='k')
        plt.savefig(pathplot + 'section_' + sec_name + 'fw_transport_fullsec.png',dpi=300,bbox_inches='tight')


        # In[110]:
        fig,ax = plt.subplots(figsize=(10,4))
        plt.plot(transport.time_counter.values,transport['ht_trsp']/1e15)
        plt.ylabel("Heat transport (PW)");
        plt.gca().axhline(0,color='k')
        plt.savefig(pathplot + 'section_' + sec_name + 'ht_transport_fullsec.png',dpi=300,bbox_inches='tight')