In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import cmocean as cm
import cartopy.crs as ccrs
import cartopy.feature as cfeature

import scipy.io
import scipy
from scipy.stats import stats


xr_model: an xarray implementation of mode

In [3]:
def mode_ufunc(data):
    """
    mode_ufunc : calculates mode along axis 
    
    Inputs:
    ==============
    data : 3D data DataArrau
    
    Returns
    ==============
    mode : returns the mode calculated along axis
    
    """
    ### Get dimensions
    ndim0 = np.shape(data)[0]
    ndim1 = np.shape(data)[1]
    ndim2 = np.shape(data)[2]

    ### Allocate space to store data
    y_dt = np.ones((ndim1, ndim2))*np.NaN

    ### Remove linear trend
    for dim1 in range(ndim1):
        #print(dim1)
        for dim2 in range(ndim2):
            ### only proceed if no NaNs
            if(np.sum(np.isnan(data[:, dim1, dim2]))==0):
                ### fit linear regression
                out = scipy.stats.mode(data[:, dim1, dim2], axis=0)

                ### subtract linear trend
                y_dt[dim1, dim2] = out[0]

    return y_dt

def xr_mode(data, input_core_dims=['time', 'lat','lon'], output_core_dims=['lat','lon']):
    """
    mode : calculates mode along axis 
    This applies mode_ufunc to xarray dataset using xr.apply_ufunc
    
    Inputs:
    ==============
    data : 3D data DataArrau
    
    Returns
    ==============
    mode : returns the mode calculated along axis
    
    Dependencies
    ==============
    import xarray as xr
    
    """
    out = xr.apply_ufunc(mode_ufunc, data, 
                         input_core_dims=[input_core_dims], 
                         output_core_dims=[output_core_dims])
    return out


biome_average: 
    
a function to average within Fay & McKinley biomes

In [9]:
def biome_average(ds=None):
    #ds_tmp = pr(ds['bias']).ensemble_mean().values
    #%run _define_model_class.ipynb

    ds_biomes = read_biomes()
    biomes_xr = ds_biomes['mean_biomes'].copy()
    biomes_np = ds_biomes['mean_biomes'].values

    ### Biome numbers (1 and 8 are ice biomes)
    biomes = [2,3,4,5,6,7,9,10,11,12,13,14,15,16,17]

    for num in biomes:
        biomes_np[np.where(biomes_np==num)] = ds.where( biomes_xr==num).mean().values

    ### NaN out ice biomes
    biomes_np[np.where(biomes_np==1)] = np.nan
    biomes_np[np.where(biomes_np==8)] = np.nan

    ### dictionary of biomes
    biomes_names = {
             #'NP ICE': 1,
             'NP SPSS': 2,
             'NP STSS': 3,
             'NP STPS': 4,
             'PEQU-W': 5,
             'PEQU-E': 6,
             'SP STPS': 7,
             #'NA ICE': 8,
             'NA SPSS': 9,
             'NA STSS': 10,
             'NA STPS': 11,
             'AEQU': 12,
             'SA STPS': 13,
             'IND STPS': 14,
             'SO STSS': 15,
             'SO SPSS': 16,
             'SO ICE': 17}
    
    ### dataset with biome averages
    ds_out = xr.Dataset(
        {
        'biome_mean':(['lat','lon'], biomes_np),
        },

        coords={
        'lat': (['lat'], np.arange(-89.5,90,1)),
        'lon': (['lon'], np.arange(0.5,360,1))
        })

    return ds_out

biomes_median : 

function to take median within biomes

In [10]:
def biome_median(ds=None):
    #ds_tmp = pr(ds['bias']).ensemble_mean().values
    #%run _define_model_class.ipynb

    ds_biomes = read_biomes()
    biomes_xr = ds_biomes['mean_biomes'].copy()
    biomes_np = ds_biomes['mean_biomes'].values

    ### Biome numbers (1 and 8 are ice biomes)
    biomes = [2,3,4,5,6,7,9,10,11,12,13,14,15,16,17]

    for num in biomes:
        biomes_np[np.where(biomes_np==num)] = ds.where( biomes_xr==num).median().values

    ### NaN out ice biomes
    biomes_np[np.where(biomes_np==1)] = np.nan
    biomes_np[np.where(biomes_np==8)] = np.nan

    ### dictionary of biomes
    biomes_names = {
             #'NP ICE': 1,
             'NP SPSS': 2,
             'NP STSS': 3,
             'NP STPS': 4,
             'PEQU-W': 5,
             'PEQU-E': 6,
             'SP STPS': 7,
             #'NA ICE': 8,
             'NA SPSS': 9,
             'NA STSS': 10,
             'NA STPS': 11,
             'AEQU': 12,
             'SA STPS': 13,
             'IND STPS': 14,
             'SO STSS': 15,
             'SO SPSS': 16,
             'SO ICE': 17}
    
    ### dataset with biome averages
    ds_out = xr.Dataset(
        {
        'biome_mean':(['lat','lon'], biomes_np),
        },

        coords={
        'lat': (['lat'], np.arange(-89.5,90,1)),
        'lon': (['lon'], np.arange(0.5,360,1))
        })

    return ds_out

