<a href="https://colab.research.google.com/github/worldbank/dec-python-course/blob/main/1-foundations/4-api-and-dataviz/foundations-s4-dataviz.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Run the lines below if needed
!pip install wbgapi
!pip install seaborn

In [None]:
from IPython.display import Image

In [None]:
import pandas as pd
pd.options.mode.chained_assignment = None
import matplotlib.pyplot as plt
import seaborn as sns
import wbgapi as wb

# Introduction to data visualization in Python

Python has several data visualization packages. Arguably, two libraries are perhaps the most widely used: `matplotlib` and `seaborn`.

# `matplotlib`

- First-ever Python data vis library
- Very powerful
- Allows low-level customization of plots
- "Wordy" syntax, can get quite complex easily
- Very popular in scientific programming

Remember this "picture"? it was actually a plot created with `matplotlib`.

<img src="https://github.com/worldbank/dec-python-course/blob/main/1-foundations/4-api-and-dataviz/img/black-hole.jpg?raw=true" width=400 />

# `seaborn`

- Built on top of `matplotlib`
- Nicer defaults
- High-level syntax
- Much easier to use than `matplotlib` but allows less customization

We're going to use `matplotlib` and `seaborn` in this session.

# Initial data

 We'll start by fetching some data from the WB API.

In [None]:
countries = ['MEX', 'CAN', 'USA']
years = range(2010, 2020)
df = wb.data.DataFrame('SP.POP.TOTL', countries, years, labels=True)

In [None]:
df

We're going to do a bit of data wrangling to give this data the shape that we need for data visualization, which is the long format.

In [None]:
df = df.reset_index(drop=True)                                  # drop index with economy iso-3 code
df = pd.wide_to_long(df, stubnames='YR', i='Country', j='year') # pivoting by country, obs now are country-year
df = df.reset_index()                                           # moving country name from index to column
df = df.rename(columns={'YR': 'Population'})                    # renaming column
df.head()                                                       # displaying first 5 obs

**Important:** All the contents of this session follow an important practice in data visualization: that all data wrangling is done outside of the visualization code. If any data wrangling is needed, we will do it using Pandas and will only pass wrangled data as visualization inputs.

# Bar plots

We'll create a simple bar plot of Mexico's total population by million.

- We need to wrangle the data so that it only includes observations from Mexico and in millions
- For bar plots the data we pass to `matplotlib` and `seaborn` is basically composed of the x-axis and y-axis data

In [None]:
y_data = df[df['Country']=='Mexico']['Population'] / 1000000 # x-axis data: population in millions
x_data = df[df['Country']=='Mexico']['year']                 # y-axis data: years we have data for Mexico

## Using `matplotlib`

In [None]:
# Simplest bar plot with default options
plt.bar(x_data, y_data)

In [None]:
# Adding some customization
plot_title = 'Mexico - Total Population in millions' 
plt.bar(x_data, y_data)
plt.title(plot_title)
plt.xlabel('Year')
plt.ylabel('Population')
plt.xticks(x_data);

- `plt` has a feature that might not seem very common in Python: it modifies an object in-place
- The multiple calls to `plt` add customizations on top of the result of the previous line
- When used in a notebook, `plt` will by default print the result of the last line of the code block
- This will not work across code blocks, though: a new code block will not have access to the `plt` object of the previous one

In [None]:
# This will return nothing because this block doesn't "have access" to the previous plt
plt.show()

- The semicolon (`;`) at the last line of a block tells the notebook to omit printing the return value of the last line (try removing it to see the difference)

## Using `seaborn`

In [None]:
# Method 1
sns.barplot(x=x_data, y=y_data)

In [None]:
# Method 1 with plot title and same color for all bars
# Color C0 tells seaborn to use the first color it has available,
# which is the same blue in the first bar of the previous plot
sns.barplot(x=x_data, y=y_data, color='C0')
plt.title(plot_title);

In [None]:
# Method 2
df_mexico = df[df['Country']=='Mexico']
df_mexico['Population'] = df['Population'] / 1000000

sns.barplot(data=df_mexico, x='year', y='Population', color='C0')
plt.title(plot_title);

