#### Load libraries and set CUDA device

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch.distributions import Poisson, Normal, Uniform, Distribution, Categorical

import numpy as np
import sep

In [None]:
# torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)

#### TruncatedDiagonalMVN class for MH kernel

In [None]:
class TruncatedDiagonalMVN(Distribution):
    """A truncated diagonal multivariate normal distribution."""

    def __init__(self, mu, sigma, a, b):
        super().__init__(validate_args=False)

        self.dim = mu.size()
        
        self.lb = a*torch.ones_like(mu)
        self.ub = b*torch.ones_like(mu)
        
        base = Normal(mu, sigma)
        prob_in_box_hw = base.cdf(b*torch.ones_like(mu)) - base.cdf(a*torch.ones_like(mu))
        self.log_prob_in_box = prob_in_box_hw.log()

        self.base_dist = base

    def __repr__(self):
        return f"{self.__class__.__name__}({self.base_dist})"

    def sample(self, **args):
        # sampling from d-dim truncated diagonal MVN <=> sampling from d truncated univariate normals
        p = torch.rand(tuple(self.dim)).to(device).clamp(min = 1e-6, max = 1.0 - 1e-6)
        p_tilde = self.base_dist.cdf(self.lb) + p * (self.log_prob_in_box.exp())
        x = self.base_dist.icdf(p_tilde)
        
        return x

    @property
    def mode(self):
        # a mode still exists if this assertion is false, but I haven't implemented code
        # to compute it because I don't think we need it
        assert (self.mean >= self.lb).all() and (self.mean <= self.ub).all()
        return self.base_dist.mode

    def log_prob(self, value):
        assert (value >= self.lb).all() and (value <= self.ub).all()
        # subtracting log probability that the base RV is in the unit box
        # is equivalent in log space to dividing the normal pdf by the normalizing constant
        return self.base_dist.log_prob(value) - self.log_prob_in_box

    def cdf(self, value):
        cdf_at_val = self.base_dist.cdf(value)
        cdf_at_lb = self.base_dist.cdf(self.lb*torch.ones_like(self.mean))
        log_cdf = (cdf_at_val - cdf_at_lb + 1e-9).log().sum(dim=-1) - self.log_prob_in_box
        return log_cdf.exp()

#### Image characteristics and point spread function

In [None]:
num_images = 1000                                # the number of images in our dataset
max_objects_generated = 10
D = max_objects_generated + 1                   # there are (max_objects_generated + 1) possible source counts for each image: {0,1,2,...,max_objects_generated}
img_dim = 15                                    # the height and width of our images
eta = 3.25                                      # PSF variance
H = img_dim                                     # height of images
W = img_dim                                     # width of images
min_flux = torch.tensor(6400., device = device) # minimum flux
background_intensity = 3 * min_flux             # background intensity of images

def psf(H, W, D, u_h, u_w, eta):
    psf_marginal_H = 1 + torch.arange(H, dtype=torch.float32, device = device)
    psf_marginal_W = 1 + torch.arange(W, dtype=torch.float32, device = device)
    
    psf = ((-(psf_marginal_H.view(1, H, 1, 1) - u_h.view(1, 1, D, -1))**2 - (psf_marginal_W.view(W, 1, 1, 1) - u_w.view(1, 1, D, -1))**2)/(2*eta**2)).exp()
    psf = psf/psf.sum([0,1]).view(1, 1, D, -1)
    
    return psf.squeeze()

#### Generate synthetic images

In [None]:
torch.manual_seed(0)

# Priors for number of objects, fluxes, locations
s_prior = Categorical((1/D)*torch.ones(D, device = device)) # Poisson(torch.tensor(4., device=device))
flux_prior = Normal(10 * min_flux, 2 * min_flux)
u_prior = Uniform(torch.zeros(2, device = device), torch.tensor((H,W), device = device))

# Create tensors to store data for multiple images
s = torch.zeros(num_images, device=device)
flux = torch.zeros(num_images, D, device=device)
u = torch.zeros(num_images, D, 2, device=device)
true_intensity = torch.zeros(num_images, H, W, device=device)
images = torch.zeros(num_images, H, W, device=device)
observed_flux = torch.zeros(num_images, device=device)



