In [1]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs

In [3]:
def calculate_OHT(T_ady_2d, T_diffy_2d):
    tady = np.where(np.abs(T_ady_2d) == 1e+20, np.nan, T_ady_2d)
    tady_avg = np.nanmean(tady, axis=0)
    tady_zonal = np.nansum(tady_avg, axis=1)
    tdiffy = np.where(np.abs(T_diffy_2d) == 1e+20, np.nan, T_diffy_2d)
    tdiffy_avg = np.nanmean(tdiffy, axis=0)
    tdiffy_zonal = np.nansum(tdiffy_avg, axis=1)
    OHT = (tady_zonal + tdiffy_zonal)/1e15
    return OHT

def calculate_global_sfn(vmo, vhGM):
    vmo_slice = (np.nanmean(vmo[-31:-1,:,:,:], axis=0))
    vhGM_slice = (np.nanmean(vhGM[-31:-1,:,:,:], axis=0))
    sfn_global_mean = (np.cumsum(np.nansum(vmo_slice/1025, axis=2), axis=0))/(1e6)
    sfn_global_res = (np.cumsum(np.nansum(vhGM_slice/1025, axis=2), axis=0))/(1e6)
    sfn_global_resmean = sfn_global_mean + sfn_global_res
    return sfn_global_mean, sfn_global_res, sfn_global_resmean

def calculate_basin_sfn(vmo, vhGM, x_west, x_east, y_south, y_north):
    vmo_slice = (np.nanmean(vmo[-31:-1,:,:,:], axis=0))
    vhGM_slice = (np.nanmean(vhGM[-31:-1,:,:,:], axis=0))
    sfn_basin_mean = (np.cumsum(np.nansum((vmo_slice[:,y_south:y_north,x_west:x_east])/1025, axis=2), axis=0))/(1e6)
    sfn_basin_res = (np.cumsum(np.nansum((vhGM_slice[:,y_south:y_north,x_west:x_east])/1025, axis=2), axis=0))/(1e6)
    sfn_basin_resmean = sfn_basin_mean + sfn_basin_res
    return sfn_basin_mean, sfn_basin_res, sfn_basin_resmean

def variablename(*arg):
    names = []
    for i in range(len(arg)):
        item = [tpl[0] for tpl in filter(lambda x: arg[i] is x[1], globals().items())]
        names.append(item)
    str(names).strip('[]')
    return names

def plot_global_sfn(sfn_mean, sfn_res, sfn_resmean, sfn_bound, y_south, y_north):
    t = np.linspace(-sfn_bound, sfn_bound, 11, endpoint=True)
    b = np.linspace(-sfn_bound, sfn_bound, 21, endpoint=True)
    fig, axs = plt.subplots(nrows=3, ncols=1, figsize=(8,12))
    sfn_list = sfn_mean, sfn_res, sfn_resmean
    names = variablename(sfn_mean, sfn_res, sfn_resmean)
    sname = str(names[2]).strip("['']")
    for row in range(3):
        ax = axs[row]
        cf = ax.contourf(lat, z, sfn_list[row], b, cmap='RdBu_r', extend='both')
        ax.invert_yaxis()
        ax.set_xlim(-np.abs(y_south), y_north)
        plt.ylabel('Depth', fontsize=14), plt.xlabel('Latitude', fontsize=14)
        ax.set_title('{}'.format(str(names[row]).strip("['']"))+' '+'{}'.format(case), fontsize=16)
        cbar = fig.colorbar(cf, ax=ax, ticks=t, format = '%.0f')
        cbar.ax.set_ylabel('[Sv]', fontsize=12)
    plt.savefig(fig_path+'{}_global.png'.format(sname))
    plt.show()
    return

