<a href="https://colab.research.google.com/github/ykitaguchi77/statistics_for_articles/blob/main/Wavefront_simulation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from PIL import Image, ImageDraw, ImageFont
import math
import numpy as np
from scipy.fft import fft2, fftshift
import matplotlib.pyplot as plt
from scipy.signal import convolve2d

def draw_landolt_c_corrected(image_size=(224, 224), visual_acuity=1.0, distance_m=5,
                             orientation='right', background_color=(255, 255, 255), ring_color=(0, 0, 0)):
    """
    Draws a Landolt C on an image, with a more accurate gap.
    The gap is created by drawing a full ring and then overlaying a rectangle to create the opening.
    """
    image = Image.new("RGB", image_size, background_color)
    draw = ImageDraw.Draw(image)

    # Standard Landolt C: gap size = 1/5 of outer diameter, ring thickness = 1/5 of outer diameter
    # Visual acuity 1.0 (20/20) means the gap subtends 1 arcminute at the eye.
    gap_angle_arcmin = 1.0 / visual_acuity
    gap_angle_rad = math.radians(gap_angle_arcmin / 60.0)
    gap_size_m = distance_m * math.tan(gap_angle_rad)
    outer_diameter_m = 5 * gap_size_m

    # Scaling to pixels - this remains a simplification for display
    # Let's make the outer diameter proportional to the image size for VA=1.0
    # For VA=1.0, let outer diameter be, for example, 40% of the smaller image dimension.
    # reference_diameter_pixels_va1 = min(image_size) * 0.4
    # For 224px, this is 224 * 0.4 = 89.6. Let's use the previous 80px for VA=1.0 for consistency.
    reference_acuity_diameter_pixels = 80
    outer_diameter_pixels = reference_acuity_diameter_pixels / visual_acuity
    outer_diameter_pixels = max(10, min(min(image_size) * 0.9, outer_diameter_pixels))

    ring_thickness_pixels = outer_diameter_pixels / 5.0
    gap_size_pixels = ring_thickness_pixels # Standard Landolt C definition

    center_x, center_y = image_size[0] // 2, image_size[1] // 2
    outer_radius = outer_diameter_pixels / 2.0
    inner_radius = outer_radius - ring_thickness_pixels

    # Draw the outer circle
    draw.ellipse(
        (center_x - outer_radius, center_y - outer_radius,
         center_x + outer_radius, center_y + outer_radius),
        fill=ring_color
    )
    # Draw the inner circle (hole) with background color
    draw.ellipse(
        (center_x - inner_radius, center_y - inner_radius,
         center_x + inner_radius, center_y + inner_radius),
        fill=background_color
    )

    # Create the gap by drawing a rectangle over the ring
    # The gap width is gap_size_pixels. The gap extends from the inner to the outer radius.
    # The length of the rectangle should be sufficient to clear the ring thickness.
    gap_rect_width = gap_size_pixels
    gap_rect_length = outer_radius - inner_radius + 2 # A bit more to ensure clean cut

    # Position the rectangle based on orientation
    if orientation == 'right':
        # Rectangle starts from center_x + inner_radius, extends to center_x + outer_radius
        # Centered vertically at center_y
        rect_x0 = center_x + inner_radius -1 # Start slightly inside to ensure cut
        rect_y0 = center_y - gap_rect_width / 2
        rect_x1 = center_x + outer_radius + 1 # Extend slightly beyond to ensure cut
        rect_y1 = center_y + gap_rect_width / 2
    elif orientation == 'left':
        rect_x0 = center_x - outer_radius - 1
        rect_y0 = center_y - gap_rect_width / 2
        rect_x1 = center_x - inner_radius + 1
        rect_y1 = center_y + gap_rect_width / 2
    elif orientation == 'up':
        rect_x0 = center_x - gap_rect_width / 2
        rect_y0 = center_y - outer_radius - 1
        rect_x1 = center_x + gap_rect_width / 2
        rect_y1 = center_y - inner_radius + 1
    elif orientation == 'down':
        rect_x0 = center_x - gap_rect_width / 2
        rect_y0 = center_y + inner_radius - 1
        rect_x1 = center_x + gap_rect_width / 2
        rect_y1 = center_y + outer_radius + 1
    else: # Default to right
        rect_x0 = center_x + inner_radius -1
        rect_y0 = center_y - gap_rect_width / 2
        rect_x1 = center_x + outer_radius + 1
        rect_y1 = center_y + gap_rect_width / 2

    draw.rectangle((rect_x0, rect_y0, rect_x1, rect_y1), fill=background_color)

    return image

