# import funcs and load data

In [None]:
import numpy as np
import pandas as pd
from scipy import interpolate

import xarray as xr
import dask.array as da
from dask.distributed import Client

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches # for creating legend
import matplotlib.dates as mdates # converts datetime64 to datetime

import cartopy
import cartopy.crs as ccrs # for plotting
import cartopy.feature as cfeature # for map features
from cartopy.util import add_cyclic_point # for wrapping map fully - avoiding white line on 0 deg
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter

import matplotlib.dates as mdates # converts datetime64 to datetime
import matplotlib.gridspec as gridspec # to create grid-shaped combos of axes
from mpl_toolkits import mplot3d # 3d plotting tool
import cmocean # for nice oceanography colour pallettes

#import argopy
#from argopy import DataFetcher as ArgoDataFetcher # to load Argo ds directly

import os # for finding files

import gsw # for conversion functions

from tqdm.notebook import tqdm_notebook as tqdm
import glob # for downloading data
import sys # for path to functions

import seaborn as sns

sns.set(#font='Times New Roman',
        rc={
         'axes.axisbelow': False,
         'axes.edgecolor': 'Black',
         'axes.facecolor': 'w', 
                            # '#aeaeae',
         'axes.grid': False,
         'axes.labelcolor': 'k',
         'axes.spines.right': True,
         'axes.spines.top': True,
         'figure.facecolor': 'white',
         'lines.solid_capstyle': 'round',
         'patch.edgecolor': 'k',
         'patch.force_edgecolor': True,
         'text.color': 'k',
         'xtick.bottom': True,
         'xtick.color': 'k',
         'xtick.direction': 'out',
         'xtick.top': False,
         'ytick.color': 'k',
         'ytick.direction': 'out',
         'ytick.left': True,
         'ytick.right': False},
         font_scale=1)
mpl.rcParams["figure.titlesize"] = 20
mpl.rcParams["axes.titlesize"] = 20
mpl.rcParams["axes.labelsize"] = 12
mpl.rcParams["font.size"] = 10
mpl.rcParams["xtick.labelsize"] = 12
mpl.rcParams["ytick.labelsize"] = 12
mpl.rcParams["ytick.labelright"] = False

#plt.rcParams["font.family"] = "Times"

from warnings import filterwarnings as fw
fw('ignore')

In [None]:
mpl.rcParams["figure.titlesize"] = 25
mpl.rcParams["axes.titlesize"] = 25
mpl.rcParams["axes.labelsize"] = 20
mpl.rcParams["font.size"] = 20
mpl.rcParams["xtick.labelsize"] = 15
mpl.rcParams["ytick.labelsize"] = 15
mpl.rcParams["ytick.labelright"] = False

In [None]:
import importlib
#importlib.reload(sys.modules['load_data'])

path = '/home/theospira/notebooks/projects/WW_climatology'

import sys
sys.path.append(path+'/functions')
from plot_formatting import *
from inspection_funcs import boxplot
from smoothing_and_interp import *

from load_data import load_data
ds, ds_s, ds_ww, bth, ssh, si, szn = load_data(
    path+'/data/hydrographic_profiles/submission2/SO_1yr_clim_seasonal.nc')

In [None]:
def seasonal_grouping(ds):
    """group a dataset on monthly grouping in order of winter, spring, summer, autumn seasons."""
    
    ds['month'] = [7,8,9,10,11,12,1,2,3,4,5,6]
    return ds.groupby_bins(group='month',bins=range(0,15,3),labels=range(0,4))

# some stats

In [None]:
tmp = xr.open_dataset('/home/theospira/notebooks/projects/WW_climatology/data/hydrographic_profiles/submission2/ww_only_hydrographic_data.nc')

In [None]:
tmpgr = tmp.ww_type.groupby('time.month')#[1]

In [None]:
for j,i in enumerate([1,4,7,10]):
    print(szn[int((j+2)%4)])
    ttl = (tmpgr[i].notnull().sum() + tmpgr[i +1].notnull().sum() + tmpgr[i+2].notnull().sum()).data
    for type in [1,2,3]:
        if type < 3:
            sznl = ((tmpgr[i] == type).sum() + (tmpgr[i +1] == type).sum() + (tmpgr[i+2] == type).sum()).data
            print('type ',type, ' ', sznl, np.round(sznl / ttl *100,decimals=2),'%')
        else:
            print('total ',' ','', ttl,'\n')

# Figure 2: circumpolar data distribution and comparison

## funcs

