# Phase Mask Multi-Layer Training

## Purpose
Train a model to predict a single phase mask capable of reconstructing three cross-sectional slices at different depths.

## Setup

In [None]:

!pip install piq
from google.colab import drive
drive.mount('/content/drive')

import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import piq

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Data Preparation

In [None]:

class CrossSectionMultiLayerDataset(Dataset):
    def __init__(self, img_dir_list, size=128):
        self.img_dirs = img_dir_list
        self.size = size
        self.img_files = sorted([
            f for f in os.listdir(img_dir_list[0])
            if f.lower().endswith('.png')
        ])

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        imgs = []
        for d in self.img_dirs:
            img_path = os.path.join(d, self.img_files[idx])
            img = Image.open(img_path).convert('L')
            img = img.resize((self.size, self.size))
            img = np.array(img, dtype=np.float32) / 255.0
            imgs.append(img)
        imgs = np.stack(imgs, axis=0)
        imgs = torch.from_numpy(imgs)
        return imgs


## Model Architecture

In [None]:

class PhaseMaskNetMultiLayer(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128]):
        super(PhaseMaskNetMultiLayer, self).__init__()
        self.encoder1 = nn.Sequential(
            nn.Conv2d(in_channels, features[0], 3, padding=1),
            nn.BatchNorm2d(features[0]),
            nn.ReLU(inplace=True),
            nn.Conv2d(features[0], features[0], 3, padding=1),
            nn.BatchNorm2d(features[0]),
            nn.ReLU(inplace=True)
        )
        self.pool1 = nn.MaxPool2d(2)
        self.encoder2 = nn.Sequential(
            nn.Conv2d(features[0], features[1], 3, padding=1),
            nn.BatchNorm2d(features[1]),
            nn.ReLU(inplace=True),
            nn.Conv2d(features[1], features[1], 3, padding=1),
            nn.BatchNorm2d(features[1]),
            nn.ReLU(inplace=True)
        )
        self.pool2 = nn.MaxPool2d(2)
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features[1], features[2], 3, padding=1),
            nn.BatchNorm2d(features[2]),
            nn.ReLU(inplace=True),
            nn.Conv2d(features[2], features[2], 3, padding=1),
            nn.BatchNorm2d(features[2]),
            nn.ReLU(inplace=True)
        )
        self.upconv2 = nn.ConvTranspose2d(features[2], features[1], 2, 2)
        self.decoder2 = nn.Sequential(
            nn.Conv2d(features[2], features[1], 3, padding=1),
            nn.BatchNorm2d(features[1]),
            nn.ReLU(inplace=True),
            nn.Conv2d(features[1], features[1], 3, padding=1),
            nn.BatchNorm2d(features[1]),
            nn.ReLU(inplace=True)
        )
        self.upconv1 = nn.ConvTranspose2d(features[1], features[0], 2, 2)
        self.decoder1 = nn.Sequential(
            nn.Conv2d(features[1], features[0], 3, padding=1),
            nn.BatchNorm2d(features[0]),
            nn.ReLU(inplace=True),
            nn.Conv2d(features[0], features[0], 3, padding=1),
            nn.BatchNorm2d(features[0]),
            nn.ReLU(inplace=True)
        )
        self.conv_final = nn.Conv2d(features[0], out_channels, 1)

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        bottleneck = self.bottleneck(self.pool2(enc2))
        dec2 = self.upconv2(bottleneck)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return self.conv_final(dec1)


## Training Setup

In [None]:

layer_dirs = [
    '/content/drive/MyDrive/layer1',
    '/content/drive/MyDrive/layer2',
    '/content/drive/MyDrive/layer3'
]

dataset = CrossSectionMultiLayerDataset(layer_dirs, size=128)
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

model = PhaseMaskNetMultiLayer().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def ssim_loss(predicted, target):
    return 1 - piq.ssim(predicted, target, data_range=1.0)

num_epochs = 30
train_losses = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for targets in train_loader:
        targets = targets.to(device)
        inputs = targets[:, 0:1, :, :]
        outputs = model(inputs)

        loss = ssim_loss(outputs, inputs)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {epoch_loss:.4f}")

torch.save(model.state_dict(), "phase_mask_multilayer_net.pth")
