In [None]:
# 11. CLASSIFICATION MODEL ARCHITECTURES (REVISED: PRE-NORM & EXTERNAL COMPILE)
# ==============================================================================
import tensorflow as tf
from tensorflow.keras import layers, models, regularizers

print("Building Neural Network Architectures (Mode: Classification)...")

# ------------------------------------------------------------------------------
# MODEL 1: ResNet-MLP (Deep Residual Network) - CLASSIFICATION
# REVISION: Post-Norm -> Pre-Norm & Remove Internal Compile
# ------------------------------------------------------------------------------
def build_mlp_classifier(input_dim, output_dim=5, width=256, depth=4, dropout_rate=0.2):
    inputs = layers.Input(shape=(input_dim,), name='input_features')

    # Project to initial dimension
    x = layers.Dense(width, activation='linear')(inputs)

    # --- Residual Blocks Loop (Pre-Norm Style) ---
    for i in range(depth):
        shortcut = x # Save main path (Skip Connection)

        # 1. Normalize first (Pre-Norm)
        # This stabilizes gradients in deep networks
        x_norm = layers.LayerNormalization()(x)

        # 2. Transformation Block (Dense -> Dropout -> Dense)
        branch = layers.Dense(width, activation='gelu')(x_norm)
        branch = layers.Dropout(dropout_rate)(branch)
        branch = layers.Dense(width, activation='linear')(branch)

        # 3. Merge back (Add)
        x = layers.Add()([shortcut, branch])

    # Final Norm (Mandatory at the end of Pre-Norm networks)
    x = layers.LayerNormalization()(x)

    # Head
    x = layers.Dense(64, activation='gelu')(x)

    # Output Layer
    outputs = layers.Dense(output_dim, activation='softmax', name='class_output')

    # NOTE: model.compile is REMOVED here
    # This allows us to inject external Schedulers/Optimizers later during the training loop.
    model = models.Model(inputs=inputs, outputs=outputs, name="ResNet_MLP_Classifier")
    return model

# ------------------------------------------------------------------------------
# MODEL 2: TABULAR TRANSFORMER - CLASSIFICATION
# REVISION: GAP -> Flatten & Remove Internal Compile
# ------------------------------------------------------------------------------
def build_transformer_classifier(input_dim, output_dim=5, embed_dim=64, num_heads=4, num_blocks=3, dropout=0.1):
    inputs = layers.Input(shape=(input_dim,), name='input_features')

    # 1. Feature Tokenizer
    # Treats each feature as a token (Simple embedding via Conv1D)
    x = layers.Reshape((input_dim, 1))(inputs)
    x = layers.Conv1D(embed_dim, kernel_size=1, activation=None)(x)

    # 2. Transformer Blocks (Pre-Norm Style)
    for _ in range(num_blocks):
        # Attention Block
        x_norm = layers.LayerNormalization(epsilon=1e-6)(x)
        attn_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim // num_heads, dropout=dropout)(x_norm, x_norm)
        x = layers.Add()([x, attn_output])

        # Feed Forward Block
        x_norm = layers.LayerNormalization(epsilon=1e-6)(x)
        ffn = models.Sequential([
            layers.Dense(embed_dim * 2, activation='gelu'),
            layers.Dropout(dropout),
            layers.Dense(embed_dim)
        ])
        x = layers.Add()([x, ffn(x_norm)])

    # 3. Prediction Head (REVISION: Flatten)
    # Replaced GlobalAveragePooling with Flatten to preserve specific feature positions
    # (Important for tabular data where column order is fixed)
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = layers.Flatten()(x)

    x = layers.Dense(64, activation='gelu')(x)
    x = layers.Dropout(dropout)(x)

    # Output
    outputs = layers.Dense(output_dim, activation='softmax', name='class_output')

    # NOTE: model.compile is REMOVED here as well
    model = models.Model(inputs=inputs, outputs=outputs, name="Transformer_Classifier")
    return model

print("Classification Architectures Ready (Pre-Norm, Flatten, & External Optimizer Ready).")

