# TCN-Transformer Model Training (MODEL-4)

This notebook demonstrates training the TCN-Transformer Pattern 1 architecture for RUL prediction.

## Architecture Overview

```
Input (batch, 32768, 2)  # Raw vibration signals (H, V channels)
    ↓
DualChannelStem          # Per-sensor Conv1D(k=7) → GELU → Conv1D(k=3)
    ↓
DualChannelTCN           # Dilated convolutions (d=1,2,4,8,16,32)
    ↓
TemporalDownsample       # Reduce sequence length (32768 → 2048)
    ↓
BidirectionalCrossAttn   # H attends to V, V attends to H
    ↓
ChannelFusion            # Concatenate H and V features
    ↓
TemporalAggregator       # LSTM (v1) or Transformer (v2)
    ↓
RULHead                  # Dense → ReLU → RUL prediction
```

> **Note:** This notebook documents exploratory model development. The architectures investigated here informed the final benchmark models but are not part of the production evaluation pipeline.

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf

# Add project root to path
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

print(f'TensorFlow version: {tf.__version__}')
print(f'NumPy version: {np.__version__}')

## 1. Import Model Components

In [None]:
from src.models.pattern1 import (
    create_tcn_transformer_lstm,
    create_tcn_transformer_transformer,
    TCNTransformerConfig,
    StemConfig,
    TCNConfig,
    AttentionConfig,
    LSTMAggregatorConfig,
    build_tcn_transformer_model,
    print_model_summary,
)

from src.training.config import TrainingConfig, compile_model, build_callbacks
from src.training.metrics import rmse, mae, phm08_score, print_evaluation_report
from src.training.cv import leave_one_bearing_out_cv, CVSplit

from src.data.loader import XJTUBearingLoader
from src.data.rul_labels import generate_rul_labels

print('Imports successful!')

## 2. Create and Inspect Model Architectures

In [None]:
# Create LSTM aggregator variant (v1)
model_lstm = create_tcn_transformer_lstm(
    input_length=32768,
    filters=64,
    dilations=[1, 2, 4, 8, 16, 32],
    lstm_units=64,
)

print('=== TCN-Transformer with LSTM Aggregator (v1) ===')
print_model_summary(model_lstm)

In [None]:
# Create Transformer aggregator variant (v2)
model_transformer = create_tcn_transformer_transformer(
    input_length=32768,
    filters=64,
    dilations=[1, 2, 4, 8, 16, 32],
    num_transformer_layers=2,
    num_heads=4,
)

print('=== TCN-Transformer with Transformer Aggregator (v2) ===')
print_model_summary(model_transformer)

In [None]:
# Display full model architecture
model_lstm.summary()

## 3. TCN Receptive Field Analysis

In [None]:
from src.models.pattern1.tcn import TCNEncoder, TCNConfig

# Analyze receptive field
tcn_config = TCNConfig(
    filters=64,
    kernel_size=3,
    dilations=[1, 2, 4, 8, 16, 32],
)

tcn = TCNEncoder(config=tcn_config)
tcn.build((None, 32768, 64))

rf = tcn.compute_receptive_field()
sampling_rate = 25600  # Hz
rf_ms = rf / sampling_rate * 1000

print(f'TCN Configuration:')
print(f'  Kernel size: {tcn_config.kernel_size}')
print(f'  Dilations: {tcn_config.dilations}')
print(f'  Receptive field: {rf} samples ({rf_ms:.2f} ms at 25.6kHz)')
print(f'  Input coverage: {rf/32768*100:.2f}% of full signal')

## 4. Load Data

In [None]:
# Load bearing data
loader = XJTUBearingLoader()
metadata = loader.get_metadata()

print(f'Dataset: {metadata["total_files"]} files')
print(f'Conditions: {list(metadata["conditions"].keys())}')
print(f'Bearings per condition: 5')

In [None]:
# Load a single bearing for demonstration
condition = '35Hz12kN'
bearing_id = 'Bearing1_1'

print(f'Loading {bearing_id} from {condition}...')
signals, file_paths = loader.load_bearing(condition, bearing_id)

# Generate RUL labels
rul_labels = generate_rul_labels(
    num_files=len(file_paths),
    strategy='piecewise_linear',
    max_rul=125
)

print(f'Signal shape: {signals.shape}')  # (num_files, 32768, 2)
print(f'RUL labels shape: {rul_labels.shape}')
print(f'RUL range: [{rul_labels.min():.1f}, {rul_labels.max():.1f}]')

## 5. Quick Training Demo

Demonstrate training on a subset of data to verify the pipeline works.

