# Plotting with Matplotlib - Solutions

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
%matplotlib widget

## A simple 1D line plot

In [None]:
# Data for plotting
t = np.linspace(0.0, 2.0, 100)
s = 1 + np.sin(2 * np.pi * t)

# Create the Figure and Axes in one go using the `subplots` function.
# The first two arguments of the subplots function are the number of rows and columns, respectively.
fig, ax = plt.subplots(1, 1)

# Make a simple x,y line plot on out axes object
ax.plot(t, s, color="red", lw=4, ls="dashed")

# # Uncomment the next 2 lines for errorbars
# e = np.random.random(100) * 0.3
# ax.errorbar(t, s, yerr=e)

# Decorate the axes with axis labels and title
ax.set(xlabel='time (s)', ylabel='voltage (mV)',
       title='About as simple as it gets, folks')
ax.grid()

<div class="alert alert-block alert-info">
    <b>Exercise:</b>
    <br>
    - Look at the documentation and try to add a grid to the axes to obtain <img src="https://matplotlib.org/_images/sphx_glr_simple_plot_001.png" width="300px"/><br>
    - Change the line color to red, the line thickness to 4 and the line style to dashed.<br>
    - Add some random y error bars.
</div>

## Multiple subplots in one figure

In [None]:
x1 = np.linspace(0.0, 5.0)
x2 = np.linspace(0.0, 2.0)
y1 = np.cos(2 * np.pi * x1) * np.exp(-x1)
y2 = np.cos(2 * np.pi * x2)

fig, ax = plt.subplots(2, 2) # note here that ax is now a list of axes
ax[0][0].plot(x1, y1, 'o-')
ax[0][0].set_title('A tale of 4 subplots')
ax[0][0].set_ylabel('Damped oscillation')

ax[0][1].plot(x1, y1, 'o-')
ax[0][1].set_title('A zoomed plot')
ax[0][1].set_ylabel('Damped oscillation')
ax[0][1].set_xlim([0, 1])

ax[1][0].plot(x2, y2, '.-')
ax[1][0].set_xlabel('time (s)')
ax[1][0].set_ylabel('Undamped')

ax[1][1].plot(x2, y2, '.-')
ax[1][1].set_xlabel('time (s)')
ax[1][1].set_ylabel('Undamped')
ax[1][1].set_xlim([1, 2])
ax[1][1].set_ylim([-1, 0])

# Adjust spacing between subplots so `ax1` title and `ax0` tick labels
# don't overlap
fig.tight_layout()

<div class="alert alert-block alert-info">
    <b>Exercise:</b><br>
Try adding two more subplots on the right hand side of the existing ones to have a 2x2 grid of plots. The new subplots should show a zoom onto a particular area of the left subplots.
</div>

## 2D image/heatmap with uniformly sized pixels

In [None]:
N = 100
M = 50
x = np.arange(N+1)
y = np.arange(M+1)
z = np.random.random([M, N])

fig, ax = plt.subplots()
im = ax.imshow(z, origin="lower", extent=[x[0], x[-1], y[0], y[-1]],
               aspect="auto", cmap="magma", vmin=0.2, vmax=0.8)
cb = plt.colorbar(im) # adds the colorbar: we specify that it applies to the `im` object
cb.ax.set_ylabel("Counts")
ax.set_xlabel("x coordinate")
ax.set_ylabel("y coordinate")
ax.set_title("A 2D image")

<div class="alert alert-block alert-info">
    <b>Exercise:</b><br>
- Change the colormap to one of the many pre-defined <a href="https://matplotlib.org/3.1.1/gallery/color/colormap_reference.html">colormaps</a> that ship with matplotlib.
<br>
- In a second step, try to set the colorbar limits to [0.2, 0.8].
</div>

## 2D image/heatmap with non-uniformly sized pixels

In [None]:
N = 5000
M = 2000
x = np.arange(N+1)**2
y = np.arange(M+1)
z = np.random.random([M, N])

fig, ax = plt.subplots()
pcmesh = ax.pcolormesh(x, y, z)
cb = plt.colorbar(pcmesh)
cb.ax.set_ylabel("Counts")
ax.set_xlabel("x coordinate")
ax.set_ylabel("y coordinate")
ax.set_title("A 2D image with non-equal sized pixels")

