# Phase Retrieval demo.

SPIE Short course on Machine Learning for Image Restoration.  
Author: Jesse Wilson (jesse.wilson@colostate.edu).

Walk through training and evaluation of a convolutional network for phase retrieval from coherent diffractive imaging. This code is provided for educational purposes.

# Preliminaries

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from IPython.display import clear_output
from torchvision.transforms.functional import gaussian_blur
import numpy as np
from random import randint
from torch.fft import fft2, fftshift, ifft2, ifftshift

# get available GPU 
# supports NVIDIA (CUDA), Intel (XPU), and Apple (MPS)
# (CAUTION: AI-generated code -- NOT validated on all systems!)
if torch.cuda.is_available():
    device = torch.device("cuda:0")
elif hasattr(torch,"xpu") and torch.xpu_is_available():
    device = torch.device("xpu:0")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Selected device: {device}.")

In [None]:
# load a dataset

batch_size=64

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Pad(18)])
dataset_train = datasets.MNIST(root='data',train=True,download=True,transform=transform)
dataset_val = datasets.MNIST(root='data',train=False,download=True, transform = transform)

dataloader_train = torch.utils.data.DataLoader(dataset_train,batch_size=64)


print(dataset_train[0][0].shape)
plt.imshow(dataset_train[0][0].squeeze())
plt.colorbar()
plt.show()

# Set up the forward model
Functions and class definition to simulate passing a coherent beam through a circular aperture, then a phase object, then propagating to the Fraunhofer plane, and recording an amplitude-only diffraction pattern.

In [None]:
def createAperture( radius ):
    x=np.linspace(1,64,64)-32
    y=np.linspace(1,64,64)-32
    X,Y = np.meshgrid(x,y)
    
    aperture = X**2+Y**2 <= radius**2
    return aperture

def createPhaseObject( img, aperture ):
    # set up a complex object in which the image is encoded with phase
    # and is surrounded by an opaque aperture
    nbatch = img.shape[0]
    obj = torch.exp( 3.14j/18.*img )
    obj = obj * aperture

    return obj


In [None]:
def forwardModel( obj ):
    # calculate a diffraction patern of a complex object using the Fourier transform
    diffr_complex = fftshift(fft2(ifftshift(obj)))
    
    diffr_phase = diffr_complex.angle() # phase of diffraction patern (not measurable)
    diffr_abs = diffr_complex.abs()   # amplitude of diffraction pattern (seen by imaging sensor)

    return( diffr_abs, diffr_phase )

# Neural network definition and quick passthrough test

In [None]:
# 64x64 encoder decoder network with fully connected bottleneck
# uses strided convs instead of conv->maxpool
# uses strided transposed convs instead of upsampling -> conv
# note: started modifying my previous code for encoder-decoder, got stuck on 
#       exact syntax and details for flattening/unflattening, so the initial
#       draft of the bottleneck was generated by prompting MS Copilot:
#       
#       Simple pytorch code for encoder-decoder with a fully-connected bottleneck. 
#       Should take 64x64 images, use strided convolutions for downsampling 
#       and transposed convs for upsampling, and have a 7x7 bottleneck.
#
class EncDec( nn.Module ):
    def __init__(self):
        super().__init__()
        self.nFilt = 32
        self.enc = nn.Sequential(
            nn.Conv2d(1,self.nFilt,kernel_size=3,stride=2), # 64x64 -> 31x31
            nn.LeakyReLU(),
            nn.Conv2d(self.nFilt,self.nFilt,kernel_size=3,stride=2), # 31x31 -> 15x15
            nn.LeakyReLU(),
            nn.Conv2d(self.nFilt,self.nFilt,kernel_size=3,stride=2), # 15x15 -> 7x7
            nn.LeakyReLU(),
            nn.Conv2d(self.nFilt,self.nFilt,kernel_size=3,stride=2), # 7x7 -> 3x3
            nn.LeakyReLU()
        )

        self.dec = nn.Sequential(
            nn.ConvTranspose2d(self.nFilt,self.nFilt,kernel_size=3,stride=2), # 3x3 -> 7x7
            nn.LeakyReLU(),
            nn.ConvTranspose2d(self.nFilt,self.nFilt,kernel_size=3,stride=2), # 7x7 -> 15x15
            nn.LeakyReLU(),
            nn.ConvTranspose2d(self.nFilt,self.nFilt,kernel_size=3,stride=2), # 15x15 -> 31x31
            nn.LeakyReLU(),
            nn.ConvTranspose2d(self.nFilt,1,kernel_size=3,stride=2,output_padding=1), # 15x15 -> 64x64
        )

        self.bottleneck = nn.Sequential(
            nn.Flatten(start_dim=1),
            nn.Linear(3*3*self.nFilt, 3*3*self.nFilt),
            nn.LeakyReLU(),
            nn.Linear(3*3*self.nFilt, 3*3*self.nFilt),
            nn.LeakyReLU(),
            nn.Unflatten(1,(self.nFilt,3,3))
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
                nn.init.constant_(m.bias,0)

    def forward(self,x):
        x = self.enc(x)
        x = self.bottleneck(x)
        x = self.dec(x)

        return x

net = EncDec().to(device)

In [None]:
# quick test: generate a random object, diffraction pattern, and pass through network
ind = randint(0,len(dataset_val)-1)
img = dataset_val[ind][0].to(device).unsqueeze(0)

aperture = createAperture(16)
aperture = torch.Tensor(aperture).to(device)

obj_complex = createPhaseObject(img, aperture)
diffr_abs, diffr_phase  = forwardModel(obj_complex)

obj_phase_est = net(diffr_abs)

# plot the results
plt.figure(figsize=(8,6))

plt.subplot(231)
plt.imshow(obj_complex.abs().cpu().squeeze())
plt.axis('off')
plt.title('abs(object)')

plt.subplot(232)
plt.imshow(obj_complex.angle().cpu().squeeze())
plt.axis('off')
plt.colorbar()
plt.title('phase(object)')

plt.subplot(233)
plt.imshow(obj_phase_est.detach().cpu().squeeze())
plt.axis('off')
plt.title('network output')

plt.subplot(234)
plt.imshow((diffr_abs.cpu().squeeze()))
plt.colorbar()
plt.axis('off')
plt.title('(abs(diffraction pattern))')

plt.subplot(235)
plt.imshow(diffr_phase.cpu().squeeze())
plt.colorbar()
plt.axis('off')
plt.title('phase(diffraction pattern)')

plt.tight_layout()
plt.show()

# Train the network

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=0.003)
loss_fn = nn.L1Loss()

