# TimesNet Mid-Heavy Configuration - Financial Data Training

This notebook contains a **mid-heavy TimesNet configuration** optimized for:
- High-performance training with substantial computational resources
- Complex pattern recognition in financial time series
- Production-ready model development
- Advanced feature extraction and long-range dependencies

**Dataset**: Financial time series with 4 targets + 114 covariates (118 total features)
**Training Time**: ~20-40 minutes per epoch
**Memory Requirements**: High (recommend 8GB+ GPU memory)

In [None]:
# Import required libraries
import os
import sys
import time
import torch
import numpy as np
import pandas as pd
from datetime import datetime

# Add project root to path
sys.path.append(os.path.dirname(os.path.abspath('.')))

from models.TimesNet import Model as TimesNet
from utils.tools import EarlyStopping, adjust_learning_rate
from utils.metrics import metric
from utils.logger import logger
from data_provider.data_loader import Dataset_Custom
from torch.utils.data import DataLoader

print("✅ All imports successful")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"💻 Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
if torch.cuda.is_available():
    print(f"🚀 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")

## 🔧 Mid-Heavy Configuration Parameters

**Purpose**: High-capacity training for complex pattern recognition and production deployment

In [None]:
# ================================
# MID-HEAVY CONFIGURATION - TIMESNET
# ================================

class MidHeavyConfig:
    # === DATA CONFIGURATION ===
    data = 'custom'                    # Dataset type (custom for prepared financial data)
    root_path = './data/'              # Root directory for data files
    data_path = 'prepared_financial_data.csv'  # Main data file
    features = 'M'                     # Forecasting mode: 'M'=Multivariate, 'S'=Univariate, 'MS'=Multivariate-to-Univariate
    target = 'log_Close'               # Primary target column (for 'S' mode)
    freq = 'b'                         # Time frequency: 'b'=business day, 'h'=hourly, 'd'=daily
    
    # === SEQUENCE PARAMETERS ===
    seq_len = 200                      # Input sequence length (lookback window) - MID-HEAVY: longer context
    label_len = 20                     # Start token length for decoder input (overlap with seq_len)
    pred_len = 20                      # Prediction horizon (how many steps to forecast) - MID-HEAVY: longer predictions
    
    # === TRAIN/VAL/TEST SPLITS ===
    val_len = 20                       # Validation set length in time steps
    test_len = 20                      # Test set length in time steps
    prod_len = 20                      # Production forecast length (future predictions beyond data)
    
    # === TIMESNET MODEL ARCHITECTURE ===
    # Core dimensions
    enc_in = 118                       # Encoder input size (total features: 4 targets + 114 covariates)
    dec_in = 118                       # Decoder input size (usually same as enc_in)
    c_out = 118                        # Output size (must match enc_in to avoid dimension mismatch)
    d_model = 128                      # Model dimension (embedding size) - MID-HEAVY: large capacity
    d_ff = 256                         # Feed-forward network dimension - MID-HEAVY: large FFN
    
    # Attention mechanism
    n_heads = 8                        # Number of attention heads - MID-HEAVY: more heads for complex patterns
    e_layers = 4                       # Number of encoder layers - MID-HEAVY: deeper network
    d_layers = 2                       # Number of decoder layers - MID-HEAVY: deeper decoder
    
    # TimesNet specific parameters
    top_k = 8                          # Top-k frequencies for TimesNet decomposition - MID-HEAVY: more frequencies
    num_kernels = 8                    # Number of convolution kernels in Inception blocks - MID-HEAVY: more kernels
    
    # Regularization
    dropout = 0.15                     # Dropout rate for regularization - MID-HEAVY: higher dropout for large model
    
    # Additional model settings
    embed = 'timeF'                    # Time feature embedding: 'timeF'=time features, 'fixed'=learnable, 'learned'=learned
    activation = 'gelu'                # Activation function: 'gelu', 'relu', 'swish'
    factor = 1                         # Attention factor (usually 1)
    distil = True                      # Whether to use knowledge distillation
    moving_avg = 50                    # Moving average window for trend decomposition - MID-HEAVY: longer window
    output_attention = False           # Whether to output attention weights
    
    # === TRAINING CONFIGURATION ===
    train_epochs = 50                  # Number of training epochs - MID-HEAVY: more epochs for convergence
    batch_size = 16                    # Batch size - MID-HEAVY: smaller batch due to larger model
    learning_rate = 0.0005             # Learning rate - MID-HEAVY: smaller for stable training
    patience = 15                      # Early stopping patience - MID-HEAVY: more patience for complex model
    lradj = 'type1'                    # Learning rate adjustment strategy
    
    # Loss and optimization
    loss = 'MSE'                       # Loss function: 'MSE', 'MAE', 'Huber'
    use_amp = True                     # Automatic mixed precision (recommended for large models)
    
    # System settings
    num_workers = 8                    # DataLoader workers - MID-HEAVY: more workers for data loading
    seed = 2024                        # Random seed for reproducibility
    
    # Task specific
    task_name = 'short_term_forecast'  # Task type: 'short_term_forecast' for financial prediction
    
    # Experiment tracking
    des = 'mid_heavy_config'           # Experiment description
    checkpoints = f'./checkpoints/TimesNet_mid_heavy_{datetime.now().strftime("%Y%m%d_%H%M")}'
    
