# Parallel pipieline

In [1]:
import time# Sys
import sys
import os

# Numerical
import numpy as np
import mrcfile
import healpy as hp
# Plotting
import matplotlib.pyplot as plt
# Signal proc / Linalg
from scipy import ndimage
from scipy import signal
from scipy.ndimage import rotate
from scipy.fftpack import fft2, ifft2, fftshift
from skimage import filters, feature, transform
# Optimization
import multiprocessing
from itertools import product
from joblib import Parallel, delayed
from scipy.ndimage import affine_transform

In [2]:
# Functions

def read_mrc_file(file_path):
    """Reads a volume from an MRC file."""
    if not os.path.isfile(file_path):
        print(f"Error: File '{file_path}' not found.")
        sys.exit(1)
    try:
        with mrcfile.open(file_path, permissive=True) as mrc:
            volume = mrc.data.astype(np.float32)
            voxel_size = mrc.voxel_size
            if volume.ndim != 3:
                print("Error: Input volume is not 3-dimensional.")
                sys.exit(1)
    except Exception as e:
        print(f"Error reading MRC file: {e}")
        sys.exit(1)
    return volume, voxel_size

def generate_healpix_pixel_bounds(nside):
    """Generates boundaries of theta and phi for each Healpix pixel."""
    npix = hp.nside2npix(nside)
    pixel_bounds = []
    for pix in range(npix):
        # Get the boundary vertices for this pixel
        vertices = hp.boundaries(nside, pix, step=1, nest=False)  # shape (3, N)
        # Convert x, y, z to theta, phi
        theta_vertices, phi_vertices = hp.vec2ang(vertices.T)  # Transpose to shape (N, 3)
        # Adjust phi to be in [0, 2π)
        phi_vertices = phi_vertices % (2 * np.pi)
        # Handle phi wrapping around 0 and 2π
        phi_diff = np.diff(np.sort(phi_vertices))
        if np.any(phi_diff > np.pi):
            # Correct for wrapping
            phi_vertices[phi_vertices < np.pi] += 2 * np.pi
        # Update min and max after correction
        phi_min = phi_vertices.min() % (2 * np.pi)
        phi_max = phi_vertices.max() % (2 * np.pi)
        theta_min = theta_vertices.min()
        theta_max = theta_vertices.max()
        pixel_bounds.append((theta_min, theta_max, phi_min, phi_max))
    return pixel_bounds

def generate_angles_in_pixel(theta_min, theta_max, phi_min, phi_max, angular_spacing_rad):
    # Generate theta values from theta_min to theta_max with steps of angular_spacing_rad
    theta_vals = np.arange(theta_min, theta_max, angular_spacing_rad)
    # Handle phi wrapping
    if phi_max >= phi_min:
        phi_vals = np.arange(phi_min, phi_max, angular_spacing_rad)
    else:
        # phi_max < phi_min due to wrapping around 2π
        phi_vals_part1 = np.arange(phi_min, 2 * np.pi, angular_spacing_rad)
        phi_vals_part2 = np.arange(0, phi_max, angular_spacing_rad)
        phi_vals = np.concatenate((phi_vals_part1, phi_vals_part2))
    # Create grid of theta and phi values
    theta_grid, phi_grid = np.meshgrid(theta_vals, phi_vals, indexing='ij')
    # Adjust phi_grid to be in [0, 2π)
    phi_grid = phi_grid % (2 * np.pi)
    return theta_grid.flatten(), phi_grid.flatten()

def generate_all_angles(nside, angular_spacing_deg):
    angular_spacing_rad = np.radians(angular_spacing_deg)
    pixel_bounds = generate_healpix_pixel_bounds(nside)
    phi_angles_list = []
    theta_angles_list = []
    for idx, (theta_min, theta_max, phi_min, phi_max) in enumerate(pixel_bounds):
        theta_vals, phi_vals = generate_angles_in_pixel(
            theta_min, theta_max, phi_min, phi_max, angular_spacing_rad
        )
        phi_angles_list.append(phi_vals)
        theta_angles_list.append(theta_vals)
        print(f"Pixel {idx + 1}/{len(pixel_bounds)}: Generated {len(theta_vals)} angles")
    # Concatenate all angles
    phi_angles = np.concatenate(phi_angles_list)
    theta_angles = np.concatenate(theta_angles_list)
    return phi_angles, theta_angles

