In [2]:
import saxtal_functions as sax
import matplotlib.pyplot as plt
import numpy as np
from itertools import product
from tqdm import tqdm

In [3]:
filename = 'test_data/test_lattices/000022735715047590394_SFO2_118_0019_patch_aligned_doseweighted_bin_2.mrc'
filename_out = 'test_data/test_lattices/000022735715047590394_SFO2_118_0019_patch_aligned_doseweighted_bin_2_masked.mrc'
threshold_method='sd' 
pixel_size = 1.048
verbose=False
show_plots=False
threads=16
sigma=15 
num_sd=3.0
num_sd_secondpass=2.0
x_window_percent=(0, 0.7)
y_window_percent=(0, 0.995)
miller_index_buffer=2
box_radius=10
min_lattice_size=5
mask_hotpixels=False
mask_radius=5
replace_distance_percent=0.05
return_spots=True

In [3]:
# To test multiple lattices, use:

# filename = 'test_data/test_lattices/000022735715047590394_SFO2_118_0019_patch_aligned_doseweighted_bin_2.mrc'
# filename_out = 'test_data/test_lattices/000022735715047590394_SFO2_118_0019_patch_aligned_doseweighted_bin_2_masked.mrc'
# threshold_method='sd' 
# pixel_size = 1.048
# verbose=True
# show_plots=True
# threads=16
# sigma=15 
# num_sd=3.0
# num_sd_secondpass=2.0
# x_window_percent=(0, 0.7)
# y_window_percent=(0, 0.995)
# miller_index_buffer=2
# box_radius=10
# min_lattice_size=5
# mask_hotpixels=False
# mask_radius=5
# replace_distance_percent=0.05
# return_spots=False

In [4]:
sax.mask_image(filename,
               filename_out,
               threshold_method,
               pixel_size,
               verbose = verbose,
               show_plots = show_plots,
               threads = threads,
               sigma = sigma,
               num_sd = num_sd,
               num_sd_secondpass = num_sd_secondpass,
               x_window_percent = x_window_percent,
               y_window_percent = y_window_percent,
               miller_index_buffer = miller_index_buffer,
               box_radius = box_radius,
               min_lattice_size = min_lattice_size,
               mask_hotpixels = mask_hotpixels,
               mask_radius= mask_radius,
               replace_distance_percent = replace_distance_percent,
               return_spots = return_spots)

[5.73]


array([[4900, 4846, 5040, 5185, 5132, 5323, 5169, 5473, 5419, 5366, 5258,
        5204, 5616, 5562, 5508, 5706, 5652, 5545, 5437,   89,   35, 5742,
        5688, 5582, 5529, 5477,  232,  178,  125,   17, 5671,  324,  270,
         215,  161,  109,  465,  359,  303,  250,  144, 5693,  557,  341,
         287,   66,  699,  651,  539,  485,  432,  373,  268,  789,  774,
         719,  391],
       [ 224,  324,  182,  150,  239,  118,  385,   74,  167,  259,  439,
         531,   38,  129,  220,   92,  184,  366,  554,   53,  146,  238,
         330,  513,  606,  700,   16,  108,  200,  384,  567,   70,  167,
         254,  346,  438,   34,  216,  311,  400,  583,  955,   88,  455,
         539,  913,   52,  138,  327,  417,  510,  604,  791,  106,  343,
         436,  977]])

In [None]:
# Import the image
image, header = sax.import_mrc(filename)

# Perform an FFT of the image
padded_fft = sax.scipy_fft(image, verbose, threads)

# Subtract the FFT from a Gaussian-smoothed FFT
log_diff_spectrum, smoothed_spectrum = sax.generate_diff_spectrum(padded_fft, sigma)

# Find diffraction spots
if threshold_method == "quantile":
    diffraction_indices, diffraction_amplitudes = find_diffraction_spots_quantile(log_diff_spectrum, quantile, x_window_percent, y_window_percent)
