In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2

# Image size
height, width = 256, 256

def draw_level(level):
    image = np.zeros((height, width), dtype=np.uint8)

    if level == 1:
        # Level 1: Empty Square (outline only)
        top_left = (80, 80)
        bottom_right = (176, 176)
        cv2.rectangle(image, top_left, bottom_right, color=255, thickness=2)

    elif level == 2:
        # Level 2: Diagonal line + circle
        cv2.line(image, (50, 200), (200, 50), color=255, thickness=3)
        cv2.circle(image, (128, 128), 30, color=180, thickness=2)

    elif level == 3:
        # Level 3: Cross + circle
        cv2.line(image, (0, 128), (255, 128), color=200, thickness=3)
        cv2.line(image, (128, 0), (128, 255), color=200, thickness=3)
        cv2.circle(image, (128, 128), 40, color=180, thickness=2)

    elif level == 4:
        # Level 4: Grid of empty squares
        for x in range(0, width, 64):
            for y in range(0, height, 64):
                cv2.rectangle(image, (x + 8, y + 8), (x + 56, y + 56), color=150, thickness=2)

    elif level == 5:
        # Level 5: Complex pattern: multiple primitives
        cv2.line(image, (20, 20), (230, 230), color=255, thickness=2)
        cv2.circle(image, (64, 64), 20, color=180, thickness=2)
        cv2.ellipse(image, (192, 64), (30, 15), 30, 0, 360, 200, thickness=2)
        pts = np.array([[150, 200], [170, 210], [180, 230], [160, 240], [140, 220]], np.int32)
        cv2.polylines(image, [pts], isClosed=True, color=220, thickness=2)
        cv2.rectangle(image, (100, 100), (140, 140), color=200, thickness=2)

    return image

# Plot all levels
fig, axes = plt.subplots(1, 5, figsize=(18, 4))
for level in range(1, 6):
    ax = axes[level - 1]
    img = draw_level(level)
    ax.imshow(img, cmap='gray')
    ax.set_title(f'Level {level}')
    ax.axis('off')

plt.tight_layout()
plt.show()


In [None]:
image= draw_level(3)

In [None]:
plt.figure(figsize=(6, 6))
plt.imshow(draw_level(3), cmap='gray')
plt.title('Level 3')
plt.axis('off')
plt.show()

In [None]:
from collections import Counter

# Parameters
total_fluorophores = 10000
free_floating_ratio = 0.1  # 5%
jitter_radius = 3  # maximum jitter radius in pixels

# Exact counts
num_free = int(total_fluorophores * free_floating_ratio)
num_attached = total_fluorophores - num_free

# Step 1: Get all coordinates on the structure
structure_coords = np.column_stack(np.where(image > 0))  # (y, x) format

if len(structure_coords) < 1:
    raise ValueError("No pixels in structure to place attached fluorophores.")

# Step 2: Randomly sample fluorophores from valid structure pixels WITH replacement
sample_indices = np.random.choice(len(structure_coords), size=num_attached, replace=True)
sampled_coords = structure_coords[sample_indices]

# Count fluorophores per pixel (allows multiple bindings)
pixel_counts = Counter(map(tuple, sampled_coords))

# Prepare lists to hold jittered attached fluorophore positions
attached_x_jittered = []
attached_y_jittered = []

# Add jitter for each fluorophore at each pixel, proportional to count
for (y, x), count in pixel_counts.items():
    for _ in range(count):
        jitter_x = x + np.random.randint(-jitter_radius, jitter_radius + 1)
        jitter_y = y + np.random.randint(-jitter_radius, jitter_radius + 1)
        # Keep within image bounds
        jitter_x = np.clip(jitter_x, 0, width - 1)
        jitter_y = np.clip(jitter_y, 0, height - 1)
        attached_x_jittered.append(jitter_x)
        attached_y_jittered.append(jitter_y)

# Step 3: Generate free-floating fluorophores randomly with jitter too
free_x = np.random.randint(0, width, num_free)
free_y = np.random.randint(0, height, num_free)

free_x_jittered = free_x + np.random.randint(-jitter_radius, jitter_radius + 1, size=num_free)
free_y_jittered = free_y + np.random.randint(-jitter_radius, jitter_radius + 1, size=num_free)

free_x_jittered = np.clip(free_x_jittered, 0, width - 1)
free_y_jittered = np.clip(free_y_jittered, 0, height - 1)

# Plotting: attached and free-floating fluorophores with jitter
plt.figure(figsize=(6, 6))
plt.imshow(image, cmap='gray')

# Plot attached fluorophores (each point)
plt.scatter(attached_x_jittered, attached_y_jittered, s=5, color='red', alpha=0.5, label='Attached fluorophores')

