In [1]:
import numpy as np
import pandas as pd
import xarray as xr
from netCDF4 import num2date
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

In [2]:
def define_dpm():
    dpm = {'noleap': [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31],
           '365_day': [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31],
           'standard': [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31],
           'gregorian': [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31],
           'proleptic_gregorian': [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31],
           'all_leap': [0, 31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31],
           '366_day': [0, 31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31],
           '360_day': [0, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30]}
    return dpm

In [3]:
def leap_year(year, calendar='standard'):
    """Determine if year is a leap year"""
    leap = False
    if ((calendar in ['standard', 'gregorian',
        'proleptic_gregorian', 'julian']) and
        (year % 4 == 0)):
        leap = True
        if ((calendar == 'proleptic_gregorian') and
            (year % 100 == 0) and
            (year % 400 != 0)):
            leap = False
        elif ((calendar in ['standard', 'gregorian']) and
                 (year % 100 == 0) and (year % 400 != 0) and
                 (year < 1583)):
            leap = False
    return leap

def get_dpm(time, calendar='standard'):
    """
    return a array of days per month corresponding to the months provided in `months`
    """
    month_length = np.zeros(len(time), dtype=np.int)
    
    dpm=define_dpm()
    cal_days = dpm[calendar]

    for i, (month, year) in enumerate(zip(time.month, time.year)):
        month_length[i] = cal_days[month]
        if leap_year(year, calendar=calendar) and month == 2:
            month_length[i] += 1
    return month_length

In [6]:
def create_plots(ds2020,ds2050):# only used for comparisons
    
    ds_diff = ds2050 - ds2020

    notnull = pd.notnull(ds2050['chl'][0])
    land_110m = cfeature.NaturalEarthFeature('physical', 'land', '110m')
    proj=ccrs.PlateCarree()
    extent=[-20, 20, 50, 80]
    
    fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(14,16), subplot_kw={'projection': proj})
    for i, season in enumerate(('DJF', 'MAM', 'JJA', 'SON')):
        ds2020['chl'].sel(season=season).where(notnull).plot.pcolormesh(
            ax=axes[i, 0], cmap='Spectral_r', transform=ccrs.PlateCarree(), #vmin=-30, vmax=30, cmap='Spectral_r',
            add_colorbar=True, extend='both')
        axes[i, 0].set_extent(extent, crs=proj)
        axes[i, 0].add_feature(land_110m, color="lightgrey")
        axes[i, 0].add_feature(cfeature.COASTLINE, edgecolor="black")
        axes[i, 0].add_feature(cfeature.BORDERS, linestyle=':')
                              
        ds2050['chl'].sel(season=season).where(notnull).plot.pcolormesh(
            ax=axes[i, 1], cmap='Spectral_r', transform=ccrs.PlateCarree(), #vmin=-30, vmax=30, cmap='Spectral_r',
            add_colorbar=True, extend='both')
        axes[i, 1].set_extent(extent, crs=proj)
        axes[i, 1].add_feature(land_110m, color="lightgrey")
        axes[i, 1].add_feature(cfeature.COASTLINE, edgecolor="black")
        axes[i, 1].add_feature(cfeature.BORDERS, linestyle=':')
        
        ds_diff['chl'].sel(season=season).where(notnull).plot.pcolormesh(
            ax=axes[i, 2],  cmap='Spectral_r', transform=ccrs.PlateCarree(), #vmin=-0.1, vmax=.1, cmap='RdBu_r',
            add_colorbar=True, extend='both')
        axes[i, 2].set_extent(extent, crs=proj)
        axes[i, 2].add_feature(land_110m, color="lightgrey")
        axes[i, 2].add_feature(cfeature.COASTLINE, edgecolor="black")
        axes[i, 2].add_feature(cfeature.BORDERS, linestyle=':')
            
        axes[i, 0].set_ylabel(season)
        axes[i, 1].set_ylabel('')
        axes[i, 2].set_ylabel('')
    
    for ax in axes.flat:
        ax.axes.get_xaxis().set_ticklabels([])
        ax.axes.get_yaxis().set_ticklabels([])
        ax.axes.axis('tight')
        ax.set_xlabel('')

    axes[0, 0].set_title('ds2020')
    axes[0, 1].set_title('ds2050')
    axes[0, 2].set_title('Difference')

    plt.tight_layout()
   
    fig.suptitle('Seasonal Chlorophyll', fontsize=16, y=1.02)
    plt.show()

In [7]:
def season_mean(ds, calendar='standard'):
    # Make a DataArray of season/year groups
    year_season = xr.DataArray(ds.time.to_index().to_period(freq='Q-NOV').to_timestamp(how='E'),
                               coords=[ds.time], name='year_season')

    # Make a DataArray with the number of days in each month, size = len(time)
    month_length = xr.DataArray(get_dpm(ds.time.to_index(), calendar=calendar),
                                coords=[ds.time], name='month_length')
    # Calculate the weights by grouping by 'time.season'
    weights = month_length.groupby('time.season') / month_length.groupby('time.season').sum()

    # Test that the sum of the weights for each season is 1.0
    np.testing.assert_allclose(weights.groupby('time.season').sum().values, np.ones(4))

    # Calculate the weighted average
    return np.log((ds * weights).groupby('time.season').sum(dim='time'))
