In [None]:
"""
3D drift estimation example.

Runs ok on google colab with T4 GPU, but for some reason the CPU multithreading is much slower than on windows
"""

%cd /content
!rm -rf /content/drift-estimation
!git clone https://github.com/qnano/drift-estimation.git
%cd /content/drift-estimation
!cmake .
!make clean & make

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from dme.dme import dme_estimate

# Need to have CUDA >= 10.1 update 2 installed
use_cuda=True


# Simulate an SMLM dataset in 3D with blinking molecules
def smlm_simulation(
        drift_trace,
        fov_width, # field of view size in pixels
        loc_error, # localization error XYZ
        n_sites, # number of locations where molecules blink on and off
        n_frames,
        on_prob = 0.1, # probability of a binding site generating a localization in a frame
        ):

    """
    localization error is set to 20nm XY and 50nm Z precision
    (assumping Z coordinates are in um and XY are in pixels)
    """

    # typical 2D acquisition with small Z range and large XY range
    binding_sites = np.random.uniform([0,0,-1], [fov_width,fov_width,1], size=(n_sites,3))

    localizations = []
    framenum = []

    for i in range(n_frames):
        on = np.random.binomial(1, on_prob, size=n_sites).astype(np.bool)
        locs = binding_sites[on]*1
        # add localization error
        locs += drift_trace[i] + np.random.normal(0, loc_error, size=locs.shape)
        framenum.append(np.ones(len(locs),dtype=np.int32)*i)
        localizations.append(locs)

    return np.concatenate(localizations), np.concatenate(framenum)

n_frames = 10000
fov_width = 200
drift_mean = (0.001,0,0)
drift_stdev = (0.02,0.02,0.02)
loc_error = np.array((0.1,0.1,0.03)) # pixel, pixel, um
# Ground truth drift trace
drift_trace = np.cumsum(np.random.normal(drift_mean, drift_stdev, size=(n_frames,3)), 0)
drift_trace -= drift_trace.mean(0)


localizations, framenum = smlm_simulation(drift_trace, fov_width, loc_error,
                                          n_sites=200,
                                          n_frames=n_frames)
print(f"Total localizations: {len(localizations)}")

crlb = np.ones(localizations.shape) * np.array(loc_error)[None]

estimated_drift,_ = dme_estimate(localizations, framenum,
             crlb,
             framesperbin = 10,  # note that small frames per bin use many more iterations
             imgshape=[fov_width, fov_width],
             coarseFramesPerBin=200,
             coarseSigma=[0.2,0.2,0.2],  # run a coarse drift correction with large Z sigma
             useCuda=use_cuda)

rmsd = np.sqrt(np.mean((estimated_drift-drift_trace)**2, 0))
print(f"RMSD of drift estimate compared to true drift: {rmsd}")

fig,ax=plt.subplots(3, figsize=(7,6))
for i in range(3):
    ax[i].plot(drift_trace[:,i],label='True drift')
    ax[i].plot(estimated_drift[:,i]+0.2,label='Estimated drift')
    ax[i].set_title(['x', 'y', 'z'][i])

    unit = ['px', 'px', 'um'][i]
    ax[i].set_ylabel(f'Drift [{unit}]')
ax[0].legend()
plt.tight_layout()



Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  on = np.random.binomial(1, on_prob, size=n_sites).astype(np.bool)


Total localizations: 200047
Computing XY drift
RCC: Computing image cross correlations. Image stack shape: (10, 408, 408). Size: 6 MB


100%|██████████| 45/45 [00:04<00:00, 10.71it/s]


Computing Z drift
RCC: Computing image cross correlations. Image stack shape: (10, 404, 404). Size: 6 MB


100%|██████████| 45/45 [00:04<00:00, 10.39it/s]


Computing initial coarse drift estimate... (200 frames/bin)


0it [00:00, ?it/s]