# Plot free-floating fluorophores
plt.scatter(free_x_jittered, free_y_jittered, s=5, color='blue', alpha=0.3, label='Free fluorophores')

plt.title(f"Fluorophores with jitter up to ±{jitter_radius} pixels")
plt.axis('off')
plt.legend()
plt.show()


In [None]:
num_frames = 2000
activation_prob = 0.003  # Probability that any given fluorophore is active in a given frame

# Reconstruct full list of attached fluorophore coordinates (with duplicates)
attached_coords_expanded = []
for (y, x), count in pixel_counts.items():
    attached_coords_expanded.extend([(x, y)] * count)  # (x, y)

attached_coords_expanded = np.array(attached_coords_expanded)  # shape (num_attached, 2)

# Free fluorophores
free_coords = np.column_stack((free_x, free_y))  # shape (num_free, 2)

# Combine all
all_coords = np.vstack([attached_coords_expanded, free_coords])
num_fluorophores = len(all_coords)

# Create empty frame list
frames = [[] for _ in range(num_frames)]

# For each fluorophore, randomly decide which frames it is active in
for coord in all_coords:
    for frame_idx in range(num_frames):
        if np.random.rand() < activation_prob:
            frames[frame_idx].append(tuple(coord))

# (Optional) Visualize one random frame
frame_idx = 1
fx, fy = zip(*frames[frame_idx]) if frames[frame_idx] else ([], [])

plt.figure(figsize=(6, 6))
plt.imshow(image, cmap='gray')
plt.scatter(fx, fy, s=5, color='red', alpha=0.8, label=f'Frame {frame_idx}')
plt.title(f"Fluorophores in Frame {frame_idx}")
plt.axis('off')
plt.legend()
plt.show()


In [None]:
import matplotlib.animation as animation
from IPython.display import HTML

# Set up the plot
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(image, cmap='gray')
scatter = ax.scatter([], [], s=5, color='blue')
ax.set_title("Fluorophore Activation Frames")
ax.axis('off')

# Init function for animation
def init():
    # Set empty 2D array for offsets
    scatter.set_offsets(np.empty((0, 2)))
    return scatter,

# Update function for each frame
def update(frame_num):
    points = frames[frame_num]
    if points:
        scatter.set_offsets(np.array(points))
    else:
        scatter.set_offsets(np.empty((0, 2)))
    ax.set_title(f"Frame {frame_num + 1} / {num_frames}")
    return scatter,

# Create animation
ani = animation.FuncAnimation(
    fig, update, frames=num_frames, init_func=init,
    interval=40, blit=True, repeat=True
)

# For Jupyter display
HTML(ani.to_jshtml())


In [None]:
fig2, ax2 = plt.subplots(figsize=(6, 6))
fig2.patch.set_facecolor('black')   # figure background
ax2.set_facecolor('black')           # axes background

scatter2 = ax2.scatter([], [], s=5, color='cyan')
ax2.set_title("Fluorophore Activation Frames (No Structure)", color='white')

ax2.axis('off')
ax2.set_xlim(0, width)
ax2.set_ylim(height, 0)  # invert y-axis to match image coordinates

# Remove all axis spines if visible
for spine in ax2.spines.values():
    spine.set_visible(False)

def init2():
    scatter2.set_offsets(np.empty((0, 2)))
    return scatter2,

def update2(frame_num):
    points = frames[frame_num]
    if points:
        scatter2.set_offsets(np.array(points))
    else:
        scatter2.set_offsets(np.empty((0, 2)))
    ax2.set_title(f"Frame {frame_num + 1} / {num_frames}", color='white')
    return scatter2,

ani2 = animation.FuncAnimation(
    fig2, update2, frames=num_frames, init_func=init2,
    interval=40, blit=True, repeat=True
)

HTML(ani2.to_jshtml())


In [None]:
import cv2

# Generate Gaussian PSF kernel
psf = gaussian_psf(size=15, sigma=3)

# Initialize empty list for blurred frames
blurred_frames = []

for frame_points in frames:
    # Create blank image for frame
    frame_img = np.zeros((height, width), dtype=np.float32)
    
    # Mark fluorophore positions as 1 (make sure coordinates are valid)
    for x, y in frame_points:
        if 0 <= y < height and 0 <= x < width:
            frame_img[y, x] = 1.0
    
    # Convolve with PSF using cv2.filter2D (same size output, float32)
    blurred = cv2.filter2D(frame_img, -1, psf, borderType=cv2.BORDER_REPLICATE)
    blurred_frames.append(blurred)

# Visualize one blurred frame example
plt.figure(figsize=(6,6))
plt.imshow(blurred_frames[1], cmap='inferno')
plt.title("Blurred Fluorophores Frame 0")
plt.axis('off')
plt.colorbar(label='Intensity')
plt.show()