XY plot

In [11]:
from mpl_toolkits.axes_grid1 import AxesGrid
import matplotlib as mpl
import matplotlib.pyplot as plt

class xy_plot(object):
    """
    SpatialMap : class to plot plot nice spatial maps with a colorbar 
                 correctly positioned in the figure
                 
    Inputs
    ==============
    data     : Input 2D dataset [lon,lat] (default=None)
    lon      : longitude vector (default=np.arange(0.5,360,1))
    lat      : latitude vector (default=np.arange(-89.5,90,1))
    region   : 'world', 'southern-ocean' (default='world')
    fig      : figure handle (default=None)
    rect     : number of rows, columns, and position (default=111)
    cmap     : colormap (default=cm.cm.balance)
    colorbar : Toggle for colorbar (default=True)
    ncolors  : number of colors in colorbar (default=101)
    vrange   : colorbar range (default=[0,1])
    
    Returns
    ==============
    returns a colormap of your data within the specified region 
    
    Methods
    ==============
    set_ticks()
    set_title()
    set_cbar_title()
    set_cbar_labels()

    Add at some point
    ==============
    # worldmap.cbar.ax.yaxis.set_ticks_position("left") # way to easily set tick location
    # worldmap.cbar.ax.yaxis.set_label_position('left') # set label position
    
    
    Example
    ==============
    # download WOA data
    ds = xr.open_dataset('https://data.nodc.noaa.gov/thredds/dodsC/ncei/woa/salinity/decav/1.00/woa18_decav_s00_01.nc', decode_times=False)
    data = ds['s_mn'].where(ds['depth']==0, drop=True).mean(['time','depth'])
    # plot spatial map
    worldmap = SpatialMap(data, lon=ds['lon'], lat=ds['lat'], fig=plt.figure(figsize=(7,7)), vrange=[30, 37], region='world')
    
    """
    def __init__(self, 
                 nrows_ncols=(1, 1),
                 fig=None, 
                 rect=[1,1,1],
                 share_all=False,
                 axes_pad = 0.2):

        ### Setup figure and axes
        if fig is None:
            fig = plt.figure(figsize=(8.5,11))

        self.grid = AxesGrid(fig, 
                             rect=rect, 
                             share_all=share_all,
                             nrows_ncols = nrows_ncols,
                             axes_pad = axes_pad,
                             label_mode = '')  # note the empty label_mode 

    def add_plot(self, x=None, y=None, ax=None, *args, **kwargs):
        """
        add_plot(x,y,ax, **kwargs)

        Inputs:
        ==============
        sub : subplot (this is returuned from add_plot())
        ax. : axis number to add colorbar to

        """

        sub = self.grid[ax].plot(x,y, *args, **kwargs)
        return sub


    def set_title(self, title, ax=None, *args, **kwargs):
        # fontsize 16
        self.grid[ax].set_title(title, *args, **kwargs)

    def square_ax(self, ax=None):
        # Aspect - square
        x0,x1 = self.grid[ax].get_xlim()
        y0,y1 = self.grid[ax].get_ylim()
        self.grid[ax].set_aspect(abs(x1-x0)/abs(y1-y0))
        self.grid[ax].get_position().bounds

    def set_xlim(self, xlim=None, ax=None):
        self.grid[ax].set_xlim(xlim[0], xlim[1])

    def set_ylim(self, ylim=None, ax=None):
            self.grid[ax].set_ylim(ylim[0], ylim[1])

    def set_yticks(self, yticks=None, ax=None):
            self.grid[ax].set_yticks(np.arange(yticks[0], yticks[1]+yticks[2], yticks[2]))

    def set_xticks(self, xticks=None, ax=None):
        self.grid[ax].set_xticks(np.arange(xticks[0], xticks[1]+xticks[2], xticks[2]))

    def set_ylabel(self, ylabel=None, fontsize=None, ax=None):
        self.grid[ax].set_ylabel(ylabel, fontsize=fs_label)

    #def xy_properties(self, ax=None, fs_label=24, fs_ticks=18, 
    #                  xlim=None, xticks=None, xlabel='',
    #                  ylim=None, yticks=None, ylabel='',
    #                  title=''):
    #    if xlim:
    #        self.ax.set_xlim(xlim[0], xlim[1])
    #        
    #    if xticks:
    #        self.ax.set_xticks(np.arange(xticks[0], xticks[1]+xticks[2], xlticks[2]))
    #    
    #    if ylim:
    #        self.ax.set_ylim(ylim[0], ylim[1])
    #        
    #    if yticks:
    #        self.ax.set_yticks(np.arange(yticks[0], yticks[1]+yticks[2], yticks[2]))
    #
    #    self.ax.set_ylabel(ylabel, fontsize=fs_label)
    #    self.ax.set_xlabel(xlabel, fontsize=fs_label)
    #    self.ax.set_title(title, fontsize=fs_label)
    #    self.ax.xaxis.set_tick_params(labelsize=fs_ticks)
    #    self.ax.yaxis.set_tick_params(labelsize=fs_ticks)
    #                  
    # def legend(self, *args, **kwargs):
    #   self.ax.legend(frameon=False,*args, **kwargs)
    

