The basic function of this notebook is to take in a STEM image, run it through SingleOrigin, and then do (spatial) statistical analysis on it:
1. Read in a STEM image
2. Create a model of the structure in the image from a cif file
3. Fit the atom columns using SingleOrigin
4. Perform statistical analysis!

First, let's take care of steps 1 through 3 all together:

In [None]:
from importlib import reload
from pathlib import Path
from typing import Literal
import numpy as np
import tifffile as tif
import SingleOrigin as so
reload(so)
from copy import deepcopy
from abtem_scripts import graphical, analysis_backend
reload(graphical)
reload(analysis_backend)

In [None]:
image_path = Path(r"C:\Users\charles\Documents\AlScN\img\AlScN0.5.tif")
image_cropped = np.array(tif.imread(image_path))

## Uncomment the following two lines if the image is not pre-cropped
# image_cropped = qc.gui_crop(image_cropped)
# image_cropped = so.image_norm(image_cropped)

## Uncomment the following line to enable highpass filtering for the image (implemented in real space)
# _, image_cropped = divide_image_frequencies(image_cropped, s=350, show_images=True)

image_cropped = so.image_norm(image_cropped)

In [None]:
cif_path = graphical.gui_get_path()
uc = so.UnitCell(cif_path, origin_shift=[0, 0, 0])
uc.atoms.replace("Ti/O", "Ti/Zr", inplace=True)

za = [1, 1, 0]  # Zone axis direction
a1 = [-1, 1, 0]  # Apparent horizontal axis in projection
a2 = [0, 0, -1]  # Most vertical axis in projection

# Ignore light elements for HAADF
uc.project_zone_axis(za, a1, a2, ignore_elements=["N"])
uc.combine_prox_cols(toler=1e-2)

# Uncomment the following line to check this output if changing the u.c.
# uc.plot_unit_cell()

hr_img = so.HRImage(image_cropped)
lattice = hr_img.add_lattice("BZT", uc)

In [None]:
# NOTE: There are a couple of steps in this cell that require interaction and will time out if ignored

# If some FFT peaks are weak or absent (such as forbidden reflections),
#  specify the order of the first peak that is clearly visible
lattice.fft_get_basis_vect(a1_order=1, a2_order=1, sigma=2)

# lattice.get_roi_mask_std(r=15, buffer=20, thresh=0.25, show_mask=True)  # Alternative to directly cropping the image in cell 2
lattice.roi_mask = np.ones(image_cropped.shape)  # Use this line if the image was pre-cropped or cropped in cell 2

lattice.define_reference_lattice()

In [None]:
lattice.fit_atom_columns(buffer=0, local_thresh_factor=0, use_background_param=True,
                         use_bounds=True, use_circ_gauss=False, parallelize=True,
                         peak_grouping_filter=None)

# Must have only one column per projected unit cell.  If no sublattice meets this criteria,
#  specify a specific column in the projected cell.
lattice.refine_reference_lattice(filter_by='elem', sites_to_use='Ba')

In [None]:
lattice.get_fitting_residuals()

In [None]:
hr_img.plot_atom_column_positions(scatter_kwargs_dict={"s": 20}, scalebar_len_nm=None,
                                  color_dict={"Ba": "#FF0060", "Ti/Zr": "#84BABA"},
                                  outlier_disp_cutoff=100, fit_or_ref="ref")

In [None]:
hr_img.plot_disp_vects(sites_to_plot=["Ti/Zr"], arrow_scale_factor=2)

Now we need to wrangle our data into a form that is going to be digestable (and sensical) to the statistical methods provided by pysal.  There is a *lot* going on here.  For the implementation details, check out the functions in `analysis_backend.py`.  Basically:
1. We're making a copy of the the  `lattice.at_cols` DataFrame and immediately dropping some columns we won't need; this gives us a more convenient object to work with
2. We need to throw out any (egregious) outliers: sometimes the fitting will leave one or two columns with crazy values for their intensity, and these will badly mess up our analysis, so we discard them
3. We then need to normalize the column intensities, so they fall in the range 0-1.  There are a variety of different ways we can do this, but we need to do it somehow
4. Then we want to discard columns which are irrelevant to our intensity statistics (i.e. columns corresponding to sites which have a different composition); if we compare, for example, Ti/Zr columns to Ba columns, we won't be able to learn anything meaningful about the Ti/Zr columns
5. Step 5 messes up our indexing, so we use a nearest-neighbors search to re-index the remaining columns
6. Next, we'll assign each column its neighborhood (as a mini adjacency list); this neighborhood is used for several things (adding boundary members to clusters when plotting, during the intensity normalization proccess for some methods, and when calculating dispersions from perfect sites),  but importantly it is *not* the same as the neighborhood that is used to calculate the actual statistics
7. Finally, we will calculate the "dispersion" of fitted column positions from their perfect positions; these static shifts are our proxy for polarization

In [None]:
frame = deepcopy(lattice.at_cols)
frame.drop(["site_frac", "x", "y", "weight"], axis=1, inplace=True)  # We don't need these cols
frame.reset_index(drop=True, inplace=True)

analysis_backend.reject_outliers(frame, mode="total_col_int")
analysis_backend.normalize_intensity(frame, lattice.a_2d, n=8, method="global_minmax", kind="Ti/Zr")
analysis_backend.drop_elements(frame, ["Ba"])  # Doing this messes up the indexing in `neighborhood`
# So we'll rebuild the neighborhood in the new indexing
tree = analysis_backend.grow_tree(frame, lattice.a_2d, "Ti/Zr")
frame["neighborhood"] = frame.apply(lambda row: analysis_backend.get_near_neighbors(row, frame, lattice.a_2d, tree, n=8), axis=1)
analysis_backend.disp_calc(frame, normalize=True, kind="Ti/Zr")

Now we can actually do statistics:
1. In the first cell, set the parameters for the analysis to be performed
2. In the second cell, we'll run the analysis and print/plot the results

In [None]:
kind: Literal["moran_global", "moran_local", "moran_global_bivariate", "moran_local_bivariate",
              "geary_global", "geary_local", "geary_local_multivariate"] = "geary_local"
adj_type: Literal["rook", "queen", "king", "bishop", "120"] = "queen"
columns: list[str] = ["dot_normal_disp"]
p: int = 10000
printstats: bool = True
sig: float = 0.05

In [None]:
sts = analysis_backend.get_stats(df=frame, adj_type=adj_type, a2d=lattice.a_2d, kind=kind,
                                 columns=columns, p=p, printstats=printstats)
match kind:
    case "moran_global" | "moran_global_bivariate":
        analysis_backend.add_stats_to_frame(frame, sts, "moran")
    case "moran_local" | "moran_local_bivariate":
        analysis_backend.add_stats_to_frame(frame, sts, "moran")
        analysis_backend.plot_moran_clusters(frame, image_cropped, sig=sig)
    case "geary_global":
        analysis_backend.add_stats_to_frame(frame, sts, "geary")
    case "geary_local" | "geary_local_multivariate":
        analysis_backend.add_stats_to_frame(frame, sts, "geary")
        analysis_backend.plot_geary_clusters(frame, image_cropped, sig=sig)
    case _:
        raise ValueError