A few details to note:

- For this example we defined the x and y-axis data as Pandas series, but they can also be lists (or containers) with numbers or NumPy series
- Seaborn accepts two methods to plot visualizations:
    + You either pass the x and y-axis data in the arguments `x`, `y`
    + Or you define a Pandas dataframe input in the arugment `data` and set `x` and `y` equal to the column names you take the x-axis and y-axis data from
- Compare the syntax of both libraries to get the same result:

```
# matplotlib
plt.bar(x_data, y_data)
plt.title(plot_title)
plt.xlabel('Year')
plt.ylabel('Population')
plt.xticks(x_data)

# seaborn
sns.barplot(x=x_data, y=y_data, color='C0')
plt.title(plot_title)
```

- `matplotlib` has a heavier syntax -- you'll also note this in the next examples
- Did you notice how we assigned the title in the `seaborn` example? `matplotlib` syntax can be used on top of `seaborn` plots
- `seaborn` sets x and y-axis labels and gives a different color to every bar by default

# Line plots

- Line plots have a very similar syntax than bar plots in `matplotlib`, but they use the function `plt.plot()` instead of `plt.barplot()`
- In `seaborn`, the function is `sns.linelplot()`

## `matplotlib`

In [None]:
plt.plot(x_data, y_data)
plt.title(plot_title)
plt.xlabel('Year')
plt.ylabel('Population')
plt.xticks(x_data);

In [None]:
# Now setting the range of the y-axis:
plt.plot(x_data, y_data)
plt.title(plot_title)
plt.xlabel('Year')
plt.ylabel('Population')
plt.xticks(x_data)
plt.ylim(100, 130); # y-axis from 100 to 130

## `seaborn`

In [None]:
sns.lineplot(data=df_mexico, x='year', y='Population', color='C0')
plt.title(plot_title)
plt.ylim(100, 130);

# Scatter plots

We'll create a scatter plot of GDP per capita and life expectancy for 2010.

## Fetching the data

We'll get the data for all the economies listed in the WBG API data. To do this we first need to get a list with all the country names.

In [None]:
all_units_df = wb.economy.DataFrame()
all_units_df.head()

In [None]:
all_countries_df = all_units_df[all_units_df['aggregate']==False] # leaving only non-aggreagate (economy-level) obs
all_countries_df.head()

In [None]:
countries_list = list(all_countries_df.index) # list with country names
print(countries_list[0:5])

The WB API client library also asks for the series we want to retrieve. They are:
- `NY.GDP.PCAP.KD`: GDP per capita (constant 2015 US$)
- `SP.DYN.LE00.IN`: Life expectancy at birth, total (years)

In [None]:
indicators = ['NY.GDP.PCAP.KD', 'SP.DYN.LE00.IN']

Retrieving the data:

In [None]:
df = wb.data.DataFrame(indicators, countries_list, time=2010, labels=True)

In [None]:
df.head()

Finally, we add a column of income level from the data frame `all_units_df`

In [None]:
df = df.merge(all_units_df[['name', 'incomeLevel']],
              left_on = 'Country',
              right_on = 'name')
df.head()

In [None]:
df = df.drop(columns='name') # drop repeated column
df = df.rename(columns={'NY.GDP.PCAP.KD': 'gdpPerCapita', 'SP.DYN.LE00.IN': 'lifeExpectancy'}) #renaming

In [None]:
df.head()

## With `matplotlib`

In [None]:
x = df['gdpPerCapita']
y = df['lifeExpectancy']

In [None]:
# Simple scatter plot
plt.scatter(x, y);

In [None]:
# Adding some customization
plt.scatter(x, y, s=10) # s=10 indicates the size of the markers
plt.title('Country GDP per capita and life expectancy')
plt.xlabel('GDP per capita (constant 2015 USD)')
plt.ylabel('Life expectancy in years');

Now adding different colors per income level:

In [None]:
# We first separate x and y-axis values by income level

## Lower
x_low = df[df['incomeLevel']=='LIC']['gdpPerCapita']
y_low = df[df['incomeLevel']=='LIC']['lifeExpectancy']

