# PhaseMaskNet Inference (2D Single Slice)

Use a trained PhaseMaskNet model to predict phase masks for new input slices.

## Setup

In [None]:

!pip install piq
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Model Definition

In [None]:

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.net(x)

class DeepCGHUNet(nn.Module):
    def __init__(self, in_ch=1, base_feat=64):
        super().__init__()
        self.enc1 = DoubleConv(in_ch, base_feat)
        self.enc2 = DoubleConv(base_feat, base_feat * 2)
        self.enc3 = DoubleConv(base_feat * 2, base_feat * 4)
        self.enc4 = DoubleConv(base_feat * 4, base_feat * 8)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = DoubleConv(base_feat * 8, base_feat * 16)
        self.up3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.dec3 = DoubleConv(base_feat * 16 + base_feat * 8, base_feat * 8)
        self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.dec2 = DoubleConv(base_feat * 8 + base_feat * 4, base_feat * 4)
        self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.dec1 = DoubleConv(base_feat * 4 + base_feat * 2, base_feat * 2)
        self.final_up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.final_conv = nn.Sequential(
            nn.Conv2d(base_feat * 2 + base_feat, base_feat, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(base_feat, 1, 1)
        )

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(self.pool(x1))
        x3 = self.enc3(self.pool(x2))
        x4 = self.enc4(self.pool(x3))
        x5 = self.bottleneck(self.pool(x4))

        x = self.up3(x5)
        x = self.dec3(torch.cat([x, x4], dim=1))
        x = self.up2(x)
        x = self.dec2(torch.cat([x, x3], dim=1))
        x = self.up1(x)
        x = self.dec1(torch.cat([x, x2], dim=1))
        x = self.final_up(x)
        x = self.final_conv(torch.cat([x, x1], dim=1))
        return x


## Load Trained Model

In [None]:

model = DeepCGHUNet(in_ch=1).to(device)
model.load_state_dict(torch.load("/content/phase_mask_net.pth", map_location=device))
model.eval()


## Prediction Example

In [None]:

def load_image(path, size=512):
    img = Image.open(path).convert('L')
    img = img.resize((size, size))
    img = np.array(img, dtype=np.float32) / 255.0
    img_tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0)
    return img_tensor.float().to(device)

image_path = "/content/your_input_image.png"  # Example input
input_tensor = load_image(image_path)

with torch.no_grad():
    pred_phase = model(input_tensor)

pred_phase = pred_phase.squeeze().cpu().numpy()

plt.figure(figsize=(6,6))
plt.imshow(pred_phase % 1.0, cmap='hot')
plt.title("Predicted Phase Mask (Wrapped)")
plt.colorbar(label="Phase (0 to 1)")
plt.axis('off')
plt.show()
