## This notebook process and generate the experimental results for the experimental section of the paper

In [None]:
import numpy as np
from scipy.io import loadmat
import tensorflow as tf
import os

# These are used for plotting later on
from cycler import cycler
import matplotlib as mpl
from matplotlib import pyplot as plt

# Edit this line to point to the location of the data folder on your system.
data_folder_prefix = './data'

In [None]:
# load the ground truth for the test dataset, this
# would be the same image throughout
prefix = data_folder_prefix + '/Simulated/R_Sweep/'
test_images = np.load(prefix + 'test_images.npy')
# get the phase of the complex objects
test_images = np.angle(test_images)

# array to make the input and output zeros on the edges
Xs, Ys = np.mgrid[:256,:256]
Xs = Xs - np.mean(Xs)
Ys = Ys - np.mean(Ys)
Rs = np.sqrt(Xs**2 + Ys**2)

#test_images[:, Rs>108] = 0

In [None]:
def find_exp_global_phase(im1, im2, size=256):
    """correct the global phase factor for experimental iteartive reconstruction
    """
    im1 = im1.reshape(1,-1)
    im2 = im2.reshape(1,-1)
    
    mean1 = np.mean(im1)
    mean2 = np.mean(im2)
    
    b = mean2 - mean1
    
    im = (im1 + b).reshape(size, size)
    
    return im

In [None]:
# this plot the ground truth
plt.rcParams['figure.figsize'] = [20, 20]
fig = plt.figure()
fig.subplots_adjust(hspace=.05, wspace=.05)

for i in range(9):
    ax = fig.add_subplot(3, 3, i +1)
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)
    plt.imshow(test_images[i], cmap = 'viridis')

## The cell below get the SSIM values for different photon level conditions using generative not pretrained network

In [None]:
# load the MS-SSIM loss function
ssim = tf.image.ssim_multiscale

# generative not pretrained
pcc_mean11 = []
ssim_mean11 = []
ssim_std11 = []

R = 0.5

for idx, photon_level in enumerate([1, 10, 100.0, 1e3]):
    pic = []
    ssim_temp = []
    std_temp = []
    
    alphas = [1, 1/2, 1/4, 1/8, 1/16, 1/32]
    # loop over all the alpha values
    for alpha in alphas:
        # load the generative not pretrained network results
        empty = 0
        for string in ['2021-10-05', '2021-10-06']:
            #print(str(string))
            if os.path.isfile(str(string) + '-exp-not-pretrained-alpha-%0.2f-R-%0.2f-photon-%0.2f.mat' % (alpha, R, photon_level)):
                matfile = loadmat(str(string) + '-exp-not-pretrained-alpha-%0.2f-R-%0.2f-photon-%0.2f.mat' % (alpha, R, photon_level))
                empty += 1
                
        if empty != 1:
            print("missing files for R-%0.2f-alpha-%0.2f-photon-%0.2f" %(R, alpha, photon_level))
            break
        rec_test_output = matfile['rec_test_output']

        pic.append(rec_test_output)
    
        loss_ssim_list = []
        for gen, true in zip(rec_test_output, test_images):
            #true = true[20:-20, 20:-20]
            true = tf.expand_dims(true, -1)
            
            #print(np.shape(gen))
            gen[Rs>108] = 0
            #gen = gen[20:-20, 20:-20]
            gen = tf.expand_dims(gen, -1)
            
            remove = 47
            true = true[remove:-remove, remove:-remove, :]
            gen = gen[remove:-remove, remove:-remove, :]
            
            ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))
            loss_ssim_list.append(ssim_data)   
        
        ssim_temp.append(np.mean(loss_ssim_list))
        std_temp.append(np.std(loss_ssim_list))
        
    print('photon_level =', photon_level)
    idx = np.argmax(ssim_temp)
    print('the best alpha is:', alphas[idx])
    ssim_mean11.append(ssim_temp[idx])
    ssim_std11.append(std_temp[idx])
    
    plt.rcParams['figure.figsize'] = [20, 20]
    fig = plt.figure()
    fig.subplots_adjust(hspace=.05, wspace=.05)

    for i in range(16):
        ax = fig.add_subplot(4, 4, i +1)
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)
        plt.imshow(pic[idx][i], cmap = 'viridis')

## The cell below get the SSIM values for different photon level conditions using generative pretrained network

