In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
from typing import List

import torch
import pandas as pd
import numpy as np

from hydra import initialize, compose
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

import matplotlib.pyplot as plt

from bliss.simulator.simulated_dataset import CachedSimulatedDataset, FileDatum


In [3]:
with initialize(config_path="../", version_base=None):
    cfg = compose("config", overrides={"surveys.sdss.load_image_data=true"})

In [4]:
cfg.prior

{'_target_': 'case_studies.galaxy_clustering.prior.GalaxyClusterPrior', 'survey_bands': ['u', 'g', 'r', 'i', 'z'], 'reference_band': 2, 'star_color_model_path': '${simulator.survey.dir_path}/color_models/star_gmm_nmgy.pkl', 'gal_color_model_path': '${simulator.survey.dir_path}/color_models/gal_gmm_nmgy.pkl', 'n_tiles_h': 56, 'n_tiles_w': 56, 'tile_slen': 2, 'batch_size': 2, 'max_sources': 6, 'mean_sources': 0.48, 'min_sources': 0, 'prob_galaxy': 0.2, 'star_flux_exponent': 0.9859821185389767, 'star_flux_truncation': 5685.588160703261, 'star_flux_loc': -1.162430157551662, 'star_flux_scale': 1.4137911256506595, 'galaxy_flux_truncation': 1013, 'galaxy_flux_exponent': 0.47, 'galaxy_flux_scale': 0.6301037, 'galaxy_flux_loc': 0.0, 'galaxy_a_concentration': 0.39330758068481686, 'galaxy_a_loc': 0.8371888967872619, 'galaxy_a_scale': 4.432725319432478, 'galaxy_a_bd_ratio': 2.0}

In [5]:
cfg.cached_simulator

{'_target_': 'bliss.case_studies.galaxy_clustering.simulated_dataset.GalaxyClusterCachedSimulatedDataset', 'batch_size': 64, 'splits': '0:80/80:90/90:100', 'num_workers': 1, 'cached_data_path': '${generate.cached_data_path}', 'file_prefix': '${generate.file_prefix}'}

In [6]:
os.getcwd()

'/home/kapnadak/bliss/case_studies/galaxy_clustering/notebooks'

In [7]:
col_names = ["RA", "DEC", "X", "Y", "MEM", "FLUX_R", "FLUX_G", "FLUX_I", "FLUX_Z", "TSIZE", "FRACDEV", "G1", "G2"]
sample_data_0 = pd.read_csv("../data/padded_catalogs/galsim_des_padded_000.dat", sep=" ", header=None, names=col_names)
sample_data_1 = pd.read_csv("../data/padded_catalogs/galsim_des_padded_001.dat", sep=" ", header=None, names=col_names)
sample_data_2 = pd.read_csv("../data/padded_catalogs/galsim_des_padded_002.dat", sep=" ", header=None, names=col_names)

In [8]:
sample_data_0.head()

Unnamed: 0,RA,DEC,X,Y,MEM,FLUX_R,FLUX_G,FLUX_I,FLUX_Z,TSIZE,FRACDEV,G1,G2
0,50.728375,-40.258012,3997.335589,1974.244275,1.0,126.118851,245.340043,165.57364,181.391205,2.989006,0,-0.105925,0.064605
1,50.718273,-40.232888,3815.497087,2426.468418,1.0,210.027165,331.354642,289.43141,337.502472,3.922377,0,0.009929,0.113251
2,50.732064,-40.245886,4063.726655,2192.505405,1.0,553.309492,782.609501,643.397903,511.653956,1.756963,0,0.66846,0.023192
3,50.706336,-40.217926,3600.626683,2695.785562,1.0,479.605685,887.104739,536.080929,541.227174,3.749274,0,-0.725819,0.074286
4,50.697505,-40.228009,3441.662618,2514.299651,1.0,210.508576,608.632044,400.418016,855.955095,2.66363,0,-0.147868,-0.394306


In [9]:
len(sample_data_0)

2200

In [10]:
len(sample_data_1)

2200

In [11]:
len(sample_data_2)

2200

In [12]:
d = dict()
d["plocs"] = torch.tensor([sample_data_0[["X", "Y"]].to_numpy(), sample_data_1[["X", "Y"]].to_numpy()])
n_sources = torch.sum(d["plocs"][:,:,0] != 0, axis=1)
d["n_sources"] = n_sources

  d["plocs"] = torch.tensor([sample_data_0[["X", "Y"]].to_numpy(), sample_data_1[["X", "Y"]].to_numpy()])