# Create config instance
args = MidHeavyConfig()

print("🔧 Mid-Heavy Configuration Loaded:")
print(f"   📏 Sequence Length: {args.seq_len}")
print(f"   🎯 Prediction Length: {args.pred_len}")
print(f"   🧠 Model Dimension: {args.d_model}")
print(f"   ⚡ Epochs: {args.train_epochs}")
print(f"   📊 Batch Size: {args.batch_size}")
print(f"   🔄 Attention Heads: {args.n_heads}")
print(f"   🏗️ Encoder Layers: {args.e_layers}")

## 🎛️ Tweakable Parameters

Modify these parameters to experiment with different mid-heavy configurations:

In [None]:
# ================================
# TWEAKABLE PARAMETERS - EXPERIMENT
# ================================

# Modify these for mid-heavy experiments:

# --- Sequence parameters (affect context and prediction complexity) ---
args.seq_len = 200         # Try: 150, 200, 300 (longer = more historical context)
args.pred_len = 20         # Try: 15, 20, 30 (longer = more challenging forecasting)
args.label_len = 20        # Try: 15, 20, 30 (usually 10-20% of seq_len)

# --- Model architecture (affect capacity and computational cost) ---
args.d_model = 128         # Try: 96, 128, 192 (larger = more representation power)
args.d_ff = 256            # Try: 192, 256, 384 (usually 1.5-2x d_model)
args.n_heads = 8           # Try: 6, 8, 12 (must divide d_model evenly)
args.e_layers = 4          # Try: 3, 4, 6 (more layers = deeper feature extraction)
args.d_layers = 2          # Try: 1, 2, 3 (decoder depth)

# --- TimesNet specific (affect frequency decomposition) ---
args.top_k = 8             # Try: 6, 8, 10 (more frequencies = richer patterns)
args.num_kernels = 8       # Try: 6, 8, 12 (more kernels = more feature maps)
args.moving_avg = 50       # Try: 30, 50, 75 (trend decomposition window)

# --- Training parameters (affect learning and convergence) ---
args.train_epochs = 50     # Try: 30, 50, 80
args.batch_size = 16       # Try: 8, 16, 24 (smaller for large models)
args.learning_rate = 0.0005  # Try: 0.0003, 0.0005, 0.001
args.dropout = 0.15        # Try: 0.1, 0.15, 0.2 (higher for larger models)
args.patience = 15         # Try: 10, 15, 20

# --- Loss functions to experiment with ---
# args.loss = 'MSE'        # Standard mean squared error
# args.loss = 'MAE'        # Mean absolute error (robust to outliers)
# args.loss = 'Huber'      # Combination of MSE and MAE
# args.loss = 'MAPE'       # Mean absolute percentage error
# args.loss = 'SMAPE'      # Symmetric mean absolute percentage error

print("🎛️ Parameters ready for tweaking")
print("💡 Tip: Start with default values, then adjust one parameter at a time")
print("⚠️  Note: Larger models require more GPU memory and training time")