In [None]:
# generative pretrained
ssim_mean22 = []
ssim_std22 = []

R = 0.5

for idx, photon_level in enumerate([1, 10, 100.0, 1e3]):
    
    pic = []
    
    ssim_temp = []
    std_temp = []
    
    alphas = [1, 1/2, 1/4, 1/8, 1/16, 1/32]
    
    # loop over all the alpha values
    for alpha in alphas:
        # load the generative not pretrained network results
        empty = 0
        for string in ['2021-10-05', '2021-10-06']:
            #print(str(string) + '-pretrained-alpha-%0.2f-R-%0.2f-photon-%0.2f.mat' % (alpha, R, photon_level))
            if os.path.isfile(str(string) + '-exp-pretrained-alpha-%0.2f-R-%0.2f-photon-%0.2f.mat' % (alpha, R, photon_level)):
                #print("good")
                matfile = loadmat(str(string) + '-exp-pretrained-alpha-%0.2f-R-%0.2f-photon-%0.2f.mat' % (alpha, R, photon_level))
                empty += 1
                
        if empty != 1:
            print("missing files for R-%0.2f-alpha-%0.2f-photon-%0.2f" %(R, alpha, photon_level))
            break
        rec_test_output = matfile['rec_test_output']

        pic.append(rec_test_output)

        loss_ssim_list = []
        for gen, true in zip(rec_test_output, test_images):
            #true = true[20:-20, 20:-20]
            true = tf.expand_dims(true, -1)
            
            #print(np.shape(gen))
            gen[Rs>108] = 0
            #gen = gen[20:-20, 20:-20]
            gen = tf.expand_dims(gen, -1)
            
            remove = 47
            true = true[remove:-remove, remove:-remove, :]
            gen = gen[remove:-remove, remove:-remove, :]
            
            ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))
            loss_ssim_list.append(ssim_data)   
        
        ssim_temp.append(np.mean(loss_ssim_list))
        std_temp.append(np.std(loss_ssim_list))
        
    print('photon_level =', photon_level)
    idx = np.argmax(ssim_temp)
    print('the best alpha is:', alphas[idx])
    ssim_mean22.append(ssim_temp[idx])
    ssim_std22.append(std_temp[idx])
    
    plt.rcParams['figure.figsize'] = [20, 20]
    fig = plt.figure()
    fig.subplots_adjust(hspace=.05, wspace=.05)

    for i in range(16):
        ax = fig.add_subplot(4, 4, i +1)
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)
        plt.imshow(pic[idx][i], cmap = 'viridis')

## The cell below get the SSIM values for different photon level conditions using not generative not pretrained network

In [None]:
# not generative not pretrained
ssim_mean33 = []
ssim_std33 = []

R = 0.5

for idx, photon_level in enumerate([1, 10, 100, 1000]):
    ssim_temp = []
    std_temp = []
    
    matfile = loadmat('2021-10-05exp-not-pretrained-R-0.5-peak-' + str(photon_level) + '.mat')
    rec_test_output = matfile['rec_test_output']
    
    plt.rcParams['figure.figsize'] = [20, 20]
    fig = plt.figure()
    fig.subplots_adjust(hspace=.05, wspace=.05)

    for i in range(16):
        ax = fig.add_subplot(4, 4, i +1)
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)
        plt.imshow(rec_test_output[i], cmap = 'viridis')

    loss_ssim_list = []
    for gen, true in zip(rec_test_output, test_images):
        #true = true[20:-20, 20:-20]
        true = tf.expand_dims(true, -1)

        #print(np.shape(gen))
        gen[Rs>108] = 0
        #gen = gen[20:-20, 20:-20]
        gen = tf.expand_dims(gen, -1)

        remove = 47
        true = true[remove:-remove, remove:-remove, :]
        gen = gen[remove:-remove, remove:-remove, :]
        
        ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))
        loss_ssim_list.append(ssim_data)   

    ssim_temp.append(np.mean(loss_ssim_list))
    std_temp.append(np.std(loss_ssim_list))
        
    print('photon_level =', photon_level)
    ssim_mean33.append(ssim_temp[0])
    ssim_std33.append(std_temp[0])

## The cell below get the SSIM values for different photon level conditions using not generative pretrained network

In [None]:
# not generative pretrained
ssim_mean44 = []
ssim_std44 = []

