In [1]:
1

1

In [2]:
%matplotlib widget
import numpy as np
from numpy.fft import fft2
import matplotlib.pyplot as plt
plt.rcParams['figure.max_open_warning'] = 2000
import time
from pathlib import Path
import h5py
from scipy.ndimage import fourier_shift
import logging

start_time = time.time()

from transforms import (
    prepare_correlation_data,
    correlate_images,
    correlation,
    translate,
    pad_images,
    set_shear_and_scale_ranges,
    set_transform_matrices,
    transform,
    transform_single_image,
    plot_transformed_images,
    normalise_max,
)

pad_factor = 1.3  # larger than 1, approx 1.25 is good
steps = 11  # odd, 5,7 is normal
correlation_method = "phase"  # "phase", "cross", "hybrid"
gpu = True  # True / False
shear_steps = steps
scale_steps = steps

test_dataset = "A"

if gpu:
    import tensorflow as tf

    tf.device("/gpu:0")

if test_dataset == "A":
    import hyperspy.api as hs

    def get_haadf(slist):
        "Helper function for picking out HAADF from velox format"
        for s in slist:
            if "HAADF" in s.metadata.General.title:
                return s

    folder = r"C:\Users\Me\Documents\STEM Images"
    signals = [
        get_haadf(hs.load(str(f))) for f in Path(folder).iterdir() if f.is_file()
    ]
    images = [s.data.astype("float32") for s in signals]
    angles = [
        float(s.original_metadata.Scan.ScanRotation) * 180 / np.pi for s in signals
    ]


elif test_dataset == "B":
    from scipy.misc import face, ascent
    from scipy.ndimage import rotate

    img1 = face(True)  # [:,:768]
    img1 = np.pad(img1, 200, "constant")
    SHIFT = [80, 90]
    ANGLE = 45
    img2 = np.fft.ifft2(fourier_shift(fft2(img1), SHIFT)).real
    img2 = rotate(img2, ANGLE, reshape=False)
    images = [img1, img2]
    angles = (0, ANGLE)

elif test_dataset == "C":
    import h5py

    file = (
        "../data_examples/nonlinear_drift_correction_synthetic_dataset_for_testing.mat"
    )
    f = h5py.File(file, mode="r")
    img1 = f["image00deg"][:]
    img2 = f["image90deg"][:]
    # img3 = f["imageIdeal"][:]
    images = [img1, img2]
    angles = [0, 90]

else:
    print("Specified wrong data?")

images = [normalise_max(img) for img in images]
print("Padding images")
padded_images, weights = pad_images(images, pad_factor=pad_factor)
GB = np.round(
    float(np.prod(np.shape(padded_images)))
    * shear_steps
    * scale_steps
    * padded_images[0].dtype.itemsize
    / 1e9,
    2,
)  # 32 bytes per complex64

print("Estimating memory usage of {}GB".format(GB))
# Set the various sheares, scales, first in terms of range
sheares, scales = set_shear_and_scale_ranges(
    padded_images[0].shape, shear_steps=shear_steps, scale_steps=scale_steps, pix=2
)
# Then in terms of transform matrices
print("Calculating transform matrices")
transform_matrices = set_transform_matrices(angles, sheares, scales)

rotation_matrices, shear_matrices, scale_matrices = transform_matrices

# Scale and shear the masked data
print("Transforming data")
data = transform(
    padded_images, rotation_matrices, shear_matrices, scale_matrices, weights=weights
)

data = data.astype("float32")

data = data.reshape((len(padded_images), -1) + padded_images[0].shape)
data = data.swapaxes(0, 1)

print("Preparing correlation data")
correlation_data = prepare_correlation_data(
    data, weights, method=correlation_method, gpu=gpu
)

print("Correlating")
max_indexes, shifts_list = correlate_images(
    correlation_data, method=correlation_method, gpu=gpu
)
# Plot masked images
print("Calculating final images")
data2 = np.array(data).swapaxes(0, 1)
image_sums = []

for i, img_array in enumerate(data2[1:]):
    max_index = int(max_indexes[i])
    img1 = data2[0, max_index]
    img2 = img_array[max_index]
    shift = shifts_list[i][max_index]
    img2_shifted = np.fft.ifftn(fourier_shift(fft2(img2), shift)).real
    image_sums.append(img1 + img2_shifted)
    print("Image drifted ({}, {}) pixels since first frame".format(shift[0], shift[1]))
# fig, AX = plt.subplots(ncols=len(padded_images) - 1, squeeze=False)
# for i, ax in enumerate(np.reshape(AX, np.prod(AX.shape))):
#     ax.imshow(image_sums[i], cmap="viridis")
#     ax.axis("off")

