In [None]:
import os
import numpy as np
np.seterr(all = 'ignore')
import torch
import torch.nn as nn
torch.manual_seed(0)

import time
datetime = time.strftime('%Y%m%d-%H%M%S')
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.io import loadmat, savemat

from PRModule import *

In [None]:
'''
read .mat file and plot input intensity

input data should be intensity
for dpGPS, intensity has to be normalized by photon count
missing data should be masked with NaN
'''

# load input data
path = './sample_au_rod.mat'
key = 'pattern'
intensity = loadmat(path)[key]
unknown = np.isnan(intensity)
print('input data size = {}'.format(intensity.shape))
time.sleep(1)
# plot intensity and psd
QQQ1 = np.log(intensity + 1)
QQQ2 = PSD(intensity, unknown)
plt.figure(figsize = (10, 5), dpi = 100)
plt.subplot(121)
plt.imshow(QQQ1, cmap = 'turbo')
plt.xticks([])
plt.yticks([])
cbar = plt.colorbar(fraction = 0.046, pad = 0.04)
cbar.set_ticks([np.nanmin(QQQ1), np.nanmax(QQQ1)])
cbar.set_ticklabels([np.nanmin(intensity), np.nanmax(intensity)])
plt.title('log-scaled intensity')
plt.subplot(122)
plt.plot(QQQ2)
plt.yscale('log')
plt.xlabel('radius')
plt.ylabel('average intensity')
plt.title('intensity psd')
plt.tight_layout()
plt.show()

In [None]:
# remove nan value
intensity[unknown] = 0
# support arguments
spargs = {
    'type' : 'auto', # [rect, auto]
    'radius' : (8, 4), # for rect
    'threshold' : 0.2 # for auto
}
# make support
support = MakeSupport(intensity, **spargs)
plt.figure(figsize = (5, 5), dpi = 100)
plt.imshow(support, cmap = 'gray')
plt.xticks([])
plt.yticks([])
plt.show()

In [None]:
# cast data type to single
input = np.sqrt(np.fft.ifftshift(intensity)).astype(np.single) # ifftshifted
unknown = unknown.astype(np.single)
support = support.astype(np.single)
# convet to torch tensor
h = input.shape[0]
w = input.shape[1]
input = torch.from_numpy(input).view(1, 1, h, w, 1)
unknown = torch.from_numpy(unknown).view(1, 1, h, w, 1)
support = torch.from_numpy(support).view(1, 1, h, w, 1)

In [None]:
# phase retrieval arguments
n_seed = 200 # number of seeds
n_batch = 40 # size of batch
n_iter = 1000 # number of iterations

info = {
    'algorithm' : 'GPS-R', # [HIO, GPS-R, GPS-F, dpGPS-R, dpGPS-F]
    'error' : 'R', # [R, NLL]
    'shrinkwrap' : True, # [True, False]

    # parameters of HIO
    'beta' : 0.9, # projection coefficient
    'boundary_push' : 0.1, # final boundary push stage ratio in total iteration
    # common parameters of GPS, dpGPS
    'sigma' : (0, 0.01, 0.4, 0.1, 0.7, 1), # parameter for relaxing magnitude constraint
    'alpha_count' : 10, # number of finer frequency filter for relaxing support constraint
    # parameters of GPS
    't' : 1, # parameter of proximal operator on magnitude constraint
    's' : 0.9, # parameter of proximal operator on support constraint

    # dpGPS setting
    'limit' : 0.25, # preconditioner value range [1-limit, 1+limit]
    'deep' : True, # whether to select deep learning-based preconditioner or not
    # shrinkwrap setting
    'sigma_initial' : 3, # initial sigma
    'sigma_limit' : 1.5, # lower boundary of sigma
    'ratio_update' : 0.01, # sigma update ratio
    'threshold' : 0.1, # threshold for defining new support
    'interval' : 50, # shrinkwrap interval
}

In [None]:
# initialize phase retrieval iterator
iterator = PhaseRetrieval(input, support, unknown, **info)
# select computing device
isCUDA = True
if isCUDA:
    if torch.cuda.is_available():
        device = torch.device('cuda')
        if torch.cuda.device_count() > 1:
            os.environ['CUDA_VISIBLE_DEVICES'] = '2, 3' # select GPU
            iterator = nn.DataParallel(iterator)
    else:
        raise Exception('CUDA is not available.')
else:
    device = torch.device('cpu')
# allocate iterator
iterator = iterator.to(device)