In [None]:
# Use subset for demo (first 50 samples)
n_demo = min(50, len(signals))
X_demo = signals[:n_demo]
y_demo = rul_labels[:n_demo].reshape(-1, 1)

# Train/val split
split_idx = int(0.8 * n_demo)
X_train, X_val = X_demo[:split_idx], X_demo[split_idx:]
y_train, y_val = y_demo[:split_idx], y_demo[split_idx:]

print(f'Training samples: {len(X_train)}')
print(f'Validation samples: {len(X_val)}')

In [None]:
# Create and compile model
model = create_tcn_transformer_lstm(input_length=32768, filters=32)  # Smaller for demo

training_config = TrainingConfig(
    learning_rate=1e-3,
    batch_size=4,
    epochs=10,
)

compile_model(model, training_config)
print('Model compiled!')

In [None]:
# Train
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=training_config.epochs,
    batch_size=training_config.batch_size,
    verbose=1,
)

In [None]:
# Plot training history
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(history.history['loss'], label='Train Loss')
ax.plot(history.history['val_loss'], label='Val Loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Huber Loss')
ax.set_title('Training History')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 6. Evaluate Predictions

In [None]:
# Make predictions
y_pred = model.predict(X_demo, verbose=0)

# Calculate metrics
y_true = y_demo.flatten()
y_pred_flat = y_pred.flatten()

print('=== Evaluation Metrics ===')
print(f'RMSE: {rmse(y_true, y_pred_flat):.2f}')
print(f'MAE: {mae(y_true, y_pred_flat):.2f}')
print(f'PHM08 Score: {phm08_score(y_true, y_pred_flat):.2f}')

In [None]:
# Plot predictions vs ground truth
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Time series plot
ax1 = axes[0]
ax1.plot(y_true, 'b-', label='Ground Truth', linewidth=2)
ax1.plot(y_pred_flat, 'r--', label='Prediction', linewidth=2)
ax1.set_xlabel('Sample Index')
ax1.set_ylabel('RUL')
ax1.set_title('RUL Prediction vs Ground Truth')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Scatter plot
ax2 = axes[1]
ax2.scatter(y_true, y_pred_flat, alpha=0.6)
max_val = max(y_true.max(), y_pred_flat.max())
ax2.plot([0, max_val], [0, max_val], 'k--', label='Perfect Prediction')
ax2.set_xlabel('True RUL')
ax2.set_ylabel('Predicted RUL')
ax2.set_title('Prediction Scatter Plot')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_aspect('equal')

plt.tight_layout()
plt.show()

## 7. Custom Configuration Example

In [None]:
# Example: Create custom configuration
custom_config = TCNTransformerConfig(
    input_length=32768,
    num_channels=2,
    stem_config=StemConfig(
        filters=32,
        kernel_size_1=7,
        kernel_size_2=3,
        use_batch_norm=True,
    ),
    tcn_config=TCNConfig(
        filters=32,
        kernel_size=3,
        dilations=[1, 2, 4, 8, 16],  # Shorter dilations
        use_batch_norm=True,
        dropout_rate=0.2,
    ),
    attention_config=AttentionConfig(
        num_heads=2,
        key_dim=32,
        dropout_rate=0.1,
    ),
    aggregator_type='lstm',
    lstm_config=LSTMAggregatorConfig(
        units=32,
        bidirectional=True,
        pooling='last',
    ),
    fusion_mode='concat',
    hidden_dim=32,
    dropout_rate=0.1,
    use_downsampling=True,
    downsample_factor=16,
)

custom_model = build_tcn_transformer_model(custom_config, name='custom_tcn_transformer')
print_model_summary(custom_model)

## 8. Model Comparison Summary

| Model | Parameters | Aggregator | Use Case |
|-------|------------|------------|----------|
| TCN-Transformer-LSTM (v1) | ~418K | Bidirectional LSTM | General purpose, good baseline |
| TCN-Transformer-Transformer (v2) | ~1.2M | Transformer Encoder | Better for long-range dependencies |

## Notes

### Architecture Highlights
- **Per-sensor stem**: Learns channel-specific features with Conv1D(k=7) → GELU → Conv1D(k=3)
- **Multi-resolution TCN**: Dilated convolutions capture patterns at multiple time scales
- **Cross-attention**: Allows H and V channels to exchange information
- **Temporal downsampling**: Reduces sequence length before attention to manage memory
- **Monotonic RUL output**: ReLU ensures non-negative predictions

### Training Tips
1. Start with smaller `filters` (32) for faster iteration
2. Use `use_downsampling=True` with `downsample_factor=16` to manage memory
3. Monitor validation loss to detect overfitting
4. Consider windowing for very large datasets (see `src/data/dataset.py`)