In [None]:
# 11. CLASSIFICATION TRAINING: PRE-NORM RESNET-MLP (REVISED: STABILITY & SCHEDULER)
# ==============================================================================
try:
    import keras_tuner as kt
except ImportError:
    !pip install keras_tuner -q
    import keras_tuner as kt

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import QuantileTransformer
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import os, shutil

# Set Scientific Plot Style
plt.rcParams['font.family'] = 'serif'
plt.rcParams['axes.grid'] = False

print("Starting Pre-Norm ResNet Tuning Pipeline (Mode: Stable & Warmup)...")

# --- 1. LOAD DATA ---
if 'df_classification' in locals():
    df_use = df_classification.copy()
else:
    try:
        df_use = pd.read_parquet("df_cls_final.parquet")
        print("   -> Loaded from Parquet.")
    except:
        raise ValueError("Error: df_classification not found!")

input_cols = [
    'bp_rp0', 'bp_g', 'g_rp',
    'abs_G0', 'parallax', 'ruwe',
    'l_norm', 'b_norm'
]

# Check columns
available_cols = [c for c in input_cols if c in df_use.columns]
target_col = 'label_code'

X = df_use[available_cols].values
y = df_use[target_col].values

# --- FEATURE CORRELATION CHECK ---
print("Generating Feature Correlation Matrix...")
plt.figure(figsize=(10, 8))
corr_matrix = df_use[available_cols].corr()
sns.heatmap(corr_matrix, annot=True, fmt=".2f", cmap='coolwarm', square=True)
plt.title("Feature Correlation Matrix")
plt.tight_layout()
plt.show()

# Split Data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Scaling (QuantileTransformer is best for non-Gaussian Astro data)
print("   -> Scaling Data (QuantileTransformer)...")
scaler = QuantileTransformer(output_distribution='normal')
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Class Weights (To handle any remaining imbalance)
from sklearn.utils import class_weight
classes = np.unique(y_train)
weights = class_weight.compute_class_weight('balanced', classes=classes, y=y_train)
class_weight_dict = dict(zip(classes, weights))
print("Class Weights:", class_weight_dict)

# --- SCHEDULER CONFIGURATION ---
BATCH_SIZE = 512
EPOCHS_SEARCH = 20
EPOCHS_FINAL = 50

# --- PRE-NORM MODEL DEFINITION (CORE ARCHITECTURE) ---
def build_resnet_prenorm_tuner(hp):
    inputs = layers.Input(shape=(X_train.shape[1],))

    # Hyperparameters
    width = hp.Int('width', min_value=128, max_value=256, step=64)
    dropout_rate = hp.Float('dropout_rate', 0.2, 0.5, step=0.1)

    # Initial Linear Projection
    x = layers.Dense(width, activation='linear')(inputs)

    # --- RESIDUAL BLOCKS (PRE-NORM STYLE) ---
    # Concept: Input -> [Norm -> Dense -> Dropout -> Dense] + Input
    # Pre-Norm allows smoother gradient flow in deeper networks.
    for i in range(hp.Int('num_blocks', 1, 3)):
        # 1. Save Identity (Skip Connection)
        shortcut = x

        # 2. Normalization FIRST (Pre-Norm)
        x_norm = layers.LayerNormalization()(x)

        # 3. Transformation Branch
        branch = layers.Dense(width, activation='gelu')(x_norm)
        branch = layers.Dropout(dropout_rate)(branch)
        branch = layers.Dense(width, activation='linear')(branch) # Linear ensures additivity

        # 4. Add (Residual Connection)
        x = layers.Add()([shortcut, branch])

    # --- OUTPUT HEAD ---
    x = layers.LayerNormalization()(x) # Final Norm is mandatory for Pre-Norm
    x = layers.Dense(64, activation='gelu')(x)
    outputs = layers.Dense(5, activation='softmax')(x)

    # --- OPTIMIZER WITH WARMUP & COSINE DECAY ---
    steps_per_epoch = len(X_train_scaled) // BATCH_SIZE
    total_steps = steps_per_epoch * EPOCHS_SEARCH

    lr_max = hp.Choice('learning_rate', [1e-3, 5e-4])

    lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate=1e-5,      # Start low
        decay_steps=total_steps,
        alpha=0.01,
        warmup_target=lr_max,            # Ramp up to this
        warmup_steps=int(0.1 * total_steps)
    )

    model = keras.Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer=optimizers.AdamW(learning_rate=lr_schedule, weight_decay=1e-4),
                  loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

