In this notebook, we do visual inspection to create a training set for our CNN shred classifier!

In [9]:
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
from astropy.io import fits
from astropy.table import Table
import numpy as np


Load the dataset!

We create a copy of this dataset with only a few columns to avoid over-writing on the original dataset!

The only columns we need are RA, DEC, TARGETID, SAMPLE, IMAGE_PATH, FILE_PATH and we will add a column for our labelling!

In [None]:
# shreds_all = Table.read("/pscratch/sd/v/virajvm/catalog_dr1_dwarfs/desi_y1_dwarf_shreds_catalog_v3.fits")
# shreds_all = shreds_all["TARGETID","RA","DEC","FILE_PATH","IMAGE_PATH"]

# ##save this folder and we will be updating for this for visual inspection stuff!
# all_img_paths = []
# for i in range(len(shreds_all)):
#     all_img_paths.append( shreds_all["FILE_PATH"][i] + "/grz_bands_segments.png")

# shreds_all["PNG_PATH"] = all_img_paths

# shreds_all.write("/pscratch/sd/v/virajvm/catalog_dr1_dwarfs/desi_y1_dwarf_shreds_VI.fits",overwrite=True)

In [None]:
# ## read the VI catalog 
# data = Table.read("/pscratch/sd/v/virajvm/catalog_dr1_dwarfs/desi_y1_dwarf_shreds_VI.fits")

# np.random.seed(42)

# #can we scrambe these row
# shuffled_indices = np.random.permutation(len(data))
# data_shuffled = data[shuffled_indices]

# data_shuffled.write("/pscratch/sd/v/virajvm/catalog_dr1_dwarfs/desi_y1_dwarf_shreds_VI_labelled.fits",overwrite=True)



In [14]:
data = Table.read("/pscratch/sd/v/virajvm/catalog_dr1_dwarfs/desi_y1_dwarf_shreds_VI_labelled.fits")

#converting the data to pandas df
df = data.to_pandas()

In [15]:
df["IS_SHRED_VI"] == b'nan'

0        False
1        False
2        False
3        False
4        False
         ...  
53372     True
53373     True
53374     True
53375     True
53376     True
Name: IS_SHRED_VI, Length: 53377, dtype: bool

In [7]:
save_path = "/pscratch/sd/v/virajvm/catalog_dr1_dwarfs/desi_y1_dwarf_shreds_VI_labelled.fits"


In [8]:
# # # === Add 'label' column if not present ===
# if "IS_SHRED_VI" not in df.columns:
#     df["IS_SHRED_VI"] = ""

In [9]:

# === Set start index ===
# index = df[df["IS_SHRED_VI"] == ""].index.min()

# if pd.isna(index):
#     index = 0

# Identify rows that are unprocessed: NaN or empty byte string
# unprocessed = df[df["IS_SHRED_VI"].isna() | (df["IS_SHRED_VI"] == b'')]
# # Get the first index of unprocessed row
# index = unprocessed.index.min()

index = df[df["IS_SHRED_VI"] == b'nan'].index.min()

print(index)

# === Display setup ===
image_widget = widgets.Output()
label_widget = widgets.Label()
next_button = widgets.Button(description="Next")
prev_button = widgets.Button(description="Prev")
shred_button = widgets.Button(description="Fragment")
good_button = widgets.Button(description="Good")

# === Save back to FITS ===
def save_fits():
    df_to_save = df.copy()

    # Ensure all object-type columns (usually strings) are cast to fixed-length strings
    for col in df_to_save.select_dtypes(include='object').columns:
        maxlen = df_to_save[col].astype(str).str.len().max()
        df_to_save[col] = df_to_save[col].astype(f'S{maxlen}')

    hdu = fits.BinTableHDU.from_columns(fits.ColDefs(df_to_save.to_records(index=False)))
    hdu.writeto(save_path, overwrite=True)

# === Save label ===
def save_label(label):
    global index
    df.at[index, "IS_SHRED_VI"] = label
    save_fits()

# === Display image ===
def resize_image_maintain_aspect(image, target_size=1024):
    # Get original dimensions
    width, height = image.size
    
    # Calculate the scaling factor to preserve the aspect ratio
    scaling_factor = target_size / max(width, height)
    
    # Compute the new width and height
    new_width = int(width * scaling_factor)
    new_height = int(height * scaling_factor)
    
    # Resize the image (this won't distort it)
    resized_img = image.resize((new_width, new_height))

    return resized_img
    
def show_image():
    image_widget.clear_output(wait=True)
    title = str(df.iloc[index]['PNG_PATH'])
    title = title.replace("/pscratch/sd/v/virajvm/redo_photometry_plots/all_deshreds/","")
    label_widget.value = f"{index + 1}/{len(df)}: {title}"
    with image_widget:
        try:
            img = Image.open(df.iloc[index]["PNG_PATH"])
            # img = img.resize((256, 256))
            plt.figure(figsize = (20,20))
            img = resize_image_maintain_aspect(img)
            plt.imshow(img)
            plt.axis("off")
            plt.show()
        except Exception as e:
            print(f"Could not load image: {e}")

