In [None]:
import os 
os.chdir("..")
from src.cpwc.multires.class_multiressolver import *
import matplotlib.pyplot as plt
import torch
from src.cpwc.tools.ptychography import Ptychography as Ptychography
from src.cpwc.tools.utils import *
torch.cuda.empty_cache()

# Set seeds 
torch.manual_seed(0)
np.random.seed(0)

In [None]:
import gc
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) and obj.is_cuda:
            print(type(obj), obj.size())
    except Exception:
        pass

In [None]:
max_scale = 9
max_probe_size = 128
max_shift = 32
device = 'cuda'
I_in = 15*np.array([1, 15, 10, 10, 30, 100])
I_out = 10*np.array([0, 0, 0, 30,40,300])
cycle = [0, -1, -1, 1,  1, 1]
lmbda = 0
LR = 1e-2
tol = [1e-10] * 9
tol_in = [1e-10] * 9

linOperator = Ptychography(max_scale = max_scale,max_probe_size = max_probe_size ,max_shift = max_shift,device=device)
noise_level = 1
image = plt.imread('test_data/peppers.jpg')/ 255
image_tensor = torch.tensor(image).double().to(device).view(1, 1, 2**max_scale, 2**max_scale)
image_tensor_ = torch.exp(1j * image_tensor)
m = torch.ones(2, 2).to(device= device)
d0 = image_tensor_[:,:,::8,::8]
d0 = torch.kron(d0, m)

multires = MultiRes(max_scale, device)
image_tensor_ = linOperator.apply(image_tensor_)
image_tensor_ = np.random.poisson(image_tensor_.cpu().detach() * noise_level) / noise_level
image_tensor_ = torch.tensor(image_tensor_).to(torch.complex64).to(device)


loss = Loss(linOperator,image_tensor_)
model1 = MultiResSolver(multires, loss, LR = LR,
                        I_in = I_in,
                        I_out = I_out,
                        tol = tol,
                        tol_in = tol_in,
                        cycle = cycle,
                        l1_type = "l1_row",
                        d0 = d0)

model1.solve_multigrid()
model1.print_time()

In [None]:
def unwrap_2d(phase):
    """
    Unwraps a 2D phase array using NumPy's 1D unwrap function.
    
    Parameters:
    phase (numpy array): The 2D phase array to be unwrapped.
    
    Returns:
    unwrapped_phase (numpy array): The 2D unwrapped phase array.
    """
    # Unwrap along the first axis (rows)
    unwrapped_phase = np.unwrap(phase, axis=0)
    
    # Unwrap along the second axis (columns)
    unwrapped_phase = np.unwrap(unwrapped_phase, axis=1)
    
    return unwrapped_phase

def plot_results(model,image):
    plt.figure(figsize=(18, 5),dpi = 200)    
    image = image[::,::]
    mean_img = np.mean(image)
    plt.subplot(1, 3, 1)
    plt.imshow(image,cmap='gray')
    plt.title(r"(a) Phase of GT $(\angle x)$")
    plt.colorbar()


    plt.subplot(1, 3, 2)
    phase = torch.angle(model.c_k[0,0,:,:].to('cpu'))
    phase = phase.numpy()
    phase = unwrap_2d(phase)
    phase += (mean_img-np.mean(phase))
    
    #phase = (phase - np.min(phase))/(np.max(phase) - np.min(phase))
    plt.imshow(phase,cmap = 'gray')
    plt.title(r"(b) Phase of Reconstruction $(\angle \hat{x})$")
    plt.colorbar()


    plt.subplot(1, 3, 3)
    error = np.abs(np.array(image)-np.array(phase))**2
    plt.imshow(error,cmap='gray')
    plt.title(r"(c) $||\angle x - \angle \hat{x}||^2$")
    plt.colorbar()
    plt.tight_layout()
    plt.savefig("poisson_folder/poisson0_0000001k.png")
    #plt.savefig("n_figs/new_result_img.png")
    #plt.savefig("n_figs/new_result_img.eps")
    mean_error = np.mean(error)
    print(mean_error)
    return None 

plot_results(model1,image)