# load a limited set of the training data onto the GPU
nTrain = 5000

# calculate complex objects
aperture = createAperture(16)
aperture = torch.Tensor(aperture).to(device)


loss_train_vec = []

# training loop
n_epochs = 1500
for epoch in range(1,n_epochs):
    for img, label in dataloader_train:
        img = img.to(device)
        obj = createPhaseObject( img, aperture )
        diffr_abs, diffr_phase = forwardModel(obj)

        obj_phase_est = net(diffr_abs)
    
        loss = loss_fn(obj_phase_est,obj.angle())
        optimizer.zero_grad()
     
        loss.backward()
        optimizer.step()
    
        loss_train_vec += [loss.item()]
    
    # plot training/validation loss curves
    clear_output(wait=True)
    plt.figure(figsize=[8,8])
    plt.subplot(221)
    plt.imshow(obj[0].angle().cpu().squeeze())
    plt.subplot(222)
    plt.imshow(obj_phase_est[0].detach().cpu().squeeze())
    plt.subplot(212)
    plt.plot(loss_train_vec)
    #plt.plot(loss_val_vec)
    plt.legend(['train','val'])
    plt.show()
    

# Activity

In [None]:
# your turn: change one thing above and run it again. A few ideas
# - change pupil diameter
# - change learning rate
# - change neural network architecture
# - add validation tracking to the training loop
# - test on out-of-distribution images
# - add noise to simulated image
# - compare with Feinup


# Extras

## Physics-informed (unsupervised) training
Note: this proof of concept code _barely_ works, and will need some fine tuning and experimentation to be robust.

In [None]:
class GradientMagnitude(nn.Module):
    def __init__(self,dev):
        super().__init__()
        self.sobel_x = torch.tensor([[1.,0.,-1.],[2.,0.,-2.],[1.,0.,1.]]).unsqueeze(0).unsqueeze(0).to(dev)
        self.sobel_y = torch.tensor([[1.,2.,1.],[0.,0.,0.],[-1.,-2.,-1.]]).unsqueeze(0).unsqueeze(0).to(dev)

    def forward(self, x):
        # calculate summed magnitude of the gradients of an image
        grad_x = F.conv2d(x,self.sobel_x)
        grad_y = F.conv2d(y,self.sobel_y)
    
        absgrad = -torch.abs(grad_x)**2 + torch.abs(grad_y)**2
        return torch.mean(absgrad)

In [None]:
# unsupervised training
# given measured diffraction amplitude pattern, find object phase such that estimated diffraction pattern matches measured

optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
loss_fn = nn.MSELoss()
#loss_fn_gradmag = GradientMagnitude(device)


# calculate complex objects
aperture = createAperture(16)
aperture = torch.Tensor(aperture).to(device)


loss_train_vec = []
#loss_val_vec = []

# training loop
n_epochs = 1500
for epoch in range(1,n_epochs):
    for img, label in dataloader_train:
        img = img.to(device)
        obj = createPhaseObject( img, aperture )
        diffr_abs, diffr_phase = forwardModel(obj)

        obj_phase_est = net(diffr_abs)
        
        obj_est = createPhaseObject(obj_phase_est, aperture)
        diffr_abs_est, diffr_phase_est = forwardModel(obj_est)

        # unsupervised loss (does NOT make use of object phase)
        #loss = loss_fn(diffr_abs_est,diffr_abs) + 0.1*torch.mean(torch.abs(obj_phase_est*(1-aperture))) + 0.1*torch.mean(torch.relu(-1*obj_phase_est))+0.001*loss_fn_gradmag(obj_phase_est)
        #loss = loss_fn(diffr_abs_est,diffr_abs) + 0.00001*loss_fn_gradmag(obj_phase_est)
        loss  = loss_fn(diffr_abs_est,diffr_abs) + 0.1*torch.mean(obj_phase_est**2) + torch.mean(torch.relu(-1*obj_phase_est))
        optimizer.zero_grad()
    
        loss.backward()
        optimizer.step()
        #scheduler.step(loss)
    
        loss_train_vec += [loss.item()]
    
 
    # plot training/validation loss curves
    clear_output(wait=True)
    plt.figure(figsize=[8,8])
    plt.subplot(221)
    plt.imshow(obj[0].angle().cpu().squeeze(),vmin=0,vmax=0.2)
    plt.subplot(222)
    plt.imshow(obj_phase_est[0].detach().cpu().squeeze())
    plt.colorbar()
    plt.subplot(212)
    plt.plot(loss_train_vec)
    plt.show()
    
