In [None]:
# 13. REGRESSION MODEL ARCHITECTURES (REVISED: PRE-NORM & EXTERNAL COMPILE)
# ==============================================================================
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models

# --- CUSTOM METRIC: R-SQUARED (R2) ---
def r2_keras(y_true, y_pred):
    """
    Custom R2-score metric for Keras monitoring.
    R2 = 1 - (SS_res / SS_tot)
    """
    SS_res =  tf.reduce_sum(tf.square(y_true - y_pred))
    SS_tot = tf.reduce_sum(tf.square(y_true - tf.reduce_mean(y_true)))
    return (1 - SS_res/(SS_tot + tf.keras.backend.epsilon()))

# ------------------------------------------------------------------------------
# MODEL A: PRE-NORM RESNET-MLP REGRESSOR
# Revision: Post-Norm -> Pre-Norm, Removed Compile
# ------------------------------------------------------------------------------
def build_mlp_regressor(input_dim, width=512, depth=6, dropout_rate=0.1):
    inputs = layers.Input(shape=(input_dim,), name="input_features")

    # 1. Linear Projection (To match block dimensions)
    x = layers.Dense(width, activation='linear')(inputs)

    # 2. Residual Blocks (Pre-Norm Style)
    #
    for i in range(depth):
        shortcut = x # Save Identity

        # Normalize FIRST (Pre-Norm)
        x_norm = layers.LayerNormalization()(x)

        # Transformation Branch (Dense -> Dropout -> Dense)
        branch = layers.Dense(width, activation='gelu')(x_norm)
        branch = layers.Dropout(dropout_rate)(branch)
        branch = layers.Dense(width, activation='linear')(branch)

        # Merge (Add)
        x = layers.Add()([shortcut, branch])

    # 3. Output Head
    x = layers.LayerNormalization()(x) # Final Norm is mandatory for Pre-Norm
    x = layers.Dense(64, activation='gelu')(x)

    # Output: 2 Neurons (Mass, Age) - Linear Activation for Regression
    outputs = layers.Dense(2, activation='linear', name='reg_output')(x)

    # NOTE: model.compile REMOVED to allow external scheduling
    model = keras.Model(inputs=inputs, outputs=outputs, name="ResNet_MLP_Regressor")
    return model

# ------------------------------------------------------------------------------
# MODEL B: PRE-NORM FT-TRANSFORMER REGRESSOR
# Revision: Pooling -> Flatten, Post-Norm -> Pre-Norm
# ------------------------------------------------------------------------------
def build_transformer_regressor(input_dim, embed_dim=64, num_heads=4, num_blocks=3, dropout=0.1):
    inputs = layers.Input(shape=(input_dim,), name="input_features")

    # 1. Feature Tokenizer
    # Converts each feature into a learned embedding vector
    x = layers.Reshape((input_dim, 1))(inputs)
    x = layers.Conv1D(filters=embed_dim, kernel_size=1, activation=None)(x)

    # 2. Transformer Blocks (Pre-Norm Style)
    #
    for i in range(num_blocks):
        # --- Attention Sub-layer ---
        x_norm = layers.LayerNormalization(epsilon=1e-6)(x) # Pre-Norm
        attn_output = layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=embed_dim // num_heads,
            dropout=dropout
        )(x_norm, x_norm)
        x = layers.Add()([x, attn_output])

        # --- Feed Forward Sub-layer ---
        x_norm = layers.LayerNormalization(epsilon=1e-6)(x) # Pre-Norm
        ffn = keras.Sequential([
            layers.Dense(embed_dim * 2, activation='gelu'),
            layers.Dropout(dropout),
            layers.Dense(embed_dim),
        ])
        ffn_output = ffn(x_norm)
        x = layers.Add()([x, ffn_output])

    # 3. Prediction Head
    # Using Flatten instead of GlobalAveragePooling to preserve feature-specific information
    x = layers.LayerNormalization(epsilon=1e-6)(x) # Final Norm
    x = layers.Flatten()(x)

    x = layers.Dense(64, activation='gelu')(x)
    x = layers.Dropout(dropout)(x)

    # Output: 2 Neurons (Mass, Age)
    outputs = layers.Dense(2, activation='linear', name='reg_output')(x)

    # NOTE: model.compile REMOVED
    model = keras.Model(inputs=inputs, outputs=outputs, name="Transformer_Regressor")
    return model

print("Regression Architectures Ready (Pre-Norm Version).")

