In [None]:
import time
from ML4transients.data_access import DatasetLoader
from ML4transients.training import PytorchDataset

print("Loading dataset...")
t0 = time.time()
dataset = DatasetLoader('/sps/lsst/groups/transients/HSC/fouchez/raphael/data/UDEEP_norm')
print(f"Dataset loaded in {time.time() - t0:.2f} seconds")

# Cutouts


In [None]:
import matplotlib.pyplot as plt
import numpy as np

first_visit = dataset.visits[0]
print(f"Loading cutouts from visit {first_visit}")

cutout_loader = dataset.cutouts[first_visit]
feature_loader = dataset.features[first_visit]
sample_ids = cutout_loader.ids[:6]

fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()

for i, dia_id in enumerate(sample_ids):
    cutout = cutout_loader.get_by_id(dia_id)
    features = feature_loader.get_by_id(dia_id)
    is_injection = features['is_injection'].iloc[0] if not features.empty else False
    
    im = axes[i].imshow(cutout, cmap='RdYlGn', origin='lower')
    axes[i].set_title(f'ID: {dia_id}\nInjection: {is_injection}')
    axes[i].axis('off')
    plt.colorbar(im, ax=axes[i], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()

print(f"Cutout shape: {cutout.shape}")
print(f"Value range: [{cutout.min():.3f}, {cutout.max():.3f}]")

# LC visualization

In [None]:
import time
import matplotlib.pyplot as plt
import numpy as np

first_visit = dataset.visits[0]
sample_dia_source_id = dataset.cutouts[first_visit].ids[10]

print(f"Performance comparison for diaSourceId: {sample_dia_source_id}")
print("=" * 60)

# Method 1: Lightcurve only
print("\nMethod 1: Lightcurve columns only")
start_time = time.time()
data_lc_only = dataset.get_complete_lightcurve_data(sample_dia_source_id, load_cutouts=False)
time_lc_only = time.time() - start_time

if data_lc_only:
    print(f"Success: {data_lc_only['num_sources']} sources, {len(data_lc_only['lightcurve'])} LC points")
    print(f"Time: {time_lc_only:.3f}s")
else:
    print("Failed")

print("-" * 60)

# Method 2: Lightcurve + cutouts
print("\nMethod 2: Lightcurve + cutouts")
start_time = time.time()
data_full = dataset.get_complete_lightcurve_data(sample_dia_source_id, load_cutouts=True)
time_full = time.time() - start_time

if data_full:
    print(f"Success: {data_full['num_sources']} sources, {len(data_full['cutouts'])} cutouts")
    print(f"Time: {time_full:.3f}s")
else:
    print("Failed")

print("=" * 60)

# Performance summary
if data_lc_only and data_full:
    speedup = time_full / time_lc_only
    print(f"\nSpeedup: {speedup:.1f}x faster (lightcurve-only)")
    print(f"Time saved: {time_full - time_lc_only:.3f}s")

In [None]:
# Plot all cutouts
if data_full and data_full['cutouts']:
    cutouts = data_full['cutouts']
    lightcurve = data_full['lightcurve']
    
    # Sort by time if available
    if 'diaSourceId' in lightcurve.columns and 'midpointMjdTai' in lightcurve.columns:
        time_map = dict(zip(lightcurve['diaSourceId'], lightcurve['midpointMjdTai']))
        sorted_ids = sorted(cutouts.keys(), key=lambda x: time_map.get(x, 0))
    else:
        sorted_ids = list(cutouts.keys())
    
    n_cutouts = len(cutouts)
    cols = min(8, n_cutouts)
    rows = (n_cutouts + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(2*cols, 2*rows))
    if n_cutouts == 1:
        axes = [axes]
    elif rows == 1:
        axes = axes.reshape(1, -1)
    axes = axes.flatten()
    
    for i, src_id in enumerate(sorted_ids):
        im = axes[i].imshow(cutouts[src_id], cmap='RdYlGn', origin='lower')
        axes[i].axis('off')
        plt.colorbar(im, ax=axes[i], fraction=0.046, pad=0.04)
    
    for i in range(n_cutouts, len(axes)):
        axes[i].set_visible(False)
    
    plt.suptitle(f'Cutouts for object {data_full["object_id"]} ({n_cutouts} total)')
    plt.tight_layout()
    plt.show()
else:
    print("No cutouts available")

# Lc inference

In [None]:
t1 = time.time()

dataset = DatasetLoader('/sps/lsst/groups/transients/HSC/fouchez/raphael/data/UDEEP_norm')

ensemble_df = dataset.lightcurves.inference_snn

print(f"Inference done in {time.time() - t1:.2f} seconds")

In [None]:
dataset = DatasetLoader('/sps/lsst/groups/transients/HSC/fouchez/raphael/data/UDEEP_norm')
t2 = time.time()

high_conf = list(dataset.lightcurves.get_high_conf_sn_sources())
print(f"Get high conf LC done in {time.time() - t2:.2f} seconds")
t3 = time.time()

dataset.lightcurves.save_high_conf_subset_dataset("/sps/lsst/groups/transients/HSC/fouchez/raphael/data/UDEEP_norm_high_conf_small", prob_threshold = 0.9, std_threshold= 0.001)
print(f"Save high conf LC done in {time.time() - t3:.2f} seconds")



In [None]:
print(high_conf[100 :])