In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42
%matplotlib inline

# Small multiples

In [2]:
# Read in our data
df = pd.read_csv("country-data.csv")
df.head(3)

OSError: File b'country-data.csv' does not exist

## Plotting everything

In [None]:
fig, ax = plt.subplots()
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', ax=ax, legend=False)
ax.set_title("Bhutan")

In [None]:
fig, ax = plt.subplots()
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', ax=ax, legend=False)
ax.set_title("Iran")

# Using ax

In [None]:
# One subplot again
fig, ax = plt.subplots()

# Use ax for both
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', ax=ax, label='Bhutan')
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', ax=ax, label='Iran')
ax.set_title("Iran and Bhutan")

# Two separate charts, different images

In [None]:
# One subplot again
fig, ax = plt.subplots()
fig2, ax2 = plt.subplots()

# Use ax for both
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', ax=ax, label='Bhutan')
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', ax=ax2, label='Iran')
ax.set_title("Iran and Bhutan")

# Using `plt.subplots` with `nrows`

Those ax were pretty useless. But we can pass other things to `.subplots` instead of just nothing! Let's try passing `nrows=2` to get back two axes to plot on

In [None]:
# Asking for TWO subplots, ax1 and ax2.
# Be sure to put them in parenthesis
fig, ax1, ax2 = plt.subplots(nrows=2, ncols=1)

# Use ax1 to plot Bhutan
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax1)
ax1.set_title("Bhutan")

# Use ax2 to plot Iran
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax2)
ax2.set_title("Iran")

# If you don't do tight_layout() you'll have weird overlaps
plt.tight_layout()

# Let's put them side-by-side

Use `nrows=1` and increase `ncols` to position graphics next to each other.

In [None]:
# Asking for TWO subplots, ax1 and ax2.
# Be sure to put them in parenthesis
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2)

# Use ax1 to plot Bhutan
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax1)
ax1.set_title("Bhutan")

# Use ax2 to plot Iran
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax2)
ax2.set_title("Iran")

# If you don't do tight_layout() you'll have weird overlaps
plt.tight_layout()

# Let's make it four

In [None]:
# Asking for TWO subplots, ax1 and ax2.
# Be sure to put them in parenthesis
fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=1, ncols=4)

# Use ax1 to plot Bhutan
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax1)
ax1.set_title("Bhutan")

# Use ax2 to plot Iran
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax2)
ax2.set_title("Iran")

# Use ax3 to plot France
df[df['Country'] == 'France'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax3)
ax3.set_title("France")

# Use ax2 to plot Mexico
df[df['Country'] == 'Mexico'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax4)
ax4.set_title("Mexico")

# If you don't do tight_layout() you'll have weird overlaps
plt.tight_layout()

# Using figsize

Use `figsize` to make the figure large enough to deal with everything you're going to be graphing

In [None]:
# Asking for TWO subplots, ax1 and ax2.
# Be sure to put them in parenthesis
fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=1, ncols=4, figsize=(14, 5))

# Use ax1 to plot Bhutan
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax1)
ax1.set_title("Bhutan")

# Use ax2 to plot Iran
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax2)
ax2.set_title("Iran")

# Use ax3 to plot France
df[df['Country'] == 'France'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax3)
ax3.set_title("France")

# Use ax2 to plot Mexico
df[df['Country'] == 'Venezuela'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax4)
ax4.set_title("Venezuela")

# If you don't do tight_layout() you'll have weird overlaps
plt.tight_layout()

# Using nrows and ncols gets weird

You can't just say `fig, (ax1, ax2 ax3, ax4) = plt.subplots(nrows=2, ncols=2)` when asking for multiple rows.

In [None]:
# Asking for TWO subplots, ax1 and ax2.
# Be sure to put them in parenthesis
fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=2, ncols=2, figsize=(8, 8))

# Use ax1 to plot Bhutan
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax1)
ax1.set_title("Bhutan")

# Use ax2 to plot Iran
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax2)
ax2.set_title("Iran")

# Use ax3 to plot France
df[df['Country'] == 'France'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax3)
ax3.set_title("France")

# Use ax2 to plot Mexico
df[df['Country'] == 'Venezuela'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax4)
ax4.set_title("Venezuela")

# If you don't do tight_layout() you'll have weird overlaps
plt.tight_layout()

# `sharex` and `sharey`

In [None]:
# Asking for TWO subplots, ax1 and ax2.
# Be sure to put them in parenthesis
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize=(8, 8), sharex=True, sharey=True)

# Use ax1 to plot Bhutan
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax1)
ax1.set_title("Bhutan")

# Use ax2 to plot Iran
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax2)
ax2.set_title("Iran")

# Use ax3 to plot France
df[df['Country'] == 'France'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax3)
ax3.set_title("France")

# Use ax2 to plot Mexico
df[df['Country'] == 'Venezuela'].plot(x='Year', y='GDP_per_capita', legend=False, ax=ax4)
ax4.set_title("Venezuela")

# If you don't do tight_layout() you'll have weird overlaps
plt.tight_layout()

# Using `plt.subplot`

The other big way of making small multiples that you'll see everywhere is `plt.subplot` (notice it's `subplot` not `subplots`)

In [None]:
# 1 row, 2 columns, and we'd like the second element.
ax1 = plt.subplot(1, 2, 1)
df[df['Country'] == 'Belarus'].plot(x='Year', y='GDP_per_capita', ax=ax1, legend=False)
ax1.set_title("Belarus")

# 1 row, 2 columns, and we'd like the first element.
ax2 = plt.subplot(1, 2, 2)
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', ax=ax2, legend=False)
ax2.set_title("Bhutan")