def plot_basin_sfn(sfn_mean, sfn_res, sfn_resmean, sfn_bound, y_south, y_north):
    t = np.linspace(-sfn_bound, sfn_bound, 11, endpoint=True)
    b = np.linspace(-sfn_bound, sfn_bound, 21, endpoint=True)
    fig, axs = plt.subplots(nrows=3, ncols=1, figsize=(8,12))
    sfn_list = sfn_mean, sfn_res, sfn_resmean
    names = variablename(sfn_mean, sfn_res, sfn_resmean)
    sname = str(names[2]).strip("['']")
    for row in range(3):
        ax = axs[row]
        cf = ax.contourf(lat[y_south:y_north], z, sfn_list[row][:,:], b, cmap='RdBu_r', extend='both')
        ax.invert_yaxis()
        ax.set_xlim(lat[y_south], lat[y_north-1])
        ax.set_title('{}'.format(str(names[row]).strip("['']"))+' '+'{}'.format(case), fontsize=16)
        cbar = fig.colorbar(cf, ax=ax, ticks=t, format = '%.0f')
        cbar.ax.set_ylabel('[Sv]', fontsize=12)
    axs[1].set_ylabel('Depth', fontsize=14), axs[2].set_xlabel('Latitude', fontsize=14)
    plt.savefig(fig_path+'{}.png'.format(sname))
    plt.show()
    return

def plot_salt_section(title, lon_index, z):
    t = np.linspace(33, 38, 11, endpoint=True)
    b = np.linspace(33, 38, 21, endpoint=True)
    lon = np.asarray(dy['xh'][lon_index])
    fig, ax = plt.subplots(figsize=(9,5))
    cf = ax.contourf(dy['yh'], dy['zl'][0:z], np.squeeze(salt[:z,:,lon_index]), b, extend='both')
    ax.invert_yaxis()
    ax.set_xlim(-70, 70)
    plt.ylabel('Depth', fontsize=14), plt.xlabel('Latitude', fontsize=14)
    ax.set_title(title+' '+'Salinity Section at'+' '+'{}'.format(lon)+'E', fontsize=16)
    cbar = fig.colorbar(cf, ax=ax, ticks=t, format = '%.0f')
    cbar.ax.set_ylabel('[psu]', fontsize=12)
    plt.savefig(fig_path+'{}_salt_section_{}.png'.format(title, lon))
    plt.show()
    return

def plot_temp_section(title, lon_index, z):
    t = np.linspace(-5, 30, 11, endpoint=True)
    b = np.linspace(-5, 30, 21, endpoint=True)
    lon = np.asarray(dy['xh'][lon_index])
    fig, ax = plt.subplots(figsize=(9,5))
    cf = ax.contourf(dy['yh'], dy['zl'][0:z], np.squeeze(temp[:z,:,lon_index]), b, cmap='coolwarm', extend='both')
    ax.invert_yaxis()
    ax.set_xlim(-70, 70)
    plt.ylabel('Depth', fontsize=14), plt.xlabel('Latitude', fontsize=14)
    ax.set_title(title+' '+'Temperature Section at'+' '+'{}'.format(lon)+'E', fontsize=16)
    cbar = fig.colorbar(cf, ax=ax, ticks=t, format = '%.0f')
    cbar.ax.set_ylabel('[C]', fontsize=12)
    plt.savefig(fig_path+'{}_temp_section_{}.png'.format(title, lon))
    plt.show()
    return