map_and_taylor

In [12]:
def map_and_taylor(ds_biomes = '', 
                   ds_aae = '', 
                   std_star = 'std_star_av', 
                   corr = 'r_av', 
                   fig='', 
                   ax='', 
                   title=''):
    """Display a Taylor diagram in a separate axis."""

    ### What is plotted in colormap
    biomes = ds_biomes.copy()
    np_biomes = ds_biomes.values.copy()
    aae = ds_aae.copy()
    #std_star = 'std_star_av'
    #corr = 'r_av'

    for num in [2,3,4,5,6,7,9,10,11,12,13,14,15,16,17]:
        np_biomes[np.where(np_biomes==num)] = aae.where( biomes ==num).mean().values

    np_biomes[np.where(np_biomes==1)] = np.nan
    np_biomes[np.where(np_biomes==8)] = np.nan

    lon=np.arange(0.5,360,1)
    lat=np.arange(-89.5,90,1)

    ### define the bins and normalize
    vrange = [0, 14]
    cmap = cm.cm.amp
    transform = ccrs.PlateCarree(central_longitude=0)
    bounds = np.linspace(0,1,101)
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

    ### Plot data
    sub = ax.pcolormesh(lon, lat, np_biomes,
                        transform=transform,
                        cmap = cmap,
                        vmin = vrange[0],
                        vmax = vrange[1])

    ### Define coastline
    ax.add_feature(cfeature.NaturalEarthFeature('physical', 'land', '110m', edgecolor='face', facecolor=[0.4, 0.4, 0.4]))
    ax.coastlines(facecolor='k')

    ### Colorbar
    cbar = plt.colorbar(sub, orientation="horizontal", pad=0.01, shrink=0.8)
    cbar.ax.tick_params(labelsize=14)
    cbar.set_label(r'AAE [$\mu atm$]',fontsize=16)

    plt.title(title, fontsize=16)

    ############################
    ### Taylor diagram
    ############################
    ### Reference std
    stdref = 1

    ### dictionary of biomes
    biomes = {
             #'NP ICE': 1,
             'NP SPSS': 2,
             'NP STSS': 3,
             'NP STPS': 4,
             'PEQU-W': 5,
             'PEQU-E': 6,
             'SP STPS': 7,
             #'NA ICE': 8,
             'NA SPSS': 9,
             'NA STSS': 10,
             'NA STPS': 11,
             'AEQU': 12,
             'SA STPS': 13,
             'IND STPS': 14,
             'SO STSS': 15,
             'SO SPSS': 16,
             'SO ICE': 17}

    dia = TaylorDiagram(stdref, fig=fig, rect=122, label='Reference', extend=False, srange=(0, 1.500001))

    dia.samplePoints[0].set_color('r')              # Mark reference point as a red star
    dia.add_grid()                                  # Add grid
    dia._ax.axis[:].major_ticks.set_tick_out(True)  # Put ticks outward

    for label, num in biomes.items():
        tmp = ds_data.where(ds_data['biomes']==num).mean()   
        dia.add_scatter(tmp[std_star], tmp[corr], 
                        s=120, c='k', marker='$%d$' % (num), label=label)

    # Add a figure legend and title
    fig.legend(dia.samplePoints,
               [ p.get_label() for p in dia.samplePoints ],
               numpoints=1, prop=dict(size='small'), loc='right')

Save Figure function