# --- TUNING EXECUTION ---
if os.path.exists('tuning_dir/gaia_resnet_v2'): shutil.rmtree('tuning_dir/gaia_resnet_v2')

tuner = kt.Hyperband(build_resnet_prenorm_tuner,
                     objective='val_accuracy',
                     max_epochs=EPOCHS_SEARCH,
                     factor=3,
                     directory='tuning_dir',
                     project_name='gaia_resnet_v2')

print("\nSearching for Best Hyperparameters...")
stop_early = keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)

tuner.search(X_train_scaled, y_train,
             epochs=EPOCHS_SEARCH,
             validation_split=0.2,
             callbacks=[stop_early],
             batch_size=BATCH_SIZE,
             class_weight=class_weight_dict,
             verbose=1)

best_hps = tuner.get_best_hyperparameters(num_trials=1)[0]
print(f"Best Width: {best_hps.get('width')}, Best LR: {best_hps.get('learning_rate')}")

# --- FINAL RETRAINING (FULL SCHEDULER) ---
print("\nRetraining Best Model (Pre-Norm)...")

# Rebuild model
model_final = build_resnet_prenorm_tuner(best_hps)

# Recalculate Scheduler for Final Epochs
total_steps_final = (len(X_train_scaled) // BATCH_SIZE) * EPOCHS_FINAL
lr_final = best_hps.get('learning_rate')

new_schedule = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=1e-6,
    decay_steps=total_steps_final,
    alpha=0.01,
    warmup_target=lr_final,
    warmup_steps=int(0.1 * total_steps_final)
)

# Recompile
model_final.compile(optimizer=optimizers.AdamW(learning_rate=new_schedule, weight_decay=1e-4),
                    loss='sparse_categorical_crossentropy', metrics=['accuracy'])

history = model_final.fit(
    X_train_scaled, y_train,
    epochs=EPOCHS_FINAL,
    validation_split=0.2,
    batch_size=BATCH_SIZE,
    class_weight=class_weight_dict,
    callbacks=[keras.callbacks.EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True)],
    verbose=1
)

# --- EVALUATION ---
# A. Learning Curves
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title('Loss Curve (Pre-Norm)')
plt.xlabel('Epochs'); plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Train Acc')
plt.plot(history.history['val_accuracy'], label='Val Acc')
plt.title('Accuracy Curve')
plt.xlabel('Epochs'); plt.ylabel('Accuracy')
plt.legend()
plt.tight_layout()
plt.savefig('learning_curves_prenorm.png', dpi=300)
plt.show()

# B. Classification Report & Matrix
y_pred_probs = model_final.predict(X_test_scaled)
y_pred = np.argmax(y_pred_probs, axis=1)

# Standard English Labels
class_labels = ['Main Sequence', 'Sub-Giant', 'Red Giant', 'Supergiant', 'White Dwarf']

print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=class_labels))

plt.figure(figsize=(10, 8))
cm = confusion_matrix(y_test, y_pred)
# Normalize
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Blues',
            xticklabels=class_labels, yticklabels=class_labels)
