
# Phase Mask Generation Model (PhaseMaskNet)

## Purpose
Train a model (PhaseMaskNet) to predict phase masks that can reconstruct cross-sectional images at specified z-heights for holographic projection.

This project aims to eventually integrate metasurfaces into additive manufacturing workflows by projecting complex layers during the build process.

---



## Setup
- Install necessary packages (`piq`, `torch`, `PIL`, etc.)
- Mount Google Drive if using external datasets.

---


In [None]:

!pip install piq
from google.colab import drive
drive.mount('/content/drive')

import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import piq

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



## Data Preparation

- Loads cross-section images from a folder.
- Assumes images are in `.png` format, grayscale.
- Resizes all images to a fixed size (currently low-resolution: 128×128 — **recommended to increase later**).
- Applies optional data augmentation if specified.

The dataset expects a folder structure like:

```
/path/to/cross_sections/
    slice_001.png
    slice_002.png
    ...
```

---


In [None]:

class CrossSectionDataset(Dataset):
    def __init__(self, img_dir, size=128, augment=True):
        self.img_dir = img_dir
        self.size = size
        self.augment = augment
        self.img_files = sorted([
            f for f in os.listdir(img_dir)
            if f.lower().endswith('.png')
        ])

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_name = self.img_files[idx]
        img_path = os.path.join(self.img_dir, img_name)
        img = Image.open(img_path).convert('L')
        img = img.resize((self.size, self.size))
        img = np.array(img, dtype=np.float32) / 255.0

        if self.augment and np.random.rand() > 0.5:
            img = np.fliplr(img)

        img = torch.from_numpy(img).unsqueeze(0)
        return img



## Model Architecture - PhaseMaskNet

- Model is based on a U-Net style encoder-decoder.
- Designed to predict phase masks instead of regular images.
- **This architecture is the core result of our project.**

---


In [None]:

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class PhaseMaskNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128]):
        super(PhaseMaskNet, self).__init__()
        self.encoder1 = DoubleConv(in_channels, features[0])
        self.pool1 = nn.MaxPool2d(2)
        
        self.encoder2 = DoubleConv(features[0], features[1])
        self.pool2 = nn.MaxPool2d(2)
        
        self.bottleneck = DoubleConv(features[1], features[2])

        self.upconv2 = nn.ConvTranspose2d(features[2], features[1], kernel_size=2, stride=2)
        self.decoder2 = DoubleConv(features[2], features[1])

        self.upconv1 = nn.ConvTranspose2d(features[1], features[0], kernel_size=2, stride=2)
        self.decoder1 = DoubleConv(features[1], features[0])

        self.conv_final = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))

        bottleneck = self.bottleneck(self.pool2(enc2))

        dec2 = self.upconv2(bottleneck)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return self.conv_final(dec1)



## Loss Function - SSIM

- SSIM focuses on perceptual similarity rather than pixel-wise differences.

---


In [None]:

def ssim_loss(predicted, target):
    return 1 - piq.ssim(predicted, target, data_range=1.0)



## Training Setup

- Simple PyTorch training loop.
- Train until loss convergence.

---


In [None]:

learning_rate = 1e-3
batch_size = 4
num_epochs = 30

dataset = CrossSectionDataset(img_dir='/content/drive/MyDrive/slices', size=128)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

model = PhaseMaskNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

train_losses = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for inputs in train_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        loss = ssim_loss(outputs, inputs)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)
    
    print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {epoch_loss:.4f}")

torch.save(model.state_dict(), "phase_mask_net.pth")



## Results and Visualization

- Compare input slice vs predicted phase mask.

---


In [None]:

model.eval()
n_samples = 4
fig, axs = plt.subplots(n_samples, 2, figsize=(6, n_samples * 3))

for i in range(n_samples):
    img = dataset[i].unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(img)

    input_img = img.squeeze().cpu().numpy()
    output_img = output.squeeze().cpu().numpy()

    axs[i, 0].imshow(input_img, cmap='gray')
    axs[i, 0].set_title("Input Slice")
    axs[i, 0].axis('off')

    axs[i, 1].imshow(output_img, cmap='gray')
    axs[i, 1].set_title("Predicted Phase Mask")
    axs[i, 1].axis('off')

plt.tight_layout()
plt.show()



## Future Work and Next Steps

1. **Upgrade resolution**: move to 512×512 or higher.
2. **Use real-world STL slices**: for better realism.
3. **Deploy model**: inference notebook provided.
4. **Explore full 3D models**.

---
