This notebook is a small demonstration of how `xarray`, `ipywidgets`, and `matplotlib` can be leveraged in a notebook to make interactive plots. We'll start by importing these libraries, along with `cartopy` for some nicer plots. Note the use of the interactive version of `matplotlib`, which is optional.

In [12]:
%matplotlib widget
import xarray as xr
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib as mpl
from ipywidgets import interact, interactive
import ipywidgets as widgets
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np
from IPython.display import display
from calendar import monthrange

We'll use the North American surface temperature dataset included with xarray for this example:

In [2]:
airtemps = xr.tutorial.open_dataset("air_temperature")
airtemps

Now we will plot a slab of the data by making a function which takes as its only argument the index of the time we want:

In [3]:
def plot1(time):
    #close the figure if reloading
    plt.close(1)
    #let's use a nice projection for the US
    proj = ccrs.LambertConformal(cutoff=0)
    trans = ccrs.PlateCarree()
    fig = plt.figure(1)
    ax = fig.subplots(1,1,subplot_kw={'projection':proj})
    #lon0 lon1 lat0 lat1
    ax.set_extent([230,295,20,55],crs=trans)
    #colorbars don't scale well to the size of cartopy axes, so
    #we make an axis divider
    divider = make_axes_locatable(ax)
    cbax = divider.append_axes('bottom',size='8%', pad='7%',axes_class=plt.Axes)
    
    #now subset and plot, adding some features
    da = airtemps['air'].isel(time=time)
    #normalize by the 5th and 98th percentiles across all times so it always looks good
    da.plot(ax=ax, transform=trans, cmap=plt.cm.Spectral_r, vmin=airtemps.air.quantile(.05), vmax=airtemps.air.quantile(.98),
            cbar_kwargs={'orientation':'horizontal','label':airtemps.air.units,'cax':cbax})
    ax.add_feature(cfeature.STATES)
    ax.coastlines()
    #need to convert to shorter time unit to use strftime
    timestr = da.time.values.astype('datetime64[h]').item().strftime('%Y-%m-%d %HZ')
    t = ax.set_title(timestr)

And let's see how it looks:

In [4]:
plot1(180)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Now that we have a nice figure, let's make it interactive using the `interact` method. Interact uses a few types of widgets, but the easiest for this case will probably be a dropdown, produce by passing a list. You could also use a slider by passing an integer, but this has some potential problems we'll skirt for this example

In [6]:
times_list = [(t.item().strftime('%Y-%m-%d %HZ'),i) for i,t in enumerate(airtemps.time.values.astype('datetime64[h]'))]
w = interact(plot1,time=times_list)

