# Correcting distortion in SAED data #

## A NOTE BEFORE STARTING ##

Since the ``emicroml`` git repository tracks this notebook under its original
basename ``correcting_distortion_in_saed_data.ipynb``, we recommend that you
copy the original notebook and rename it to any other basename that is not one
of the original basenames that appear in the ``<root>/examples`` directory
before executing any of the notebook cells below, where ``<root>`` is the root
of the ``emicroml`` repository. For example, you could rename it
``correcting_distortion_in_saed_data.ipynb``. This way you can explore the
notebook by executing and modifying cells without changing the original
notebook, which is being tracked by git.

## Import necessary modules ##

In [None]:
# For pattern matching.
import re

# For listing files and subdirectories in a given directory, and for renaming
# directories.
import os



# For general array handling.
import numpy as np

# For creating and plotting figures.
import hyperspy.api as hs
import matplotlib.pyplot as plt

# For minimizing objective functions.
import scipy.optimize



# For loading ML models for distortion estimation in CBED.
import emicroml.modelling.cbed.distortion.estimation

In [None]:
%matplotlib ipympl
%matplotlib ipympl

## Introduction ##

In this notebook, we show how one can use one of the machine learning (ML)
models that are trained as a result of executing the "action" described in the
page [Training machine learning
models](https://mrfitzpa.github.io/emicroml/examples/modelling/cbed/distortion/estimation/train_ml_model_set.html)
to correct distortion in selected area electron diffraction (SAED)
data. Strictly speaking, each ML model is trained to estimate distortion in
convergent beam electron diffraction (CBED) patterns. However, by exploiting the
fact that distortions predominantly come from post-specimen lenses,
e.g. projection lenses, we can estimate and correction distortion in SAED data
as follows:

1. Collect the target experimental SAED data;
2. Modify only pre-specimen lenses to produce CBED data;
3. Use a ML model to estimate distortion field in CBED data;
4. Correct distortion in SAED data using distortion field from step 3.

We demonstrate steps 3 and 4 using pre-collected experimental SAED and CBED
patterns of a calibration sample of single-crystal Au oriented in the \[100\]
direction. This experimental data was collected on a modified Hitachi SU9000
scanning electron microscope operated at 20 keV.

In order to execute the cells in this notebook as intended, a set of Python
libraries need to be installed in the Python environment within which the cells
of the notebook are to be executed. See [this
page](https://mrfitzpa.github.io/emicroml/examples/prerequisites_for_execution_without_slurm.html)
for instructions on how to do so. Additionally, a subset of the output that
results from performing the aforementioned actions is required to execute the
cells in this notebook as intended. One can obtain this subset of output by
executing said actions, however this requires significant computational
resources, including significant walltime. Alternatively, one can copy this
subset of output from a Federated Research Data Repository dataset by following
the instructions given on [this
page](https://mrfitzpa.github.io/emicroml/examples/modelling/cbed/distortion/estimation/copying_subset_of_output_from_frdr_dataset.html).

You can find the documentation for the ``emicroml`` library
[here](https://mrfitzpa.github.io/emicroml/_autosummary/emicroml.html).
It is recommended that you consult the documentation of this
library as you explore the notebook. Moreover, users should
execute the cells in the order that they appear, i.e. from top to
bottom, as some cells reference variables that are set in other
cells above them. **Users should make sure to navigate the
documentation for the version of ``emicroml`` that they are
currently using.**

## Loading and visualizing the SAED and CBED patterns ##

Let's load and visualize the target SAED pattern:

In [None]:
path_to_data_dir = "../data"
filename = (path_to_data_dir 
            + "/for_demo_of_distortion_correction_in_saed_data"
            + "/distorted_saed_pattern.npy")

kwargs = {"file": filename}
distorted_saed_pattern_image = np.load(**kwargs)

kwargs = {"data": distorted_saed_pattern_image}
distorted_saed_pattern_signal = hs.signals.Signal2D(**kwargs)



kwargs = {"axes_off": True, 
          "scalebar": False, 
          "colorbar": False, 
          "gamma": 0.2,
          "cmap": "plasma", 
          "title": ""}
distorted_saed_pattern_signal.plot(**kwargs)

This SAED pattern is subject to optical distortion which we want to correct. To
do this, keeping the sample inside, we modify only the pre-specimen lenses to
produce a CBED pattern which should be subject approximately to the same
distortion.

Let's load and visualize the target CBED pattern:

In [None]:
filename = (path_to_data_dir 
            + "/for_demo_of_distortion_correction_in_saed_data"
            + "/distorted_cbed_pattern.npy")

kwargs = {"file": filename}
distorted_cbed_pattern_image = np.load(**kwargs)

kwargs = {"data": distorted_cbed_pattern_image}
distorted_cbed_pattern_signal = hs.signals.Signal2D(**kwargs)



kwargs = {"axes_off": True, 
          "scalebar": False, 
          "colorbar": False, 
          "gamma": 0.2,
          "cmap": "plasma", 
          "title": ""}
distorted_cbed_pattern_signal.plot(**kwargs)

Next, we apply a mask to block all but most of the zero-order Laue zone (ZOLZ)
reflections:

In [None]:
N_x, N_y = distorted_cbed_pattern_signal.data.shape

L = 70
R = N_x-420
B = N_y-512
T = 110

distorted_cbed_pattern_signal.data[:, :L] = 0
distorted_cbed_pattern_signal.data[:, N_x-R:] = 0
distorted_cbed_pattern_signal.data[:T, :] = 0
distorted_cbed_pattern_signal.data[N_y-B:, :] = 0



kwargs = {"axes_off": True, 
          "scalebar": False, 
          "colorbar": False, 
          "gamma": 0.2,
          "cmap": "plasma", 
          "title": ""}
distorted_cbed_pattern_signal.plot(**kwargs)

We found that masking can improve the performance of our DL model. One possible
explanation is that at low beam energies and small CBED disk sizes, the Ewald
sphere curvature can be quite pronounced and the small-angle approximation may
not hold across the entire angular field of view of a given CBED pattern, both
of which may affect the validity of our assumption that the CBED pattern should
depict only near-perfect circular CBED disks of the same common radius, in the
absence of distortion. This should only be a concern at larger scattering
angles, which is why we did not mask most of the ZOLZ reflections.

## Estimating the distortion in the CBED pattern ##

Now let's load a ML model so that we can estimate the distortion in the CBED
pattern:

In [None]:
path_to_ml_model_state_dicts = path_to_data_dir + "/ml_models/ml_model_1"
pattern = "ml_model_at_lr_step_[0-9]*\.pth"
largest_lr_step_idx = max([name.split("_")[-1].split(".")[0]
                           for name in os.listdir(path_to_ml_model_state_dicts)
                           if re.fullmatch(pattern, name)])

ml_model_state_dict_filename = \
    (path_to_ml_model_state_dicts
     + "/ml_model_at_lr_step_{}.pth".format(largest_lr_step_idx))



module_alias = emicroml.modelling.cbed.distortion.estimation
kwargs = {"ml_model_state_dict_filename": ml_model_state_dict_filename,
          "device_name": None}  # Default to CUDA device if available.
ml_model = module_alias.load_ml_model_from_file(**kwargs)

_ = ml_model.eval()

With the ML model loaded, let's estimate the distortion in the CBED pattern:

In [None]:
sampling_grid_dims_in_pixels = distorted_cbed_pattern_image.shape
distorted_cbed_pattern_images = distorted_cbed_pattern_image[None, :, :]

kwargs = {"cbed_pattern_images": distorted_cbed_pattern_images,
          "sampling_grid_dims_in_pixels": sampling_grid_dims_in_pixels}
distortion_models = ml_model.predict_distortion_models(**kwargs)

distortion_model = distortion_models[0]

Note that any input distorted CBED pattern must have image dimensions, in units
of pixels, equal to
``2*(ml_model.core_attrs["num_pixels_across_each_cbed_pattern"],)``. This is
because a given ML model is trained for images of fixed dimensions, in units of
pixels.

Let's visualize the predicted distortion field:

In [None]:
slice_step = 16



quiver_kwargs = {"angles": "uv",
                 "pivot": "middle",
                 "scale_units": "width"}



attr_name = "sampling_grid"
sampling_grid = getattr(distortion_model, attr_name)
sampling_grid = (sampling_grid[0].numpy(), sampling_grid[1].numpy())

X = sampling_grid[0][::slice_step, ::slice_step]
Y = sampling_grid[1][::slice_step, ::slice_step]



fig, ax = plt.subplots()

attr_name = "flow_field_of_coord_transform"
flow_field = getattr(distortion_model, attr_name)
flow_field = (flow_field[0].numpy(), flow_field[1].numpy())

U = flow_field[0][::slice_step, ::slice_step]
V = flow_field[1][::slice_step, ::slice_step]

kwargs = quiver_kwargs
ax.quiver(X, Y, U, V, **kwargs)

title_font_size = 15

ax.set_title("Flow Field Of Coordinate Transformation", 
             fontsize=title_font_size)

axis_label_font_size = title_font_size
ax.set_xlabel("fractional horizontal coordinate", 
              fontsize=axis_label_font_size)
ax.set_ylabel("fractional vertical coordinate", 
              fontsize=axis_label_font_size)

for spatial_dim in ("x", "y"):
    major_tick_width = 1.5
    major_tick_length = 8
    minor_tick_width = major_tick_width
    minor_tick_length = major_tick_length//2
    tick_label_size = 15
    
    kwargs = {"axis": spatial_dim,
              "which": "major",
              "direction": "out",
              "left": True,
              "right": True, 
              "width": major_tick_width, 
              "length": major_tick_length, 
              "labelsize": tick_label_size}
    ax.tick_params(**kwargs)

    kwargs["which"] = "minor"
    kwargs["width"] = minor_tick_width
    kwargs["length"] = minor_tick_length
    ax.tick_params(**kwargs)

    ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)

    for side in ['top','bottom','left','right']:
        ax.spines[side].set_linewidth(major_tick_width)

plt.gca().set_aspect('equal')
plt.tight_layout()
plt.show()

## Correcting the distortion in the SAED pattern ##

Let's use the predicted distortion model to correct the distortion in the SAED
pattern:

In [None]:
kwargs = \
    {"distorted_images": distorted_saed_pattern_image[None, None, :, :]}
undistorted_then_resampled_images = \
    distortion_model.undistort_then_resample_images(**kwargs)

undistorted_saed_pattern_image = \
    undistorted_then_resampled_images[0, 0].numpy(force=True)



kwargs = {"data": undistorted_saed_pattern_image}
undistorted_saed_pattern_signal = hs.signals.Signal2D(**kwargs)



kwargs = {"axes_off": True, 
          "scalebar": False, 
          "colorbar": False, 
          "gamma": 0.2,
          "cmap": "plasma", 
          "title": ""}
undistorted_saed_pattern_signal.plot(**kwargs)

## Assessing the accuracy of the distortion correction ##

We know that the sample is single-crystal Au oriented in the \[100\] direction,
used for calibration. As such, in the absence of distortions, the zero-order
Laue zone (ZOLZ) reflections should lie approximately on a square lattice. We
say approximately because strictly speaking, the ZOLZ reflections furthest from
the direct beam should deviate from a square lattice by approximately 2 pixels,
due to the curvature of the Ewald sphere for a beam energy of 20
keV. Nevertheless, fitting square lattices to the most visible ZOLZ reflections
in both the distorted SAED pattern and the undistorted SAED pattern, and
comparing the errors of the fits, should be a reasonable way to assess the
accuracy of the distortion correction.

The first step is to locate the ZOLZ reflections that are sufficiently visible
in the SAED patterns. We can do this by applying masks and peak-finding
algorithms. Let's define a function that locates the visible ZOLZ reflections:

In [None]:
def find_visible_zolz_reflections(saed_pattern_signal):
    # Specify a mask that only reveals the visible ZOLZ reflections.
    N_x, N_y = saed_pattern_signal.data.shape[::-1]

    L = 30
    R = N_x-345
    B = N_y-455
    T = 110

    rectangular_mask_image = np.zeros((N_y, N_x), dtype=bool)
    rectangular_mask_image[T:N_y-B, L:N_x-R] = True

    rectangular_mask_signal = hs.signals.Signal2D(data=rectangular_mask_image)



    # Apply the mask to the SAED pattern and then apply the Difference of 
    # Gaussian peak-finding method to find candidate peaks that may be located 
    # at visible ZOLZ reflection locations.
    masked_saed_pattern_signal = saed_pattern_signal*rectangular_mask_signal

    kwargs = {"method": "difference_of_gaussian", 
              "overlap": 0, 
              "threshold": 0.0025, 
              "min_sigma": 1,
              "max_sigma": 2,
              "interactive": False, 
              "show_progressbar": False}
    find_peaks_result = masked_saed_pattern_signal.find_peaks(**kwargs)
    candidate_peak_locations = find_peaks_result.data[0][:, ::-1]



    # Some ZOLZ reflections have satelite peaks that get picked up by the above
    # peak-finding algorithm, hence we need to remove them. 
    selection = tuple()
    num_candidate_peaks = len(candidate_peak_locations)
    ref_distance = None
    num_iterations = 2

    for iteration_idx in range(num_iterations):
        nearest_neighbour_distances = tuple()
        
        for candidate_peak_idx in range(num_candidate_peaks):
            candidate_peak_location = \
                candidate_peak_locations[candidate_peak_idx]

            displacements = candidate_peak_locations-candidate_peak_location
            distances = np.sort(np.linalg.norm(displacements, axis=(1,)))[1:]

            nearest_neighbour_distance = distances[0]
            nearest_neighbour_distances += (nearest_neighbour_distance,)

            if ref_distance is not None:
                if 2*nearest_neighbour_distance >= ref_distance:
                    selection += \
                        (candidate_peak_idx,)
                else:
                    lattice_spacing_estimate = \
                        distances[2*distances > ref_distance][0]

                    abs_diff = np.abs(lattice_spacing_estimate-ref_distance)
                    rel_diff = abs_diff / ref_distance
                    tol = 0.06
                
                    if rel_diff < tol:
                        selection += (candidate_peak_idx,)
                

        if ref_distance is None:
            nn_distances = nearest_neighbour_distances
            nn_distances = np.array(nearest_neighbour_distances)

            outlier_threshold = 2
            outlier_registry = (np.abs(nn_distances - nn_distances.mean())
                                > outlier_threshold*nn_distances.std())

            ref_distance = nn_distances[~outlier_registry].mean()

    visible_zolz_reflections = \
        tuple(candidate_peak_locations[(selection,)].tolist())



    # Plot the SAED pattern with markers at the locations of the visible ZOLZ 
    # reflections.
    kwargs = {"axes_off": True, 
              "scalebar": False, 
              "colorbar": False, 
              "gamma": 0.2,
              "cmap": "plasma", 
              "title": ""}
    saed_pattern_signal.plot(**kwargs)

    for zolz_reflection in visible_zolz_reflections:
        kwargs = {"color": "black", 
                  "sizes": 3, 
                  "offsets": zolz_reflection}
        marker = hs.plot.markers.Points(**kwargs)
        saed_pattern_signal.add_marker(marker, permanent=False)

    return visible_zolz_reflections

Now let's locate the sufficiently visble ZOLZ reflections of the distorted SAED
pattern:

In [None]:
kwargs = \
    {"saed_pattern_signal": distorted_saed_pattern_signal}
zolz_reflection_selection_of_distorted_saed_pattern = \
    find_visible_zolz_reflections(**kwargs)

Now let's do the same for the undistorted SAED pattern:

In [None]:
kwargs = \
    {"saed_pattern_signal": undistorted_saed_pattern_signal}
zolz_reflection_selection_of_undistorted_saed_pattern = \
    find_visible_zolz_reflections(**kwargs)

Now we need to perform the fits. The objective function that we will minimize is
the square root of the mean of the Euclidean distances squared between the ZOLZ
reflections and their corresponding points on the square lattice fit:

In [None]:
def objective(x, visible_zolz_reflections, N_x):
    u_O_x, u_O_y, b, theta = x

    # u_0_x: fractional horizontal coordinate of origin of square lattice fit.
    # u_0_y: fractional vertical coordinate of origin of square lattice fit.
    # b: length of primitive lattice vector.
    # theta: rotation applied to lattice.
    # N_x: Number of pixels across SAED pattern.

    N = N_x

    result = 0.0

    for (k_x, k_y) in visible_zolz_reflections:
        to_round = ((k_x-u_O_x)*np.cos(theta) + (k_y-u_O_y)*np.sin(theta)) / b
        rounded = np.round(to_round)
        result += (to_round-rounded)**2

        to_round = (-(k_x-u_O_x)*np.sin(theta) + (k_y-u_O_y)*np.cos(theta)) / b
        rounded = np.round(to_round)
        result += (to_round-rounded)**2

    result *= ((b/N)**2) / len(visible_zolz_reflections)
    result = np.sqrt(result)

    return result

We define the fitting error to be the final value of the objective function.

Next, let's define a function that performs the fit:

In [None]:
def fit_visible_zolz_reflections_to_square_lattice(visible_zolz_reflections,
                                                   saed_pattern_signal):
    visible_zolz_reflections = np.array(visible_zolz_reflections)
    saed_pattern_image = saed_pattern_signal.data
    N_x, N_y = saed_pattern_signal.data.shape[::-1]
    
    ref_point = np.array((236, 242))
    u_O_guess = None

    for zolz_reflection in visible_zolz_reflections:
        if u_O_guess is None:
            u_O_guess = zolz_reflection
        else:
            distance_1 = np.linalg.norm(ref_point-u_O_guess)
            distance_2 = np.linalg.norm(ref_point-zolz_reflection)
            if distance_2 < distance_1:
                u_O_guess = zolz_reflection

    u_O_x_guess, u_O_y_guess = u_O_guess

                

    displacements = visible_zolz_reflections-u_O_guess
    b_guess = np.sort(np.linalg.norm(displacements, axis=(1,)))[1]

    

    for zolz_reflection in visible_zolz_reflections:
        displacement = zolz_reflection-u_O_guess
        distance = np.linalg.norm(displacement)
        if 1.1*b_guess > distance > 0:
            if displacement[1] > displacement[0] > 0:
                theta_guess = np.arctan2(displacement[1], 
                                         displacement[0])



    initial_guesses = (u_O_x_guess,
                       u_O_y_guess,
                       b_guess,
                       theta_guess)


    
    u_O_x_bounds = (0, N_x)
    u_O_y_bounds = (0, N_y)
    b_bounds = (0.5*b_guess, 1.5*b_guess)
    theta_bounds = (0.75*theta_guess, 1.25*theta_guess)

    bounds = (u_O_x_bounds,
              u_O_y_bounds,
              b_bounds,
              theta_bounds)

    

    kwargs = {"fun": objective,
              "args": (visible_zolz_reflections, N_x),
              "x0": initial_guesses,
              "bounds": bounds}
    minimization_result = scipy.optimize.minimize(**kwargs)



    u_O_x, u_O_y, b, theta = minimization_result.x

    u_O = np.array((u_O_x, u_O_y))
    b_1 = b*np.array((np.cos(theta), np.sin(theta)))
    b_2 = b*np.array((-np.sin(theta), np.cos(theta)))



    kwargs = {"axes_off": True, 
              "scalebar": False, 
              "colorbar": False, 
              "gamma": 0.2,
              "cmap": "plasma", 
              "title": ""}
    saed_pattern_signal.plot(**kwargs)



    M = 10

    for m_1 in range(-M, M+1):
        for m_2 in range(-M, M+1):
            lattice_position = (u_O + m_1*b_1 + m_2*b_2)
        
            displacements = visible_zolz_reflections-lattice_position
            distance = np.sort(np.linalg.norm(displacements, axis=(1,)))[0]            

            if 2*distance < b:
                kwargs = {"color": "black", 
                          "sizes": 3, 
                          "offsets": lattice_position.tolist()}
                marker = hs.plot.markers.Points(**kwargs)
                saed_pattern_signal.add_marker(marker, permanent=False)


    
    fitting_error = minimization_result.fun
    unformatted_msg = ("The error of the fit is: {}, "
                       "in units of the image width.")
    msg = unformatted_msg.format(fitting_error)
    print(msg)

    return fitting_error

Let's perform the fit for the distorted SAED pattern, using the above function:

In [None]:
kwargs = {"visible_zolz_reflections": \
          zolz_reflection_selection_of_distorted_saed_pattern, 
          "saed_pattern_signal": \
          distorted_saed_pattern_signal}
_ = fit_visible_zolz_reflections_to_square_lattice(**kwargs)

The black dots in the figure directly above form the best square lattice fit to
the sufficiently visible ZOLZ reflections.

Now let's do the same fitting procedure for the undistorted SAED pattern:

In [None]:
kwargs = {"visible_zolz_reflections": \
          zolz_reflection_selection_of_undistorted_saed_pattern, 
          "saed_pattern_signal": \
          undistorted_saed_pattern_signal}
_ = fit_visible_zolz_reflections_to_square_lattice(**kwargs)

As we can see both visually and from the lattice fit errors, our ML approach
corrects an appreciable amount of the distortion in the SAED pattern.