In [None]:
from scipy.stats import mode
def ww_profile(ax,ds_s,ds,leg_coords,lon=0,lat=-63,return_sec=False,add_leg=True,
               c1_f='#377eb8', c2_f='#ff7f00', c1='k', c2='#e41a1c'):

    tmp  = ds_s[['ctemp','sig','up_bd','lw_bd','ww_cd',]].sel(lon=slice(lon-3,lon+3),pres=slice(10,400)
                    #).interpolate_na('lat',max_gap=3).interpolate_na('pres'
                    ).mean('lon',skipna=True)
    tmp2 = tmp.sel(lat=lat,pres=slice(10,400))
    z    = gsw.z_from_p(tmp2.pres,tmp2.lat)*-1
    
    # set temp and rho colors, respectively
    c1_f = c1_f
    c2_f = c2_f
    c1 = c1
    c2 = c2
    for i,a in enumerate(ax):
        a.plot(tmp2.ctemp.isel(season=i),z,zorder=10,c=c1)
        a.set_xlim(-2,2)
        a.set_xticklabels([])
        a.set_ylim(300,0)
        a.set_ylabel('')
        a.yaxis.tick_right()
    
        a.set_xlabel('')
        if i == 3:
            a.set_xticklabels(np.arange(-2,2.5,2))
            a.set_xlabel('CT 'r'(°C)',)
        a.tick_params(axis='both',which='major',)
        a.tick_params(axis='x',which='major',)
    
        a.spines['bottom'].set_color(c1)
        a.xaxis.label.set_color(c1)
        [t.set_color(c1) for t in a.xaxis.get_ticklabels()]
    
        # add upper bound, lower bound and core pressure
        y1 = tmp2.up_bd.isel(season=i).data
        y2 = tmp2.lw_bd.isel(season=i).data
        cd = tmp2.ww_cd.isel(season=i).data
        a.axhline(y1,ls='--',c='#626262')
        a.axhline(y2,ls='--',c='#626262')
        a.axhline(cd,ls='--',c='k')
        
        # set colour based on condition
        cdn = mode(ds.sel(lon=slice(lon-3,lon+3),pres=slice(10,300),lat=lat,season=i).ww_type,nan_policy='omit')[0]
        if cdn == 1:
            f1 = a.fill_between(x=np.arange(-2,2.5,0.001),y1=y1,y2=y2,color=c1_f,alpha=0.75,)
        elif cdn == 2:
            f2 = a.fill_between(x=np.arange(-2,2.5,0.001),y1=y1,y2=y2,color=c2_f,alpha=0.75,)
        
        # plot density
        a1 = a.twiny()
        a1.plot(tmp2.sig.isel(season=i),z,zorder=5,c=c2,)
        a1.set_xlim(27,28)
        a1.set_xticks([27.0,27.5,28,])
        a1.set_xticklabels([])
        a1.spines['top'].set_color(c2)
        a1.xaxis.label.set_color(c2)
        a1.tick_params(axis='x', colors=c2)
        a1.set_title('')
        a1.tick_params(axis='both',which='major',)
        a1.xaxis.label.set_color(c2)
    
        #a.set_ylabel('pres (dbar)',fontsize=17.5)
        a.set_yticklabels(np.arange(0,350,50),)
        a.set_title('')
        if i == 0:
            a1.set_xticklabels([27.0,27.5,28],c=c2)
            a1.set_xlabel(r'$\sigma_0$ (kg m$^{-3}$)',)            
        else:
            a1.set_xlabel('')
        
        if i==2:
            a.annotate('upper\nbound',(0.275,y1),xytext=(0.7,y1-10),fontsize=12,
                            arrowprops={'arrowstyle':'->','color':'#626262'})
            a.annotate('lower\nbound',(-1.,y2),xytext=(-0.75,200),fontsize=12,
                            arrowprops={'arrowstyle':'->','color':'#626262'},zorder=15)
            a.annotate('core',(-1.65,cd),xytext=(-1.7,cd+90),fontsize=12,
                            arrowprops={'arrowstyle':'->','color':'k'})
    # add curly brackets and thickness label
    #    ax[-1].annotate(r"$\}$",(1.4,y1),xytext=(1-0.65,y1-10),fontsize=10,
    #                   arrowprops={'arrowstyle':'->','color':'#626262'})
    
    if add_leg:
        fig.legend([f1,f2],['WW$_{ML}$','WW$_{SS}$'],bbox_to_anchor=leg_coords,edgecolor=None)


    if return_sec == True:
        return tmp, tmp2

def update_projection(ax, axi, projection, fig=None):
    """
    Update a subplot's projection. 
    
    Check available projections with matplotlib.projections.get_projection_names() or use Cartopy options.
    
    params:
    -------
    ax:  axis from whole figure
    axi: subplot axis to re-project
    projection: desired new projection
    """
    if fig is None:
        fig = plt.gcf()
    rows, cols, start, stop = axi.get_subplotspec().get_geometry()
    ax.flat[start].remove()
    ax.flat[start] = fig.add_subplot(rows, cols, start+1, projection=projection)

def seasonal_grouping(ds):
    """group a dataset on monthly grouping in order of winter, spring, summer, autumn seasons."""
    
    ds['month'] = [7,8,9,10,11,12,1,2,3,4,5,6]
    return ds.groupby_bins(group='month',bins=range(0,15,3),labels=range(0,4))

import matplotlib.cm as mplcm