## 📊 Data Loading and Preparation

Load the prepared financial dataset with all targets and covariates:

In [None]:
# Set device and random seeds
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

print(f"🎲 Random seed set to: {args.seed}")
print(f"💻 Using device: {device}")

# Create checkpoint directory
os.makedirs(args.checkpoints, exist_ok=True)
print(f"📁 Checkpoints will be saved to: {args.checkpoints}")

In [None]:
# Load and inspect the data
data_path = os.path.join(args.root_path, args.data_path)
print(f"📂 Loading data from: {data_path}")

if not os.path.exists(data_path):
    print(f"❌ Data file not found: {data_path}")
    print("Please run the data preparation script first:")
    print("python example_data_preparation.py")
else:
    df = pd.read_csv(data_path)
    print(f"✅ Data loaded successfully")
    print(f"   📊 Shape: {df.shape}")
    print(f"   📅 Date range: {df['date'].min()} to {df['date'].max()}")
    print(f"   🎯 Target columns: log_Open, log_High, log_Low, log_Close")
    print(f"   🔧 Feature columns: {df.shape[1] - 5} (excluding date and targets)")
    
    # Show first few rows
    print("\n📋 First 3 rows:")
    display(df.head(3))

## 🏗️ Model Setup and Data Loaders

Create TimesNet model and prepare data loaders for training:

In [None]:
# Create data loaders
print("🔄 Creating data loaders...")

# Training data loader
train_data = Dataset_Custom(
    root_path=args.root_path,
    data_path=args.data_path,
    flag='train',
    size=[args.seq_len, args.label_len, args.pred_len],
    features=args.features,
    target=args.target,
    timeenc=0 if args.embed != 'timeF' else 1,
    freq=args.freq,
    val_len=args.val_len,
    test_len=args.test_len
)

train_loader = DataLoader(
    train_data,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=args.num_workers,
    drop_last=True
)

# Validation data loader
val_data = Dataset_Custom(
    root_path=args.root_path,
    data_path=args.data_path,
    flag='val',
    size=[args.seq_len, args.label_len, args.pred_len],
    features=args.features,
    target=args.target,
    timeenc=0 if args.embed != 'timeF' else 1,
    freq=args.freq,
    val_len=args.val_len,
    test_len=args.test_len
)

val_loader = DataLoader(
    val_data,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    drop_last=False
)

# Test data loader
test_data = Dataset_Custom(
    root_path=args.root_path,
    data_path=args.data_path,
    flag='test',
    size=[args.seq_len, args.label_len, args.pred_len],
    features=args.features,
    target=args.target,
    timeenc=0 if args.embed != 'timeF' else 1,
    freq=args.freq,
    val_len=args.val_len,
    test_len=args.test_len
)

test_loader = DataLoader(
    test_data,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    drop_last=False
)

print(f"✅ Data loaders created:")
print(f"   🎓 Training batches: {len(train_loader)}")
print(f"   🔍 Validation batches: {len(val_loader)}")
print(f"   🧪 Test batches: {len(test_loader)}")

In [None]:
# Initialize TimesNet model
print("🧠 Initializing TimesNet model...")

model = TimesNet(args).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"✅ TimesNet model initialized")
print(f"   🔢 Total parameters: {total_params:,}")
print(f"   🎯 Trainable parameters: {trainable_params:,}")
print(f"   💾 Model size: ~{total_params * 4 / 1e6:.1f} MB")

# Setup optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
criterion = torch.nn.MSELoss()

# Early stopping
early_stopping = EarlyStopping(patience=args.patience, verbose=True)

print(f"⚙️ Optimizer: Adam (lr={args.learning_rate})")
print(f"📉 Loss function: {args.loss}")
print(f"⏰ Early stopping patience: {args.patience}")

## 🚀 Training Loop

Train the TimesNet model with progress tracking:

In [None]:
# Training function
def train_epoch(model, train_loader, optimizer, criterion, device, use_amp=False):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    num_batches = len(train_loader)
    
    if use_amp:
        scaler = torch.cuda.amp.GradScaler()
    
    for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
        optimizer.zero_grad()
        
        batch_x = batch_x.float().to(device)
        batch_y = batch_y.float().to(device)
        batch_x_mark = batch_x_mark.float().to(device)
        batch_y_mark = batch_y_mark.float().to(device)
        
        # Decoder input
        dec_inp = torch.zeros_like(batch_y[:, -args.pred_len:, :]).float()
        dec_inp = torch.cat([batch_y[:, :args.label_len, :], dec_inp], dim=1).float().to(device)
        
        if use_amp:
            with torch.cuda.amp.autocast():
                outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                # Focus loss on target columns (first 4 columns)
                loss = criterion(outputs[:, -args.pred_len:, :4], batch_y[:, -args.pred_len:, :4])
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
            # Focus loss on target columns (first 4 columns)
            loss = criterion(outputs[:, -args.pred_len:, :4], batch_y[:, -args.pred_len:, :4])
            
            loss.backward()
            optimizer.step()
        
        total_loss += loss.item()
        
        # Progress update every 20% of batches
        if (i + 1) % max(1, num_batches // 5) == 0:
            print(f"    Batch {i+1}/{num_batches} - Loss: {loss.item():.6f}")
    
    return total_loss / len(train_loader)

def validate_epoch(model, val_loader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch_x, batch_y, batch_x_mark, batch_y_mark in val_loader:
            batch_x = batch_x.float().to(device)
            batch_y = batch_y.float().to(device)
            batch_x_mark = batch_x_mark.float().to(device)
            batch_y_mark = batch_y_mark.float().to(device)
            
            # Decoder input
            dec_inp = torch.zeros_like(batch_y[:, -args.pred_len:, :]).float()
            dec_inp = torch.cat([batch_y[:, :args.label_len, :], dec_inp], dim=1).float().to(device)
            
            outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
            # Focus loss on target columns (first 4 columns)
            loss = criterion(outputs[:, -args.pred_len:, :4], batch_y[:, -args.pred_len:, :4])
            
            total_loss += loss.item()
    
    return total_loss / len(val_loader)

print("🎯 Training functions defined")
print("📝 Ready to start training...")

In [None]:
# Main training loop
print(f"🚀 Starting training for {args.train_epochs} epochs...")
print(f"⚡ Using AMP: {args.use_amp}")
print("=" * 60)

train_losses = []
val_losses = []
start_time = time.time()

for epoch in range(args.train_epochs):
    epoch_start = time.time()
    
    print(f"\n📊 Epoch {epoch+1}/{args.train_epochs}")
    
    # Training
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device, args.use_amp)
    
    # Validation
    val_loss = validate_epoch(model, val_loader, criterion, device)
    
    # Learning rate adjustment
    adjust_learning_rate(optimizer, epoch + 1, args)
    
    # Record losses
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    epoch_time = time.time() - epoch_start
    
    print(f"    📈 Train Loss: {train_loss:.6f}")
    print(f"    📉 Val Loss: {val_loss:.6f}")
    print(f"    ⏱️  Epoch Time: {epoch_time:.1f}s")
    
    # Early stopping
    early_stopping(val_loss, model, args.checkpoints)
    if early_stopping.early_stop:
        print("\n⏹️  Early stopping triggered")
        break
    
    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        checkpoint_path = os.path.join(args.checkpoints, f'checkpoint_epoch_{epoch+1}.pth')
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
        }, checkpoint_path)
        print(f"    💾 Checkpoint saved: {checkpoint_path}")

total_time = time.time() - start_time

print("\n" + "=" * 60)
print("🎉 Training completed!")
print(f"⏰ Total training time: {total_time:.1f}s ({total_time/60:.1f} minutes)")
print(f"📊 Final train loss: {train_losses[-1]:.6f}")
print(f"📉 Final val loss: {val_losses[-1]:.6f}")
print(f"🏆 Best val loss: {min(val_losses):.6f} (epoch {val_losses.index(min(val_losses))+1})")

## 📈 Training Visualization

Plot training and validation loss curves:

In [None]:
import matplotlib.pyplot as plt

# Plot training curves
plt.figure(figsize=(12, 5))