def plot_surface_diffs(case, case2, name, name2):
    sst = case['tos'].mean(dim='time') - case2['tos'].mean(dim='time')
    sss = case['sos'].mean(dim='time') - case2['sos'].mean(dim='time')
    ssh = case['zos'].mean(dim='time') - case2['zos'].mean(dim='time')
    
    sst = case['tos'].mean(dim='time')
    mask = np.zeros(sst.shape, dtype=bool)
    mask = np.where(~np.isnan(sst), mask, 1)

    mask_sss = np.ma.array(sst, mask=mask)
    
    x_west = np.where(dy_straight['xh']==211)[0][0]
    x_east = np.where(dy_straight['xh']==351)[0][0]

    b = np.linspace(-2, 2, 21, endpoint=True)
    t = np.linspace(-2, 2, 11, endpoint=True)
    fig = plt.figure(figsize=(14, 6))
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.Robinson(central_longitude=180.0))
    plt.contourf(case['xh'], case['yh'], sst, b, 
             cmap='RdBu_r', extend='both',transform=ccrs.PlateCarree())
    cbar = plt.colorbar(ticks=t, boundaries=t, spacing='uniform', extend='both')
    cbar.ax.set_ylabel('[$^\circ$C]', fontsize=14)
    plt.ylabel('Latitude', fontsize=14), plt.xlabel('Longitude', fontsize=14)
    plt.title('{} - {} SST'.format(name, name2), fontsize=16)
    # plt.savefig(fig_path+'SST.png')
    plt.contour(dy_both['xh'], dy_both['yh'], mask, [0.01], colors='gray', transform=ccrs.PlateCarree())
    plt.show()

    plt.subplots(figsize=(8,5))
    plt.plot(case['yh'], np.nanmean(sst, axis=1), linewidth=2, label='Global Mean')
    # Following two lines are for configurations with more than one basin
    plt.plot(case['yh'], np.nanmean(sst[:,x_west:x_east], axis=1), '--',linewidth=2, label='Small Basin')
    plt.plot(case['yh'], np.nanmean(sst[:,:x_west], axis=1), '--', linewidth=2, label='Large Basin')
    #
    plt.ylabel('[C]', fontsize=14), plt.xlabel('Latitude', fontsize=14)
    plt.title('{} - {} SST'.format(name, name2), fontsize=16)
    plt.grid()
    plt.legend()
    # plt.savefig(fig_path+'SST_section.png')
    plt.show()

    b_salt = np.linspace(-1, 1, 21, endpoint=True)
    t_salt = np.linspace(-1, 1, 11, endpoint=True)
    fig = plt.figure(figsize=(14, 6))
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.Robinson(central_longitude=180.0))
    plt.contourf(case['xh'], case['yh'], sss, b_salt, 
             cmap='BrBG_r', extend='both', transform=ccrs.PlateCarree())
    cbar = plt.colorbar(ticks=t_salt, boundaries=t_salt, spacing='uniform', extend='both')
    cbar.ax.set_ylabel('[psu]', fontsize=14)
    plt.ylabel('Latitude', fontsize=14), plt.xlabel('Longitude', fontsize=14)
    plt.title('{} - {} SSS'.format(name, name2), fontsize=16)
    plt.contour(dy_both['xh'], dy_both['yh'], mask, [0.01], colors='gray', transform=ccrs.PlateCarree())
    # plt.savefig(fig_path+'SSS.png')
    plt.show()

    plt.subplots(figsize=(8,5))
    plt.plot(case['yh'], np.nanmean(sss, axis=1), linewidth=2, label='Global Mean')
    # Following two lines are for configurations with more than one basin
    plt.plot(case['yh'], np.nanmean(sss[:,x_west:x_east], axis=1), '--', linewidth=2, label='Small Basin')
    plt.plot(case['yh'], np.nanmean(sss[:,:x_west], axis=1), '--', linewidth=2, label='Large Basin')
    #
    plt.ylabel('[psu]', fontsize=14), plt.xlabel('Latitude', fontsize=14)
    plt.title('{} - {} SSS'.format(name, name2), fontsize=16)
    plt.grid()
    plt.legend()
    # plt.savefig(fig_path+'SSS_section.png')
    plt.show()

    b_ssh = np.linspace(-0.4, 0.4, 21, endpoint=True)
    t_ssh = np.linspace(-0.4, 0.4, 11, endpoint=True)
    fig = plt.figure(figsize=(14, 6))
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.Robinson(central_longitude=180.0))
    plt.contourf(case['xh'], case['yh'], ssh, b_ssh, 
             cmap='PRGn', extend='both', transform=ccrs.PlateCarree())
    cbar = plt.colorbar(ticks=t_ssh, boundaries=t_ssh, spacing='uniform', extend='both')
    cbar.ax.set_ylabel('[m]', fontsize=14)
    plt.ylabel('Latitude', fontsize=14), plt.xlabel('Longitude', fontsize=14)
    plt.title('{} - {} SSH'.format(name, name2), fontsize=16)
    # plt.savefig(fig_path+'SSS.png')
    plt.contour(dy_both['xh'], dy_both['yh'], mask, [0.01], colors='gray', transform=ccrs.PlateCarree())
    plt.show()
    return

