# Introduction

This notebook demonstrates two approaches to nanocrystal segmentation:
1. Virtual dark-field (VDF) imaging-based segmentation
2. Non-negative matrix factorisation (NMF)-based segmentation

The segmentation is demonstrated on a SPED dataset of partly overlapping MgO nanoparticles, where some of the particles share the same orientation. The SPED data can be found in [1]. An article including explanation of the methods and discussions of the results is under review. 

[1] T Bergh. (2019) *Scanning precession electron diffraction data of partly overlapping magnesium oxide nanoparticles.* doi: 10.5281/zenodo.3382874.

This functionaility was introduced in pyxem-0.10.0 (November 2019) and has been checked to run. Bugs are always possible, do not trust the code blindly, and if you experience any issues please report them here: https://github.com/pyxem/pyxem-demos/issues

# Contents

1. <a href='#gen'> Setting up, Loading Data, Pre-processing</a>
2. <a href='#vdf'> Virtual Image Based Segmentation</a>
3. <a href='#nmf'> NMF Based Segmentation</a>

# <a id='gen'></a> 1. Setting up, Loading Data, Pre-processing

Import pyxem and other required libraries

In [None]:
%matplotlib qt
import numpy as np
import hyperspy.api as hs
import matplotlib.pyplot as plt
import pyxem as pxm

Load demonstration data

In [None]:
s = pxm.load_hspy('MgO2_16_TX5_c_bin2,2.hdf5',
                  lazy=False,
                  assign_to='electron_diffraction2d')

Plot data to inspect

In [None]:
s.plot(cmap='magma_r')

Remove the background

In [None]:
sigma_min = 1.7
sigma_max = 13.2

s_rb = s.remove_background('gaussian_difference', 
                              sigma_min=sigma_min, 
                              sigma_max=sigma_max)

Find the position of the direct beam in background subtracted data.

In [None]:
shifts = s_rb.get_direct_beam_position(method='cross_correlate',
                                       square_width=15,
                                       radius_start=2,
                                       radius_finish=6)

Visualise the direct beam positions

In [None]:
plt.figure()
plt.imshow(shifts.data.reshape(s.data.shape[0],s.data.shape[1],2)[..., 0])
plt.figure()
plt.imshow(shifts.data.reshape(s.data.shape[0],s.data.shape[1],2)[..., 1])

Apply direct beam shifts to the original dataset to align.

In [None]:
s.align2D(shifts=shifts.data, fill_value=0, crop=True)

Set calibrations

In [None]:
scale = 0.03246
scale_real = 2.56
s.set_diffraction_calibration(scale)
s.set_scan_calibration(scale_real)

s_rb.set_diffraction_calibration(scale)
s_rb.set_scan_calibration(scale_real)

# <a id='vdf'></a> 2. Virtual Image Based Segmentation

Find all diffraction peaks for all PED patterns. 
The parameters were found by interactive peak finding:

`peaks = s_rb.find_peaks_interactive(imshow_kwargs={'cmap': 'magma_r'})`

In [None]:
peaks = s_rb.find_peaks(method='laplacian_of_gaussians', 
                        min_sigma=0.7,
                        max_sigma=10,
                        num_sigma=30, 
                        threshold=0.046, 
                        overlap=0.5, 
                        log_scale=False,
                        exclude_border=True)

Visualise the number of diffraction peaks found pr. probe position

In [None]:
diff_map = peaks.get_diffracting_pixels_map()
diff_map.plot()

Refine the peak positions

In [None]:
from pyxem.generators.subpixelrefinement_generator import SubpixelrefinementGenerator
from pyxem.signals.diffraction_vectors import DiffractionVectors

Padding is used so that peaks at the edges of the diffraction patterns are treated correctly.

In [None]:
refine_gen = SubpixelrefinementGenerator(dp=s_rb,
                                         vectors=peaks,
                                         padding=12)
peaks_refined = DiffractionVectors(
    refine_gen.center_of_mass_method(square_size=4))
peaks_refined.axes_manager.set_signal_dimension(0)

Find the unique diffraction peaks by clustering

In [None]:
distance_threshold = scale*0.89
min_samples = 10

unique_peaks = peaks_refined.get_unique_vectors(method='DBSCAN',
    distance_threshold=distance_threshold, min_samples=min_samples)
print(np.shape(unique_peaks.data)[0], ' unique vectors were found.')