In [None]:
# 14. MLP TRAINING & EVALUATION (REVISED: PRE-NORM + WARMUP SCHEDULER)
# ==============================================================================
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, losses, callbacks
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import QuantileTransformer, StandardScaler
from sklearn.metrics import r2_score, mean_absolute_error
import os

# Set Scientific Plot Style
plt.rcParams['font.family'] = 'serif'
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.3

# --- 1. DATA SETUP ---
print("Preparing Data for MLP Regressor...")
# Load Data
try:
    df_reg = pd.read_parquet("df_reg_flame_final.parquet")
    print(f"   -> Data loaded: {len(df_reg):,} rows.")
except:
    print("Error: Parquet file not found. Please run Preprocessing (Section 8) first!")
    df_reg = pd.DataFrame()

if not df_reg.empty:
    # --- CONFIG INPUT FEATURES ---
    # Ensuring teff_gspphot is included for better regression accuracy
    input_cols = [
        'bp_rp0', 'bp_g', 'g_rp',
        'abs_G0',
        'parallax', 'ruwe',
        'l_norm',
        'teff_gspphot'  # [IMPORTANT] Effective Temperature Feature
    ]

    # Safety Check for missing columns
    missing_cols = [c for c in input_cols if c not in df_reg.columns]
    if missing_cols:
        print(f"Warning: The following columns are missing: {missing_cols}")
        input_cols = [c for c in input_cols if c in df_reg.columns]

    print(f"   -> Using {len(input_cols)} Input Features.")

    X = df_reg[input_cols].values
    # We predict Log-Mass and Log-Age to stabilize training
    y = df_reg[['log_mass', 'log_age']].values

    # Split Data
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # Scaling
    # QuantileTransformer is ideal for MLP to Gaussianize input features
    print("   -> Scaling Data...")
    scaler_X = QuantileTransformer(output_distribution='normal', random_state=42)
    X_train_scaled = scaler_X.fit_transform(X_train)
    X_test_scaled = scaler_X.transform(X_test)

    scaler_y = StandardScaler()
    y_train_scaled = scaler_y.fit_transform(y_train)
    y_test_scaled = scaler_y.transform(y_test)

    # --- HYPERPARAMETER & SCHEDULER ---
    BATCH_SIZE = 2048
    EPOCHS = 50

    # [REVISION] Define Warmup Scheduler (Cosine Decay)
    # This replaces ReduceLROnPlateau for better stability in deep networks
    steps_per_epoch = len(X_train_scaled) // BATCH_SIZE
    total_steps = steps_per_epoch * EPOCHS

    lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate=1e-5,      # Start small (Warmup)
        decay_steps=total_steps,
        alpha=0.01,                      # Decay to 1% of max LR
        warmup_target=1e-3,              # Max LR for MLP
        warmup_steps=int(0.1 * total_steps) # 10% Warmup duration
    )

    # --- 3. BUILD & COMPILE ---
    print(f"\nBuilding & Compiling MLP Model (Pre-Norm)...")

    if 'build_mlp_regressor' in locals():
        model_mlp = build_mlp_regressor(
            input_dim=X_train.shape[1],
            width=512,
            depth=6,
            dropout_rate=0.1
        )
    else:
        raise ValueError("Error: 'build_mlp_regressor' function is not defined.")

    # [REVISION] Manual Compile with Huber Loss
    optimizer = optimizers.AdamW(learning_rate=lr_schedule, weight_decay=1e-4)

    model_mlp.compile(
        optimizer=optimizer,
        loss=losses.Huber(delta=1.0), # Huber loss is robust against outliers in astronomical data
        metrics=['mae', 'mse', r2_keras],
        jit_compile=True
    )

    # --- 4. TRAINING LOOP ---
    callbacks_list = [
        keras.callbacks.EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True)
    ]

    history_mlp = model_mlp.fit(
        X_train_scaled, y_train_scaled,
        validation_data=(X_test_scaled, y_test_scaled),
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        callbacks=callbacks_list,
        verbose=1
    )

    # --- 5. PHYSICAL EVALUATION (CRITICAL STEP) ---
    print("\nEvaluating Results in Physical Units...")

    # 1. Predict (in Log Normal Scale)
    preds_scaled = model_mlp.predict(X_test_scaled, verbose=0)

    # 2. Inverse Standard Scaler (Back to Log Scale)
    preds_log = scaler_y.inverse_transform(preds_scaled)
    y_true_log = scaler_y.inverse_transform(y_test_scaled)

    # 3. CONVERT TO PHYSICAL UNITS (10^x)
    pred_mass_phys = 10 ** preds_log[:, 0]
    true_mass_phys = 10 ** y_true_log[:, 0]

    pred_age_phys  = 10 ** preds_log[:, 1]
    true_age_phys  = 10 ** y_true_log[:, 1]

    # 4. Calculate Physical Errors
    mae_mass = mean_absolute_error(true_mass_phys, pred_mass_phys)
    mae_age = mean_absolute_error(true_age_phys, pred_age_phys)

    r2_mass = r2_score(true_mass_phys, pred_mass_phys)
    r2_age = r2_score(true_age_phys, pred_age_phys)

    print("-" * 50)
    print(f"PHYSICAL EVALUATION RESULTS (Pre-Norm MLP):")
    print(f"   -> Mass ($M_\odot$): MAE = {mae_mass:.3f} $M_\odot$ | R2 = {r2_mass:.4f}")
    print(f"   -> Age (Gyr)    : MAE = {mae_age:.3f} Gyr | R2 = {r2_age:.4f}")
    print("-" * 50)

    # --- VISUALIZATION (PHYSICAL UNITS) ---
    fig, ax = plt.subplots(2, 2, figsize=(14, 12))
    plt.suptitle("MLP Evaluation: Physical Units ($M_{\odot}$ & Gyr)", fontsize=16, fontweight='bold')

    # Scatter Plot: Mass
    ax[0,0].scatter(true_mass_phys, pred_mass_phys, s=2, alpha=0.3, color='tab:blue')
    min_m, max_m = true_mass_phys.min(), true_mass_phys.max()
    ax[0,0].plot([min_m, max_m], [min_m, max_m], 'r--', lw=2)
    ax[0,0].set_title(f"Mass: Predicted vs Actual ($R^2={r2_mass:.3f}$)")
    ax[0,0].set_xlabel("Actual Mass ($M_{\odot}$)")
    ax[0,0].set_ylabel("Predicted Mass ($M_{\odot}$)")
    ax[0,0].set_xscale('log'); ax[0,0].set_yscale('log')

    # Scatter Plot: Age
    ax[0,1].scatter(true_age_phys, pred_age_phys, s=2, alpha=0.3, color='tab:orange')
    min_a, max_a = true_age_phys.min(), true_age_phys.max()
    ax[0,1].plot([min_a, max_a], [min_a, max_a], 'r--', lw=2)
    ax[0,1].set_title(f"Age: Predicted vs Actual ($R^2={r2_age:.3f}$)")
    ax[0,1].set_xlabel("Actual Age (Gyr)")
    ax[0,1].set_ylabel("Predicted Age (Gyr)")
    ax[0,1].set_xscale('log'); ax[0,1].set_yscale('log')

    # Residual Plot: Mass
    res_mass = pred_mass_phys - true_mass_phys
    ax[1,0].scatter(true_mass_phys, res_mass, s=2, alpha=0.3, color='purple')
    ax[1,0].axhline(0, color='black', linestyle='--')
    ax[1,0].set_title(f"Mass Residuals (MAE: {mae_mass:.3f} $M_\odot$)")
    ax[1,0].set_xlabel("Actual Mass ($M_{\odot}$)")
    ax[1,0].set_ylabel("Error ($Pred - True$)")
    ax[1,0].set_xscale('log')

    # Residual Plot: Age
    res_age = pred_age_phys - true_age_phys
    ax[1,1].scatter(true_age_phys, res_age, s=2, alpha=0.3, color='green')
    ax[1,1].axhline(0, color='black', linestyle='--')
    ax[1,1].set_title(f"Age Residuals (MAE: {mae_age:.3f} Gyr)")
    ax[1,1].set_xlabel("Actual Age (Gyr)")
    ax[1,1].set_ylabel("Error ($Pred - True$)")
    ax[1,1].set_xscale('log')

    plt.tight_layout()
    plt.savefig('mlp_physical_evaluation.png', dpi=300)
    plt.show()
    print("Plot saved: mlp_physical_evaluation.png")

    # Save Model
    model_mlp.save('best_mlp_regressor_full_physics.keras')
    print("Model saved: best_mlp_regressor_full_physics.keras")