R = 0.5

for idx, photon_level in enumerate([1, 10, 100, 1000]):
    ssim_temp = []
    std_temp = []
    
    matfile = loadmat('2021-10-05exp-pretrained-R-0.5-peak-' + str(photon_level) + '.mat')
    rec_test_output = matfile['rec_test_output']
    
    plt.rcParams['figure.figsize'] = [20, 20]
    fig = plt.figure()
    fig.subplots_adjust(hspace=.05, wspace=.05)

    for i in range(16):
        ax = fig.add_subplot(4, 4, i +1)
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)
        plt.imshow(rec_test_output[i], cmap = 'viridis')

    loss_ssim_list = []
    for gen, true in zip(rec_test_output, test_images):
        #true = true[20:-20, 20:-20]
        true = tf.expand_dims(true, -1)

        #print(np.shape(gen))
        gen[Rs>108] = 0
        #gen = gen[20:-20, 20:-20]
        gen = tf.expand_dims(gen, -1)

        remove = 47
        true = true[remove:-remove, remove:-remove, :]
        gen = gen[remove:-remove, remove:-remove, :]
        
        ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))
        loss_ssim_list.append(ssim_data)   

    ssim_temp.append(np.mean(loss_ssim_list))
    std_temp.append(np.std(loss_ssim_list))
        
    print('photon_level =', photon_level)
    ssim_mean44.append(ssim_temp[0])
    ssim_std44.append(std_temp[0])

## The cell below get the SSIM values for different photon level conditions using E2E network

In [None]:
# End-to-End
ssim_mean55 = []
ssim_std55 = []

for idx, photon_level in enumerate([1, 10, 100, 1e3]):
    ssim_temp = []
    std_temp = []

    rec_test_output = np.load('exp-End-to-End-test-output-R-0.5-photon-' + str(photon_level) + '.npy')
    print(np.max(rec_test_output[0]), np.min(rec_test_output[0]))
    
    plt.rcParams['figure.figsize'] = [20, 20]
    fig = plt.figure()
    fig.subplots_adjust(hspace=.05, wspace=.05)

    for i in range(16):
        ax = fig.add_subplot(4, 4, i +1)
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)
        plt.imshow(rec_test_output[i], cmap = 'viridis')

    loss_ssim_list = []
    for gen, true in zip(rec_test_output, test_images):
        #true = true[20:-20, 20:-20]
        true = tf.expand_dims(true, -1)

        #print(np.shape(gen))
        gen[Rs>108] = 0
        #gen = gen[20:-20, 20:-20]
        gen = tf.expand_dims(gen, -1)

        remove = 47
        true = true[remove:-remove, remove:-remove, :]
        gen = gen[remove:-remove, remove:-remove, :]
        
        ssim_data =  ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))
        loss_ssim_list.append(ssim_data)   

    ssim_temp.append(np.mean(loss_ssim_list))
    std_temp.append(np.std(loss_ssim_list))
        
    print('photon_level =', photon_level)
    ssim_mean55.append(ssim_temp[0])
    ssim_std55.append(std_temp[0])

In [None]:
print(ssim_mean55)

## The cell below calculate the pixel offset of the experimental RPI dataset

In [None]:
from tqdm import tqdm

# calculating the correct pixel offset values
loss = np.zeros((100, 400))

iter_test = np.load('test-approx-%d-iter-%d-lr-%0.2f.npy' 
                             % (1e3, 100, 0.5)).astype(np.float32)[:, remove:-remove, remove:-remove]/np.pi

for idx, (img1, img2) in tqdm(enumerate(zip(iter_test, test_images))):
    for i in range(-10, 10):
        for j in range(-10, 10):
            imgcrop = img2[remove+i:-remove+i, remove+j:-remove+j]
            im1_mean = np.mean(img1)
            im2_mean = np.mean(imgcrop)
            img1_final = img1 + (im2_mean - im1_mean)
            loss[idx, (i+10)*20 + j+10] = np.sqrt(np.sum(np.square(img1_final - imgcrop)))

In [None]:
a = np.argmin(loss, axis=1)
counts = np.bincount(a)
for i in range(-10, 10):
    for j in range(-10, 10):
        if (i+10)*20 + j+10 == np.argmax(counts):
            print("the offset x and y are:", i, j)
            shiftx, shifty = i, j

