# PhaseMaskNet Multilayer Training

In [None]:

from PIL import Image, ImageDraw
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.net(x)

class DeepCGHUNet(nn.Module):
    def __init__(self, in_ch=1, base_feat=64):
        super().__init__()
        self.enc1 = DoubleConv(in_ch, base_feat)
        self.enc2 = DoubleConv(base_feat, base_feat * 2)
        self.enc3 = DoubleConv(base_feat * 2, base_feat * 4)
        self.enc4 = DoubleConv(base_feat * 4, base_feat * 8)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = DoubleConv(base_feat * 8, base_feat * 16)
        self.up3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.dec3 = DoubleConv(base_feat * 16 + base_feat * 8, base_feat * 8)
        self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.dec2 = DoubleConv(base_feat * 8 + base_feat * 4, base_feat * 4)
        self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.dec1 = DoubleConv(base_feat * 4 + base_feat * 2, base_feat * 2)
        self.final_up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.final_conv = nn.Sequential(
            nn.Conv2d(base_feat * 2 + base_feat, base_feat, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(base_feat, 1, 1)
        )

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(self.pool(x1))
        x3 = self.enc3(self.pool(x2))
        x4 = self.enc4(self.pool(x3))
        x5 = self.bottleneck(self.pool(x4))

        x = self.up3(x5)
        x = self.dec3(torch.cat([x, x4], dim=1))
        x = self.up2(x)
        x = self.dec2(torch.cat([x, x3], dim=1))
        x = self.up1(x)
        x = self.dec1(torch.cat([x, x2], dim=1))
        x = self.final_up(x)
        x = self.final_conv(torch.cat([x, x1], dim=1))
        return x


In [None]:

class MultiPlaneCrossSectionDataset(Dataset):
    def __init__(self, size=512, num_samples=100, z_planes=3, radius_range=(10, 30)):
        self.size = size
        self.num_samples = num_samples
        self.z_planes = z_planes
        self.radius_range = radius_range
        self.data = self._generate_dataset()

    def _generate_random_blob(self):
        img = Image.new("L", (self.size, self.size), 0)
        draw = ImageDraw.Draw(img)
        for _ in range(np.random.randint(5, 12)):
            r = np.random.randint(*self.radius_range)
            x = np.random.randint(r, self.size - r)
            y = np.random.randint(r, self.size - r)
            draw.ellipse((x - r, y - r, x + r, y + r), fill=255)
        return np.array(img, dtype=np.float32) / 255.0

    def _generate_dataset(self):
        return [
            [self._generate_random_blob() for _ in range(self.z_planes)]
            for _ in range(self.num_samples)
        ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        targets = self.data[idx]
        input_hint = np.mean(targets, axis=0)
        return torch.tensor(input_hint).unsqueeze(0), torch.tensor(np.array(targets))


In [None]:

def angular_spectrum_phase_to_intensity(phase, z, lam=405e-9, pixel_size=1.5e-6):
    B, _, H, W = phase.shape
    device = phase.device
    fx = torch.fft.fftfreq(W, d=pixel_size).to(device)
    fy = torch.fft.fftfreq(H, d=pixel_size).to(device)
    FX, FY = torch.meshgrid(fx, fy, indexing='ij')
    FX, FY = FX.to(device), FY.to(device)
    p = (FX ** 2 + FY ** 2) * lam**2
    sp = torch.sqrt(torch.clamp(1 - p, min=0)).to(device)
    q = torch.exp(2j * np.pi * z / lam * sp)

    field = torch.exp(1j * 2 * np.pi * (phase % 1.0))
    field_fft = torch.fft.fft2(field)
    propagated = torch.fft.ifft2(field_fft * q)
    intensity = torch.abs(propagated) ** 2
    return intensity


In [None]:

dataset3D = MultiPlaneCrossSectionDataset(num_samples=100)
train_loader = DataLoader(dataset3D, batch_size=2, shuffle=True)

model3D = DeepCGHUNet(in_ch=1).to(device)
optimizer = torch.optim.Adam(model3D.parameters(), lr=1e-3)

z_planes = [0.005, 0.01, 0.015]

for epoch in range(1, 61):
    model3D.train()
    total_loss = 0
    for x, y_stack in train_loader:
        x = x.to(device)
        y_stack = y_stack.to(device)  # Shape: (B, Z, H, W)
        optimizer.zero_grad()
        phase = model3D(x)  # Shape: (B, 1, H, W)

        loss = 0
        for zi in range(len(z_planes)):
            recon = angular_spectrum_phase_to_intensity(phase, z_planes[zi])
            target = y_stack[:, zi].unsqueeze(1)
            loss += F.mse_loss(recon, target)

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"[Epoch {epoch}] Avg Loss: {total_loss / len(train_loader):.6f}")
