# Rigid registration of atomic resolution STEM data 
This notebook walks through using the rigidregistration python package to register quickly acquired frames to one another to generate a high signal to noise image with reduced scan artifacts. For more detail on the technique, see **"Image registration of low signal-to-noise cryo-STEM data"**, Ultramicroscopy (2018), [DOI: 10.1016/j.ultramic.2018.04.008](https://doi.org/10.1016/j.ultramic.2018.04.008)

A high level overview of the process is:
1) Data is loaded and inspected

2) Frames are Fourier filtered to weight information as a function of spatial frequency, in order to improve evaluation of shifts between images.

3) Every frame in the stack is cross correlated with every other frame to measure the shift between them (the position of the maximum in the cross correlation). 

4) Incorrect shifts are identified by evaluating transitivity / identifying outliers, and corrected by enforcing transitivity or omitting images/cross correlations.

5) Images are translated to eliminate relative shifts, frames are averaged along the stacking axis to yield a high signal to noise average image.


### Getting started

These first few cells below are preparatory: importing the necessary python libraries and functions, loading and pre-inspecting the data, and instantiating (creating) the 'imstack' object, which contains most of the functions that will be used.

In this example, data which is formatted as **.tif files** are loaded using the `tifffile package`.  For other file formats common to electron microscopy data (e.g., .dm3, .ser...) `hyperspy` or `ncempy` can be used to read in the data.

In [None]:
# Import libraries and functions
import numpy as np
import matplotlib.pyplot as plt
from time import time
from tifffile import imread
import rigidregistration

%matplotlib notebook

In [None]:
# Load data.  
# Final axis of stack variable should iterate over images.
# For best performance, data should be normalized between 0 and 1

f="../Data/BSCMO_0047 3.7 Mx_0.tif"        # Filepath to data 
stack=np.rollaxis(imread(f),0,3)           # Rearrange axes so final axis iterates over images
stack = stack - stack.min()
stack=stack[:,:,:]/stack.max()             # Normalize data between 0 and 1
print("Analyzing {}.".format(f))
print("Shape: {}.".format(str(stack.shape)))

In [None]:
# Inspect data in preparation for registration
# On the left is a single frame of the stack, on the right is the FFT of the frame

for i in range(5,6):                      # Select which images from the stack to display
    fig,(ax1,ax2)=plt.subplots(1,2,figsize=(8,4))
    ax1.matshow(stack[:,:,i],cmap='gray')
    ax2.matshow(np.log(np.abs(np.fft.fftshift(np.fft.fft2(stack[:,:,i])))+1e-9),cmap='gray',vmin=np.average(np.log(np.abs(np.fft.fft2(stack[:,:,i]))))) 
    ax1.grid(False)
    ax2.grid(False)
    plt.show()

In [None]:
# Instantiate imstack object, and get all FFTs

s=rigidregistration.stackregistration.imstack(stack)    # Instantiage imstack object.
s.getFFTs()                                             # Calculate the FFTs of each image in the stack

### Fourier masking
Select a Fourier mask, defining which information will be weighted more or less to calculate the image shifts. 
**There are two choices to be made:** <span style="color:orange"> the cutoff frequency </span>and <span style="color:orange"> mask shape</span>. It is also possible to define a unique Fourier mask.

**1. The cutoff frequency**

In all cases, the parameter ***`n`*** controls the mask cutoff frequency; features smaller than ~n pixels will be ignored during image correlation.
For data with higher SNR, choosing a mask with n at the information limit is frequently sufficient.
For low-SNR data, choosing a mask with a cutoff frequency near the primary Bragg peaks is often preferable, as this heavily weights low frequency information to avoid unit-cell hops, but ideally contains just enough lattice information to 'lock-in' to the lattice.

**2. The mask shape**

Supported apodization functions for makeFourierMask() method are `"bandpass", "lowpass", "hann", "hamming", "blackman", "gaussian", "none"`.
For lattices lacking high rotational symmetry, an anisotropic mask is generally preferable to avoid overweighting one lattice direction, discussed further below.
Details functional forms can be found in the source code.

**3. Defining a unique mask**

The imstack object contains several meshgrids which correspond to Fourier space coordinates, facilitating defining a unique Fourier mask.  See below.

In [None]:
# Explore mask shape options

