In [None]:
import tifffile
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle
from scipy.ndimage import convolve
from skimage.transform import resize
import os


In [None]:

# -----------------------
# CONFIGURATION
# -----------------------
datapath = r"D:\Masters (we in boys)\Themed Research Project\Themed Research Project\2. Resources\data\thunderstrom output\from scratch"
tiff_filename = "Artificial dataset.tif"
csv_filename = "ground truth.csv"
upsampling_factor = 8
camera_pixelsize = 80  # in nm
gaussian_sigma = 1
patch_size = 26 * upsampling_factor
num_patches_per_frame = 500
max_examples = 10000
min_emitters = 7
save_path = (r"D:\Masters (we in boys)\Themed Research Project\Themed Research Project\2. Resources\Recreation\from scratch_training_data.pkl")


In [None]:

# -----------------------
# LOAD DATA
# -----------------------
tiff_path = os.path.join(datapath, tiff_filename)
csv_path = os.path.join(datapath, csv_filename)

stack = tifffile.imread(tiff_path)
df = pd.read_csv(csv_path)

num_frames, height, width = stack.shape
Mhr, Nhr = height * upsampling_factor, width * upsampling_factor
patches, heatmaps, spikes = [], [], []


In [None]:

# -----------------------
# HEATMAP PSF
# -----------------------
psf_kernel = np.exp(-((np.arange(-3, 4)[:, None]**2 + np.arange(-3, 4)[None, :]**2) / (2 * gaussian_sigma**2)))
psf_kernel /= psf_kernel.sum()



In [None]:
plt.imshow(stack[0], cmap='gray')
plt.title("Frame 1")
plt.axis('off')
plt.show()

In [None]:

# -----------------------
# UTILITY: Random valid centers
# -----------------------
def get_random_patch_centers(shape, patch_half, num_points):
    margin_r, margin_c = patch_half, patch_half
    valid_r = np.random.randint(margin_r, shape[0] - margin_r, size=num_points)
    valid_c = np.random.randint(margin_c, shape[1] - margin_c, size=num_points)
    return list(zip(valid_r, valid_c))


In [None]:

# -----------------------
# MAIN EXTRACTION LOOP
# -----------------------
k = 0
for frm in range(1, num_frames + 1):
    frame = stack[frm - 1].astype(np.float32)
    frame_us = resize(frame, (Mhr, Nhr), order=0, preserve_range=True, anti_aliasing=False)

    df_frame = df[df['frame'] == frm]
    col_x_nm = df_frame['x [nm]'].values
    col_y_nm = df_frame['y [nm]'].values

    col_x = np.clip((col_x_nm / (camera_pixelsize / upsampling_factor)).astype(int), 0, Nhr - 1)
    col_y = np.clip((col_y_nm / (camera_pixelsize / upsampling_factor)).astype(int), 0, Mhr - 1)

    spike_img = np.zeros((Mhr, Nhr), dtype=np.uint8)
    spike_img[col_y, col_x] = 1
    heatmap_img = convolve(spike_img.astype(float), psf_kernel, mode='constant')

    patch_half = patch_size // 2
    centers = get_random_patch_centers((Mhr, Nhr), patch_half, num_patches_per_frame)

    for r, c in centers:
        spike_patch = spike_img[r - patch_half:r + patch_half, c - patch_half:c + patch_half]
        if spike_patch.sum() < min_emitters:
            continue

        patch = frame_us[r - patch_half:r + patch_half, c - patch_half:c + patch_half]
        heatmap_patch = heatmap_img[r - patch_half:r + patch_half, c - patch_half:c + patch_half]

        if patch.shape != (patch_size, patch_size):
            continue  # skip any edge artifacts

        patches.append(patch.astype(np.float32))
        heatmaps.append(heatmap_patch.astype(np.float32))
        spikes.append(spike_patch.astype(np.bool_))
        k += 1

        if k >= max_examples:
            break
    if k >= max_examples:
        break

    print(f"Frame {frm}/{num_frames} — Total patches: {k}")


In [None]:

def debug_first_frame(datapath, tiff_filename, csv_filename,
                      upsampling_factor=8, camera_pixelsize=80,
                      patch_size=208, min_emitters=7, num_patches=50):

    # Load image stack and CSV
    tiff_path = os.path.join(datapath, tiff_filename)
    csv_path = os.path.join(datapath, csv_filename)

    stack = tifffile.imread(tiff_path)
    df = pd.read_csv(csv_path)
    frame = stack[0].astype(np.float32)
    Mhr, Nhr = frame.shape[0] * upsampling_factor, frame.shape[1] * upsampling_factor

    # Upsample the image
    frame_us = resize(frame, (Mhr, Nhr), order=0, preserve_range=True, anti_aliasing=False)

    # Get emitter coordinates from CSV
    df_frame = df[df['frame'] == 1]
    col_x_nm = df_frame['x [nm]'].values
    col_y_nm = df_frame['y [nm]'].values
    px_size_hr = camera_pixelsize / upsampling_factor

    col_x = np.clip((col_x_nm / px_size_hr).astype(int), 0, Nhr - 1)
    col_y = np.clip((col_y_nm / px_size_hr).astype(int), 0, Mhr - 1)

    # Create spike image
    spike_img = np.zeros((Mhr, Nhr), dtype=np.uint8)
    spike_img[col_y, col_x] = 1

    # Pick patch centers
    patch_half = patch_size // 2
    def get_centers(shape, margin, num_points):
        valid_r = np.random.randint(margin, shape[0] - margin, size=num_points)
        valid_c = np.random.randint(margin, shape[1] - margin, size=num_points)
        return list(zip(valid_r, valid_c))

    centers = get_centers((Mhr, Nhr), patch_half, num_patches)

    # Evaluate patches
    valid_patch_count = 0
    for r, c in centers:
        spike_patch = spike_img[r - patch_half:r + patch_half, c - patch_half:c + patch_half]
        if spike_patch.sum() >= min_emitters:
            valid_patch_count += 1

    # Plot
    plt.figure(figsize=(10, 5))
    plt.imshow(frame_us, cmap='gray')
    plt.scatter(col_x, col_y, s=10, c='r', marker='+', label='Emitters')

    for r, c in centers:
        rect = plt.Rectangle((c - patch_half, r - patch_half), patch_size, patch_size,
                             edgecolor='blue', facecolor='none', linewidth=1)
        plt.gca().add_patch(rect)

    plt.title(f"Frame 1: Emitters = {len(col_x)} | Candidate patches = {num_patches} | Valid = {valid_patch_count}")
    plt.legend()
    plt.axis('off')
    plt.show()

In [None]:
debug_first_frame(
    datapath=r"D:\Masters (we in boys)\Themed Research Project\Themed Research Project\2. Resources\data\thunderstrom output\from scratch",
    tiff_filename="Artificial dataset.tif",
    csv_filename="ground truth.csv"
)


In [None]:

# -----------------------
# SAVE WITH PICKLE
# -----------------------
print(f"Saving {k} examples to {save_path}")
with open(save_path, "wb") as f:
    pickle.dump({
        "patches": np.stack(patches),
        "heatmaps": np.stack(heatmaps),
        "spikes": np.stack(spikes)
    }, f)