def plot_cross_section(fig,ax,ds_s,ds_ww,l=0,xlim=(-71-45),lvls=9,**cbar_kwargs):
    tmp1  = ds_s.sel(lon=slice(l-3,l+3)).mean('lon').rolling({'pres':12,'lat':8},min_periods=1).mean()
    condn = ((tmp1.ctemp-2)<0)
    # select ww data only
    tmp2  = ds_ww.sel(lon=slice(l-3,l+3)).mean('lon').rolling({'lat':8},min_periods=1).mean()[['up_bd','lw_bd','ww_cd']]
    #tmp2  = tmp2.where(condn>140)
    
    for i,a in enumerate(ax):
        cs = tmp1.ctemp[i].plot.contourf(ax=a,x='lat',cmap='cmo.thermal',vmin=-2,vmax=2,levels=lvls,add_colorbar=False)
        tmp1.ctemp[i].plot.contour(ax=a,x='lat',colors=['w'],levels=[2],linestyles=['--'])
        tmp2.up_bd[i].where(condn[i].sel(pres=tmp2.up_bd[i],method='nearest')>0
                           ).plot(x='lat',c='#626262',zorder=10,ax=a,ls='--')
        tmp2.lw_bd[i].where(condn[i].sel(pres=tmp2.lw_bd[i],method='nearest')>0
                           ).plot(x='lat',c='#626262',zorder=10,ax=a,ls='--')
        tmp2.ww_cd[i].where(condn[i].sel(pres=tmp2.ww_cd[i],method='nearest')>0
                           ).plot(x='lat',c='k',zorder=10,ax=a,ls='--')
        a.axvline((tmp1.adt.mean('season') - -0.58).__abs__().idxmin(),ls='-',c='k',lw=2.5)
        a.axvline((tmp1.adt.mean('season') - -0.1).__abs__().idxmin(),ls='-',c='k',lw=2.5)
        a.set_xlabel('')
        a.set_ylabel('')
        a.set_ylim(300,0)
        a.set_xlim(-70,-43)
        if i !=len(ax)-1:
            a.set_xticklabels([])

        #cl = a.clabel(cs,levels=np.arange(-2,2.1,1),colors='k')
        #for ii in cl:
         #   ii.set_rotation(0)

    ax[-1].set_xlabel('Latitude (°N)')
    
    cb = fig.colorbar(cs,ax=ax,**cbar_kwargs)
    cb.set_ticks(np.arange(-2,2.1,1))
    cb.minorticks_off()

## plots

## add cross section

In [None]:
ln  = 0  # chosen lon for cross section and profiles
lt  = -59
col = 'k'
sec_kwargs=dict(c=col,ls='--',lw=2,alpha=1,)

alphbt = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
          'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

si_x = 15

extent = [180,-180,-90,-45]
crs = ccrs.PlateCarree()

fig,ax = plt.subplots(4,3,figsize=[13,17.5],constrained_layout=True,width_ratios=[4,5,1.5],
                      subplot_kw={'projection': ccrs.SouthPolarStereo(),}, dpi=600)
ax=ax.flatten()

circular_plot_fomatting(fig,ax[::3],ds=ds,si=si,bathym=False,annotation=False,)

# plot mode WW type
cm = mpl.colors.LinearSegmentedColormap.from_list("", ['#377eb8','#ff7f00'])
for i,a in enumerate(ax[::3]):
    hmp = ds['ww_type'].isel(season=i).plot(x='lon',transform=crs,ax=a,cmap=cm,levels=3,vmin=0,vmax=4,add_colorbar=False,)
    a.plot([ln,ln],[-80,-40],transform=crs,zorder=7,**sec_kwargs)
    a.scatter(ln, lt, s=320, marker=(5,1), color='gold',edgecolors='k',zorder=8,transform=crs)
    a.text(x=0.06,y=0.89,s="(" + alphbt[i] + ")",transform=a.transAxes,fontsize=17.5)
cb = fig.colorbar(hmp,ax=ax[::3],location='top',aspect=12*2*0.4/0.8,shrink=0.4,pad=-0.02)
cb.set_ticks([1,3])
cb.set_ticklabels([r'WW$_{\text{ML}}$',r'WW$_{\text{SS}}$'],fontsize=20)

# plot cross section of WW
for i,a in enumerate(ax[1:12:3]):
    update_projection(ax,a,'rectilinear')
cbar_kwargs = {'aspect':12*2,'shrink':0.8,'location':'top','pad':-0.02,
               'label':'Conservative Temperature 'r'(°C)',}
plot_cross_section(fig,ax[1:12:3],ds_s,ds_ww,**cbar_kwargs,l=ln,lvls=17)
ax[-2].set_xlabel('Latitude (°N)')

# plot example WW profiles
for i,a in enumerate(ax[2:24:3]):
    update_projection(ax,a,'rectilinear')
#ax_share = np.append(ax[1:12:3],ax[2:24:3])
#set_share_axes(ax_share,sharey=True)

ww_profile(ax[2::3],ds_s,ds,lon=ln,lat=lt,leg_coords=(0.655,0.97),c2='#e41a1c',add_leg=False)
ax[-1].set_xlabel

# annotate middle row 
for i,a in enumerate(ax[1:12:3]):
    a.axvline(x=lt,**sec_kwargs,zorder=2)
    a.text(x=-0.125,y=0.975,s="(" + alphbt[i+4] + ")",transform=a.transAxes,fontsize=17.5)
    a.yaxis.tick_left()
#ax[-2].scatter(lt, 275, s=320, marker=(5,1), color='gold',edgecolors='k',zorder=3,)