# Define the band-pass filter function
def bandpass_filter(img, low_sigma, high_sigma):
    low_pass = ndimage.gaussian_filter(img, high_sigma)
    high_pass = img - ndimage.gaussian_filter(img, low_sigma)
    return high_pass - low_pass

def rotate_and_project_opt(volume, phi, theta, psi):
    """Efficiently rotates the volume and projects it along the Z-axis."""
    # Compute combined rotation matrix
    cos_phi, sin_phi = np.cos(np.radians(phi)), np.sin(np.radians(phi))
    cos_theta, sin_theta = np.cos(np.radians(theta)), np.sin(np.radians(theta))
    cos_psi, sin_psi = np.cos(np.radians(psi)), np.sin(np.radians(psi))

    # Rotation matrices for each axis
    R_phi = np.array([
        [cos_phi, -sin_phi, 0],
        [sin_phi, cos_phi, 0],
        [0, 0, 1]
    ])

    R_theta = np.array([
        [cos_theta, 0, sin_theta],
        [0, 1, 0],
        [-sin_theta, 0, cos_theta]
    ])

    R_psi = np.array([
        [cos_psi, -sin_psi, 0],
        [sin_psi, cos_psi, 0],
        [0, 0, 1]
    ])

    # Combined rotation matrix
    R = R_psi @ R_theta @ R_phi

    # Center the volume for rotation
    center = np.array(volume.shape) / 2
    offset = center - R @ center

    # Apply affine transformation (rotation)
    rotated = affine_transform(volume, R, offset=offset, order=1)

    # Project along Z-axis
    projection = np.sum(rotated, axis=0)

    return projection



In [3]:
def cross_correlate_top_left(image, template):
    # Ensure the template is zero-padded to the size of the image
    template_padded = np.zeros_like(image)
    template_shape = template.shape
    # Place the template at the top-left corner of the padded array
    template_padded[:template_shape[0], :template_shape[1]] = template

    # Compute the FFTs
    image_fft = fft2(image)
    template_fft = fft2(template_padded)

    # Compute cross-correlation
    cross_corr = np.real(ifft2(image_fft * np.conj(template_fft)))

    # Normalize the cross-correlation
    # Compute local sums for normalization (using convolution with ones)
    template_sum = np.sum(template ** 2)
    image_local_sums = ndimage.uniform_filter(image ** 2, size=template_shape)
    denominator = np.sqrt(template_sum * image_local_sums)
    denominator[denominator == 0] = 1  # Prevent division by zero

    normalized_cross_corr = cross_corr / denominator
    return normalized_cross_corr

In [4]:
import numpy as np
from scipy.signal import fftconvolve
from scipy.ndimage import uniform_filter

def cross_correlate_fourier_opt(image, template):
    # Compute cross-correlation using fftconvolve
    cross_corr = fftconvolve(image, template[::-1, ::-1], mode='same')
    
    # Normalize the cross-correlation
    template_sum = np.sum(template ** 2)
    image_local_sums = uniform_filter(image ** 2, size=template.shape)
    denominator = np.sqrt(template_sum * image_local_sums)
    denominator[denominator == 0] = 1  # Prevent division by zero
    normalized_cross_corr = cross_corr / denominator
    
    return normalized_cross_corr

In [5]:
import numpy as np
from scipy.fft import fft2, ifft2
from scipy.ndimage import uniform_filter

def cross_correlate_fourier(image, template):
    # Ensure the template is zero-padded to the size of the image
    template_padded = np.zeros_like(image)
    template_shape = template.shape
    # Place the template at the top-left corner of the padded array
    template_padded[:template_shape[0], :template_shape[1]] = template
    
    # Compute the FFTs
    image_fft = fft2(image)
    template_fft = fft2(template_padded)
    
    # Compute cross-correlation in Fourier space
    cross_corr = np.real(ifft2(image_fft * np.conj(template_fft)))
    
    # Normalize the cross-correlation
    template_sum = np.sum(template ** 2)
    image_local_sums = uniform_filter(image ** 2, size=template_shape)
    denominator = np.sqrt(template_sum * image_local_sums)
    denominator[denominator == 0] = 1  # Prevent division by zero
    normalized_cross_corr = cross_corr / denominator
    
    return normalized_cross_corr

