# 1 Intro

This notebook demonstrates how to create basic plots with Matplotlib's default style. In most cases, we apply none or minimal customization/refinement to showcase Matplotlib's basic functionalities.

In addition, we explore how to achieve the same plots using [Seaborn](https://seaborn.pydata.org/index.html), a high-level data visualization library built on Matplotlib. Seaborn excels at producing (reasonably) good-looking plots with minimal effort, so if you want to quickly create a basic plot that does not need extensive customization, Seaborn can sometimes be a good choice.

At the end of this notebook, we create an advanced plot, [bubble plot](https://en.wikipedia.org/wiki/Bubble_chart). This gives you a taste of how Matplotlib plots can be extensively customized.

The main reference for this notebook is this Matplotlib document [*Overview of Matplotlib Common Plots*](https://matplotlib.org/stable/plot_types/index.html).

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# 2 Data

We will use the gampminder dataset (1952-2007) for all the plots. This dataset contains time series of development statistics for countries around world. The famous animated bubble plot of [World Health Chart](https://www.gapminder.org/fw/world-health-chart/) is based on this dataset. You can find a more comprehensive and updated gapminder dataset [here](https://www.gapminder.org/data/).

We first load the dataset and then take a quick look at it.

In [None]:
# load the data
data_url = "https://github.com/tdmdal/datasets-teaching/raw/main/gapminder/gapminder.csv"
gapminder = pd.read_csv(data_url)

In [None]:
# display first 5 rows
gapminder.head()

In [None]:
# check column types
gapminder.dtypes

In [None]:
# display summary statistics
gapminder.describe()

In [None]:
# display unique years
gapminder["year"].unique()

In [None]:
# display unique countries
gapminder["country"].unique()

# 3 Basic Plots

## 3.1 Simple bar plot

We create a bar plot for population of Canada in 1997, 2002 and 2007.

In [None]:
# prepare data
gm_canada_after97 = gapminder[(gapminder["country"] == "Canada") & (gapminder["year"] >= 1997)]
gm_canada_after97.head()

### Matplotlib

In [None]:
# matplotlib default
fig, ax = plt.subplots()

# https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.bar.html
ax.bar(x=gm_canada_after97["year"], height=gm_canada_after97["pop"])

# set y label and plot title
ax.set_ylabel("Population")
ax.set_title("Canada Population")

# optional in notebook
plt.show()

In [None]:
# matplotlib default + minor refinement
fig, ax = plt.subplots()

# convert population to M unit; specify bar width
ax.bar(x=gm_canada_after97["year"], height=gm_canada_after97["pop"]/1_000_000, width=2)

# set y label and plot title
ax.set_ylabel("Population (M)")
ax.set_title("Canada Population")

# specify xticks
xticks = [1997, 2002, 2007]
ax.set_xticks(xticks)

# optional in notebook
plt.show()

### Seaborn

In [None]:
# seaborn default
# https://seaborn.pydata.org/generated/seaborn.barplot.html
ax = sns.barplot(x="year", y="pop", data=gm_canada_after97)
ax.set_ylabel("Population")
ax.set_title("Canada Population")

# optional in notebook
plt.show()

### Exercise

Create a bar plot for GDP per capita of Canada 1997, 2002 and 2007.

In [None]:
# your code here



## 3.2 Grouped bar plot

We create a grouped bar plot comparing GDP per capita between Canada and USA in 1997, 2002 and 2007.

In [None]:
# prepare data
gm_canada_us_after97 = gapminder[(gapminder["country"].isin(["Canada", "United States"])) & (gapminder["year"] >= 1997)]
gm_canada_us_after97.head()

### Matplotlib

In [None]:
# matplotlib default + minor refinement

# further prepare data
can = gm_canada_us_after97[gm_canada_us_after97["country"]=="Canada"]
us = gm_canada_us_after97[gm_canada_us_after97["country"]=="United States"]

fig, ax = plt.subplots()

# define bar width and offset for grouped bar plot
bar_width = 1.6
offset = bar_width / 2

# plot 3 bars for Canada and then for USA
# Note that the starting x locations of the bars are calculated so that the bars
# for Canada and USA don't overlap
ax.bar(x=can["year"]-offset, height=can["gdpPercap"], width=bar_width, label="Canada")
ax.bar(x=us["year"]+offset, height=us["gdpPercap"], width=bar_width, label="USA")

# set y label and plot title
ax.set_ylabel("GDP per capita (USD)")
ax.set_title("GDP per capita (Canada vs. USA)")

# specify xticks
xticks = [1997, 2002, 2007]
ax.set_xticks(xticks)

# add legends based on bar labels
ax.legend(loc='upper left')

# optional in notebook
plt.show()

### Seanborn

In [None]:
# https://seaborn.pydata.org/generated/seaborn.barplot.html
# note the hue argument
ax = sns.barplot(x="year", y="gdpPercap", hue="country", data=gm_canada_us_after97)
ax.set_ylabel("GDP per capita (USD)")
ax.set_title("GDP per capita (Canada vs. USA)")

# optional in notebook
plt.show()

### Pandas `DataFrame.plot.bar()`

In [None]:
# pivot the data frame to the right format before using DataFrame.plot.bar()
gm_plot = pd.pivot_table(gm_canada_us_after97, values='gdpPercap', index='year', columns=['country'])
gm_plot.head()

In [None]:
# grouped bar plot using pandas DataFrame.plot.bar()
ax = gm_plot.plot.bar()
ax.set_ylabel("GDP per capita (USD)")
ax.set_title("GDP per capita (Canada vs. USA)")

# optional in notebook
plt.show()

### Exercise

Create a grouped bar plot comparing GDP per capita of Canada, USA and Japan in 1997, 2002 and 2007.

In [None]:
# your code here



## 3.3 Line plot

Create a line plot for GDP per capita (Canada vs USA) from 1952 to 2007.

### Matplotlib

In [None]:
# prepare data
gm_canada = gapminder[gapminder["country"] == "Canada"]
gm_usa = gapminder[gapminder["country"] == "United States"]

In [None]:
fig, ax = plt.subplots()

# https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.plot.html
ax.plot(gm_canada["year"], gm_canada["gdpPercap"], marker=".", label="Canada")
ax.plot(gm_usa["year"], gm_usa["gdpPercap"], marker=".", label="USA")

# set y-axis label and plot title
ax.set_ylabel("GDP per capita (USD)")
ax.set_title("GDP per capita (Canada vs. USA)")

# specify xticks
xticks = list(range(1952, 2008, 5))
ax.set_xticks(xticks)

# add legends based on line labels
ax.legend(loc='upper left')

# optional in notebook
plt.show()

### Seaborn

In [None]:
# prepare data for seaborn plot
gm_canada_usa = gapminder[(gapminder["country"].isin(["Canada", "United States"]))]

In [None]:
# https://seaborn.pydata.org/generated/seaborn.lineplot.html
ax = sns.lineplot(data=gm_canada_usa, x="year", y="gdpPercap",
                  hue="country", markers=True, style="country")

# specify xticks
xticks = list(range(1952, 2008, 5))
ax.set_xticks(xticks)

# set y-axis label and plot title
ax.set_ylabel("GDP per capita (USD)")
ax.set_title("GDP per capita (Canada vs. USA)")

# remove the default legend title
ax.get_legend().set_title(None)

# optional in notebook
plt.show()

### Exercise

Create a line plot of life expectancy (Canada vs USA) from 1952 to 2007.

In [None]:
# your code here



## 3.4 Histogram

Plot histograms for life expectancy across countries in 1952 and 2007.

### Matplotlib

In [None]:
# prepare data for matplotlib plot
gm_1952 = gapminder[gapminder["year"] == 1952]
gm_2007 = gapminder[gapminder["year"] == 2007]

In [None]:
# two histogram on the same plot (an ax object)
fig, ax = plt.subplots()

# https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.hist.html
ax.hist(gm_1952["lifeExp"], bins=20, alpha=0.5, label=1952)
ax.hist(gm_2007["lifeExp"], bins=20, alpha=0.5, label=2007)

ax.set_ylabel("Count")
ax.set_xlabel("Life Expectancy (Year)")
ax.set_title("Histogram of Life Expectancy across Countries")

ax.legend(loc='upper left')

# optional in notebook
plt.show()

In [None]:
# two histogram on two sub-plots (two ax objects)
fig, ax = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True)

# https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.hist.html
ax[0].hist(gm_1952["lifeExp"], bins=20, edgecolor="white", alpha=0.5)
ax[1].hist(gm_2007["lifeExp"], bins=20, edgecolor="white", alpha=0.5)

# https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.tick_params.html#matplotlib.axes.Axes.tick_params
ax[1].yaxis.set_tick_params(labelleft=True)

ax[0].set_ylabel("Count")

ax[0].set_xlabel("Life Expectancy (Year)")
ax[1].set_xlabel("Life Expectancy (Year)")

ax[0].set_title("1952")
ax[1].set_title("2007")

fig.suptitle("Histogram of Life Expectancy across Countries")

# optional in notebook
plt.show()

### Seaborn

In [None]:
# prepare data for seanborn histogram plot
gm_1952_2007 = gapminder[(gapminder["year"].isin([1952, 2007]))]

# two histogram on the same plot (an ax object)
ax = sns.histplot(data=gm_1952_2007, x="lifeExp", hue="year", bins=20, multiple="layer")

ax.set_ylabel("Count")
ax.set_xlabel("Life Expectancy (Year)")
ax.set_title("Histogram of Life Expectancy across Countries")

# optional in notebook
plt.show()

In [None]:
# two histogram on two sub-plots (two ax objects)

# prepare data for seanborn histogram plot
gm_1952_2007 = gapminder[(gapminder["year"].isin([1952, 2007]))]

# two histogram on the same plot (an ax object)
# note the col argument
ax = sns.displot(data=gm_1952_2007, x="lifeExp", col="year", bins=20)

# https://seaborn.pydata.org/generated/seaborn.FacetGrid.set_xlabels.html
ax.set_xlabels("Life Expectancy (Year)")

# https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.subplots_adjust.html
ax.fig.subplots_adjust(top=0.85)
ax.fig.suptitle("Histogram of Life Expectancy across Countries")

# optional in notebook
plt.show()

### Exercise

Create a histograms for life expectancy across countries in 2007. Plot the probability density (instead of the count), and a density curve too.

In [None]:
# using matplotlib

from scipy.stats import gaussian_kde
import numpy as np

# https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.gaussian_kde.html
density = gaussian_kde(gm_2007["lifeExp"])
xs = np.linspace(gm_2007["lifeExp"].min(), gm_2007["lifeExp"].max(), 100)
ys = density(xs)

fig, ax = plt.subplots()

# insert your code here

# optional in notebook
plt.show()

In [None]:
# using seaborn



## 3.5 Scatter plot

Produce a scatter plot of life expectancy against GDP per capita in 2007.

### Matplotlib

In [None]:
fig, ax = plt.subplots()

# https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.scatter.html
ax.scatter(gm_2007["gdpPercap"], gm_2007["lifeExp"], alpha=0.5)

ax.set_ylabel("Life expectancy (Year)")
ax.set_xlabel("GDP per capita (USD)")
ax.set_title("Life expectancy vs. GDP per capita (2007)")

# optional in notebook
plt.show()

### Seanborn

In [None]:
# https://seaborn.pydata.org/generated/seaborn.scatterplot.html
ax = sns.scatterplot(data=gm_2007, x="gdpPercap", y="lifeExp", alpha=0.5)

ax.set_ylabel("Life expectancy (Year)")
ax.set_xlabel("GDP per capita (USD)")
ax.set_title("Life expectancy vs. GDP per capita (2007)")

# optional in notebook
plt.show()

### Exercise

Produce two scatter sub-plots of life expectancy against GDP per capita in 1952 and 2007.

In [None]:
# Your code here



# 4 Bubble plot

It isn't trivial to create a good-looking bubble plot using Matplotlib. Getting the size of the bubble and the legend right takes some effort. Below we first produce a bubble plot with almost no customization. We then refine it to make it look better.

We won't attempt to customize the bubble plot created using Seaborn.

## Matplotlib

In [None]:
# Matplotlib bubble plot with no refinement

fig, ax = plt.subplots()

# https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.scatter.html
scatter = ax.scatter(
    x = gm_2007["gdpPercap"],
    y = gm_2007["lifeExp"],
    s = gm_2007["pop"]/200000,
    c = gm_2007["continent"].astype("category").cat.codes,
)

ax.set_ylabel("Life expectancy (Year)")
ax.set_xlabel("GDP per capita (USD)")
ax.figure.suptitle("Life expectancy vs. GDP per capita (2007)")

# optional in notebook
plt.show()

In [None]:
# Matplotlib bubble plot with refinement

from matplotlib import ticker

fig, ax = plt.subplots()

# set the factor to adjust the size of the bubble
size_factor = 200_000

# obtain continent category code mapping (for legend label later)
continent_code_dict = dict(zip(gm_2007['continent'].astype("category").cat.codes, gm_2007['continent']))
continent_list_ordered = [v for k,v in dict(sorted(continent_code_dict.items())).items()]

# https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.scatter.html
scatter = ax.scatter(
    x = gm_2007["gdpPercap"],
    y = gm_2007["lifeExp"],
    s = gm_2007["pop"]/size_factor,
    c = gm_2007["continent"].astype("category").cat.codes,
    cmap = "Accent",
    alpha = 0.6,
    edgecolors = "black",
    linewidths = 0.5
)

ax.set_ylabel("Life expectancy (Year)")
ax.set_xlabel("GDP per capita (USD)")
ax.figure.suptitle("Life expectancy vs. GDP per capita (2007)")

# https://matplotlib.org/stable/api/spines_api.html
ax.spines[["top", "right"]].set_visible(False)

# set y axis top limit to 85 so one top bubble doesn't miss a part
# https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.set_ylim.html
ax.set_ylim(bottom=None, top=85)

# format the x-axis major ticker label
ax.xaxis.set_major_formatter(ticker.StrMethodFormatter("{x:,.0f}"))

# handle legend
# https://matplotlib.org/stable/gallery/lines_bars_and_markers/scatter_with_legend.html
handles1, labels1 = scatter.legend_elements(prop="colors")
labels1 = continent_list_ordered
legend1 = ax.legend(handles1, labels1, loc="lower right", title="Continent", framealpha=0)
ax.add_artist(legend1)

handles2, labels2 = scatter.legend_elements(prop="sizes",
                                            num=[10_000_000, 30_000_000],
                                            # fmt="{x:,.0e}",
                                            func = lambda x: x*size_factor,
                                            alpha=0.3)
# manually set the label instead of using the default formatter or fmt argument
labels2 = ["10M", "30M"]
legend2 = ax.legend(handles2, labels2, loc="center right", title="Population", framealpha=0)
ax.add_artist(legend2)

# optional in notebook
plt.show()

## Seaborn

In [None]:
# use the default seaborn plot
# still need refinement, but leave it as it is for now
gm_2007_new = gm_2007.copy()
gm_2007_new["pop"] = gm_2007_new["pop"] / 1_000_000

# https://seaborn.pydata.org/generated/seaborn.scatterplot.html
ax = sns.scatterplot(
    data=gm_2007_new,
    x="gdpPercap",
    y="lifeExp",
    size="pop",
    hue="continent",
    alpha=0.5,
    sizes=(20, 800)
)

# sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))

ax.set_ylabel("Life expectancy (Year)")
ax.set_xlabel("GDP per capita (USD)")
ax.set_title("Life expectancy vs. GDP per capita (2007)")

# optional in notebook
plt.show()