interactive(children=(Dropdown(description='time', options=(('2013-01-01 00Z', 0), ('2013-01-01 06Z', 1), ('20…

One challenge here is that there are a LOT of options. You can easily change the array you are plotting to look at aggregate version of the data. Here we will adapt the example above to look only through monthly mean data:

In [7]:
monthly_airtemps = airtemps.resample(time='1M').mean()
monthly_airtemps

In [55]:
#close the figure if reloading
plt.close(2)

def plot2(time):
    #let's use a nice projection for the US
    proj = ccrs.LambertConformal(cutoff=0)
    trans = ccrs.PlateCarree()
    fig = plt.figure(2)
    fig.clear()
    ax = fig.subplots(1,1,subplot_kw={'projection':proj})
    #lon0 lon1 lat0 lat1
    ax.set_extent([230,295,20,55],crs=trans)
    #colorbars don't scale well to the size of cartopy axes, so
    #we make an axis divider
    divider = make_axes_locatable(ax)
    cbax = divider.append_axes('bottom',size='8%', pad='7%',axes_class=plt.Axes)
    
    #now subset and plot, adding some features
    da = monthly_airtemps['air'].isel(time=time)
    #normalize by the 5th and 98th percentiles across all times so it always looks good
    da.plot(ax=ax, transform=trans, cmap=plt.cm.Spectral_r, vmin=monthly_airtemps.air.quantile(.05), 
            vmax=monthly_airtemps.air.quantile(.995),
            cbar_kwargs={'orientation':'horizontal','label':airtemps.air.units,'cax':cbax})
    ax.add_feature(cfeature.STATES)
    ax.coastlines()
    #need to convert to shorter time unit to use strftime
    timestr = da.time.values.astype('datetime64[h]').item().strftime('%Y-%m')
    t = ax.set_title(timestr)

This time we will demo using a slider. We pass the entire widget here so that we can prevent the slider from taking unusable values, and so that we can make sure it doesn't try to replot until we are done dragging the slider by turning off the continuous update.

In [56]:
w = interact(plot2,time=widgets.IntSlider(min=0,max=monthly_airtemps.time.size-1,continuous_update=False))

interactive(children=(IntSlider(value=0, continuous_update=False, description='time', max=23), Output()), _dom…

Now suppose you want multiple widgets, perhaps for different dimensions in the data or to perform other operations. Since this dataset is only 3D, we will make dropdowns for year, month, day, and hour. First we'll need to modify our plotting function slightly. Also note that the figure number and plot function name are updated every time. This ensure fewer conflicts with other figures that may be open in the notebook.

In [48]:
#clear the figure when reloading
plt.close(3)

def plot3(year,month,day,hour):
    #let's use a nice projection for the US
    proj = ccrs.LambertConformal(cutoff=0)
    trans = ccrs.PlateCarree()
    fig = plt.figure(3)
    #clear figure since reusing
    fig.clear()
    ax = fig.subplots(1,1,subplot_kw={'projection':proj})
    #lon0 lon1 lat0 lat1
    ax.set_extent([230,295,20,55],crs=trans)
    #colorbars don't scale well to the size of cartopy axes, so
    #we make an axis divider
    divider = make_axes_locatable(ax)
    cbax = divider.append_axes('bottom',size='8%', pad='7%',axes_class=plt.Axes)
    
    #ensure we don't try to pick days that don't exist
    #since we are passing string, join all together first
    dpm = monthrange(int(year),int(month))[1]
    if int(day) > dpm:
        day = str(dpm)
    time=f'{year}-{month}-{day} {hour}:00:00Z'
    #now subset and plot, adding some features
    da = airtemps['air'].sel(time=time,method='nearest')
    #normalize by the 5th and 98th percentiles across all times so it always looks good
    da.plot(ax=ax, transform=trans, cmap=plt.cm.Spectral_r, vmin=airtemps.air.quantile(.05), vmax=airtemps.air.quantile(.98),
            cbar_kwargs={'orientation':'horizontal','label':airtemps.air.units,'cax':cbax})
    ax.add_feature(cfeature.STATES)
    ax.coastlines()
    #need to convert to shorter time unit to use strftime
    timestr = da.time.values.astype('datetime64[h]').item().strftime('%Y-%m-%d %HZ')
    t = ax.set_title(timestr)

Next we will make our interactive plot. We will now use `interactive` to return the widgets instead of immediately creating them. This way we can change the layout of the widgets before displaying them. Note that the figure below won't appear until you change a dropdown.

In [49]:
times = airtemps.time
years_list = [str(yr) for yr in np.unique(times.dt.year)]
month_list = [str(m) for m in range(1,13)]
day_list = [str(d) for d in range(1,32)]
hr_list = [str(t) for t in range(0,24,6)]
w = interactive(plot3,year=years_list,month=month_list,day=day_list,hour=hr_list)
box = widgets.Box(w.children,layout=widgets.Layout(width=f'{plt.gcf().get_size_inches()[0]*1.2:.1f}in', flex_flow='row wrap',
                        justify_content='space-around', margin='0 0 0 .5in', flex='1 0 auto'))
display(box)

Box(children=(Dropdown(description='year', options=('2013', '2014'), value='2013'), Dropdown(description='mont…

One problem with the way we are currently `interact`ing is that we are redrawing the figure every time we update the data slice, which can get pretty slow for more complicated or higher resolution figures. If creating a pcolormesh or line plot, we can only update the data using the `QuadMesh.set_array()` method or the `Line.set_data()` method. Let's illustrate.

This time, we need to set up the figure with some starter data before defining our function so that our function can reuse the existing axes. Note the widget appears below the figure now since `plt.figure()` is called before the widget is displayed. We will also use the button widgets to make an `ncview`-style changer.

In [47]:
plt.close(4)
time = 0
#all of this stuff can be reused
#let's use a nice projection for the US
proj = ccrs.LambertConformal(cutoff=0)
trans = ccrs.PlateCarree()
fig = plt.figure(4)
ax = fig.subplots(1,1,subplot_kw={'projection':proj})
#lon0 lon1 lat0 lat1
ax.set_extent([230,295,20,55],crs=trans)
#colorbars don't scale well to the size of cartopy axes, so
#we make an axis divider
divider = make_axes_locatable(ax)
cbax = divider.append_axes('bottom',size='8%', pad='7%',axes_class=plt.Axes)
#add features
ax.add_feature(cfeature.STATES)
ax.coastlines()

#now subset and plot
da0 = airtemps['air'].isel(time=time)
#normalize by the 5th and 98th percentiles across all times so it always looks good
cs = da0.plot(ax=ax, transform=trans, cmap=plt.cm.Spectral_r, vmin=airtemps.air.quantile(.02), vmax=airtemps.air.max(),
        cbar_kwargs={'orientation':'horizontal','label':airtemps.air.units,'cax':cbax})
#need to record all axes children at this point so title can be updated
orig_childs = set(ax.get_children())
#need to convert to shorter time unit to use strftime
timestr = da0.time.values.astype('datetime64[h]').item().strftime('%Y-%m-%d %HZ')
t = ax.set_title(timestr)

def plot4(time):
    #remove old title
    childs = set(ax.get_children())
    new_objs = list(childs.symmetric_difference(orig_childs))
    for obj in new_objs:
        if isinstance(obj,mpl.Text.text):
            obj.remove()
    #update data and title
    da = airtemps['air'].isel(time=time)
    #set array requires 1D array
    #ensure vmin and vmax were specified when creating pcolormesh as they will be reused
    cs.set_array(da.values.ravel())
    timestr = da.time.values.astype('datetime64[h]').item().strftime('%Y-%m-%d %HZ')
    t = ax.set_title(timestr)
    
#define event handlers for buttons
#need global time so each button can update, so use an invisible slider to track the time
#this way no need for globals but will have a memory-persistant tracker
intslid = widgets.IntSlider(value=time,min=0,max=airtemps.time.size-1)
def forward(b):
    time = intslid.value
    if time < airtemps.time.size-1:
        time += 1
    else:
        time = 0
    intslid.value = time
    plot4(time)
    
def backward(b):
    time = intslid.value
    if time > 0:
        time -= 1
    else:
        time = airtemps.time.size-1
    intslid.value = time
    plot4(time)
    
#make widgets and register handlers
lbutton = widgets.Button(description='',icon='arrow-left',layout=widgets.Layout(width='auto',height='auto'))
lbutton.on_click(backward)
rbutton = widgets.Button(description='',icon='arrow-right',layout=widgets.Layout(width='auto',height='auto'))
rbutton.on_click(forward)
box = widgets.HBox([lbutton, rbutton], layout=widgets.Layout(width=f'{fig.get_size_inches()[0]:.1f}in',
                        justify_content='center', margin='0 0 0 .5in'))
display(box)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

HBox(children=(Button(icon='arrow-left', layout=Layout(height='auto', width='auto'), style=ButtonStyle()), But…