In [6]:
import cupy as cp
from scipy.ndimage import uniform_filter

def cross_correlate_fourier_gpu(image, template):
    # Convert image and template to GPU arrays
    image_gpu = cp.asarray(image)
    template_gpu = cp.asarray(template)
    
    # Compute cross-correlation using fftconvolve on GPU
    cross_corr_gpu = cp.fft.ifft2(cp.fft.fft2(image_gpu) * cp.conj(cp.fft.fft2(template_gpu, s=image_gpu.shape))).real
    
    # Normalize the cross-correlation
    template_sum = cp.sum(template_gpu ** 2)
    image_local_sums = cp.asarray(uniform_filter(cp.asnumpy(image_gpu ** 2), size=template.shape))
    denominator = cp.sqrt(template_sum * image_local_sums)
    denominator[denominator == 0] = 1  # Prevent division by zero
    normalized_cross_corr_gpu = cross_corr_gpu / denominator
    
    # Convert result back to CPU array
    normalized_cross_corr = cp.asnumpy(normalized_cross_corr_gpu)
    
    return normalized_cross_corr

In [7]:
def process_combination(volume, phi, theta, psi, I_filtered, projection_shape, ncc_max, ncc_mean, ncc_M2, n):
    """Process a single combination of angles and update relevant statistics."""
    # Rotate and project
    projection = rotate_and_project(volume, phi, theta, psi)

    # Check Dimensions
    if projection.shape != projection_shape:
        print(f"Warning: Projection dimensions {projection.shape} do not match expected shape {projection_shape}.")

    # Normalize projection
    projection = projection - np.mean(projection)

    # Band pass filter the projection
    proj_filtered = bandpass_filter(projection, low_sigma=1, high_sigma=5)

    # Normalize the filtered projection
    T_mean = np.mean(proj_filtered)
    T_std = np.std(proj_filtered)
    T_norm = (proj_filtered - T_mean) / T_std

    # Calculate cross-correlation
    cross_corr = cross_correlate_top_left(I_filtered, T_norm)

    # Update max elements
    local_ncc_max = np.maximum(ncc_max, cross_corr)

    # Update mean and variance using Welford's algorithm
    local_n = n + 1
    delta = cross_corr - ncc_mean
    local_ncc_mean = ncc_mean + delta / local_n
    delta2 = cross_corr - local_ncc_mean
    local_ncc_M2 = ncc_M2 + delta * delta2

    return local_ncc_max, local_ncc_mean, local_ncc_M2, local_n

# Prepare the input arguments for parallel processing
def parallelize_processing(volume, phi_angles, theta_angles, psi_angles, I_filtered, projection_shape):
    # Shared initial values
    ncc_max = np.zeros_like(I_filtered, dtype=np.float32)
    ncc_mean = np.zeros_like(I_filtered, dtype=np.float32)
    ncc_M2 = np.zeros_like(I_filtered, dtype=np.float32)
    n = 0

    # Generate all combinations of angles
    angle_combinations = [(phi, theta, psi) for phi, theta in zip(phi_angles, theta_angles) for psi in psi_angles]

    # Process combinations in parallel
    results = Parallel(n_jobs=-1)(
        delayed(process_combination)(
            volume, phi, theta, psi, I_filtered, projection_shape, ncc_max, ncc_mean, ncc_M2, n
        ) for phi, theta, psi in angle_combinations
    )

    # Aggregate results
    for local_ncc_max, local_ncc_mean, local_ncc_M2, local_n in results:
        ncc_max = np.maximum(ncc_max, local_ncc_max)
        ncc_mean = local_ncc_mean
        ncc_M2 = local_ncc_M2
        n = local_n

    # Calculate the variance
    ncc_variance = ncc_M2 / (n - 1) if n > 1 else np.zeros_like(I_filtered, dtype=np.float32)

    return ncc_max, ncc_mean, ncc_variance