In [13]:
d["fluxes"] = torch.tensor([sample_data_0[["FLUX_R", "FLUX_G", "FLUX_I", "FLUX_Z"]].to_numpy(), sample_data_1[["FLUX_R", "FLUX_G", "FLUX_I", "FLUX_Z"]].to_numpy()])
d["membership"] = torch.tensor([sample_data_0[["MEM"]].to_numpy(), sample_data_1[["MEM"]].to_numpy()])
d["hlr"] = torch.tensor([sample_data_0[["TSIZE"]].to_numpy(), sample_data_1[["TSIZE"]].to_numpy()])
d["fracdev"] = torch.tensor([sample_data_0[["FRACDEV"]].to_numpy(), sample_data_1[["FRACDEV"]].to_numpy()])
d["g1g2"] = torch.tensor([sample_data_0[["G1", "G2"]].to_numpy(), sample_data_1[["G1", "G2"]].to_numpy()])

In [14]:
from bliss.catalog import FullCatalog, TileCatalog
TileCatalog.allowed_params.update(["membership", "fracdev", "g1g2"])

In [15]:
fc = FullCatalog(height=5000, width=5000, d=d)

In [16]:
fc.batch_size

2

In [17]:
tc = fc.to_tile_catalog(20, 20)

In [18]:
tc.height, tc.width

(5000, 5000)

In [19]:
tile_catalog_dict = tc.to_dict()

In [20]:
tile_catalog_dict.keys()

dict_keys(['locs', 'n_sources', 'fluxes', 'membership', 'hlr', 'fracdev', 'g1g2'])

In [21]:
np.where(tile_catalog_dict["n_sources"][0] != 0)

(array([  0,   0,   0, ..., 249, 249, 249]),
 array([  7,   8, 126, ...,  65,  89, 175]))

In [22]:
tile_catalog_dict["n_sources"][0].sum()

tensor(1989)

In [23]:
tc.get_indices_of_on_sources()[0][0]

tensor([ 902040, 1042920,  398080,  ...,  833074,  834488,  833071])

In [24]:
full_locs = tc.get_full_locs_from_tiles()
full_locs[0].shape

torch.Size([250, 250, 20, 2])

In [25]:
tc.n_sources[1].sum()

tensor(2041)

In [26]:
tile_catalog_dict["fluxes"][0].shape

torch.Size([250, 250, 20, 4])

In [27]:
tile_catalog_dict["membership"][0].sum() # number of clustered galaxies

tensor(381.)

In [28]:
tile_catalog_dict["g1g2"][0].shape

torch.Size([250, 250, 20, 2])

In [29]:
(fc["membership"][0] == 0).sum() # number of background galaxies + number of zero padded rows

tensor(1819)

In [30]:
fc["membership"][0].sum() + (fc["membership"][0] == 0).sum()

tensor(2200., dtype=torch.float64)

In [None]:
from torchmetrics import Metric
class ClusterMembershipAccuracy(Metric):
    def __init__(self):
        super().__init__()

        self.add_state("membership_tp", default=torch.zeros(1), dist_reduce_fx="sum")
        self.add_state("membership_tn", default=torch.zeros(1), dist_reduce_fx="sum")
        self.add_state("membership_fp", default=torch.zeros(1), dist_reduce_fx="sum")
        self.add_state("membership_fn", default=torch.zeros(1), dist_reduce_fx="sum")
        self.add_state("n_matches", default=torch.zeros(1), dist_reduce_fx="sum")

    def update(self, true_cat, est_cat, matching):
        for i in range(true_cat.batch_size):
            tcat_matches, ecat_matches = matching[i]
            self.n_matches += tcat_matches.size(0)

            true_membership = true_cat.membership[i][tcat_matches]
            est_membership = est_cat.membership[i][ecat_matches]

            self.membership_tp += (true_membership * est_membership).sum()
            self.membership_tn += (~true_membership * ~est_membership).sum()
            self.membership_fp += (~true_membership * est_membership).sum()
            self.membership_fn += (true_membership * ~est_membership).sum()

    def compute(self):
        precision = self.membership_tp / (self.membership_tp + self.membership_fp)
        recall = self.membership_tp / (self.membership_tp + self.membership_fn)
        accuracy = (self.membership_tp + self.membership_tn) / (self.membership_tp + self.membership_tn + self.membership_fp + self.membership_fn)
        f1 = 2 * precision * recall / (precision + recall)
        return {
            "membership_accuracy": accuracy,
            "membership_precision": precision,
            "membership_recall": recall,
            "membership_f1": f1
        }