<div class="alert alert-block alert-info">
    <b>Exercise:</b>
    <br>
    The `pcolormesh` function is slow for large datasets/images. See this for yourself by making the data above much larger and try to re-plot.
</div>

## 2D filled/empty contour plot

In [None]:
N = 100
xx = np.linspace(1.0, 5.0, N)
yy = np.linspace(1.0, 5.0, N)
x, y = np.meshgrid(xx, yy)
z = np.sin(x)**10 + np.cos(10 + y*x) * np.cos(x)

fig, ax = plt.subplots(3, 1, figsize=(9, 9))

pcmesh = ax[0].pcolormesh(x, y, z, cmap="RdBu")
cb1 = plt.colorbar(pcmesh, ax=ax[0])
cb1.ax.set_ylabel("Counts")
ax[0].set_title("pcolormesh")

contf = ax[1].contourf(x, y, z, cmap="RdBu", levels=np.linspace(-1.2, 1.2, 30))
cb2 = plt.colorbar(contf, ax=ax[1])
cb2.ax.set_ylabel("Counts")
ax[1].set_title("filled contours")

cont = ax[2].contour(x, y, z, cmap="RdBu")
cb3 = plt.colorbar(cont, ax=ax[2])
cb3.ax.set_ylabel("Counts")
ax[2].set_title("contours")

# Adjust spacing between subplots so `ax1` title and `ax0` tick labels
# don't overlap
fig.tight_layout()

<div class="alert alert-block alert-info">
    <b>Exercise:</b>
    <br>
    Change the number of filled contours to 30 in the middle panel.
</div>

## Histogram

In [None]:
y = np.random.normal(50.0, scale=20.0, size=1000)

fig, ax = plt.subplots()
ax.hist(y, bins=np.logspace(0, np.log10(150.0), 30))
ax.set_xlabel("Some x label [m]")
ax.set_ylabel("A fancy y label [kg]")
ax.set_title("A 1D histogram plot")

<div class="alert alert-block alert-info">
    <b>Exercise:</b>
    <br>
    Change the bins to be 30 log-spaced bins between 1 and 150.
</div>

## 2D scatter plot with different symbol sizes and colors

In [None]:
N = 100
x = np.random.random(N)
y = np.random.random(N)
z = np.random.random(N)
s = 300.0 * np.random.random(N)

fig, ax = plt.subplots()
scat = ax.scatter(x, y, c=z, s=s, marker='s', alpha=0.5)
cb = plt.colorbar(scat)
cb.ax.set_ylabel("The colored quantity")
ax.set_xlabel("x coordinate")
ax.set_ylabel("y coordinate")
ax.set_title("A 2D scatter plot")

# produce a legend with a cross section of sizes from the scatter
handles, labels = scat.legend_elements(prop="sizes")
legend = ax.legend(handles, labels, loc="upper right", title="Sizes")

<div class="alert alert-block alert-info">
    <b>Exercise:</b>
    <br>
    Change the scatter markers from circles to squares and change their opacity to 50%.
</div>

## Quiver (vectors) and stream plots

In [None]:
N = 100
w = 5.0
xx = np.linspace(-w, w, N)
x, y = np.meshgrid(xx, xx)
u = -1.0 - x**2 + y
v = -1.0 + x - y**2

fig, ax = plt.subplots(1, 2, figsize=(9, 4))

vmag = np.sqrt(u**2 + v**2)

M = 4
# Here we plot only 1 in every M^2 points to avoid over-crowding the figure
ax[0].imshow(vmag, extent=[xx[0], xx[-1], xx[0], xx[-1]])
ax[0].quiver(x[::M, ::M], y[::M, ::M], u[::M, ::M], v[::M, ::M], scale=200.0)
ax[0].set_title('Quiver (vector field)')

strm = ax[1].streamplot(x, y, u, v, color=vmag)
cb = plt.colorbar(strm.lines)
cb.ax.set_ylabel("Velocity magnitude")
ax[1].set_title('Streamlines')