# Prepare the input arguments for parallel processing
def parallelize_processing1(volume, phi_angles, theta_angles, psi_angles, I_filtered):
    # Shared initial values
    ncc_max = np.zeros_like(I_filtered, dtype=np.float32)
    ncc_mean = np.zeros_like(I_filtered, dtype=np.float32)
    ncc_M2 = np.zeros_like(I_filtered, dtype=np.float32)
    n = 0

    # Generate all combinations of angles
    angle_combinations = [(phi, theta, psi) for phi, theta in zip(phi_angles, theta_angles) for psi in psi_angles]

    # Process combinations in parallel
    results = Parallel(n_jobs=-1)(
        delayed(process_combination1)(
            volume, phi, theta, psi, I_filtered, ncc_max, ncc_mean, ncc_M2, n
        ) for phi, theta, psi in angle_combinations
    )

    # Aggregate results
    for local_ncc_max, local_ncc_mean, local_ncc_M2, local_n in results:
        ncc_max = np.maximum(ncc_max, local_ncc_max)
        ncc_mean = local_ncc_mean
        ncc_M2 = local_ncc_M2
        n = local_n

    # Calculate the variance
    ncc_variance = ncc_M2 / (n - 1) if n > 1 else np.zeros_like(I_filtered, dtype=np.float32)

    return ncc_max, ncc_mean, ncc_variance

def process_combination1(volume, phi, theta, psi, I_filtered, ncc_max, ncc_mean, ncc_M2, n):
    """Process a single combination of angles and update relevant statistics."""
    start_time = time.time()
    # Rotate and project
    projection = rotate_and_project_opt(volume, phi, theta, psi)

    # Check Dimensions
    #if projection.shape != projection_shape:
    #    print(f"Warning: Projection dimensions {projection.shape} do not match expected shape {projection_shape}.")

    # Normalize projection
    projection = projection - np.mean(projection)

    # Band pass filter the projection
    proj_filtered = bandpass_filter(projection, low_sigma=1, high_sigma=5)

    # Normalize the filtered projection
    T_mean = np.mean(proj_filtered)
    T_std = np.std(proj_filtered)
    T_norm = (proj_filtered - T_mean) / T_std

    # Calculate cross-correlation
    #cross_corr = cross_correlate_fourier_opt(I_filtered, T_norm)
    cross_corr = cross_correlate_fourier_gpu(I_filtered, T_norm)
    
    
    # Update max elements
    local_ncc_max = np.maximum(ncc_max, cross_corr)

    # Update mean and variance using Welford's algorithm
    local_n = n + 1
    delta = cross_corr - ncc_mean
    local_ncc_mean = ncc_mean + delta / local_n
    delta2 = cross_corr - local_ncc_mean
    local_ncc_M2 = ncc_M2 + delta * delta2
    end_time = time.time()
    print(f"Cross correlation complete for : {phi, theta, psi} in {end_time-start_time} seconds")
    return local_ncc_max, local_ncc_mean, local_ncc_M2, local_n

In [8]:
if __name__ == "__main__":
    # Initialize necessary variables here (e.g., volume, phi_angles, theta_angles, psi_angles, etc.)
    # Example:
    # volume = ...
    #volume_file = input("Enter the path to the input MRC volume file: ").strip()
    volume_file = '0_data/7ood.mrc'
    volume, voxel_size = read_mrc_file(volume_file)
    # Get nside from the user
    try:
        nside = int(input("Enter NSIDE parameter (must be a power of 2 integer): "))
        if not hp.isnsideok(nside):
            raise ValueError
    except ValueError:
        print("Error: NSIDE must be a valid power of 2 integer.")
        sys.exit(1)

    print(f"Using NSIDE = {nside} for Healpix sampling.")

    # Get angular spacing from the user
    try:
        angular_spacing = float(input("Enter angular spacing in degrees: "))
        if angular_spacing <= 0:
            raise ValueError
    except ValueError:
        print("Error: Angular spacing must be a positive number.")
        sys.exit(1)
    # phi_angles = ...
    # theta_angles = ...
    # psi_angles = ...
    # Generate phi and theta angles
    phi_angles, theta_angles = generate_all_angles(nside, angular_spacing)
    print(f"Generated {len(phi_angles)} orientations using Healpix.")

    # Generate psi angles
    try:
        psi_step = float(input("Enter psi angular step in degrees: "))
        if psi_step <= 0:
            raise ValueError
    except ValueError:
        print("Error: Psi angular step must be a positive number.")
        sys.exit(1)

    psi_angles = np.arange(0, 360, psi_step)
    print(f"Using {len(psi_angles)} in-plane rotations (psi angles).")

    # Total number of projections
    total_projections = len(phi_angles) * len(psi_angles)
    print(f"Total number of projections to generate: {total_projections}")
    # I_filtered = ...
    with mrcfile.open('0_data/00040_3_0_wctf.mrc', permissive=True) as mrc:
        I = np.squeeze(mrc.data.astype(np.float32))
    # Apply the band-pass filter to the micrograph
    I_filtered = bandpass_filter(I, low_sigma=1, high_sigma=5)
    
    ncc_max, ncc_mean, ncc_variance = parallelize_processing1(
        volume, phi_angles, theta_angles, psi_angles, I_filtered,
    )