# The displayed plots are:
# Left: The full FFT in black and white, with a colored, semitransparent overlay showing the mask
# Middle: The FFT multiplied by the weighting mask
# Right: The cross correlation of a selected pair of images, weighted by the chosen mask


masktypes=["bandpass","hann","none"]   # List of mask shapes

# other options include "lowpass", "hamming", "blackman", "gaussian"

n=4                                                                              # Set cutoff frequency

i,j = 5,9                                    # Choose image pair to cross correlate. 
for masktype in masktypes:                   # Iterate over mask types
    s.makeFourierMask(mask=masktype,n=n)     # Set the selected Fourier mask
    s.show_Fourier_mask(i=i,j=j)             # Display results

In [None]:
# Vary n, the cutoff frequency

# In a typical workflow, this parameter is varied and the effect on the cross correlation is optimized

masktype="hann"

i,j = 5,9                                    # Choose image pair
for n in np.arange(2,10,4):                  # Select n values to test
    s.makeFourierMask(mask=masktype,n=n)     # Set the selected Fourier mask
    s.show_Fourier_mask(i=i,j=j)             # Display the results

In [None]:
# Elliptical Gaussian masks

# The makeFourierMask_eg() method creates an elliptical gaussian mask, with parameters n1, n2, and theta.
# n1, n2 define the cutoff frequencies along the two primary axes 
# theta defines the mask tilt, in degrees.

n1=6
n2=2
theta=60
i,j = 1,10
s.makeFourierMask_eg(n1=n1,n2=n2,theta=np.radians(theta))   # Set the selected mask
s.show_Fourier_mask(i=i,j=j)                                # Display the results


Define a final mask to use. For most datasets a <span style="color:red">hann mask</span> with a carefully chosen width (n) can work well. For structures with very different lattice parameters (such as BSCCO) an elliptical mask may work better, in which case you can use the <span style="color:red">s.makeFourierMask_eg</span> function as above and pass in an <span style="color:red">n1, n2, and theta</span>.


In [None]:
# input your final mask parameters here and run this cell before moving on to calculating the image shifts

masktype="hann"

i,j = 5,9                                    # Choose image pair
n =??                                        # Select n value
s.makeFourierMask(mask=masktype,n=n)         # Set the selected Fourier mask
s.show_Fourier_mask(i=i,j=j)                 # Display the results



# <span style="color:red">Checkpoint:</span>
Compare your selected masks and resulting cross correlations with your group -- why did you choose the mask you did? What features did you look for in the FFT and the cross correlation to make this choice?

### Calculate image shifts
Calculate the relative shifts between all pairs of images from their cross correlations.

Analytically, for two functions which are identical except for some shift, the shift is given by the maximum value of their cross correlation.
After calculating the cross correlation, here it's maximum may be found in one of two ways: finding the brightest pixel, or fitting gaussian functions.  Which method is used is controlled by the findMaxima parameter.

#### 1. Brightest pixel 

The shift is given directly by the position of the brightest pixel in the cross correlation.  This is the fastest approach, and is selected by setting ***`findMaxima`***=<span style="color:orange">"pixel"</span>.  In this approach, the relative shifts are determined with resolution of 1 pixel; however, the final shifts which are applied to the images before averaging will be determined using all of the relative image shifts between all pairs of images, thus the final shifts may still be determined with subpixel resolution, with an accuracy that generally improves with the number of images in the stack.

#### 2. Gaussian fitting 

Fitting a continuous function to the maximum of the cross correlation is a simple way to find the shift between an image pair with subpixel resolution.  For images of atomic lattices, a Gaussian is a natural choice for a fitting function.

For images of crystal lattices, 'unit cell hop' errors can occur, wherein the calculated shift between a pair of images is incorrect by a multiple of the primitive lattice vectors.  Unit cell hops become increasingly common due to sampling error when the real space sampling of the image is low enough that each atomic column is only a handful of pixels across.  Fitting a continuous function can correct for this sampling error, by performing Gaussian fits to several regions near the brighest several pixels, and then finding the cross correlation maximum using these continuous fits.

