# MODEL-5: Pattern 2 - 2D CNN + Temporal Architecture

This notebook demonstrates the Pattern 2 architecture for RUL prediction using:
- **STFT Spectrograms**: Time-frequency representation of vibration signals
- **2D CNN Backbone**: Spatial feature extraction from spectrograms
- **Late Fusion**: Combining horizontal and vertical channel embeddings
- **Temporal Aggregator**: LSTM or Transformer for sequence modeling
- **RUL Head**: Non-negative RUL prediction (optional uncertainty)

## Architecture Overview

```
Spectrogram (128, 128, 2)
    ↓
DualChannelCNN2DBackbone (per-channel or shared weights)
    ↓
LateFusion (concat, add, or weighted)
    ↓
TemporalAggregator (LSTM v1 or Transformer v2)
    ↓
RULHead (optional uncertainty via Gaussian output)
```

> **Note:** This notebook documents exploratory model development. The architectures investigated here informed the final benchmark models but are not part of the production evaluation pipeline.

In [None]:
import os
import sys
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Add project root to path
sys.path.insert(0, os.path.abspath('..'))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {keras.__version__}")

## 1. Load Data and Generate Spectrograms

First, we load the raw bearing data and generate spectrograms using STFT.

In [None]:
from src.data.loader import XJTUBearingLoader
from src.data.rul_labels import generate_rul_labels, RULStrategy
from src.features.stft import extract_spectrogram

# Initialize loader
loader = XJTUBearingLoader()
metadata = loader.get_metadata()

print(f"Dataset overview:")
print(f"  Conditions: {list(metadata.keys())}")
print(f"  Total bearings: {sum(len(c) for c in metadata.values())}")
print(f"  Sample rate: 25.6 kHz")

In [None]:
# Load data for one condition (35Hz12kN) for demonstration
condition = '35Hz12kN'
bearings = metadata[condition]

print(f"Loading data for condition: {condition}")
print(f"Bearings: {list(bearings.keys())}")

# Load first bearing to get file count
bearing_id = 'Bearing1_1'
signals, files = loader.load_bearing(condition, bearing_id)
print(f"\n{bearing_id}: {len(files)} files, signal shape per file: {signals[0].shape}")

In [None]:
# Generate spectrograms for a subset of the data
n_samples = min(50, len(signals))  # Limit for demo

print(f"Generating {n_samples} spectrograms...")
spectrograms = []

for i in range(n_samples):
    spec = extract_spectrogram(signals[i], sampling_rate=25600.0)
    spectrograms.append(spec)

spectrograms = np.stack(spectrograms)
print(f"Spectrograms shape: {spectrograms.shape}")

In [None]:
# Generate RUL labels
rul_labels = generate_rul_labels(
    total_files=len(files),
    strategy=RULStrategy.PIECEWISE_LINEAR,
    max_rul=125
)

# Get labels for our subset
y_subset = rul_labels[:n_samples].astype(np.float32).reshape(-1, 1)
print(f"RUL labels shape: {y_subset.shape}")
print(f"RUL range: {y_subset.min():.1f} - {y_subset.max():.1f}")

## 2. Visualize Spectrograms

In [None]:
# Plot spectrograms at different lifecycle stages
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

# Indices: early, mid, late
indices = [0, n_samples // 2, n_samples - 1]
titles = ['Early (Healthy)', 'Mid (Degrading)', 'Late (Near Failure)']

for col, (idx, title) in enumerate(zip(indices, titles)):
    # Horizontal channel
    axes[0, col].imshow(spectrograms[idx, :, :, 0], aspect='auto', origin='lower', cmap='viridis')
    axes[0, col].set_title(f'{title} - Horizontal\nRUL: {y_subset[idx, 0]:.1f}')
    axes[0, col].set_xlabel('Time Frame')
    axes[0, col].set_ylabel('Mel Bin')
    
    # Vertical channel
    axes[1, col].imshow(spectrograms[idx, :, :, 1], aspect='auto', origin='lower', cmap='viridis')
    axes[1, col].set_title(f'{title} - Vertical')
    axes[1, col].set_xlabel('Time Frame')
    axes[1, col].set_ylabel('Mel Bin')

plt.tight_layout()
plt.savefig('../outputs/models/pattern2_spectrograms.png', dpi=150)
plt.show()

## 3. Build and Compare Pattern 2 Models

In [None]:
from src.models.pattern2 import (
    create_pattern2_lstm,
    create_pattern2_transformer,
    create_simple_pattern2,
    create_pattern2_with_uncertainty,
    print_model_summary,
)

# Create model variants
print("=" * 60)
print("Pattern 2 Model Variants")
print("=" * 60)

# LSTM variant
print("\n1. LSTM Aggregator (v1):")
model_lstm = create_pattern2_lstm()
print_model_summary(model_lstm)

# Transformer variant
print("\n2. Transformer Aggregator (v2):")
model_transformer = create_pattern2_transformer()
print_model_summary(model_transformer)

# Simple variant (no temporal aggregation)
print("\n3. Simple (No Temporal Aggregation):")
model_simple = create_simple_pattern2()
print_model_summary(model_simple)

## 4. Train Pattern 2 LSTM Model

In [None]:
from src.training.config import TrainingConfig, compile_model, build_callbacks
from sklearn.model_selection import train_test_split

# Split data
X_train, X_val, y_train, y_val = train_test_split(
    spectrograms, y_subset, test_size=0.2, random_state=42
)

print(f"Training set: {X_train.shape[0]} samples")
print(f"Validation set: {X_val.shape[0]} samples")

In [None]:
# Create and compile model
model = create_pattern2_lstm()

# Training configuration
config = TrainingConfig(
    epochs=30,
    batch_size=8,
)

# Compile model
compile_model(model, config)

# Build callbacks (simplified for demo)
callbacks = [
    keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        verbose=1
    ),
]