In [13]:
def save_figure(fig, fig_name, fig_dir = '/local/data/artemis/workspace/vbennington/figures/'):
    fig.savefig('{0}{1}'.format(fig_dir,fig_name),bbox_inches='tight',pad_inches=0)

xr_add_cyclic_point

an xarray implementation of cartopy's add_cyclic_point

In [14]:
from cartopy.util import add_cyclic_point

def xr_add_cyclic_point(data, cyclic_coord=None):
    '''
    cyclic_point : a wrapper for catopy's apply_ufunc

    Inputs
    =============
    data         : dataSet you want to add cyclic point to
    cyclic_coord : coordinate to apply cyclic to

    Returns
    =============
    cyclic_data : returns dataset with cyclic point added

    '''
    return xr.apply_ufunc(add_cyclic_point, data.load(),
                          input_core_dims=[[cyclic_coord]], 
                          output_core_dims=[['tmp_new']]).rename({'tmp_new': cyclic_coord})

SpatialMap2 : 

an updated version of Luke's original class

In [3]:
import matplotlib.path as mpath
import numpy as np
import xarray as xr
import cmocean as cm
import cartopy.crs as ccrs
import cartopy.feature
from cartopy.mpl.geoaxes import GeoAxes
from mpl_toolkits.axes_grid1 import AxesGrid
import matplotlib as mpl
import matplotlib.pyplot as plt

from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import matplotlib.ticker as mticker

