In [None]:
import sys
sys.path.append("../")

from smc.sampler import SMCsampler
from smc.prior import StarPrior
from smc.images import ImageModel
from smc.kernel import MetropolisHastings
from smc.aggregate import Aggregate

import torch
# torch.cuda.is_available()
device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
torch.set_default_device(device)

import matplotlib.pyplot as plt

### Generate images

We specify a prior and an image model and generate one image.

In [None]:
image_dim = 16
true_psf_stdev = 3
true_background = 10
true_flux_scale = 10
true_flux_shape = 1.75

TruePrior = StarPrior(max_objects = 10,
                      image_height = image_dim,
                      image_width = image_dim,
                      flux_scale = true_flux_scale,
                      flux_shape = true_flux_shape)

TrueImageModel = ImageModel(image_height = image_dim,
                            image_width = image_dim,
                            psf_stdev = true_psf_stdev,
                            background = true_background)

true_counts, true_locs, true_fluxes, images = TrueImageModel.generate(Prior = TruePrior,
                                                                      num_images = 1)

In [None]:
for i in range(1):
    print(f"image {i+1}",
          "\ncount\n", true_counts[i].round().item(),
          "\ntotal flux\n", true_fluxes[i].sum().item(),
          "\nloc\n", true_locs[i].cpu().numpy(), "\n\n")
    fig, img = plt.subplots(1,1)
    tmp = img.imshow(images[i].cpu())
    _ = fig.colorbar(tmp)

### Perform inference with SMC

We set the side length of the tiles on which we will run the sampler, and we specify a prior and image model at the tile level. We also specify a mutation kernel to be used within the SMC sampler.

We'll assume that the image background, PSF standard deviation, and flux prior parameters are all known.

In [None]:
tile_dim = 8

TilePrior = StarPrior(max_objects = 4,
                      image_height = tile_dim,
                      image_width = tile_dim,
                      flux_scale = true_flux_scale,
                      flux_shape = true_flux_shape,
                      pad = 4)

TileImageModel = ImageModel(image_height = tile_dim,
                            image_width = tile_dim,
                            psf_stdev = true_psf_stdev,
                            background = true_background)

MHKernel = MetropolisHastings(num_iters = 200,
                              locs_stdev = 0.1,
                              features_stdev = 25)

Now we initialize an `SMCsampler` object and run it on the tiles.

In [None]:
smc = SMCsampler(image = images[0],
                 tile_dim = tile_dim,
                 Prior = TilePrior,
                 ImageModel = TileImageModel,
                 MutationKernel = MHKernel,
                 num_catalogs_per_count = 500,
                 max_smc_iters = 500)

print(f"True count: {true_counts[0]}")
print(f"True total flux: {true_fluxes[0].sum()}\n")

smc.run(print_progress = True)
smc.summarize()

Now we instantiate an `Aggregate` object with the tile-level results from above:

In [None]:
agg = Aggregate(smc.Prior,
                smc.ImageModel,
                smc.tiled_image,
                smc.counts,
                smc.locs,
                smc.features,
                smc.weights_intercount)

And we run the aggregation procedure to obtain image-level catalogs:

In [None]:
agg.run()

We compute the posterior mean number of light sources:

In [None]:
(agg.weights * agg.counts).squeeze().sum().round(decimals=2).item()

And we compute the posterior mean total flux:

In [None]:
(agg.weights * agg.features.sum(3)).squeeze().sum().item()