plt.tight_layout()

In [None]:
# 1 row, 2 columns, and we'd like the second element.
ax1 = plt.subplot(1, 2, 1)
df[df['Country'] == 'Belarus'].plot(x='Year', y='GDP_per_capita', ax=ax1, legend=False)
ax1.set_title("Belarus")

# 1 row, 2 columns, and we'd like the first element.
ax2 = plt.subplot(2, 2, 4)
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', ax=ax2, legend=False)
ax2.set_title("Bhutan")

plt.tight_layout()

In [None]:
# 1 row, 2 columns, and we'd like the second element.
ax1 = plt.subplot(1, 2, 1)
df[df['Country'] == 'Belarus'].plot(x='Year', y='GDP_per_capita', ax=ax1, legend=False)
ax1.set_title("Belarus")

# 1 row, 2 columns, and we'd like the first element.
ax2 = plt.subplot(2, 2, 2)
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', ax=ax2, legend=False)
ax2.set_title("Iran")

ax3 = plt.subplot(2, 2, 4)
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', ax=ax3, legend=False)
ax3.set_title("Bhutan")

plt.tight_layout()

# Sharing axes using `plt.subplot`

You can't just use `sharex` and `sharey`!

In [None]:
# 1 row, 2 columns, and we'd like the second element.
ax1 = plt.subplot(1, 2, 1)
df[df['Country'] == 'Belarus'].plot(x='Year', y='GDP_per_capita', ax=ax1, legend=False)
ax1.set_title("Belarus")
ax1.set_ylim((0, 15000))

# 1 row, 2 columns, and we'd like the first element.
ax2 = plt.subplot(2, 2, 2)
df[df['Country'] == 'Iran'].plot(x='Year', y='GDP_per_capita', ax=ax2, legend=False)
ax2.set_title("Iran")
ax2.set_ylim((0, 15000))

ax3 = plt.subplot(2, 2, 4)
df[df['Country'] == 'Bhutan'].plot(x='Year', y='GDP_per_capita', ax=ax3, legend=False)
ax3.set_title("Bhutan")
ax3.set_ylim((0, 15000))

plt.tight_layout()

# Looping

But this isn't going to work for looping. In fact, *most things don't work for looping*, but I figured it A Reasonable Method.

In [None]:
# We can ask for ALL THE AXES and put them into the axes variable
# 3 rows, 3 columns
fig, axes = plt.subplots(nrows=3, ncols=3, sharex=True, sharey=True)
axes

Right now it's 3 lists of 3 axes apiece, which will be hard to loop over.

In [None]:
axes

In [None]:
len(axes)

In [None]:
axes[0]

In [None]:
axes[1]

Luckily we can steal code from the internet to turn it into just one normal list

In [None]:
# http://stackoverflow.com/questions/952914/making-a-flat-list-out-of-list-of-lists-in-python
[item for sublist in axes for item in sublist] 

...and take them off one piece at a time...

In [None]:
axes_list = [item for sublist in axes for item in sublist]

In [None]:
len(axes_list)

In [None]:
# Remove the first one, save it as 'ax'
ax = axes_list.pop(0)
ax

In [None]:
# Only 8 left now
len(axes_list)

In [None]:
# Remove another one, save it as 'ax'
ax = axes_list.pop(0)
ax

In [None]:
# Only 7 left now
len(axes_list)

In [None]:
# Remove another one, save it as 'ax'
ax = axes_list.pop(0)
ax

In [None]:
# Only 6 left now
len(axes_list)

and we can just do this until we have none left!

# Putting it together

In [None]:
# We can ask for ALL THE AXES and put them into axes
fig, axes = plt.subplots(nrows=2, ncols=5, sharex=True, sharey=True, figsize=(10,5))
axes_list = [item for sublist in axes for item in sublist] 

countries = ["Iran", "Iraq", "Venezuela", "China", "Bhutan", "Bangladesh", "Mexico", "Poland", "Kazakhstan", "Nepal"]

subset_df = df[df['Country'].isin(countries)]

for countryname, selection in subset_df.groupby("Country"):
    ax = axes_list.pop(0)
    selection.plot(x='Year', y='GDP_per_capita', label=countryname, ax=ax, legend=False)
    ax.set_title(countryname)
    ax.tick_params(
        which='both',
        bottom='off',
        left='off',
        right='off',
        top='off'
    )
    ax.spines['left'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

# Now use the matplotlib .remove() method to 
# delete anything we didn't use
for ax in axes_list:
    ax.remove()

In [None]:
# We can ask for ALL THE AXES and put them into axes
fig, axes = plt.subplots(nrows=2, ncols=5, sharey=True, figsize=(10,5))
axes_list = [item for sublist in axes for item in sublist] 

countries = ["Iran", "Iraq", "Venezuela", "China", "Bhutan", "Bangladesh", "Mexico", "Poland", "Kazakhstan", "Nepal"]

subset_df = df[df['Country'].isin(countries)]

for countryname, selection in subset_df.groupby("Country"):
    ax = axes_list.pop(0)
    selection.plot(x='Year', y='GDP_per_capita', label=countryname, ax=ax, legend=False)
    ax.set_title(countryname)
    ax.tick_params(
        which='minor',
        bottom='off',
        left='off',
        right='off',
        top='off'
    )
    ax.tick_params(
        which='major',
        bottom='on',
        left='off',
        right='off',
        top='off'
    )
    ax.grid(linewidth=0.25)
    ax.set_xticks((1950,1970, 1990, 2010))
    ax.spines['left'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlabel("")

# Now use the matplotlib .remove() method to 
# delete anything we didn't use
for ax in axes_list:
    ax.remove()

plt.tight_layout()