# Section 1: Introduction to NumPy

## Importing Modules
If vanilla python seems rather lackluster, that's because it is. Fortunately, the scientific stack adds a broad and powerful array of python packages fill in the gaps. Once installed, packages in python are easily loaded for use.

In [1]:
import numpy
print(numpy.__version__)

1.16.3


Commands from packages are like attributes of objects. For convenience, we will import packages using shorthand.

In [2]:
import numpy as np
print(np.__version__)

1.16.3


## NumPy Arrays
### Why arrays improve on lists
Arrays are the most basic type of the NumPy package. NumPy arrays are vectors (Nx1), similar to pythonic lists. In contrast to lists, however, arrays have many more attributes and can be modified in substantially more ways. Several examples are provided below demonstrating the improvement of arrays over lists.

In [3]:
## Define example list.
example_list = list(range(5))

print(example_list)
print(example_list * 3)            # scalar * list
print(example_list * example_list) # list * list

[0, 1, 2, 3, 4]
[0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4]


TypeError: can't multiply sequence by non-int of type 'list'

In contrast, NumPy arrays can be modified in this way. We use the **arange** command to initialize an array of sequential integers.

In [4]:
arr = np.arange(5)

print(arr, type(arr))
print(arr * 3)
print(arr * arr)

[0 1 2 3 4] <class 'numpy.ndarray'>
[ 0  3  6  9 12]
[ 0  1  4  9 16]


Every array has an object type. These can be looked up and modified.

In [5]:
print(arr, arr.dtype)   # Print current datatype.
arr = arr.astype(float) # Conver to float.
print(arr, arr.dtype)   # Print new datatype.

[0 1 2 3 4] int64
[0. 1. 2. 3. 4.] float64


Numpy arrays store metadata about their contents. These can be helpful, especially the **shape** atribute.

In [6]:
print('Array shape:', arr.shape) # Print shape of array.
print('Array size:', arr.nbytes) # Print bytes of array.

Array shape: (5,)
Array size: 40


Arrays now have a number of other built-in attributes 
not available for lists.

In [7]:
print('Round:', arr.round()) # Round array.
print('Min:', arr.min())     # Get max of array.
print('Max:', arr.max())     # Get min of array.
print('Sum:', arr.sum())     # Get sum of array.
print('Mean:',arr.mean())    # Get mean of array.

Round: [0. 1. 2. 3. 4.]
Min: 0.0
Max: 4.0
Sum: 10.0
Mean: 2.0


### Generating NumPy Arrays
There are many ways of generating NumPy arrays. The most simple way is to convert a Python list to NumPy array using the **array** command.

In [8]:
## Making an array from a list using the array command.
example_list = [4, 7, 9.4]
arr = np.array(example_list)

print(example_list, type(example_list))
print(arr, type(arr)) 

[4, 7, 9.4] <class 'list'>
[4.  7.  9.4] <class 'numpy.ndarray'>


NumPy has recreated all of the standard R/Matlab commands for 
generating arrays.

In [9]:
print('np.arange(5)        = %s' %np.arange(5))         # Array of 5 sequential integers.
print('np.zeros(5)         = %s' %np.zeros(5))          # Array of 5 zeros.
print('np.ones(5)          = %s' %np.ones(5))           # Array of 5 ones.
print('np.linspace(0,10,5) = %s' %np.linspace(0,10,5))  # Length-5 evenly-spaced array from 0 to 10.
print('np.logspace(0,2,5) = %s' %np.logspace(0,2,5))    # Length-5 logarithmically-spaced array.

np.arange(5)        = [0 1 2 3 4]
np.zeros(5)         = [0. 0. 0. 0. 0.]
np.ones(5)          = [1. 1. 1. 1. 1.]
np.linspace(0,10,5) = [ 0.   2.5  5.   7.5 10. ]
np.logspace(0,2,5) = [  1.           3.16227766  10.          31.6227766  100.        ]


