# PhaseMaskNet Training (2D Single Slice)

Train a U-Net model to predict phase masks from cross-sectional slices.

## Setup

In [None]:

!pip install piq
from google.colab import drive
drive.mount('/content/drive')

import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from piq import ssim as piq_ssim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## CrossSectionDataset

In [None]:

class CrossSectionDataset(Dataset):
    def __init__(self, img_dir, size=512, augment=True):
        self.img_dir = img_dir
        self.size = size
        self.augment = augment
        self.img_files = sorted([
            f for f in os.listdir(img_dir)
            if f.lower().endswith('.png')
        ])

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_files[idx])
        img = Image.open(img_path).convert("L")
        img = img.resize((self.size, self.size), Image.BICUBIC)
        img_np = np.array(img, dtype=np.float32) / 255.0

        if self.augment:
            if np.random.rand() < 0.5:
                img_np = np.flip(img_np, axis=0).copy()
            if np.random.rand() < 0.5:
                img_np = np.flip(img_np, axis=1).copy()

        noise = np.random.normal(loc=0.0, scale=0.02, size=img_np.shape)
        img_np = np.clip(img_np + noise, 0, 1)

        img_tensor = torch.from_numpy(img_np).unsqueeze(0).float()
        return img_tensor, img_tensor


## Model: DeepCGHUNet

In [None]:

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.net(x)

class DeepCGHUNet(nn.Module):
    def __init__(self, in_ch=1, base_feat=64):
        super().__init__()
        self.enc1 = DoubleConv(in_ch, base_feat)
        self.enc2 = DoubleConv(base_feat, base_feat * 2)
        self.enc3 = DoubleConv(base_feat * 2, base_feat * 4)
        self.enc4 = DoubleConv(base_feat * 4, base_feat * 8)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = DoubleConv(base_feat * 8, base_feat * 16)
        self.up3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.dec3 = DoubleConv(base_feat * 16 + base_feat * 8, base_feat * 8)
        self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.dec2 = DoubleConv(base_feat * 8 + base_feat * 4, base_feat * 4)
        self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.dec1 = DoubleConv(base_feat * 4 + base_feat * 2, base_feat * 2)
        self.final_up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.final_conv = nn.Sequential(
            nn.Conv2d(base_feat * 2 + base_feat, base_feat, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(base_feat, 1, 1)
        )

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(self.pool(x1))
        x3 = self.enc3(self.pool(x2))
        x4 = self.enc4(self.pool(x3))
        x5 = self.bottleneck(self.pool(x4))

        x = self.up3(x5)
        x = self.dec3(torch.cat([x, x4], dim=1))
        x = self.up2(x)
        x = self.dec2(torch.cat([x, x3], dim=1))
        x = self.up1(x)
        x = self.dec1(torch.cat([x, x2], dim=1))
        x = self.final_up(x)
        x = self.final_conv(torch.cat([x, x1], dim=1))
        return x


## Training PhaseMaskNet

In [None]:

img_dir = "/content/drive/MyDrive/cross_section_datasets/circles_only"
dataset = CrossSectionDataset(img_dir, augment=True)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_ds, test_ds = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=4, shuffle=False)

model = DeepCGHUNet(in_ch=1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def angular_spectrum_propagation(phase, wavelength=405e-9, pixel_size=1.5e-6, z=0.005):
    B, _, H, W = phase.shape
    fx = torch.fft.fftfreq(W, d=pixel_size).to(phase.device)
    fy = torch.fft.fftfreq(H, d=pixel_size).to(phase.device)
    FX, FY = torch.meshgrid(fx, fy, indexing='xy')
    H_z = torch.exp(-1j * np.pi * wavelength * z * (FX**2 + FY**2))

    wrapped_phase = phase % 1.0
    field = torch.exp(1j * 2 * np.pi * wrapped_phase)
    field_f = torch.fft.fft2(field)
    field_z = field_f * H_z
    propagated = torch.fft.ifft2(field_z)
    intensity = torch.abs(propagated) ** 2
    return intensity

def gradient_loss(pred):
    dx = pred[:, :, :, 1:] - pred[:, :, :, :-1]
    dy = pred[:, :, 1:, :] - pred[:, :, :-1, :]
    return torch.mean(torch.abs(dx)) + torch.mean(torch.abs(dy))

for epoch in range(1, 201):
    model.train()
    total_loss = 0.0
    for x, target_intensity in train_loader:
        x, target_intensity = x.to(device), target_intensity.to(device)
        optimizer.zero_grad()
        pred_phase = model(x)
        pred_intensity = angular_spectrum_propagation(pred_phase)
        loss = F.mse_loss(pred_intensity, target_intensity) + 0.1 * gradient_loss(pred_phase)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    print(f"[Epoch {epoch}] Loss: {avg_loss:.6f}")

torch.save(model.state_dict(), "phase_mask_net.pth")
