In [1]:
import numpy as np

In [2]:
import torch
import pytest

from celeste import device, use_cuda
from celeste import train
from celeste.datasets import simulated_datasets
from celeste.models import sourcenet

In [3]:
# data parameters
n_bands = 2
slen = 50
n_galaxy_params = 8

ptile_slen = 8

max_detections = 4

# Mock true parameters

In [4]:
# number of padded tiles
n_ptiles = 100

In [5]:
# simulate true number of sources. Just a vector of length n_ptiles (i.e. number of ptiles)
true_n_sources = torch.from_numpy(np.random.choice(max_detections + 1, n_ptiles)) 

# boolean array of on/off sources
true_is_on_array = simulated_datasets.get_is_on_from_n_sources(true_n_sources, max_detections)

In [6]:
# true paramters. 
# remember that off sources should be filled with zeros, hence the multiplication by true_is_on_array

# locations are in 0-1 
true_locs = torch.rand(n_ptiles, max_detections, 2) * true_is_on_array.unsqueeze(2)

# log fluxes are reals and galaxy parameters are reals
true_log_fluxes = torch.randn(n_ptiles, max_detections, n_bands) * true_is_on_array.unsqueeze(2)
true_galaxy_params = torch.randn(n_ptiles, max_detections, n_galaxy_params) * true_is_on_array.unsqueeze(2)

In [7]:
# boolean indicating whether source is galaxy
true_galaxy_bool = (torch.rand(n_ptiles, max_detections) > 0.5).float()

# Mock estimated sources

In [8]:
# multiply by true_is_on_array since off-stars should return zeros
loc_mean = torch.rand(true_locs.shape) * true_is_on_array.unsqueeze(2) 
loc_logvar = torch.randn(true_locs.shape) * true_is_on_array.unsqueeze(2)

log_flux_mean = torch.randn(true_log_fluxes.shape) * true_is_on_array.unsqueeze(2)
log_flux_logvar = torch.randn(true_log_fluxes.shape) * true_is_on_array.unsqueeze(2)

galaxy_params_mean = torch.randn(true_galaxy_params.shape) * true_is_on_array.unsqueeze(2)
galaxy_params_logvar = torch.randn(true_galaxy_params.shape) * true_is_on_array.unsqueeze(2)

# for each detection, log-prob that it is a galaxy
prob_galaxy = torch.rand(n_ptiles, max_detections)

# log probability on the number of sources
n_source_log_probs = torch.log(torch.rand(n_ptiles, max_detections + 1))

In [9]:
import celeste.sleep as sleep

In [10]:
locs_log_probs_all = sleep._get_params_logprob_all_combs(true_locs, loc_mean, loc_logvar)

star_params_log_probs_all = sleep._get_params_logprob_all_combs(
        true_log_fluxes, log_flux_mean, log_flux_logvar
    )

galaxy_params_log_probs_all = sleep._get_params_logprob_all_combs(
    true_galaxy_params, galaxy_params_mean, galaxy_params_logvar
)

In [11]:
locs_loss_all_perm, star_params_loss_all_perm, \
    galaxy_params_loss_all_perm, galaxy_bool_loss_all_perm = sleep._get_log_probs_all_perms(
            locs_log_probs_all,
            star_params_log_probs_all,
            galaxy_params_log_probs_all,
            prob_galaxy,
            true_is_on_array,
            true_galaxy_bool,
        )

In [12]:
prob_galaxy.shape

torch.Size([100, 4])

In [13]:
true_galaxy_bool.shape

torch.Size([100, 4])

In [14]:
galaxy_bool_loss_all_perm.shape

torch.Size([100, 24])

In [15]:
locs_loss_all_perm.shape

torch.Size([100, 24])

In [16]:
star_params_loss_all_perm.shape

torch.Size([100, 24])

In [17]:
sleep._get_params_loss(
    loc_mean,
    loc_logvar,
    log_flux_mean,
    log_flux_logvar,
    galaxy_params_mean,
    galaxy_params_logvar,
    prob_galaxy,
    n_source_log_probs,
    true_locs,
    true_log_fluxes,
    true_galaxy_params,
    true_is_on_array,
    true_galaxy_bool,
)

(tensor(12.3114),
 tensor([5.6985, 2.3032, 0.8383, 0.0712, 0.7623, 1.5677, 0.5747, 0.2081, 0.9482,
         0.6053, 2.5197, 1.8836, 1.6294, 0.4612, 0.8105, 0.9680, 0.1845, 1.2788,
         1.7605, 0.1172, 0.1371, 0.2990, 0.4544, 1.4849, 0.0694, 0.1654, 0.6528,
         0.2866, 0.9981, 0.1647, 0.1124, 0.3828, 0.6232, 0.0159, 0.6374, 1.9130,
         0.4327, 0.0778, 1.0626, 0.9690, 0.1888, 0.0383, 0.8296, 0.1526, 0.5219,
         0.6841, 1.4540, 0.2348, 1.2847, 0.3865, 0.8628, 0.1956, 1.4547, 1.3304,
         0.8966, 0.1362, 1.7701, 0.0438, 0.1197, 1.0410, 2.8199, 1.0190, 2.4812,
         0.2110, 0.3077, 1.5948, 1.8083, 1.0171, 0.1898, 1.2035, 3.3004, 1.6085,
         0.8116, 0.1355, 0.4099, 1.2152, 0.0538, 0.0144, 0.0116, 2.4922, 0.8345,
         0.4131, 0.8482, 0.0702, 2.5615, 0.6038, 0.1136, 0.9831, 1.2291, 0.4431,
         2.0209, 0.1337, 0.6964, 0.4840, 0.2178, 1.4894, 0.2808, 1.3756, 0.7579,
         1.2536]),
 tensor([ -5.8630, -17.7511, -24.3982,  -4.7383,  -3.0827, -17.0694, -18

In [18]:
from torch.nn import functional

In [19]:
true_n_sources = true_is_on_array.sum(1)
one_hot_encoding = functional.one_hot(true_n_sources, n_source_log_probs.shape[1])

In [20]:
true_n_sources

tensor([3, 4, 3, 4, 3, 4, 4, 2, 2, 0, 3, 2, 1, 1, 0, 0, 2, 1, 1, 3, 1, 1, 4, 2,
        2, 3, 2, 4, 3, 1, 3, 4, 3, 0, 3, 4, 3, 2, 1, 1, 0, 2, 3, 4, 3, 0, 1, 0,
        0, 0, 3, 4, 4, 4, 4, 3, 4, 2, 2, 0, 4, 0, 1, 0, 1, 3, 1, 2, 0, 1, 4, 1,
        1, 2, 3, 0, 1, 0, 3, 2, 2, 2, 2, 0, 1, 3, 4, 3, 4, 4, 4, 4, 4, 4, 2, 4,
        0, 2, 2, 2])