## NumPy Matrices
### Why matrices improve on lists
It is possible to represent matrices in pythonic lists, though it is inefficient. Similar to the benefits of arrays, NumPy matrices dramatically improve upon the numerical capabilities of core python. Python can technically represent matrices as a list of lists.

In [10]:
nested_lists = [[1,2,3],
                [4,5,6],
                [7,8,9]]

print(nested_lists)
print(nested_lists[1][2])   # To extract the 2nd row, 3rd column, two brackets are necessary.

[[1, 2, 3], [4, 5, 6], [7, 8, 9]]
6


NumPy matrices make this much easier!

In [11]:
mat = np.array(nested_lists)

print(mat)
print(mat[1,2])

[[1 2 3]
 [4 5 6]
 [7 8 9]]
<class 'numpy.ndarray'>


Indexing of NumPy matrices (and arrays for that matter) obey all of the slicing conventions of lists. Commas are used to demarcate which axis a slice operation is targeting.

In [13]:
print('mat[1,2]  = %s' %mat[1,2])    # Second row, third column.
print('mat[0,:]  = %s' %mat[0,:])    # All the first row.
print('mat[:,-1] = %s' %mat[:,-1])   # All of the final column.

mat[1,2]  = 6
mat[0,:]  = [1 2 3]
mat[:,-1] = [3 6 9]


NumPy matrices have all the same attributes of NumPy arrays, but now functions can be applied to specific rows or columns in addition to the entire matrix.

In [14]:
## Sum across matrix.
print( mat.sum() )          

45


In [15]:
## Sum across columns.
print( mat.sum(axis=0) )

[12 15 18]


In [16]:
## Sum across rows.
print( mat.sum(axis=1) )

[ 6 15 24]


Importantly, all NumPy arrays and matrices have a **reshape** attribute allowing for transforming matrices into different dimensions.

In [17]:
print('Original shape', mat.shape)

# Reshape to column vector
mat = mat.reshape(9,1)
print('Column vector', mat.shape)

# Reshape to column vector
mat = mat.reshape(1,9)
print('Row vector', mat.shape)

Original shape (3, 3)
Column vector (9, 1)
Row vector (1, 9)


Importantly, reshape can be used to change the shape of NumPy arrays. The order flag can also change how they are organized (row-ordered vs. column-ordered).

In [18]:
print('Original:', mat)

Original: [[1 2 3 4 5 6 7 8 9]]


In [19]:
## Reshape (column organized)
print(mat.reshape(3,3,order='C'))

[[1 2 3]
 [4 5 6]
 [7 8 9]]


In [20]:
## Reshape (row organized)
print(mat.reshape(3,3,order='F')) 

[[1 4 7]
 [2 5 8]
 [3 6 9]]


The dimensions of matrices can also be quickly changed with **flatten** and **squeeze**. 

In [21]:
## Reshape to new dimensions.
mat = mat.reshape(3,3,1)
print('Original:', mat.shape)

## Flatten matrix.
print('Flatten:', mat.flatten().shape )

## Squeeze matrix.
print('Squeeze:', mat.squeeze().shape )

Original: (3, 3, 1)
Flatten: (9,)
Squeeze: (3, 3)


### Generating NumPy Matrices
Just as with arrays, there are a number of ways of generating NumPy matrices. The simplest is to use the **array** command on a list of lists. 

In [22]:
nested_lists = [[0, 1, 1],[2, 3, 5], [8, 13, 21]]
mat = np.array(nested_lists)

print(nested_lists)
print(mat)

[[0, 1, 1], [2, 3, 5], [8, 13, 21]]
[[ 0  1  1]
 [ 2  3  5]
 [ 8 13 21]]


The same commands previously introduced to generate NumPy arrays can also be used to generate matrices. Simply specify extra dimensions.