# annotate final row 
for i,a in enumerate(ax[2::3]):
    a.text(x=0.0175,y=.025,s="(" + alphbt[i+8] + ")",transform=a.transAxes,fontsize=17.5)
    a.set_yticklabels([])
    a.yaxis.tick_left()

fig.text(x=0.33,y=0.46,s="Depth (m)",rotation='vertical',ha='center')
    
# set formatting of circumpolar plots
for i,a in enumerate(ax[::3]):
    a.set_yticklabels([])
    a.set_yticks([])
    a.set_ylabel('')

# remove titles of rows 2&3
for a in ax:
    a.set_title('')

#x,y = 0.97,0.99
x,y = -0.05,0.99
for a,s in zip(ax[::3],szn):
    fig.text(x=x,y=y,s=s,transform=a.transAxes,fontsize=22.5,ha='left')

#fig.tight_layout()

# supplementary: d/dt of WW type

In [None]:
ww_type = ds['ww_type'].where(np.logical_or(ds.ww_type==1,ds.ww_type==2))

extent = [180,-180,-90,-45]
crs = ccrs.PlateCarree()

fig,ax = plt.subplots(1,4,figsize=[20,8],constrained_layout=True,dpi=600,
                      subplot_kw={'projection':ccrs.SouthPolarStereo(),'facecolor':'darkgrey'})
ax=ax.flatten()
asp = 12  # aspect for colorbars

circular_plot_fomatting(fig,ax,ds=ds,fronts=False,sea_ice=False,bathym=False,annotation=True,draw_labels=False)

# custom cmap
#cm = mpl.colors.LinearSegmentedColormap.from_list("", ['b','w','r']) 
cm   = mpl.colors.LinearSegmentedColormap.from_list("", ['#377eb8','w','#ff7f00'])
cm2  = mpl.colors.LinearSegmentedColormap.from_list("", ['#f781bf','#f781bf'])

# create data set of difference
tmp = []
for i in range(4):
    tmp += (ww_type.isel(season=i) - ww_type.isel(season=int(i-1)%4)),
tmp = xr.concat(tmp,dim='season')

#norm = colors.BoundaryNorm(bounds, cmap.N)
# run for all axes
for i,a in enumerate(ax):
    # add annotation for each subfig
    a.text(x=.06,y=.89,s="(" + alphbt[i] + ")",transform=a.transAxes)
    
    # plot bbyyyy
    hmp = tmp[i].plot(x='lon',transform=crs,cmap=cm,ax=a,add_colorbar=False,)
#    tmp.where(tmp==0).diff('season',3)[0].plot(x='lon',transform=crs,cmap=cm2,ax=a,add_colorbar=False,)
    tmp.where(tmp==0).diff('season',3)[0].notnull().plot.contour(x='lon',transform=crs,ax=a,colors='k',levels=[0],
                                                   linewidths=[1.],linestyles=['-'],)
    
    a.set_title(szn[i])
cb = fig.colorbar(hmp,ax=ax,shrink=0.45,aspect=asp,pad=0.025)
cb.set_ticks([-1,0,1])
cb.set_ticklabels(['SS -> ML','remains same','ML -> SS'])

ax[0].contour(cartopy.util.add_cyclic_point(ds.adt.mean('season')),levels=[-0.58,-0.1],colors='k',
              linewidths=1.5,linestyles=['-'],labels='PF & SAF',transform=crs,zorder=10)

#fig.savefig(path+'/figs/A1-d_dt(ww_type)',format='png')

In [None]:
for j,l in enumerate(['ML','SS']):
    print('\n',l)
    for i in range(4):
        print(szn[i], ww_type[i].where(ww_type[i]==1+j).notnull().sum().data, ' ',
              np.round((ww_type[i].where(ww_type[i]==1+j).notnull().sum() / ww_type[i].notnull().sum()).data * 100,decimals=2),'%')

In [None]:
for j,l in enumerate(['SS -> ML','remains same','ML -> SS']):
    print('\n',l)
    for i in range(4):
        print(szn[i], tmp[i].where(tmp[i]==-1+j).notnull().sum().data, ' ',
              np.round((tmp[i].where(tmp[i]==-1+j).notnull().sum() / ww_type[i].notnull().sum()).data * 100,decimals=2),'%')

# Figure 1: funcs for timeseries + plots

## main

In [None]:
ds.dsource

In [None]:
# Import necessary libraries
from datetime import datetime as dt
from mpl_toolkits.axes_grid1 import make_axes_locatable