Visualise the detected unique peaks by plotting them on the maximum of the signal. 

In [None]:
radius_px = s_rb.axes_manager.signal_shape[0]/2
reciprocal_radius = radius_px * scale

In [None]:
unique_peaks.plot_diffraction_vectors(
    method='DBSCAN',
    unique_vectors=unique_peaks,
    distance_threshold=distance_threshold,
    xlim=reciprocal_radius,
    ylim=reciprocal_radius,
    min_samples=min_samples,
    image_to_plot_on=s_rb.max(),
    image_cmap='magma_r',
    plot_label_colors=False)

Visualise both the clusters and the unique peaks obtained after DBSCAN clustering. 

*NB The cluster colors are randomly generated, so run it again if it is hard to discern two close clusters.*

In [None]:
peaks_refined.plot_diffraction_vectors(
    method='DBSCAN',
    xlim=reciprocal_radius, 
    ylim=reciprocal_radius,
    unique_vectors=unique_peaks, 
    distance_threshold=distance_threshold,
    min_samples=min_samples, 
    image_to_plot_on=s_rb.max(), 
    image_cmap='gray_r',
    plot_label_colors=True, 
    distance_threshold_all=scale*0.1)

Filter the unique vectors by magnitude in order to exclude the direct beam from the following analysis

In [None]:
gmags = unique_peaks.get_magnitudes()
gmags.data[gmags.data<10*scale] = 0
Gs = unique_peaks.data[np.where(gmags)]
Gs = pxm.DiffractionVectors(Gs)
print(np.shape(Gs)[0], ' unique vectors.')
Gs.axes_manager.set_signal_dimension(0)

Plot the unique vectors

In [None]:
Gs.plot_diffraction_vectors(unique_vectors=Gs,
                            distance_threshold=distance_threshold,
                            xlim=reciprocal_radius,
                            ylim=reciprocal_radius,
                            min_samples=min_samples,
                            image_to_plot_on=s_rb.max(),
                            image_cmap='magma',
                            plot_label_colors=False)

Optionally save and load the unique peaks

`np.save('peaks.npy', Gs.data)
Gs = np.load('peaks.npy', allow_pickle=True)
Gs = pxm.DiffractionVectors(Gs)
Gs.axes_manager.set_signal_dimension(0)`

### Calculate VDF images for all unique peaks

In [None]:
from pyxem.generators.vdf_generator import VDFGenerator

In [None]:
radius=scale*2

vdfgen = VDFGenerator(s_rb, Gs)
VDFs = vdfgen.get_vector_vdf_images(radius=radius)

In [None]:
VDFs.plot(cmap='magma', scalebar=False)

## 1(b) Watershed segmentation

First find adequate parameters by looking at watershed segmentation of a single VDF image.

In [None]:
from pyxem.utils.segment_utils import separate_watershed

In [None]:
min_distance = 5.5
min_size = 10
max_size = None
max_number_of_grains = np.inf
marker_radius = 2
exclude_border = 2

In [None]:
i = 27
sep_i = separate_watershed(
    VDFs.inav[i].data, min_distance=min_distance, min_size=min_size,
    max_size=max_size, max_number_of_grains=max_number_of_grains,
    exclude_border=exclude_border, marker_radius=marker_radius,
    threshold=True, plot_on=True)

Perform segmentation on all the VDF images

In [None]:
segs = VDFs.get_vdf_segments(min_distance=min_distance,
                                  min_size=min_size,
                                  max_size = max_size,
                                  max_number_of_grains = max_number_of_grains,
                                  exclude_border=exclude_border,
                                  marker_radius=marker_radius,
                                  threshold=True)
print(np.shape(segs.segments)[0],' segments were found.')

In [None]:
segs.segments.plot(cmap='magma_r')

## 1(c) Correlation of the VDF image segments

Calculate normalised cross-correlations between the VDF image segments to identify those that are related to the same crystal. If the correlation value exceeds *corr_threshold* for certain segments, those segments are summed. These segments are discarded if the number of these segments are below *vector_threshold*, as this number corresponds to the number of detected diffraction peaks associated with the single crystal. The *vector_threshold* criteria is included to avoid including segment images resulting from noise or incorrect segmentation. 

In [None]:
corr_threshold=0.7
vector_threshold=5
segment_threshold=4

In [None]:
corrsegs = segs.correlate_vdf_segments(
    corr_threshold=corr_threshold, vector_threshold=vector_threshold,
    segment_threshold=segment_threshold)