<div class="alert alert-block alert-info">
    <b>Exercise:</b>
    <br>
    - Overplot the arrows or streamlines onto an image or contour plot of the velocity magnitude.
    <br>
    - Change the color of the arrows or streamlines to reflect the velocity magnitude via a colormap.
</div>

## Patches: drawing geometric shapes

In [None]:
import matplotlib.patches as mpatches

fig, ax = plt.subplots(figsize=(6, 6))
ax.add_patch(mpatches.Rectangle([0.0, 0.0], 1.0, 2.0))
ax.add_patch(mpatches.Circle([3.0, 3.0], 1.6, color="#FF5733"))
ax.add_patch(mpatches.Polygon([[1.0, 1.5], [3.0, 3.0], [4.0, 0.5]], color="#4CFF33"))
ax.arrow(1.0, 1.0, 2, 3, width=0.1, color="cyan")
ax.text(0, 4, "My text")
ax.set_xlim([-0.5, 5.0])
ax.set_ylim([-0.5, 5.0])

<div class="alert alert-block alert-info">
    <b>Exercise:</b>
    <br>
    Add an arrow and some text in the figure above.
</div>

## Grouped bar chart with labels

In [None]:
labels = ['G1', 'G2', 'G3', 'G4', 'G5']
cats_means = [20, 34, 30, 35, 27]
dogs_means = [25, 32, 34, 20, 25]

x = np.arange(len(labels))  # the label locations
width = 0.35  # the width of the bars

fig, ax = plt.subplots()
rects1 = ax.bar(x - width/2, cats_means, width, label='Cats')
rects2 = ax.bar(x + width/2, dogs_means, width, label='Dogs')

# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel('Scores')
ax.set_title('Scores by group and animal')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()

def autolabel(rects):
    """Attach a text label above each bar in *rects*, displaying its height."""
    for rect in rects:
        height = rect.get_height()
        ax.annotate('{}'.format(height),
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom')


autolabel(rects1)
autolabel(rects2)

<div class="alert alert-block alert-info">
    <b>Exercise:</b>
    <br>
    Attach a text label above each bar in rects1 and rects2, displaying its height.
</div>

## Pie chart

In [None]:
# Pie chart, where the slices will be ordered and plotted counter-clockwise:
labels = 'Frogs', 'Hogs', 'Dogs', 'Logs'
sizes = [15, 30, 45, 10]

fig1, ax = plt.subplots(2, 2)

ax[0][0].pie(sizes, explode=(0.1, 0, 0, 0), labels=labels, autopct='%1.1f%%',
        shadow=True, startangle=90)
ax[0][0].axis('equal')

ax[0][1].pie(sizes, explode=(0, 0.1, 0, 0), labels=labels, autopct='%1.1f%%',
        shadow=True, startangle=90)
ax[0][1].axis('equal')

ax[1][0].pie(sizes, explode=(0, 0, 0.1, 0), labels=labels, autopct='%1.1f%%',
        shadow=True, startangle=90)
ax[1][0].axis('equal')

ax[1][1].pie(sizes, explode=(0, 0, 0, 0.1), labels=labels, autopct='%1.1f%%',
        shadow=True, startangle=90)
ax[1][1].axis('equal')

<div class="alert alert-block alert-info">
    <b>Exercise:</b>
    <br>
    Create 4 subplots where a different section is exploded in each subplot.
</div>

## Log plots

In [None]:
# Data for plotting
t = np.arange(0.01, 20.0, 0.01)

# Create figure
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)

# log y axis
ax1.semilogy(t, np.exp(-t / 5.0))
ax1.set(title='semilogy')
ax1.grid()

# log x axis
y = np.sin(2 * np.pi * t)
ax2.semilogx(t, y)
ax2.set(title='semilogx')
ax2.grid()

# log x and y axis
ax3.loglog(t, 20 * np.exp(-t / 10.0), basex=2)
ax3.set(title='loglog base 2 on x')
ax3.grid()

# Image with log x axis
ax4.imshow(np.random.random([20, 20]), extent=[1, 100.0, 0, 19])
ax4.set_xscale("log", nonposx='clip')
ax4.set(title='log image')

fig.tight_layout()