# Define a function to create a profile timeseries plot
def profile_timeseries_sp(ax,freq='1M'):
    """
    Create a profile timeseries plot for hydrographic data.

    Parameters:
    -----------
    ax : matplotlib.axes._subplots.AxesSubplot
        The axis where the profile timeseries plot will be created.

    Description:
    ------------
    This function loads hydrographic profile data from a NetCDF file, groups the data into 3-month intervals,
    and counts the number of profiles for each data source within each interval. It then creates a stacked bar chart
    to visualize the profile counts over time for different data sources.

    The function takes a Matplotlib axis (`ax`) as input and adds the profile timeseries plot to it.

    Notes:
    ------
    - Make sure to provide a valid path to the NetCDF file containing the hydrographic profile data.
    - This function uses several external libraries, including xarray, pandas, numpy, and tqdm.
    - Customize data sources and associated colors in the `dsrce` and `c` lists as needed.
    - Adjust the x-axis limits as desired.
    """
    # Load hydrographic profile data from a NetCDF file (update the file path)
    tmp = xr.open_dataset('/home/theospira/notebooks/projects/WW_climatology/data/hydrographic_profiles/ww_gauss_smoothed_ds.nc')

    # Create time bins for grouping the data in 3-month intervals

    t_bins = pd.date_range('2003-12-31', '2022-01-01', freq=freq)

    # Group the data by time into 3-month intervals and count profiles in each interval
    grp = tmp[['n_prof', 'dsource']].groupby_bins(tmp.time, bins=t_bins, labels=np.arange(len(t_bins)-1))
    tmp3 = grp.count('n_prof')
    tmp3 = tmp3.rename({'time_bins': 'time'})

    # Initialize an array for plotting
    arr = np.ndarray([4, tmp3.time.size]) * np.nan

    # Define data sources and associated colors
    dsrce = ['Float',    'MEOP',    'CTD',     'Gliders',]
    c     = ['#9e0142', '#86cfa5', '#f98e52', '#ffffbe', ]

    # Loop through time bins and data sources to count profiles
    for j, t in enumerate(list(grp.groups.keys())):
        for i, d in enumerate(dsrce):
            dsr = grp[t].dsource
            if d in dsr:
                arr[i, j] = (dsr == d).sum()

    # Initialize the bottom of the bars for stacking
    bottom = np.zeros(tmp3.time.size)

    # Create a bar chart for each data source with stacked bars
    x_ax = np.arange(0,t_bins.size-1,1)
    for i, d in enumerate(dsrce):
        ax.bar(x_ax, arr[i], label=d, bottom=bottom, color=col[i], width=1, edgecolor=None, linewidth=0, align='edge')
        arr2 = arr[i].copy()
        arr2[np.isnan(arr[i])] = 0
        bottom += arr2
    #ax.bar(np.arange(0,t_bins.size,1), pct_uncl, bottom=pct_ml + pct_ss, label='Unclassified', color='#bebebe')

    xtix = np.arange(0,205,12)
    ax.set_xticks(xtix[1::2],)
    ax.xaxis.set_minor_locator(plt.FixedLocator(xtix[::2]))
    ax.set_xticklabels(pd.date_range(str(ds.time.dt.year.min().data)+'-01-01', '2022-02-01', freq='1M'
                                    )[xtix[1::2]].strftime('%Y'))
    
   # for i, d in enumerate(dsrce):
    #    ax.bar(x_ax, arr[i], label=d, bottom=bottom, color=c[i], width=w,
     #          edgecolor=None, linewidth=0)
    #    arr2 = arr[i]
     #   arr2[np.isnan(arr[i])] = 0
      #  bottom += arr2

    # Set the x-axis limits
    ax.set_xlim(t_bins[0], dt.strptime('2022-01-01', '%Y-%M-%d'))

def total_dsr_pcnt(ax):
    dsrce = ['Float',    'MEOP',    'CTD',     'Gliders', ]
    c     = ['#9e0142', '#86cfa5', '#f98e52', '#ffffbe', ]

    tmp = xr.open_dataset('/home/theospira/notebooks/projects/WW_climatology/data/hydrographic_profiles/ww_gauss_smoothed_ds.nc')
    # what are the total number of profiles for each data source
    dsr = np.unique(tmp.dsource,return_counts=True)
    # as percentage
#    arr = np.round(dsr[1] / dsr[1].sum() * 100,2)
    arr = dsr[1] / dsr[1].sum() * 100


    bottom = 0
    for i, d in enumerate(dsrce):
        idx = np.where(dsr[0] == d)
        ax.bar(1, arr[idx], label=d, bottom=bottom, color=c[i], #width=w,
               edgecolor=None, linewidth=0)
        bottom += arr[idx]

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

alphbt = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 
          'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

n_prof = xr.open_dataset('/home/theospira/notebooks/2m_interp/data/n_profs_monthly_time_series.nc')
n_prof = seasonal_grouping((n_prof.n_prof.groupby('time.month')).sum()).sum().rename({'month_bins':'season'})
n_prof = n_prof.where(n_prof>0)#.rename({'time':'season'})


asp = 15
pad = 0.07
shr = 0.65

crs = ccrs.PlateCarree()

# Create the gridspec layout
fig = plt.figure(figsize=(10, 14), constrained_layout=True, )#dpi=600)
gs = gridspec.GridSpec(5, 3, figure=fig, height_ratios=[1, 1, 1, 1, 0.8],)

# Create subplots for the first row (excluding the colorbar space)
ax = []

# Create subplots for the second and third rows
for i in range(2):
    for j in range(4):
        ax.append(fig.add_subplot(gs[j, i],projection=ccrs.SouthPolarStereo()))
        ax[-1].set_facecolor('darkgrey')

circular_plot_fomatting(fig,ax[:8],ds=ds,si=si,bathym=False,annotation=True)
    
# plot data distribution
for i,a in enumerate(ax[:4]):
    hmp1 = ds.n_prof.isel(season=i).plot(vmin=0,vmax=30,x='lon',cmap='cmo.amp',ax=a,
                                     transform=crs,add_colorbar=False,levels=11)
    a.set_title('')
    #a.text(x=-0.005,y=1.05,s=szn[i],transform=a.transAxes,fontsize=20)