print(np.shape(corrsegs.segments)[0],' correlated segments were found.')

Simulate virtual diffraction patterns for each summed segment

In [None]:
sigma = scale*1.5

virtual_sig = corrsegs.get_virtual_electron_diffraction(
    calibration=scale, shape=(int(radius_px*2), int(radius_px*2)), sigma=sigma)
virtual_sig.set_diffraction_calibration(scale)
#hs.plot.plot_signals([corrsegs.segments, virtual_sig], cmap='magma_r')

Plot the final results from the VDF image-based segmentation

In [None]:
hs.plot.plot_images(corrsegs.segments, cmap='magma_r', axes_decor='off',
                    per_row=np.shape(corrsegs.segments)[0],
                    suptitle='', scalebar=False, scalebar_color='white',
                    colorbar=False,
                    padding={'top': 0.95, 'bottom': 0.05,
                             'left': 0.05, 'right':0.78})
hs.plot.plot_images(virtual_sig, cmap='magma_r', axes_decor='off',
                    per_row=np.shape(corrsegs.segments)[0],
                    suptitle='', scalebar=False, scalebar_color='white',
                    colorbar=False,
                    padding={'top': 0.95, 'bottom': 0.05,
                             'left': 0.05, 'right': 0.78})

In [None]:
s_rb = None

# <a id='nmf'></a> 3. NMF Based Segmentation

For the NMF-based segmentation, the required pre-processing, binning and alignment, were done at the start of the notebook. 

#### Create a signal mask for the direct beam
Create a signal mask so that the region in the centre of each PED pattern, including the direct beam, can be excluded in the machine learning. 

In [None]:
sm = pxm.Diffraction2D(s.inav[0,0])
signal_mask = sm.get_direct_beam_mask(radius=10)
signal_mask.plot()

#### Perform single value decomposition (SVD)

In [None]:
s.change_dtype('float32')
s.decomposition(algorithm='svd',
                normalize_poissonian_noise=True,
                centre='variables',
                signal_mask=signal_mask.data)

In [None]:
s.plot_decomposition_results()

#### Investigate the scree plot and use it as a guide to determine the number of components

In [None]:
num_comp=11

ax = s.plot_explained_variance_ratio(
    n=200, threshold=num_comp, hline=True, xaxis_labeling='ordinal',
    signal_fmt={'color':'k', 'marker':'.'}, 
    noise_fmt={'color':'gray', 'marker':'.'})

### NMF

In [None]:
s.decomposition(normalize_poissonian_noise=True,
                algorithm='nmf',
                output_dimension=num_comp,
                centre = 'variables',
                signal_mask=signal_mask.data)

In [None]:
s_nmf = s.get_decomposition_model(components=np.arange(num_comp))
#s_nmf.plot_decomposition_results()
factors = s_nmf.get_decomposition_factors()
loadings = s_nmf.get_decomposition_loadings()

Plot the NMF results

In [None]:
hs.plot.plot_images(loadings, cmap='magma_r', axes_decor='off', per_row=11,
             suptitle='', scalebar=False, scalebar_color='white', colorbar=False,
             padding={'top': 0.95, 'bottom': 0.05,
                      'left': 0.05, 'right':0.78})
hs.plot.plot_images(factors, cmap='magma_r', axes_decor='off', per_row=11,
             suptitle='', scalebar=False, scalebar_color='white', colorbar=False,
             padding={'top': 0.95, 'bottom': 0.05,
                      'left': 0.05, 'right':0.78})