This method is selected by setting ***`findMaxima`***=<span style="color:orange">"gf"</span>.  Before running the Gaussian fitting method, three additional parameters for the fitting should be set by calling s.setGaussianFitParams().  These are:

  * ***`sigma_guess`***: sets the initial guess for the standard deviation of the guassian fits, in pixels.  This may be estimated simply and quickly by observing the peak widths in the cross correlations or the width of atomic columns in the raw data.

  * ***`window_radius`***: sets the size of the region about the brightest pixel which is used to fit a gaussian.  Should be set such that neighboring cross correlation peaks are excluded; the window used is a square region of size length **2xwindow_radius+1**.

  * ***`num_peaks`***: sets how many of the brightest pixels to fit gaussians to. Typically **3-5** are sufficient to handle sampling problems. 
  
#### For most  datasets using gaussian fits will yield a higher quality registration with only a modest amount of additional time required. 

Values for <span style="color:red">sigma_guess, window_radius</span> should be set based on the size of the peaks in the cross correlation.

This cell will take 1-3 minutes to run

In [None]:
# Calculate image shifts using gaussian fitting to get subpixel precision in each fit
# Be sure to see parameter descriptions above! The fits can take some time. 
# If you want to make sure things are progressing, set verbose to True to get a message after each shift

findMaxima = 'gf'
s.setGaussianFitParams(num_peaks=3,sigma_guess= ??,window_radius= ??)

t0=time()                                                  # Start time 
s.findImageShifts(findMaxima=findMaxima,verbose=False)     # Find shifts.  
t=time()-t0                                                # End time
print("Performed {} correlations in {} minutes {} seconds".format(int(s.nz*(s.nz-1)/2),int(t/60),t%60))

### Find and correct outliers in shift matrix
The previous step determines the relative shifts between all pairs of images.  Here, any incorrectly calculated shifts -- which may result from noisy, low SNR data -- are identified and corrected.  First, the shift matrix is displayed and inspected.  Next, outliers are identified.  Outliers are then corrected.

**1. Display the shift matrix**

For a stack of $N$ images, there are $N-1$ relative shifts for each image.  The complete set of relative shifts is stored in an $NxN$ matrix.  Element $i,j$ of the shift matrix gives the relative shift of image $i$ with respect to image $j$.<sup>1</sup>  To be physically consistent, the relative image shifts must add vectorially, i.e. $\mathbf{r}_{ij} + \mathbf{r}_{jk} = \mathbf{r}_{ik}$.  In this step, we enforce physical consistency in the shift matrix.  Visually, a correct shift matrix should appear "smooth" (though not necessarily varying monotonically).

**2. Identify outliers**

Several approaches are possible to identify outliers in the shift matrix.  

The recommended method is by enforcing transitivity, using `s.get_outliers()`.  There is one required parameter and one optional parameter.  The required parameter is a <span style="color:orange">threshhold</span> value - higher threshhold values permit greater deviations from perfect transitivity.  The optional parameter <span style="color:orange">maxpaths</span> is the number of transitivity relationships used to evaluate a given relative image shift - for example, $\mathbf{r}_{12} + \mathbf{r}_{24} = \mathbf{r}_{14}$ and $\mathbf{r}_{13} + \mathbf{r}_{34} = \mathbf{r}_{14}$ are two distinct transitivity relationships that can be used to evaluate the self consistency of the relative image shift from image 1 to image 4.

A simpler method to detect outliers is to require that each matrix element does not differ but too great an amount from its nearest neighbor elements, which roughly corresponds to enforcing the the shift matrix is "smooth".  The single required paramater is a <span style="color:orange">threshold</span> value.

Finally, if necessary, outliers can be directly identified manually.

**3. Correct outliers**

Outliers are corrected by extrapolating the correct values of any identified outliers using the transistivity relations.


<sup>1</sup>  But who needs ImageJ, amirite?

In [None]:
# Show Xij and Yij matrices
s.show_Rij()

In [None]:
# Identify outliers by enforcing transitivity, using additional optional features

# For many datasets if filtering and fit parameters were selected well above no outliers will be present
# For others, particularly at cryo, even perfect parameter selection will still have some outliers
# For this tutorial, the BSCMO dataset can have no outliers, while the BSCCO cryo dataset will always have some

s.set_bad_images([])                             # Flag entire images (rows/columns of shift matrix) as unuseable
                                                 # Passing in an empty list [] will use every image


