# Template matching 2xxx SPED dataset

In [1]:
%matplotlib qt5

import numpy as np
import matplotlib.pyplot as plt
import hyperspy.api as hs
import pyxem as pxm

import diffpy
from diffsims.libraries.structure_library import StructureLibrary
from diffsims.generators.diffraction_generator import DiffractionGenerator
from diffsims.generators.library_generator import DiffractionLibraryGenerator
from diffsims.generators.rotation_list_generators import get_beam_directions_grid

from pyxem.utils import indexation_utils as iutls
from pyxem.utils import plotting_utils as putls
from pyxem.utils import polar_transform_utils as ptutls
from pyxem.utils import expt_utils as eutls
from pyxem.utils.plotting_utils import plot_template_over_pattern

import matplotlib.colors as mcolors
from orix.projections import StereographicProjection
from orix import plot, sampling
from orix.crystal_map import CrystalMap, Phase, PhaseList
from orix.quaternion import Orientation, Rotation, symmetry
from orix.vector import Vector3d, Miller
from orix.io import load, save
from orix.projections import StereographicProjection
from orix.vector.vector3d import Vector3d




In [2]:
file = r'D:\Template Matching\Elisabeth paper\SPED_600x600x12_10x10_4p63x4p63_1deg_100Hz_CL12cm_NBD_alpha5_spot1p3_preprocessed.hspy'

experimental_data = hs.load(file)
experimental_data.set_signal_type('electron_diffraction')

In [3]:
#experimental_data = experimental_data.inav[350:450, 100:200]

In [4]:
#experimental_data.plot(cmap='magma_r', norm='symlog')

## Important to get correct maximum_excitation_error s

In [5]:
minimum_intensity = 1E-20 # min intensity included in simulations
shift = 0.001             # shift used in transform func for template matching

In [6]:
s_Al    = 0.07
s_Th100 = 0.22
s_Th001 = 0.016
s_T1    = 0.022
s_list = [s_Al, s_Th100, s_Th001, s_T1]


# 1) Get rotations to simulate DPs at and create structure library of phases

### Importing structures and giving appropriate phase names

In [7]:
# Importing structures
structure_Al   = diffpy.structure.loadStructure('Al_a5p04.cif')
structure_Theta= diffpy.structure.loadStructure('thetaprime.cif')
structure_T1   = diffpy.structure.loadStructure('T1-a_Al-4p04.cif')

# List of phase names:
phases         = ['Al', 'ThetaPrime100', 'ThetaPrime001', 'T1']

# List of structures:
structures = [structure_Al, structure_Theta, structure_Theta, structure_T1]

### A tilt limiting finction

In [8]:
def _new_structure_matrix_from_alignment(old_matrix, x=None, y=None, z=None):
    """Taken from orix v0.9, see
    https://github.com/pyxem/orix/blob/fb269b0456163aa3ac1f80498a9894c53953dccb/orix/crystal_map/phase_list.py#L794-L845.
    
    Explanation of why changing the structure matrix (base) to use orix with
    ReciPro, which uses another alignment, is needed:
    https://orix.readthedocs.io/en/stable/crystal_reference_frame.html.
    """
    if sum([i is None for i in [x, y, z]]) > 1:
        raise ValueError("At least two of x, y, z must be set.")

    # Old direct lattice base (row) vectors in Cartesian coordinates
    old_matrix = Vector3d(old_matrix)
    ad, bd, cd = old_matrix.unit

    # Old reciprocal lattice base vectors in cartesian coordinates
    ar = bd.cross(cd).unit
    br = cd.cross(ad).unit
    cr = ad.cross(bd).unit

    # New unit crystal base
    new_vectors = Vector3d.zero((3,))
    axes_mapping = {"a": ad, "b": bd, "c": cd, "a*": ar, "b*": br, "c*": cr}
    for i, al in enumerate([x, y, z]):
        if al in axes_mapping.keys():
            new_vectors[i] = axes_mapping[al]
    other_idx = {0: (1, 2), 1: (2, 0), 2: (0, 1)}
    for i in range(3):
        if np.isclose(new_vectors[i].norm, 0):
            other0, other1 = other_idx[i]
            new_vectors[i] = new_vectors[other0].cross(new_vectors[other1])

    # New crystal base
    new_matrix = new_vectors.dot(old_matrix.reshape(3, 1)).round(12)

    return new_matrix