# plot modal data source
for i,a in enumerate(ax[4:8]):
    hmp = ds.dsource.isel(season=i).plot(x='lon',ax=a,transform=crs,cmap='Spectral',
                                         vmin=1,vmax=6,levels=6,add_colorbar=False)
    a.set_title('')

# Create the subplot for the final row with one column 
ax.append(fig.add_subplot(gs[4, :-1]))

# plot time series barchart
profile_timeseries_sp(ax[-1],freq='1M')
ax[-1].set_ylabel('Number of Profiles',fontsize=15)
ax[-1].set_xlabel('Year',fontsize=15)
ax[-1].text(x=0,y=1.05,s='(i)',transform=ax[-1].transAxes)
ax[-1].tick_params(axis='both', which='major', labelsize=12)

# data source as a percentage
# add a new axis
ax.append(make_axes_locatable(ax[-1]).append_axes("right", size="2%", pad=0.1)) 
total_dsr_pcnt(ax[-1])
ax[-1].yaxis.tick_right()
ax[-1].yaxis.set_label_position("right")
ax[-1].set_ylim(0,100)
ax[-1].set_xticks([])

ax[-1].text(x=-0.4,y=1.05,s='(j)',transform=ax[-1].transAxes)
ax[-1].text(x=3,y=1.05,s='%', transform=ax[-1].transAxes, fontsize=15)
ax[-1].tick_params(axis='both', which='major', labelsize=12)

# add cbar for data sources
cb = fig.colorbar(hmp,ax=ax[5:7],shrink=shr,aspect=asp*2,pad=pad,ticks=np.arange(1.5,6.5,1))
cb.set_ticklabels(['Argo', 'CTD', 'Gliders', 'MEOP'])
cb.ax.tick_params(labelsize=12)
cb.ax.minorticks_off()

# add cbar from profile distribution
cb = fig.colorbar(hmp1,ax=ax[1:3],shrink=shr,aspect=asp*2,pad=-0.24,location='left')
cb.set_label(label='Number of Profiles',size=15,)
cb.ax.tick_params(labelsize=12)

x,y = 0.97,0.99
for a,s in zip(ax[:4],szn):
    fig.text(x=x,y=y,s=s,transform=a.transAxes,fontsize=20,ha='center')

In [None]:
hmp

In [None]:
fig.savefig(path+'/figs/02-hydrographic_data.png',format='png')

## swap axes

### func

In [None]:
# Import necessary libraries
from datetime import datetime as dt
from mpl_toolkits.axes_grid1 import make_axes_locatable

def load_dsource_timeseries():
    ds = xr.open_dataset('/home/theospira/notebooks/projects/WW_climatology/data/hydrographic_profiles/superseded/ww_gauss_smoothed_ds-preDec23.nc')

    # convert SOCCOM into "float" data
    for i,d in enumerate(np.unique(ds.dsource)):
        idx = np.where(ds.dsource == d)[0]
        if d=='SOCCOM':
            ds['dsource'][idx] = 1
        else:
            ds['dsource'][idx] = i+1 
    #ds['dsource'] = ds['dsource'].astype(int)

    return ds[['dsource']]
    
# Define a function to create a profile timeseries plot
def profile_timeseries_sp(ax,freq='3M'):
    """
    Create a profile timeseries plot for hydrographic data.

    Parameters:
    -----------
    ax : matplotlib.axes._subplots.AxesSubplot
        The axis where the profile timeseries plot will be created.

    Description:
    ------------
    This function loads hydrographic profile data from a NetCDF file, groups the data into 3-month intervals,
    and counts the number of profiles for each data source within each interval. It then creates a stacked bar chart
    to visualize the profile counts over time for different data sources.

    The function takes a Matplotlib axis (`ax`) as input and adds the profile timeseries plot to it.

    Notes:
    ------
    - Make sure to provide a valid path to the NetCDF file containing the hydrographic profile data.
    - This function uses several external libraries, including xarray, pandas, numpy, and tqdm.
    - Customize data sources and associated colors in the `dsrce` and `c` lists as needed.
    - Adjust the x-axis limits as desired.
    """
    # Load hydrographic profile data from a NetCDF file (update the file path)
    tmp = load_dsource_timeseries()

    # Create time bins for grouping the data in 3-month intervals

    t_bins = pd.date_range('2003-12-31', '2022-01-01', freq=freq)

    # Group the data by time into 3-month intervals and count profiles in each interval
    grp = tmp[['n_prof', 'dsource']].groupby_bins(tmp.time, bins=t_bins, labels=np.arange(len(t_bins)-1))
    tmp3 = grp.count('n_prof')
    tmp3 = tmp3.rename({'time_bins': 'time'})

    # Initialize an array for plotting
    arr = np.ndarray([5, tmp3.time.size]) * np.nan

    # Define data sources and associated colors
    dsrce = ['Argo',    'MEOP',    'CTD',     'Gliders',]
    c     = ['#9e0142', '#86cfa5', '#f98e52', '#ffffbe', ]

    # Loop through time bins and data sources to count profiles
    for j, t in enumerate(list(grp.groups.keys())):
        for i, d in enumerate(dsrce):
            dsr = grp[t].dsource
            if d in dsr:
                arr[i, j] = (dsr == d).sum()

    # Initialize the bottom of the bars for stacking
    bottom = np.zeros(tmp3.time.size)

    # Create a bar chart for each data source with stacked bars
    if freq=='3M':
        x_ax = t_bins[:-1] + pd.DateOffset(months=1) + pd.DateOffset(days=14)
        w    = 75
    elif freq=='1M':
        x_ax = t_bins[:-1] + pd.DateOffset(days=14)
        w    = 25
    for i, d in enumerate(dsrce):
        ax.barh(y = x_ax, width=arr[i], label=d, left=bottom, color=c[i], height=w,
               edgecolor=None, linewidth=0)
        arr2 = arr[i].copy()
        arr2[np.isnan(arr[i])] = 0
        bottom += arr2
    
    # Set the y-axis limits
    ax.set_ylim(t_bins[0], dt.strptime('2022-01-01', '%Y-%M-%d'))
    ax.invert_yaxis()