# Loss curves
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss', color='blue', alpha=0.7)
plt.plot(val_losses, label='Validation Loss', color='red', alpha=0.7)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True, alpha=0.3)

# Loss curves (log scale)
plt.subplot(1, 2, 2)
plt.plot(train_losses, label='Training Loss', color='blue', alpha=0.7)
plt.plot(val_losses, label='Validation Loss', color='red', alpha=0.7)
plt.xlabel('Epoch')
plt.ylabel('Loss (log scale)')
plt.title('Training and Validation Loss (Log Scale)')
plt.yscale('log')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"📊 Training summary:")
print(f"   🎯 Epochs completed: {len(train_losses)}")
print(f"   📈 Loss improvement: {(train_losses[0] - train_losses[-1])/train_losses[0]*100:.1f}%")
print(f"   ⚠️  Overfitting check: {'Yes' if val_losses[-1] > min(val_losses) * 1.1 else 'No'}")

## 🧪 Model Testing

Evaluate the trained model on the test set:

In [None]:
# Load best model
best_model_path = os.path.join(args.checkpoints, 'checkpoint.pth')
if os.path.exists(best_model_path):
    model.load_state_dict(torch.load(best_model_path))
    print(f"✅ Loaded best model from: {best_model_path}")
else:
    print("⚠️  Using current model weights (best model checkpoint not found)")

# Test the model
print("\n🧪 Testing model...")

model.eval()
test_loss = 0
all_predictions = []
all_targets = []

with torch.no_grad():
    for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
        batch_x = batch_x.float().to(device)
        batch_y = batch_y.float().to(device)
        batch_x_mark = batch_x_mark.float().to(device)
        batch_y_mark = batch_y_mark.float().to(device)
        
        # Decoder input
        dec_inp = torch.zeros_like(batch_y[:, -args.pred_len:, :]).float()
        dec_inp = torch.cat([batch_y[:, :args.label_len, :], dec_inp], dim=1).float().to(device)
        
        outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
        
        # Focus on target columns (first 4)
        pred = outputs[:, -args.pred_len:, :4]
        true = batch_y[:, -args.pred_len:, :4]
        
        loss = criterion(pred, true)
        test_loss += loss.item()
        
        # Store predictions and targets
        all_predictions.append(pred.cpu().numpy())
        all_targets.append(true.cpu().numpy())
        
        if i == 0:  # Show progress for first few batches
            print(f"    Test batch {i+1}/{len(test_loader)} - Loss: {loss.item():.6f}")

test_loss /= len(test_loader)

# Concatenate all predictions and targets
all_predictions = np.concatenate(all_predictions, axis=0)
all_targets = np.concatenate(all_targets, axis=0)

print(f"\n📊 Test Results:")
print(f"   📉 Test Loss: {test_loss:.6f}")
print(f"   📐 Predictions shape: {all_predictions.shape}")
print(f"   🎯 Targets shape: {all_targets.shape}")

# Calculate additional metrics
mae = np.mean(np.abs(all_predictions - all_targets))
mse = np.mean((all_predictions - all_targets) ** 2)
rmse = np.sqrt(mse)

print(f"\n📈 Additional Metrics:")
print(f"   MAE: {mae:.6f}")
print(f"   MSE: {mse:.6f}")
print(f"   RMSE: {rmse:.6f}")

## 🔮 Model Analysis

Analyze model performance and visualize predictions:

In [None]:
# Visualize some predictions
target_names = ['log_Open', 'log_High', 'log_Low', 'log_Close']
n_samples = min(3, all_predictions.shape[0])

fig, axes = plt.subplots(n_samples, 4, figsize=(16, 4*n_samples))
if n_samples == 1:
    axes = axes.reshape(1, -1)

for sample in range(n_samples):
    for target in range(4):
        ax = axes[sample, target]
        
        # Plot predictions vs targets
        time_steps = range(args.pred_len)
        ax.plot(time_steps, all_targets[sample, :, target], 'b-', label='True', alpha=0.7, linewidth=2)
        ax.plot(time_steps, all_predictions[sample, :, target], 'r--', label='Predicted', alpha=0.7, linewidth=2)
        
        ax.set_title(f'Sample {sample+1}: {target_names[target]}')
        ax.set_xlabel('Time Step')
        ax.set_ylabel('Value')
        ax.legend()
        ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"📊 Visualized {n_samples} prediction samples for all targets")

