In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
from scipy import signal
from scipy.fftpack import fft2, ifft2, fftshift
from skimage import filters, feature, transform
import mrcfile
#import multiprocessing
from functools import partial
from joblib import Parallel, delayed

# For demo

In [None]:
with mrcfile.open('0_data/00040_3_0.mrc', permissive=True) as mrc:
    I = np.squeeze(mrc.data.astype(np.float32))

# Display the Cryo-EM micrograph
plt.figure(figsize=(6, 6))
plt.imshow(I, cmap='gray')
plt.title('Cryo-EM Micrograph')
plt.axis('off')
plt.show()

# Generate templates

In [None]:

with mrcfile.open('projections_mpi.mrcs', permissive=True) as mrc:
    templates = mrc.data.astype(np.float32)

In [None]:
#templates.shape
# Display one of the templates
template_index = 0  # Change this to view different templates
T = templates[template_index]

plt.figure(figsize=(4, 4))
plt.imshow(T, cmap='gray')
plt.title(f'Template Projection {template_index}')
plt.axis('off')
plt.show()
print(f"Template dimension : {templates.shape[1],templates.shape[2]}")

# Band pass filtering

In [None]:
# Define the band-pass filter function
def bandpass_filter(img, low_sigma, high_sigma):
    low_pass = ndimage.gaussian_filter(img, high_sigma)
    high_pass = img - ndimage.gaussian_filter(img, low_sigma)
    return high_pass - low_pass

# Apply the band-pass filter to the micrograph
I_filtered = bandpass_filter(I, low_sigma=1, high_sigma=5)

# Apply the band-pass filter to each template
T_filtered_list = []
for T in templates:
    T_filtered = bandpass_filter(T, low_sigma=1, high_sigma=5)
    T_filtered_list.append(T_filtered)

# For demonstration, use the first filtered template
T_filtered = T_filtered_list[0]

# Display the filtered micrograph and template
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(I_filtered, cmap='gray')
ax[0].set_title('Filtered Cryo-EM Micrograph')
ax[0].axis('off')

ax[1].imshow(T_filtered, cmap='gray')
ax[1].set_title('Filtered Template')
ax[1].axis('off')
plt.show()


In [None]:
# Normalize the micrograph
I_mean = np.mean(I_filtered)
I_std = np.std(I_filtered)
I_norm = (I_filtered - I_mean) / I_std

# Normalize each filtered template
T_norm_list = []
for T_filtered in T_filtered_list:
    T_mean = np.mean(T_filtered)
    T_std = np.std(T_filtered)
    T_norm = (T_filtered - T_mean) / T_std
    T_norm_list.append(T_norm)


In [64]:
# Function to pad a template to the same size as the image
def pad_template_to_image(template, image_shape):
    pad_y = image_shape[0] - template.shape[0]
    pad_x = image_shape[1] - template.shape[1]
    pad_top = pad_y // 2
    pad_bottom = pad_y - pad_top
    pad_left = pad_x // 2
    pad_right = pad_x - pad_left
    template_padded = np.pad(template, ((pad_top, pad_bottom), (pad_left, pad_right)), mode='constant')
    return template_padded

# Pad each rotated template
templates_padded = []
for T_rot in templates:
    T_padded = pad_template_to_image(T_rot, I_norm.shape)
    templates_padded.append(T_padded)


### Serial cross correlation

In [65]:
# Perform cross-correlation with each padded template
correlation_maps = []

# Precompute the FFT of the normalized micrograph
I_fft = np.fft.fft2(I_norm)

for idx, T_padded in enumerate(templates_padded):
    # Compute FFT of the padded template
    T_fft = np.fft.fft2(T_padded)
    # Compute cross-correlation using inverse FFT
    corr = np.fft.ifft2(I_fft * np.conj(T_fft))
    # Take the real part (the imaginary part should be negligible)
    corr = np.real(corr)
    # Normalize the correlation map
    corr_mean = np.mean(corr)
    corr_std = np.std(corr)
    corr_norm = (corr - corr_mean) / corr_std
    correlation_maps.append(corr_norm)


KeyboardInterrupt: 

In [None]:
# Aggregate the correlation maps by taking the maximum at each pixel
NCC_max = np.max(correlation_maps, axis=0)

# Display the aggregated cross-correlation map
plt.figure(figsize=(6, 6))
plt.imshow(NCC_max, cmap='jet')
plt.title('Aggregated Normalized Cross-Correlation Map')
plt.colorbar()
plt.axis('off')
plt.show()


In [None]:
# Aggregate the correlation maps by taking the maximum at each pixel
NCC_max = np.max(correlation_maps, axis=0)

# Display the aggregated cross-correlation map
plt.figure(figsize=(6, 6))
plt.imshow(NCC_max, cmap='jet')
plt.title('Aggregated Normalized Cross-Correlation Map')
plt.colorbar()
plt.axis('off')
plt.show()


In [None]:
# Set a threshold based on the statistics of NCC_max
threshold = np.mean(NCC_max) + 3 * np.std(NCC_max)
NCC_thresh = np.copy(NCC_max)
NCC_thresh[NCC_thresh < threshold] = 0


In [None]:
# Estimate the size of the template for peak separation
template_size = T.shape[0]  # Assuming square templates

# Find peaks in the thresholded correlation map
coordinates = feature.peak_local_max(
    NCC_thresh,
    min_distance=template_size // 2,
    threshold_abs=threshold
)

# Display detected peaks on the original micrograph
plt.figure(figsize=(8, 8))
plt.imshow(I, cmap='gray')
plt.scatter(coordinates[:, 1], coordinates[:, 0], c='r', marker='x')
plt.title('Detected Particles')
plt.axis('off')
plt.show()


In [None]:
# Assign orientation to each detected particle
detected_particles = []
for coord in coordinates:
    y, x = coord
    # Extract the local correlation values at this coordinate for each rotated template
    local_corrs = [corr[y, x] for corr in correlation_maps]
    # Find the index of the maximum correlation
    best_match_idx = np.argmax(local_corrs)
    # Map the index back to the corresponding template and rotation angle
    num_angles = len(rotation_angles)
    template_idx = best_match_idx // num_angles
    angle_idx = best_match_idx % num_angles
    best_template = templates[template_idx]
    best_angle = rotation_angles[angle_idx]
    detected_particles.append({
        'position': (x, y),
        'template_index': template_idx,
        'angle': best_angle
    })

# Display the results
for idx, particle in enumerate(detected_particles):
    print(f"Particle {idx+1}: Position={particle['position']}, "
          f"Template Index={particle['template_index']}, "
          f"Orientation={particle['angle']} degrees")