In [None]:
fig3, ax3 = plt.subplots(figsize=(6, 6))
fig3.patch.set_facecolor('black')
ax3.set_facecolor('black')
im = ax3.imshow(np.zeros((height, width)), cmap='inferno', vmin=0, vmax=0.1)
ax3.axis('off')
ax3.set_title("Blurred Fluorophores Animation", color='white')

def update_blurred(frame_num):
    im.set_data(blurred_frames[frame_num])
    ax3.set_title(f"Frame {frame_num + 1} / {num_frames}", color='white')
    return [im]

ani3 = animation.FuncAnimation(
    fig3, update_blurred, frames=num_frames,
    interval=40, blit=True, repeat=True
)

HTML(ani3.to_jshtml())


In [None]:
# Sum all blurred frames to create a stacked image
stacked_blurred = np.sum(blurred_frames, axis=0)

# Normalize for display (optional)
stacked_blurred_norm = stacked_blurred / stacked_blurred.max()

plt.figure(figsize=(6,6))
plt.imshow(stacked_blurred_norm, cmap='inferno')
plt.title("Stacked Sum of All Blurred Frames")
plt.axis('off')
plt.show()



In [None]:
from scipy.signal import convolve2d
import numpy as np
import math

def richardson_lucy(image, psf, iterations=20):
    """
    Perform Richardson-Lucy deconvolution.
    
    Parameters:
    - image: 2D numpy array, observed blurred image
    - psf: 2D numpy array, point spread function (normalized)
    - iterations: int, number of iterations
    
    Returns:
    - estimate: deconvolved image estimate
    """
    estimate = np.full(image.shape, 0.5)  # Initial guess (uniform)
    psf_mirror = psf[::-1, ::-1]  # flipped PSF
    
    for i in range(iterations):
        conv_est = convolve2d(estimate, psf, mode='same')
        relative_blur = image / (conv_est + 1e-7)
        estimate *= convolve2d(relative_blur, psf_mirror, mode='same')
    
    return estimate


In [None]:
from tqdm import tqdm

num_frames = len(blurred_frames)
iterations = 30

# Deconvolve all frames with tqdm progress bar
deconvolved_frames = [
    richardson_lucy(frame, psf, iterations=iterations)
    for frame in tqdm(blurred_frames, desc="Deconvolving Frames")
]

# Determine grid size for plotting (square-ish)
cols = int(math.ceil(math.sqrt(num_frames)))
rows = int(math.ceil(num_frames / cols))




In [None]:
plt.figure(figsize=(6, 6))
plt.imshow(deconvolved_frames[1], cmap='inferno')
plt.title("Deconvolved Frame 0")
plt.axis('off')
plt.colorbar(label='Intensity')
plt.show()

In [None]:
from scipy.ndimage import maximum_filter, label, find_objects

def extract_points_from_frame(frame, threshold=0.1, footprint_size=3):
    """
    Extract (x, y) coordinates of bright points from a frame.
    
    Parameters:
    - frame: 2D numpy array (deconvolved image)
    - threshold: float, minimum intensity to consider a point
    - footprint_size: int, neighborhood size for local maxima detection
    
    Returns:
    - points: list of (x, y) tuples (pixel coordinates)
    """
    # Find local maxima by comparing frame to max filtered version
    footprint = np.ones((footprint_size, footprint_size))
    local_max = (frame == maximum_filter(frame, footprint=footprint))
    
    # Apply threshold to filter out low intensity points
    detected = local_max & (frame > threshold)
    
    # Label connected components (each max)
    labeled, num_features = label(detected)
    slices = find_objects(labeled)
    
    points = []
    for slc in slices:
        # Get coordinates of the max pixel in this component
        region = frame[slc]
        max_pos = np.unravel_index(np.argmax(region), region.shape)
        # Translate to image coordinates
        y = max_pos[0] + slc[0].start
        x = max_pos[1] + slc[1].start
        points.append((x, y))
    return points


In [None]:
# Extract points from all deconvolved frames
all_points = [extract_points_from_frame(frame) for frame in deconvolved_frames]

# Example: print number of detected points per frame
for i, pts in enumerate(all_points):
    print(f"Frame {i+1}: Detected {len(pts)} points")


In [None]:
# Initialize empty reconstruction image (float for accumulation)
reconstructed = np.zeros_like(deconvolved_frames[0], dtype=np.float32)

# Intensity value to add per detected point (tweak if needed)
intensity_per_point = 1.0

for points in all_points:
    for (x, y) in points:
        # Add intensity, making sure indices are in bounds
        if 0 <= y < reconstructed.shape[0] and 0 <= x < reconstructed.shape[1]:
            reconstructed[y, x] += intensity_per_point

