In [None]:
# ============================================================
# CELL 1: SETUP & IMPORTS
# ============================================================

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
import gc
import os
from tqdm.auto import tqdm
from scipy import stats

# Settings
warnings.filterwarnings('ignore')
pd.set_option('display.max_columns', 100)
pd.set_option('display.max_rows', 100)

# Custom color palette
COLORS = {
    'tde': '#e74c3c',      # Red for TDE
    'non_tde': '#3498db',  # Blue for non-TDE
    'primary': '#2c3e50',
    'secondary': '#95a5a6'
}

# Data path
DATA_PATH = Path("/kaggle/input/project/mallorn-astronomical-classification-challenge")

print("‚úÖ Setup complete!")
print(f"üìÅ Data path: {DATA_PATH}")

# List files
print("\nüìÇ Files in data directory:")
for item in sorted(DATA_PATH.iterdir()):
    if item.is_dir():
        print(f"   üìÅ {item.name}/")
    else:
        size_mb = item.stat().st_size / (1024*1024)
        print(f"   üìÑ {item.name} ({size_mb:.2f} MB)")

In [None]:
# ============================================================
# CELL 2: LOAD METADATA (LOG FILES)
# ============================================================

# Load train log
print("üì• Loading train_log.csv...")
train_log = pd.read_csv(DATA_PATH / "train_log.csv")

# Load test log  
print("üì• Loading test_log.csv...")
test_log = pd.read_csv(DATA_PATH / "test_log.csv")

# Load sample submission
print("üì• Loading sample_submission.csv...")
sample_sub = pd.read_csv(DATA_PATH / "sample_submission.csv")

print("\n" + "="*60)
print("üìä DATASET OVERVIEW")
print("="*60)
print(f"üîπ Training objects: {len(train_log):,}")
print(f"üîπ Testing objects:  {len(test_log):,}")
print(f"üîπ Total objects:    {len(train_log) + len(test_log):,}")

print("\n" + "="*60)
print("üìã TRAIN LOG - First 5 rows")
print("="*60)
display(train_log.head())

print("\n" + "="*60)
print("üìã TEST LOG - First 5 rows")
print("="*60)
display(test_log.head())

In [None]:
# ============================================================
# CELL 3: ANALYZE TRAIN LOG STRUCTURE
# ============================================================

print("="*60)
print("üìä TRAIN LOG - DETAILED ANALYSIS")
print("="*60)

print("\nüîπ Data Types:")
print(train_log.dtypes)

print("\nüîπ Missing Values:")
missing = train_log.isnull().sum()
missing_pct = (missing / len(train_log) * 100).round(2)
missing_df = pd.DataFrame({'Missing Count': missing, 'Missing %': missing_pct})
print(missing_df[missing_df['Missing Count'] > 0])
if missing.sum() == 0:
    print("   ‚úÖ No missing values in train_log!")

print("\nüîπ Statistical Summary (Numeric Columns):")
display(train_log.describe())

print("\nüîπ Unique Values per Column:")
for col in train_log.columns:
    n_unique = train_log[col].nunique()
    print(f"   {col}: {n_unique:,} unique values")

In [None]:
# ============================================================
# CELL 4: TARGET VARIABLE ANALYSIS
# ============================================================

print("="*60)
print("üéØ TARGET VARIABLE ANALYSIS")
print("="*60)

# Target distribution
target_counts = train_log['target'].value_counts().sort_index()
print("\nüîπ Target Distribution:")
print(f"   Non-TDE (0): {target_counts[0]:,} ({target_counts[0]/len(train_log)*100:.2f}%)")
print(f"   TDE (1):     {target_counts[1]:,} ({target_counts[1]/len(train_log)*100:.2f}%)")
print(f"\n   ‚ö†Ô∏è Imbalance Ratio: {target_counts[0]/target_counts[1]:.1f} : 1")

# Visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# 1. Bar chart
ax1 = axes[0]
bars = ax1.bar(['Non-TDE\n(0)', 'TDE\n(1)'], target_counts.values, 
               color=[COLORS['non_tde'], COLORS['tde']], edgecolor='black', linewidth=1.5)
ax1.set_ylabel('Count', fontsize=11)
ax1.set_title('Target Distribution (Count)', fontsize=12, fontweight='bold')
for bar, count in zip(bars, target_counts.values):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 50, 
             f'{count:,}', ha='center', va='bottom', fontsize=11, fontweight='bold')

# 2. Pie chart
ax2 = axes[1]
explode = (0, 0.1)
wedges, texts, autotexts = ax2.pie(target_counts.values, 
                                    labels=['Non-TDE', 'TDE'],
                                    autopct='%1.1f%%',
                                    colors=[COLORS['non_tde'], COLORS['tde']],
                                    explode=explode,
                                    startangle=90,
                                    wedgeprops={'edgecolor': 'black', 'linewidth': 1.5})
ax2.set_title('Target Distribution (%)', fontsize=12, fontweight='bold')

# 3. Log scale bar chart (to see TDE better)
ax3 = axes[2]
bars = ax3.bar(['Non-TDE\n(0)', 'TDE\n(1)'], target_counts.values,
               color=[COLORS['non_tde'], COLORS['tde']], edgecolor='black', linewidth=1.5)