for i in range(num_images):
    # Sample number of objects, fluxes, locations
    s[i] = s_prior.sample()
    s_indicator = torch.logical_and(torch.arange(D, device = device) <= s[i],
                                    torch.arange(D, device = device) > torch.zeros(1, device=device))
    flux[i] = flux_prior.sample([D]) * s_indicator
    u[i] = u_prior.sample([D]) * s_indicator.unsqueeze(1)

    # Compute true intensity of image
    star_intensity = (flux[i].view(1, 1, D) * psf(H, W, D, u[i][:,0], u[i][:,1], eta)).sum(2)
    true_intensity[i] = background_intensity + star_intensity

    # Sample image
    images[i] = Poisson(true_intensity[i]).sample().to(device)

    # Compute observed flux
    observed_flux[i] = (images[i] - background_intensity).sum([0,1])
    
    print(f"image {i+1}\n", "s\n", s[i], "\n\ntotal flux\n", flux[i].sum(), "\n\nu\n", u[i], "\n\n\n")
    # fig, (true, observed) = plt.subplots(1,2)
    # _ = true.imshow(true_intensity[i].cpu())
    # _ = observed.imshow(images[i].cpu())

----
----
----

#### SEP

In [None]:
# Grid search over SEP parameters to obtain optimal performance
num_detection_thresholds_to_try = 25
detection_thresholds = torch.linspace(start = 0, end = 1000, steps = num_detection_thresholds_to_try, device=device)

num_minarea_to_try = 3
minarea = torch.linspace(start = 1, end = 9, steps = num_minarea_to_try, device=device)

num_deblend_cont_to_try = 3
deblend_cont = torch.linspace(start = 1e-4, end = 5e-2, steps = num_deblend_cont_to_try, device=device)

num_dblend_nthresh_to_try = 1
deblend_nthresh = torch.linspace(start = 32, end = 32, steps = num_dblend_nthresh_to_try, device=device)

sep_estimated_s = torch.zeros(num_detection_thresholds_to_try, num_minarea_to_try, num_deblend_cont_to_try, num_dblend_nthresh_to_try, num_images, device=device)
sep_prop_correct = torch.zeros(num_detection_thresholds_to_try, num_minarea_to_try, num_deblend_cont_to_try, num_dblend_nthresh_to_try, device=device)
sep_mse = torch.zeros(num_detection_thresholds_to_try, num_minarea_to_try, num_deblend_cont_to_try, num_dblend_nthresh_to_try, device=device)
sep_mae = torch.zeros(num_detection_thresholds_to_try, num_minarea_to_try, num_deblend_cont_to_try, num_dblend_nthresh_to_try, device=device)

for t in range(num_detection_thresholds_to_try):
    for m in range(num_minarea_to_try):
        for c in range(num_deblend_cont_to_try):
            for h in range(num_dblend_nthresh_to_try):
                    print("detectection_threshold = ", detection_thresholds[t].item())
                    print("minarea = ", minarea[m].item())
                    print("deblend_cont = ", deblend_cont[c].item())
                    print("deblend_nthresh = ", deblend_nthresh[h].item(), "\n")
                    for img in range(num_images):
                        detected_sources = sep.extract((images[img] - background_intensity).cpu().numpy(),
                                                       thresh = detection_thresholds[t], minarea = minarea[m], deblend_cont = deblend_cont[c],
                                                       deblend_nthresh = deblend_nthresh[h], clean = False)
                        # print(f"true flux = {flux[img][flux[img]>0].sum()}\nestimated flux = {detected_sources['flux']}\n")
                        sep_estimated_s[t, m, c, h, img] = len(detected_sources)

                    sep_prop_correct[t,m,c,h] = ((sep_estimated_s[t, m, c, h,:] == s).sum()/num_images)
                    sep_mse[t,m,c,h] = ((sep_estimated_s[t, m, c, h,:] - s)**2).mean()
                    sep_mae[t,m,c,h] = (sep_estimated_s[t, m, c, h,:] - s).abs().mean()
    
                    print(f"proportion correct:", sep_prop_correct[t,m,c,h].item())
                    print(f"MSE:", sep_mse[t,m,c,h].item())
                    print(f"MAE:", sep_mae[t,m,c,h].item(), "\n\n\n")