if threshold_method == "sd":
    diffraction_indices, diffraction_amplitudes = sax.find_diffraction_spots_sd(log_diff_spectrum, num_sd, x_window_percent, y_window_percent)
else:
    print("No thresholding method specified. Please specify a method using the threshold_method parameter.")
    

# Return some info if function is verbose
if verbose:
    print("Number of first-pass spots found: " + str(diffraction_indices.shape[0]))

# Start while ----------------------------

# Look for the lattice
combined_nonredundant_lattice, unit_cell_dimensions, highest_resolution = sax.find_lattice(diffraction_indices,
                                                                                       diffraction_amplitudes, 
                                                                                       log_diff_spectrum, 
                                                                                       pixel_size, 
                                                                                       show_plots=True, 
                                                                                       verbose=True, 
                                                                                       num_sd_secondpass=2, 
                                                                                       miller_index_buffer=2, 
                                                                                       box_radius=10, 
                                                                                       min_lattice_size=5)


# Filter out the hot pixels - leave off, depreciated
if mask_hotpixels:
    if verbose: print("Removing hot pixels...")
    combined_nonredundant_lattice = combined_nonredundant_lattice[remove_hotpixels(combined_nonredundant_lattice, verbose)]
    if verbose: print(str(num_spots - combined_nonredundant_lattice.shape[0]) + " hot pixels removed.")


In [None]:
for indices in np.transpose(combined_nonredundant_lattice):
    print(indices)

In [None]:
masked_fft = replace_diffraction_spots(padded_fft, combined_nonredundant_lattice, replace_distance_percent)

In [None]:
image, header = sax.import_mrc(filename)
padded_fft = sax.scipy_fft(image, verbose, 16)
log_diff_spectrum, smoothed_spectrum = sax.generate_diff_spectrum(padded_fft, sigma)

In [None]:
diffraction_indices, diffraction_amplitudes = sax.find_diffraction_spots_sd(log_diff_spectrum, num_sd, x_window_percent, y_window_percent)

In [None]:
combined_nonredundant_lattice, unit_cell_dimensions, highest_resolution = sax.find_lattice(diffraction_indices, 
                                                                                           diffraction_amplitudes, 
                                                                                           log_diff_spectrum,
                                                                                           pixel_size = 1.048,
                                                                                           show_plots=True,
                                                                                           verbose=True, 
                                                                                           num_sd_secondpass=2,
                                                                                           miller_index_buffer=2, 
                                                                                           box_radius=10,
                                                                                           min_lattice_size=5)

print(unit_cell_dimensions, highest_resolution)

In [None]:
mask_indices_array = sax.generate_lattice_mask_indices(combined_nonredundant_lattice)

In [None]:
# Pass these values to replace_diffraction_spots

masked_fft = sax.replace_diffraction_spots(padded_fft, mask_indices_array)

In [None]:
# From here, calculate a new log_diff_spectrum and rerun pipeline
# If the lattice has more than 5 points...do something. Remove it?

In [None]:
log_diff_spectrum, smoothed_spectrum = sax.generate_diff_spectrum(masked_fft, sigma)

In [None]:
diffraction_indices, diffraction_amplitudes = sax.find_diffraction_spots_sd(log_diff_spectrum, num_sd, x_window_percent, y_window_percent)

In [None]:
combined_nonredundant_lattice, unit_cell_dimensions, highest_resolution = sax.find_lattice(diffraction_indices, 
                                                                                           diffraction_amplitudes, 
                                                                                           log_diff_spectrum,
                                                                                           pixel_size = 1.048,
                                                                                           show_plots=True,
                                                                                           verbose=True, 
                                                                                           num_sd_secondpass=2, 
                                                                                           miller_index_buffer=2, 
                                                                                           box_radius=10,
                                                                                           min_lattice_size=5)

print(combined_nonredundant_lattice, unit_cell_dimensions, highest_resolution)