In [None]:
import os 
os.chdir("..")
from src.cpwc.multires.class_multiressolver import *
import matplotlib.pyplot as plt
import torch
from src.cpwc.tools.ptychography import Ptychography as Ptychography
from src.cpwc.tools.utils import *
torch.cuda.empty_cache()

# Set seeds 
torch.manual_seed(0)
np.random.seed(0)
import gc
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) and obj.is_cuda:
            print(type(obj), obj.size())
    except Exception:
        pass

cycle = [0, -1, -1, -1,  1, 1, 1]
lmbda = 0
tol = [1e-10] * 9
tol_in = [1e-10] * 9
device = 'cuda'
max_scale = 9
max_probe_size = 128
image = np.load("test_data/potential.npy")
image = (image - image.min())/(image.max() - image.min()) 
image_tensor = torch.tensor(image).double().to(device).view(1, 1, 2**max_scale, 2**max_scale)
image_tensor_ = torch.exp(1j * image_tensor)
multires = MultiRes(max_scale, device)


def extract_data(nested_list):
    result = []
    for item in nested_list:
        if isinstance(item, list):  # If the item is a list, recurse into it
            result.extend(extract_data(item))
        else:  # If the item is not a list, add it to the result
            result.append(item)
    return result

def unwrap_2d(phase):
    """
    Unwraps a 2D phase array using NumPy's 1D unwrap function.
    
    Parameters:
    phase (numpy array): The 2D phase array to be unwrapped.
    
    Returns:
    unwrapped_phase (numpy array): The 2D unwrapped phase array.
    """
    # Unwrap along the first axis (rows)
    unwrapped_phase = np.unwrap(phase, axis=0)
    
    # Unwrap along the second axis (columns)
    unwrapped_phase = np.unwrap(unwrapped_phase, axis=1)
    
    return unwrapped_phase

def save_data(model,model_name,image):
    image = image[::,::]
    mean_img = np.mean(image)
    loss_data = extract_data(model.measures["loss"])
    cos_sim = extract_data(model.measures["csim"])
    phase = torch.angle(model.c_k[0,0,:,:].to('cpu'))
    phase = phase.numpy()
    phase = unwrap_2d(phase)
    phase += (mean_img-np.mean(phase)) 

    np.save("np_data/{}_overlap_loss.npy".format(model_name), loss_data)
    np.save("np_data/{}_overlap_csim.npy".format(model_name), cos_sim)
    np.save("np_data/{}_overlap_image.npy".format(model_name), phase)

In [1]:
import random
seed_list = [i*10 for i in range(1,10)]
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
a = os.chdir("./dat")

NameError: name 'os' is not defined

In [None]:
import torch
import numpy as np
import os

seed = 60
set_seed(seed)

# Initialize probe and operator
probe = Ptychography(image_tensor_, max_probe_size, max_probe_size, 1, device)
LR = 1e-2
max_shift = 32
I_in = 15 * np.array([1, 15, 10, 5, 10, 30, 100])
I_out = 5 * np.array([0, 0, 0, 30, 30, 40, 300])
linOperator = Ptychography(max_scale=max_scale, max_probe_size=max_probe_size, max_shift=max_shift, device=device)
image_tensor_ = linOperator.apply(image_tensor_)

# Define loss function
loss = Loss(linOperator, image_tensor_,lmbda = 1e-7)

# Initialize solver
model = MultiResSolver(multires, loss, LR=LR, I_in=I_in, I_out=I_out, tol=tol, tol_in=tol_in, cycle=cycle, l1_type="l1_row",gt=image)

# Solve
model.solve_multigrid()
model.print_time()

image = image[::,::]
mean_img = np.mean(image)
loss_data2 = extract_data(model.measures["loss"])
cos_sim2 = extract_data(model.measures["csim"])
phase = torch.angle(model.c_k[0,0,:,:].to('cpu'))
phase = phase.numpy()
phase = unwrap_2d(phase)
phase += (mean_img-np.mean(phase)) 
# Extract and store new loss values
# Save checkpoint and loss values
#np.save(checkpoint_dir+"/overlap_loss_{}.npy".format(seed), loss_data)
#np.save(checkpoint_dir+"/overlap_csim_{}.npy".format(seed), cos_sim)
#np.save(checkpoint_dir+"/overlap_image_{}.npy".format(seed), phase)


In [2]:
image = image[::,::]
mean_img = np.mean(image)
loss_data = extract_data(model.measures["loss"])
cos_sim = extract_data(model.measures["csim"])
phase = torch.angle(model.c_k[0,0,:,:].to('cpu'))
phase = phase.numpy()
phase = unwrap_2d(phase)
phase += (mean_img-np.mean(phase)) 
plt.imshow(phase, cmap='gray')
plt.axis('off')


NameError: name 'image' is not defined

In [None]:
np.save("./comp/overlap_loss.npy", loss_data)
np.save("./comp/overlap_csim.npy", cos_sim)
np.save("./comp/overlap_image.npy", phase)

In [None]:
image = "./comp/overlap_image.npy"
image_reg = "./comp/reg_overlap_image.npy"

plt.figure()
plt.imshow(np.load(image), cmap='gray')
plt.axis('off')
plt.figure()
plt.imshow(np.load(image_reg), cmap='gray')
plt.axis('off')
plt.show()


plt.imsave("./comp/overlap_image.tiff", np.load(image), cmap='gray')
plt.imsave("./comp/reg_overlap_image.tiff", np.load(image_reg), cmap='gray')

In [None]:
#Plot the loss values using the shorter array where len(loss) != len(loss_reg)
loss = np.load("./comp/overlap_loss.npy")
loss_reg = np.load("./comp/reg_overlap_loss.npy")
plt.figure(dpi = 600,figsize = (10,5))
plt.semilogy(loss, label='MRGD')
plt.semilogy(loss_reg, label='MRGD w/Regularization')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid()
plt.show()