In [None]:
for t in range(num_detection_thresholds_to_try):
    for m in range(num_minarea_to_try):
        for c in range(num_deblend_cont_to_try):
            for h in range(num_dblend_nthresh_to_try):
                    if sep_mse[t,m,c,h] == sep_mse.min():
                        print("detectection_threshold = ", detection_thresholds[t].item())
                        print("minarea = ", minarea[m].item())
                        print("deblend_cont = ", deblend_cont[c].item())
                        print("deblend_nthresh = ", deblend_nthresh[h].item(), "\n")
                        
                        detection_threshold_optim = detection_thresholds[t]
                        minarea_optim = minarea[m]
                        deblend_cont_optim = deblend_cont[c]
                        deblend_nthresh_optim = deblend_nthresh[h]

In [None]:
# Run SEP with optimal parameters
sep_estimated_s = torch.zeros(num_images, device=device)
sep_loc_x = torch.zeros(num_images, device=device)
sep_loc_y = torch.zeros(num_images, device=device)
sep_flux = torch.zeros(num_images, device=device)
sep_reconstruction = torch.zeros(num_images, H, W, device=device)

for img in range(num_images):
    detected_sources = sep.extract((images[img] - background_intensity).cpu().numpy(),
                                thresh = detection_threshold_optim, minarea = minarea_optim, deblend_cont = deblend_cont_optim,
                                deblend_nthresh = deblend_nthresh_optim, clean = False)
    
    sep_loc_x = torch.from_numpy(detected_sources['x']).to(device)
    sep_loc_y = torch.from_numpy(detected_sources['y']).to(device)
    sep_flux = torch.from_numpy(detected_sources['flux']).to(device)
    sep_estimated_s[img] = len(detected_sources)
    
    if sep_estimated_s[img] > 1:
        sep_reconstruction[img] = (psf(H, W, sep_estimated_s[img].int().item(), sep_loc_x, sep_loc_y, eta) * sep_flux.view(1, 1, sep_estimated_s[img].int().item())).sum(2) + background_intensity
    elif sep_estimated_s[img] == 1:
        sep_reconstruction[img] = (psf(H, W, sep_estimated_s[img].int().item(), sep_loc_x, sep_loc_y, eta) * sep_flux.view(1, 1, sep_estimated_s[img].int().item())).sum(0) + background_intensity
    else:
        sep_reconstruction[img] = (psf(H, W, 1, torch.zeros(1,device=device), torch.zeros(1,device=device), eta) * torch.zeros(1, device=device).view(1, 1, 1)).sum(2) + background_intensity

----
----
----

#### SMC-Deblender