In [23]:
np.zeros( [3,3] )               # 3x3 matrix of zeros.
np.ones( [3,3] )                # 3x3 matrix of ones.
np.arange(9).reshape(3,3)       # 3x3 matrix of sequential integers.
np.linspace(0,8,9).reshape(3,3) # 3x3 matrix evenly-spaced array from 0 to 8. 
np.identity(3)                  # 3x3 identity matrix.

array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]])

Matrices can also be formed by joining NumPy arrays. There are several methods for doing this, including: `row_stack`/`r_`, `column_stack`/`c_`, `hstack`, `vstack`, and `concatenate`. We demonstrate each below. 

In [30]:
arr = np.arange(5)
print(arr)

## Join rows.
rows = np.row_stack([arr,arr])
print(rows)

## Join columns.
cols = np.column_stack([arr,arr])
print(cols)

[0 1 2 3 4]
[[0 1 2 3 4]
 [0 1 2 3 4]]
[[0 0]
 [1 1]
 [2 2]
 [3 3]
 [4 4]]


In [31]:
## np.hstack = join arrays along their columns.
print(np.hstack([arr,arr]))
print(np.hstack([arr.reshape(5,1), arr.reshape(5,1)]))

[0 1 2 3 4 0 1 2 3 4]
[[0 0]
 [1 1]
 [2 2]
 [3 3]
 [4 4]]


In [32]:
## np.vstack = join arrays along their rows.
np.vstack([arr,arr])
print(np.vstack([arr.reshape(5,1), arr.reshape(5,1)]))

[[0]
 [1]
 [2]
 [3]
 [4]
 [0]
 [1]
 [2]
 [3]
 [4]]


In [33]:
## np.concatenate = join arrays along specified axis.
## Default is first axis.
print(np.concatenate([arr, arr], axis=0))
print(np.concatenate([arr.reshape(5,1), arr.reshape(5,1)], axis=1))

[0 1 2 3 4 0 1 2 3 4]
[[0 0]
 [1 1]
 [2 2]
 [3 3]
 [4 4]]


### Generating Random Data
NumPy also includes many functions for generating random data. 

In [None]:
## Set the RNG seed!
np.random.seed(47404)

In [None]:
## Generate ten random integers between 0-9.
print( np.random.randint(0,10,10) )

In [None]:
## Generate five random samples of a normal distribution with mu=0,sd=1.
print( np.random.normal(0,1,5) )

In [None]:
## Generate 10 random coin flips.
print( np.random.binomial(1,0.5,10))

In [None]:
## Choose five numbers from 0-9 without replacement.
print( np.random.choice(np.arange(10), 5, replace=False) )

indexing/masking.

Probably want to do some 

## Core NumPy Functions
NumPy also introduces a number of useful functions designed to operate efficiently over NumPy arrays. The following is a non-exhaustive overview of some important NumPy functions.

### Matrix Math

### Rounding Functions

In [None]:
mat = np.linspace(0,1,5)
print('Original: %s' %mat)
print('np.round: %s' %np.round(mat, 1) )
print('np.floor: %s' %np.floor(mat) ) 
print('np.ceil:  %s' %np.ceil(mat) )

### Mathematical functions

NumPy includes a variety of mathematical functions. All of these can be applied across an entire matrix or across arrays.

In [None]:
np.sum;       # Sum of an array or matrix.
np.cumsum;    # Cumulative sum over an array.
np.prod;      # Element-wise multiplication of an array.
np.divide;    # Element-wise division of two arrays.
np.diff;      # Pairwise difference of elements of an array.
np.exp;       # Exponential transform.
np.log;       # Natural logarithm.
np.log10;     # Base-10 logarithm.

### Summary Functions

NumPy includes many functions to summarize an array. With the exception of **corrcoef**, all of these can be
applied across an entire matrix or across arrays.