#Defining a function that returns the tilt range of interest around a given zone axis for a given structure:
def get_tilt_range_around_zone(phase_object, zone_axis, euler_grid, max_tilt):
    '''
    phase_object : Orix phase object (Ex. Phase(name='T1', space_group=191, structure=T1_structure))
    zone_axis : The zone axis you want to rotate about
    euler_grid : the total rotation grid needed to describe whole crystal
    max_tilt : The maximum angle deviation from zone axis to be included
    '''
    # Input
    uvw = zone_axis
    phase = phase_object.deepcopy()
    
    # Change structure matrix (crystal lattice base)
    # NB! Use with care, since orix assumes another alignment, e1||a, e3||c*!
    lat = phase.structure.lattice
    new_base = _new_structure_matrix_from_alignment(lat.base, x="a*", z="c")
    lat.setLatBase(new_base)
        
    # Sample grid
    g_grid = Rotation.from_euler(np.deg2rad(euler_grid))

    # Rotate Zs (optic axis) into *cartesian* crystal coordinates
    rz = Vector3d.zvector()
    hz = g_grid * rz

    # Cartesian crystal coordinates in Miller indices <uvw> (unit cell coordinates)
    hz_miller = Miller(xyz=hz.data, phase=phase)
    hz_miller.coordinate_format = "uvw"

    # Rotate into fundamental sector defined in orix
    hz_miller = hz_miller.in_fundamental_sector()

    # Define zone axis and symmetrically equivalent directions
    za = Miller(uvw=uvw, phase=phase)
    za2 = za.symmetrise(unique=True)

    # Get disorientation angle (smallest misorientation angle under symmetry)
    hz_miller2 = hz_miller.reshape(hz_miller.size, 1)
    za2 = za2.reshape(1, za2.size)
    all_angles = za2.angle_with(hz_miller2)
    angles = all_angles.min(axis=1)

    # Get orientations within threshold
    mask = angles <= np.deg2rad(threshold)
    masked_euler_grid = euler_grid[mask]
    
    if False:
        fig = plt.figure()
        ax = fig.add_subplot(projection="ipf", symmetry=phase.point_group)
        ax.scatter(hz_miller, c=angles)
        ax.scatter(hz_miller[mask], c="w", ec="k")
        fig.tight_layout()
#        fig = hz_miller.scatter(
#        hemisphere="both",
#        c=angles,
#        axes_labels=["Xc", "Yc"],
#        return_figure=True,
#        figure_kwargs=dict(figsize=(15, 15)),
#        )
#        fig.tight_layout()
    
    return masked_euler_grid

### Getting rotation lists for the different symmetries

In [9]:
resolution = 0.5 # maximum angle in degrees between nearest templates. Pretty rough grid for speed.

cubic_grid = get_beam_directions_grid("cubic", resolution, mesh="spherified_cube_edge") #Al
hex_grid   = get_beam_directions_grid("hexagonal", resolution, mesh="spherified_cube_edge") #T1
tetra_grid = get_beam_directions_grid("tetragonal", resolution, mesh="spherified_cube_edge") #theta

print("Number of patterns \ncubic: ", cubic_grid.shape[0], '\ntetra:', tetra_grid.shape[0], '\nhex:', hex_grid.shape[0])

Number of patterns 
cubic:  4186 
tetra: 12376 
hex: 16209


  phi2 = sign * np.nan_to_num(np.arccos(x_comp / norm_proj))


### Now limit the rotations within a threshold

In [10]:
# Threshold value for tilts away from the possible precipitate orientations in degrees:
threshold = 3

### Al

In [11]:
# Create a phase object (orix) for a crystal structure and give the expected zone axis:
uvw = [0, 0, 1]
Al_phase = Phase(name='Al', space_group=225, structure=structure_Al)

# Limit tilt range to orientations around zone axis to the threshold value:
Al_grid_masked_001 = get_tilt_range_around_zone(Al_phase, uvw, cubic_grid, threshold)
Al_grid_masked_001.size//3

