# PhaseMaskNet Multilayer Inference

In [None]:

import torch
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
import numpy as np
from piq import ssim, psnr
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:

def angular_spectrum_phase_to_intensity(phase, z, lam=405e-9, pixel_size=1.5e-6):
    B, _, H, W = phase.shape
    device = phase.device
    fx = torch.fft.fftfreq(W, d=pixel_size).to(device)
    fy = torch.fft.fftfreq(H, d=pixel_size).to(device)
    FX, FY = torch.meshgrid(fx, fy, indexing='ij')
    FX, FY = FX.to(device), FY.to(device)
    p = (FX ** 2 + FY ** 2) * lam**2
    sp = torch.sqrt(torch.clamp(1 - p, min=0)).to(device)
    q = torch.exp(2j * np.pi * z / lam * sp)

    field = torch.exp(1j * 2 * np.pi * (phase % 1.0))
    field_fft = torch.fft.fft2(field)
    propagated = torch.fft.ifft2(field_fft * q)
    intensity = torch.abs(propagated) ** 2
    return intensity


In [None]:

class MultiPlaneCrossSectionDataset(Dataset):
    def __init__(self, size=512, num_samples=1, z_planes=3, radius_range=(10, 30)):
        self.size = size
        self.num_samples = num_samples
        self.z_planes = z_planes
        self.radius_range = radius_range
        self.data = self._generate_dataset()

    def _generate_random_blob(self):
        img = Image.new("L", (self.size, self.size), 0)
        draw = ImageDraw.Draw(img)
        for _ in range(np.random.randint(5, 12)):
            r = np.random.randint(*self.radius_range)
            x = np.random.randint(r, self.size - r)
            y = np.random.randint(r, self.size - r)
            draw.ellipse((x - r, y - r, x + r, y + r), fill=255)
        return np.array(img, dtype=np.float32) / 255.0

    def _generate_dataset(self):
        return [
            [self._generate_random_blob() for _ in range(self.z_planes)]
            for _ in range(self.num_samples)
        ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        targets = self.data[idx]
        input_hint = np.mean(targets, axis=0)
        return torch.tensor(input_hint).unsqueeze(0), torch.tensor(np.array(targets))


In [None]:

def evaluate_and_visualize_3d_accuracy(model, z_planes=[0.005, 0.01, 0.015], lam=405e-9, pixel_size=1.5e-6):
    model.eval()

    dataset = MultiPlaneCrossSectionDataset(num_samples=1)
    x, y_stack = dataset[0]
    input_tensor = x.unsqueeze(0).to(device)

    with torch.no_grad():
        pred_phase = model(input_tensor)
        wrapped_phase = (pred_phase[0, 0].cpu().numpy() % 1.0)

        recon_layers = []
        ssim_vals, psnr_vals, bin_accs = [], [], []

        for i, z in enumerate(z_planes):
            recon = angular_spectrum_phase_to_intensity(pred_phase, z, lam, pixel_size)
            recon_np = recon[0, 0].cpu().numpy()

            target_np = y_stack[i].numpy()

            recon_norm = (recon_np - recon_np.min()) / (recon_np.max() - recon_np.min())
            target_norm = (target_np - target_np.min()) / (target_np.max() - target_np.min())

            rt = torch.tensor(recon_norm).unsqueeze(0).unsqueeze(0).float()
            tt = torch.tensor(target_norm).unsqueeze(0).unsqueeze(0).float()
            ssim_vals.append(ssim(rt, tt, data_range=1.0).item())
            psnr_vals.append(psnr(rt, tt, data_range=1.0).item())
            bin_accs.append(((rt > 0.5) == (tt > 0.5)).float().mean().item())

            recon_layers.append(recon_norm)

    fig, axs = plt.subplots(2, len(z_planes) + 1, figsize=(5 * (len(z_planes) + 1), 8))

    axs[0, 0].imshow(x[0], cmap='gray')
    axs[0, 0].set_title("Input (Hint Image)")
    axs[0, 0].axis('off')

    im_phase = axs[1, 0].imshow(wrapped_phase, cmap='hot', vmin=0, vmax=1)
    axs[1, 0].set_title("Predicted Phase Mask")
    axs[1, 0].axis('off')
    plt.colorbar(im_phase, ax=axs[1, 0], shrink=0.8, label='Phase (0–1)')

    for i in range(len(z_planes)):
        axs[0, i+1].imshow(y_stack[i], cmap='gray')
        axs[0, i+1].set_title(f"Target @ z={int(z_planes[i]*1e3)} mm")
        axs[0, i+1].axis('off')

        im = axs[1, i+1].imshow(recon_layers[i], cmap='hot')
        axs[1, i+1].set_title(
            f"Recon @ z={int(z_planes[i]*1e3)} mm\n"
            f"SSIM: {ssim_vals[i]:.3f} | PSNR: {psnr_vals[i]:.1f} dB\n"
            f"Bin Acc: {bin_accs[i]*100:.1f}%"
        )
        axs[1, i+1].axis('off')
        plt.colorbar(im, ax=axs[1, i+1], shrink=0.8, label='Norm. Intensity')

    plt.tight_layout()
    plt.show()

    print("Evaluation Summary Across Planes:")
    for i, z in enumerate(z_planes):
        print(f"z = {z*1e3:.0f} mm:")
        print(f"  SSIM:          {ssim_vals[i]:.4f} ({ssim_vals[i]*100:.2f}%)")
        print(f"  PSNR:          {psnr_vals[i]:.2f} dB")
        print(f"  Pixel Accuracy: {bin_accs[i]*100:.2f}%")