# === Button Callbacks ===
def on_next(b): 
    global index
    if index < len(df) - 1:
        index += 1
        show_image()

def on_prev(b):
    global index
    if index > 0:
        index -= 1
        show_image()

def on_shred(b):
    save_label("fragment")
    on_next(b)

def on_good(b):
    save_label("good")
    on_next(b)

# === Wire Buttons ===
next_button.on_click(on_next)
prev_button.on_click(on_prev)
shred_button.on_click(on_shred)
good_button.on_click(on_good)



1810


In [10]:
# === Layout ===
buttons = widgets.HBox([prev_button, next_button, good_button, shred_button])
display(label_widget, image_widget, buttons)

# === Start ===
show_image()

Label(value='')

Output()

HBox(children=(Button(description='Prev', style=ButtonStyle()), Button(description='Next', style=ButtonStyle()…

1046/53377: b'south/sweep-110p020-120p025/1164p245/BGS_BRIGHT_tgid_39628361462449758/grz_bands_segments.png'
-> why is the yellow compact source not identified as a star??


1605/53377: b'south/sweep-230p000-240p005/2319p042/BGS_BRIGHT_tgid_39627891255808577/grz_bands_segments.png'


1645/53377: b'south/sweep-050m005-060p000/0567m042/BGS_BRIGHT_tgid_39627683042166732/grz_bands_segments.png'


1715/53377: b'south/sweep-130m010-140m005/1347m077/ELG_tgid_39627600196274367/grz_bands_segments.png'


1897/53377: b'south/sweep-310p005-320p010/3134p057/BGS_BRIGHT_tgid_39627928744495248/grz_bands_segments.png'

Maybe one of the reasons is that it was targeted?

1074/53377: b'south/sweep-210m005-220p000/2168m025/BGS_BRIGHT_tgid_39627727967355167/grz_bands_segments.png'
-> looks like a fragment object and good for apeture, but the source subtraction is too much I think?

1118/53377: south/sweep-120p000-130p005/1280p032/ELG_tgid_39627865389534167/grz_bands_segments.png'
-> example of object where 2 galaxies look very similar color and so aperture photometry fails ... 


1131/53377:south/sweep-180m005-190p000/1818m032/ELG_tgid_39627709264959315/grz_bands_segments.png'
-> example of ELG source on outskirt of very diffuse large galaxy and so aperture mag pipeline fails

In [21]:
data = Table.read("/pscratch/sd/v/virajvm/catalog_dr1_dwarfs/desi_y1_dwarf_shreds_VI_labelled.fits")

##let us get the corresponding meta data for these objects!
data_main = Table.read("/pscratch/sd/v/virajvm/catalog_dr1_dwarfs/desi_y1_dwarf_shreds_catalog_v3.fits")
data_main = data_main["TARGETID","RA","DEC", "MAG_G","MAG_R","MAG_Z","FRACFLUX_G","FRACFLUX_R","FRACFLUX_Z"]

In [27]:
np.min(np.unique(data["TARGETID"]) - np.unique(data_main["TARGETID"]))

0

In [None]:
/pscratch/sd/v/virajvm/redo_photometry_plots/all_deshreds/south/sweep-110p020-120p025/1164p245/BGS_BRIGHT_tgid_39628361462449758/

In [111]:
data_training = data[(data["IS_SHRED_VI"] == "good") | (data["IS_SHRED_VI"] == "fragment")]

In [132]:
# import sys
# import os
# rootdir = '/global/u1/v/virajvm/'
# sys.path.append(os.path.join(rootdir, 'DESI2_LOWZ/desi_dwarfs/code'))

# from desi_lowz_funcs import match_c_to_catalog

# ##I need to cross-match this to get the files file image data!

# shreds_all = Table.read("/pscratch/sd/v/virajvm/catalog_dr1_dwarfs/desi_y1_dwarf_shreds_catalog_v3.fits")
# temp = shreds_all["TARGETID","RA","DEC","IMAGE_PATH"]

# idx, _, _ = match_c_to_catalog(c_cat=data, catalog_cat=temp, c_ra="RA",c_dec="DEC",catalog_ra="RA",catalog_dec="DEC")

# print(temp[idx]["TARGETID"].data - data["TARGETID"].data)

# data["IMAGE_PATH"] = temp[idx]["IMAGE_PATH"]

# data.write("/pscratch/sd/v/virajvm/catalog_dr1_dwarfs/desi_y1_dwarf_shreds_VI_labelled.fits", overwrite=True )

In [82]:
data_main = Table.read("/pscratch/sd/v/virajvm/catalog_dr1_dwarfs/desi_y1_dwarf_shreds_catalog_v3.fits")



In [105]:
len(data_main[data_main["PCNN_FRAGMENT"] < 0.3]) / len(data_main)

0.2443309341382246

In [112]:
data_cnn_shred = data_main[data_main["PCNN_FRAGMENT"] >= 0.3]
data_cnn_clean = data_main[data_main["PCNN_FRAGMENT"] < 0.3]