In [None]:
np.min;           # Return the smallest element.
np.max;           # Return the largest element.
np.argmin;        # Return the index of the smallest element.
np.argmax;        # Return the index of the largest element.
np.mean;          # Compute the mean of an array.
np.median;        # Compute the median of an array.
np.std;           # Compute the standard deviation of an array.
np.var;           # Compute the variance (sd^2) of an array.
np.percentile;    # Compute the xth percentile of an array.
np.corrcoef;      # Compute the row-/col-wise correlation of a matrix.

In [None]:
## To give a few examples.
mat = np.vstack([ np.arange(5), np.arange(5)[::-1] ])
print('Original:\n%s' %mat)

In [None]:
## Compute percentile.
print( '70%% (all):  %s' %np.percentile(mat, 70) )

## Compute mean across rows.
print('70%% (rows): %s' %np.percentile(mat, 70, axis=1) )

## Compute mean across cols.
print('70%% (cols): %s' %np.percentile(mat, 70, axis=0) )

In [None]:
## Compute correlation.
print('Correlation:\n', np.corrcoef(mat))

### Set Functions
NumPy includes functions for identifying unique elements within or between arrays.

In [None]:
## Define two arrays for example.
arr1 = np.array([41, 16, 34, 0, 2, 20, 19, 14, 22, 15, 18, 9, 35, 41])
arr2 = np.array([42, 22, 40, 7, 33, 0, 12, 19, 44, 10, 31, 11, 11, 49])

In [None]:
## Sort elements (ascending order).
np.sort(arr1)

In [None]:
## Return unique elements.
np.unique(arr1)

In [None]:
## Return unique elements, count number of appearances.
np.unique(arr1, return_counts=True)

In [None]:
## Find the elements of array-1 in array-2.
np.in1d(arr1, arr2)

In [None]:
## Return all unique elements of arrays 1 & 2.
np.union1d(arr1, arr2)

In [None]:
## Return all elements belonging to both arrays 1 & 2.
np.intersect1d(arr1, arr2)

### Replacing List Comprehensions

NumPy includes a number of very helpful functions that act to replace list comprehensions (np.where) and for loops (np.apply_across_axis, np.apply_over_axes). These are often more efficient than writing out a full For loop. We will emphasize these functions with a simple example of standard-scoring (z-scoring) a matrix.

In [None]:
## Define the standard score (z-score) function.
def zscore(arr): 
    return (arr - arr.mean()) / arr.std()

## Define a simple matrix.
mat = np.arange(12).reshape(2,6)
print(mat)

Use **apply_across_axis** to apply our function across each row.

In [None]:
zmat = np.apply_along_axis(zscore, axis=1, arr=mat)
print(zmat.round(2))

Use the **where** command to set all negative numbers to 0, else 1. **where** is identical to the **which** command in R. 

In [None]:
amat = np.where(zmat < 0, 0, 1)
print(amat)

If no transforms are specified, **where** returns the indices of the array where the conditional is met.

In [None]:
print( np.where(zmat < 0 ) )

### Linear Algebra Functions

NumPy includes an entire submodule dedicated to efficient linear algebra functions (though it should be noted that SciPy has reimplemented them for maximal efficiency). See np.linalg for a full list of commands.

In [None]:
## Define a simple matrix.
mat = np.arange(16).reshape(4,4)
print(mat)

In [None]:
## Transpose the matrix
print(mat.T)           

In [None]:
## Return diagonal of matrix
print(np.diag(mat))

In [None]:
## Return upper triangular matrix
print(np.triu(mat))    

In [None]:
## Matrix multiply itself. Can also np.dot.
print(np.dot(mat, mat))    

In [None]:
## Can also use:
print(mat.dot(mat))

In [None]:
## Linear algebra operations include:
np.linalg.norm;        # Vector or matrix norm
np.linalg.inv;         # Inverse of a square matrix
np.linalg.det;         # Determinant of a square matrix
np.linalg.eig;         # Eigenvalues and vectors of a square matrix
np.linalg.cholesky;    # Cholesky decomposition of a matrix
np.linalg.svd;         # Singular value decomposition of a matrix
np.linalg.lstsq;       # Solve linear least-squares problem