def total_dsr_pcnt(ax):
    dsrce = ['Argo',    'MEOP',    'CTD',     'Gliders',]
    src_n = [1,4,2,3,]
    col   = ['#9e0142', '#86cfa5', '#f98e52', '#ffffbe',]
    
    # what are the total number of profiles for each data source
    tmp = load_dsource_timeseries()
    dsr = np.unique(tmp.dsource,return_counts=True)
    # as percentage
    #    arr = np.round(dsr[1] / dsr[1].sum() * 100,2)
    arr = dsr[1] / dsr[1].sum() * 100
    
    
    bottom = 0
    for i in range(4):
        idx = np.where(dsr[0] == src_n[i])
        ax.barh(width=arr[idx], y=1, label=dsrce[i], left=bottom, color=col[i], #width=w,
               edgecolor=None, linewidth=0)
        bottom += arr[idx]

In [None]:
## update dsource classification bar chart

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

def dsource_classif_bar_chart(ds, ax=None, w=1, ret_ax=0):
    """
    Create a bar chart showing the absolute quantity of hydrographic profiles coloured by data source.

    Parameters:
    -----------
    ds : xarray.Dataset
        The dataset containing the hydrographic profile data.
    ax : matplotlib.axes.Axes, optional
        The matplotlib axes to plot on. If not provided, a new figure and axes will be created.
    w : float, optional
        The width of the bars.
    ret_ax : int, optional
        If set to 1, the function returns the axis object. Default is 0.
    """
    
    t_bins = pd.date_range(str(ds.time.dt.year.min().data-1)+'-12-31', '2022-01-01', freq='1M')
    gr_t = ds[['dsource','n_prof']].groupby_bins('time', t_bins,)
    
    # loop over months
    # Initialize an array for plotting
    arr = np.ndarray([5, t_bins.size]) * 0
    
    # Define data sources and associated colors
    dsrce = ['Float',    'MEOP',    'CTD',     'Gliders',]
    src_n = [1,4,2,3,]
    col   = ['#9e0142', '#86cfa5', '#f98e52', '#ffffbe',]
    
    # Initialize an array for plotting
    arr = np.ndarray([5, t_bins.size]) * 0
    # Loop through time bins and data sources to count profiles
    for j, t in enumerate(list(gr_t.groups.keys())):
        for i, d in enumerate(src_n):
            dsr = gr_t[t].dsource
            if d in dsr:
                arr[i, j] = (dsr.data == d).sum()

    # Plotting
    if ax is None:
        plt.figure(figsize=(15, 8))
        ax = plt.gca()
    
    # Initialize the bottom of the bars for stacking
    bottom = np.zeros(t_bins.size)
    # define x axis
    x_ax = np.arange(0,t_bins.size,1)
    for i, d in enumerate(dsrce):
        ax.barh(y=x_ax, width=arr[i], label=d, left=bottom, color=col[i], height=w, edgecolor=None, linewidth=0.1, align='edge')
        arr2 = arr[i].copy()
        arr2[np.isnan(arr[i])] = 0
        bottom += arr2
    #ax.bar(np.arange(0,t_bins.size,1), pct_uncl, bottom=pct_ml + pct_ss, label='Unclassified', color='#bebebe')

    xtix = np.arange(0,205,12)
    ax.set_yticks(xtix[::2],)
    ax.yaxis.set_minor_locator(plt.FixedLocator(xtix[1::2]))
    ax.set_yticklabels(pd.date_range(str(ds.time.dt.year.min().data)+'-01-01', '2022-02-01', freq='1M'
                                    )[xtix[::2]].strftime('%Y'))
    #ax.legend(framealpha=0.95)
    # Set the y-axis limits
    #ax.set_ylim(xtix[0], dt.strptime('2022-01-01', '%Y-%M-%d'))
    ax.set_ylim(0,204)
    ax.set_xlim(0,7300)
    ax.invert_yaxis()
    
    if ret_ax:
        return ax

### plots

