In [1]:
import os
import numpy as np
import pandas as pd

# ============================================================================
# LOAD DATA
# ============================================================================

DATA_DIR = os.path.join("..", "data", "plasticc")
sample_metadata = pd.read_csv(os.path.join(DATA_DIR, "sample_metadata.csv"))
sample_lc = pd.read_csv(os.path.join(DATA_DIR, "sample_lightcurves.csv"))

print(f"Metadata: {len(sample_metadata)} objects")
print(f"Lightcurves: {len(sample_lc):,} rows")

# ============================================================================
# COMPLETE FEATURE EXTRACTION
# ============================================================================

def extract_features_complete(obj_id, lc_data, target_class):
    """
    Extract ALL 16 engineered features that worked in classical model
    """
    # Filter for this object, detected observations only
    lc = lc_data[lc_data['object_id'] == obj_id].copy()
    lc_clean = lc[(lc['detected'] == 1) & (lc['flux'] > 0) & (lc['flux_err'] > 0)].copy()
    
    if len(lc_clean) < 5:
        return None
    
    # Sort by time
    lc_clean = lc_clean.sort_values('mjd')
    
    flux = lc_clean['flux'].values
    time = lc_clean['mjd'].values
    
    # Convert flux to magnitude (astronomical convention)
    mags = -2.5 * np.log10(flux)
    
    # ========================================================================
    # MAGNITUDE FEATURES
    # ========================================================================
    mag_min = mags.min()
    mag_max = mags.max()
    mag_mean = mags.mean()
    mag_std = mags.std()
    mag_range = mag_max - mag_min
    
    # ========================================================================
    # FLUX FEATURES
    # ========================================================================
    flux_max = flux.max()
    flux_mean = flux.mean()
    flux_std = flux.std()
    
    # ========================================================================
    # TIME FEATURES
    # ========================================================================
    time_span = time.max() - time.min()
    
    # ========================================================================
    # PEAK DETECTION & RISE/DECLINE
    # ========================================================================
    # Find peak (minimum magnitude = brightest)
    peak_idx = np.argmin(mags)
    peak_time = time[peak_idx]
    
    # Rise time: time from first observation to peak
    rise_time = peak_time - time[0]
    
    # Decline time: time from peak to last observation
    decline_time = time[-1] - peak_time
    
    # Rise/decline ratio (key discriminator!)
    if decline_time > 1.0:
        rise_decline_ratio = rise_time / decline_time
    else:
        rise_decline_ratio = 0.0
    
    # Data quality check: peak must not be at edges
    if rise_time < 1.0 or decline_time < 1.0:
        return None
    
    # Time span check: need reasonable observation window
    if time_span < 10:
        return None
    
    # Clip extreme ratios (outliers)
    rise_decline_ratio = np.clip(rise_decline_ratio, 0.01, 100)
    
    # ========================================================================
    # SLOPE FEATURES
    # ========================================================================
    # Calculate slopes on magnitude
    if len(mags) >= 2:
        dt = np.diff(time)
        dmag = np.diff(mags)
        
        # Avoid division by very small dt
        valid_slopes = dt > 0.1
        if valid_slopes.sum() > 0:
            slopes = dmag[valid_slopes] / dt[valid_slopes]
            
            # Rise slopes (before peak)
            rise_slopes = slopes[:peak_idx] if peak_idx > 0 else slopes[:1]
            # Decline slopes (after peak)
            decline_slopes = slopes[peak_idx:] if peak_idx < len(slopes) else slopes[-1:]
            
            mean_rise_slope = rise_slopes.mean() if len(rise_slopes) > 0 else 0.0
            mean_decline_slope = decline_slopes.mean() if len(decline_slopes) > 0 else 0.0
            max_slope = np.abs(slopes).max()
        else:
            mean_rise_slope = 0.0
            mean_decline_slope = 0.0
            max_slope = 0.0
    else:
        mean_rise_slope = 0.0
        mean_decline_slope = 0.0
        max_slope = 0.0
    
    # ========================================================================
    # RETURN ALL 16 FEATURES
    # ========================================================================
    return {
        'transient_id': obj_id,
        'label': 'SNIa' if target_class == 90 else 'SNII',
        # Magnitude features (5)
        'mag_min': mag_min,
        'mag_max': mag_max,
        'mag_mean': mag_mean,
        'mag_std': mag_std,
        'mag_range': mag_range,
        # Flux features (3)
        'flux_max': flux_max,
        'flux_mean': flux_mean,
        'flux_std': flux_std,
        # Time features (4)
        'time_span': time_span,
        'rise_time': rise_time,
        'decline_time': decline_time,
        'rise_decline_ratio': rise_decline_ratio,
        # Slope features (3)
        'mean_rise_slope': mean_rise_slope,
        'mean_decline_slope': mean_decline_slope,
        'max_slope': max_slope,
        # Metadata
        'n_points': len(lc_clean),
    }