s.get_outliers(threshold=??,maxpaths=??)       # Set outlier threshhold and maxpaths
                                               # 10 for both parameters is a reasonable starting point,
                                               # decrease threshold to mask more pairs, and increase to mask fewer
                                               # 10 maxpaths is normally sufficient but increase to improve 
                                               # transitivity assessment



s.show_Rij(mask=True)

If you are having trouble identifying outliers with transitivity above, you can use this cell to identify outliers using nearest neighbors to enforce "smoothness." Be careful -- this is a more error prone method of outlier identification,
<b>so only run it if you are not able to identify the outliers using transitivity</b>

In [None]:
# Identify outliers using nearest neighbors to enforce "smoothness"
# Can be useful if transitivity doesn't correctly identify outliers

s.set_bad_images([])                           # Flag entire images (rows/columns of shift matrix) as unuseable


s.get_outliers_NN(max_shift= ??)                 # Set outlier threshhold and maxpaths


s.show_Rij(mask=True)

In [None]:
# Correct outliers

s.make_corrected_Rij()    # Correct outliers using the transitivity relations
s.show_Rij_c()            # Display the corrected shift matrix

# <span style="color:red">Checkpoint:</span>
Compare your original and corrected shift matrices with your group.  Are your original (uncorrected) matrices similar to your group members? How did different parameter choices for the masks and gaussian fits affect the matrices, and what seemed to work best? 

Why are the shift matrices symmetric across the diagonal? What is different about the X vs Y shift matrix? What does this tell you about the sample's drift?

For outlier identification and correction, what methods were used / what parameters seemed to work best? Did you set any bad images? Were the final corrected shift matrices similar?

### Calculate average image

To obtain the average image, each image in the stack is shifted by an amount which is calculated from the shift matrix.  The entire, shifted image stack is then averaged.  Several functions are available for displaying and saving the resulting average image, and for summarizing the processing that's been applied to the data for quick review.

**1. Shifting and averaging**

The final shifts which are applied to each image in the stack are determined by averaging each row of the shift matrix, i.e. the shift applied to the $i$'th image is 

$\mathbf{r}_i = \frac{1}{N}\sum_{j}\mathbf{R}_{ij}$

Shifts are applied in Fourier space using the shift theorem.  Running `s.get_average_image()` calculates the shifts, shifts the images, and calculates the average image.  The final shifts which have been applied to the data are stored in ***`s.shifts_x`*** and ***`s.shifts_y`***, and the shifted image stack is stored in ***`s.stack_registered`***.

**2. Displaying and saving reports and averaged images**

The averaged image can be displayed and saved by running `s.show()` and `s.save()`, respectively.
A summary of all the processing that's been performed on the data can be displayed and saved by running `s.show_report()` and `s.save_report()`, respectively.  Saving a report is highly recommended, as it allows quick assessment of the fidelity of the final images.

The `s.save()` method saves a .tif file of the average image.  Metadata with the processing parameters used to create the average image are stored in the description string of the .tif file.

Note that because images have been shifted, a region about the edges of the final image will no longer be meaningful and should be discarded.  The `s.save()` method automatically discards unmeaningful edge data.  To keep the full field of view in the final output, pass the key:`value pair crop=False` to `s.save()`.  In this case, be sure to exclude the edge region from any final analysis.  The min/max values delineating meaningful data are stored as ***`s.xmin`***, ***`s.xmax`***, ***`s.ymin`***, and ***`s.ymax`***.

In [None]:
# Create registered image stack and average

# This cell may take a minute to run

s.get_averaged_image()   # To skip calculation of image shifts, or correcting the shift matrix, pass the function
                         # get_shifts=False, or correct_Rij=False

In [None]:
# Display final image

s.show()                 

# <span style="color:red">Checkpoint:</span>
Compare your averaged images and Fourier transforms with the group. Do all of them look similar? What could be the cause of any differences? Be sure to zoom in and compare the averaged image and fourier transform with the single frame from the start of the notebook. 

In [None]:
# Save the average image as a Tiff to the same directory the stack was in, with _registered added to the filename

s.save(f[:-4]+"_registered.tif")     # To keep the image uncropped, use crop=False.  The appropriate 
                                                # cropping boundaries are stored as metadata in the description string
                                                # of the output .tif file.

In [None]:
# Save report of registration procedure

s.save_report(f[:-4]+"_sample_report.pdf")