<div class="alert alert-block alert-info">
    <b>Exercise:</b>
    <br>
    Create a 2D image where the x axis is logarithmic.
</div>

## Controlling axis placement

This example will illustrate how to place axes in exact position with respect to the figure, using the `add_axes()` function. This is useful for placing side panels around a central figure for instance.

In [None]:
N = 1000
nbins = 50
x = np.random.normal(0.0, scale=20.0, size=N)
y = np.random.normal(0.0, scale=20.0, size=N)

dx = 0.65
xymin = 0.1
fig = plt.figure(figsize=(8, 8))
# Use: add_axes([x_lower_left, y_lower_left, size_x, size_y])
ax1 = fig.add_axes([xymin, xymin, dx, dx]) # Central figure
# Share the axes with sharex and sharey
ax2 = fig.add_axes([xymin+dx, xymin, 0.2, dx], sharey=ax1) # Right histogram
ax3 = fig.add_axes([xymin, xymin+dx, dx, 0.2], sharex=ax1) # Top histogram

ax1.scatter(x, y)
ax2.hist(y, bins=nbins, orientation='horizontal') # Note "horizontal" here refers to horizontal bars
ax3.hist(x, bins=nbins)

# Set the ticks on the histogram axes to the right and top to avoid overlap on main panel
ax2.yaxis.tick_right()
ax3.xaxis.tick_top()

# Inset
ax4 = fig.add_axes([0.6, 0.6, 0.1, 0.1])
ax4.scatter(x, y)
ax4.set_xlim([-5, 5])
ax4.set_ylim([-5, 5])
ax4.set_title("Zoom")

<div class="alert alert-block alert-info">
    <b>Exercise:</b>
    <br>
    - Create a new set of axes inside the central panel (inset) that shows a zoom onto the central area of the main panel.
    <br>
    - Connect the axes of the main panel and subplots so that zooming in on the central panel also zooms in on the histograms (hint: search for "shared" axes).
</div>

## Interactive 2D heatmap with slider through 3D data cube

