<a href="https://colab.research.google.com/github/sakeefkarim/python_plotting/blob/main/code/Plotting_in_Python.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Plotting in Python
### An Introduction to `seaborn` and `plotnine`

This notebook provides a _high-level_ overview of how to generate basic charts—from scatterplots to heatmaps—using the [`seaborn`](https://seaborn.pydata.org/) and [`plotnine`](https://plotnine.readthedocs.io/en/stable/) libraries in Python. Along the way, we’ll be using methods from [`pandas`](https://pandas.pydata.org/), [`matplotlib`](https://matplotlib.org/) and cognate libraries to modify our data, customize our plotting aesthetics and export our visualizations.


## Preliminaries

As of December 2022, we'll need to upgrade `seaborn` on Google Colab to use the `seaborn.objects` interface. After upgrading the library, ***restart your runtime session***. You can then proceed with the rest of the notebook.

In [None]:
!pip install seaborn --upgrade

Let's load our *essential* packages:

In [None]:
import scipy as sp
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import seaborn.objects as so

To [reiterate](https://colab.research.google.com/drive/17SWEs0aRX70KVKBbpjeBo8_Lyy5d8Elc?usp=sharing): we can _mount_ our Google Drive folders onto a Colab session to save plots, data sets and so on. To programmatically mount your Drive folder(s), run the following lines:

In [None]:
from google.colab import drive
drive.mount('/drive')

For this session, we'll be ***mostly*** playing around with data from [`palmerpenguins`](https://allisonhorst.github.io/palmerpenguins/) — a popular package for exploring, manipulating and visualizing data in `R`. We can easily import the `palmerpenguins` library [via Python](https://github.com/mcnakhaee/palmerpenguins) by either:

- _Installiing_ a new library:

In [None]:
!pip install palmerpenguins

from palmerpenguins import load_penguins

penguins = load_penguins()

penguins.head()

- Loading the data set via `seaborn`:

In [None]:
sns.load_dataset('penguins') 

# Note: the penguins dataframe in seaborn does *not* include the year variable!

Here's an example of how we can use another package (i.e., [`pyreadr`](https://pypi.org/project/pyreadr/)) to load data from `palmerpenguins` as an `.rds` file:

In [None]:
!pip install pyreadr

import pyreadr

url = 'https://github.com/sakeefkarim/python_plotting/blob/main/files/palmerpenguins.rds?raw=true'

destination = '/drive/My Drive/Python/palmerpenguins.rds'

pyreadr.download_file(url, destination)

penguins_git = pyreadr.read_r(destination)

# Checking to see which objects are available:

print(penguins_git.keys()) 

# Only none, ergo:

penguins_rds = penguins_git[None]

Of course, `pandas` contains inbuilt methods to import `.dta` (or `.sav`, `.sas` *etc.*) files: 

In [None]:
penguins_dta = pd.read_stata('https://github.com/sakeefkarim/python_plotting/blob/main/files/palmerpenguins.dta?raw=true')

penguins_dta

# `seaborn`

## Exploratory Visualizations

The `pairplot` function allows users to easily visualize pairwise associations between *all* the numeric variables in a data frame. This can be especially helpful for conducting [exploratory data analyses](https://www.datacamp.com/community/tutorials/exploratory-data-analysis-python).

In [None]:
# This unlocks seaborn's basic 'dark grid' theme:

sns.set_theme()

# Other seaborn themes: http://seaborn.pydata.org/tutorial/aesthetics.html#seaborn-figure-styles

sns.pairplot(penguins)

# The 'hue' parameter (for most seaborn functions) allows analysts to condition on a
# variable of interest:

# sns.pairplot(penguins, hue = 'island')

sns.pairplot(penguins, hue = 'species')

## Basic Scatterplots

As the cell below illustrates, we can use `matplotlib` functions in conjunction with `seaborn` to modify and export our visualizations.

In [None]:
mpl.style.use('fivethirtyeight')

# Basic scatterplot

p1 = sns.scatterplot(x = 'bill_depth_mm', y = 'bill_length_mm', 
                     hue = 'species', 
                     # For more seaborn palettes, see https://seaborn.pydata.org/tutorial/color_palettes.html.
                     palette = 'pastel', 
                     # Size parameter:
                     s = 60, 
                     # Transparency parameter:
                     alpha = 0.8,
                     data = penguins)

p1.set_title('Bill Depth vs Bill Length for Different Penguin Species', 
             weight = 'bold')

p1.set(xlabel='Bill Depth (mm)', 
       ylabel='Bill Length (mm)')

plt.legend(loc='upper right', 
           # Two elements correspond to x and y coordinates:
           bbox_to_anchor= (1.25, 1),
           # Remove legend frame:
           frameon=False, 
           borderaxespad=0)

plt.savefig('/drive/My Drive/Python/penguin_fig1.png', 
            dpi = 300, 
            # Making sure the image isn't cropped!
            bbox_inches='tight')

plt.show()

**Optional**: Below, you'll find supplemental code for modifying scales and formatting axis labels. To run the code, un-annotate the lines in the cell by removing the `#` sign.

In [None]:
# import matplotlib.pylab as pylab

# params = {'legend.fontsize': 'medium',
#           'axes.labelsize': 'large',
#           'axes.titlesize':'large',
#           'xtick.labelsize':'medium',
#           'ytick.labelsize':'medium'}

# pylab.rcParams.update(params)

# p1b = sns.scatterplot(x = 'bill_depth_mm', y = 'bill_length_mm', hue = 'island', data = penguins)

# p1b.set_xscale('log')

### Adding Regression Lines to Scatterplots

The `lmplot` function can be used to visualize linear relationships between variables. It can also map _interactions_ or spotlight conditional relationships (i.e., heterogeneous treatment effects) using faceted grids. 

In [None]:
sns.set_theme()

# Capitalizing sex in the data frame:

penguins.sex = penguins.sex.str.capitalize()

# Setting up the plot:

p2 = sns.lmplot(x='bill_depth_mm', y='bill_length_mm', 
                hue = 'species',
                # Unlocks facets/conditional panels:
                col = 'sex',
                # Equivalent to scales = free_x in ggplot:
                facet_kws=dict(sharex=False),
                data=penguins)

# Removing the legend automatically generated via lmplot 
# (so we can use mpl functions to manipulate our legend instead):

p2._legend.remove()

# Removing the 'variable name =' text from facet panels:

p2.set_titles('{col_name}')

p2.set_axis_labels(x_var='Bill Depth (mm)', 
                   y_var='Bill Length (mm)')

plt.legend(title='Species')

plt.show()

## Barplots

In the example below, we'll generate a grouped horizontal barplot that maps variation in penguins' body mass _across_ species and illustrates how this variation ebbs and flows as a function of sex.

In [None]:
sns.set(style='ticks', palette='Set2')

# Generating array corresponding to desired (alphabetical) x-axis order:

order_plot = penguins['species'].sort_values().unique()

p3 = sns.barplot(x='body_mass_g', y='species', 
                 # To create a grouped bar plot:
                 hue = 'sex',
                 # Removing confidence intervals ...
                 errorbar = "ci",
                 alpha = 0.9,
                 hue_order = ['Female', 'Male'],
                 order = order_plot,
                 data=penguins)

p3.set(xlabel='Body Mass (g)', ylabel='')

plt.legend(loc='lower right', 
           bbox_to_anchor= (1.25, 0), 
           borderaxespad=0)

## Boxplots

Boxplots are a wonderful way to visualize distributions. Just as with barplots, manipulating your `x` and `y` arguments can help you iterate between vertical and horizontal (boxplot) representations of the same underlying data.


In [None]:
mpl.style.use('ggplot')

# Changing our body mass measure to kg in lieu of g:

penguins['body_mass_kg'] = penguins.body_mass_g/1000

p4 = sns.boxplot(y='body_mass_kg', x='species', data=penguins,
                 width=0.5, 
                 hue = 'sex',
                 palette='Set2',
                 order = order_plot)

p4.set(xlabel='Body Mass (kg)', ylabel='')

plt.legend(title = '')

## Other Ways to Plot Distributions

### Violin Plots

In [None]:
sns.set_theme()

sns.violinplot(x='body_mass_kg', y='species', hue='sex',
               split = True,
               data=penguins)

### Histograms

In [None]:
sns.set_theme(style = 'whitegrid')

sns.histplot(x='body_mass_kg',
             hue = 'species',
             hue_order = order_plot,
             multiple = 'dodge',
             linewidth=.05, data = penguins)

sns.displot(x= 'body_mass_kg', 
            hue='species',
            kind='hist', 
            multiple='layer',
            palette = 'pastel',
            data = penguins)

### Joint Plots

In [None]:
mpl.style.use('fivethirtyeight')

sns.jointplot(x = 'body_mass_kg', y = 'bill_length_mm', 
              hue = 'species', 
              palette = 'pastel',
              s = 80,
              data = penguins)

### Density Plots

In [None]:
sns.set_theme()

sns.displot(x= 'body_mass_kg', 
            hue='species',
            #Kernel Density Estimate
            kind='kde', 
            alpha = 0.5,
            multiple = 'fill',
            palette = 'pastel',
            data = penguins)

## Visualizing Time-Varying Data

While our `penguins` data frame _does_ have time-varying data, we're limited to three time points:


In [None]:
#Share of observations by 'year':

penguins['year'].value_counts(normalize=True)

As a result, we'll *briefly* leverage another data frame that's often used to introduce students to data visualizations in `R`: [`gapminder`](https://www.rdocumentation.org/packages/gapminder/versions/0.3.0)!

In [None]:
gapminder = pd.read_excel('https://github.com/sakeefkarim/python_plotting/blob/main/files/gapminder.xlsx?raw=true')

### Line Plots

Let's explore how life expectancy (i.e., `lifeExp`) has evolved over time across continents:

In [None]:
# Removing Oceania (few observations):

gapminder['continent'].value_counts(normalize=True)

gapminder = gapminder[gapminder.continent != 'Oceania']

# Ensuring that continents appear in alphabetical order:

continent_order = gapminder['continent'].sort_values().unique()

# Grouped line plot:

sns.lineplot(x='year', y='lifeExp', 
             hue = 'continent',
             hue_order = continent_order,
             data = gapminder)

plt.xlabel('Year', fontsize=13)

plt.ylabel('Life Expectancy', fontsize=13)


# Facets:

p6 = sns.relplot(x='year', y='lifeExp', 
                 col='continent', 
                 col_order = continent_order,
                 col_wrap = 2,
                 kind='line', 
                 palette='Set2',
                 color = '#AA336A',
                 data = gapminder)

p6.set_titles('{col_name}')

### Heatmaps

To generate a heatmap in `seaborn`, let's modify our input data frame by:

+ Isolating countries in the Americas.
+ Isolating the following variables: `country`, `year` and `lifeExp`.
+ Reshaping our data from long to wide using the `pivot` method.

In [None]:
# Isolating countries in the Americas:

gapminder_adj = gapminder[ gapminder['continent'] == 'Americas']

# Zeroing-in on variables of interest:

gapminder_adj = gapminder_adj[['country', 'year', 'lifeExp']]

# Pivoting to wide format

gapminder_adj = gapminder_adj.pivot(index = 'country', columns = 'year', values = 'lifeExp')

gapminder_adj.head()

With these modifications in place, let's generate a quick heatmap that captures how life expectancy in the Americas has changed over time:

In [None]:
mpl.style.use('fivethirtyeight')

p5 = sns.heatmap(gapminder_adj, 
                 linewidths =0.5, 
                 # Includes all y-tick labels:
                 yticklabels=True,
                 square=True)

# Rotates x-axis labels (to enhance legibility):

plt.xticks(rotation=30) 

p5.set(xlabel='', ylabel='')

p5.set_title('Life Expectancy in the Americas', 
             size = 18,
             weight = 'bold')

# Changing the plot's dimensions:

plt.gcf().set_size_inches(15, 10)

plt.savefig('/drive/My Drive/Python/new_heatmap.png', bbox_inches='tight', dpi = 300)

## `seaborn`'s Object Interface

In September of 2022, `seaborn`'s developers introduced a new [`seaborn.objects`](https://seaborn.pydata.org/tutorial/objects_interface.html) interface to bring the *grammar of graphics* into the `seaborn` ecosystem. However, `seaborn.objects` is very much in its infancy.

In the sections below, we'll quickly touch on how to generate scatterplots, barplots and line plots using this new interface. As you work your way down the next few sections of this notebook, make sure to un-annotate lines of code to add _layers_ to your plot objects.

### Scatterplots

In [None]:
(
    so.Plot(penguins, x = 'body_mass_kg', y = 'bill_length_mm', color = 'species')
    #.add(so.Dot(alpha = 0.4, pointsize=7))
    #.label(x='Body Mass (kg)', y='Bill Length (mm)', color='')
    #.scale(color=so.Nominal(order = ['Adelie', 'Chinstrap', 'Gentoo']))
    #.facet('sex')
    #.share(x=True, y = True)
    #.add(so.Line(linewidth=3.5, alpha =0.8), so.PolyFit())
    #.layout(size=(10, 5)) #width, height
    #.save('/drive/My Drive/Python/penguin_fig2.png', bbox_inches='tight', dpi = 300)
)

### Barplots

In [None]:
(
     so.Plot(penguins, 'year', color='species')
    #.scale(x=so.Nominal(), color=so.Nominal(order = order_plot))
    #.add(so.Bar(), so.Count(), so.Dodge(gap = 0.05))
    #.label(x='', color=str.capitalize)
    #.layout(size=(7, 5)) 
)

### Line Plots

In [None]:
from seaborn import axes_style

(
    so.Plot(gapminder, x='year', y='lifeExp')
    #.facet('continent', wrap = 2, order = continent_order)
    #.add(so.Line(alpha=.35, color = 'grey', linestyle = '--'), so.Agg('median'), group='continent', col=None)
    #.add(so.Line(color = 'black', linewidth = 3), so.Agg('median'))
    #.label(y = 'Life Expectancy', x = str.capitalize)
    #.theme({**axes_style('ticks')})
)

# `plotnine`

The [`plotnine`](https://plotnine.readthedocs.io/en/stable/) library allows Python users to (essentially) use `ggplot2` within Python. In lieu of providing exhaustive examples,<a name='cite_ref-1'></a>[<sup>[1]</sup>](#cite_note-1)  I'll once again walk throughnsome basic plotting options: i.e., scatterplots, barplots and line plots.

## Scatterplots

In [None]:
from plotnine import *

gg_p1 = (
          ggplot(penguins, aes(x = 'body_mass_kg', y = 'bill_length_mm', colour = 'species')) + 
          geom_point(size = 4, alpha = 0.5) +
          theme_minimal() +
          theme(legend_position = 'top',
                axis_title = element_text(size = 12)) +
          labs(x = 'Body Mass (kg)', y = 'Bill Length (mm)', colour = '') +
          scale_colour_brewer(type = 'qual', palette = 'Dark2')
          )

ggsave(gg_p1, '/drive/My Drive/Python/penguin_fig3.png', dpi = 300)

# Adding facets

gg_p1 + facet_wrap('~year')

## Barplots

In [None]:
penguins.dropna(inplace=True)

(
ggplot(penguins, aes(x = 'species', y = 'body_mass_kg', fill = 'sex')) + 
geom_col(position = 'dodge', width = 0.5) + 
labs(x = '', y = 'Body Mass (kg)', fill = '') +
coord_flip() +
theme_bw() + facet_wrap('~year') +
scale_fill_grey()
)

## Line Plots

In [None]:
gapminder_grouped = gapminder.groupby(['continent', 'year']).median().reset_index()

(
ggplot(gapminder_grouped, aes(x = 'year', y = 'lifeExp', colour = 'continent')) + 
 geom_line(size = 2) +
 theme_minimal() +
 facet_wrap('continent') +
 theme(legend_position = 'none') +
 labs(x = '', y = 'Life Expectancy') +
 scale_colour_brewer(type = 'qual', palette = 'Dark2')

)

<a name='cite_note-1'></a>1. [^](#cite_ref-1) You all know quite a bit about `ggplot2` already!