<<< [Index](#index)
# Statistical plots with Seaborn
This is a high-level tour of the Seaborn plotting library for producing statistical graphics in Python. The tour covers Seaborn tools for computing and visualizing linear regressions as well as tools for visualizing univariate distributions (e.g., strip, swarm, and violin plots) and multivariate distributions (e.g., joint plots, pair plots, and heatmaps). This also includes a discussion of grouping categories in plots.



* Visualizing Regression   
   * [Simple Linear Regression](#reg)
   * [Plotting Residuals of Regression](#res)
   * [Higher Order regression](#high)
   * [Grouping linear regression by hue](#hue)
   * [Grouping linear regressions by row or column](#row)
   
* [Visualizing Univariate Data](#uni)
   * [Constructing strip plots](#strip)
   * [Constructing swarm plots](#swarm)
   * [Constructing violin plots](#violin)
* [Visualizing Mulitvariate Distributions](#multi)  
  * [Plotting joint distributions](#j1)
  * [Plotting distributions pairwise](#pair)
  * [Visualizing correlations with a heatmap](#heat)



<a id = 'reg'> <a>
### Simple Linear regression

In [2]:
# Real time plot

% matplotlib auto
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
sns.set(style= 'darkgrid')

Using matplotlib backend: MacOSX


In [3]:
auto = pd.read_csv('auto-mpg.csv')

In [4]:
auto.head()

Unnamed: 0,mpg,cyl,displ,hp,weight,accel,yr,origin,name,color,size,marker
0,18.0,6,250.0,88,3139,14.5,71,US,ford mustang,red,27.370336,o
1,9.0,8,304.0,193,4732,18.5,70,US,hi 1200d,green,62.199511,o
2,36.1,4,91.0,60,1800,16.4,78,Asia,honda civic cvcc,blue,9.0,x
3,18.5,6,250.0,98,3525,19.0,77,US,ford granada,red,34.515625,o
4,34.3,4,97.0,78,2188,15.8,80,Europe,audi 4000,blue,13.298178,s


In [5]:
# Plot a linear regression between 'weight' and 'hp'
sns.lmplot(x ='weight', y='hp', data=auto)


<seaborn.axisgrid.FacetGrid at 0x11391ed68>

<a id = 'res'> <a>
### Plotting residuals of a regression

Often, you don't just want to see the regression itself but also see the residuals to get a better idea how well the regression captured the data. Seaborn provides sns.residplot() for that purpose, visualizing how far datapoints diverge from the regression line.



In [7]:
plt.close()
# Generate a green residual plot of the regression between 'hp' and 'mpg'
sns.residplot(x='hp', y='mpg', data=auto, color='green')

# Display the plot
plt.show()


<a id = 'high'><a>
### Higher Order regression

When there are more complex relationships between two variables, a simple first order regression is often not sufficient to accurately capture the relationship between the variables. Seaborn makes it simple to compute and visualize regressions of varying orders.



In [8]:
plt.close()
# Generate a scatter plot of 'weight' and 'mpg' using red circles
plt.scatter(auto.weight, auto.mpg, label = 'data', color = 'red', marker = 'o')

<matplotlib.collections.PathCollection at 0x1a23d70cf8>

In [9]:
sns.regplot(x= 'weight', y= 'mpg', data = auto, scatter= None, color = 'blue', label = 'order 1')

<matplotlib.axes._subplots.AxesSubplot at 0x1a2338f278>

In [10]:
sns.regplot(x = 'weight', y = 'mpg', data = auto , scatter = None, order = 2, color= 'green', label = 'order 2')

<matplotlib.axes._subplots.AxesSubplot at 0x1a2338f278>

In [11]:
plt.legend(loc= 'upper right')

<matplotlib.legend.Legend at 0x1a23db96d8>

<a id = 'hue'> <a>
## Grouping linear regression by hue
    
   

Often it is useful to compare and contrast trends between different groups. Seaborn makes it possible to apply **linear regressions separately for subsets of the data** by applying a groupby operation. Using the hue argument, you can specify a categorical variable by which to group data observations. The distinct groups of points are used to produce distinct regressions with different hues in the plot.



In [12]:
plt.close()

In [13]:
auto.head()

Unnamed: 0,mpg,cyl,displ,hp,weight,accel,yr,origin,name,color,size,marker
0,18.0,6,250.0,88,3139,14.5,71,US,ford mustang,red,27.370336,o
1,9.0,8,304.0,193,4732,18.5,70,US,hi 1200d,green,62.199511,o
2,36.1,4,91.0,60,1800,16.4,78,Asia,honda civic cvcc,blue,9.0,x
3,18.5,6,250.0,98,3525,19.0,77,US,ford granada,red,34.515625,o
4,34.3,4,97.0,78,2188,15.8,80,Europe,audi 4000,blue,13.298178,s


In [19]:
plt.close()
# Plot a linear regression between 'weight' and 'hp', with a hue of 'origin' and palette of 'Set1'
sns.lmplot(x = 'weight', y= 'hp', data = auto, hue = 'origin')
plt.show()


<a id= row> <a>
### Grouping linear regressions by row or column



In [20]:
plt.close()
# Plot linear regressions between 'weight' and 'hp' grouped row-wise by 'origin'
sns.lmplot(x= 'weight', y = 'hp', data = auto, col= 'origin')

<seaborn.axisgrid.FacetGrid at 0x1a2494fd30>

<a id = 'uni'> <a>
## Visualizing Univariate Data

<a id = 'strip'><a> 
### Constructing strip plots

The strip plot is one way of visualizing this kind of data. It plots the distribution of variables for each category as individual datapoints. For vertical strip plots (the default), distributions of continuous values are laid out parallel to the y-axis and the distinct categories are spaced out along the x-axis.

* For example, `sns.stripplot(x='type', y='length', data=df)` produces a sequence of *vertical* strip plots of **length** distributions grouped by **type** (assuming length is a continuous column and type is a categorical column of the DataFrame df). 
* Overlapping points can be difficult to distinguish in strip plots. The argument `jitter=True` helps spread out overlapping points.


In [12]:
auto.head()

Unnamed: 0,mpg,cyl,displ,hp,weight,accel,yr,origin,name,color,size,marker
0,18.0,6,250.0,88,3139,14.5,71,US,ford mustang,red,27.370336,o
1,9.0,8,304.0,193,4732,18.5,70,US,hi 1200d,green,62.199511,o
2,36.1,4,91.0,60,1800,16.4,78,Asia,honda civic cvcc,blue,9.0,x
3,18.5,6,250.0,98,3525,19.0,77,US,ford granada,red,34.515625,o
4,34.3,4,97.0,78,2188,15.8,80,Europe,audi 4000,blue,13.298178,s


In [21]:
plt.close()
plt.subplot(1, 2, 1)
sns.stripplot(x = auto.cyl, y= auto.hp, data = auto )
plt.xlabel('cyl')
plt.ylabel('hp')

Text(0,0.5,'hp')

In [22]:
# Make the strip plot again using jitter and a smaller point size
plt.subplot(1,2,2)
sns.stripplot(x = auto.cyl, y= auto.hp, data = auto, jitter = True, size = 3 )
plt.xlabel('cyl')
plt.ylabel('hp')

Text(0,0.5,'hp')

<a id = 'swarm'><a> 
### Constructing swarm plots

Swarm plot (`sns.swarmplot()`), which is very similar but **spreads out the points to avoid overlap** and provides a *better visual overview of the data*

* Another grouping can be added in using the `hue` keyword. For instance, using `sns.swarmplot(x='type', y='length', data=df, hue='build year')` makes a swarm plot from the DataFrame df with the 'length' column values spread out vertically, horizontally grouped by the column 'type' and each point colored by the categorical column 'build year'.

In [12]:
auto.head()

Unnamed: 0,mpg,cyl,displ,hp,weight,accel,yr,origin,name,color,size,marker
0,18.0,6,250.0,88,3139,14.5,71,US,ford mustang,red,27.370336,o
1,9.0,8,304.0,193,4732,18.5,70,US,hi 1200d,green,62.199511,o
2,36.1,4,91.0,60,1800,16.4,78,Asia,honda civic cvcc,blue,9.0,x
3,18.5,6,250.0,98,3525,19.0,77,US,ford granada,red,34.515625,o
4,34.3,4,97.0,78,2188,15.8,80,Europe,audi 4000,blue,13.298178,s


In [23]:
plt.close()
# Generate a swarm plot of 'hp' grouped horizontally by 'cyl'
plt.subplot(2, 1, 1)
sns.swarmplot(x = 'cyl', y = 'hp', data = auto)

<matplotlib.axes._subplots.AxesSubplot at 0x1a248ebe10>

In [24]:
plt.subplot(2, 1, 2)
sns.swarmplot(x = 'hp', y = 'cyl', data = auto, hue = 'origin', orient = 'h')

<matplotlib.axes._subplots.AxesSubplot at 0x1a2451d240>

<a id = 'violin'><a> 
### Constructing Violin plots

Both strip and swarm plots visualize all the datapoints. For large datasets, this can result in *significant overplotting.* Therefore, it is often useful to use plot types which reduce a dataset to more descriptive statistics and provide a good summary of the data. Box and whisker plots are a classic way of summarizing univariate distributions but seaborn provides a more sophisticated extension of the standard box plot, called a violin plot.



In [25]:
plt.close()
plt.subplot(2,1,1)
sns.violinplot(x = 'cyl', y = 'hp', data = auto)
plt.xlabel('cyl')
plt.ylabel('hp')

Text(0,0.5,'hp')

In [26]:
plt.subplot(2,1,2)
sns.violinplot(x = 'cyl', y = 'hp', data = auto, color = 'lightgray', inner = None)



<matplotlib.axes._subplots.AxesSubplot at 0x1a23350a90>

In [27]:
# Overlay a strip plot on the violin plot

sns.stripplot(x = 'cyl', y = 'hp', data = auto, jitter = True, size = 1.5)

<matplotlib.axes._subplots.AxesSubplot at 0x1a23350a90>

<a id = 'uni'> <a>
## Visualizing Multivariate Distributions

<a id = 'j1'> <a>
### Plotting Joint Distributions

Seaborn's `sns.jointplot()` provides means of visualizing bivariate distributions. The basic calling syntax is similar to that of `sns.lmplot()`. By default, calling sns.jointplot(x, y, data) renders a few things:
* A scatter plot using the specified columns x and y from the DataFrame data.
* A (univariate) histogram along the top of the scatter plot showing distribution of the column x.
* A (univariate) histogram along the right of the scatter plot showing distribution of the column y.

In [28]:
plt.close()
# Generate a joint plot of 'hp' and 'mpg'
sns.jointplot(x = 'hp', y= 'mpg',  data = auto )

# Display the plot
plt.show()




The seaborn function `sns.jointplot()` has a parameter kind to specify how to visualize the joint variation of two continuous random variables (i.e., two columns of a DataFrame)

* `kind='scatter'` uses a scatter plot of the data points
* `kind='reg'` uses a regression plot (default order 1)
* `kind='resid'` uses a residual plot
* `kind='kde'` uses a kernel density estimate of the joint distribution
* `kind='hex'` uses a hexbin plot of the joint distribution

In [30]:
# Generate a joint plot of 'hp' and 'mpg' using a hexbin plot
sns.jointplot(x = 'hp', y= 'mpg',  data = auto, kind = 'hex' )
# Display the plot
plt.show()




<a id = 'pair'> <a>
### Plotting distributions pairwise 

The function `sns.pairplot()` constructs a grid of all joint plots pairwise from all pairs of (non-categorical) columns in a DataFrame. The syntax is very simple: `sns.pairplot(df)`, where df is a DataFrame. The non-categorical columns are identified and the corresponding joint plots are plotted in a square grid of subplots. The diagonal of the subplot grid shows the univariate histograms of the individual columns.



In [32]:
# Print the first 5 rows of the DataFrame
auto.head()


Unnamed: 0,mpg,cyl,displ,hp,weight,accel,yr,origin,name,color,size,marker
0,18.0,6,250.0,88,3139,14.5,71,US,ford mustang,red,27.370336,o
1,9.0,8,304.0,193,4732,18.5,70,US,hi 1200d,green,62.199511,o
2,36.1,4,91.0,60,1800,16.4,78,Asia,honda civic cvcc,blue,9.0,x
3,18.5,6,250.0,98,3525,19.0,77,US,ford granada,red,34.515625,o
4,34.3,4,97.0,78,2188,15.8,80,Europe,audi 4000,blue,13.298178,s


In [34]:
# Plot the pairwise joint distributions from the DataFrame 
sns.pairplot(auto)
# Display the plot
plt.show()
plt.savefig('insanes.jpg')


Plot the pairwise joint distributions separated by continent of origin and display the regressions.

In [37]:
sns.pairplot(auto, hue = 'origin', kind = 'reg')
plt.savefig('pairplot2.png')

<a id = 'heat'> <a>
### Visualizing correlations with a heatmap



In [38]:
auto.describe()

Unnamed: 0,mpg,cyl,displ,hp,weight,accel,yr,size
count,392.0,392.0,392.0,392.0,392.0,392.0,392.0,392.0
mean,23.445918,5.471939,194.41199,104.469388,2977.584184,15.541327,75.979592,26.62681
std,7.805007,1.705783,104.644004,38.49116,849.40256,2.758864,3.683737,15.20831
min,9.0,3.0,68.0,46.0,1613.0,8.0,70.0,7.227136
25%,17.0,4.0,105.0,75.0,2225.25,13.775,73.0,13.754831
50%,22.75,4.0,151.0,93.5,2803.5,15.5,76.0,21.83229
75%,29.0,8.0,275.75,126.0,3614.75,17.025,79.0,36.29563
max,46.6,8.0,455.0,230.0,5140.0,24.8,82.0,73.387778