i1, i2 = plot_transformed_images(
    padded_images,
    angles,
    shifts_list,
    max_indexes,
    sheares,
    scales,
    shear_steps,
    scale_steps,
)

print("--- %s seconds ---" % (time.time() - start_time))

Padding images
Estimating memory usage of 0.43GB
Calculating transform matrices
Transforming data
Preparing correlation data
Correlating


HBox(children=(IntProgress(value=0, max=121), HTML(value='')))


Calculating final images
Image drifted (0.0, -1.0) pixels since first frame


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

--- 15.986728191375732 seconds ---


In [2]:
from tqdm.auto import tqdm

In [7]:
plt.figure()
plt.imshow(data2[1,0])
    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.image.AxesImage at 0x23d4abc9dd8>

In [28]:
I1, I2 = i1.copy(), i2.copy()

In [29]:
fig, ax = plt.subplots()
row_indices = np.nonzero(i1.mean(axis=1))[0]

rollrange = np.arange(-5, 5, 1)
for row_index in tqdm(row_indices):
    diff = []
    row2 = I2[row_index]
    for roll in rollrange:
        temprow = I1[row_index].copy()
        temprow = np.roll(temprow, roll)
        
        nonzeromask = (temprow > 0) & (row2 > 0)
        
        meansquare = np.mean(np.abs((temprow[nonzeromask] - row2[nonzeromask])))
        diff.append(meansquare)
    
    diff_norm = diff/np.max(diff)
    ax.plot(rollrange, diff_norm)
    best_roll = rollrange[np.argmin(diff_norm)] if diff_norm.min() < 0.5 else 0
    I1[row_index] = np.roll(I1[row_index], best_roll)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

HBox(children=(IntProgress(value=0, max=512), HTML(value='')))




In [30]:
fig, ax = plt.subplots()
col_indices = np.nonzero(i2.mean(axis=0))[0]

rollrange = np.arange(-5, 5, 1)
for col_index in tqdm(col_indices):
    diff = []
    col1 = I1[:,col_index]
    for roll in rollrange:
        tempcol = I2[:,col_index].copy()
        tempcol = np.roll(tempcol, roll)
        
        nonzeromask = (tempcol > 0) & (col1 > 0)

        meansquare = np.mean(np.abs((tempcol[nonzeromask] - col1[nonzeromask])))
        diff.append(meansquare)
    
    diff_norm = diff/np.max(diff)
    ax.plot(rollrange, diff_norm)
    best_roll = rollrange[np.argmin(diff_norm)] if diff_norm.min() < 0.5 else 0
    I2[:,col_index] = np.roll(I2[:,col_index], best_roll)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

HBox(children=(IntProgress(value=0, max=664), HTML(value='')))

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)





In [27]:
fig, AX = plt.subplots(ncols=2, nrows=2)
ax1, ax2, ax3, ax4 = AX.flatten()
ax1.imshow(i1)
ax2.imshow(I1)
ax3.imshow(i2)
ax4.imshow(I2)

fig, (ax1, ax2) = plt.subplots(ncols=2)
ax1.imshow(i1)
ax2.imshow(I1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.image.AxesImage at 0x1dfcda29588>

In [48]:
scale = signals[0].axes_manager[0].scale
units = signals[0].axes_manager[0].units

s = hs.signals.Signal2D(i1+i2)
for ax, ax_orig in zip(s.axes_manager._axes, signals[0].axes_manager._axes):
    ax.scale = ax_orig.scale
    ax.units = ax_orig.units
    ax_orig.name = ax_orig.name

In [50]:
def log(s):
     return np.log(s.fft(shift=True, apodization=True).amplitude)

In [51]:
hs.plot.plot_images([log(s), log(signals[0])])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.axes._subplots.AxesSubplot at 0x252615b5b38>,
 <matplotlib.axes._subplots.AxesSubplot at 0x252618f13c8>]

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x252632247b8>]

In [23]:
diff

[0.01034282262505044,
 0.010341032086210292,
 0.0103394161422629,
 0.010337268058566544,
 0.010334976755118878,
 0.010332069531845206,
 0.010329426676503014,
 0.010325663092799767,
 0.01032250077622829,
 0.010318944721737226,
 0.010315684938124766,
 0.01031335500782157,
 0.010311510306817732,
 0.010310401852147958,
 0.010310428965563707,
 0.010310437774838971,
 0.010310731262247326,
 0.010310650079380802,
 0.010309863096337265,
 0.010312525023922383]

In [7]:
A = np.array([[1,2,3], [4,5,6]])

In [8]:
A[0]

array([1, 2, 3])

In [6]:
np.roll(, 1, 1)

array([[3, 1, 2],
       [6, 4, 5]])

In [2]:
i1

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])