# Tutorial about drift analysis and correction

Lateral drift correction is useful in most SMLM experiments. To determine the amount of drift a method based on image cross-correlation or an iterative closest point algorithm can be applied.

We demonstrate drift analysis and correction on simulated data.

In [None]:
from pathlib import Path

%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats

import locan as lc

In [None]:
lc.show_versions(system=False, dependencies=False, verbose=False)

## Synthetic data

We use synthetic data that follows a Neyman-Scott spatial distribution (blobs). The intensity values are exponentially distributed and the number of localizations per frame follows a Poisson distribution: 

In [None]:
rng = np.random.default_rng(seed=1)

In [None]:
intensity_mean = 1000
localizations_per_frame_mean = 3

In [None]:
dat_blob = lc.simulate_Thomas(parent_intensity=1e-4, region=((0, 1000), (0, 1000)), cluster_mu=1000, cluster_std=10, seed=rng)
dat_blob.dataframe['intensity'] = stats.expon.rvs(scale=intensity_mean, size=len(dat_blob), loc=500)
dat_blob.dataframe['frame'] = lc.simulate_frame_numbers(n_samples=len(dat_blob), lam=localizations_per_frame_mean, seed=rng)

dat_blob = lc.LocData.from_dataframe(dataframe=dat_blob.data)

print('Data head:')
print(dat_blob.data.head(), '\n')
print('Summary:')
dat_blob.print_summary()
print('Properties:')
print(dat_blob.properties)

In [None]:
lc.render_2d(dat_blob, bin_size=10, rescale='equal');

## Add linear drift

We add linear drift with a velocity given in length units per frame.

In [None]:
dat_blob_with_drift = lc.add_drift(dat_blob, velocity=(0.002, 0.001), seed=rng)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
lc.render_2d(dat_blob_with_drift, ax=axes[0], bin_size=10);
lc.render_2d(dat_blob_with_drift, ax=axes[1], bin_size=2, rescale='equal', bin_range=((0, 500),(0, 500)));
lc.render_2d_mpl(dat_blob_with_drift, ax=axes[2], other_property='frame', bin_size=2, bin_range=((0, 500),(0, 500)), cmap='viridis');

## Estimate RMS errors

Knowing the ground truth, you can define a root mean squared error between the original localization coordinates and those after drift and later after correction.

In [None]:
def rmse(locdata, other_locdata):
    return np.sqrt(np.mean(np.square(np.subtract(locdata.coordinates, other_locdata.coordinates)), axis=0))

In [None]:
rmse(dat_blob, dat_blob_with_drift).round(2)

## Estimate drift

Drift can be estimated by comparing different chunks of successive localizations using either an "iterative closest point" algorithm or a "cross-correlation" algorithm. Per default, the icp algorithm is applied.

In [None]:
%%time
drift = lc.Drift(chunk_size=10_000, target='first', method='icp').compute(dat_blob_with_drift)

Transformations to register the different data chunks are represented by a transformation matrix and a transformation offset that together specifiy an affine transformation. The tansformation parameters are kept under the `transformations` attribute.

In [None]:
drift.transformations

The parameters can be visualized using the plot function. The matrix in this case is close to the unit matrix.

In [None]:
drift.plot(transformation_component='matrix', element=None);
plt.legend();

In [None]:
drift.plot(transformation_component='offset', element=None)
plt.legend();

## Model drift

A continuous transformation model as function of frame number is estimated by fitting the individual transformation components with the specified fit models. Fit models can be provided as `DriftComponent` or by a string representing standard model functions.

In [None]:
from lmfit.models import ConstantModel, LinearModel, PolynomialModel

drift.fit_transformations(slice_data=slice(None), offset_models=(lc.DriftComponent('spline', s=100), 'linear'), verbose=True);

The fit models are represented as `DriftComponent` and can be accessed through the transformation_models attribute.

In [None]:
drift.transformation_models

In [None]:
drift.transformation_models['offset'][0].type

In [None]:
drift.transformation_models['offset'][0].eval(0)

Each `DriftModel` carries detailed information about the fit under the model_result attribute. In most cases, except splines, this will be a `lmfit.ModelResult` object. 

In [None]:
drift.transformation_models['offset'][0].model_result

In [None]:
drift.transformation_models['offset'][1].type

In [None]:
drift.transformation_models['offset'][1].model_result

## Drift correction

The estimated drift is corrected by applying a transformation on the localization chunks (from_model=False).

In [None]:
%%time
drift.apply_correction(from_model=False);

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
lc.render_2d(drift.locdata_corrected, ax=axes[0], bin_size=2, rescale='equal', bin_range=((0, 200),(0, 200)));
lc.render_2d_mpl(drift.locdata_corrected, ax=axes[1], other_property='frame', bin_size=2, bin_range=((0, 200),(0, 200)), cmap='viridis');

In [None]:
rmse(dat_blob, drift.locdata_corrected).round(2)

Or the estimated drift is corrected by applying a transformation on each individual localization using the drift models (from_model=True).

In [None]:
%%time
drift.apply_correction(from_model=True)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
lc.render_2d(drift.locdata_corrected, ax=axes[0], bin_size=2, rescale='equal', bin_range=((0, 200),(0, 200)));
lc.render_2d_mpl(drift.locdata_corrected, ax=axes[1], other_property='frame', bin_size=2, bin_range=((0, 200),(0, 200)), cmap='viridis');

In [None]:
rmse(dat_blob, drift.locdata_corrected).round(2)

In [None]:
drift.locdata_corrected.meta

## Drift analysis by a cross-correlation algorithm

The same kind of drift estimation and correction can be applied using the image cross-correlation algorithm.

In [None]:
%%time
drift = lc.Drift(chunk_size=10_000, target='first', method='cc').\
        compute(dat_blob_with_drift).\
        fit_transformations(slice_data=slice(None), offset_models=(LinearModel(), LinearModel()), verbose=True).\
        apply_correction(from_model=True);

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
lc.render_2d(drift.locdata_corrected, ax=axes[0], bin_size=2, rescale='equal', bin_range=((0, 200),(0, 200)));
lc.render_2d_mpl(drift.locdata_corrected, ax=axes[1], other_property='frame', bin_size=2, bin_range=((0, 200),(0, 200)), cmap='viridis');

In [None]:
rmse(dat_blob, drift.locdata_corrected)