# ============================================================================
# EXTRACT FEATURES FOR ALL OBJECTS
# ============================================================================

print("\n" + "=" * 70)
print("EXTRACTING FEATURES")
print("=" * 70)

all_features = []
failed_count = 0

for idx, row in sample_metadata.iterrows():
    obj_id = row['object_id']
    target = row['target']
    
    features = extract_features_complete(obj_id, sample_lc, target)
    
    if features is not None:
        all_features.append(features)
    else:
        failed_count += 1
    
    if (idx + 1) % 200 == 0:
        print(f"  Processed {idx + 1}/{len(sample_metadata)} objects...")

features_df = pd.DataFrame(all_features)

print(f"\nâœ“ Successfully extracted features: {len(features_df)}")
print(f"âœ— Failed (insufficient data): {failed_count}")
print(f"\nClass distribution:")
print(features_df['label'].value_counts())

# ============================================================================
# FEATURE STATISTICS BY CLASS
# ============================================================================

print("\n" + "=" * 70)
print("FEATURE STATISTICS BY CLASS")
print("=" * 70)

numeric_features = features_df.select_dtypes(include=[np.number]).columns.tolist()
numeric_features.remove('transient_id')

print("\nMeans:")
print(features_df.groupby('label')[numeric_features].mean())

print("\nStandard deviations:")
print(features_df.groupby('label')[numeric_features].std())

# ============================================================================
# FEATURE CORRELATION WITH LABEL
# ============================================================================

from scipy.stats import pointbiserialr

print("\n" + "=" * 70)
print("FEATURE CORRELATION WITH LABEL")
print("=" * 70)

label_numeric = (features_df['label'] == 'SNIa').astype(int)

correlations = []
for col in numeric_features:
    if col != 'n_points':
        corr, pval = pointbiserialr(label_numeric, features_df[col])
        correlations.append({
            'feature': col,
            'correlation': abs(corr),
            'correlation_signed': corr,
            'p_value': pval
        })

corr_df = pd.DataFrame(correlations).sort_values('correlation', ascending=False)
print("\nTop 10 features by correlation:")
print(corr_df.head(10))

# Identify top 3 for quantum
top_3 = corr_df.head(3)['feature'].tolist()
print(f"\nðŸŽ¯ TOP 3 FEATURES FOR QUANTUM: {top_3}")

# ============================================================================
# SAVE FEATURES
# ============================================================================

features_path = os.path.join(DATA_DIR, "transient_features.csv")
features_df.to_csv(features_path, index=False)
print(f"\nâœ“ Saved features: {features_path}")

print("\n" + "=" * 70)
print("FEATURE EXTRACTION COMPLETE!")
print("=" * 70)
print(f"Next step: Train classical models")

Metadata: 2000 objects
Lightcurves: 381,810 rows

EXTRACTING FEATURES
  Processed 200/2000 objects...
  Processed 400/2000 objects...
  Processed 600/2000 objects...
  Processed 800/2000 objects...
  Processed 1000/2000 objects...
  Processed 1200/2000 objects...
  Processed 1400/2000 objects...
  Processed 1600/2000 objects...
  Processed 1800/2000 objects...
  Processed 2000/2000 objects...

âœ“ Successfully extracted features: 1072
âœ— Failed (insufficient data): 928

Class distribution:
label
SNII    549
SNIa    523
Name: count, dtype: int64

FEATURE STATISTICS BY CLASS

Means:
        mag_min   mag_max  mag_mean   mag_std  mag_range    flux_max  \
label                                                                  
SNII  -5.835673 -3.386111 -4.947218  0.683593   2.449562  478.061295   
SNIa  -5.761671 -3.060064 -4.694934  0.769396   2.701607  408.937967   

        flux_mean    flux_std   time_span  rise_time  decline_time  \
label                                               