# Introduction to Data Visualization
## Matplotlib
Matplotlib, or the Matlab plotting library, is the core plotting package of the scientific python distribution. The aim of Matplotlib is to recreate all of the plotting capabilities of Matlab in python. As such, much of the syntax/style of Matplotlib reflects Matlab plotting. 

We will go through the syntax of plotting the five most common types of plots: bar plots, line plots, scatter plots, boxplots, and heatmaps. We will also cover adding details to plots (e.g. axes, titles, legends, errorbars), making multiple plots in one figure, and scaling/sizing plots.

Similar to plotting in R, pure Matplotlib is a little clunky and a lot of code is needed to make more visually appealing plots. For this reason, we will introduce the Seaborn package later. Seaborn is similar to ggplot2 in that, with a tidy dataframe and some knowledge of the syntax, beautiful plots can quickly/easily be generated. But, it's better to crawl before walking, so we'll start with Matplotlib first.

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline 

## NOTE: The second line is a bit of notebook magic! 
## It's a jupyter-notebook shortcut that makes all
## plots be displayed at the bottom of a cell.

For plots to be displayed outside of the Notebook, do not set the backend and use the **plt.show()** command when finished plotting.

### Figures & Axes
A brief note: In Matplotlib jargon, an axis is a plot (e.g. barplot, scatterplot) and a figure is the surrounding object containing all plots. The most basic figure contains a single axis (i.e. one plot). More complex figures may have multiple plots of different sizes and numbers per row. 

This distinction is important because certain graphical tweaks can only be applied to figures or axes. For example, figures control the size of the canvas, the spacing of plots, and saving figures. Axes control plot-specific features, including labels, titles, and legends. To start, we will only generate figures with one plot. Later, we will introduce drawing multiple plots per figure.

### Barplots
Barplots are probably the least intuitive plot in Matplotlib because the user must specify the starting point and width of the bars (this is in contrast to other languages that automatically assign x-coordinates to the bars). Though clunky, this does provide some additional control to the user. 

In this example, we will plot the average response within subjects.

In [None]:
## Open new figure canvas. Define its size.
fig = plt.figure(figsize=(12,4))

## Make a new axis: The numbers correspond to
## row, column, and plot index. For example,
## subplot(211) would mean creating the 
## first plot of a figure with 2 rows and 1 column.
ax = plt.subplot(111)

## Use groupby to compute average response.
avg_resp = data.groupby('subj').respcat.mean()

## Now we have to specify the starting positions of
## the bars along the x-axis. We will have each bar
## begin at a sequential integer for each new subject
## (i.e., 0, 1, 2, ... N).
n_subj = len(avg_resp)
xpos = np.arange(n_subj)

## Now we can plot.
width = 0.9
ax.bar(left=xpos, height=avg_resp, width=width, color='#7ec0ee');

## Fix x-axis.
ax.set_xlim(xpos.min() - width, xpos.max()+width);   # 
                                                     # on both sides.
ax.set_xticks(xpos);                    # Set x-tickmarks at center of bars.
ax.set_xticklabels(data.subj.unique(), fontsize=12); # Set subjects as x-ticklabels
ax.set_xlabel('Subjects', fontsize=18);              # Set x-axis label.

## Fix y-axis.
ax.set_ylim(0,1);
ax.set_ylabel('Average Response', fontsize=18);

## Set title.
ax.set_title('Example Barplot', fontsize=24);

## Autoscale image.
plt.tight_layout();    # Reduce whitespace outside of plot.

Grouped barplots must be manually generated, as in the following example where we plot the average response and reponse time data by subject.

In [None]:
## Open new figure canvas. Define its size.
fig = plt.figure(figsize=(12,4))

## Make a new axis.
ax = plt.subplot(111)