Discard the components related to background (\#0) and to the carbon film (\#4)

In [None]:
from hyperspy.signals import Signal2D

In [None]:
factors = Signal2D(np.delete(factors.data, [0, 4], axis = 0))
loadings = Signal2D(np.delete(loadings.data, [0, 4], axis = 0))

In [None]:
hs.plot.plot_images(factors, cmap='magma_r', axes_decor='off',
                    per_row=9, suptitle='', scalebar=False,
                    scalebar_color='white', colorbar=False,
                    padding={'top': 0.95, 'bottom': 0.05,
                             'left': 0.05, 'right':0.78})

hs.plot.plot_images(loadings, cmap='magma_r', axes_decor='off',
                    per_row=9, suptitle='', scalebar=False,
                    scalebar_color='white', colorbar=False,
                    padding={'top': 0.95, 'bottom': 0.05,
                             'left': 0.05, 'right':0.78})

## 2(b) Correlation

NMF often leads to splitting of some crystals into several components. Therefore the correlation between loadings and between component patterns are calculated, and if both the correlation values for loadings and factors exceed threshold values, those loadings and factors are summed. 

#### Investigate the normalised cross-correlations
Calculate the matrix of normalised cross-correlation for both the loadings and patterns first, to find suitable correlation threshold values. 

In [None]:
from pyxem.utils.segment_utils import norm_cross_corr

num_comp = np.shape(loadings.data)[0]

corr_list_loadings = np.zeros((num_comp, num_comp))
for i in np.arange(num_comp):
    corr_list_loadings[i] = list(map(
        lambda x: norm_cross_corr(x, template=loadings.data[i]), loadings.data))

corr_list_factors = np.zeros((num_comp, num_comp))
for i in np.arange(num_comp):
    corr_list_factors[i] = list(map(
        lambda x: norm_cross_corr(x, template=factors.data[i]), factors.data))

plt.figure()
plt.imshow(corr_list_factors, cmap='cool', vmin=corr_list_factors.min(), vmax=1.0)

plt.figure()
plt.imshow(corr_list_loadings, cmap='cool', vmin=corr_list_loadings.min(), vmax=1.0)


In [None]:
from pyxem.signals.segments import LearningSegment

In [None]:
learn = LearningSegment(factors=factors, loadings=loadings)

In [None]:
corr_th_factors = 0.45
corr_th_loadings = 0.3

Perform correlation and summation of the factors and loadings

In [None]:
learn_corr = learn.correlate_learning_segments(
    corr_th_factors=corr_th_factors,
    corr_th_loadings=corr_th_loadings)

Plot the NMF reuslts after correlation and summation

In [None]:
hs.plot.plot_images(learn_corr.loadings, cmap='magma_r', axes_decor='off',
                    per_row=7, suptitle='', scalebar=False,
                    scalebar_color='white', colorbar=False,
                    padding={'top': 0.95, 'bottom': 0.05,
                             'left': 0.05, 'right':0.78})
hs.plot.plot_images(learn_corr.factors, cmap='magma_r', axes_decor='off',
                    per_row=7, suptitle='', scalebar=False,
                    scalebar_color='white', colorbar=False,
                    padding={'top': 0.95, 'bottom': 0.05,
                             'left': 0.05, 'right':0.78})

## 2(c) Watershed segmentation

Since one single loading map can contain several crystals, watershed segmentation is performed on the correlated loadings. 

First investigate how the parameters influence the segmentation on
one single loading map.

In [None]:
from pyxem.utils.segment_utils import separate_watershed

In [None]:
min_distance = 10
min_size = 50
max_size = None
max_number_of_grains = np.inf
marker_radius = 2
exclude_border = 1
threshold = True

In [None]:
i =1
sep_i = separate_watershed(
    learn_corr.loadings.data[i], min_distance=min_distance,
    min_size=min_size, max_size=max_size, 
    max_number_of_grains=max_number_of_grains,
    exclude_border=exclude_border, 
    marker_radius=marker_radius, threshold=True, plot_on=True)

Set a threshold for the minimum intensity value that a loading segment must contain in order to be kept. 

In [None]:
min_intensity_threshold = 10000

In [None]:
learn_corr_seg = learn_corr.separate_learning_segments(
    min_intensity_threshold=min_intensity_threshold,
    min_distance = min_distance, min_size = min_size,
    max_size = max_size, 
    max_number_of_grains = max_number_of_grains,
    exclude_border = exclude_border,
    marker_radius = marker_radius, threshold = True)

Plot the final results from the NMF-based segmentation

In [None]:
hs.plot.plot_images(learn_corr_seg.loadings, 
                    cmap='magma_r', axes_decor='off',
                    per_row=10, suptitle='', scalebar=False,
                    scalebar_color='white', colorbar=False,
                    padding={'top': 0.95, 'bottom': 0.05,
                             'left': 0.05, 'right':0.78})

hs.plot.plot_images(learn_corr_seg.factors, 
                    cmap='magma_r', axes_decor='off',
                    per_row=10, suptitle='', scalebar=False,
                    scalebar_color='white', colorbar=False,
                    padding={'top': 0.95, 'bottom': 0.05,
                             'left': 0.05, 'right':0.78})