In [1]:
import os 
os.chdir("..")
from src.cpwc.multires.class_multiressolver import *
import matplotlib.pyplot as plt
import torch
from src.cpwc.tools.ptychography import Ptychography as Ptychography
from src.cpwc.tools.utils import *
torch.cuda.empty_cache()

# Set seeds 
torch.manual_seed(0)
np.random.seed(0)
import gc
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) and obj.is_cuda:
            print(type(obj), obj.size())
    except Exception:
        pass

cycle = [0, -1, -1, -1,  1, 1, 1]
lmbda = 0
tol = [1e-10] * 9
tol_in = [1e-10] * 9
device = 'cuda'
max_scale = 9
max_probe_size = 128
image = np.load("test_data/potential.npy")
image = (image - image.min())/(image.max() - image.min()) 
image_tensor = torch.tensor(image).double().to(device).view(1, 1, 2**max_scale, 2**max_scale)
image_tensor_ = torch.exp(1j * image_tensor)
multires = MultiRes(max_scale, device)


def extract_data(nested_list):
    result = []
    for item in nested_list:
        if isinstance(item, list):  # If the item is a list, recurse into it
            result.extend(extract_data(item))
        else:  # If the item is not a list, add it to the result
            result.append(item)
    return result

def unwrap_2d(phase):
    """
    Unwraps a 2D phase array using NumPy's 1D unwrap function.
    
    Parameters:
    phase (numpy array): The 2D phase array to be unwrapped.
    
    Returns:
    unwrapped_phase (numpy array): The 2D unwrapped phase array.
    """
    # Unwrap along the first axis (rows)
    unwrapped_phase = np.unwrap(phase, axis=0)
    
    # Unwrap along the second axis (columns)
    unwrapped_phase = np.unwrap(unwrapped_phase, axis=1)
    
    return unwrapped_phase

def save_data(model,model_name,image):
    image = image[::,::]
    mean_img = np.mean(image)
    loss_data = extract_data(model.measures["loss"])
    cos_sim = extract_data(model.measures["csim"])
    phase = torch.angle(model.c_k[0,0,:,:].to('cpu'))
    phase = phase.numpy()
    phase = unwrap_2d(phase)
    phase += (mean_img-np.mean(phase)) 

    np.save("np_data/{}_overlap_loss.npy".format(model_name), loss_data)
    np.save("np_data/{}_overlap_csim.npy".format(model_name), cos_sim)
    np.save("np_data/{}_overlap_image.npy".format(model_name), phase)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import random
seed_list = [i*10 for i in range(1,10)]
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
a = os.chdir("./dat")

In [None]:
import torch
import numpy as np
import os

checkpoint_dir = "gd/checkpoints25"
os.makedirs(checkpoint_dir, exist_ok=True)
seed = 0
set_seed(seed)

# Initialize probe and operator
probe = Ptychography(image_tensor_, max_probe_size, max_probe_size, 1, device)
LR = 1e-3
max_shift = 96
I_in = 15 * np.array([1, 15, 10, 5, 10, 30, 100])
I_out = 10 * np.array([0, 0, 0, 0, 0, 0, 350])
linOperator = Ptychography(max_scale=max_scale, max_probe_size=max_probe_size, max_shift=max_shift, device=device)
image_tensor_ = linOperator.apply(image_tensor_)

# Define loss function
loss = Loss(linOperator, image_tensor_)

# Initialize solver
model = MultiResSolver(multires, loss, LR=LR, I_in=I_in, I_out=I_out, tol=tol, tol_in=tol_in, cycle=cycle, l1_type="l1_row",gt=image)

# Solve
model.solve_multigrid()
model.print_time()

image = image[::,::]
mean_img = np.mean(image)
loss_data = extract_data(model.measures["loss"])
cos_sim = extract_data(model.measures["csim"])
phase = torch.angle(model.c_k[0,0,:,:].to('cpu'))
phase = phase.numpy()
phase = unwrap_2d(phase)
phase += (mean_img-np.mean(phase)) 
# Extract and store new loss values
# Save checkpoint and loss values
np.save(checkpoint_dir+"/overlap_loss_{}.npy".format(seed), loss_data)
np.save(checkpoint_dir+"/overlap_csim_{}.npy".format(seed), cos_sim)
np.save(checkpoint_dir+"/overlap_image_{}.npy".format(seed), phase)


----------- s = 9 -----------


AttributeError: 'MultiResSolver' object has no attribute 'gt'

In [None]:
#Plot the means and variances of the loss

loss1 = np.load("./gd/checkpoints25/overlap_loss_0.npy")
loss2 = np.load("./gd/checkpoints25/overlap_loss_10.npy")
loss3 = np.load("./checkpoints25/overlap_loss_20.npy")
loss4 = np.load("./checkpoints25/overlap_loss_30.npy")
loss5 = np.load("./checkpoints25/overlap_loss_40.npy")
loss6 = np.load("./checkpoints25/overlap_loss_50.npy")
loss7 = np.load("./checkpoints25/overlap_loss_60.npy")

#Find the mean and variance of the 7 different loss values
mean = (loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7)/7
std = np.std(np.array([loss1, loss2, loss3, loss4, loss5, loss6, loss7]), axis=0)

#Save mean and std
np.save("./checkpoints25/mean_loss.npy", mean)
np.save("./checkpoints25/std_loss.npy", std)




In [None]:
# Plot all

mean75 = np.load("./checkpoints75/mean_loss.npy")
std75 = np.load("./checkpoints75/std_loss.npy")

mean50 = np.load("./checkpoints50/mean_loss.npy")
std50 = np.load("./checkpoints50/std_loss.npy")

mean25 = np.load("./checkpoints25/mean_loss.npy")
std25 = np.load("./checkpoints25/std_loss.npy")

plt.figure(figsize = (10, 5),dpi = 300)
plt.semilogy(mean75, label="MRGD w/ 0.75 Overlap")
plt.fill_between(np.arange(len(mean75)), mean75-std75, mean75+std75, alpha=0.3)
plt.semilogy(mean50, label="MRGD w/ 0.5 Overlap")
plt.fill_between(np.arange(len(mean50)), mean50-std50, mean50+std50, alpha=0.3)
plt.semilogy(mean25, label="MRGD w/ 0.25 Overlap")
plt.fill_between(np.arange(len(mean25)), mean25-std25, mean25+std25, alpha=0.3)
plt.legend()
plt.grid()