class SpatialMap2(object):
    """
    SpatialMap2 : class to plot plot nice spatial maps with a colorbar 
                 correctly positioned in the figure
                 
    Inputs
    ==============
    data     : Input 2D dataset [lon,lat] (default=None)
    lon      : longitude vector (default=np.arange(0.5,360,1))
    lat      : latitude vector (default=np.arange(-89.5,90,1))
    region   : 'world', 'southern-ocean' (default='world')
    fig      : figure handle (default=None)
    rect     : number of rows, columns, and position (default=111)
    cmap     : colormap (default=cm.cm.balance)
    colorbar : Toggle for colorbar (default=True)
    ncolors  : number of colors in colorbar (default=101)
    vrange   : colorbar range (default=[0,1])
    
    Returns
    ==============
    returns a colormap of your data within the specified region 
    
    Methods
    ==============
    set_ticks()
    set_title()
    set_cbar_title()
    set_cbar_labels()

    Add at some point
    ==============
    # worldmap.cbar.ax.yaxis.set_ticks_position("left") # way to easily set tick location
    # worldmap.cbar.ax.yaxis.set_label_position('left') # set label position
    
    
    Example
    ==============
    # download WOA data
    ds = xr.open_dataset('https://data.nodc.noaa.gov/thredds/dodsC/ncei/woa/salinity/decav/1.00/woa18_decav_s00_01.nc', decode_times=False)
    data = ds['s_mn'].where(ds['depth']==0, drop=True).mean(['time','depth'])
    # plot spatial map
    worldmap = SpatialMap(data, lon=ds['lon'], lat=ds['lat'], fig=plt.figure(figsize=(7,7)), vrange=[30, 37], region='world')
    
    """

    def __init__(self, 
                 nrows_ncols=(1, 1),
                 region='world', 
                 fig=None, 
                 rect=[1,1,1],  
                 colorbar=True, 
                 cbar_location='bottom',
                 cbar_mode='single',
                 cbar_orientation = 'horizontal',
                 cbar_size='7%', 
                 cbar_pad=0.1, 
                 axes_pad = 0.2):
                 #cmap=cm.cm.balance,
                 #ncolors=101,
                 #vrange = [0, 1]):
        
        self.region = region
        self.cbar_orientation = cbar_orientation
        #self.vrange = vrange
        #self.ncolors = ncolors
        #self.cmap = cmap
        
        ### Setup figure and axes
        if fig is None:
            fig = plt.figure(figsize=(8.5,11))
            
        # Define projection
        if self.region.upper()=='SOUTHERN-OCEAN':
            projection = ccrs.SouthPolarStereo()
        
        if self.region.upper()=='WORLD':
            projection=ccrs.Robinson(central_longitude=0)
            
        # Setup axesgrid
        axes_class = (GeoAxes, dict(map_projection=projection))
        self.grid = AxesGrid(fig, 
                             rect=rect, 
                             axes_class=axes_class,
                             share_all=False,
                             nrows_ncols = nrows_ncols,
                             axes_pad = axes_pad,
                             cbar_location = cbar_location,
                             cbar_mode= cbar_mode if colorbar==True else None,
                             cbar_pad = cbar_pad if colorbar==True else None,
                             cbar_size = cbar_size,
                             label_mode = '')  # note the empty label_mode 
    
    def add_plot(self, lon=None, lat=None, 
                 data=None, 
                 ax=None, 
                 land=True, 
                 coastline=True, 
                 linewidth_coast=0.25, 
                 ncolors=101, 
                 vrange=[-25, 25], 
                 cmap=cm.cm.balance, 
                 facecolor=[0.25,0.25,0.25],
                 *args, **kwargs):
        """
        add_plot(lon, lat, data, **kwargs)
        
        Inputs:
        ==============
        sub : subplot (this is returuned from add_plot())
        ax. : axis number to add colorbar to
        
        """
        
        self.vrange = vrange[0:2]
        self.ncolors = vrange[2]
        self.cmap = cmap
        
        ### Set Longitude if none is given
        if lon is None:
            self.lon = np.arange(-179.5,180,1)
        else:
            self.lon = lon
            
        ### Set latitude if none is given
        if lat is None:
            self.lat = np.arange(-89.5,90,1)
        else:
            self.lat = lat
            
        self.transform = ccrs.PlateCarree(central_longitude=0)
        self.bounds = np.linspace(self.vrange[0], self.vrange[1], self.ncolors)
        self.norm = mpl.colors.BoundaryNorm(self.bounds, self.cmap.N)
        
        # Define southern ocean region
        if self.region.upper()=='SOUTHERN-OCEAN':
            # Compute a circle in axes coordinates, which we can use as a boundary
            # for the map. We can pan/zoom as much as we like - the boundary will be
            # permanently circular.
            theta = np.linspace(0, 2*np.pi, 100)
            center, radius = [0.5, 0.5], 0.5
            verts = np.vstack([np.sin(theta), np.cos(theta)]).T
            circle = mpath.Path(verts * radius + center)
            
            # Set extent
            self.grid[ax].set_boundary(circle, transform=self.grid[ax].transAxes)

            # Limit the map to -60 degrees latitude and below.
            self.grid[ax].set_extent([-180, 180, -90, -35], ccrs.PlateCarree())
            
            
        ### land mask
        # Add Contintents
        if land is True:
            self.grid[ax].add_feature(cfeature.NaturalEarthFeature('physical', 'land', '110m', 
                                                edgecolor='None', 
                                                facecolor=facecolor))
        
        ## add Coastline
        if coastline is True:
            self.grid[ax].coastlines(facecolor=facecolor, linewidth=linewidth_coast)
        
        sub = self.grid[ax].pcolormesh(self.lon, self.lat, data,
                            norm=self.norm,
                            transform=self.transform,
                            cmap = self.cmap,
                            vmin = self.vrange[0],
                            vmax = self.vrange[1], *args, **kwargs)
        return sub
    
    def add_colorbar(self, sub, ax=0, *args, **kwargs):
        """
        add_colorbar(sub, ax, **kwargs)
        
        Inputs:
        ==============
        sub : subplot (this is returuned from add_plot())
        ax. : axis number to add colorbar to
        
        """
        # Weird whitespace when you use 'extend'
        # The workaround is to make a colorbar
        # Help from : https://github.com/matplotlib/matplotlib/issues/9778
        
        col = self.grid.cbar_axes[ax].colorbar(sub, *args, **kwargs)
        #col = mpl.colorbar.ColorbarBase(self.grid.cbar_axes[ax], 
         #                               orientation=self.cbar_orientation,
         #                               cmap=self.cmap,
         #                               norm=mpl.colors.Normalize(vmin=self.vrange[0], 
         #                                                         vmax=self.vrange[1]),
         #                               *args, **kwargs)
    