ax3.set_yscale('log')
ax3.set_ylabel('Count (Log Scale)', fontsize=11)
ax3.set_title('Target Distribution (Log Scale)', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.savefig('target_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nüí° KEY INSIGHT: Dataset is HIGHLY IMBALANCED!")
print("   ‚Üí Need to use F1 Score (not Accuracy)")
print("   ‚Üí Consider: class weights, oversampling, threshold tuning")

In [None]:
# ============================================================
# CELL 5: SPECTRAL TYPE (SpecType) ANALYSIS
# ============================================================

print("="*60)
print("üî¨ SPECTRAL TYPE ANALYSIS")
print("="*60)

# SpecType distribution
spectype_counts = train_log['SpecType'].value_counts()
print("\nüîπ All Spectral Types in Training Data:")
print(spectype_counts.to_string())

# Group by target
spectype_by_target = train_log.groupby(['SpecType', 'target']).size().unstack(fill_value=0)
spectype_by_target.columns = ['Non-TDE', 'TDE']
print("\nüîπ Spectral Types by Target:")
display(spectype_by_target)

# Identify TDE SpecTypes
tde_spectypes = train_log[train_log['target'] == 1]['SpecType'].unique()
print(f"\nüéØ TDE Spectral Types: {tde_spectypes}")

# Visualization
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# 1. All SpecTypes
ax1 = axes[0]
colors = [COLORS['tde'] if st in tde_spectypes else COLORS['non_tde'] for st in spectype_counts.index]
bars = ax1.barh(spectype_counts.index, spectype_counts.values, color=colors, edgecolor='black')
ax1.set_xlabel('Count', fontsize=11)
ax1.set_title('Distribution of Spectral Types\n(Red = TDE)', fontsize=12, fontweight='bold')
ax1.invert_yaxis()

# Add count labels
for bar, count in zip(bars, spectype_counts.values):
    ax1.text(bar.get_width() + 10, bar.get_y() + bar.get_height()/2, 
             f'{count:,}', va='center', fontsize=9)

# 2. Grouped bar chart
ax2 = axes[1]
x = np.arange(len(spectype_by_target))
width = 0.35

# Only show non-zero
non_tde_vals = spectype_by_target['Non-TDE'].values
tde_vals = spectype_by_target['TDE'].values

bars1 = ax2.barh(x - width/2, non_tde_vals, width, label='Non-TDE', 
                  color=COLORS['non_tde'], edgecolor='black')
bars2 = ax2.barh(x + width/2, tde_vals, width, label='TDE', 
                  color=COLORS['tde'], edgecolor='black')

ax2.set_yticks(x)
ax2.set_yticklabels(spectype_by_target.index)
ax2.set_xlabel('Count', fontsize=11)
ax2.set_title('Spectral Types: TDE vs Non-TDE', fontsize=12, fontweight='bold')
ax2.legend()
ax2.invert_yaxis()

plt.tight_layout()
plt.savefig('spectype_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nüí° KEY INSIGHTS:")
print(f"   ‚Üí There are {len(spectype_counts)} different spectral types")
print(f"   ‚Üí TDE is labeled as: {tde_spectypes}")
print(f"   ‚Üí Most common non-TDE types: {spectype_counts.head(3).index.tolist()}")

In [None]:
# ============================================================
# CELL 6: REDSHIFT (Z) ANALYSIS
# ============================================================

print("="*60)
print("üåå REDSHIFT (Z) ANALYSIS")
print("="*60)

print("\nüîπ Train Redshift Statistics:")
print(train_log['Z'].describe())

print("\nüîπ Test Redshift Statistics:")
print(test_log['Z'].describe())

# Check Z_err
print("\nüîπ Z_err in Train (should be empty/NaN):")
print(f"   Missing: {train_log['Z_err'].isna().sum()} / {len(train_log)}")

print("\nüîπ Z_err in Test:")
print(test_log['Z_err'].describe())

# Visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Train Z distribution by target
ax1 = axes[0, 0]
train_log[train_log['target'] == 0]['Z'].hist(bins=50, alpha=0.7, label='Non-TDE', 
                                                color=COLORS['non_tde'], ax=ax1, edgecolor='black')
train_log[train_log['target'] == 1]['Z'].hist(bins=50, alpha=0.7, label='TDE', 
                                                color=COLORS['tde'], ax=ax1, edgecolor='black')
ax1.set_xlabel('Redshift (Z)', fontsize=11)
ax1.set_ylabel('Count', fontsize=11)
ax1.set_title('Training: Redshift Distribution by Target', fontsize=12, fontweight='bold')
ax1.legend()

# 2. Train vs Test Z distribution
ax2 = axes[0, 1]
train_log['Z'].hist(bins=50, alpha=0.7, label='Train', color='green', ax=ax2, edgecolor='black')
test_log['Z'].hist(bins=50, alpha=0.7, label='Test', color='orange', ax=ax2, edgecolor='black')
ax2.set_xlabel('Redshift (Z)', fontsize=11)
ax2.set_ylabel('Count', fontsize=11)
ax2.set_title('Train vs Test: Redshift Distribution', fontsize=12, fontweight='bold')
ax2.legend()

# 3. Z by SpecType (boxplot)
ax3 = axes[1, 0]
spectype_order = train_log.groupby('SpecType')['Z'].median().sort_values().index
train_log.boxplot(column='Z', by='SpecType', ax=ax3, 
                   positions=range(len(spectype_order)))
ax3.set_xticklabels(spectype_order, rotation=45, ha='right')
ax3.set_xlabel('Spectral Type', fontsize=11)
ax3.set_ylabel('Redshift (Z)', fontsize=11)
ax3.set_title('Redshift by Spectral Type', fontsize=12, fontweight='bold')
plt.suptitle('')

# 4. Z_err in test (important!)
ax4 = axes[1, 1]
test_log['Z_err'].hist(bins=50, color='purple', alpha=0.7, ax=ax4, edgecolor='black')
ax4.set_xlabel('Z_err (Redshift Error)', fontsize=11)
ax4.set_ylabel('Count', fontsize=11)
ax4.set_title('Test Set: Redshift Error Distribution\n‚ö†Ô∏è Train has NO Z_err!', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.savefig('redshift_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

# Statistical comparison
print("\nüìä REDSHIFT COMPARISON:")
print(f"   Train Z: mean={train_log['Z'].mean():.4f}, std={train_log['Z'].std():.4f}")
print(f"   Test Z:  mean={test_log['Z'].mean():.4f}, std={test_log['Z'].std():.4f}")

print("\nüí° KEY INSIGHTS:")
print("   ‚Üí ‚ö†Ô∏è TRAIN has spectroscopic Z (NO error)")
print("   ‚Üí ‚ö†Ô∏è TEST has photometric Z (WITH error Z_err)")
print("   ‚Üí This is a DOMAIN SHIFT between train and test!")
print("   ‚Üí TDE typically at lower redshift than some SNe")

In [None]:
# ============================================================
# CELL 7: EXTINCTION (EBV) ANALYSIS
# ============================================================

print("="*60)
print("üå´Ô∏è EXTINCTION (EBV) ANALYSIS")
print("="*60)

print("\nüîπ Train EBV Statistics:")
print(train_log['EBV'].describe())

print("\nüîπ Test EBV Statistics:")
print(test_log['EBV'].describe())

# Visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# 1. EBV distribution by target
ax1 = axes[0]
train_log[train_log['target'] == 0]['EBV'].hist(bins=50, alpha=0.7, label='Non-TDE', 
                                                  color=COLORS['non_tde'], ax=ax1, edgecolor='black')
train_log[train_log['target'] == 1]['EBV'].hist(bins=50, alpha=0.7, label='TDE', 
                                                  color=COLORS['tde'], ax=ax1, edgecolor='black')
ax1.set_xlabel('E(B-V)', fontsize=11)
ax1.set_ylabel('Count', fontsize=11)
ax1.set_title('Training: EBV by Target', fontsize=12, fontweight='bold')
ax1.legend()

# 2. Train vs Test
ax2 = axes[1]
train_log['EBV'].hist(bins=50, alpha=0.7, label='Train', color='green', ax=ax2, edgecolor='black')
test_log['EBV'].hist(bins=50, alpha=0.7, label='Test', color='orange', ax=ax2, edgecolor='black')
ax2.set_xlabel('E(B-V)', fontsize=11)
ax2.set_ylabel('Count', fontsize=11)
ax2.set_title('Train vs Test: EBV Distribution', fontsize=12, fontweight='bold')
ax2.legend()

# 3. EBV vs Z scatter
ax3 = axes[2]
ax3.scatter(train_log[train_log['target']==0]['Z'], 
            train_log[train_log['target']==0]['EBV'], 
            alpha=0.3, s=10, c=COLORS['non_tde'], label='Non-TDE')
ax3.scatter(train_log[train_log['target']==1]['Z'], 
            train_log[train_log['target']==1]['EBV'], 
            alpha=0.8, s=30, c=COLORS['tde'], label='TDE', marker='*')
ax3.set_xlabel('Redshift (Z)', fontsize=11)
ax3.set_ylabel('E(B-V)', fontsize=11)
ax3.set_title('EBV vs Redshift', fontsize=12, fontweight='bold')
ax3.legend()

plt.tight_layout()
plt.savefig('ebv_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nüí° KEY INSIGHTS:")
print("   ‚Üí EBV measures dust extinction along line of sight")
print("   ‚Üí Lower EBV = less dust = cleaner observations")
print("   ‚Üí Distribution similar between train and test ‚úÖ")

In [None]:
# ============================================================
# CELL 8: LOAD LIGHTCURVES (Sample first)
# ============================================================

print("="*60)
print("üí´ LOADING LIGHTCURVE DATA")
print("="*60)

def load_lightcurves_for_split(split_name, data_type='train'):
    """Load lightcurve data for a specific split"""
    file_path = DATA_PATH / split_name / f"{data_type}_full_lightcurves.csv"
    if file_path.exists():
        return pd.read_csv(file_path)
    return None

# Load all training lightcurves
print("\nüì• Loading ALL training lightcurves...")
train_splits = train_log['split'].unique()
print(f"   Found {len(train_splits)} splits: {sorted(train_splits)}")

all_train_lcs = []
for split in tqdm(train_splits, desc="Loading train splits"):
    lc = load_lightcurves_for_split(split, 'train')
    if lc is not None:
        all_train_lcs.append(lc)

train_lc = pd.concat(all_train_lcs, ignore_index=True)
del all_train_lcs
gc.collect()

print(f"\n‚úÖ Training lightcurves loaded!")
print(f"   Shape: {train_lc.shape}")
print(f"   Columns: {train_lc.columns.tolist()}")

# Display sample
print("\nüìã Sample lightcurve data:")
display(train_lc.head(10))

# Memory usage
mem_mb = train_lc.memory_usage(deep=True).sum() / (1024*1024)
print(f"\nüíæ Memory usage: {mem_mb:.2f} MB")

In [None]:
# ============================================================
# CELL 9: LIGHTCURVE STATISTICS
# ============================================================

print("="*60)
print("üìä LIGHTCURVE STATISTICS")
print("="*60)

# Basic stats
print("\nüîπ Lightcurve Data Info:")
print(train_lc.dtypes)

print("\nüîπ Statistical Summary:")
display(train_lc.describe())

# Observations per object
obs_per_object = train_lc.groupby('object_id').size()
print(f"\nüîπ Observations per Object:")
print(f"   Min:    {obs_per_object.min()}")
print(f"   Max:    {obs_per_object.max()}")
print(f"   Mean:   {obs_per_object.mean():.1f}")
print(f"   Median: {obs_per_object.median()}")

# Filters
filters = train_lc['Filter'].unique()
print(f"\nüîπ Filters (bands): {sorted(filters)}")

# Observations per filter
obs_per_filter = train_lc.groupby('Filter').size()
print(f"\nüîπ Observations per Filter:")
print(obs_per_filter.sort_index())

# Time span
time_span = train_lc.groupby('object_id')['Time (MJD)'].agg(['min', 'max'])
time_span['duration'] = time_span['max'] - time_span['min']
print(f"\nüîπ Time Span per Object (days):")
print(f"   Min duration:    {time_span['duration'].min():.1f}")
print(f"   Max duration:    {time_span['duration'].max():.1f}")
print(f"   Mean duration:   {time_span['duration'].mean():.1f}")
print(f"   Median duration: {time_span['duration'].median():.1f}")

In [None]:
# ============================================================
# CELL 10: LIGHTCURVE DISTRIBUTIONS
# ============================================================

fig, axes = plt.subplots(2, 3, figsize=(16, 10))

# 1. Observations per object
ax1 = axes[0, 0]
obs_per_object.hist(bins=50, ax=ax1, color='steelblue', edgecolor='black')
ax1.axvline(obs_per_object.median(), color='red', linestyle='--', label=f'Median: {obs_per_object.median():.0f}')
ax1.set_xlabel('Number of Observations', fontsize=11)
ax1.set_ylabel('Count', fontsize=11)
ax1.set_title('Observations per Object', fontsize=12, fontweight='bold')
ax1.legend()

# 2. Observations per filter
ax2 = axes[0, 1]
obs_per_filter.sort_index().plot(kind='bar', ax=ax2, color='teal', edgecolor='black')
ax2.set_xlabel('Filter', fontsize=11)
ax2.set_ylabel('Total Observations', fontsize=11)
ax2.set_title('Observations per Filter (Band)', fontsize=12, fontweight='bold')
ax2.tick_params(axis='x', rotation=0)

# 3. Time span distribution
ax3 = axes[0, 2]
time_span['duration'].hist(bins=50, ax=ax3, color='purple', edgecolor='black')
ax3.set_xlabel('Time Span (days)', fontsize=11)
ax3.set_ylabel('Count', fontsize=11)
ax3.set_title('Lightcurve Duration', fontsize=12, fontweight='bold')

# 4. Flux distribution
ax4 = axes[1, 0]
# Clip extreme values for visualization
flux_clipped = train_lc['Flux'].clip(-500, 2000)
flux_clipped.hist(bins=100, ax=ax4, color='orange', edgecolor='black', alpha=0.7)
ax4.set_xlabel('Flux (ŒºJy)', fontsize=11)
ax4.set_ylabel('Count', fontsize=11)
ax4.set_title('Flux Distribution (clipped)', fontsize=12, fontweight='bold')

# 5. Flux error distribution
ax5 = axes[1, 1]
train_lc['Flux_err'].hist(bins=100, ax=ax5, color='red', edgecolor='black', alpha=0.7)
ax5.set_xlabel('Flux Error (ŒºJy)', fontsize=11)
ax5.set_ylabel('Count', fontsize=11)
ax5.set_title('Flux Error Distribution', fontsize=12, fontweight='bold')

# 6. SNR distribution
ax6 = axes[1, 2]
snr = train_lc['Flux'].abs() / train_lc['Flux_err']
snr_clipped = snr.clip(0, 50)
snr_clipped.hist(bins=50, ax=ax6, color='green', edgecolor='black', alpha=0.7)
ax6.axvline(3, color='red', linestyle='--', label='SNR=3 (detection threshold)')
ax6.set_xlabel('Signal-to-Noise Ratio', fontsize=11)
ax6.set_ylabel('Count', fontsize=11)
ax6.set_title('SNR Distribution', fontsize=12, fontweight='bold')
ax6.legend()

plt.tight_layout()
plt.savefig('lightcurve_distributions.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nüí° KEY INSIGHTS:")
print(f"   ‚Üí Average {obs_per_object.mean():.0f} observations per object")
print(f"   ‚Üí 6 filters: u, g, r, i, z, y (LSST bands)")
print(f"   ‚Üí Some flux values are NEGATIVE (baseline subtraction)")
print(f"   ‚Üí SNR > 3 typically means 'detected'")

In [None]:
# ============================================================
# CELL 11: EXAMPLE LIGHTCURVES - TDE vs NON-TDE
# ============================================================

print("="*60)
print("üåü EXAMPLE LIGHTCURVES")
print("="*60)

# Merge with target
train_lc_with_target = train_lc.merge(
    train_log[['object_id', 'target', 'SpecType', 'Z']], 
    on='object_id'
)

# Get sample TDE and non-TDE objects
tde_objects = train_log[train_log['target'] == 1]['object_id'].values
non_tde_objects = train_log[train_log['target'] == 0]['object_id'].values

# Select random samples
np.random.seed(42)
sample_tde = np.random.choice(tde_objects, min(3, len(tde_objects)), replace=False)
sample_non_tde = np.random.choice(non_tde_objects, 3, replace=False)

# Color map for filters
filter_colors = {'u': 'purple', 'g': 'green', 'r': 'red', 
                 'i': 'brown', 'z': 'gray', 'y': 'black'}

def plot_lightcurve(ax, object_id, data, title_prefix=''):
    """Plot lightcurve for a single object"""
    obj_data = data[data['object_id'] == object_id]
    obj_info = train_log[train_log['object_id'] == object_id].iloc[0]
    
    for filt in sorted(obj_data['Filter'].unique()):
        filt_data = obj_data[obj_data['Filter'] == filt].sort_values('Time (MJD)')
        ax.errorbar(filt_data['Time (MJD)'], filt_data['Flux'], 
                   yerr=filt_data['Flux_err'], fmt='o-', 
                   label=filt, color=filter_colors.get(filt, 'black'),
                   markersize=4, alpha=0.7, capsize=2)
    
    ax.axhline(0, color='gray', linestyle='--', alpha=0.5)
    ax.set_xlabel('Time (MJD)')
    ax.set_ylabel('Flux (ŒºJy)')
    ax.set_title(f"{title_prefix}\n{obj_info['SpecType']} | Z={obj_info['Z']:.3f}", fontsize=10)
    ax.legend(loc='best', fontsize=8)

# Plot
fig, axes = plt.subplots(2, 3, figsize=(16, 10))

# TDE examples (top row)
for i, obj_id in enumerate(sample_tde):
    plot_lightcurve(axes[0, i], obj_id, train_lc, f'üî¥ TDE Example {i+1}')

# Non-TDE examples (bottom row)
for i, obj_id in enumerate(sample_non_tde):
    plot_lightcurve(axes[1, i], obj_id, train_lc, f'üîµ Non-TDE Example {i+1}')

plt.tight_layout()
plt.savefig('example_lightcurves.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nüí° OBSERVATIONS from lightcurves:")
print("   ‚Üí TDEs typically show smooth rise and decline")
print("   ‚Üí Different filters show similar evolution (color information)")
print("   ‚Üí Time span varies significantly between objects")

# Clean up
del train_lc_with_target
gc.collect()

In [None]:
# ============================================================
# CELL 12: TDE vs NON-TDE LIGHTCURVE COMPARISON
# ============================================================

print("="*60)
print("üìä TDE vs NON-TDE COMPARISON")
print("="*60)

# Compute statistics per object
def compute_object_stats(lc_df):
    """Compute statistics for each object"""
    stats_list = []
    
    for obj_id, group in tqdm(lc_df.groupby('object_id'), desc="Computing stats"):
        stats = {'object_id': obj_id}
        
        # Overall stats
        stats['n_obs'] = len(group)
        stats['time_span'] = group['Time (MJD)'].max() - group['Time (MJD)'].min()
        stats['flux_mean'] = group['Flux'].mean()
        stats['flux_std'] = group['Flux'].std()
        stats['flux_max'] = group['Flux'].max()
        stats['flux_min'] = group['Flux'].min()
        stats['flux_range'] = stats['flux_max'] - stats['flux_min']
        stats['snr_mean'] = (group['Flux'].abs() / group['Flux_err']).mean()
        stats['negative_flux_ratio'] = (group['Flux'] < 0).mean()
        
        # Per filter stats
        for filt in ['u', 'g', 'r', 'i', 'z', 'y']:
            filt_data = group[group['Filter'] == filt]
            if len(filt_data) > 0:
                stats[f'{filt}_n_obs'] = len(filt_data)
                stats[f'{filt}_flux_mean'] = filt_data['Flux'].mean()
                stats[f'{filt}_flux_max'] = filt_data['Flux'].max()
            else:
                stats[f'{filt}_n_obs'] = 0
                stats[f'{filt}_flux_mean'] = np.nan
                stats[f'{filt}_flux_max'] = np.nan
        
        stats_list.append(stats)
    
    return pd.DataFrame(stats_list)

# Compute stats (this may take a few minutes)
print("\n‚è≥ Computing object-level statistics...")
object_stats = compute_object_stats(train_lc)
object_stats = object_stats.merge(train_log[['object_id', 'target', 'SpecType', 'Z', 'EBV']], on='object_id')

print(f"‚úÖ Stats computed for {len(object_stats)} objects")
display(object_stats.head())

# Save for later use
object_stats.to_csv('object_stats.csv', index=False)
print("üíæ Saved to object_stats.csv")

In [None]:
# ============================================================
# CELL 13: TDE vs NON-TDE FEATURE COMPARISON
# ============================================================

# Split by target
tde_stats = object_stats[object_stats['target'] == 1]
non_tde_stats = object_stats[object_stats['target'] == 0]

print(f"TDE objects: {len(tde_stats)}")
print(f"Non-TDE objects: {len(non_tde_stats)}")

# Features to compare
features_to_compare = ['n_obs', 'time_span', 'flux_mean', 'flux_std', 
                       'flux_max', 'flux_range', 'snr_mean', 'negative_flux_ratio', 'Z']

fig, axes = plt.subplots(3, 3, figsize=(15, 12))
axes = axes.flatten()

for i, feat in enumerate(features_to_compare):
    ax = axes[i]
    
    # KDE plots
    non_tde_stats[feat].dropna().plot(kind='kde', ax=ax, label='Non-TDE', 
                                        color=COLORS['non_tde'], linewidth=2)
    tde_stats[feat].dropna().plot(kind='kde', ax=ax, label='TDE', 
                                    color=COLORS['tde'], linewidth=2)
    
    ax.set_xlabel(feat, fontsize=11)
    ax.set_title(f'{feat} Distribution', fontsize=12, fontweight='bold')
    ax.legend()
    
    # Add statistical test
    stat, pval = stats.mannwhitneyu(
        non_tde_stats[feat].dropna(), 
        tde_stats[feat].dropna(),
        alternative='two-sided'
    )
    significance = "***" if pval < 0.001 else "**" if pval < 0.01 else "*" if pval < 0.05 else "ns"
    ax.text(0.95, 0.95, f'p={pval:.2e} {significance}', 
            transform=ax.transAxes, ha='right', va='top', fontsize=9,
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.savefig('tde_vs_nontde_features.png', dpi=150, bbox_inches='tight')
plt.show()

# Statistical comparison table
print("\nüìä STATISTICAL COMPARISON:")
print("="*70)
print(f"{'Feature':<20} {'Non-TDE Mean':>15} {'TDE Mean':>15} {'Difference':>15}")
print("="*70)
for feat in features_to_compare:
    non_tde_mean = non_tde_stats[feat].mean()
    tde_mean = tde_stats[feat].mean()
    diff_pct = (tde_mean - non_tde_mean) / non_tde_mean * 100 if non_tde_mean != 0 else 0
    print(f"{feat:<20} {non_tde_mean:>15.3f} {tde_mean:>15.3f} {diff_pct:>14.1f}%")

print("\nüí° KEY DISCRIMINATING FEATURES:")
print("   ‚Üí Look for features with significant differences (*** p < 0.001)")
print("   ‚Üí These will be important for classification!")

In [None]:
# ============================================================
# CELL 14: EDA SUMMARY
# ============================================================

print("="*70)
print("üìã EDA SUMMARY - KEY FINDINGS")
print("="*70)

print("""
üîπ DATASET SIZE:
   - Training: {train_n:,} objects
   - Testing:  {test_n:,} objects
   - Total lightcurve observations: {lc_n:,}

üîπ CLASS IMBALANCE:
   - Non-TDE: {non_tde_n:,} ({non_tde_pct:.1f}%)
   - TDE:     {tde_n:,} ({tde_pct:.1f}%)
   - Ratio:   {ratio:.0f}:1

üîπ DOMAIN SHIFT (Train vs Test):
   ‚ö†Ô∏è Train has spectroscopic redshift (Z) with NO error
   ‚ö†Ô∏è Test has photometric redshift (Z) WITH error (Z_err)
   ‚Üí Model must be robust to redshift uncertainty!

üîπ KEY FEATURES IDENTIFIED:
   - Redshift (Z): TDEs may be at different z distribution
   - Extinction (EBV): Similar distribution, good for correction
   - Flux statistics: mean, std, max, range differ between classes
   - Time span: Lightcurve duration
   - Per-band features: 6 LSST filters (u,g,r,i,z,y)

üîπ NEXT STEPS:
   1. Feature Engineering based on domain knowledge
   2. Handle class imbalance
   3. Build baseline model
   4. Optimize for F1 score
""".format(
    train_n=len(train_log),
    test_n=len(test_log),
    lc_n=len(train_lc),
    non_tde_n=len(train_log[train_log['target']==0]),
    non_tde_pct=len(train_log[train_log['target']==0])/len(train_log)*100,
    tde_n=len(train_log[train_log['target']==1]),
    tde_pct=len(train_log[train_log['target']==1])/len(train_log)*100,
    ratio=len(train_log[train_log['target']==0])/len(train_log[train_log['target']==1])
))

print("‚úÖ EDA COMPLETE!")
print("\nüìå Saved files:")
print("   - target_distribution.png")
print("   - spectype_distribution.png")
print("   - redshift_analysis.png")
print("   - ebv_analysis.png")
print("   - lightcurve_distributions.png")
print("   - example_lightcurves.png")
print("   - tde_vs_nontde_features.png")
print("   - object_stats.csv")