In [None]:
# 15. TRANSFORMER REGRESSION ARCHITECTURE (V2: PRE-NORM STABILITY)
# ==============================================================================
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, losses

# --- CUSTOM METRIC: R-SQUARED (R2) ---
def r2_keras(y_true, y_pred):
    """
    Custom R2 metric for Keras.
    Formula: 1 - (SS_res / SS_tot)
    """
    SS_res =  tf.reduce_sum(tf.square(y_true - y_pred))
    SS_tot = tf.reduce_sum(tf.square(y_true - tf.reduce_mean(y_true)))
    return (1 - SS_res/(SS_tot + tf.keras.backend.epsilon()))

# ------------------------------------------------------------------------------
# MODEL DEFINITION: PRE-NORM FT-TRANSFORMER
# ------------------------------------------------------------------------------
def build_transformer_regressor_v2(input_dim, embed_dim=64, num_heads=8, num_blocks=4, dropout=0.1):
    """
    Builds a Pre-Norm Transformer for Regression.

    Key Architectural Choices:
    1. Pre-Norm: LayerNormalization is applied BEFORE Attention/FFN.
       This creates a 'gradient highway' for better stability in deep networks (Standard in GPT-3/PaLM).
    2. Flattening: Instead of GlobalAveragePooling, we flatten the features
       to preserve the distinct information of each physical parameter.
    """
    inputs = layers.Input(shape=(input_dim,), name="input_features")

    # 1. Feature Tokenizer
    # Projects scalar features into an embedding space
    x = layers.Reshape((input_dim, 1))(inputs)
    x = layers.Conv1D(filters=embed_dim, kernel_size=1, activation=None)(x)

    # 2. Transformer Blocks (Pre-Norm Architecture)
    #
    for i in range(num_blocks):
        # --- Sub-layer 1: Multi-Head Attention ---
        # Normalize FIRST
        x_norm = layers.LayerNormalization(epsilon=1e-6)(x)

        attn_output = layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=embed_dim // num_heads,
            dropout=dropout
        )(x_norm, x_norm)

        # Skip Connection (Add)
        x = layers.Add()([x, attn_output])

        # --- Sub-layer 2: Feed Forward Network (FFN) ---
        # Normalize FIRST
        x_norm = layers.LayerNormalization(epsilon=1e-6)(x)

        ffn = keras.Sequential([
            layers.Dense(embed_dim * 2, activation='gelu'), # Expand dim 2x
            layers.Dropout(dropout),
            layers.Dense(embed_dim), # Project back
        ])
        ffn_output = ffn(x_norm)

        # Skip Connection (Add)
        x = layers.Add()([x, ffn_output])

    # 3. Output Head (Flattening Strategy)
    x = layers.LayerNormalization(epsilon=1e-6)(x) # Final Norm
    x = layers.Flatten()(x)

    # Deep Regressor Head
    x = layers.Dense(128, activation='gelu')(x)
    x = layers.Dropout(dropout)(x)
    x = layers.Dense(64, activation='gelu')(x)

    # Output: 2 targets (Mass, Age)
    outputs = layers.Dense(2, activation='linear', name='reg_output')(x)

    # NOTE: model.compile is intentionally omitted to allow external scheduling
    model = keras.Model(inputs=inputs, outputs=outputs, name="PreNorm_Transformer_V2")
    return model

