In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1' # select GPU if needed
import math
import numpy as np
np.seterr(all = 'ignore')
import torch
import torch.nn as nn
torch.manual_seed(0) # initial seed for random number generator

import time
datetime = time.strftime('%Y%m%d_%H%M')
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.io import loadmat, savemat

from PRModule import *

In [None]:
'''
read .mat file and plot input intensity

input data should be intensity
for dpGPS, intensity has to be normalized by photon count
missing data should be masked with NaN
'''

# load input data
save_name = 'au_rod' # name for saving result
save_path = '.' # path for saving result
path = './sample_au_rod.mat'
key = 'pattern'
intensity = loadmat(path)[key]
missing = np.isnan(intensity)
print('input data size = {}'.format(intensity.shape))
time.sleep(1)
# plot intensity and psd
log_intensity = np.log(intensity + 1)
psd_intensity = PSD(intensity, mask = missing)
plt.figure(figsize = (10, 5), dpi = 100)
plt.subplot(121)
plt.imshow(log_intensity, cmap = 'turbo')
cbar = plt.colorbar(fraction = 0.046, pad = 0.04)
cbar.set_ticks([np.nanmin(log_intensity), np.nanmax(log_intensity)])
cbar.set_ticklabels([np.nanmin(intensity), np.nanmax(intensity)])
plt.title('log-scaled intensity')
plt.subplot(122)
plt.plot(psd_intensity)
plt.yscale('log')
plt.xlabel('radius')
plt.ylabel('average intensity')
plt.title('intensity psd')
plt.tight_layout()
plt.show()

In [None]:
# remove nan value
intensity[missing] = 0
# support arguments
spargs = {
    'type' : 'auto', # [rect, auto]
    'radius' : (8, 4), # for rect
    'threshold' : 0.2 # for auto
}
# make support
support = MakeSupport(intensity, **spargs)
plt.figure(figsize = (10, 5), dpi = 100)
plt.subplot(121)
plt.imshow(support, cmap = 'gray')
plt.title('support constraint')
plt.subplot(122)
plt.imshow(support, cmap = 'gray')
ib, jb = np.nonzero(support)
plt.xlim(np.amin(jb), np.amax(jb))
plt.ylim(np.amax(ib), np.amin(ib))
plt.title('support constraint (zoomed)')
plt.tight_layout()
plt.show()

In [None]:
# cast data type to single
input = np.sqrt(np.fft.ifftshift(intensity)).astype(np.single) # ifftshifted
unknown = np.fft.ifftshift(missing).astype(np.single)
support = support.astype(np.single)
# convet to torch tensor
h = input.shape[0]
w = input.shape[1]
input = torch.from_numpy(input).view(1, 1, h, w, 1)
unknown = torch.from_numpy(unknown).view(1, 1, h, w, 1)
support = torch.from_numpy(support).view(1, 1, h, w, 1)

In [None]:
# phase retrieval arguments
n_seed = 200 # number of seeds
n_batch = 40 # size of batch, number of seeds in one epoch
n_iter = 1000 # number of phase retrieval iterations

info = {
    'algorithm' : 'dpGPS-R', # [HIO, GPS-R, GPS-F, dpGPS-R, dpGPS-F]
    'error' : 'R', # [R, NLL]
    'shrinkwrap' : True, # [True, False]

    # parameters of HIO
    'beta' : 0.9, # projection coefficient
    'boundary_push' : 0.1, # final boundary push stage ratio in total iteration
    # common parameters of GPS, dpGPS
    'sigma' : (0, 0.01, 0.4, 0.1, 0.7, 1), # parameter for relaxing magnitude constraint
    'alpha_count' : 10, # number of finer frequency filter for relaxing support constraint
    # parameters of GPS
    't' : 1, # parameter of proximal operator on magnitude constraint
    's' : 0.9, # parameter of proximal operator on support constraint
    # parameters of dpGPS
    'inner_iter' : 3, # number of inner inexact iteration
    # dpGPS setting
    'limit' : 0.25, # preconditioner value range [1-limit, 1+limit]
    'deep' : True, # whether to select deep learning-based preconditioner or not

    # shrinkwrap setting
    'sigma_initial' : 3, # initial sigma
    'sigma_limit' : 1.5, # lower boundary of sigma
    'ratio_update' : 0.01, # sigma update ratio
    'threshold' : 0.1, # threshold for defining new support
    'interval' : 50, # shrinkwrap interval
}

In [None]:
# initialize phase retrieval iterator
iterator = PhaseRetrieval(input, support, unknown, **info)
# select computing device
isCUDA = True
if isCUDA:
    if torch.cuda.is_available():
        device = torch.device('cuda')
        if torch.cuda.device_count() > 1:
            iterator = nn.DataParallel(iterator)
    else:
        raise Exception('CUDA is not available.')
else:
    device = torch.device('cpu')