## The cell below get the SSIM values for different photon level conditions using iterative algorithm

In [None]:
from tqdm import tqdm

# load the MS-SSIM loss function
ssim = tf.image.ssim_multiscale

iterative_std1 = []
iterative_ssim1 = []

lr = 0.5
R = 0.5

for idx, photon_level in enumerate([1, 10, 100, 1e3]):
    for i, iters in enumerate([100]):
        iter_input = np.load('test-approx-%d-iter-%d-lr-%0.2f.npy' 
                             % (photon_level, iters, lr)).astype(np.float32)
        #iter_input[:, Rs>128] = 0
        
        iterloss_ssim_list = []
        gen_list = []
        true_list = []
        for gen, true in tqdm(zip(iter_input, test_images)):
            #gen = np.expand_dims(gen, axis=-1)
            #gen = denoise_tv_chambolle(gen, weight=0.1)[..., 0]
            remove = 47
            true = true[remove+shiftx:-remove+shiftx, remove+shifty:-remove+shifty]
            gen = gen[remove:-remove, remove:-remove]
            gen = gen/np.pi
            
            gen = find_exp_global_phase(gen, true, size=(128-remove)*2)
            gen = np.clip(gen, np.min(test_images), np.max(test_images))
            gen_list.append(gen)
            true_list.append(true)
            gen = tf.expand_dims(gen, -1)
            true = tf.expand_dims(true, -1)
            
            #print(np.shape(true))
            #print(np.shape(gen))
            
            ssim_data = ssim(true, gen, tf.math.reduce_max(true) - tf.math.reduce_min(true))
            iterloss_ssim_list.append(ssim_data)
            
        plt.rcParams['figure.figsize'] = [20, 20]
        fig = plt.figure()
        fig.subplots_adjust(hspace=.05, wspace=.05)

        for i in range(16):
            ax = fig.add_subplot(4, 4, i +1)
            ax.axes.xaxis.set_visible(False)
            ax.axes.yaxis.set_visible(False)
            plt.imshow(gen_list[i], cmap = 'viridis')
            plt.clim(np.min(test_images), np.max(test_images))

        iterative_ssim1.append(np.mean(iterloss_ssim_list))
        iterative_std1.append(np.std(iterloss_ssim_list))

## The cell below generate plots in our paper

In [None]:
from cycler import cycler
import matplotlib as mpl


plt.style.use('seaborn-whitegrid')

mpl.rcParams['axes.prop_cycle'] = cycler(color=['r', 'g', 'b', 'chocolate', 'olive'])

plt.rcParams['figure.figsize'] = [20, 16]


labels = np.flip(['10$^0$ photons', '10$^1$ photons', '10$^2$ photons', '10$^3$ photons'])

x = np.arange(len(labels))  # the label locations
width = 0.13 # the width of the bars

fig, ax = plt.subplots()
rects1 = ax.bar(x - width * 3/2, np.flip(ssim_mean44), width, yerr=np.flip(ssim_std44), label='Non-Generative', ecolor='black', capsize=5)
#rects2 = ax.bar(x - width * 1, ssim_mean33, width, yerr=ssim_std33, label='Non-Generative-Not-Pretrain', ecolor='black', capsize=5)
rects3 = ax.bar(x - width * 1/2, np.flip(ssim_mean11), width, yerr=np.flip(ssim_std11), label='Generative', ecolor='black', capsize=5)
rects3 = ax.bar(x + width * 1/2, np.flip(iterative_ssim1), width, yerr=np.flip(iterative_std1), label='Iterative', ecolor='black', capsize=5)
rects3 = ax.bar(x + width * 3/2, np.flip(ssim_mean55), width, yerr=np.flip(ssim_std55), label='End-to-End', ecolor='black', capsize=5)

# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel('MS-SSIM', fontsize=35)
#ax.set_title('R=0.5' + ' with Different Poission Noise', fontsize=40)
ax.set_xticks(x)
ax.set_xticklabels(labels, fontsize=25)
ax.tick_params(axis='both', labelsize=35)
ax.legend(fontsize=39)
ax.patch.set_edgecolor('black')  
ax.patch.set_linewidth('3')