# Optional: Normalize or clip values to [0, 1] or [0, 255] for display
reconstructed_norm = reconstructed / reconstructed.max()

# Display reconstructed image
plt.figure(figsize=(6,6))
plt.imshow(reconstructed_norm, cmap='inferno')
plt.title("Reconstructed Image from All Frames")
plt.axis('off')
plt.show()


In [None]:
# Initialize empty boolean mask
reconstructed_mask = np.zeros_like(deconvolved_frames[0], dtype=bool)

for points in all_points:
    for (x, y) in points:
        if 0 <= y < reconstructed_mask.shape[0] and 0 <= x < reconstructed_mask.shape[1]:
            reconstructed_mask[y, x] = True  # mark pixel as having a fluorophore

# Convert boolean mask to float image for display (1 = presence, 0 = absence)
cleaned_image = reconstructed_mask.astype(float)

plt.figure(figsize=(6,6))
plt.imshow(cleaned_image, cmap='inferno')
plt.title("Cleaned Reconstruction (Duplicates Removed)")
plt.axis('off')
plt.show()


In [None]:
from scipy.spatial import cKDTree
import numpy as np

def filter_free_floating(points, radius=5):
    """
    Remove points with no neighbors within a specified radius.
    
    Args:
        points: Nx2 numpy array of (x, y) fluorophore coordinates
        radius: float, neighborhood radius in pixels
        
    Returns:
        filtered_points: Nx2 numpy array of points considered non-free-floating
    """
    if len(points) == 0:
        return points
    
    tree = cKDTree(points)
    # Query neighbors within radius for each point (including itself)
    neighbors = tree.query_ball_point(points, r=radius)
    
    # Keep points that have at least one other neighbor (len > 1)
    filtered_indices = [i for i, nbrs in enumerate(neighbors) if len(nbrs) > 1]
    
    filtered_points = points[filtered_indices]
    return filtered_points


In [None]:
# Apply filtering to all frames
filtered_all_points = [filter_free_floating(np.array(frame_points), radius=6) for frame_points in all_points]

# Optional: print summary
for i, (orig, filt) in enumerate(zip(all_points, filtered_all_points)):
    print(f"Frame {i}: original points = {len(orig)}, filtered points = {len(filt)}")

In [None]:
# Initialize empty reconstruction image (float for accumulation)
reconstructed_filtered = np.zeros_like(deconvolved_frames[0], dtype=np.float32)

# Intensity value to add per detected point (adjust if needed)
intensity_per_point = 1.0

for points in filtered_all_points:
    for (x, y) in points:
        # Add intensity if indices are in bounds
        if 0 <= y < reconstructed_filtered.shape[0] and 0 <= x < reconstructed_filtered.shape[1]:
            reconstructed_filtered[y, x] += intensity_per_point

# Normalize for display
reconstructed_filtered_norm = reconstructed_filtered / reconstructed_filtered.max()

# Display the cleaned reconstructed image
plt.figure(figsize=(6,6))
plt.imshow(reconstructed_filtered_norm, cmap='inferno')
plt.title("Reconstructed Image from Filtered Points")
plt.axis('off')
plt.show()


In [None]:
# Remove duplicates from filtered points in each frame
unique_filtered_points = [np.array(list(set(map(tuple, frame_points)))) for frame_points in filtered_all_points]

# Initialize empty reconstruction image (float for accumulation)
reconstructed_unique = np.zeros_like(deconvolved_frames[0], dtype=np.float32)

intensity_per_point = 1.0

for points in unique_filtered_points:
    for (x, y) in points:
        if 0 <= y < reconstructed_unique.shape[0] and 0 <= x < reconstructed_unique.shape[1]:
            reconstructed_unique[y, x] += intensity_per_point

# Normalize for display
reconstructed_unique_norm = reconstructed_unique / reconstructed_unique.max()

# Plot white points on black background
plt.figure(figsize=(6, 6))
plt.imshow(np.zeros_like(reconstructed_unique_norm), cmap='gray')  # black background
ys, xs = zip(*np.vstack(unique_filtered_points)) if unique_filtered_points else ([], [])
plt.scatter(xs, ys, color='white', s=1, alpha=0.8)
plt.title("Unique Filtered Points (White on Black)")
plt.axis('off')
plt.show()


In [None]:
# Convert original structure pixels to a set for fast lookup (y,x)
structure_pixels = set(map(tuple, structure_coords))

# Count points on structure
points_on_structure = 0
total_points = 0

for frame_points in unique_filtered_points:
    total_points += len(frame_points)
    for (x, y) in frame_points:
        if (y, x) in structure_pixels:
            points_on_structure += 1

accuracy = points_on_structure / total_points if total_points > 0 else 0

print(f"Accuracy: {accuracy:.4f} ({points_on_structure} / {total_points} points on structure)")
