In [None]:
import sys
sys.path.append("../")

from smc.sampler import SMCsampler
from smc.prior import StarPrior
from smc.images import ImageModel
from smc.kernel import MetropolisHastings

import torch
# torch.cuda.is_available()
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
torch.set_default_device(device)

import matplotlib.pyplot as plt

### Generate images

We specify a prior and an image model and generate one image.

In [None]:
TruePrior = StarPrior(max_objects = 40,
                      image_height = 32,
                      image_width = 32,
                      min_flux = 10000)

TrueImageModel = ImageModel(image_height = 32,
                            image_width = 32,
                            psf_stdev = 0.75,
                            background = 20000)

true_counts, true_locs, true_fluxes, images = TrueImageModel.generate(Prior = TruePrior,
                                                                      num_images = 1)

In [None]:
for i in range(1):
    print(f"image {i+1}",
          "\ncount\n", true_counts[i].round().item(),
          "\ntotal flux\n", true_fluxes[i].sum().item(),
          "\nloc\n", true_locs[i].cpu().numpy(), "\n\n")
    fig, img = plt.subplots(1,1)
    _ = img.imshow(images[i].cpu())

### Perform inference with SMC

We set the side length of the tiles on which we will run the sampler, and we specify a prior and image model at the tile level. We also specify a mutation kernel to be used within the SMC sampler.

In [None]:
tile_dim = 4

TilePrior = StarPrior(max_objects = 3,
                      image_height = tile_dim,
                      image_width = tile_dim,
                      min_flux = 10000,
                      pad = 2)

TileImageModel = ImageModel(image_height = tile_dim,
                            image_width = tile_dim,
                            psf_stdev = 0.75,
                            background = 20000)

MHKernel = MetropolisHastings(num_iters = 100,
                              locs_stdev = 0.1,
                              features_stdev = 2000)

Now we initialize an `SMCsampler` object and run it on the tiles.

In [None]:
smc = SMCsampler(image = images[0],
                 tile_dim = tile_dim,
                 Prior = TilePrior,
                 ImageModel = TileImageModel,
                 MutationKernel = MHKernel,
                 num_catalogs_per_count = 1000,
                 max_smc_iters = 500)

print(f"True count: {true_counts[0]}")
print(f"True total flux: {true_fluxes[0].sum()}\n")

smc.run(print_progress = True)
smc.summarize()

### reshaping tensors

In [None]:
from einops import rearrange, reduce

In [None]:
# data = smc.tiled_image
# counts = smc.counts
# locs = smc.locs
# features = smc.features
# weights = smc.weights_intercount

# tile_dim = data.shape[0]

# k = torch.tensor(tile_dim).log2().int()
# num_aggregation_levels = 2 * k

# for _ in range(k):
#     data = data.unfold(0, 2, 2).unfold(1, 2, 2).squeeze([0,1])
#     counts = counts.unfold(0, 2, 2).unfold(1, 2, 2).squeeze([0,1])
#     locs = locs.unfold(0, 2, 2).unfold(1, 2, 2).squeeze([0,1])
#     features = features.unfold(0, 2, 2).unfold(1, 2, 2).squeeze([0,1])
#     weights = weights.unfold(0, 2, 2).unfold(1, 2, 2).squeeze([0,1])

# data = rearrange(data, 'h w ... -> ... h w')
# counts = rearrange(counts, 'n ... -> ... n')
# locs = rearrange(locs, 'n m l ... -> ... n m l')
# features = rearrange(features, 'n m ... -> ... n m')
# weights = rearrange(weights, 'n ... -> ... n')

In [None]:
# # this is how to collapse the image along the height axis
# rearrange(data, 't ... h w -> ... (t h) w')
# # ... and along the width axis
# rearrange(data, 't ... h w -> ... h (t w)')

In [None]:
data = smc.tiled_image
counts = smc.counts
locs = smc.locs
features = smc.features
weights = smc.weights_intercount

### resample

In [None]:
resampled_index = weights.flatten(0,1).multinomial(smc.num_catalogs,
                                      replacement = True).clamp(min = 0,
                                                                max = smc.num_catalogs - 1)

for h in range(8):
    for w in range(8):
        counts[h,w] = counts[h,w, resampled_index[h,w]]
        locs[h,w] = locs[h,w, resampled_index[h,w]]
        features[h,w] = features[h,w, resampled_index[h,w]]
        weights[h,w] = 1 / smc.num_catalogs

### merge

In [None]:
new_data = rearrange(data.unfold(0, 2, 2), 'nh nw h w t -> nh nw (t h) w')

In [None]:
new_counts = counts.unfold(0, 2, 2).sum(3)

In [None]:
new_locs = rearrange(locs.unfold(0, 2, 2), 'nh nw N M l t -> nh nw N (t M) l')
locs_mask = (new_locs != 0).int()
locs_index = torch.sort(locs_mask, dim = 3, descending = True)[1]
new_locs = torch.gather(new_locs, dim = 3, index = locs_index)

In [None]:
new_features = rearrange(features.unfold(0, 2, 2), 'nh nw N M t -> nh nw N (t M)')
features_mask = (new_features != 0).int()
features_index = torch.sort(features_mask, dim = 3, descending = True)[1]
new_features = torch.gather(new_features, dim = 3, index = features_index)

### loglikelihood

In [None]:
ImageModel(image_height=8,
           image_width=4,
           psf_stdev=0.75,
           background=20000).loglikelihood(new_data, new_locs, new_features).shape

### prior

In [None]:
StarPrior(max_objects = 2*3,
          image_height = 8,
          image_width = 4,
          min_flux = 10000,
          pad = 2).log_prob(new_counts, new_locs, new_features).shape