# BLISS AstroPy Affiliate Package Tutorial

## Introduction

Bayesian Light Source Separator (BLISS) is a Bayesian method for deblending and cataloging light sources.

## Installation

In [1]:
%env BLISS_HOME=/home/zhteoh/730-astropy-integration

env: BLISS_HOME=/home/zhteoh/730-astropy-integration


In [2]:
!pip install -e $BLISS_HOME

Obtaining file:///home/zhteoh/730-astropy-integration
  Installing build dependencies ... [?25ldone
[?25h  Checking if build backend supports build_editable ... [?25ldone
[?25h  Getting requirements to build editable ... [?25ldone
[?25h  Preparing editable metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: bliss-deblender
  Building editable for bliss-deblender (pyproject.toml) ... [?25ldone
[?25h  Created wheel for bliss-deblender: filename=bliss_deblender-0.1.2-py3-none-any.whl size=4250 sha256=718c133f465dfd758fb1563ff1b26c8f6afc4497e997a35992f5c1ae00d43289
  Stored in directory: /tmp/pip-ephem-wheel-cache-x5uyc7s4/wheels/6e/ff/3a/b3c8dc9273fa82437c76778098de2f5c5c660d410e2f8504bc
Successfully built bliss-deblender
Installing collected packages: bliss-deblender
  Attempting uninstall: bliss-deblender
    Found existing installation: bliss-deblender 0.1.2
    Uninstalling bliss-deblender-0.1.2:
      Successfully uninstalled bliss-deblender-0.1.2

# Tutorial

## Train the model

### Generate synthetic image data

In [3]:
from bliss.api import generate

In [4]:
generate(
    n_batches=3, 
    batch_size=64, 
    max_images_per_file=128, 
    cached_data_path="/data/scratch/zhteoh/730-tutorial/dataset"
)

Data will be saved to /data/scratch/zhteoh/730-tutorial/dataset


Simulating images in batches for file: 100%|██████████| 2/2 [00:56<00:00, 28.47s/it]
Simulating images in batches for file: 100%|██████████| 2/2 [00:55<00:00, 27.63s/it]2s/it]
Generating and writing cached dataset files: 100%|██████████| 2/2 [01:52<00:00, 56.17s/it]


#### Pass additional custom configuration parameters

In [5]:
generate(
    n_batches=3,  # required
    batch_size=64,  # required
    max_images_per_file=128,  # required
    cached_data_path="/data/scratch/zhteoh/730-tutorial/dataset",  # required
    simulator={"prior": {"mean_sources": 0.02}},  # optional
    generate={"file_prefix": "dataset"},  # optional
)

Data will be saved to /data/scratch/zhteoh/730-tutorial/dataset


Simulating images in batches for file: 100%|██████████| 2/2 [00:14<00:00,  7.25s/it]
Simulating images in batches for file: 100%|██████████| 2/2 [00:14<00:00,  7.39s/it]6s/it]
Generating and writing cached dataset files: 100%|██████████| 2/2 [00:29<00:00, 14.70s/it]


In [6]:
# Check that the dataset is generated
!ls /data/scratch/zhteoh/730-tutorial/dataset
!du -sh /data/scratch/zhteoh/730-tutorial/dataset
# !cat /data/scratch/zhteoh/730-tutorial/dataset/hparams.yaml

import torch
with open("/data/scratch/zhteoh/730-tutorial/dataset/dataset_0.pt", "rb") as f:
    dataset = torch.load(f)
print(len(dataset))
print(dataset[0]["images"].shape)

dataset_0.pt  dataset_1.pt  hparams.yaml
18M	/data/scratch/zhteoh/730-tutorial/dataset
128
torch.Size([1, 80, 80])


### Train the model

In [7]:
from bliss.api import train

#### Without pretrained weights

In [8]:
train(
    weight_save_path="/data/scratch/zhteoh/730-tutorial/output/tutorial_encoder/0.pt",
)

Global seed set to 42

                 from  n    params  module                                  arguments                     
  0                -1  1      3328  yolov5.models.common.Conv               [2, 64, 5, 1]                 
  1                -1  3     12672  yolov5.models.common.Conv               [64, 64, 1, 1]                
  2                -1  1     73984  yolov5.models.common.Conv               [64, 128, 3, 2]               
  3                -1  1    147712  yolov5.models.common.Conv               [128, 128, 3, 1]              
  4                -1  1    295424  yolov5.models.common.Conv               [128, 256, 3, 2]              
  5                -1  6   1118208  yolov5.models.common.C3                 [256, 256, 6]                 
  6                -1  1   1180672  yolov5.models.common.Conv               [256, 512, 3, 2]              
  7                -1  9   6433792  yolov5.models.common.C3                 [512, 512, 9]                 
  8           

Sanity Checking: 0it [00:00, ?it/s]

Generating fixed validation set: 100%|██████████| 10/10 [04:45<00:00, 28.59s/it]


Training: 0it [00:00, ?it/s]

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/zhteoh/730-astropy-integration/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_2025915/1460248005.py", line 1, in <module>
    train(
  File "/home/zhteoh/730-astropy-integration/bliss/api.py", line 127, in train
  File "/home/zhteoh/730-astropy-integration/bliss/train.py", line 76, in train
    model_checkpoint = torch.load(checkpoint_callback.best_model_path, map_location="cpu")
  File "/home/zhteoh/730-astropy-integration/.venv/lib/python3.10/site-packages/torch/serialization.py", line 771, in load
    with _open_file_like(f, 'rb') as opened_file:
  File "/home/zhteoh/730-astropy-integration/.venv/lib/python3.10/site-packages/torch/serialization.py", line 270, in _open_file_like
    return _open_file(name_or_buffer, mode)
  File "/home/zhteoh/730-astropy-integration/.venv/lib/python3.10/site-packages/torch/serializ

#### With pretrained weights

Download our relevant pretrained weights for your sky survey.

In [9]:
from bliss.api import load_pretrained_weights_for_survey

import os
assert os.path.exists("/data/scratch/zhteoh/730-tutorial/pretrained_weights")

load_pretrained_weights_for_survey(
    survey="sdss",
    pretrained_weights_path="/data/scratch/zhteoh/730-tutorial/pretrained_weights/sdss_pretrained.pt",
)

#### Train on cached generated disk dataset

In [10]:
from bliss.api import train_on_cached_data

In [11]:
train_on_cached_data(
    weight_save_path="/data/scratch/zhteoh/730-tutorial/output/tutorial_encoder/0.pt",
    cached_data_path="/data/scratch/zhteoh/730-tutorial/dataset",
    train_n_batches=2,
    batch_size=64,
    val_split_file_idxs=[1],
    training={"pretrained_weights": "/data/scratch/zhteoh/730-tutorial/pretrained_weights/sdss_pretrained.pt"}
)

Global seed set to 42

                 from  n    params  module                                  arguments                     
  0                -1  1      3328  yolov5.models.common.Conv               [2, 64, 5, 1]                 
  1                -1  3     12672  yolov5.models.common.Conv               [64, 64, 1, 1]                
  2                -1  1     73984  yolov5.models.common.Conv               [64, 128, 3, 2]               
  3                -1  1    147712  yolov5.models.common.Conv               [128, 128, 3, 1]              
  4                -1  1    295424  yolov5.models.common.Conv               [128, 256, 3, 2]              
  5                -1  6   1118208  yolov5.models.common.C3                 [256, 256, 6]                 
  6                -1  1   1180672  yolov5.models.common.Conv               [256, 512, 3, 2]              
  7                -1  9   6433792  yolov5.models.common.C3                 [512, 512, 9]                 
  8           

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 9, global step 20: 'val/loss' reached 0.26338 (best 0.26338), saving model to '/home/zhteoh/730-astropy-integration/output/version_7/checkpoints/epoch=9-val_loss=0.263.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 19, global step 40: 'val/loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 29, global step 60: 'val/loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 39, global step 80: 'val/loss' was not in top 1


## Run the model

### Using sample dataset

#### Download the sample dataset

In [12]:
from bliss.surveys.sdss import PhotoFullCatalog, SloanDigitalSkySurvey

sdss_data_path = "/home/zhteoh/730-astropy-integration/data/sdss"
photo_cat = PhotoFullCatalog.from_file(sdss_data_path, run=94, camcol=1, field=12, band=2)
sdss = SloanDigitalSkySurvey(sdss_data_path, 94, 1, (12,), (2,))

#### Get predictions for the sample dataset

In [13]:
from bliss.api import predict_sdss

est_cat, est_cat_table, galaxy_params_table = predict_sdss(
    data_path="/home/zhteoh/730-astropy-integration/case_studies/astropy_integration_730/data/sdss", 
    weight_save_path="/home/zhteoh/730-astropy-integration/case_studies/astropy_integration_730/data/pretrained_models/sdss.pt",
    # predict={"dataset": {"run": 94, "camcol": 1, "fields": [12]}}
)


                 from  n    params  module                                  arguments                     
  0                -1  1      3328  yolov5.models.common.Conv               [2, 64, 5, 1]                 
  1                -1  3     12672  yolov5.models.common.Conv               [64, 64, 1, 1]                
  2                -1  1     73984  yolov5.models.common.Conv               [64, 128, 3, 2]               
  3                -1  1    147712  yolov5.models.common.Conv               [128, 128, 3, 1]              
  4                -1  1    295424  yolov5.models.common.Conv               [128, 256, 3, 2]              
  5                -1  6   1118208  yolov5.models.common.C3                 [256, 256, 6]                 
  6                -1  1   1180672  yolov5.models.common.Conv               [256, 512, 3, 2]              
  7                -1  9   6433792  yolov5.models.common.C3                 [512, 512, 9]                 
  8                -1  1   4720640  

In [14]:
from IPython.display import display
from IPython.core.display import HTML

with open("./predict.html", "r") as f:
    html_str = f.read()
    display(HTML(html_str))

In [15]:
est_cat_table.show_in_notebook(display_length=5)

idx,star_log_fluxes,star_fluxes,galaxy_bools,star_bools,"galaxy_params [galaxy_flux, galaxy_disk_frac, galaxy_beta_radians, galaxy_disk_q, galaxy_a_d, galaxy_bulge_q, galaxy_a_b]"
Unnamed: 0_level_1,dex(nmgy),nmgy,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,6.943897,1036.8025,1,0,"( 1985.7, 0.3987, 3.3787, 0.51905, 1.921, 0.51818, 0.84206)"
1,8.34118,4193.034,1,0,"( 7689, 0.17275, 2.4744, 0.45257, 1.437, 0.40004, 0.47546)"
2,6.7338595,840.3845,0,1,"( 995.96, 0.34902, 3.0425, 0.41535, 1.3364, 0.43808, 0.61264)"
3,8.78471,6533.5776,0,1,"( 5537.7, 0.064413, 2.9679, 0.48888, 1.6053, 0.28075, 0.43291)"
4,11.414507,90626.945,0,1,"( 52242, 0.02524, 3.0408, 0.495, 1.8073, 0.19888, 0.41041)"
5,7.0182014,1116.7761,1,0,"( 1597.9, 0.43475, 3.71, 0.45056, 1.2087, 0.474, 0.64957)"
6,9.667396,15794.161,0,1,"( 1938.7, 0.23446, 3.1722, 0.51027, 1.8463, 0.30432, 0.75711)"
7,12.765108,349796.94,0,1,"( 1.2707e+05, 0.017796, 3.0484, 0.44899, 1.5763, 0.21264, 0.46637)"
8,8.596792,5414.264,1,0,"( 15978, 0.17065, 3.7518, 0.78209, 1.3361, 0.90699, 0.57194)"
9,7.6643662,2131.0417,1,0,"( 5006.7, 0.724, 4.8921, 0.13465, 1.4052, 0.26175, 0.89717)"


In [16]:
galaxy_params_table.show_in_notebook(display_length=5)

idx,galaxy_flux,galaxy_disk_frac,galaxy_beta_radians,galaxy_disk_q,galaxy_a_d,galaxy_bulge_q,galaxy_a_b
Unnamed: 0_level_1,nmgy,Unnamed: 2_level_1,rad,Unnamed: 4_level_1,arcsec,Unnamed: 6_level_1,arcsec
0,1985.6808,0.39870256,3.3786595,0.5190524,1.9210416,0.5181762,0.84205616
1,7688.9707,0.17275049,2.4743593,0.4525674,1.4369837,0.40003857,0.47546133
2,995.95703,0.3490162,3.0424857,0.41534755,1.3363634,0.43808466,0.61263925
3,5537.7144,0.06441329,2.9679139,0.48888382,1.6053165,0.28075293,0.43291458
4,52242.484,0.025240203,3.0408385,0.4949989,1.8073436,0.19888058,0.41041377
5,1597.8751,0.43475053,3.709997,0.4505621,1.2087125,0.4740012,0.64956796
6,1938.6857,0.23445645,3.1721957,0.5102745,1.8463191,0.30431896,0.7571146
7,127069.7,0.01779637,3.048355,0.44899118,1.5762504,0.21264386,0.4663707
8,15977.64,0.17064832,3.751763,0.78208774,1.3361334,0.9069929,0.5719446
9,5006.736,0.7239966,4.8920856,0.13464788,1.4051633,0.26175326,0.8971654


#### Save predicted catalog to FITS file

In [17]:
est_cat_table.write("est_cat.fits", format="fits")

OSError: File est_cat.fits already exists. If you mean to replace it then use the argument "overwrite=True".

In [18]:
# Check that catalog is saved as intended
from astropy.table import Table

est_cat_table = Table.read("est_cat.fits", format="fits")
est_cat_table.show_in_notebook(display_length=5)

idx,star_log_fluxes,star_fluxes,galaxy_bools,star_bools,"galaxy_params [galaxy_flux, galaxy_disk_frac, galaxy_beta_radians, galaxy_disk_q, galaxy_a_d, galaxy_bulge_q, galaxy_a_b]"
0,6.943897,1036.8025,1,0,"( 1985.7, 0.3987, 3.3787, 0.51905, 1.921, 0.51818, 0.84206)"
1,8.34118,4193.034,1,0,"( 7689, 0.17275, 2.4744, 0.45257, 1.437, 0.40004, 0.47546)"
2,6.7338595,840.3845,0,1,"( 995.96, 0.34902, 3.0425, 0.41535, 1.3364, 0.43808, 0.61264)"
3,8.78471,6533.5776,0,1,"( 5537.7, 0.064413, 2.9679, 0.48888, 1.6053, 0.28075, 0.43291)"
4,11.414507,90626.945,0,1,"( 52242, 0.02524, 3.0408, 0.495, 1.8073, 0.19888, 0.41041)"
5,7.0182014,1116.7761,1,0,"( 1597.9, 0.43475, 3.71, 0.45056, 1.2087, 0.474, 0.64957)"
6,9.667396,15794.161,0,1,"( 1938.7, 0.23446, 3.1722, 0.51027, 1.8463, 0.30432, 0.75711)"
7,12.765108,349796.94,0,1,"( 1.2707e+05, 0.017796, 3.0484, 0.44899, 1.5763, 0.21264, 0.46637)"
8,8.596792,5414.264,1,0,"( 15978, 0.17065, 3.7518, 0.78209, 1.3361, 0.90699, 0.57194)"
9,7.6643662,2131.0417,1,0,"( 5006.7, 0.724, 4.8921, 0.13465, 1.4052, 0.26175, 0.89717)"


#### Evaluate prediction

In [19]:
import torch

from bliss.metrics import BlissMetrics

est_cat_cuda = est_cat.to(torch.device("cpu"))
photo_cat_cuda = photo_cat.to(torch.device("cpu"))

metrics = BlissMetrics()
results = metrics(est_cat_cuda, photo_cat_cuda)

print(results)

{'detection_precision': tensor(0.), 'detection_recall': tensor(0.), 'f1': tensor(nan), 'avg_distance': tensor(59.98706), 'n_matches': tensor(0), 'n_matches_gal_coadd': tensor(0), 'class_acc': tensor(nan), 'gal_tp': tensor(0), 'gal_fp': tensor(0), 'gal_fn': tensor(0), 'gal_tn': tensor(0)}


### Using user-specified dataset

#### Download online dataset

In [5]:
from astropy.coordinates import SkyCoord
from astroquery.sdss import SDSS
from pathlib import Path

from bliss.api import load_survey

# pos = SkyCoord('0h8m05.63s +14d50m23.3s', frame='icrs') # 1011/3/44
# pos = SkyCoord("1h8m05.73s +13d10m20.3s", frame="icrs") # 4829/5/27
pos = SkyCoord("1h2m05.83s -2d11m20.3s", frame="icrs") # 2699/4/71
region = SDSS.query_region(pos, radius="5 arcsec")
run, camcol, field = region["run"][0], region["camcol"][0], region["field"][0]
print("run:", run, "camcol:", camcol, "field:", field)
load_survey("sdss", run, camcol, field, download_dir=Path("/home/zhteoh/730-astropy-integration/case_studies/astropy_integration_730/data/sdss"))

run: 2699 camcol: 4 field: 71


#### Get predictions for the downloaded dataset

In [8]:
from bliss.api import predict_sdss

est_cat_dl, est_cat_table_dl, galaxy_params_table_dl = predict_sdss(
    data_path="/home/zhteoh/730-astropy-integration/case_studies/astropy_integration_730/data/sdss",
    weight_save_path="/home/zhteoh/730-astropy-integration/case_studies/astropy_integration_730/data/pretrained_models/sdss.pt",
    predict={"dataset": {"run": 1011, "camcol": 3, "fields": [44]}}
)


                 from  n    params  module                                  arguments                     
  0                -1  1      3328  yolov5.models.common.Conv               [2, 64, 5, 1]                 
  1                -1  3     12672  yolov5.models.common.Conv               [64, 64, 1, 1]                
  2                -1  1     73984  yolov5.models.common.Conv               [64, 128, 3, 2]               
  3                -1  1    147712  yolov5.models.common.Conv               [128, 128, 3, 1]              
  4                -1  1    295424  yolov5.models.common.Conv               [128, 256, 3, 2]              
  5                -1  6   1118208  yolov5.models.common.C3                 [256, 256, 6]                 
  6                -1  1   1180672  yolov5.models.common.Conv               [256, 512, 3, 2]              
  7                -1  9   6433792  yolov5.models.common.C3                 [512, 512, 9]                 
  8                -1  1   4720640  

In [9]:
est_cat_table_dl.show_in_notebook(display_length=5)

idx,star_log_fluxes,star_fluxes,galaxy_bools,star_bools,"galaxy_params [galaxy_flux, galaxy_disk_frac, galaxy_beta_radians, galaxy_disk_q, galaxy_a_d, galaxy_bulge_q, galaxy_a_b]"
Unnamed: 0_level_1,dex(nmgy),nmgy,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,7.2089057,1351.4126,1,0,"( 2245.7, 0.52317, 3.1349, 0.58725, 1.1594, 0.54102, 0.69279)"
1,6.876845,969.56244,1,0,"( 1588.6, 0.53903, 3.0247, 0.54364, 1.2282, 0.52448, 0.76001)"
2,13.763748,949554.44,1,0,"( 2.187e+07, 0.40593, 4.0932, 0.82353, 1.1891, 0.6698, 0.63926)"
3,6.8527656,946.4949,1,0,"( 2937.8, 0.56703, 2.7753, 0.51234, 1.9366, 0.49868, 0.99211)"
4,7.574844,1948.5559,1,0,"( 3570, 0.53911, 3.4633, 0.58454, 0.97503, 0.58442, 0.6165)"
5,8.9579935,7769.752,1,0,"( 16860, 0.068614, 3.8155, 0.51641, 0.91751, 0.48461, 0.44293)"
6,7.090758,1200.8175,1,0,"( 11507, 0.88426, 3.7154, 0.6639, 3.4336, 0.57231, 1.2085)"
7,8.141618,3434.4695,1,0,"( 14164, 0.83768, 3.7812, 0.44161, 1.0834, 0.39799, 0.59891)"
8,7.5652614,1929.9733,1,0,"( 7038.8, 0.89499, 1.9568, 0.6597, 1.1934, 0.50816, 0.89573)"
9,6.7617164,864.1241,0,1,"( 1056.4, 0.42049, 3.3001, 0.40526, 1.2182, 0.43805, 0.60383)"


In [10]:
galaxy_params_table_dl.show_in_notebook(display_length=5)

idx,galaxy_flux,galaxy_disk_frac,galaxy_beta_radians,galaxy_disk_q,galaxy_a_d,galaxy_bulge_q,galaxy_a_b
Unnamed: 0_level_1,nmgy,Unnamed: 2_level_1,rad,Unnamed: 4_level_1,arcsec,Unnamed: 6_level_1,arcsec
0,2245.7,0.52317137,3.134928,0.5872512,1.1593739,0.5410212,0.6927947
1,1588.575,0.5390294,3.0247016,0.5436415,1.2282202,0.5244816,0.76000875
2,21870434.0,0.4059322,4.0931606,0.8235268,1.1890823,0.66980326,0.63925546
3,2937.8188,0.5670294,2.7752528,0.51234394,1.9366459,0.49868476,0.9921101
4,3569.9858,0.5391074,3.463295,0.58454233,0.9750287,0.5844166,0.61649823
5,16860.148,0.068613924,3.81548,0.51640934,0.9175066,0.48460945,0.44293204
6,11506.901,0.88426423,3.7153625,0.66389686,3.433607,0.5723128,1.2085218
7,14164.047,0.837677,3.7812374,0.44161278,1.0833614,0.3979892,0.5989149
8,7038.7925,0.89498544,1.9568386,0.65969616,1.1934271,0.50815684,0.8957292
9,1056.3546,0.42049375,3.3000593,0.40526482,1.2181686,0.43805063,0.6038323
