In [None]:
# This notebook will be used to develop a function, mask_stack(), 
# that takes a motion-corrected micrograph, finds lattice points, and then masks
# each frame in the movie.

# Inputs:
# mask_images inputs
# patch motion corrected, dose-weighted micrograph
# raw movie


# Outputs:
# raw movie with each frame masked

In [1]:
from funcs_mrcio import iwrhdr_opened, irdhdr_opened, iwrsec_opened, irdsec_opened
import saxtal_functions as sax
import numpy as np
import time
import scipy.fft as sfft
from skimage.exposure import rescale_intensity
import matplotlib.pyplot as plt

In [2]:
movie_filename = "test_movie.mrc"
micrograph_filename = "test_micrograph.mrc"
threads = 16
replace_distance_percent=0.05


In [3]:
sax.mask_movie(movie_filename,
               micrograph_filename,
               threshold_method='sd', 
               verbose=True, 
               threads=16, 
               sigma=9, 
               num_sd=4.0, 
               x_window_percent=(0.005, 0.7),
               y_window_percent=(0.005, 0.995),
               mask_hotpixels=True,
               replace_distance_percent=0.05,
               return_spots=True)

scipy_batch_fft(): FFT performed in 1073.83 milliseconds.
scipy_batch_fft(): 35.79 milliseconds per frame.
scipy_fft(): FFT performed in 47.13 milliseconds.
Number of diffraction spots found: 5771
Removing hot pixels...


100%|█████████████████████████████████████| 5771/5771 [00:01<00:00, 3319.88it/s]


5225 hot pixels removed.
scipy_inverse_batch_fft(): iFFT performed in 860.25 milliseconds.
scipy_inverse_batch_fft(): 28.68 milliseconds per frame.
Export complete!
test_movie.mrc masked successfully!


In [None]:
movie, header = sax.import_movie(movie_filename)

In [None]:
# Start the timer
start_time = time.time()

# Perform an FFT over the 0th and 1st axis of the movie
movie_FFT = sfft.rfftn(movie, 
                       s=(np.max(movie[:,:,0].shape), np.max(movie[:,:,0].shape)), 
                       axes=(0,1),
                       overwrite_x=True,
                       workers=threads)

# Stop the timer
end_time = time.time()
    
print("scipy_fft(): FFT performed in", np.round((end_time-start_time)*1000, 2), "milliseconds.")
print("scipy_fft():", np.round((end_time-start_time)*1000/movie.shape[2], 2), "milliseconds per image.")

In [None]:
# Look at the 1st, 8th, and 15th frame

plt.matshow(np.log10(np.abs(movie_FFT[:,:,0]))[:500, :500],
           cmap = "Greys_r",
           vmax = 3,
           vmin = 1)

plt.matshow(np.log10(np.abs(movie_FFT[:,:,7]))[:500, :500],
           cmap = "Greys_r",
           vmax = 3,
           vmin = 1)

plt.matshow(np.log10(np.abs(movie_FFT[:,:,14]))[:500, :500],
           cmap = "Greys_r",
           vmax = 3,
           vmin = 1)

In [None]:
# Find diffraction spots by running mask_image

diffraction_spots = sax.mask_image(micrograph_filename, 
               threshold_method='sd', 
               verbose=False, 
               threads=16, 
               sigma=9, 
               num_sd=4.0, 
               x_window_percent=(0.005, 0.7),
               y_window_percent=(0.005, 0.995),
               mask_hotpixels=True,
               replace_distance_percent=0.05,
               return_spots=True)

In [None]:
# Make a new array to hold masked movie

masked_movie_FFT = np.empty(movie_FFT.shape, dtype=np.complex64)

In [None]:
# Replace diffraction spots in each subframe

for z in range(header['nz']):
    frame_FFT = movie_FFT[:,:,z]
    masked_movie_FFT[:,:,z] = sax.replace_diffraction_spots(frame_FFT, 
                                                            diffraction_spots, 
                                                            replace_distance_percent)

In [None]:
# Look at the 1st, 8th, and 15th frame of the masked movie

plt.matshow(np.log10(np.abs(masked_movie_FFT[:,:,0]))[:200, :200],
           cmap = "Greys_r",
           vmax = 3,
           vmin = 1)

plt.matshow(np.log10(np.abs(masked_movie_FFT[:,:,7]))[:200, :200],
           cmap = "Greys_r",
           vmax = 3,
           vmin = 1)

plt.matshow(np.log10(np.abs(masked_movie_FFT[:,:,14]))[:200, :200],
           cmap = "Greys_r",
           vmax = 3,
           vmin = 1)

In [None]:
# Perform the inverse transform of the movie

# Start the timer
start_time = time.time()

# Perform an FFT over the 0th and 1st axis of the movie
masked_movie = sfft.irfftn(masked_movie_FFT, 
                           axes=(0,1),
                           overwrite_x=True,
                           workers=threads)

# Stop the timer
end_time = time.time()
    
print("scipy_inverse_batch_fft(): iFFT performed in", np.round((end_time-start_time)*1000, 2), "milliseconds.")
print("scipy_inverse_batch_fft():", np.round((end_time-start_time)*1000/movie.shape[2], 2), "milliseconds per image.")

In [None]:
unpadded_masked_movie = masked_movie[0:movie.shape[0], 0:movie.shape[1], 0:movie.shape[2]]

In [None]:
# Generate a new filename
new_movie_filename = "masked_output/" + movie_filename[0:-4] + "_masked.mrc"

# Generate a new header
nx, ny, nz = unpadded_masked_movie.shape
nxyz = np.array([nx, ny, nz], dtype=np.float32)
dmin = np.min(unpadded_masked_movie)
dmax = np.max(unpadded_masked_movie)
dmean = np.sum(unpadded_masked_movie)/(nx*ny*nz)

In [None]:
# Open a new file
masked_movie_mrc = open(new_movie_filename, 'wb')

In [None]:
# Write the header to the new file
iwrhdr_opened(masked_movie_mrc, 
              nxyz, 
              dmin, 
              dmax, 
              dmean, 
              mode=2)

# Write the rebinned array to the new file
iwrsec_opened(unpadded_masked_movie, masked_movie_mrc)