# Phase Mask Multi-Layer Inference

## Purpose
Load a trained model to predict phase masks that reconstruct multiple slices.

## Setup

In [None]:

!pip install piq
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Load Trained Model

In [None]:

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class PhaseMaskNetMultiLayer(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128]):
        super().__init__()
        self.encoder1 = DoubleConv(in_channels, features[0])
        self.pool1 = nn.MaxPool2d(2)
        self.encoder2 = DoubleConv(features[0], features[1])
        self.pool2 = nn.MaxPool2d(2)
        self.bottleneck = DoubleConv(features[1], features[2])
        self.upconv2 = nn.ConvTranspose2d(features[2], features[1], 2, 2)
        self.decoder2 = DoubleConv(features[2], features[1])
        self.upconv1 = nn.ConvTranspose2d(features[1], features[0], 2, 2)
        self.decoder1 = DoubleConv(features[1], features[0])
        self.conv_final = nn.Conv2d(features[0], out_channels, 1)

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        bottleneck = self.bottleneck(self.pool2(enc2))
        dec2 = self.upconv2(bottleneck)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return self.conv_final(dec1)

model = PhaseMaskNetMultiLayer().to(device)
model.load_state_dict(torch.load('/content/phase_mask_multilayer_net.pth', map_location=device))
model.eval()


## Predict and Visualize

In [None]:

def load_image(path, size=128):
    img = Image.open(path).convert('L')
    img = img.resize((size, size))
    img = np.array(img, dtype=np.float32) / 255.0
    img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0)
    return img.to(device)

image_path = '/content/drive/MyDrive/layer1/slice_001.png'
img = load_image(image_path)

with torch.no_grad():
    pred = model(img)

input_img = img.squeeze().cpu().numpy()
pred_img = pred.squeeze().cpu().numpy()

fig, axs = plt.subplots(1, 2, figsize=(8, 4))
axs[0].imshow(input_img, cmap='gray')
axs[0].set_title('Input Slice')
axs[0].axis('off')

axs[1].imshow(pred_img, cmap='gray')
axs[1].set_title('Predicted Phase Mask (Multi-layer)')
axs[1].axis('off')

plt.tight_layout()
plt.show()