# ANSI Zernike Polynomials (Normalized)
def zernike_polynomial(n, m, rho, phi):
    if (n - abs(m)) % 2 != 0:
        return np.zeros_like(rho)
    R = np.zeros_like(rho)
    if (n - abs(m)) % 2 == 0:
        for k in range((n - abs(m)) // 2 + 1):
            term = ((-1)**k * math.factorial(n - k) //
                    (math.factorial(k) * math.factorial((n + abs(m)) // 2 - k) *
                     math.factorial((n - abs(m)) // 2 - k))) * rho**(n - 2*k)
            R += term
    if m == 0:
        if n == 0: return np.ones_like(rho) * np.sqrt(1)
        return np.sqrt(n + 1) * R
    elif m > 0: return np.sqrt(2 * (n + 1)) * R * np.cos(m * phi)
    else: return np.sqrt(2 * (n + 1)) * R * np.sin(abs(m) * phi)

def calculate_wavefront_aberration(zernike_coeffs, pupil_radius_pixels, image_size_pixels):
    x = np.linspace(-1, 1, image_size_pixels) * pupil_radius_pixels
    y = np.linspace(-1, 1, image_size_pixels) * pupil_radius_pixels
    X, Y = np.meshgrid(x, y)
    rho = np.sqrt(X**2 + Y**2) / pupil_radius_pixels
    phi = np.arctan2(Y, X)
    wavefront = np.zeros((image_size_pixels, image_size_pixels), dtype=float)
    pupil_mask = rho <= 1.0
    for (n, m), coeff in zernike_coeffs.items():
        if coeff == 0: continue
        Z_nm = zernike_polynomial(n, m, rho[pupil_mask], phi[pupil_mask])
        wavefront[pupil_mask] += coeff * Z_nm
    return wavefront, pupil_mask

def calculate_psf(wavefront_aberration, pupil_mask, wavelength_nm=550, pupil_diameter_mm=4.0, image_size_pixels=224):
    wavelength_um = wavelength_nm * 1e-3
    k = 2 * np.pi / wavelength_um
    pupil_function = np.zeros((image_size_pixels, image_size_pixels), dtype=complex)
    pupil_function[pupil_mask] = np.exp(1j * k * wavefront_aberration[pupil_mask])
    psf_complex = fftshift(fft2(pupil_function))
    psf = np.abs(psf_complex)**2
    if np.sum(psf) > 0: psf = psf / np.sum(psf)
    return psf

def apply_psf_to_image(image_array, psf_array):
    if np.sum(psf_array) > 0: psf_normalized = psf_array / np.sum(psf_array)
    else: psf_normalized = psf_array
    convolved_image = convolve2d(image_array, psf_normalized, mode='same', boundary='symm')
    convolved_image = np.clip(convolved_image, 0, np.max(image_array)) # Ensure values are within valid range
    return convolved_image

if __name__ == '__main__':
    # User specified parameters
    image_resolution = (224, 224)
    default_orientation = 'right'
    bg_color = (255, 255, 255) # White
    fg_color = (0, 0, 0)       # Black
    visual_acuities = [0.1, 0.5, 1.0]
    generated_landolt_paths_corrected = []

    # Define base path for Colab
    colab_base_path = "/content/" # Or simply use file names directly

    print("Generating corrected Landolt C images...")
    for va in visual_acuities:
        img = draw_landolt_c_corrected(
            image_size=image_resolution,
            visual_acuity=va,
            orientation=default_orientation,
            background_color=bg_color,
            ring_color=fg_color
        )
        path = f"{colab_base_path}corrected_landolt_c_va_{va}_{default_orientation}.png"
        img.save(path)
        generated_landolt_paths_corrected.append(path)
        print(f"Saved corrected Landolt C for VA {va} ({default_orientation}) to {path}")

    # Example with different orientation for corrected version
    img_up_corrected = draw_landolt_c_corrected(
        image_size=image_resolution,
        visual_acuity=1.0,
        orientation='up',
        background_color=bg_color,
        ring_color=fg_color
    )
    path_up_corrected = f"{colab_base_path}corrected_landolt_c_va_1.0_up.png"
    img_up_corrected.save(path_up_corrected)
    generated_landolt_paths_corrected.append(path_up_corrected)
    print(f"Saved corrected Landolt C for VA 1.0 (up) to {path_up_corrected}")
    print("Corrected Landolt C generation examples complete.")

    # --- PSF Generation Example (can remain the same) ---
    print("\nStarting PSF generation example...")
    psf_image_size = 224
    pupil_diameter_mm_param = 6.0
    wavelength_nm_param = 550
    example_zernike_coeffs = {
        (0,0): 0.0, (1,-1): 0.0, (1,1): 0.0, (2,-2): 0.0, (2,0): 0.0, (2,2): 1.0, # Defocus
        (3,-3): 0.0, (3,-1): 0.0, (3,1): 0.0, (3,3): 0.0, (4,-4): 0.0, (4,-2): 0.0,
        (4,0): 0.0, # Spherical Aberration
        (4,2): 0.0, (4,4): 0.0
    }
    pupil_radius_in_psf_pixels = psf_image_size // 2
    wavefront_ab, pupil_m = calculate_wavefront_aberration(
        zernike_coeffs=example_zernike_coeffs,
        pupil_radius_pixels=pupil_radius_in_psf_pixels,
        image_size_pixels=psf_image_size
    )
    psf_map = calculate_psf(
        wavefront_aberration=wavefront_ab,
        pupil_mask=pupil_m,
        wavelength_nm=wavelength_nm_param,
        pupil_diameter_mm=pupil_diameter_mm_param,
        image_size_pixels=psf_image_size
    )
    print(f"PSF map generated with shape: {psf_map.shape}")
    plt.figure(figsize=(6,6))
    psf_display = psf_map
    if np.max(psf_map) > 0:
        psf_display = np.log1p(psf_map / np.max(psf_map) * 1000) # Log scale for better visualization
    plt.imshow(psf_display, cmap='hot')
    plt.title(f"Point Spread Function (log scale)\nDefocus: {example_zernike_coeffs.get((2,0),0)} $\mu m$, Pupil: {pupil_diameter_mm_param}mm")
    plt.colorbar()
    psf_image_path_corrected = f"{colab_base_path}psf_example_corrected.png"
    plt.savefig(psf_image_path_corrected)
    plt.close() # Close the plot to prevent it from displaying in the notebook if not desired after saving
    print(f"Saved example PSF image to {psf_image_path_corrected}")
    print("PSF generation example complete.")

    # --- Convolution Step with corrected Landolt C ---
    print("\nStarting convolution of CORRECTED Landolt C with PSF...")
    # Use one of the newly generated corrected Landolt C images
    landolt_c_to_convolve_path_corrected = f"{colab_base_path}corrected_landolt_c_va_1.0_{default_orientation}.png"
    try:
        landolt_c_img_pil_corrected = Image.open(landolt_c_to_convolve_path_corrected).convert('L') # Convert to grayscale
        landolt_c_array_corrected = np.array(landolt_c_img_pil_corrected) / 255.0 # Normalize to [0, 1]

        blurred_landolt_c_array_corrected = apply_psf_to_image(landolt_c_array_corrected, psf_map)

        # Convert back to uint8 for saving as image
        blurred_landolt_c_img_corrected = Image.fromarray((blurred_landolt_c_array_corrected * 255).astype(np.uint8))
        blurred_image_path_corrected = f"{colab_base_path}blurred_corrected_landolt_c_va_1.0_right.png"
        blurred_landolt_c_img_corrected.save(blurred_image_path_corrected)
        print(f"Saved blurred CORRECTED Landolt C image to {blurred_image_path_corrected}")

        # Display comparison
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        axes[0].imshow(landolt_c_array_corrected, cmap='gray', vmin=0, vmax=1)
        axes[0].set_title('Corrected Landolt C (VA 1.0)')
        axes[0].axis('off')

        axes[1].imshow(psf_display, cmap='hot')
        axes[1].set_title('PSF (log scale)')
        axes[1].axis('off')

        axes[2].imshow(blurred_landolt_c_array_corrected, cmap='gray', vmin=0, vmax=1)
        axes[2].set_title('Blurred Corrected Landolt C')
        axes[2].axis('off')

        plt.tight_layout()
        comparison_image_path_corrected = f"{colab_base_path}comparison_corrected_landolt_psf_blurred.png"
        plt.savefig(comparison_image_path_corrected)
        plt.close(fig) # Close the plot
        print(f"Saved comparison image for corrected Landolt C to {comparison_image_path_corrected}")

    except FileNotFoundError:
        print(f"Error: Could not find corrected Landolt C image at {landolt_c_to_convolve_path_corrected} for convolution.")
    except Exception as e:
        print(f"An error occurred during convolution with corrected Landolt C: {e}")

    print("Corrected convolution example complete.")

    #To see the images in Colab, you can use:
    from google.colab.patches import cv2_imshow
    import cv2
    img = cv2.imread(comparison_image_path_corrected)
    cv2_imshow(img)
    # #Or use PIL to display:
    # display(Image.open(comparison_image_path_corrected))

In [None]:
from PIL import Image, ImageDraw, ImageFont
import math
import numpy as np
from scipy.fft import fft2, fftshift
import matplotlib.pyplot as plt
from scipy.signal import convolve2d

# --- Helper functions to calculate Diopters and Axis ---
def calculate_defocus_diopter(c20_um, pupil_diameter_mm):
    """
    Calculates defocus in diopters from Zernike coefficient C(2,0).
    C(2,0) is in micrometers, pupil_diameter_mm is in millimeters.
    Formula based on ANSI Z80.28 standards (approximate).
    D_defocus = -4 * sqrt(3) * C(2,0) / (pupil_radius_mm^2)
    """
    if pupil_diameter_mm <= 0:
        return float('nan')
    pupil_radius_mm = pupil_diameter_mm / 2.0
    d_defocus = (-4 * math.sqrt(3) * c20_um) / (pupil_radius_mm**2)
    return d_defocus

def calculate_astigmatism_diopter_axis(c2n2_um, c2p2_um, pupil_diameter_mm):
    """
    Calculates astigmatism (cylinder diopters and axis) from Zernike
    coefficients C(2,-2) and C(2,2).
    Coefficients are in micrometers, pupil_diameter_mm is in millimeters.
    J0 = -2 * sqrt(6) * C(2,2) / (pupil_radius_mm^2)
    J45 = -2 * sqrt(6) * C(2,-2) / (pupil_radius_mm^2)
    Cylinder = -sqrt(J0^2 + J45^2)
    Axis (degrees) = 0.5 * atan2(J45, J0) * (180/pi)
    """
    if pupil_diameter_mm <= 0:
        return float('nan'), float('nan')
    pupil_radius_mm = pupil_diameter_mm / 2.0

    j0 = (-2 * math.sqrt(6) * c2p2_um) / (pupil_radius_mm**2)
    j45 = (-2 * math.sqrt(6) * c2n2_um) / (pupil_radius_mm**2)

    d_cylinder = -math.sqrt(j0**2 + j45**2)

    if abs(d_cylinder) < 1e-9:  # Negligible astigmatism
        axis_deg = 0.0
    else:
        axis_rad = 0.5 * math.atan2(j45, j0)
        axis_deg = math.degrees(axis_rad)
        if axis_deg < 0:
            axis_deg += 180.0
    return d_cylinder, axis_deg

# --- Landolt C, Zernike, PSF, and Convolution functions ---
def draw_landolt_c_corrected(image_size=(224, 224), visual_acuity=1.0, distance_m=5,
                             orientation='right', background_color=(255, 255, 255), ring_color=(0, 0, 0)):
    image = Image.new("RGB", image_size, background_color)
    draw = ImageDraw.Draw(image)
    gap_angle_arcmin = 1.0 / visual_acuity
    gap_angle_rad = math.radians(gap_angle_arcmin / 60.0)
    gap_size_m = distance_m * math.tan(gap_angle_rad)
    outer_diameter_m = 5 * gap_size_m
    reference_acuity_diameter_pixels = 80
    outer_diameter_pixels = reference_acuity_diameter_pixels / visual_acuity
    outer_diameter_pixels = max(10, min(min(image_size) * 0.9, outer_diameter_pixels))
    ring_thickness_pixels = outer_diameter_pixels / 5.0
    gap_size_pixels = ring_thickness_pixels
    center_x, center_y = image_size[0] // 2, image_size[1] // 2
    outer_radius = outer_diameter_pixels / 2.0
    inner_radius = outer_radius - ring_thickness_pixels
    draw.ellipse(
        (center_x - outer_radius, center_y - outer_radius,
         center_x + outer_radius, center_y + outer_radius),
        fill=ring_color
    )
    draw.ellipse(
        (center_x - inner_radius, center_y - inner_radius,
         center_x + inner_radius, center_y + inner_radius),
        fill=background_color
    )
    gap_rect_width = gap_size_pixels
    if orientation == 'right':
        rect_x0 = center_x + inner_radius -1
        rect_y0 = center_y - gap_rect_width / 2
        rect_x1 = center_x + outer_radius + 1
        rect_y1 = center_y + gap_rect_width / 2
    elif orientation == 'left':
        rect_x0 = center_x - outer_radius - 1
        rect_y0 = center_y - gap_rect_width / 2
        rect_x1 = center_x - inner_radius + 1
        rect_y1 = center_y + gap_rect_width / 2
    elif orientation == 'up':
        rect_x0 = center_x - gap_rect_width / 2
        rect_y0 = center_y - outer_radius - 1
        rect_x1 = center_x + gap_rect_width / 2
        rect_y1 = center_y - inner_radius + 1
    elif orientation == 'down':
        rect_x0 = center_x - gap_rect_width / 2
        rect_y0 = center_y + inner_radius - 1
        rect_x1 = center_x + gap_rect_width / 2
        rect_y1 = center_y + outer_radius + 1
    else: # Default to right
        rect_x0 = center_x + inner_radius -1
        rect_y0 = center_y - gap_rect_width / 2
        rect_x1 = center_x + outer_radius + 1
        rect_y1 = center_y + gap_rect_width / 2
    draw.rectangle((rect_x0, rect_y0, rect_x1, rect_y1), fill=background_color)
    return image

def zernike_polynomial(n, m, rho, phi):
    if (n - abs(m)) % 2 != 0:
        return np.zeros_like(rho)
    R = np.zeros_like(rho)
    if (n - abs(m)) % 2 == 0:
        for k in range((n - abs(m)) // 2 + 1):
            term = ((-1)**k * math.factorial(n - k) //
                    (math.factorial(k) * math.factorial((n + abs(m)) // 2 - k) *
                     math.factorial((n - abs(m)) // 2 - k))) * rho**(n - 2*k)
            R += term
    if m == 0:
        if n == 0: return np.ones_like(rho) * np.sqrt(1)
        return np.sqrt(n + 1) * R
    elif m > 0:
        return np.sqrt(2 * (n + 1)) * R * np.cos(m * phi)
    else:
        return np.sqrt(2 * (n + 1)) * R * np.sin(abs(m) * phi)

def calculate_wavefront_aberration(zernike_coeffs, pupil_radius_pixels, image_size_pixels):
    x_coords = np.linspace(-pupil_radius_pixels, pupil_radius_pixels, image_size_pixels)
    y_coords = np.linspace(-pupil_radius_pixels, pupil_radius_pixels, image_size_pixels)
    X, Y = np.meshgrid(x_coords, y_coords)
    rho = np.sqrt(X**2 + Y**2) / pupil_radius_pixels
    phi = np.arctan2(Y, X)
    wavefront = np.zeros((image_size_pixels, image_size_pixels), dtype=float)
    pupil_mask = rho <= 1.0
    rho_pupil = rho[pupil_mask]
    phi_pupil = phi[pupil_mask]
    for (n, m), coeff in zernike_coeffs.items():
        if coeff == 0: continue
        Z_nm_values = zernike_polynomial(n, m, rho_pupil, phi_pupil)
        wavefront[pupil_mask] += coeff * Z_nm_values
    return wavefront, pupil_mask

def calculate_psf(wavefront_aberration, pupil_mask, wavelength_nm=550, image_size_pixels=224):
    wavelength_um = wavelength_nm * 1e-3
    k = 2 * np.pi / wavelength_um
    pupil_function = np.zeros((image_size_pixels, image_size_pixels), dtype=complex)
    pupil_function[pupil_mask] = np.exp(1j * k * wavefront_aberration[pupil_mask])
    psf_complex = fftshift(fft2(pupil_function))
    psf = np.abs(psf_complex)**2
    if np.sum(psf) > 0:
        psf = psf / np.sum(psf)
    return psf

def apply_psf_to_image(image_array, psf_array):
    if np.sum(psf_array) > 0:
        psf_normalized = psf_array / np.sum(psf_array)
    else:
        psf_normalized = psf_array
    convolved_image = convolve2d(image_array, psf_normalized, mode='same', boundary='symm')
    convolved_image = np.clip(convolved_image, 0, 1.0)
    return convolved_image


if __name__ == '__main__':
    image_resolution = (224, 224)
    default_orientation = 'right'
    bg_color = (255, 255, 255)
    fg_color = (0, 0, 0)
    visual_acuities = [1.0]
    generated_landolt_paths_corrected = []
    colab_base_path = "/content/"

    print("Generating corrected Landolt C images...")
    for va in visual_acuities:
        img_pil = draw_landolt_c_corrected(
            image_size=image_resolution,
            visual_acuity=va,
            orientation=default_orientation,
            background_color=bg_color,
            ring_color=fg_color
        )
        path = f"{colab_base_path}corrected_landolt_c_va_{va:.1f}_{default_orientation}.png"
        img_pil.save(path)
        generated_landolt_paths_corrected.append(path)
        print(f"Saved corrected Landolt C for VA {va:.1f} ({default_orientation}) to {path}")
        # Optionally display the generated Landolt C using matplotlib
        plt.figure(figsize=(4,4))
        plt.imshow(img_pil)
        plt.title(f"Landolt C VA {va:.1f} ({default_orientation})")
        plt.axis('off')
        plt.show()
        plt.close()
    print("Corrected Landolt C generation examples complete.")

    print("\nStarting PSF generation example...")
    psf_image_size = 224
    pupil_diameter_mm_param = 4.0
    wavelength_nm_param = 550

    example_zernike_coeffs = {
        (0,0): 0.0, (1,-1): 0.0, (1,1): 0.0,
        (2,-2): 0.1, (2,0): 0.5, (2,2): -0.3,
        (3,-3): 0.3, (3,-1): 0.0, (3,1): 0.0, (3,3): 0.0,
        (4,0): 0.0
    }

    c20_val = example_zernike_coeffs.get((2,0), 0.0)
    c2n2_val = example_zernike_coeffs.get((2,-2), 0.0)
    c2p2_val = example_zernike_coeffs.get((2,2), 0.0)

    defocus_D = calculate_defocus_diopter(c20_val, pupil_diameter_mm_param)
    astig_cyl_D, astig_axis_deg = calculate_astigmatism_diopter_axis(c2n2_val, c2p2_val, pupil_diameter_mm_param)

    print("\n--- Zernike Coefficients & Clinical Representation ---")
    print(f"Pupil Diameter: {pupil_diameter_mm_param} mm")
    print(f"Zernike (2,0) [Defocus]: {c20_val:.3f} um  => Defocus: {defocus_D:.3f} D")
    print(f"Zernike (2,-2) [Astig @45deg]: {c2n2_val:.3f} um")
    print(f"Zernike (2,2) [Astig @0/90deg]: {c2p2_val:.3f} um")
    if abs(astig_cyl_D) > 1e-9:
        print(f"  => Calculated Astigmatism: Cylinder: {astig_cyl_D:.3f} D, Axis: {astig_axis_deg:.1f} degrees")
    else:
        print(f"  => Calculated Astigmatism: Cylinder: {astig_cyl_D:.3f} D (negligible)")
    print("----------------------------------------------------")

    pupil_radius_in_psf_pixels = psf_image_size // 2
    wavefront_ab, pupil_m = calculate_wavefront_aberration(
        zernike_coeffs=example_zernike_coeffs,
        pupil_radius_pixels=pupil_radius_in_psf_pixels,
        image_size_pixels=psf_image_size
    )
    psf_map = calculate_psf(
        wavefront_aberration=wavefront_ab,
        pupil_mask=pupil_m,
        wavelength_nm=wavelength_nm_param,
        image_size_pixels=psf_image_size
    )
    print(f"PSF map generated with shape: {psf_map.shape}")

    current_fig = plt.figure(figsize=(7,7)) # Get current figure for explicit closing later
    psf_display = psf_map
    if np.max(psf_map) > 0:
        psf_display = np.log1p(psf_map / np.max(psf_map) * 1000)

    psf_title_str = f"Point Spread Function (log scale)\n"
    psf_title_str += f"$Z_2^0$: {c20_val:.2f} $\mu m$ (Defocus: {defocus_D:.2f} D)\n"
    if abs(astig_cyl_D) > 1e-9:
        psf_title_str += f"$Z_2^{{-2}}$: {c2n2_val:.2f}, $Z_2^2$: {c2p2_val:.2f} $\mu m$ (Cyl: {astig_cyl_D:.2f} D @ {astig_axis_deg:.0f}$^\circ$)\n"
    else:
        psf_title_str += f"$Z_2^{{-2}}$: {c2n2_val:.2f}, $Z_2^2$: {c2p2_val:.2f} $\mu m$ (Cyl: {astig_cyl_D:.2f} D)\n"
    psf_title_str += f"Pupil: {pupil_diameter_mm_param}mm, $\lambda$: {wavelength_nm_param}nm"

    plt.title(psf_title_str, fontsize=10)
    img_ax = plt.imshow(psf_display, cmap='hot')
    plt.colorbar(img_ax)

    psf_image_path_corrected = f"{colab_base_path}psf_example_with_diopters.png"
    plt.savefig(psf_image_path_corrected)
    print(f"Saved example PSF image to {psf_image_path_corrected}")
    plt.show() # Display the PSF plot
    plt.close(current_fig) # Close the specific figure
    print("PSF generation example complete.")


    print("\nStarting convolution of CORRECTED Landolt C with PSF...")
    if not generated_landolt_paths_corrected:
        print("Error: No Landolt C images were generated. Skipping convolution.")
    else:
        landolt_c_to_convolve_path_corrected = generated_landolt_paths_corrected[0]
        try:
            landolt_c_img_pil_corrected = Image.open(landolt_c_to_convolve_path_corrected).convert('L')
            landolt_c_array_corrected = np.array(landolt_c_img_pil_corrected) / 255.0

            blurred_landolt_c_array_corrected = apply_psf_to_image(landolt_c_array_corrected, psf_map)

            blurred_landolt_c_img_corrected = Image.fromarray((blurred_landolt_c_array_corrected * 255).astype(np.uint8))
            blurred_image_path_corrected = f"{colab_base_path}blurred_corrected_landolt_c_with_diopters.png"
            blurred_landolt_c_img_corrected.save(blurred_image_path_corrected)
            print(f"Saved blurred CORRECTED Landolt C image to {blurred_image_path_corrected}")

            comparison_fig, axes = plt.subplots(1, 3, figsize=(18, 6)) # Get figure for explicit closing
            axes[0].imshow(landolt_c_array_corrected, cmap='gray', vmin=0, vmax=1)
            axes[0].set_title('Original Landolt C (VA 1.0)')
            axes[0].axis('off')

            psf_subplot_title = f"PSF\nDef: {defocus_D:.2f} D\n"
            if abs(astig_cyl_D) > 1e-9:
                psf_subplot_title += f"Cyl: {astig_cyl_D:.2f} D @ {astig_axis_deg:.0f}$^\circ$"
            else:
                psf_subplot_title += f"Cyl: {astig_cyl_D:.2f} D (negligible)"

            axes[1].imshow(psf_display, cmap='hot')
            axes[1].set_title(psf_subplot_title, fontsize=10)
            axes[1].axis('off')

            axes[2].imshow(blurred_landolt_c_array_corrected, cmap='gray', vmin=0, vmax=1)
            axes[2].set_title('Blurred Landolt C')
            axes[2].axis('off')

            plt.tight_layout()
            comparison_image_path_corrected = f"{colab_base_path}comparison_landolt_psf_blurred_with_diopters.png"
            plt.savefig(comparison_image_path_corrected)
            print(f"Saved comparison image to {comparison_image_path_corrected}")
            plt.show() # Display the comparison plot
            plt.close(comparison_fig) # Close the specific figure

        except FileNotFoundError:
            print(f"Error: Could not find Landolt C image at {landolt_c_to_convolve_path_corrected} for convolution.")
        except Exception as e:
            print(f"An error occurred during convolution: {e}")

    print("Corrected convolution example complete.")