#cb2 = mpl.colorbar.ColorbarBase(ax, cmap=cmap,
#                                norm=norm,
#                                boundaries=[0] + bounds + [13],
#                                extend='both',
#                                ticks=bounds,
#                                spacing='proportional',
#                                orientation='horizontal')

        return col
    
    ### Class methods
    def set_ticks(self, col, tmin, tmax, dt, *args, **kwargs):
        """
        set_ticks(tmin,tmax,dt, **kwargs)
        
        Inputs:
        ==============
        tmin : min tick value
        tmax : max tick value
        dt.  : delta tick value
        
        """
        #col.cbar_axis.set_ticks(np.arange(tmin, tmax+dt, dt), *args, **kwargs)
        col.set_ticks(ticks=np.arange(tmin, tmax+dt, dt), *args, **kwargs)
        
    def set_title(self, title, ax, *args, **kwargs):
        """
        set_title(title, *args, **kwargs)
        
        Inputs:
        ==============
        title : title value
        
        """
        self.grid[ax].set_title(title, *args, **kwargs)
        
    def set_cbar_title(self, col, title, *args, **kwargs):
        """
        set_cbar_title(title, *args, **kwargs)
        
        Inputs:
        ==============
        title : colorbar title value
        
        """
        col.ax.set_title(title, *args, **kwargs)

    def set_cbar_ylabel(self, col, ylabel, *args, **kwargs):
        """
        set_cbar_ylabel(title, *args, **kwargs)
        
        Inputs:
        ==============
        title : colorbar title value
        
        """
        col.ax.set_ylabel(ylabel, *args, **kwargs)
        
    def set_cbar_xlabel(self, col, ylabel, *args, **kwargs):
        """
        set_cbar_xlabel(title, *args, **kwargs)
        
        Inputs:
        ==============
        title : colorbar title value
        
        """
        col.ax.set_xlabel(ylabel, *args, **kwargs)
        
    def set_cbar_xticklabels(self, col, labels, *args, **kwargs):
        """
        set_cbar_labels(labels, *args, **kwargs)
        
        Inputs:
        ==============
        labels : custom colorbar labels
        
        """
        col.ax.set_xticklabels(labels, *args, **kwargs)
        
    def set_cbar_yticklabels(self, col, labels, *args, **kwargs):
        """
        set_cbar_labels(labels, *args, **kwargs)
        
        Inputs:
        ==============
        labels : custom colorbar labels
        
        """
        col.ax.set_yticklabels(labels, *args, **kwargs)

Taylor Diagram

In [16]:
#%%writefile TaylorDiagram.py

import numpy as NP
import matplotlib.pyplot as PLT
%matplotlib inline

class TaylorDiagram(object):
    """
    Taylor diagram.
    Plot model standard deviation and correlation to reference (data)
    sample in a single-quadrant polar plot, with r=stddev and
    theta=arccos(correlation).
    Modified from gist : https://gist.github.com/ycopin/3342888
    """

    def __init__(self, refstd,
                 fig=None, rect=111, label='_', srange=(0, 1.5), extend=False):
        """
        Set up Taylor diagram axes, i.e. single quadrant polar
        plot, using `mpl_toolkits.axisartist.floating_axes`.
        Parameters:
        * refstd: reference standard deviation to be compared to
        * fig: input Figure or None
        * rect: subplot definition
        * label: reference label
        * srange: stddev axis extension, in units of *refstd*
        * extend: extend diagram to negative correlations
        """

        from matplotlib.projections import PolarAxes
        import mpl_toolkits.axisartist.floating_axes as FA
        import mpl_toolkits.axisartist.grid_finder as GF

        self.refstd = refstd            # Reference standard deviation

        tr = PolarAxes.PolarTransform()

        # Correlation labels
        rlocs = NP.array([0, 0.2, 0.4, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 1])
        if extend:
            # Diagram extended to negative correlations
            self.tmax = NP.pi
            rlocs = NP.concatenate((-rlocs[:0:-1], rlocs))
        else:
            # Diagram limited to positive correlations
            self.tmax = NP.pi/2
        tlocs = NP.arccos(rlocs)        # Conversion to polar angles
        gl1 = GF.FixedLocator(tlocs)    # Positions
        tf1 = GF.DictFormatter(dict(zip(tlocs, map(str, rlocs))))

        # Standard deviation axis extent (in units of reference stddev)
        self.smin = srange[0] * self.refstd
        self.smax = srange[1] * self.refstd

        ghelper = FA.GridHelperCurveLinear(
            tr,
            extremes=(0, self.tmax, self.smin, self.smax),
            grid_locator1=gl1, tick_formatter1=tf1)

        if fig is None:
            fig = PLT.figure()

        ax = FA.FloatingSubplot(fig, rect, grid_helper=ghelper)
        fig.add_subplot(ax)

        # Adjust axes
        ax.axis["top"].set_axis_direction("bottom")   # "Angle axis"
        ax.axis["top"].toggle(ticklabels=True, label=True)
        ax.axis["top"].major_ticklabels.set_axis_direction("top")
        ax.axis["top"].label.set_axis_direction("top")
        ax.axis["top"].label.set_text("Correlation")

        ax.axis["left"].set_axis_direction("bottom")  # "X axis"
        ax.axis["left"].label.set_text("Normalized Standard deviation")

        ax.axis["right"].set_axis_direction("top")    # "Y-axis"
        ax.axis["right"].toggle(ticklabels=True)
        ax.axis["right"].major_ticklabels.set_axis_direction(
            "bottom" if extend else "left")

        ax.axis["bottom"].set_visible(False)          # Unused

        
        self._ax = ax                   # Graphical axes
        self.ax = ax.get_aux_axes(tr)   # Polar coordinates

        # Add reference point and stddev contour
        l, = self.ax.plot([0], self.refstd, 'k*',
                          ls='', ms=10, label=label)
        t = NP.linspace(0, self.tmax)
        r = NP.zeros_like(t) + self.refstd
        self.ax.plot(t, r, 'k--', label='_')

        # Collect sample points for latter use (e.g. legend)
        self.samplePoints = [l]

    def add_sample(self, stddev, corrcoef, *args, **kwargs):
        """
        Add sample (*stddev*, *corrcoeff*) to the Taylor
        diagram. *args* and *kwargs* are directly propagated to the
        `Figure.plot` command.
        """

        l, = self.ax.plot(NP.arccos(corrcoef), stddev,
                          *args, **kwargs)  # (theta, radius)

        self.samplePoints.append(l)

        return l
    
    def add_scatter(self, stddev, corrcoef, *args, **kwargs):
        l = self.ax.scatter(NP.arccos(corrcoef),  stddev, 
                          *args, **kwargs)
        self.samplePoints.append(l)
        return l

    def add_grid(self, *args, **kwargs):
        """Add a grid."""

        self._ax.grid(*args, **kwargs)

    def add_contours(self, levels=5, **kwargs):
        """
        Add constant centered RMS difference contours, defined by *levels*.
        """

        rs, ts = NP.meshgrid(NP.linspace(self.smin, self.smax),
                             NP.linspace(0, self.tmax))
        # Compute centered RMS difference
        rms = NP.sqrt(self.refstd**2 + rs**2 - 2*self.refstd*rs*NP.cos(ts))

        contours = self.ax.contour(ts, rs, rms, levels, **kwargs)

        return contours

