In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import torch

# torch.cuda.is_available()
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
torch.set_default_device(device)

### Load results

In [None]:
true_counts = torch.load("results/true_counts.pt")
true_fluxes = torch.load("results/true_fluxes.pt")
true_locs = torch.load("results/true_locs.pt")
true_total_intensities = torch.load("results/true_total_intensities.pt")
images = torch.load("results/images.pt")
num_images = images.shape[0]
max_objects = true_fluxes.shape[1]

sep_estimated_count = torch.load("results/sep_estimated_count.pt")
sep_reconstruction = torch.load("results/sep_reconstruction.pt")

smc_posterior_mean_count = torch.load("results/smc_posterior_mean_count.pt")
smc_reconstruction = torch.load("results/smc_reconstruction.pt")
smc_runtime = torch.load("results/smc_runtime.pt")
smc_num_iters = torch.load("results/smc_num_iters.pt")

### SEP results

In [None]:
sep_prop_correct = ((sep_estimated_count == true_counts).sum()/num_images)
sep_mse = ((sep_estimated_count - true_counts)**2).mean()
sep_mae = (sep_estimated_count - true_counts).abs().mean()

print(f"proportion correct = {sep_prop_correct}")
print(f"MSE = {sep_mse}")
print(f"MAE = {sep_mae}")

In [None]:
sep_mean_estimated_s_by_num = torch.zeros(max_objects)
sep_bounds_estimated_s_by_num = torch.zeros(max_objects, 2)
sep_num_correct_by_num = torch.zeros(max_objects)
sep_prop_correct_by_num = torch.zeros(max_objects)
sep_mse_by_num = torch.zeros(max_objects)
sep_mae_by_num = torch.zeros(max_objects)
sep_bounds_mae_by_num = torch.zeros(max_objects, 2)

for num in range(max_objects):
    print(f"true number of sources = {num}")
    sep_mean_estimated_s_by_num[num] = sep_estimated_count[true_counts==num].mean()
    print(f"mean estimated number of sources = {sep_mean_estimated_s_by_num[num].item()}")
    sep_bounds_estimated_s_by_num[num] = sep_estimated_count[true_counts==num].quantile(torch.tensor((0.05, 0.95)))
    
    sep_mse_by_num[num] = ((sep_estimated_count[true_counts==num] - true_counts[true_counts==num])**2).mean()
    print(f"MSE across {num_images} images = ", sep_mse_by_num[num].item())
    
    sep_mae_by_num[num] = ((sep_estimated_count[true_counts==num] - true_counts[true_counts==num]).abs()).mean()
    print(f"MAE across {num_images} images = ", sep_mae_by_num[num].item())
    sep_bounds_mae_by_num[num] = ((sep_estimated_count[true_counts==num] - true_counts[true_counts==num]).abs()).quantile(torch.tensor((0.05, 0.95)))
    
    sep_num_correct_by_num[num] = (sep_estimated_count[true_counts==num].round() == true_counts[true_counts==num]).sum()
    sep_prop_correct_by_num[num] = sep_num_correct_by_num[num]/(true_counts==num).sum()
    print(f"correct number of sources detected in {sep_num_correct_by_num[num]} of the {(true_counts==num).sum()} images (accuracy = {sep_prop_correct_by_num[num]})\n\n\n")

### SMC results

In [None]:
print(f"MSE across {num_images} images:", ((smc_posterior_mean_count - true_counts)**2).mean().item())
print(f"MAE across {num_images} images:", ((smc_posterior_mean_count - true_counts).abs()).mean().item())
print(f"correct number of sources detected in {(smc_posterior_mean_count.round() == true_counts).sum()} of the {num_images} images (accuracy = {(smc_posterior_mean_count.round() == true_counts).sum()/num_images})")
print(f"number of iterations: minimum = {smc_num_iters.min().int()}, median = {smc_num_iters.median().int()}, maximum = {smc_num_iters.max().int()}")
print(f"runtime: minimum = {smc_runtime.min().int()}, median = {smc_runtime.median().int()}, maximum = {smc_runtime.max().int()}\n\n\n")

for i in range(num_images):
    print(f"image {i+1} of {num_images} took {smc_num_iters[i].int()} iterations:   ",
        "true s:", true_counts[i].int().item(),
        "   estimated s:", "{:.3f}".format(smc_posterior_mean_count[i].round(decimals = 4).item()),
        "   true total flux:", true_fluxes[i].sum().round().int().item(),
        "   estimated total flux:", )

In [None]:
# smc_mean_post_mean_s_by_num = torch.zeros(D, device=device)
# smc_bounds_post_mean_s_by_num = torch.zeros(D, 2, device=device)
# smc_num_correct_by_num = torch.zeros(D, device=device)
# smc_prop_correct_by_num = torch.zeros(D, device=device)
# smc_mse_by_num = torch.zeros(D, device=device)
# smc_mae_by_num = torch.zeros(D, device=device)
# smc_bounds_mae_by_num = torch.zeros(D, 2, device=device)

# for num in range(D):
#     print(f"true number of sources = {num}")
    
#     smc_mean_post_mean_s_by_num[num] = post_mean_s_smc[s==num].mean()
#     print(f"estimated number of sources for images where s = {num}:", smc_mean_post_mean_s_by_num[num].item())
#     smc_bounds_post_mean_s_by_num[num] = post_mean_s_smc[s==num].quantile(torch.tensor((0.05, 0.95),device=device))
    
#     smc_mse_by_num[num] = ((post_mean_s_smc[s==num] - s[s==num])**2).mean()
#     print(f"MSE across {(s==num).sum()} images:", (smc_mse_by_num[num].item()))
    
#     smc_mae_by_num[num] = ((post_mean_s_smc[s==num] - s[s==num]).abs()).mean()
#     print(f"MAE across {(s==num).sum()} images:", (smc_mae_by_num[num].item()))
#     smc_bounds_mae_by_num[num] = ((post_mean_s_smc[s==num] - s[s==num]).abs()).quantile(torch.tensor((0.05, 0.95),device=device))
    
#     smc_num_correct_by_num[num] = (post_mean_s_smc[s==num].round() == s[s==num]).sum()
#     smc_prop_correct_by_num[num] = smc_num_correct_by_num[num]/(s==num).sum()
#     print(f"correct number of sources detected in {smc_num_correct_by_num[num].int()} of the {(s==num).sum()} images (accuracy = {smc_prop_correct_by_num[num]})")
    
#     print(f"number of iterations: minimum = {num_iters_smc[s==num].min().int()}, median = {num_iters_smc[s==num].median().int()}, maximum = {num_iters_smc[s==num].max().int()}\n")