## Use groupby to compute averages.
gb = data.groupby('subj')
avg_resp = gb.respcat.mean()
std_resp = gb.respcat.std()
avg_rt = gb.rt.mean()
std_rt = gb.rt.std()

## Specify the starting positions of the bars.
n_subj = 5
resp_pos = [0.0, 2.5, 5.0, 7.5, 10.0]
rt_pos = [1.0, 3.5, 6.0, 8.5, 11.0]

## Make barplots. Here we use the yerr flag to make error bars.
## we also define the condition using label.
width = 1.0
ax.bar(resp_pos, avg_resp[:n_subj], width, label = 'Choice',
       yerr=std_resp[:n_subj], color='#7ec0ee');
ax.bar(rt_pos, avg_rt[:n_subj], width, label='RT',
       yerr=std_rt[:n_subj], color='#71eeb8');


## Fix x-axis.
ax.set_xlim(-width, max(rt_pos)+width);
ax.set_xticks(np.array(rt_pos) - width / 2.);
ax.set_xticklabels(data.subj.unique(), fontsize=12);
ax.set_xlabel('Subjects', fontsize=18);

## Set title.
ax.set_title('Example Grouped Barplot', fontsize=24);

## Set horizontal line. Starts at zero and travels 
## length of plot.
ax.hlines(0, -width, max(rt_pos)+width, color='black')

## Add legend.
ax.legend(loc=2, frameon=False, fontsize=18)

## Autoscale image.
plt.tight_layout();    # Reduce whitespace outside of plot.

### Lineplots
Lineplots are more intuitive than are barplots, requirng at the minimum only the x- and y-datapoints. Many tweaks and embellishments can similarly be added. 

Here we will plot the nominal likelihood-of-take of the risky bet against the diff values. In the example, we use the Matplotlib color shorthands. These are:
* b: blue 
* g: green 
* r: red 
* c: cyan 
* m: magenta 
* y: yellow
* k: black 
* w: white

In [None]:
## Open new figure canvas. Define its size.
fig = plt.figure(figsize=(12,4))

## Make a new axis.
ax = plt.subplot(111)

## Compute groupby gain.
gb = data.groupby('diff')
respnum_avg = gb.respnum.mean()

## Plot average gain response.
## In this example, we use the matplotlib color shorthands. 
ax.plot( respnum_avg.index, respnum_avg, color='b', linewidth=3 );

## We will also shade within 1sd of the line. To do this, we use
## fill_between, which asks for the x-points and the y-lower/upper
## bounds along which to fill. The alpha parameter below reflects
## the transparency from {transparent = 0.0, opaque = 1.0}.
respnum_std = gb.respnum.std()
ax.fill_between( respnum_avg.index, respnum_avg - respnum_std, respnum_avg + respnum_std,
                color='b', alpha=0.2)

## We will also add dotted lines to demarcate the bounds of +- 1sd. 
ax.plot( respnum_avg.index, respnum_avg - respnum_std, linewidth=0.5, linestyle='--', color='k' )
ax.plot( respnum_avg.index, respnum_avg + respnum_std, linewidth=0.5, linestyle='--', color='k' )

## Minimize function calls with ax.set
ax.set_xlim(respnum_avg.index.min(), respnum_avg.index.max())
ax.set_xlabel('Gain - Loss', fontsize=18)
ax.set_yticks([1,2,3,4])
ax.set_yticklabels(['Strong Accept', 'Weak Accept', 'Weak Reject', 'Strongly Reject'])
ax.invert_yaxis()
ax.tick_params(axis='both',which='major',labelsize=14)
ax.set_title('Example Line Plot', fontsize=24)

## Autoscore.
plt.tight_layout()