xy_plot : 

class to make simple XY plots

In [17]:
import numpy as np
import xarray as xr
import cmocean as cm
%matplotlib inline
class xy_plot_OLD(object):
    """
    Spatial Map
    Plot data on a world map centerted at 157.5 deg
    * set_ticks
    * set_title
    * set_cbar_title
    """
    def __init__(self, fig=None, rect=111,  *args, **kwargs):
        ''' spatial_map()'''
    
        ### Setup figure and axes
        if fig is None:
            fig = plt.figure(figsize=(8.5,11), *args, **kwargs)
        
        ### Setup axes
        self.ax = fig.add_subplot(rect)
        fig.add_subplot(self.ax)
    
    ### Class methods
    def add_line(self, x, y, *args, **kwargs):
        self.ax.plot(x, y,  *args, **kwargs)
        
    def set_title(self, title, *args, **kwargs):
        self.ax.set_title(title, fontsize=16, *args, **kwargs)
        
    def ax_square(self):
        # Aspect - square
        x0,x1 = self.ax.get_xlim()
        y0,y1 = self.ax.get_ylim()
        self.ax.set_aspect(abs(x1-x0)/abs(y1-y0))
        self.ax.get_position().bounds
    
    def xy_properties(self, fs_label=24, fs_ticks=18, 
                      xlim=None, xticks=None, xlabel='',
                      ylim=None, yticks=None, ylabel='',
                      title=''):
        if xlim:
            self.ax.set_xlim(xlim[0], xlim[1])
            
        if xticks:
            self.ax.set_xticks(np.arange(xticks[0], xticks[1]+xticks[2], xlticks[2]))
        
        if ylim:
            self.ax.set_ylim(ylim[0], ylim[1])
            
        if yticks:
            self.ax.set_yticks(np.arange(yticks[0], yticks[1]+yticks[2], yticks[2]))

        self.ax.set_ylabel(ylabel, fontsize=fs_label)
        self.ax.set_xlabel(xlabel, fontsize=fs_label)
        self.ax.set_title(title, fontsize=fs_label)
        self.ax.xaxis.set_tick_params(labelsize=fs_ticks)
        self.ax.yaxis.set_tick_params(labelsize=fs_ticks)
                      
    def legend(self, *args, **kwargs):
        self.ax.legend(frameon=False,*args, **kwargs)

PolarStereoMap : 

a class to generate polar sterogrpahic maps

This was merged with SpatialMap2 and is now obsolete

In [18]:
import matplotlib.path as mpath
import numpy as np
import xarray as xr
import cmocean as cm
import cartopy.crs as ccrs
import cartopy.feature
from cartopy.mpl.geoaxes import GeoAxes
from mpl_toolkits.axes_grid1 import AxesGrid
import matplotlib as mpl
import matplotlib.pyplot as plt

