# Plotting with Matplotlib

[Matplotlib](https://matplotlib.org/) is a Python 2D plotting library which produces publication quality figures in a variety of hardcopy formats and interactive environments across platforms. Matplotlib can be used in Python scripts, the Python and IPython shells, the Jupyter notebook, web application servers, and four graphical user interface toolkits.

Matplotlib tries to make easy things easy and hard things possible. You can generate plots, histograms, power spectra, bar charts, errorcharts, scatterplots, etc., with just a few lines of code.

Matplotlib has also become the gold standard for creating figures for scientific publications. It is heavily inspired by `MATLAB`, and migrating from `MATLAB` to `matplotlib` should be relatively easy.

# MATLAB-style API vs object-oriented API

`matplotlib` provides two different ways of creating graphs. The first is via the `pyplot` API which aims to replicate the behaviour of `MATLAB` where aach pyplot function makes some change to a figure: e.g., creates a figure, creates a plotting area in a figure, plots some lines in a plotting area, decorates the plot with labels, etc. `pyplot` is mainly intended for interactive plots and simple cases of programmatic plot generation, and you can find more information [here](https://matplotlib.org/tutorials/introductory/pyplot.html).

At its core, `matplotlib` is object-oriented. We **strongly** recommend directly working with the objects, as it provides much more control and customization over your plots, especially when working with multiple figures or subplots.
In many cases you will create a `Figure` and one or more `Axes` using `pyplot.subplots` and from then on only work on these objects. This will be the philosophy used throughout this tutorial.

# General concepts

We will begin by introducing the basic terminology associated with the `matplotlib` library. This is best illustrated with a picture

![anatomy](https://matplotlib.org/_images/anatomy.png)

**A few key aspects**:

#### Figure
The whole figure. The figure keeps track of all the child `Axes`, a smattering of 'special' `artists` (titles, figure legends, etc), and the `canvas`. (Don't worry too much about the canvas, it is crucial as it is the object that actually does the drawing to get you your plot, but as the user it is more-or-less invisible to you). A figure can have any number of `Axes`, but to be useful should have at least one.

#### Axes
This is what you think of as 'a plot', it is the region of the image with the data space. A given figure can contain many `Axes`, but a given `Axes` object can only be in one `Figure`. The `Axes` contains two (or three in the case of 3D) `Axis` objects (be aware of the difference between `Axes` and `Axis`) which take care of the data limits. Each `Axes` has a title (set via `set_title()`), an x-label (set via `set_xlabel()`), and a y-label set via `set_ylabel()`).

The `Axes` class and its member functions are the primary entry point to working with the object-oriented interface.

# Learning from examples

The best way to illustrate and understand the way `matplotlib` functions is to give a series of examples. We will first run through several typical use cases for a graphing library and how one creates them with `matplotlib`.
Looking at examples around the web is also the best way to learn how to use `matplotlib`. There is a very large collection of examples on the `matplotlib` website itself (https://matplotlib.org/gallery.html), as well as thousands of posts on stackoverflow.com.

`matplotlib` is often seen as requiring more lines of code that other tools such as `Gnuplot`, but also comes with very fine control and is able to do things no other library can do. If you feel that a large amount of code is necessary to produce just the simplest of graphs, hopefully the benefits of the syntax will become clear by the end of this tutorial.

**We have added one exercise per example below; it would take some time to complete all of them. We suggest you try a couple to begin with, and then skip over the ones that do not interest you to focus on a select few. If one example reminds you of a case you were once yourself confronted with during your work/research, you can also try to expand the example to see how one would achieve your old case using `matplotlib`, instead of following the suggested exercise.**

We import here `numpy` which will be used to generate all the data we visualize in this tutorial, as well as the plotting module of `matplotlib` (`matplotlib` contains other things than just plotting, e.g. modules to manipulate colors etc.)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

We here activate the `notebook` backend which provides interactive plots in the Jupyter notebook (for static plots, you can use the `inline` backend, which is the current default for Jupyter)

In [None]:
%matplotlib notebook

## A simple 1D line plot

In [None]:
# Data for plotting
t = np.linspace(0.0, 2.0, 100)
s = 1 + np.sin(2 * np.pi * t)

# Create the Figure and Axes in one go using the `subplots` function.
# The first two arguments of the subplots function are the number of rows and columns, respectively.
fig, ax = plt.subplots(1, 1)
# Make a simple x,y line plot on out axes object
ax.plot(t, s)
# Decorate the axes with axis labels and title
ax.set(xlabel='time (s)', ylabel='voltage (mV)',
       title='About as simple as it gets, folks')

# If you are running this outside of a Jupyter notebook, e.g. in an IPython console, you will
# need to explicitly `show` the figure with

# fig.show()

<div class="alert alert-block alert-warning">
    <b>Exercise:</b><br>
Look at the documentation and try to add a grid to the axes to obtain <img src="https://matplotlib.org/_images/sphx_glr_simple_plot_001.png" />
</div>

## Multiple subplots in one figure

In [None]:
x1 = np.linspace(0.0, 5.0)
x2 = np.linspace(0.0, 2.0)
y1 = np.cos(2 * np.pi * x1) * np.exp(-x1)
y2 = np.cos(2 * np.pi * x2)

fig, ax = plt.subplots(2, 1) # note here that ax is now a list of axes
ax[0].plot(x1, y1, 'o-')
ax[0].set_title('A tale of 2 subplots')
ax[0].set_ylabel('Damped oscillation')

ax[1].plot(x2, y2, '.-')
ax[1].set_xlabel('time (s)')
ax[1].set_ylabel('Undamped')

<div class="alert alert-block alert-warning">
    <b>Exercise:</b><br>
Try adding two more subplots on the right hand side of the existing ones to have a 2x2 grid of plots. The new subplots should show a zoom onto a particular area of the left subplots.
</div>

## 2D image/heatmap with uniformly sized pixels

In [None]:
N = 100
M = 50
x = np.arange(N+1)
y = np.arange(M+1)
z = np.random.rand(M, N).astype(np.float64)

fig, ax = plt.subplots()
im = ax.imshow(z, origin="lower", extent=[x[0], x[-1], y[0], y[-1]],
               aspect="auto", cmap="viridis")
cb = plt.colorbar(im) # adds the colorbar: we specify that it applies to the `im` object
cb.ax.set_ylabel("Counts")
ax.set_xlabel("x coordinate")
ax.set_ylabel("y coordinate")
ax.set_title("A 2D image")

## 2D image/heatmap with non-uniformly sized pixels

(Note that the `pcolormesh` is slow for large datasets/images)

In [None]:
N = 10
M = 5
x = np.arange(N+1)**2
y = np.arange(M+1)
z = np.random.rand(M, N).astype(np.float64)

fig, ax = plt.subplots()
pcmesh = ax.pcolormesh(x, y, z, cmap="viridis")
cb = plt.colorbar(pcmesh)
cb.ax.set_ylabel("Counts")
ax.set_xlabel("x coordinate")
ax.set_ylabel("y coordinate")
ax.set_title("A 2D image with non-equal sized pixels")

### 1D histogram plot with error bars

In [None]:
N = 50
x = np.arange(N+1)
y = np.random.rand(N)
e = 0.1*np.random.rand(N)
fig = plt.figure()
ax = fig.add_subplot(111)
ax.bar(0.5*(x[:-1] + x[1:]), y, width=np.ediff1d(x),yerr=e)
ax.set_xlabel("Some x label [m]")
ax.set_ylabel("A fancy y label [kg]")
ax.set_title("A 1D histogram plot")

## 2. 2D plots

### 2D image/heatmap with uniformly sized pixels

In [None]:
N = 100
M = 50
x = np.arange(N+1)
y = np.arange(M+1)
z = np.random.rand(M, N).astype(np.float64)
fig = plt.figure()
ax = fig.add_subplot(111)
im = ax.imshow(z, origin="lower", extent=[x[0], x[-1], y[0], y[-1]],
               aspect="auto", cmap="viridis")
cb = plt.colorbar(im)
cb.ax.set_ylabel("Counts")
ax.set_xlabel("x coordinate")
ax.set_ylabel("y coordinate")
ax.set_title("A 2D image")

### 2D filled contour plot

In [None]:
N = 100
M = 50
xx = np.arange(N, dtype=np.float64)
yy = np.arange(M, dtype=np.float64)
x, y = np.meshgrid(xx, yy)
b = N/20.0
c = M/2.0
r = np.sqrt(((x-c)/b)**2 + ((y-c)/b)**2)
z = np.sin(r)
fig = plt.figure()
ax = fig.add_subplot(111)
contf = plt.contourf(x, y, z, cmap="viridis")
cb = plt.colorbar(contf)
cb.ax.set_ylabel("Counts")
ax.set_xlabel("x coordinate")
ax.set_ylabel("y coordinate")
ax.set_title("2D contours")

### 2D image/heatmap with non-uniformly sized pixels

In [None]:
N = 10
M = 5
x = np.arange(N+1)**2
y = np.arange(M+1)
z = np.random.rand(M, N).astype(np.float64)
fig = plt.figure()
ax = fig.add_subplot(111)
pcmesh = ax.pcolormesh(x, y, z, cmap="viridis")
cb = plt.colorbar(pcmesh)
cb.ax.set_ylabel("Counts")
ax.set_xlabel("x coordinate")
ax.set_ylabel("y coordinate")
ax.set_title("A 2D image with non-equal sized pixels")

### 2D scatter plot with different symbol sizes

In [None]:
N = 100
x = np.random.rand(N).astype(np.float64)
y = np.random.rand(N).astype(np.float64)
z = np.random.rand(N).astype(np.float64)
s = 300.0*np.random.rand(N).astype(np.float64)
fig = plt.figure()
ax = fig.add_subplot(111)
scat = ax.scatter(x, y, c=z, cmap="jet", s=s)
cb = plt.colorbar(scat)
cb.ax.set_ylabel("Counts")
ax.set_xlabel("x coordinate")
ax.set_ylabel("y coordinate")
ax.set_title("A 2D scatter plot")

## 3. Widgets and interactive plots

### 2D heatmap with slider through 3D data cube

In [None]:
from matplotlib.widgets import Slider
 
data = np.random.rand(10, 10, 10)
idx = 0

fig = plt.figure()
ax = fig.add_subplot(111)
fig.subplots_adjust(bottom=0.15)

im_h = ax.imshow(data[:, :, idx], interpolation='nearest')

ax.set_xlabel("x coordinate")
ax.set_ylabel("y coordinate")
ax.set_title("2D heatmap with slider")
cb = plt.colorbar(im_h)
cb.ax.set_ylabel("Counts")

ax_depth = plt.axes([0.23, 0.02, 0.56, 0.04])
slider_depth = Slider(ax_depth,
    'depth',
    0,
    data.shape[2]-1,
    valinit=idx)

def update_depth(val):
    idx = int(round(slider_depth.val))
    im_h.set_data(data[:, :, idx])
  
slider_depth.on_changed(update_depth)
  
plt.show()

### Scrolling with the mouse through a 3D data cube

In [None]:
class IndexTracker(object):
    def __init__(self, ax, X):
        self.ax = ax
        ax.set_title('use scroll wheel to navigate images')

        self.X = X
        rows, cols, self.slices = X.shape
        self.ind = self.slices//2

        self.im = ax.imshow(self.X[:, :, self.ind])
        self.update()

    def onscroll(self, event):
        print("%s %s" % (event.button, event.step))
        if event.button == 'up':
            self.ind = np.clip(self.ind + 1, 0, self.slices - 1)
        else:
            self.ind = np.clip(self.ind - 1, 0, self.slices - 1)
        self.update()

    def update(self):
        self.im.set_data(self.X[:, :, self.ind])
        ax.set_ylabel('slice %s' % self.ind)
        self.im.axes.figure.canvas.draw()


fig, ax = plt.subplots(1, 1)

X = np.random.rand(20, 20, 40)

tracker = IndexTracker(ax, X)


fig.canvas.mpl_connect('scroll_event', tracker.onscroll)
plt.show()

### Widget demo

In [None]:
from matplotlib.widgets import Slider, Button, RadioButtons

fig, ax = plt.subplots()
plt.subplots_adjust(left=0.25, bottom=0.25)
t = np.arange(0.0, 1.0, 0.001)
a0 = 5
f0 = 3
delta_f = 5.0
s = a0 * np.sin(2 * np.pi * f0 * t)
l, = plt.plot(t, s, lw=2)
ax.margins(x=0)

axcolor = 'lightgoldenrodyellow'
axfreq = plt.axes([0.25, 0.1, 0.65, 0.03], facecolor=axcolor)
axamp = plt.axes([0.25, 0.15, 0.65, 0.03], facecolor=axcolor)

sfreq = Slider(axfreq, 'Freq', 0.1, 30.0, valinit=f0, valstep=delta_f)
samp = Slider(axamp, 'Amp', 0.1, 10.0, valinit=a0)


def update(val):
    amp = samp.val
    freq = sfreq.val
    l.set_ydata(amp*np.sin(2*np.pi*freq*t))
    fig.canvas.draw_idle()


sfreq.on_changed(update)
samp.on_changed(update)

resetax = plt.axes([0.8, 0.025, 0.1, 0.04])
button = Button(resetax, 'Reset', color=axcolor, hovercolor='0.975')


def reset(event):
    sfreq.reset()
    samp.reset()
button.on_clicked(reset)

rax = plt.axes([0.025, 0.5, 0.15, 0.15], facecolor=axcolor)
radio = RadioButtons(rax, ('red', 'blue', 'green'), active=0)


def colorfunc(label):
    l.set_color(label)
    fig.canvas.draw_idle()
radio.on_clicked(colorfunc)

plt.show()

## 4. 3D plots (this is possible but not recommended!)

In [None]:
from mpl_toolkits.mplot3d import Axes3D

### 3D line plot

In [None]:
N = 100
M = 10
xx = np.arange(N, dtype=np.float64)
yy = np.arange(M, dtype=np.float64)
x, y = np.meshgrid(xx, yy)
b = M/2.0
c = N/2.0
r = np.sqrt(((x-c)/b)**2 + ((y-c)/b)**2)
z = np.sin(r)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

for i in range(M):
    ax.plot(xx, [i]*N, z[i, :])

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

### 3D scatter plot

In [None]:
N = 100
x = np.random.rand(N).astype(np.float64)
y = np.random.rand(N).astype(np.float64)
z = np.random.rand(N).astype(np.float64)
s = 300.0*np.random.rand(N).astype(np.float64)
c = np.abs(z)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
plot = ax.scatter(x, y, z, s=s, c=c, cmap="jet")
  
ax.set_xlabel("x coordinate")
ax.set_ylabel("y coordinate")
ax.set_zlabel("z coordinate")
ax.set_title("A 3D scatter plot")
  
cb = plt.colorbar(plot)
cb.ax.set_ylabel("Counts")

### 3d surface

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

X = np.arange(-5, 5, 0.25)
Y = np.arange(-5, 5, 0.25)
X, Y = np.meshgrid(X, Y)
R = np.sqrt(X**2 + Y**2)
Z = np.sin(R)

surf = ax.plot_surface(X, Y, Z, linewidth=0, antialiased=False, cmap='viridis')

fig.colorbar(surf, shrink=0.5, aspect=5)