def plot_surface(case, name):
    sst = case['tos'].mean(dim='time')
    sss = case['sos'].mean(dim='time')
    ssh = case['zos'].mean(dim='time')

    b = np.linspace(-2, 2, 21, endpoint=True)
    t = np.linspace(-2, 2, 11, endpoint=True)
    fig = plt.figure(figsize=(14, 6))
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.Robinson(central_longitude=180.0))
    plt.contourf(case['xh'], case['yh'], sst, b, 
             cmap='RdBu_r', extend='both',transform=ccrs.PlateCarree())
    cbar = plt.colorbar(ticks=t, boundaries=t, spacing='uniform', extend='both')
    cbar.ax.set_ylabel('[$^\circ$C]', fontsize=14)
    plt.ylabel('Latitude', fontsize=14), plt.xlabel('Longitude', fontsize=14)
    plt.title('{} SST'.format(name), fontsize=16)
    # plt.savefig(fig_path+'SST.png')
    plt.show()

    plt.subplots(figsize=(8,5))
    plt.plot(case['yh'], np.nanmean(sst, axis=1), linewidth=2, label='Global Mean')
    # Following two lines are for configurations with more than one basin
    plt.plot(case['yh'], np.nanmean(sst[:,x_west:x_east], axis=1), '--',linewidth=2, label='Small Basin')
    plt.plot(case['yh'], np.nanmean(sst[:,:x_west], axis=1), '--', linewidth=2, label='Large Basin')
    #
    plt.ylabel('[C]', fontsize=14), plt.xlabel('Latitude', fontsize=14)
    plt.title('{} SST'.format(name), fontsize=16)
    plt.grid()
    plt.legend()
    # plt.savefig(fig_path+'SST_section.png')
    plt.show()

    b_salt = np.linspace(-1, 1, 21, endpoint=True)
    t_salt = np.linspace(-1, 1, 11, endpoint=True)
    fig = plt.figure(figsize=(14, 6))
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.Robinson(central_longitude=180.0))
    plt.contourf(case['xh'], case['yh'], sss, b_salt, 
             cmap='BrBG_r', extend='both', transform=ccrs.PlateCarree())
    cbar = plt.colorbar(ticks=t_salt, boundaries=t_salt, spacing='uniform', extend='both')
    cbar.ax.set_ylabel('[psu]', fontsize=14)
    plt.ylabel('Latitude', fontsize=14), plt.xlabel('Longitude', fontsize=14)
    plt.title('{} SSS'.format(name), fontsize=16)
    # plt.savefig(fig_path+'SSS.png')
    plt.show()

    plt.subplots(figsize=(8,5))
    plt.plot(case['yh'], np.nanmean(sss, axis=1), linewidth=2, label='Global Mean')
    # Following two lines are for configurations with more than one basin
    plt.plot(case['yh'], np.nanmean(sss[:,x_west:x_east], axis=1), '--', linewidth=2, label='Small Basin')
    plt.plot(case['yh'], np.nanmean(sss[:,:x_west], axis=1), '--', linewidth=2, label='Large Basin')
    #
    plt.ylabel('[psu]', fontsize=14), plt.xlabel('Latitude', fontsize=14)
    plt.title('{} SSS'.format(name), fontsize=16)
    plt.grid()
    plt.legend()
    # plt.savefig(fig_path+'SSS_section.png')
    plt.show()

    b_ssh = np.linspace(-0.4, 0.4, 21, endpoint=True)
    t_ssh = np.linspace(-0.4, 0.4, 11, endpoint=True)
    fig = plt.figure(figsize=(14, 6))
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.Robinson(central_longitude=180.0))
    plt.contourf(case['xh'], case['yh'], ssh, b_ssh, 
             cmap='PRGn', extend='both', transform=ccrs.PlateCarree())
    cbar = plt.colorbar(ticks=t_ssh, boundaries=t_ssh, spacing='uniform', extend='both')
    cbar.ax.set_ylabel('[m]', fontsize=14)
    plt.ylabel('Latitude', fontsize=14), plt.xlabel('Longitude', fontsize=14)
    plt.title('{} SSH'.format(name), fontsize=16)
    # plt.savefig(fig_path+'SSS.png')
    plt.show()
    return