class PolarStereoMap(object):
    """
    Spatial Map
    Plot data on a world map centerted at 157.5 deg
    * set_ticks
    * set_title
    * set_cbar_title
    """

    def __init__(self, data, fig=None, rect=111, 
                 cmap=cm.cm.balance, colorbar=True, ncolors=101, vrange = [0, 1], 
                 lon=np.arange(0.5,360,1), lat=np.arange(-89.5,90,1)):
            
    #def __init__(self, data, fig=None, rect=111, vrange = [0, 1],  
    #             cmap=cm.cm.balance, 
    #             colorbar=True, ncolors=101):
    #    ''' spatial_map()'''
    
        ### Setup figure and axes
        if fig is None:
            fig = plt.figure(figsize=(8.5,11))
            
        projection = ccrs.SouthPolarStereo()
        axes_class = (GeoAxes, dict(map_projection=projection))
        grid = AxesGrid(fig, rect, axes_class=axes_class,
                share_all=False,
                nrows_ncols = (1, 1),
                axes_pad = 0.2,
                cbar_location = 'bottom',
                cbar_mode="edge",
                cbar_pad = 0.1,
                cbar_size = '7%',
                label_mode = '')  # note the empty label_mode
        ### should really change ax to grid[0] everywhere
        self.ax = grid[0]

        ### Setup axes
        #self.ax = fig.add_subplot(rect, projection=ccrs.SouthPolarStereo())
        #fig.add_subplot(self.ax)
        
        # Compute a circle in axes coordinates, which we can use as a boundary
        # for the map. We can pan/zoom as much as we like - the boundary will be
        # permanently circular.
        theta = np.linspace(0, 2*np.pi, 100)
        center, radius = [0.5, 0.5], 0.5
        verts = np.vstack([np.sin(theta), np.cos(theta)]).T
        circle = mpath.Path(verts * radius + center)
        self.ax.set_boundary(circle, transform=self.ax.transAxes)

        # Limit the map to -60 degrees latitude and below.
        self.ax.set_extent([-180, 180, -90, -35], ccrs.PlateCarree())
        self.ax.gridlines()
        self.ax.add_feature(cartopy.feature.LAND)
    
        ### Latitude and longitude
        #self.lon=np.arange(0.5,360,1)
        #self.lat=np.arange(-89.5,90,1)
        self.lon = lon
        self.lat = lat
        
        ### colorbar
        self.transform = ccrs.PlateCarree(central_longitude=0)
        self.bounds = np.linspace(vrange[0], vrange[1], ncolors)
        self.norm = mpl.colors.BoundaryNorm(self.bounds, cmap.N)
        
        ### land mask
        self.ax.add_feature(cfeature.NaturalEarthFeature('physical', 'land', '110m', 
                                                    edgecolor='face', 
                                                    facecolor=[0.4, 0.4, 0.4]))
        self.ax.coastlines(facecolor=[0.4, 0.4, 0.4])
        
        ### Add data to map
        sub = self.ax.pcolormesh(self.lon, self.lat, data,
                                    norm=self.norm,
                                    transform=self.transform,
                                    cmap = cmap,
                                    vmin = vrange[0],
                                    vmax = vrange[1])

        ### Colorbar
        if colorbar is True:
           # self.cbar = plt.colorbar(sub, orientation="horizontal", pad=0.01, shrink=0.8)
           # self.cbar.ax.tick_params(labelsize=14)
            self.cbar = plt.colorbar(sub, cax=grid.cbar_axes[0], orientation='horizontal')
    
            ### ==============================================================
        ### Add colorbar
        ### Better to use plt with reference to cax 
        ### than to grid.cbar_axes[0].colorbar(sub)
        ### since it is clunky as all get
        ### ==============================================================
        #cbar = plt.colorbar(sub, cax=grid.cbar_axes[0])
        #cbar.set_label(r'Tributary derived phosphorus [$\rm \mu gPL^{-1}$]', fontsize=16)
        #cbar.set_ticks(np.arange(0, 2.2, 0.5))
        #cbar.ax.tick_params(labelsize=12)
    
    
    ### Class methods
    def set_ticks(self, tmin, tmax, dt, *args, **kwargs):
        self.cbar.set_ticks(np.arange(tmin, tmax+dt, dt))
        
    def set_title(self, title, *args, **kwargs):
        self.ax.set_title(title, fontsize=16, **kwargs)
        
    def set_cbar_title(self, title, *args, **kwargs):
        self.cbar.set_label(title,fontsize=16)

    def set_cbar_labels(self, labels, *args, **kwargs):
        self.cbar.ax.set_xticklabels(labels, **kwargs)  # horizontal colorbar