# allocate iterator
iterator = iterator.to(device)

In [None]:
# select output type
otype = 'u' # [u, z]
output = torch.zeros(n_seed, 1, h, w, 1 if otype == 'u' else 2)
path = torch.zeros(n_seed, n_iter)
n_max = math.ceil(n_seed / n_batch)
t_elap = time.time()
for n in tqdm(range(n_max)):
    # make initial random phase
    n_batch_cur = n_batch
    if n == n_max - 1 and n_seed % n_batch > 0:
        n_batch_cur = n_seed % n_batch
    phi = torch.rand(n_batch_cur, 1, h, w, 1) * 2 * math.pi
    phi = torch.cat((torch.cos(phi), torch.sin(phi)), dim = -1)
    phi = phi.to(device)
    # perform phase retrieval iteration
    output[n * n_batch:min((n + 1) * n_batch, n_seed), :, :, :, :], \
        path[n * n_batch:min((n + 1) * n_batch, n_seed), :] \
            = iterator(n_iter, phi, toggle = otype == 'z', **info)
t_elap = time.time() - t_elap

In [None]:
plt.figure(figsize = (10, 5), dpi = 100)
plt.plot(path.mean(dim = 0))
plt.title('average error per iteration')
plt.show()

In [None]:
# conver results to numpy ndarray
output = output.squeeze().numpy()
path = path.numpy()
# sort by minimum error
error = np.amin(path, axis = 1)
order = np.argsort(error)
output = output[order, :, :]
path = path[order, :]
error = error[order]
# subpixel alignment
output = SubpixelAlignment(output, subpixel = 10)
# save result
isSave = True
if isSave:
    savemat('{}/{}_{}.mat'.format(save_path, save_name, datetime), {
        'result' : output, 'error' : error, 'info' : info, 't_elap' : t_elap})

In [None]:
# plot best reconstruction result
plt.figure(figsize = (10, 5), dpi = 100)
plt.subplot(121)
plt.imshow(output[0], cmap = 'turbo')
plt.title('best reconstruction result')
plt.subplot(122)
plt.imshow(output[0], cmap = 'turbo')
ib, jb = np.nonzero(output[0] > 0)
plt.xlim(np.amin(jb), np.amax(jb))
plt.ylim(np.amax(ib), np.amin(ib))
cbar = plt.colorbar(fraction = 0.046, pad = 0.04)
plt.title('best reconstruction result (zoomed)')
plt.tight_layout()
plt.show()

In [None]:
# plot best 5
plt.figure(figsize = (25, 10), dpi = 100)
# plot r-space
plt.subplot(2, 5, 1)
plt.imshow(output[0], cmap = 'turbo')
plt.xlim(np.amin(jb), np.amax(jb))
plt.ylim(np.amax(ib), np.amin(ib))
plt.xticks([])
plt.yticks([])
cbar = plt.colorbar(fraction = 0.046, pad = 0.04)
plt.title('best 1 (error = {:.4f})'.format(error[0]))
plt.subplot(2, 5, 2)
plt.imshow(output[1], cmap = 'turbo')
plt.xlim(np.amin(jb), np.amax(jb))
plt.ylim(np.amax(ib), np.amin(ib))
plt.xticks([])
plt.yticks([])
cbar = plt.colorbar(fraction = 0.046, pad = 0.04)
plt.title('best 2 (error = {:.4f})'.format(error[1]))
plt.subplot(2, 5, 3)
plt.imshow(output[2], cmap = 'turbo')
plt.xlim(np.amin(jb), np.amax(jb))
plt.ylim(np.amax(ib), np.amin(ib))
plt.xticks([])
plt.yticks([])
cbar = plt.colorbar(fraction = 0.046, pad = 0.04)
plt.title('best 3 (error = {:.4f})'.format(error[2]))
plt.subplot(2, 5, 4)
plt.imshow(output[3], cmap = 'turbo')
plt.xlim(np.amin(jb), np.amax(jb))
plt.ylim(np.amax(ib), np.amin(ib))
plt.xticks([])
plt.yticks([])
cbar = plt.colorbar(fraction = 0.046, pad = 0.04)
plt.title('best 4 (error = {:.4f})'.format(error[3]))
plt.subplot(2, 5, 5)
plt.imshow(output[4], cmap = 'turbo')
plt.xlim(np.amin(jb), np.amax(jb))
plt.ylim(np.amax(ib), np.amin(ib))
plt.xticks([])
plt.yticks([])
cbar = plt.colorbar(fraction = 0.046, pad = 0.04)
plt.title('best 5 (error = {:.4f})'.format(error[4]))
# plot k-space
plt.subplot(2, 5, 6)
plt.imshow(np.log(np.abs(np.fft.fftshift(np.fft.fft2(output[0]))) + 1), cmap = 'turbo')
plt.xticks([])
plt.yticks([])
plt.subplot(2, 5, 7)
plt.imshow(np.log(np.abs(np.fft.fftshift(np.fft.fft2(output[1]))) + 1), cmap = 'turbo')
plt.xticks([])
plt.yticks([])
plt.subplot(2, 5, 8)
plt.imshow(np.log(np.abs(np.fft.fftshift(np.fft.fft2(output[2]))) + 1), cmap = 'turbo')
plt.xticks([])
plt.yticks([])
plt.subplot(2, 5, 9)
plt.imshow(np.log(np.abs(np.fft.fftshift(np.fft.fft2(output[3]))) + 1), cmap = 'turbo')
plt.xticks([])
plt.yticks([])
plt.subplot(2, 5, 10)
plt.imshow(np.log(np.abs(np.fft.fftshift(np.fft.fft2(output[4]))) + 1), cmap = 'turbo')
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.show()