In [None]:
# Performance analysis by target
print("\n🎯 Performance by Target:")
print("=" * 50)

for i, target_name in enumerate(target_names):
    target_pred = all_predictions[:, :, i]
    target_true = all_targets[:, :, i]
    
    mae = np.mean(np.abs(target_pred - target_true))
    mse = np.mean((target_pred - target_true) ** 2)
    rmse = np.sqrt(mse)
    
    # Correlation
    corr = np.corrcoef(target_pred.flatten(), target_true.flatten())[0, 1]
    
    print(f"{target_name:12} | MAE: {mae:.6f} | RMSE: {rmse:.6f} | Corr: {corr:.4f}")

print("\n💡 Tips for improvement:")
print("   - If correlation is low: try longer seq_len or more layers")
print("   - If MAE is high: try different loss functions (MAE, Huber)")
print("   - If overfitting: increase dropout or reduce model size")
print("   - If underfitting: increase model capacity or training epochs")

## 💾 Save Model and Results

Save the trained model and experiment results:

In [None]:
# Save final model
final_model_path = os.path.join(args.checkpoints, 'final_model.pth')
torch.save({
    'model_state_dict': model.state_dict(),
    'config': vars(args),
    'train_losses': train_losses,
    'val_losses': val_losses,
    'test_loss': test_loss,
    'test_metrics': {
        'mae': mae,
        'mse': mse,
        'rmse': rmse
    }
}, final_model_path)

print(f"💾 Final model saved to: {final_model_path}")

# Save experiment summary
summary_path = os.path.join(args.checkpoints, 'experiment_summary.txt')
with open(summary_path, 'w') as f:
    f.write("TimesNet Mid-Heavy Configuration - Experiment Summary\n")
    f.write("=" * 60 + "\n\n")
    f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(f"Configuration: {args.des}\n\n")
    
    f.write("Model Architecture:\n")
    f.write(f"  - Sequence Length: {args.seq_len}\n")
    f.write(f"  - Prediction Length: {args.pred_len}\n")
    f.write(f"  - Model Dimension: {args.d_model}\n")
    f.write(f"  - Feed-forward Dim: {args.d_ff}\n")
    f.write(f"  - Attention Heads: {args.n_heads}\n")
    f.write(f"  - Encoder Layers: {args.e_layers}\n")
    f.write(f"  - Decoder Layers: {args.d_layers}\n")
    f.write(f"  - Top-k Frequencies: {args.top_k}\n")
    f.write(f"  - Kernels: {args.num_kernels}\n")
    f.write(f"  - Total Parameters: {total_params:,}\n\n")
    
    f.write("Training Configuration:\n")
    f.write(f"  - Epochs: {len(train_losses)}\n")
    f.write(f"  - Batch Size: {args.batch_size}\n")
    f.write(f"  - Learning Rate: {args.learning_rate}\n")
    f.write(f"  - Dropout: {args.dropout}\n")
    f.write(f"  - Use AMP: {args.use_amp}\n")
    f.write(f"  - Training Time: {total_time:.1f}s\n\n")
    
    f.write("Results:\n")
    f.write(f"  - Final Train Loss: {train_losses[-1]:.6f}\n")
    f.write(f"  - Final Val Loss: {val_losses[-1]:.6f}\n")
    f.write(f"  - Best Val Loss: {min(val_losses):.6f}\n")
    f.write(f"  - Test Loss: {test_loss:.6f}\n")
    f.write(f"  - Test MAE: {mae:.6f}\n")
    f.write(f"  - Test RMSE: {rmse:.6f}\n")

print(f"📝 Experiment summary saved to: {summary_path}")

print("\n🎉 Mid-Heavy Configuration Training Complete!")
print(f"📁 All results saved in: {args.checkpoints}")
print("\n💡 Next steps:")
print("   - Compare with light and medium configurations")
print("   - Try different hyperparameters")
print("   - Experiment with different loss functions")
print("   - Implement production forecasting")