Using NSIDE = 1 for Healpix sampling.
Pixel 1/12: Generated 9 angles
Pixel 2/12: Generated 18 angles
Pixel 3/12: Generated 27 angles
Pixel 4/12: Generated 9 angles
Pixel 5/12: Generated 12 angles
Pixel 6/12: Generated 9 angles
Pixel 7/12: Generated 9 angles
Pixel 8/12: Generated 12 angles
Pixel 9/12: Generated 9 angles
Pixel 10/12: Generated 18 angles
Pixel 11/12: Generated 27 angles
Pixel 12/12: Generated 9 angles
Generated 168 orientations using Healpix.
Using 12 in-plane rotations (psi angles).
Total number of projections to generate: 2016
Cross correlation complete for : (1.0471975511965976, 0.5235987755982988, 90.0) in 59.7474639415741 seconds
Cross correlation complete for : (0.0, 0.5235987755982988, 330.0) in 59.964335203170776 seconds
Cross correlation complete for : (0.5235987755982988, 0.0, 150.0) in 60.21707057952881 seconds
Cross correlation complete for : (0.0, 1.0471975511965976, 0.0) in 59.8497416973114 seconds
Cross correlation complete for : (0.5235987755982988, 1.0471

OutOfMemoryError: Out of memory allocating 46,190,592 bytes (allocated so far: 0 bytes).

In [None]:
# Display the results
plt.figure(figsize=(6, 6))
plt.imshow(ncc_max, cmap='gray')
plt.title('Max Cross-Correlation Map')
plt.colorbar()
plt.axis('off')
plt.show()

plt.figure(figsize=(6, 6))
plt.imshow(ncc_mean, cmap='gray')
plt.title('Mean Cross-Correlation Map')
plt.colorbar()
plt.axis('off')
plt.show()

plt.figure(figsize=(6, 6))
plt.imshow(ncc_variance, cmap='gray')
plt.title('Variance Cross-Correlation Map')
plt.colorbar()
plt.axis('off')
plt.show()


# Aggregate the correlation maps by taking the maximum at each pixel
mip = (ncc_max - ncc_mean)/np.sqrt(ncc_variance)

# Display the aggregated cross-correlation map
plt.figure(figsize=(6, 6))
plt.imshow(mip, cmap='gray')
plt.title('Aggregated MIP map')
plt.colorbar()
plt.axis('off')
plt.show()

In [None]:


# Count the number of elements above 8
num_elements_above_8 = np.sum(mip > 7.43)
print(f"Number of elements in mip_scaled above 8: {num_elements_above_8}")

In [7]:
# Find the row and column indices of values greater than 7 in the mip array
indices = np.where(mip > 7.43)
rows, cols = indices

# Print the row and column indices in pairs
for row, col in zip(rows, cols):
    print(f"Row: {row}, Col: {col}")

In [None]:
plt.figure(figsize=(6, 6))
plt.imshow(I_filtered, cmap='gray')

# Add circles for each pair of coordinates
for col, row in zip(cols, rows):
	circle = plt.Circle((col, row), radius=50, color='red', fill=False, linewidth=0.2)
	plt.gca().add_patch(circle)

#plt.scatter(cols, rows, c='red', s=1)
plt.title('Row, Col Indices on Image I')
plt.axis('off')
plt.show()