In [None]:
# calculate pairwise distance in r-space
dpair = PairwiseDistance(output)
# calculate PRTF and its PSD
prtf = PRTF(output, ref = np.sqrt(intensity), mask = missing)
prtf_psd = PSD(prtf, mask = missing)

In [None]:
# plot metric
plt.figure(figsize = (15, 5), dpi = 100)
plt.subplot(131)
plt.hist(error)
plt.xlabel('error')
plt.ylabel('count')
plt.title('error distribution')
plt.subplot(132)
plt.hist(dpair)
plt.xlabel('distance')
plt.ylabel('count')
plt.title('pairwise distance')
plt.subplot(133)
plt.plot(prtf_psd)
plt.axhline(0.5, color = 'k', linestyle = ':', linewidth = 1)
plt.xlabel('radius')
plt.ylabel('average prtf')
plt.title('prtf psd')
plt.tight_layout()
plt.show()

In [None]:
# calculate eigenmode of r-space result
eigen, singular, approx = EigenMode(output, k = 3)

In [None]:
plt.figure(figsize = (20, 10), dpi = 100)
# plot eigenmode
plt.subplot(2, 4, 2)
plt.imshow(eigen[0], cmap = 'turbo')
plt.xlim(np.amin(jb), np.amax(jb))
plt.ylim(np.amax(ib), np.amin(ib))
plt.xticks([])
plt.yticks([])
cbar = plt.colorbar(fraction = 0.046, pad = 0.04)
plt.title('eigenmode 1 (s = {:.3f})'.format(singular[0]))
plt.subplot(2, 4, 3)
plt.imshow(eigen[1], cmap = 'turbo')
plt.xlim(np.amin(jb), np.amax(jb))
plt.ylim(np.amax(ib), np.amin(ib))
plt.xticks([])
plt.yticks([])
cbar = plt.colorbar(fraction = 0.046, pad = 0.04)
plt.title('eigenmode 2 (s = {:.3f})'.format(singular[1]))
plt.subplot(2, 4, 4)
plt.imshow(eigen[2], cmap = 'turbo')
plt.xlim(np.amin(jb), np.amax(jb))
plt.ylim(np.amax(ib), np.amin(ib))
plt.xticks([])
plt.yticks([])
cbar = plt.colorbar(fraction = 0.046, pad = 0.04)
plt.title('eigenmode 3 (s = {:.3f})'.format(singular[2]))
# plot average
plt.subplot(2, 4, 5)
plt.imshow(np.mean(output, axis = 0), cmap = 'turbo')
plt.xlim(np.amin(jb), np.amax(jb))
plt.ylim(np.amax(ib), np.amin(ib))
plt.xticks([])
plt.yticks([])
cbar = plt.colorbar(fraction = 0.046, pad = 0.04)
plt.title('average')
# plot low-rank approximation
plt.subplot(2, 4, 6)
plt.imshow(approx[0], cmap = 'turbo')
plt.xlim(np.amin(jb), np.amax(jb))
plt.ylim(np.amax(ib), np.amin(ib))
plt.xticks([])
plt.yticks([])
cbar = plt.colorbar(fraction = 0.046, pad = 0.04)
plt.title('eigenapprox 1')
plt.subplot(2, 4, 7)
plt.imshow(approx[1], cmap = 'turbo')
plt.xlim(np.amin(jb), np.amax(jb))
plt.ylim(np.amax(ib), np.amin(ib))
plt.xticks([])
plt.yticks([])
cbar = plt.colorbar(fraction = 0.046, pad = 0.04)
plt.title('eigenapprox 2')
plt.subplot(2, 4, 8)
plt.imshow(approx[2], cmap = 'turbo')
plt.xlim(np.amin(jb), np.amax(jb))
plt.ylim(np.amax(ib), np.amin(ib))
plt.xticks([])
plt.yticks([])
cbar = plt.colorbar(fraction = 0.046, pad = 0.04)
plt.title('eigenapprox 3')
plt.tight_layout()
plt.show()