In [None]:
n_prof = xr.open_dataset('/home/theospira/notebooks/2m_interp/data/n_profs_monthly_time_series.nc')
n_prof = seasonal_grouping((n_prof.n_prof.groupby('time.month')).sum()).sum().rename({'month_bins':'season'})
n_prof = n_prof.where(n_prof>0)#.rename({'time':'season'})

In [None]:
importlib.reload(sys.modules['plot_formatting'])
from plot_formatting import *

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

alphbt = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 
          'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

asp = 15
pad = 0.07
shr = 0.65

crs = ccrs.PlateCarree()

# Create the gridspec layout # 12.5,21
fig = plt.figure(figsize=(12.5, 21), dpi=600,)
gs  = gridspec.GridSpec(5, 6, figure=fig, height_ratios=[1, 1, 1, 1, 0.9], width_ratios=[0.05,1.25,1.25,0.05,0.9,0.05]) #1])

# Create subplots for the first row (excluding the colorbar space)
ax = []

# Create subplots for the second and third rows
for i in range(2):
    for j in range(4):
        ax.append(fig.add_subplot(gs[j, i+1],projection=ccrs.SouthPolarStereo()))
        ax[-1].set_facecolor('darkgrey')

circular_plot_fomatting(fig,ax[:8],ds=ds,si=si,bathym=False,annotation=False,)
for a in ax[:8]:
    plot_gridlines(a)
    
# plot data distribution
for i,a in enumerate(ax[:4]):
    hmp1 = ds.n_prof.isel(season=i).plot(vmin=0,vmax=30,x='lon',cmap='cmo.amp',ax=a,
                                     transform=crs,add_colorbar=False,levels=11)
    a.set_title('')
# add season titles
x,y = 1.025,0.99
for a,s in zip(ax[:4],szn):
    fig.text(x=x,y=y,s=s,transform=a.transAxes,fontsize=22.5,ha='center')

# add cbar from profile distribution
ax.append(fig.add_subplot(gs[1:-2, 0]))
cb = fig.colorbar(hmp1,cax=ax[-1],shrink=shr,aspect=asp*3,pad=pad,location='left')
cb.set_label(label='Number of Profiles',)
cb.ax.tick_params(labelsize=17.5)

# plot modal data source
# first, make the cmap
c     = ['#9e0142', '#f98e52', '#ffffbe', '#86cfa5', ] # ['Floats', 'CTD', 'Gliders', 'MEOP',]
cmap = mcolors.LinearSegmentedColormap.from_list("", c)
for i,a in enumerate(ax[4:8]):
    hmp = ds.dsource.isel(season=i).plot(x='lon',ax=a,transform=crs,cmap=cmap,
                                         vmin=0,vmax=5,levels=5,add_colorbar=False)
    a.set_title('')

for i,a in enumerate(ax[:8]):
    a.text(x=0.06,y=0.89,s="(" + alphbt[i] + ")",transform=a.transAxes,fontsize=17.5)
    
# plot time series barchart
# Create the subplot for the final col with one row 
ax.append(fig.add_subplot(gs[:-1, 4]))
# plot
#profile_timeseries_sp(ax[-1],freq='1M')
dsource_classif_bar_chart(load_dsource_timeseries(),ax=ax[-1],)
ax[-1].set_facecolor('#ededed')
ax[-1].set_xlabel('Number of Profiles',)
#ax[-1].set_ylabel('Year',fontsize=15)
ax[-1].xaxis.set_label_position("top")
ax[-1].xaxis.tick_top()
ax[-1].tick_params(axis='both', which='major', )
ax[-1].text(x=-0.3,y=1.03,s='(i)',transform=ax[-1].transAxes)

# data source as a percentage
# add a new axis
ax.append(make_axes_locatable(ax[-1]).append_axes("bottom", size="2%", pad=0.2)) 
total_dsr_pcnt(ax[-1])
ax[-1].yaxis.tick_right()
ax[-1].yaxis.set_label_position("right")
ax[-1].set_xlim(0,100)
ax[-1].set_xticks([0,50,100])
ax[-1].set_xticklabels(['0%','50%','100%'])
ax[-1].set_yticks([])

ax[-1].text(x=-0.125,y=0.,s='(j)',transform=ax[-1].transAxes,ha='center')
ax[-1].tick_params(axis='both', which='major', )

# add cbar for data sources
ax.append(fig.add_subplot(gs[1:-2, 5]))
cb = fig.colorbar(hmp,cax=ax[-1],shrink=shr,aspect=asp*3,pad=pad,ticks=np.arange(0.125, 1, 0.25)*5)
cb.set_ticklabels(['Floats', 'CTD', 'Gliders', 'MEOP',])
cb.ax.tick_params(labelsize=17.5)
cb.ax.minorticks_off()

# Create the subplot for the final row with one column 
ax.append(fig.add_subplot(gs[4, :-1],))#projection=crs))
hmp = plot_bathym(ax[-1],bth)
ax[-1].text(x=0.025,y=1.05,s='(k)',transform=ax[-1].transAxes,ha='center',fontsize=17.5)
ax.append(fig.add_subplot(gs[4, -1],))
cb = fig.colorbar(hmp,cax=ax[-1],ticks=np.arange(0,5050,1000),extend='max',label='Depth (m)')
cb.ax.invert_yaxis()
cb.minorticks_off()