def autolabel(rects):
    """Attach a text label above each bar in *rects*, displaying its height."""
    for rect in rects:
        height = rect.get_height()
        ax.annotate('{}'.format(height),
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom')

#autolabel(rects1)
#autolabel(rects2)
fig.tight_layout()
plt.ylim([0, 1])
#plt.yscale('log')
plt.show()

In [None]:
def norm_to_one(tensor):
    tf_max = np.max(tensor)
    tf_min = np.min(tensor)
    return (tensor - tf_min) / (tf_max - tf_min)

In [None]:
import matplotlib.patches as patches
plt.rcParams['figure.figsize'] = [20+4, 20]
fig = plt.figure()
fig.subplots_adjust(hspace=.05, wspace=.05)

R = 0.5

background = 0
circuit = 108

for i in range(24):
    photon_level = [1e3, 100, 10, 1][i//6]
    alpha = np.flip([0.03125, 0.5, 0.25, 1])[i//6]
    
    prefix = data_folder_prefix + '/Simulated/R_Sweep/'
    test_images = np.load(prefix + 'test_images-R-0.50.npy')
    # get the phase of the complex objects
    test_images = np.angle(test_images)
    
    prox_1 = np.load('test-approx-%d-iter-1-lr-1.00.npy' 
                             % photon_level).astype(np.float32)/np.pi
    
    mean1 = np.mean(prox_1[:, remove:-remove, remove:-remove], axis=(1, 2))
    mean2 = np.mean(test_images[:, remove:-remove, remove:-remove], axis=(1, 2))
    b1 = mean2 - mean1
    #prox_1 = prox_1 + b
    #prox_1[:, Rs>circuit] = background
    
    prox_100 = np.load('test-approx-%d-iter-100-lr-0.50.npy' 
                             % photon_level).astype(np.float32)/np.pi
    mean1 = np.mean(prox_100[:, remove:-remove, remove:-remove], axis=(1, 2))
    mean2 = np.mean(test_images[:, remove:-remove, remove:-remove], axis=(1, 2))
    b100 = mean2 - mean1
    #prox_100 = prox_100 + b
    #prox_100[:, Rs>circuit] = background
    
    e2e = np.load('exp-End-to-End-test-output-R-0.5-photon-' + str(photon_level) + '.npy')
    e2e[:, Rs>circuit] = background
    
    not_gen = loadmat('2021-10-05exp-pretrained-R-0.5-peak-%d.mat' % photon_level)['rec_test_output']
    not_gen[:, Rs>circuit] = background
    
    gen_not_pre = loadmat('2021-10-05-exp-not-pretrained-alpha-%0.2f-R-%0.2f-photon-%0.2f.mat' % (alpha, R, photon_level))['rec_test_output']
    gen_not_pre[:, Rs>circuit] = background
    

    
    idx = 3
    circle = plt.Circle((5, 5), 0.5, color='b', fill=False)
    ax = fig.add_subplot(5, 6, i +1).add_patch(circle)
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)
    #ax.set_clip_path(patch)
    if i%6 == 0 :
        im = prox_1[idx] + b1[idx]
        im[Rs>circuit] = background
        plt.imshow(np.rot90(im[20:-20, 20:-20], 3), cmap='viridis')
        plt.clim(np.min(test_images), np.max(test_images))
    if i%6 == 1 :
        im = prox_100[idx] + b100[idx]
        im[Rs>circuit] = background
        plt.imshow(np.rot90(im[20:-20, 20:-20], 3), cmap='viridis')
        plt.clim(np.min(test_images), np.max(test_images))
    if i%6 == 2 :
        plt.imshow(np.rot90(e2e[idx][20:-20, 20:-20], 3), cmap='viridis')
        plt.clim(np.min(test_images), np.max(test_images))
    if i%6 == 3 :
        plt.imshow(np.rot90(not_gen[idx][20:-20, 20:-20], 3), cmap='viridis')
        plt.clim(np.min(test_images), np.max(test_images))
    if i%6 == 4 :
        plt.imshow(np.rot90(gen_not_pre[idx][20:-20, 20:-20], 3), cmap='viridis')
        plt.clim(np.min(test_images), np.max(test_images))
    if i%6 == 5 :
        test_image = test_images[idx]
        test_image[Rs>circuit] = background
        plt.imshow(np.rot90(test_image[20:-20, 20:-20], 3), cmap='viridis')
        plt.clim(np.min(test_images), np.max(test_images))