20

### Theta Prime

In [12]:
# Create a phase object (orix) for a crystal structure and give the expected zone axis:
uvw_001 = [0, 0, 1]
uvw_100 = [1, 0, 0]
ThetaPrime_phase = Phase(name='ThetaPrime', space_group=119, structure=structure_Theta)

# Limit tilt range to orientations around zone axis to the threshold value:
ThetaPrime_grid_masked_001 = get_tilt_range_around_zone(ThetaPrime_phase, uvw_001, tetra_grid, threshold)
ThetaPrime_grid_masked_100 = get_tilt_range_around_zone(ThetaPrime_phase, uvw_100, tetra_grid, threshold)
print(ThetaPrime_grid_masked_001.size//3)
print(ThetaPrime_grid_masked_100.size//3)

20
35


### T1

In [13]:
# Create a phase object (orix) for a crystal structure and give the expected zone axis:
uvw_041 = [0,-4,1] #*421
T1_phase = Phase(name='T1', space_group=191, structure=structure_T1)

# Limit tilt range to orientations around zone axis to the threshold value:
T1_grid_masked_041 = get_tilt_range_around_zone(T1_phase, uvw_041, hex_grid, threshold)
T1_grid_masked_041.size//3

147

## Create a structure library

In [14]:
# List of rotations:
rot_grids = [Al_grid_masked_001, ThetaPrime_grid_masked_100, ThetaPrime_grid_masked_001, T1_grid_masked_041]

# Collect all in a structure library used to simulate our diffraction patterns:
library_Al    = StructureLibrary([phases[0]], [structures[0]], [rot_grids[0]])
library_Th100 = StructureLibrary([phases[1]], [structures[1]], [rot_grids[1]])
library_Th001 = StructureLibrary([phases[2]], [structures[2]], [rot_grids[2]])
library_T1    = StructureLibrary([phases[3]], [structures[3]], [rot_grids[3]])
library_phases = [library_Al, library_Th100, library_Th001, library_T1]

# 2) Prepare microscope for simulations and create diffraction library

### "Turn on the microscope"

In [15]:
diff_gen = DiffractionGenerator(accelerating_voltage=200,
                                precession_angle=0,
                                shape_factor_model="linear",
                                #scattering_params=None,
                                minimum_intensity=minimum_intensity,
                                )

lib_gen = DiffractionLibraryGenerator(diff_gen)

### And then simulate library for given s with resolution as set earlier

In [16]:
# half size of the images
half_shape = (experimental_data.data.shape[-2]//2, experimental_data.data.shape[-1]//2)

diffraction_calibration = experimental_data.axes_manager[2].scale
# maximum radius in reciprocal space to calculate spot intensities for
reciprocal_radius = np.sqrt(half_shape[0]**2 + half_shape[1]**2)*diffraction_calibration
print(reciprocal_radius)
# Calculate the simulated DP library
from diffsims.libraries.diffraction_library import DiffractionLibrary

diff_lib_full = DiffractionLibrary()
for i, s in enumerate(s_list):
    # Calculate the simulated DP library
    diff_lib_full[phases[i]] = lib_gen.get_diffraction_library(library_phases[i],
                                            calibration=diffraction_calibration,
                                            reciprocal_radius=reciprocal_radius,
                                            half_shape=half_shape,
                                            with_direct_beam=False,
                                            max_excitation_error=s)[phases[i]] # As given earlier
diff_lib = diff_lib_full
diff_lib.keys() # This shows you which phases you have, can then access diff_lib['key'].keys() and so on...

1.723304078565359


                                                  

dict_keys(['Al', 'ThetaPrime100', 'ThetaPrime001', 'T1'])

### Can check random simulation (And see if ZOLZ included when s too low)

In [17]:
print(phases)
if(True):
    diff_lib[phases[0]]['simulations'][0].plot()
    diff_lib[phases[1]]['simulations'][0].plot()
    diff_lib[phases[2]]['simulations'][0].plot()
    diff_lib[phases[3]]['simulations'][0].plot()

['Al', 'ThetaPrime100', 'ThetaPrime001', 'T1']


# 3) Do template matching

In [18]:
# log shift function used in template matching, small shift amplifies weak reflection (but also noise...)
print(shift)
def log_shift(raw):
    ''''
    Parameters
    -------
    raw : np.array((128, 128)). The raw data.
    shift : float. Introduces a shift for the log. To account for pixels with 0 value.
    
    Returns
    -------
    log_shift : np.array((128,128)). The log of the raw data.
    '''
    
    log_shift = np.log10(raw+shift) - np.log10(shift)# - 0.1
    return log_shift

    # or we can do a gamma correction rather
def gamma_corr(image):
    copied = image.copy()
    copied = copied**0.5
    return copied

0.001


In [19]:
# Parameters for matching:
delta_r = 1                         # res in inkreasing k value 
delta_theta = 1                     # res in rotation basically, azimuthal angle
max_r = experimental_data.axes_manager[2].size//2  # Max k to check IN PIXELS, Default (None) is from center of DP to corner.
intensity_transform_function = log_shift           # Function to apply to both image and template intensities on an element by element basis prior to comparison. Note that the function is performed on the CPU.
find_direct_beam     = False        # Lets try first without doing this as its kind of a pre processing thing...
direct_beam_position = None         # I think that: if find_dir_beam True and this none it automatically finds it 
normalize_image      = True         # Normalize the images in the correlation coefficient calculation
normalize_templates  = True         # Normalize the templates in the correlation coefficient calculation
frac_keep = 1 
n_keep    = None                    # Not sure what i ought to have here... look into this
n_best    = 5                       # keeping 5 best matches

result, phasedict = iutls.index_dataset_with_template_rotation(experimental_data,
                                                                diff_lib,
                                                                n_best                       = n_best,
                                                                frac_keep                    = frac_keep,
                                                                n_keep                       = n_keep,
                                                                delta_r                      = delta_r,
                                                                delta_theta                  = delta_theta,
                                                                max_r                        = max_r,
                                                                intensity_transform_function = intensity_transform_function,
                                                                normalize_images             = normalize_image,
                                                                normalize_templates          = normalize_templates,
                                                                )

[########################################] | 100% Completed | 36min 49.4s


### Can now get a phase map

In [20]:
# Creating custom colormap for phases:
from matplotlib.colors import to_rgba
from matplotlib.colors import LinearSegmentedColormap

color_names = ['linen', 'darkorange', 'dodgerblue', 'forestgreen']
colors = [to_rgba(c) for c in color_names]

cmap = LinearSegmentedColormap.from_list('gt_cmap', colors, N=len(color_names))

# Making a dataset with the phase indexes found:
phase_map = hs.signals.BaseSignal(result['phase_index'][:,:,0])

# Plot phase map with our cmap above:
experimental_data.plot(navigator=phase_map, 
                       norm='symlog', 
                       cmap='magma', 
                       navigator_kwds=dict(colorbar=True, cmap=cmap)
                      )
# Print the phase correspinding to numbers in phase map:
print(phasedict)
print('\nUsed: \ns =', s, '\nshift =', shift, '\nmin intensity =', minimum_intensity)

{0: 'Al', 1: 'ThetaPrime100', 2: 'ThetaPrime001', 3: 'T1'}

Used: 
s = 0.022 
shift = 0.001 
min intensity = 1e-20


In [34]:
#phase_map = hs.load('phaseMap_Almasked_DoG_shift_w_subtr0p05_indivS.hdf5')
phase_map = hs.load('phaseMap_Almasked_DoG_shift_w_subtr0p05_indivS.hdf5')
ground_truth = hs.load('Ground_truth_all.hdf5')

diff = phase_map - ground_truth
diff = np.array(diff, dtype='bool')
plt.figure()
plt.imshow(diff, cmap='gray')
plt.axis('off')
#plt.savefig('Difference_map_peaks_s0p023.png')

diff_map = hs.signals.BaseSignal(diff)
#diff_map.save('Difference_map_S8.hdf5')

error = np.count_nonzero(np.abs((phase_map-ground_truth.data)))/(512*512)
print(f'{error:.2%}')

1.76%


In [36]:
fig, ax = plt.subplots()
ax.imshow(phase_map, cmap=cmap)
ax.axis('off')
#plt.savefig('phasemap_DoGmasked',  transparent=True, bbox_inches='tight', pad_inches=0, dpi=300)
#plt.savefig(dir + filnavn, transparent=True, bbox_inches='tight', pad_inches=0, dpi=300)

In [26]:
phase_map.save('phasemap_raw_S8.hdf5')
diff_map.save('diffmap_raw_S8.hdf5')

In [23]:
corr_map = hs.signals.BaseSignal(result['correlation'][:,:,0])

Al_mask = corr_map > 0.06

phase_map_corr = phase_map*Al_mask

corr_map.plot(cmap='viridis')
phase_map_corr.plot()

print(corr_map.data)

[[0.04909621 0.05034316 0.05080305 ... 0.04574449 0.04560146 0.04578072]
 [0.04998801 0.05046719 0.05094641 ... 0.04545602 0.04543443 0.04579739]
 [0.05006605 0.05055881 0.05057929 ... 0.04577555 0.04572659 0.04567333]
 ...
 [0.04997284 0.05030674 0.05037801 ... 0.0510587  0.05092442 0.05104286]
 [0.0495874  0.04990137 0.05042325 ... 0.05116148 0.05116171 0.05101952]
 [0.04958849 0.04997083 0.05022321 ... 0.05114369 0.05126114 0.05126593]]


# Can plot the matched sim on DP

In [None]:
px = 25
py = 60
n_sol = 0

solution = result["orientation"]
# Query the necessary info from the solution
sim_sol_index = result["template_index"][py, px, n_sol]   # Collects (py,px) in solution map n
mirrored_sol = result["mirrored_template"][py, px, n_sol] # mirrored template represents whether the original template best fits (False) or the mirror image (True)
in_plane_angle = solution[py, px, n_sol, 0] # Orientations, the first angle is the in plane angle
#in_plane_angle = np.deg2rad(in_plane_angle)
print('Orientation of (', px,',', py, ')','is', solution[py, px, n_sol])
print('in_plane_angle =', in_plane_angle)


# Query the appropriate template:

found_phase = phases[result['phase_index'][py, px, n_sol]]
print(found_phase)

used_sim = diff_lib[found_phase]["simulations"][sim_sol_index]  # Collects the simulation fitted from template match
used_sim.plot()
print(diff_lib[found_phase]["orientations"][sim_sol_index])
fig, ax = plt.subplots(ncols=2, figsize=(8, 4))

# plotting phase map and the point we chose
ax[0].imshow(phase_map, cmap=cmap)
ax[0].scatter([px], [py], c="r", marker='x', s=200)

def log_norm(raw):
    raw.data *= 1 / raw.data.max()
    return np.log10(raw)

# LOG NORMED PATTERN:
def log_norm(raw):
    return np.log10(raw)
pattern = experimental_data.inav[px,py]
pattern.data *= 1/pattern.data.max()
pattern_log = pattern.map(log_norm, inplace = False)
plt.imshow(pattern_log, cmap='Greys')

# plotting the diffraction pattern and template
plot_template_over_pattern(pattern_log.data,   # DP at px, py
                                 used_sim,                       # simulated pattern found to fit best for DP at px, py
                                 ax=ax[1],                       # Optional, to get DP in same fig as correlation map
                                 in_plane_angle = in_plane_angle,#
                                 coordinate_system = "cartesian", 
                                 size_factor = 10,               # This only changes size of plotted X for DP spot marking, not real change
                                 #vmax=0.01,                      # Binned dataset has very high intensities... should fix this...
                                 mirrored_template=mirrored_sol,
                                 find_direct_beam=False,         # Gets better when True, should improve pre-process centering...
                                 cmap = "Greys",
                                 marker_color = "green",
                                 #direct_beam_position = (63,63)
                                )
for i in ax:
    i.axis("off")

Orientation of ( 25 , 60 ) is [960.   0.  90.]
in_plane_angle = 960.0
ThetaPrime001
[ 0.  0. 90.]


  return np.log10(raw)


[########################################] | 100% Completed |  0.1s


  return np.log10(raw)