## Lower-middle
x_lm = df[df['incomeLevel']=='LMC']['gdpPerCapita']
y_lm = df[df['incomeLevel']=='LMC']['lifeExpectancy']

## Upper-middle
x_um = df[df['incomeLevel']=='UMC']['gdpPerCapita']
y_um = df[df['incomeLevel']=='UMC']['lifeExpectancy']

# Upper
x_upper = df[df['incomeLevel']=='HIC']['gdpPerCapita']
y_upper = df[df['incomeLevel']=='HIC']['lifeExpectancy']

In [None]:
plt.scatter(x_low, y_low, label='LIC', s=10)
plt.scatter(x_lm, y_lm, label='LMC', s=10)
plt.scatter(x_um, y_um, label='UMC', s=10)
plt.scatter(x_upper, y_upper, label='HIC', s=10)
plt.title('Country GDP per capita and life expectancy')
plt.xlabel('GDP per capita (constant 2015 USD)')
plt.ylabel('Life expectancy in years')
plt.legend(title='Income level');

 `matplotlib` allows to add multiple plot elements on top of each other. This also applies when you want to group the units into categories.

Alternatively, you can get to the same result using loops:

In [None]:
# We first separate x and y-axis values by income level

groups = ['LIC', 'LMC', 'UMC', 'HIC']
gdp_dict = {}
le_dict  = {}

for group in groups:
    
    gdp_dict[group] = df[df['incomeLevel']==group]['gdpPerCapita']
    le_dict[group]  = df[df['incomeLevel']==group]['lifeExpectancy']

In [None]:
for group in groups:
    plt.scatter(gdp_dict[group], le_dict[group], label=group, s=10)
plt.title('Country GDP per capita and life expectancy')
plt.xlabel('GDP per capita (constant 2015 USD)')
plt.ylabel('Life expectancy in years')
plt.legend(title='Income level');

## With `seaborn`

In [None]:
# No legend
sns.scatterplot(data=df, x='gdpPerCapita', y='lifeExpectancy')
plt.title('Country GDP per capita and life expectancy')
plt.xlabel('GDP per capita (constant 2015 USD)')
plt.ylabel('Life expectancy in years');

In [None]:
# With legend and groups
sns.scatterplot(data=df, x='gdpPerCapita', y='lifeExpectancy', hue='incomeLevel')
plt.title('Country GDP per capita and life expectancy')
plt.xlabel('GDP per capita (constant 2015 USD)')
plt.ylabel('Life expectancy in years')
plt.legend(title='Income level');

# Histograms

## `matplotlib`

In [None]:
# Basic histogram with default options
# x: GDP per capita
plt.hist(x);

In [None]:
plt.hist(x, bins=40) # bins sets the number of equal-size bins
plt.title('Histogram of country GDP per capita')
plt.xlabel('GDP per capita (constant 2015 USD)')
plt.ylabel('Number of countries');

## `seaborn`

In [None]:
# Default options
sns.histplot(data=df, x='gdpPerCapita');

In [None]:
# More customization
sns.histplot(data=df, x='gdpPerCapita', bins=40)
plt.title('Histogram of country GDP per capita')
plt.xlabel('GDP per capita (constant 2015 USD)')
plt.ylabel('Number of countries');

# Saving a plot

Both `matplotlib` and `seaborn` use the same method (from `matplotlib`) to save figures: `.savefig()`

In [None]:
# Saving a matplotlib plot
plt.hist(x, bins=40) # bins sets the number of equal-size bins
plt.title('Histogram of country GDP per capita')
plt.xlabel('GDP per capita (constant 2015 USD)')
plt.ylabel('Number of countries')
plt.savefig('histogram_matplotlib.png');

In [None]:
# Saving a seaborn plot
sns.histplot(data=df, x='gdpPerCapita', bins=40)
plt.title('Histogram of country GDP per capita')
plt.xlabel('GDP per capita (constant 2015 USD)')
plt.ylabel('Number of countries')
plt.savefig('histogram_seaborn.png');

**Final note only if you're working on Colab:** Remember to go to `File` > `Save a copy in Drive` to save a copy of this notebook in your Google account.