plt.title('Normalized Confusion Matrix (Pre-Norm ResNet)')
plt.ylabel('Actual Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.savefig('confusion_matrix_prenorm.png', dpi=300)
plt.show()

model_final.save('best_resnet_prenorm_classifier.keras')
print("Model saved: best_resnet_prenorm_classifier.keras")

In [None]:
# 12. TRANSFORMER CLASSIFICATION TRAINING (REVISED: PRE-NORM, WARMUP & STABILITY)
# ==============================================================================
print("Starting FT-Transformer Pipeline (Mode: Pre-Norm & Warmup)...")

# --- DATA PREPARATION ---
if 'df_classification' in locals():
    df_use = df_classification.copy()
else:
    try:
        df_use = pd.read_parquet("df_cls_final.parquet")
    except:
        raise ValueError("Error: Data not found! Please run Preprocessing first.")

input_cols = ['bp_rp0', 'bp_g', 'g_rp', 'abs_G0', 'parallax', 'ruwe', 'l_norm', 'b_norm']
X = df_use[input_cols].values
y = df_use['label_code'].values

# --- [VISUALIZATION] FEATURE CORRELATION (TRANSFORMER) ---
print("Generating Feature Correlation Matrix (Transformer)...")
plt.figure(figsize=(10, 8))
corr_matrix = df_use[input_cols].corr()
sns.heatmap(corr_matrix, annot=True, fmt=".2f", cmap='coolwarm', square=True)
plt.title("Feature Correlation Matrix (Transformer)")
plt.tight_layout()
plt.savefig('transformer_feature_correlation.png', dpi=300)
plt.show()

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
scaler = QuantileTransformer(output_distribution='normal')
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

NUM_FEATURES = X_train_scaled.shape[1]
NUM_CLASSES = 5
classes = np.unique(y_train)
weights = class_weight.compute_class_weight('balanced', classes=classes, y=y_train)
class_weight_dict = dict(zip(classes, weights))
print("Class Weights:", class_weight_dict)

# --- SCHEDULER CONFIGURATION ---
BATCH_SIZE = 512
EPOCHS_SEARCH = 15
EPOCHS_FINAL = 50

# --- FT-TRANSFORMER ARCHITECTURE (REVISED: PRE-NORM) ---
def build_ft_transformer_prenorm(hp):
    inputs = layers.Input(shape=(NUM_FEATURES,))

    # Feature Tokenizer (Simple Linear Embedding per feature)
    x = layers.Reshape((NUM_FEATURES, 1))(inputs)
    embed_dim = hp.Int('embed_dim', min_value=32, max_value=64, step=32)
    x = layers.Conv1D(filters=embed_dim, kernel_size=1, activation=None)(x)

    # --- Transformer Blocks (Pre-Norm Style) ---
    # Pre-Norm is crucial for stability: Norm -> Attention -> Add
    for i in range(hp.Int('num_blocks', 1, 3)):
        # 1. Attention Block
        x_norm = layers.LayerNormalization(epsilon=1e-6)(x) # Normalize FIRST

        num_heads = hp.Int(f'num_heads_{i}', 2, 4, step=2)
        attn_output = layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=embed_dim // num_heads,
            dropout=hp.Float(f'attn_dropout_{i}', 0.0, 0.2, step=0.1)
        )(x_norm, x_norm)

        x = layers.Add()([x, attn_output]) # Skip Connection

        # 2. Feed Forward Block
        x_norm = layers.LayerNormalization(epsilon=1e-6)(x) # Normalize FIRST

        ff_dim = embed_dim * 2
        ffn = keras.Sequential([
            layers.Dense(ff_dim, activation="gelu"),
            layers.Dropout(hp.Float(f'ffn_dropout_{i}', 0.0, 0.2, step=0.1)),
            layers.Dense(embed_dim)
        ])
        ffn_output = ffn(x_norm)

        x = layers.Add()([x, ffn_output]) # Skip Connection

    # --- Head (Flatten & Final Norm) ---
    x = layers.LayerNormalization(epsilon=1e-6)(x) # Final Norm is mandatory
    x = layers.Flatten()(x) # Flatten preserves feature positional info

    x = layers.Dense(64, activation='gelu')(x)
    x = layers.Dropout(0.1)(x)
    outputs = layers.Dense(NUM_CLASSES, activation='softmax')(x)

    # --- Optimizer with Warmup & Cosine Decay ---
    # Transformer needs warmup to avoid early divergence
    steps_per_epoch = len(X_train_scaled) // BATCH_SIZE
    total_steps = steps_per_epoch * EPOCHS_SEARCH

    lr_max = hp.Choice('learning_rate', values=[1e-3, 5e-4])

    lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate=1e-5,      # Start very small
        decay_steps=total_steps,
        alpha=0.01,
        warmup_target=lr_max,            # Ramp up to this target
        warmup_steps=int(0.1 * total_steps)
    )

    weight_decay = hp.Choice('weight_decay', values=[1e-4, 1e-5])

    model = keras.Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer=optimizers.AdamW(learning_rate=lr_schedule, weight_decay=weight_decay),
                  loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