In [None]:
# Train model
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=config.epochs,
    batch_size=config.batch_size,
    callbacks=callbacks,
    verbose=1
)

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(history.history['loss'], label='Train Loss')
axes[0].plot(history.history['val_loss'], label='Val Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss (Huber)')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True)

# MAE
axes[1].plot(history.history['mae'], label='Train MAE')
axes[1].plot(history.history['val_mae'], label='Val MAE')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('MAE')
axes[1].set_title('Training and Validation MAE')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig('../outputs/models/pattern2_training_history.png', dpi=150)
plt.show()

## 5. Evaluate Model

In [None]:
from src.training.metrics import rmse, mae, phm08_score

# Make predictions
y_pred_train = model.predict(X_train, verbose=0)
y_pred_val = model.predict(X_val, verbose=0)

# Compute metrics
print("Training Metrics:")
print(f"  RMSE: {rmse(y_train.flatten(), y_pred_train.flatten()):.4f}")
print(f"  MAE:  {mae(y_train.flatten(), y_pred_train.flatten()):.4f}")

print("\nValidation Metrics:")
print(f"  RMSE: {rmse(y_val.flatten(), y_pred_val.flatten()):.4f}")
print(f"  MAE:  {mae(y_val.flatten(), y_pred_val.flatten()):.4f}")

In [None]:
# Plot predictions vs actual
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Training set
axes[0].scatter(y_train, y_pred_train, alpha=0.7, label='Predictions')
axes[0].plot([0, 130], [0, 130], 'r--', label='Perfect')
axes[0].set_xlabel('True RUL')
axes[0].set_ylabel('Predicted RUL')
axes[0].set_title('Training Set: Predicted vs True RUL')
axes[0].legend()
axes[0].grid(True)

# Validation set
axes[1].scatter(y_val, y_pred_val, alpha=0.7, label='Predictions')
axes[1].plot([0, 130], [0, 130], 'r--', label='Perfect')
axes[1].set_xlabel('True RUL')
axes[1].set_ylabel('Predicted RUL')
axes[1].set_title('Validation Set: Predicted vs True RUL')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig('../outputs/models/pattern2_predictions.png', dpi=150)
plt.show()

## 6. Model with Uncertainty Quantification

In [None]:
# Create model with uncertainty output
model_unc = create_pattern2_with_uncertainty()
print("Pattern 2 with Uncertainty:")
print_model_summary(model_unc)

# Note: Training with uncertainty requires custom loss function
# For now, just demonstrate the model structure

In [None]:
# Test uncertainty output
mean_pred, var_pred = model_unc.predict(X_val[:5], verbose=0)
std_pred = np.sqrt(var_pred)

print("Uncertainty Output Example:")
print(f"{'True RUL':<10} {'Mean':<10} {'Std':<10} {'95% CI':>20}")
print("-" * 50)
for i in range(5):
    ci_low = mean_pred[i, 0] - 1.96 * std_pred[i, 0]
    ci_high = mean_pred[i, 0] + 1.96 * std_pred[i, 0]
    print(f"{y_val[i, 0]:<10.1f} {mean_pred[i, 0]:<10.4f} {std_pred[i, 0]:<10.4f} [{ci_low:.2f}, {ci_high:.2f}]")

## 7. Summary

### Model Comparison

| Model | Parameters | Input Shape | Notes |
|-------|------------|-------------|-------|
| Pattern 2 LSTM | ~693K | (128, 128, 2) | Bidirectional LSTM aggregator |
| Pattern 2 Transformer | ~3.8M | (128, 128, 2) | 2-layer Transformer aggregator |
| Pattern 2 Simple | ~423K | (128, 128, 2) | No temporal aggregation |
| Pattern 2 Uncertainty | ~694K | (128, 128, 2) | Outputs mean + variance |

### Key Findings

1. **2D CNN + Temporal architecture** successfully processes spectrograms for RUL prediction
2. **Late fusion** effectively combines horizontal and vertical channel embeddings
3. **LSTM and Transformer aggregators** both work for sequence modeling
4. **Uncertainty quantification** provides confidence intervals for predictions
5. Model trains and loss decreases, indicating learning is occurring

### Next Steps

1. Train on full dataset with proper cross-validation
2. Compare with Pattern 1 (TCN-Transformer) architecture
3. Implement sequence-of-spectrograms mode for temporal context
4. Add CWT scalogram frontend (FEAT-8) for comparison

In [None]:
# Save the trained model
import os
os.makedirs('../outputs/models', exist_ok=True)
model.save('../outputs/models/pattern2_lstm_demo.keras')
print("Model saved to outputs/models/pattern2_lstm_demo.keras")