In [None]:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

def load_image_grayscale(path):
    img = Image.open(path).convert('L')  # Convert to grayscale
    img_np = np.array(img).astype(np.float32) / 255.0  # Normalize to [0, 1]
    return img_np


def compute_fft2(image):
    f = np.fft.fft2(image)
    fshift = np.fft.fftshift(f)
    return fshift

def generate_circular_masks(shape, grid_size, radius, overlap=0.15):
    """
    shape: (H, W) of image
    grid_size: (rows, cols)
    radius: radius of each circular patch
    overlap: fraction of diameter to overlap (default 15%)
    Returns a list of circular masks centered on a grid
    """
    H, W = shape
    masks = []
    cx, cy = W // 2, H // 2  # center of frequency domain

    # Adjusted spacing between centers to achieve desired overlap
    spacing = 2 * radius * (1 - overlap)

    # Compute grid point coordinates
    x_lin = np.linspace(cx - (grid_size[1] - 1) / 2 * spacing,
                        cx + (grid_size[1] - 1) / 2 * spacing,
                        grid_size[1])
    y_lin = np.linspace(cy - (grid_size[0] - 1) / 2 * spacing,
                        cy + (grid_size[0] - 1) / 2 * spacing,
                        grid_size[0])
    xv, yv = np.meshgrid(x_lin, y_lin)

    Y, X = np.ogrid[:H, :W]

    for i in range(grid_size[0]):
        for j in range(grid_size[1]):
            center_x, center_y = int(round(xv[i, j])), int(round(yv[i, j]))
            mask = (X - center_x)**2 + (Y - center_y)**2 <= radius**2
            masks.append(mask.astype(np.float32))

    return masks

def apply_mask_to_fft(fft_image, mask):
    return fft_image * mask


def compute_ifft2(fft_masked):
    f_ishift = np.fft.ifftshift(fft_masked)
    img_back = np.fft.ifft2(f_ishift)
    return np.abs(img_back)


def show_image_and_spectrum(image, fft_image, title_prefix=""):
    magnitude_spectrum = np.log1p(np.abs(fft_image))

    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(image, cmap='gray')
    plt.title(f'{title_prefix}Image')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(magnitude_spectrum, cmap='gray')
    plt.title(f'{title_prefix}FFT Magnitude')
    plt.axis('off')

    plt.tight_layout()
    plt.show()


def visualize_masks_overlay(fft_image, masks, alpha=0.3):
    base = np.log1p(np.abs(fft_image))
    combined_mask = np.zeros_like(base)

    for mask in masks:
        combined_mask += mask

    plt.figure(figsize=(6, 6))
    plt.imshow(base, cmap='gray')
    plt.imshow(combined_mask, cmap='jet', alpha=alpha)
    plt.title("Circular Mask Grid Overlay")
    plt.axis('off')
    plt.show()


def display_low_res_images(images, grid_size):
    rows, cols = grid_size
    fig, axs = plt.subplots(rows, cols, figsize=(2 * cols, 2 * rows))
    for i in range(rows):
        for j in range(cols):
            idx = i * cols + j
            axs[i, j].imshow(images[idx], cmap='gray')
            axs[i, j].axis('off')
            axs[i, j].set_title(f'Patch ({i},{j})')
    plt.tight_layout()
    plt.show()


# Load and process image
img = load_image_grayscale('/Users/yenwangcheng/Downloads/Venusfliegenfalle.jpeg')
fshift = compute_fft2(img)

# Show original and spectrum
show_image_and_spectrum(img, fshift, title_prefix="Original ")

# Create masks
grid_size = (5, 5)
radius = 60
masks = generate_circular_masks(img.shape, grid_size, radius, overlap=0.15)

# Visualize masks on spectrum
visualize_masks_overlay(fshift, masks)