# --- TUNING EXECUTION ---
if os.path.exists('tuning_dir/gaia_transformer_v2'): shutil.rmtree('tuning_dir/gaia_transformer_v2')

tuner = kt.Hyperband(build_ft_transformer_prenorm,
                     objective='val_accuracy',
                     max_epochs=EPOCHS_SEARCH,
                     factor=3,
                     directory='tuning_dir',
                     project_name='gaia_transformer_v2',
                     overwrite=True)

print("\nTuning FT-Transformer (Pre-Norm)...")
stop_early = keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)

tuner.search(X_train_scaled, y_train,
             epochs=EPOCHS_SEARCH,
             batch_size=BATCH_SIZE,
             validation_split=0.2,
             callbacks=[stop_early],
             class_weight=class_weight_dict,
             verbose=1)

best_hps = tuner.get_best_hyperparameters(num_trials=1)[0]

# --- FINAL TRAINING (FULL SCHEDULER) ---
print("\nTraining Final Model (Pre-Norm + Full Warmup)...")

# Rebuild manually to reset weights
best_transformer = build_ft_transformer_prenorm(best_hps)

# Update Scheduler for full 50 epochs
total_steps_final = (len(X_train_scaled) // BATCH_SIZE) * EPOCHS_FINAL
lr_final_val = best_hps.get('learning_rate')

new_schedule = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=1e-5,
    decay_steps=total_steps_final,
    alpha=0.01,
    warmup_target=lr_final_val,
    warmup_steps=int(0.1 * total_steps_final)
)

weight_decay_final = best_hps.get('weight_decay')

# Recompile
best_transformer.compile(optimizer=optimizers.AdamW(learning_rate=new_schedule, weight_decay=weight_decay_final),
                         loss='sparse_categorical_crossentropy', metrics=['accuracy'])

history = best_transformer.fit(
    X_train_scaled, y_train,
    epochs=EPOCHS_FINAL,
    batch_size=BATCH_SIZE,
    validation_data=(X_test_scaled, y_test),
    class_weight=class_weight_dict,
    callbacks=[keras.callbacks.EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True)],
    verbose=1
)

# --- EVALUATION & VISUALIZATION ---
print("\nTransformer Evaluation:")
loss, acc = best_transformer.evaluate(X_test_scaled, y_test, batch_size=512)
print(f"   Accuracy: {acc:.2%}")

# A. Learning Curves
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title('Loss Curve (Pre-Norm)')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Train Acc')
plt.plot(history.history['val_accuracy'], label='Val Acc')
plt.title('Accuracy Curve')
plt.legend()
plt.tight_layout()
plt.savefig('transformer_learning_curves_prenorm.png', dpi=300)
plt.show()

# B. Confusion Matrix
y_pred = np.argmax(best_transformer.predict(X_test_scaled, batch_size=512), axis=1)
labels = ['Main Sequence', 'Sub-Giant', 'Red Giant', 'Supergiant', 'White Dwarf']

cm = confusion_matrix(y_test, y_pred)
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

plt.figure(figsize=(10, 8))
sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Blues', xticklabels=labels, yticklabels=labels)
plt.title('Normalized Confusion Matrix - Pre-Norm Transformer')
plt.ylabel('Actual Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.savefig('transformer_confusion_matrix_prenorm.png', dpi=300)
plt.show()

print(classification_report(y_test, y_pred, target_names=labels))
best_transformer.save("best_transformer_classifier_prenorm.keras")
print("Model saved: best_transformer_classifier_prenorm.keras")