In [None]:
import sys
sys.path.append("../")

from smc.sampler import SMCsampler
from smc.prior import StarPrior
from smc.images import ImageModel
from smc.kernel import MetropolisHastings
from smc.aggregate import Aggregate

import torch
# torch.cuda.is_available()
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
torch.set_default_device(device)

import matplotlib.pyplot as plt

### Generate images

We specify a prior and an image model and generate one image.

In [None]:
image_dim = 16
true_psf_stdev = 1.5
true_background = 10000
true_flux_mean = 80000
true_flux_stdev = 20000

TruePrior = StarPrior(max_objects = 8,
                      image_height = image_dim,
                      image_width = image_dim,
                      flux_mean = true_flux_mean,
                      flux_stdev = true_flux_stdev)

TrueImageModel = ImageModel(image_height = image_dim,
                            image_width = image_dim,
                            psf_stdev = true_psf_stdev,
                            background = true_background)

true_counts, true_locs, true_fluxes, images = TrueImageModel.generate(Prior = TruePrior,
                                                                      num_images = 1)

In [None]:
for i in range(1):
    print(f"image {i+1}",
          "\ncount\n", true_counts[i].round().item(),
          "\ntotal flux\n", true_fluxes[i].sum().item(),
          "\nloc\n", true_locs[i].cpu().numpy(), "\n\n")
    fig, img = plt.subplots(1,1)
    tmp = img.imshow(images[i].cpu(), cmap = 'viridis')
    _ = fig.colorbar(tmp)
    _ = img.scatter(true_locs.squeeze()[:true_counts.squeeze(),1].cpu() - 0.5,
                    true_locs.squeeze()[:true_counts.squeeze(),0].cpu() - 0.5, marker = 'x', color = 'red')
    _ = img.vlines(x=image_dim//4 - 0.5, ymin = 0 - 0.5, ymax = image_dim - 0.5, color = 'white')
    _ = img.vlines(x=2*image_dim//4 - 0.5, ymin = 0 - 0.5, ymax = image_dim - 0.5, color = 'white')
    _ = img.vlines(x=3*image_dim//4 - 0.5, ymin = 0 - 0.5, ymax = image_dim - 0.5, color = 'white')
    _ = img.hlines(y=image_dim//4 - 0.5, xmin = 0 - 0.5, xmax = image_dim - 0.5, color = 'white')
    _ = img.hlines(y=2*image_dim//4 - 0.5, xmin = 0 - 0.5, xmax = image_dim - 0.5, color = 'white')
    _ = img.hlines(y=3*image_dim//4 - 0.5, xmin = 0 - 0.5, xmax = image_dim - 0.5, color = 'white')

### Perform inference with SMC

We set the side length of the tiles on which we will run the sampler, and we specify a prior and image model at the tile level. We also specify a mutation kernel to be used within the SMC sampler.

We'll assume that the image background, PSF standard deviation, and flux prior parameters are all known.

In [None]:
tile_dim = 16

TilePrior = StarPrior(max_objects = 8,
                      image_height = tile_dim,
                      image_width = tile_dim,
                      flux_mean = true_flux_mean,
                      flux_stdev = true_flux_stdev,
                      pad = 2)

TileImageModel = ImageModel(image_height = tile_dim,
                            image_width = tile_dim,
                            psf_stdev = true_psf_stdev,
                            background = true_background)

MHKernel = MetropolisHastings(num_iters = 100,
                              locs_stdev = 0.1,
                              features_stdev = 2500,
                              features_min = 20000)

Now we initialize an `SMCsampler` object and run it on the tiles.

In [None]:
smc = SMCsampler(image = images[0],
                 tile_dim = tile_dim,
                 Prior = TilePrior,
                 ImageModel = TileImageModel,
                 MutationKernel = MHKernel,
                 num_catalogs_per_count = 1000,
                 max_smc_iters = 500)

print(f"True count: {true_counts[0]}")
print(f"True total flux: {true_fluxes[0].sum()}\n")

smc.run(print_progress = True)

Now we instantiate an `Aggregate` object with the tile-level results from above:

In [None]:
agg = Aggregate(smc.Prior,
                smc.ImageModel,
                smc.tiled_image,
                smc.counts,
                smc.locs,
                smc.features,
                smc.weights_intercount)

We run the aggregation procedure to obtain image-level catalogs:

In [None]:
agg.run()

We compute the posterior mean number of light sources:

In [None]:
agg.posterior_mean_counts

And we compute the posterior mean total flux:

In [None]:
agg.posterior_mean_total_flux

Finally, we reconstruct the image using one of the catalogs we just sampled:

In [None]:
index = torch.randint(low = 0, high = agg.counts.shape[-1] - 1, size = [1])
psf = agg.ImageModel.psf(agg.locs[:,:,index])
rate = (psf * agg.features[:,:,index].unsqueeze(-3).unsqueeze(-4)).sum(-1) + agg.ImageModel.background
reconstruction = torch.distributions.Poisson(rate).sample().squeeze([0,1]).permute((2,0,1))
_ = plt.imshow(reconstruction[0].cpu())
_ = plt.colorbar()
_ = plt.scatter(true_locs.squeeze()[:true_counts.squeeze(),1].cpu() - 0.5,
                true_locs.squeeze()[:true_counts.squeeze(),0].cpu() - 0.5, marker = 'x', color = 'red')
_ = plt.scatter(agg.locs.squeeze()[index,:agg.counts.squeeze()[index],1].cpu() - 0.5,
                agg.locs.squeeze()[index,:agg.counts.squeeze()[index],0].cpu() - 0.5, marker = 'x', color = 'blue')

In [None]:
true_fluxes

In [None]:
agg.locs.squeeze()[index]

In [None]:
agg.features.squeeze()[index]