In [None]:
def _save_dataset(ds, train_path, val_path, test_path, n_samples: int, overwrite=False, njobs=1):
    assert n_samples % 3 == 0
    tpath, vpath, ttpath = Path(train_path), Path(val_path), Path(test_path)

    if tpath.exists() or vpath.exists() or ttpath.exists():
        if not overwrite:
            raise ValueError("Overwrite turned on, but files exists.")

    results = Parallel(n_jobs=njobs)(delayed(_task)(ds, ii) for ii in tqdm(range(n_samples)))
    output = torch.cat(results)
    assert output.shape[0] == n_samples

In [None]:
from bliss.datasets.blends import GalsimBlends
from bliss.catalog import TileCatalog
from bliss.models.detection_encoder import DetectionEncoder
import matplotlib.pyplot as plt 

In [None]:
def _setup_blend_galaxy_generator(
    catalog_file: str,
    density: float,
    max_number: float,
    slen: int,
    bp: int,
    seed: int,
    max_mag: float = 27.3,
):
    catalog = btk.catalog.CatsimCatalog.from_file(catalog_file)

    stamp_size = (slen + 2 * bp) * PIXEL_SCALE  # arcsecs

    sampling_function = btk.sampling_functions.DensitySampling(
        max_number=max_number,
        min_number=0,
        density=density,
        stamp_size=stamp_size,
        max_shift=slen * PIXEL_SCALE / 2,  # in arcseconds
        seed=seed,
        max_mag=max_mag,
        mag_name="i_ab",
    )

    survey = btk.survey.get_surveys("LSST")

    return btk.draw_blends.CatsimGenerator(
        catalog,
        sampling_function,
        survey,
        batch_size=1,  # batching is taking care of by torch dataset
        stamp_size=stamp_size,
        njobs=1,
        add_noise="none",  # will add noise and background later
        seed=seed,  # use same seed here
    )


In [None]:
# TODO: a simpler generator with at most 1 source (star or galaxy)
# also, only use 1 tile
# using the generator_setup argument

ds = GalsimBlends(catalog_file='../../../data/OneDegSq.fits', stars_file='../../../data/stars_med_june2018.fits', tile_slen=4, max_sources_per_tile=1, bp=24, slen=40, seed=0, galaxy_density=100, star_density=100, generator_setup=...)

ds.max_n_galaxies, ds.max_n_stars

(4, 4)

In [None]:
# save data from dataset

In [None]:
# instatiate model
"""
        _target_: bliss.models.detection_encoder.DetectionEncoder
        input_transform:
            _target_: bliss.models.detection_encoder.ConcatBackgroundTransform
        n_bands: 1
        tile_slen: 4
        ptile_slen: 52
        max_detections: 1
        channel: 8
        spatial_dropout: 0.0
        dropout: 0.0
        hidden: 128
        annotate_probs: True
        slack: 1.0
        optimizer_params:
            lr: 1e-4
"""
from bliss.models.encoder_layers import (
    ConcatBackgroundTransform,
    EncoderCNN,
    LogBackgroundTransform,
    make_enc_final,
)
# TODO: change model to only use 1 tile
detection_encoder = DetectionEncoder(input_transform=ConcatBackgroundTransform(), n_bands=1, tile_slen=4, ptile_slen=52, max_detections=1, channel=8, spatial_dropout=0.0, dropout=0.0, hidden=128, annotate_probs=True, slack=1.0, optimizer_params={'lr':1e-4})

In [None]:
# create datamodule to train from saved data files

In [None]:
# train on only the small batch

In [None]:
# look at output encoded parameters and compare with true parameters of the tile

# do they seem sensible? 