print("Transformer V2 Architecture (Pre-Norm) Ready.")

In [None]:
# 16. TRANSFORMER V2 TRAINING (REVISED: SAVE PLOTS & LEARNING CURVES)
# ==============================================================================
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, losses
from sklearn.metrics import r2_score

# Ensure data exists (from Section 14)
if 'X_train_scaled' in locals():

    # --- [NEW] 0. FEATURE CORRELATION MATRIX (OPTIONAL CHECK) ---
    # Plot correlation only if the dataframe is still in memory
    if 'df_reg' in locals() and 'input_cols' in locals():
        print("Generating Feature Correlation Matrix...")
        plt.figure(figsize=(12, 10))

        # Merge inputs and targets for correlation analysis
        cols_to_plot = input_cols + ['log_mass', 'log_age']
        # Filter strictly existing columns to prevent errors
        cols_to_plot = [c for c in cols_to_plot if c in df_reg.columns]

        df_corr = df_reg[cols_to_plot].corr()
        sns.heatmap(df_corr, annot=True, fmt=".2f", cmap='coolwarm', square=True, cbar_kws={"shrink": .8})
        plt.title("Feature & Target Correlation Matrix (Transformer Data)")
        plt.tight_layout()
        plt.savefig('transformer_v2_feature_correlation.png', dpi=300)
        plt.show()
        print("Plot saved: transformer_v2_feature_correlation.png")

    # --- 1. CONFIGURATION ---
    BATCH_SIZE = 512
    EPOCHS = 50

    # --- 2. LEARNING RATE SCHEDULER (Cosine Decay + Warmup) ---
    # Strategy: Start from 0, ramp up to Max, then decay slowly
    #
    total_steps = (len(X_train_scaled) // BATCH_SIZE) * EPOCHS
    warmup_steps = int(0.1 * total_steps) # 10% steps for warmup

    lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate=1e-5,      # Warmup start (very small)
        decay_steps=total_steps,
        alpha=0.01,                      # Final LR = 1% of Max LR
        warmup_target=1e-3,              # Peak LR
        warmup_steps=warmup_steps
    )

    # Optimizer with Weight Decay (AdamW) for better regularization
    optimizer = optimizers.AdamW(learning_rate=lr_schedule, weight_decay=1e-4)

    # --- 3. BUILD & COMPILE ---
    print("\nBuilding Transformer V2 (Pre-Norm)...")
    # Verify function definition
    if 'build_transformer_regressor_v2' in locals():
        model_trans_v2 = build_transformer_regressor_v2(
            input_dim=X_train_scaled.shape[1],
            embed_dim=64,
            num_heads=8,
            num_blocks=4,
            dropout=0.1
        )
    else:
        raise ValueError("Error: 'build_transformer_regressor_v2' function not defined.")

    model_trans_v2.compile(
        optimizer=optimizer,
        loss=losses.Huber(delta=1.0),
        metrics=['mae', 'mse', r2_keras],
        jit_compile=True
    )

    # --- 4. TRAINING ---
    print("Starting Training with Warmup Scheduler...")
    callbacks_list = [
        keras.callbacks.EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True)
    ]

    history_trans_v2 = model_trans_v2.fit(
        X_train_scaled, y_train_scaled,
        validation_data=(X_test_scaled, y_test_scaled),
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        callbacks=callbacks_list,
        verbose=1
    )

    # --- 5. EVALUATION & VISUALIZATION ---
    print("\nEvaluating Transformer V2...")

    # A. Learning Curves (Critical for checking Warmup effect)
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history_trans_v2.history['loss'], label='Train Loss')
    plt.plot(history_trans_v2.history['val_loss'], label='Val Loss')
    plt.title('Loss Curve (Huber)')
    plt.xlabel('Epochs')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(history_trans_v2.history['r2_keras'], label='Train R2')
    plt.plot(history_trans_v2.history['val_r2_keras'], label='Val R2')
    plt.title('R2 Score Curve')
    plt.xlabel('Epochs')
    plt.legend()
    plt.tight_layout()
    plt.savefig('transformer_v2_learning_curves.png', dpi=300)
    plt.show()
    print("Plot saved: transformer_v2_learning_curves.png")

    # B. Prediction Plots
    preds_scaled = model_trans_v2.predict(X_test_scaled, verbose=0)
    preds = scaler_y.inverse_transform(preds_scaled)
    y_true = scaler_y.inverse_transform(y_test_scaled)

    r2_mass = r2_score(y_true[:,0], preds[:,0])
    r2_age = r2_score(y_true[:,1], preds[:,1])

    print(f"   -> R2 Score (Mass): {r2_mass:.4f}")
    print(f"   -> R2 Score (Age) : {r2_age:.4f}")

    fig, ax = plt.subplots(1, 2, figsize=(14, 6))
    plt.suptitle(f"Transformer V2 Results (Log Scale)", fontsize=16)

    # Mass (Log Scale)
    ax[0].scatter(y_true[:,0], preds[:,0], s=1, alpha=0.3, color='tab:blue')
    ax[0].plot([y_true[:,0].min(), y_true[:,0].max()], [y_true[:,0].min(), y_true[:,0].max()], 'r--')
    ax[0].set_title(f"Log Mass ($R^2={r2_mass:.3f}$)")
    ax[0].set_xlabel("True Log Mass"); ax[0].set_ylabel("Pred Log Mass")

    # Age (Log Scale)
    ax[1].scatter(y_true[:,1], preds[:,1], s=1, alpha=0.3, color='tab:orange')
    ax[1].plot([y_true[:,1].min(), y_true[:,1].max()], [y_true[:,1].min(), y_true[:,1].max()], 'r--')
    ax[1].set_title(f"Log Age ($R^2={r2_age:.3f}$)")
    ax[1].set_xlabel("True Log Age"); ax[1].set_ylabel("Pred Log Age")

    plt.tight_layout()
    plt.savefig('transformer_v2_prediction_analysis.png', dpi=300)
    plt.show()
    print("Plot saved: transformer_v2_prediction_analysis.png")

    # Save Final Model
    model_trans_v2.save('best_transformer_v2_regressor.keras')
    print("Model saved: best_transformer_v2_regressor.keras")

else:
    print("Error: Please run Section 14 first to load and scale data.")