In [None]:
def tempered_log_p_x_given_z(image, flux, u, tempering_factor):
    num_blocks = flux.size(0)
    num_particles = flux.size(1)
    
    rate = (psf(H, W, num_blocks, u[:,:,0], u[:,:,1], eta) * flux.view(1, 1, num_blocks, -1)).sum(2) + background_intensity
    cond_ll = Poisson(rate).log_prob(image.view(img_dim, img_dim, 1)).sum([0,1])
    tempered_cond_ll = tempering_factor.unsqueeze(1) * torch.stack(torch.split(cond_ll, num_particles//num_blocks, dim=0), dim=0)

    return tempered_cond_ll



def log_target(image, s, flux, u, tempering_factor):
    num_blocks = flux.size(0)
    num_particles = flux.size(1)
    
    s_indicator = torch.logical_and(torch.arange(num_blocks, device = device).view(num_blocks,1) <= s,
                                    torch.arange(num_blocks, device = device).view(num_blocks,1) > torch.zeros(num_particles, device=device))

    log_targ = (flux_prior.log_prob(flux) * s_indicator).sum(0)   # s_prior.log_prob(s) can be omitted; it appears in both the numerator and denominator of the MH acceptance ratio with the same s
    log_targ += (u_prior.log_prob(u) * s_indicator.unsqueeze(2)).sum(2).sum(0)
    log_targ += tempered_log_p_x_given_z(image, flux, u, tempering_factor).flatten(0)
    
    return log_targ



def MCMC_kernel(image, s_tminus1, s_t, flux_tminus1, u_tminus1, tau_tminus1):
    num_blocks = flux_tminus1.size(0)
    num_particles = flux_tminus1.size(1)
    
    
    
    s_indicator = torch.logical_and(torch.arange(num_blocks, device = device).view(num_blocks,1) <= s_tminus1,
                                    torch.arange(num_blocks, device = device).view(num_blocks,1) > torch.zeros(num_particles, device=device))
    
    
    
    flux_prev = flux_tminus1
    u_prev = u_tminus1
    s_prev = s_tminus1
    s_new = s_t
    


    flux_sd_scale = 2500
    flux_proposal_sd = flux_sd_scale*torch.ones(1, device=device)
    
    u_sd_scale = 0.25
    u_proposal_sd = u_sd_scale*torch.ones(1, device=device)
    
    
    
    flux_sq_jump_dist_prev = torch.tensor(1e-10, device=device)
    flux_rel_sq_jump_dist_tol = 1e-2
    flux_rel_sq_jump_dist = flux_rel_sq_jump_dist_tol + 1   # make sure flux_rel_sq_jump_dist is initially greater than tolerance
    
    u_sq_jump_dist_prev = torch.tensor(1e-10, device=device)
    u_rel_sq_jump_dist_tol = 1e-2
    u_rel_sq_jump_dist = u_rel_sq_jump_dist_tol + 1   # make sure u_rel_sq_jump_dist is initially greater than tolerance
    
    
    
    # Set upper bound on number of MH iterations
    max_MH_iter = 500
    
    for MH_iter in range(max_MH_iter):
        # Random walk Metropolis proposal for fluxes
        flux_proposed = Normal(flux_prev, flux_proposal_sd).sample() * s_indicator

        # Random walk Metropolis proposal for locations
        u_proposed = TruncatedDiagonalMVN(u_prev, u_proposal_sd, torch.tensor(0, device=device), torch.tensor(img_dim, device=device)).sample() * s_indicator.unsqueeze(2)

        # Compute acceptance probability
        log_numerator = log_target(image, s_new, flux_proposed, u_proposed, tau_tminus1)
        log_numerator += (TruncatedDiagonalMVN(u_proposed, u_proposal_sd, torch.tensor(0, device=device), torch.tensor(img_dim, device=device)).log_prob(u_prev) * s_indicator.unsqueeze(2)).sum(2).sum(0)

        if MH_iter == 0:
            log_denominator = log_target(image, s_prev, flux_prev, u_prev, tau_tminus1)
            log_denominator += (TruncatedDiagonalMVN(u_prev, u_proposal_sd, torch.tensor(0, device=device), torch.tensor(img_dim, device=device)).log_prob(u_proposed) * s_indicator.unsqueeze(2)).sum(2).sum(0)

        alpha = (log_numerator - log_denominator).exp().clamp(max = 1)
        
        # Accept proposal if prob <= alpha, reject otherwise
        prob = Uniform(torch.zeros(num_particles), torch.ones(num_particles)).sample().to(device)
        flux_new = flux_proposed * (prob <= alpha).unsqueeze(0) + flux_prev * (prob > alpha).unsqueeze(0)
        u_new = u_proposed * (prob <= alpha).view(1, -1, 1) + u_prev * (prob > alpha).view(1, -1, 1)

        # Compute relative squared jumping distance for fluxes
        flux_sq_jump_dist_new_by_block = torch.stack(torch.split((flux_new - flux_tminus1)**2, num_particles//num_blocks, dim=1), dim=1).mean(2)
        flux_sq_jump_dist_new = ((flux_sq_jump_dist_new_by_block * (flux_sq_jump_dist_new_by_block != 0)).sum()) / (flux_sq_jump_dist_new_by_block != 0).sum()
        flux_rel_sq_jump_dist = (flux_sq_jump_dist_new - flux_sq_jump_dist_prev)/flux_sq_jump_dist_prev
        
        # Compute relative squared jumping distance for locations
        u_sq_jump_dist_new_by_block = torch.stack(torch.split(((u_new - u_tminus1)**2).sum(2), num_particles//num_blocks, dim=1), dim=1).mean(2)
        u_sq_jump_dist_new = ((u_sq_jump_dist_new_by_block * (u_sq_jump_dist_new_by_block != 0)).sum()) / (u_sq_jump_dist_new_by_block != 0).sum()
        u_rel_sq_jump_dist = (u_sq_jump_dist_new - u_sq_jump_dist_prev)/u_sq_jump_dist_prev

        # Continue loop until relative squared jumping distance falls below tolerance for flux and location 
        if flux_rel_sq_jump_dist < flux_rel_sq_jump_dist_tol and u_rel_sq_jump_dist < u_rel_sq_jump_dist_tol:
            break
        
        # Cache log_denominator for next iteration
        log_denominator = log_numerator * (prob <= alpha) + log_denominator * (prob > alpha)

        # Reset fluxes and locations
        flux_prev = flux_new
        u_prev = u_new

        flux_sq_jump_dist_prev = flux_sq_jump_dist_new
        u_sq_jump_dist_prev = u_sq_jump_dist_new
        
    
    print("num MH iters:", MH_iter)
    
    return [flux_new, u_new]



def bisection_f(image, flux_tminus1, u_tminus1, delta, ess_min):
    log_numerator = 2*tempered_log_p_x_given_z(image, flux_tminus1, u_tminus1, delta).logsumexp(dim=1)
    log_denominator = tempered_log_p_x_given_z(image, flux_tminus1, u_tminus1, 2*delta).logsumexp(dim=1)

    return (log_numerator - log_denominator).exp() - ess_min



def AdaptiveTempering(image, flux_tminus1, u_tminus1, tau_tminus1, tol, max_iter, ess_min):
    num_blocks = flux_tminus1.size(0)
    
    a = torch.zeros(num_blocks, device = device)
    b = 1 - tau_tminus1
    c = (a+b)/2
    
    f_a = torch.zeros(num_blocks, device = device)
    f_b = torch.zeros(num_blocks, device = device)
    f_c = torch.zeros(num_blocks, device = device)
    
    # Compute increase in tau for every block using the bisection method
    for j in range(max_iter):
        if torch.all((b-a).abs() <= tol):
            break

        f_a = bisection_f(image, flux_tminus1, u_tminus1, a, ess_min)
        f_b = bisection_f(image, flux_tminus1, u_tminus1, b, ess_min)
        f_c = bisection_f(image, flux_tminus1, u_tminus1, c, ess_min)

        a[f_a.sign() == f_c.sign()] = c[f_a.sign() == f_c.sign()]
        b[f_b.sign() == f_c.sign()] = c[f_b.sign() == f_c.sign()]

        c = (a+b)/2

    # For all blocks, set the increase in tau to be the minimum increase across the blocks
    c = c.min(0).values.repeat(num_blocks)
    
    return c + tau_tminus1

In [None]:
def SMC(image, num_blocks, num_particles, max_iters):
    ## Initialize
    s_tminus1 = torch.ones(num_particles, device = device) * torch.arange(num_blocks, device = device).repeat_interleave(num_particles//num_blocks)
    s_indicator = torch.logical_and(torch.arange(num_blocks, device = device).view(num_blocks,1) <= s_tminus1,
                                    torch.arange(num_blocks, device = device).view(num_blocks,1) > torch.zeros(num_particles, device=device))
    flux_tminus1 = flux_prior.sample([num_blocks, num_particles]) * s_indicator
    u_tminus1 = u_prior.sample([num_blocks, num_particles]) * s_indicator.unsqueeze(2)

    # Compute unnormalized weights for t=0 (all weights will be equal since proposal = prior)
    log_unnormalized_weights_tminus1 = torch.zeros(num_particles, device = device)

    # Compute normalized weights for t=0 (separately for s=1 particles, s=2 particles, and all particles)
    normalized_weights_block_tminus1 = torch.stack(torch.split(log_unnormalized_weights_tminus1, num_particles//num_blocks, dim=0), dim = 0).softmax(1)
    normalized_weights_all_tminus1 = log_unnormalized_weights_tminus1.softmax(0)

    # Compute effective sample sizes for t=0 (ESS = num_particles//num_blocks since weights are all equal)
    ess = 1/(normalized_weights_block_tminus1**2).sum(1)
    ess_min = 0.5*num_particles//num_blocks

    # Set initial tempering exponent to zero
    tau_tminus1 = torch.zeros(num_blocks, device=device)
    
    # Initialize final_iter just in case SMC doesn't converge before max_iters
    final_iter = max_iters
    
    for t in range(1, max_iters):        
        ### ADAPTIVE TEMPERING
        tau_t = AdaptiveTempering(image, flux_tminus1, u_tminus1, tau_tminus1, 1e-6, 50, ess_min)
        
        if t % 10 == 0:
            print(f"\n======= iteration {t} =======")
            print("tau\n", tau_t.unique())
        
        ### ADAPTIVE STRATIFIED RESAMPLING
        for block_num in range(num_blocks):
            if ess[block_num] < ess_min:
                bins = normalized_weights_block_tminus1[block_num,:].cumsum(0)
                unif = (torch.arange(num_particles//num_blocks, device=device) + torch.rand(num_particles//num_blocks, device=device))/(num_particles//num_blocks)
                resample_indices = torch.bucketize(unif, bins).clamp(min = 0, max = num_particles//num_blocks - 1) # clamp to make sure no indices are equal to num_particles//num_blocks, which would trigger device-side assert
                
                flux_tminus1[block_num, (block_num*num_particles//num_blocks):((block_num+1)*num_particles//num_blocks)] = torch.gather(flux_tminus1[block_num, (block_num*num_particles//num_blocks):((block_num+1)*num_particles//num_blocks)],
                                                                                                                      0, resample_indices)
                u_tminus1[block_num, (block_num*num_particles//num_blocks):((block_num+1)*num_particles//num_blocks), :] = torch.gather(u_tminus1[block_num, (block_num*num_particles//num_blocks):((block_num+1)*num_particles//num_blocks), :],
                                                                                                                      0, resample_indices.unsqueeze(1).expand(num_particles//num_blocks, 2))

                normalized_weights_block_tminus1[block_num,:] = (1/(num_particles//num_blocks)) * torch.ones(num_particles//num_blocks).to(device)
                normalized_weights_all_tminus1[(block_num*num_particles//num_blocks):((block_num+1)*num_particles//num_blocks)] = (normalized_weights_all_tminus1[(block_num*num_particles//num_blocks):((block_num+1)*num_particles//num_blocks)].sum(0)/(num_particles//num_blocks)).unsqueeze(0).expand(num_particles//num_blocks)
        
        
        
        ### PROPAGATE
        # Each particle has the same s as its parent
        s_t = s_tminus1

        updated_z = MCMC_kernel(image, s_tminus1, s_t, flux_tminus1, u_tminus1, tau_tminus1)
        flux_t = updated_z[0]
        u_t = updated_z[1]
        
        
        
        ### UPDATE WEIGHTS
        # Compute log of incremental weights
        log_incremental_weights_t = tempered_log_p_x_given_z(image, flux_t, u_t, tau_t - tau_tminus1).flatten(0)
        
        # Compute log of unnormalized weights
        log_unnormalized_weights_t = normalized_weights_all_tminus1.log() + log_incremental_weights_t

        # Compute normalized weights
        normalized_weights_block_t = torch.stack(torch.split(log_unnormalized_weights_t, num_particles//num_blocks, dim=0), dim = 0).softmax(1).clamp(1e-40)
        normalized_weights_all_t = log_unnormalized_weights_t.softmax(0).clamp(1e-40)
        
        # Compute effective sample sizes
        ess = 1/(normalized_weights_block_t**2).sum(1)
        
        if torch.all(1 - tau_t.unique() < 1e-6):
            final_iter = t
            break
        
        ### Update
        s_tminus1 = s_t
        flux_tminus1 = flux_t
        u_tminus1 = u_t
        log_unnormalized_weights_tminus1 = log_unnormalized_weights_t
        normalized_weights_block_tminus1 = normalized_weights_block_t
        normalized_weights_all_tminus1 = normalized_weights_all_t
        tau_tminus1 = tau_t
    
    return [s_t, flux_t, u_t, normalized_weights_all_t, normalized_weights_block_t, final_iter]

In [None]:
torch.manual_seed(0)

# Set number of particles, number of SMC steps, and sequence of tempering factors
max_objects_smc = max_objects_generated + 2
num_blocks = max_objects_smc + 1        # allow SMC to guess source counts in the set {0,1,2,...,max_objects_generated,max_objects_generated+1,...,max_objects_generated+2}
num_particles = 500*num_blocks
max_iters = 1000

# Create tensors to store SMC results
post_mean_s_smc = torch.zeros(num_images, device=device)
prob_s_smc = torch.zeros(num_images, num_blocks, device=device)
num_iters_smc = torch.zeros(num_images, device=device)
reconstruction_smc = torch.zeros(num_images, H, W, device=device)

# Run SMC sampler for all images
for j in range(num_images):
    print(f"image {j+1} of {num_images}")
    smc = SMC(images[j], num_blocks, num_particles, max_iters)
    
    
    post_mean_s_smc[j] = (smc[3] * smc[0]).sum()
    prob_s_smc[j] = torch.stack(torch.split(smc[3], num_particles//num_blocks, dim=0), dim=0).sum(1)
    num_iters_smc[j] = smc[5]
    
    smc_argmax_index = smc[3].argmax()
    smc_argmax_flux = smc[1][:,smc_argmax_index]
    smc_argmax_loc = smc[2][:,smc_argmax_index,:]
    reconstruction_smc[j] = (psf(H, W, num_blocks, smc_argmax_loc[:,0], smc_argmax_loc[:,1], eta) * smc_argmax_flux.view(1, 1, num_blocks)).sum(2) + background_intensity
    print(f"\nimage {j+1} took {num_iters_smc[j].int()} iterations.\nfor image {j+1}, true s is {s[j]} and estimated s is {post_mean_s_smc[j]}.\n")
    
    print(f"MSE across {j+1} images:", ((post_mean_s_smc[:(j+1)] - s[:(j+1)])**2).mean().item())
    print(f"MAE across {j+1} images:", ((post_mean_s_smc[:(j+1)] - s[:(j+1)]).abs()).mean().item())
    print(f"correct number of sources detected in {(post_mean_s_smc[:(j+1)].round() == s[:(j+1)]).sum()} of the {j+1} images (accuracy = {(post_mean_s_smc[:(j+1)].round() == s[:(j+1)]).sum()/(j+1)})\n\n\n")

----
----
----

#### Save results

In [None]:
# Synthetic images
torch.save(s, "results/s.pt")
torch.save(flux, "results/flux.pt")
torch.save(u, "results/u.pt")
torch.save(true_intensity, "results/true_intensity.pt")
torch.save(images, "results/images.pt")

# SEP results
torch.save(sep_estimated_s, "results/sep_estimated_s.pt")
torch.save(sep_reconstruction, "results/sep_reconstruction.pt")

# SMC results
torch.save(post_mean_s_smc, "results/post_mean_s_smc.pt")
torch.save(prob_s_smc, "results/prob_s_smc.pt")
torch.save(reconstruction_smc, "results/reconstruction_smc.pt")
torch.save(num_iters_smc, "results/num_iters_smc.pt")