`matplotlib` ships with its own set of [widgets](https://matplotlib.org/stable/api/widgets_api.html) which allow to create very useful and capable interactive data visualizations. In this example, we plot a 2D slice through a 3D data cube, and use a slider to navigate the 3rd dimension, updating the slice as we move the slider (note that performance can vary depending on if you are running the notebook kernel locally or remotely).

In [None]:
from matplotlib.widgets import Slider
 
data = np.random.random([10, 10, 10])
idx = 0

fig = plt.figure()
ax = fig.add_subplot(111)
fig.subplots_adjust(bottom=0.15)

im_object = ax.imshow(data[:, :, idx], interpolation='nearest')

ax.set_xlabel("x coordinate")
ax.set_ylabel("y coordinate")
ax.set_title("2D heatmap with slider")
cb = plt.colorbar(im_object)
cb.ax.set_ylabel("Temperature")

slider_ax = fig.add_axes([0.23, 0.02, 0.56, 0.04])
slider = Slider(slider_ax, 'Depth', 0, data.shape[2]-1, valinit=idx)

# Allow moving the slider with the mouse wheel
def onscroll(event):
        idx = int(round(slider.val))
        if event.button == "up":
            idx += 1
        else:
            idx -= 1
        idx = np.clip(idx, 0, data.shape[-1] - 1)
        slider.set_val(idx)

def update_depth(val):
    idx = int(round(slider.val))
    im_object.set_data(data[:, :, idx])
    return

slider.on_changed(update_depth)

fig.canvas.mpl_connect("scroll_event", onscroll)

<div class="alert alert-block alert-info">
    <b>Exercise:</b>
    <br>
    Connect a mouse scroll event to the figure canvas to also navigate the 3rd dimension with the mouse wheel (see matplotlib's <a href="https://matplotlib.org/3.1.1/users/event_handling.html">event handling</a>).
</div>

## Interactive linked scatter and histogram

In this example, we use the `RectangleSelector` to select points in a 2D scatter plots that represent the x and y coordinates of some data points. The points also have a 3rd dimension/property, e.g. temperature. The temperature of all the points are histogrammed in the right hand side panel. The histogram for only the selected points is overlayed with a red color and updated every time the selector is moved/resized.

In [None]:
from matplotlib.widgets import RectangleSelector, LassoSelector
from matplotlib.path import Path

fig, ax = plt.subplots(1, 2, figsize=(9, 5))
N = 1000
x = np.random.normal(0.0, scale=20.0, size=N)
y = np.random.normal(0.0, scale=20.0, size=N)
# Make a bimodal distribution for the temperature
z = np.concatenate([np.random.normal(-50.0, scale=20.0, size=N//2),
                    np.random.normal(50.0, scale=20.0, size=N//2)])
ax[0].set_xlabel("X coordinate")
ax[0].set_ylabel("Y coordinate")
ax[1].set_xlabel("Temperature")
ax[1].set_ylabel("Counts")
# Make scatter plot on the left
ax[0].scatter(x, y, alpha=.7)
# Make histogram on the right, with well defined bins (they will be re-used later)
bins = np.linspace(np.amin(z)-1.0, np.amax(z)+1.0, 50)
ax[1].hist(z, bins=bins)
# Save number of patches so that we know how many to delete
npatches = len(ax[1].patches)

xys = ax[0].collections[0].get_offsets()
def update_select(verts):
    print("hello")
    path = Path(verts)
    select = np.nonzero(path.contains_points(xys))[0]
    ax[1].patches = ax[1].patches[:npatches]
    # Draw new red histogram from selected points
    ax[1].hist(z[select], bins=bins, alpha=0.5, color='r')
selector = LassoSelector(ax[0], onselect=update_select, useblit=False,
                         lineprops={'color': 'red', 'linewidth': 4})

<div class="alert alert-block alert-info">
    <b>Exercise:</b>
    <br>
    - Try to replace the RectangleSelector with the <a href="https://matplotlib.org/3.1.0/api/widgets_api.html?highlight=widgets#matplotlib.widgets.LassoSelector">LassoSelector</a> or a different kind of selector (hint: in Jupyter notebooks, one has to make sure that useblit=False for the LassoSelector).
    <br>
    - Try to add a RectangleSelector in the histogram panel that would highlight scatter points in the left panel.
</div>

In [None]:
from matplotlib.widgets import RectangleSelector

fig, ax = plt.subplots(1, 2, figsize=(9, 5))
N = 1000
x = np.random.normal(0.0, scale=20.0, size=N)
y = np.random.normal(0.0, scale=20.0, size=N)
# Make a bimodal distribution for the temperature
z = np.concatenate([np.random.normal(-50.0, scale=20.0, size=N//2),
                    np.random.normal(50.0, scale=20.0, size=N//2)])

ax[0].set_xlabel("X coordinate")
ax[0].set_ylabel("Y coordinate")
ax[1].set_xlabel("Temperature")
ax[1].set_ylabel("Counts")

# Make scatter plot on the left
ax[0].scatter(x, y, alpha=.7)
# Copy the colors of the scatter point to match the size of the data arrays N
fc = np.broadcast_to(ax[0].collections[0].get_facecolors()[0], (N, 4)).copy()
# Save original R value of RGBA
red_original = fc[0, 0]

# Make histogram on the right, with well defined bins (they will be re-used later)
bins = np.linspace(np.amin(z)-1.0, np.amax(z)+1.0, 50)
ax[1].hist(z, bins=bins)

# Save number of patches so that we know how many to delete
npatches = len(ax[1].patches)

# Function to be called when rectangle selector is updated
def update_scatter(eclick, erelease):
    x1, w1 = eclick.xdata, eclick.ydata
    x2, w2 = erelease.xdata, erelease.ydata
    # Find all points that lie inside rectangle selector limits
    select = np.where(np.logical_and(z >= x1, z <= x2))
    # Reset all facecolors
    fc[:, 0] = red_original
    # Make selection red
    fc[:, 0][select] = 1
    ax[0].collections[0].set_facecolors(fc)

# Create rectangle selector and install callback to update function
selector = RectangleSelector(ax[1], update_scatter,
                             drawtype='box', useblit=True,
                             button=[1, 3],  # don't use middle button
                             interactive=True)