### Scatterplots
The synxtax of scatterplots is similar to that of lineplots. Whereas lineplots have different [linestyles](https://matplotlib.org/examples/lines_bars_and_markers/line_styles_reference.html), scatterplots have different [marker styles](https://matplotlib.org/api/markers_api.html). 

Here we make scatterplots of the average reaction time by diff values.

In [None]:
## Open new figure canvas. Define its size.
fig = plt.figure(figsize=(8,4))

## Make a new axis.
ax = plt.subplot(111)

## Plot gain. "s" controls the size of the marker; marker control the shape.
## Edgecolor adds an outline to the marker.
rt = data.groupby(['diff']).rt.mean()
ax.scatter( rt.index, rt, s=100, marker='o', color='c', edgecolor='k');

## Add details.
ax.set_xlim(-11, 36)
ax.set_xlabel('Gain - Loss', fontsize=18)
ax.set_ylabel('Reaction Time (s)', fontsize=18)
ax.tick_params(axis='both',which='major',labelsize=14)
ax.set_title('Example Scatter Plot', fontsize=24)

## Autoscore.
plt.tight_layout()

### Histrograms
Histograms are very easy fortunately. Here we will plot two subjects reaction times.

In [None]:
## Open new figure canvas. Define its size.
fig = plt.figure(figsize=(8,4))

## Make a new axis.
ax = plt.subplot(111)

## RT distribution for subj06.
ax.hist( data.loc[data.subj == 'subj06', 'rt'], bins=25, 
        label='subj06', color='#1b9e77', alpha=0.75 );

## RT distribution for subj10.
ax.hist( data.loc[data.subj == 'subj10', 'rt'], bins=25, 
        label='subj10', color='#7570b3', alpha=0.75 );

## Add details.
ax.set_xlabel('Reaction Time (s)', fontsize=18)
ax.set_ylabel('Frequency', fontsize=18)
ax.tick_params(axis='both', which='major', labelsize=14)
ax.set_title('Example Histogram', fontsize=24)
ax.legend(loc='best', fontsize=16, frameon=False);

plt.tight_layout();

### Heatmaps
Heatmaps are very useful plots, but slightly counterintuitive in Matplotlib. We will go through an example looking at the average likelihood-of-take as a function of gain and loss. A full list of colormaps can be found [here](https://matplotlib.org/examples/color/colormaps_reference.html).  

In [None]:
## Open new figure canvas. Define its size.
fig = plt.figure(figsize=(8,6))

## Make a new axis.
ax = plt.subplot(111)

## Make groupby object by gain and loss.
resp = data.groupby(['gain','loss']).respcat.mean()

## Extract as matrix and reshape (there are 16 unique values for each).
resp = resp.as_matrix().reshape(16,16).T

## Plot. The parameters are as follows:
#### aspect: defines scaling. automatic scaling is preferable.
#### interpolation: smoothing of image. We want no smoothing.
#### origin: upper or lower, we want smaller values to begin in lower corner.
#### cmap: what colormap to use.
#### vmin, vmax: min,max values of colormap.
cbar = ax.imshow(resp, aspect='auto', interpolation='none', 
                 origin='lower', cmap='Blues', vmin=0, vmax=1)

## Add details.
ax.set_xticks(np.arange(0,16,2))
ax.set_xticklabels(np.unique(data.gain)[::2], fontsize=14)
ax.set_xlabel('Gain', fontsize=24)

ax.set_yticks(np.arange(0,16,2))
ax.set_yticklabels(np.unique(data.loss)[::2], fontsize=14)
ax.set_ylabel('Loss', fontsize=24)

## Add colorbar.
cbar = plt.colorbar(cbar, ax=ax);
cbar.ax.tick_params(labelsize=14) 
cbar.set_label('Average Response', fontsize=20)

### Embedding Multiple Plots in a Figure
With Matplotlib, there are 3.5 methods for constructing a figure with multiple embedded plots. These vary from minimal control of layout to maximal control of layout, and are as follows:
1. subplot/subplots: generates equal sized plots in a figure.
2. subplot2grid: generates plots of different sizes along a grid, minimal spacing options.
3. gridspec: generates plots of different sizes, many spacing options.

We have previously present subplot. Briefly we will show how subplot and subplots can be used to make a multiply-embedded figure.

In [None]:
## Subplot example: Figure needs to be called.
fig = plt.figure(figsize=(5,5))

## Make 2x2 figure.
ax = plt.subplot(2,2,1)
ax.text(0.5,0.5,'ax1', fontsize=18, ha='center', va='center');

ax = plt.subplot(2,2,2)
ax.text(0.5,0.5,'ax2', fontsize=18, ha='center', va='center');

ax = plt.subplot(2,2,3)
ax.text(0.5,0.5,'ax3', fontsize=18, ha='center', va='center');

ax = plt.subplot(2,2,4)
ax.text(0.5,0.5,'ax4', fontsize=18, ha='center', va='center');

In [None]:
## Subplots example: Figure can be called directly from command.

## Make 2x2 figure. Note that figsize can be directly called.
## Axes is a [2,2]-list of axes.
fig, axes = plt.subplots(2,2,figsize=(5,5))

## Iteratively add text.
for n in range(4):
    axes[n//2,n%2].text(0.5,0.5,'ax%s' %(n+1), fontsize=18, 
                        ha='center', va='center');

To use subplot2grid(), you provide geometry of the grid and the location of the subplot in the grid. Here we present an example geometry for a 3x3 grid.

In [None]:
## subplot2grid example: Figure must be called.
fig = plt.figure(figsize=(5,5))

## Call subplot2grid. First argument specifies the global
## layout of the figure. Second argument specifies which
## axis you are generating. colspan/rowspan describes the
## length of the axes within the grid of the figure.

ax = plt.subplot2grid((3, 3), (0, 0), colspan=3)
ax.text(0.5,0.5,'ax1', fontsize=18, ha='center', va='center');

ax = plt.subplot2grid((3, 3), (1, 0), colspan=2)
ax.text(0.5,0.5,'ax2', fontsize=18, ha='center', va='center');

ax = plt.subplot2grid((3, 3), (1, 2), rowspan=2)
ax.text(0.5,0.5,'ax3', fontsize=18, ha='center', va='center');

ax = plt.subplot2grid((3, 3), (2, 0))
ax.text(0.5,0.5,'ax4', fontsize=18, ha='center', va='center');

ax = plt.subplot2grid((3, 3), (2, 1))
ax.text(0.5,0.5,'ax5', fontsize=18, ha='center', va='center');

Gridspec objects are similar to subplot2grid in that they allow different sized plots within a figure. Gridspec objects also allow spacing configuration of axes within the figure. To give an example, we embed two sets of six plots with a large gap between them.

In [None]:
import matplotlib.gridspec as gridspec

## Initialize figure.
fig = plt.figure(figsize=(10,5))

## Define first 3x3 grid. 
gs = gridspec.GridSpec(3, 3)

## Update spacing parameters such that the figures can only
## extend to the 0.45 fraction of the figure.
gs.update(left=0.05, right=0.45, wspace=0.05)

## Create plots by indexing into grid.
ax1 = plt.subplot(gs[0, :])
ax2 = plt.subplot(gs[1, :-1])
ax3 = plt.subplot(gs[1:, -1])
ax4 = plt.subplot(gs[-1, 0])
ax5 = plt.subplot(gs[-1, -2])

## Define second 3x3 grid. 
gs = gridspec.GridSpec(3, 3)

## Update spacing parameters such that the figures can only
## start at 0.55 fraction of the figure.
gs.update(left=0.55, right=0.95, wspace=0.05)

## Create plots by indexing into grid.
ax1 = plt.subplot(gs[0, :])
ax2 = plt.subplot(gs[1, :-1])
ax3 = plt.subplot(gs[1:, -1])
ax4 = plt.subplot(gs[-1, 0])
ax5 = plt.subplot(gs[-1, -2])