# 配置驱动的HAR训练流程 - 任务2.3完整实现

这个notebook实现了完全由配置文件驱动的训练执行引擎。
通过修改配置文件即可启动不同的实验，无需修改代码。

## 步骤 1: 导入所有必需的库
首先，我们导入所有需要的标准库、第三方库和项目内部模块。

In [None]:
# 标准库导入
import os
import sys
import time
import yaml
import random
import logging
import pickle
import json
from pathlib import Path
from typing import Dict, Any, Optional, List, Tuple

# 第三方库导入
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import f1_score, confusion_matrix, classification_report
from sklearn.utils import class_weight
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns

# --- 项目内部导入 (占位符) ---
# 在这个独立的notebook中，我们将必要的辅助函数和模型直接包含进来
#真实的模块导入
try:
    from config.config_loader import ConfigLoader
    from config.config_bridge import ConfigBridge
except ImportError:
    print("警告: 配置模块未找到，将只使用传统模式")
    ConfigLoader = None
    ConfigBridge = None

import utils_torch as utils
import model_cbranchformer as model
#ConfigLoader = None
#ConfigBridge = None

print("所有模块导入完成")

## 步骤 2: 定义辅助工具、模型和配置桥接器
为了使此Notebook可以独立运行，我们将原本在 `utils_torch.py`、`model_cbranchformer.py` 和 `config_bridge.py` 中的关键代码直接定义在这里。

In [None]:
# === 从 utils_torch.py 移入的关键代码 ===
class HARDataset(torch.utils.data.Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

def load_dataset_pytorch(dataset_name, client_count, data_config, random_seed, base_path):
    # 这是一个简化的数据加载器，用于演示目的
    # 它会生成随机数据来模拟真实的数据集加载
    print(f"正在生成 {dataset_name} 的模拟数据...")
    
    # 根据数据集名称定义参数
    if dataset_name.upper() == 'UCI':
        train_samples, test_samples = 7352, 2947
        seq_len, channels = 128, 6
    else: # 默认或其它
        train_samples, test_samples = 10000, 2000
        seq_len, channels = 100, 9
    
    # 创建模拟数据
    train_data = torch.randn(train_samples, seq_len, channels)
    train_label = torch.randint(0, 6, (train_samples,))
    test_data = torch.randn(test_samples, seq_len, channels)
    test_label = torch.randint(0, 6, (test_samples,))

    class MockDataset:
        def __init__(self):
            self.central_train_data = train_data
            self.central_train_label = train_label
            self.central_test_data = test_data
            self.central_test_label = test_label
            self.central_dev_data = None # 让主程序自己分割验证集
            self.central_dev_label = None

    return MockDataset()

def return_client_by_dataset(dataset_name):
    # 模拟函数
    return 1 # 集中式训练
utils = type('utils', (), {'HARDataset': HARDataset, 'load_dataset_pytorch': load_dataset_pytorch, 'return_client_by_dataset': return_client_by_dataset})

# === 从 model_cbranchformer.py 移入的关键代码 (简化版) ===
class cbranchformer_har_base(nn.Module):
    def __init__(self, input_shape, activity_count, **kwargs):
        super().__init__()
        # 这是一个非常简化的模型结构，用于演示
        self.flatten = nn.Flatten()
        in_features = input_shape[0] * input_shape[1]
        self.fc1 = nn.Linear(in_features, 128)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(kwargs.get('dropout_rate', 0.5))
        self.fc2 = nn.Linear(128, activity_count)
        print(f"创建了一个简化的 cbranchformer_har_base 模型")

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class MobileHART_XS(nn.Module):
    def __init__(self, input_shape, activity_count, **kwargs):
        super().__init__()
        self.flatten = nn.Flatten()
        in_features = input_shape[0] * input_shape[1]
        self.fc1 = nn.Linear(in_features, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, activity_count)
        print(f"创建了一个简化的 MobileHART_XS 模型")

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = type('model', (), {'cbranchformer_har_base': cbranchformer_har_base, 'MobileHART_XS': MobileHART_XS})

# === 从 config_bridge.py 移入的关键代码 ===
class ConfigBridge:
    def __init__(self, config_path: str):
        with open(config_path, 'r') as f:
            self.raw_config = yaml.safe_load(f)
        self.config = self._to_dot_notation(self.raw_config)
        self.use_new_config = True
        print("配置桥接器初始化完成")

    def _to_dot_notation(self, data):
        if isinstance(data, dict):
            return type('DotDict', (), {k: self._to_dot_notation(v) for k, v in data.items()})
        elif isinstance(data, list):
            return [self._to_dot_notation(i) for i in data]
        else:
            return data

    def get_dataset_config(self) -> Dict[str, Any]:
        return self.raw_config.get('dataset', {})

    def get_training_config(self) -> Dict[str, Any]:
        return self.raw_config.get('training', {})

## 步骤 3: 定义核心训练器类
这里是包含完整 `ConfigurableTrainer` 类的代码。`validate_epoch` 方法已经被修复和补全。

In [None]:
class ConfigurableTrainer:
    """配置驱动的训练器类 - 完整实现任务2.3"""
    
    def __init__(self, config_path: Optional[str] = None):
        """
        初始化训练器
        
        Args:
            config_path: 配置文件路径，如果为None则使用传统硬编码方式
        """
        self.config_path = config_path
        self.use_config = config_path is not None and ConfigBridge is not None
        
        # 初始化配置桥接器
        if self.use_config:
            try:
                self.config_bridge = ConfigBridge(config_path)
                self.config = self.config_bridge.config
                self.use_config = self.config_bridge.use_new_config
            except Exception as e:
                print(f"配置文件加载失败: {e}")
                print("回退到传统硬编码模式")
                self.use_config = False
                self.config_bridge = None
                self.config = None
        else:
            self.config_bridge = None
            self.config = None
        
        # 初始化训练状态
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.criterion = None
        self.device = None
        self.output_dir = None
        self.logger = None
        
        # 训练历史
        self.history = {
            'train_loss': [],
            'train_accuracy': [],
            'val_loss': [],
            'val_accuracy': []
        }
        
        print(f"训练器初始化完成，使用配置模式: {'新配置系统' if self.use_config else '传统硬编码'}")
    
    def setup_logging(self, output_dir: Path, verbose: bool = True):
        """设置日志系统"""
        log_level = logging.INFO if verbose else logging.WARNING
        log_dir = output_dir / 'logs'
        log_dir.mkdir(parents=True, exist_ok=True)
        
        # 清除已有的处理器
        for handler in logging.root.handlers[:]:
            logging.root.removeHandler(handler)
        
        logging.basicConfig(
            level=log_level,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(log_dir / 'training.log'),
                logging.StreamHandler(sys.stdout)
            ]
        )
        self.logger = logging.getLogger(__name__)
        self.logger.info("日志系统初始化完成")
    
    def get_parameters(self) -> Dict[str, Any]:
        """获取训练参数，支持配置文件和硬编码两种模式"""
        if self.use_config:
            return self._get_config_parameters()
        else:
            return self._get_hardcoded_parameters()
    
    def _get_config_parameters(self) -> Dict[str, Any]:
        """从配置文件获取参数"""
        dataset_config = self.config_bridge.get_dataset_config()
        training_config = self.config_bridge.get_training_config()
        
        # 计算总通道数
        total_channels = sum(mod['channels'] for mod in dataset_config['modalities'])
        
        # 获取主要专家模型类型
        main_expert = list(self.config.architecture.experts.values())[0]
        architecture_type = main_expert.type.upper()
        
        return {
            # 实验信息
            'experiment_name': self.config.name,
            'description': getattr(self.config, 'description', ''),
            
            # 数据集参数
            'dataset_name': dataset_config['name'],
            'data_config': 'BALANCED',  # 默认值，可配置
            'activity_labels': dataset_config['activity_labels'],
            'activity_count': len(dataset_config['activity_labels']),
            
            # 模型参数
            'architecture': architecture_type,
            'segment_size': dataset_config['modalities'][0]['sequence_length'],
            'num_input_channels': total_channels,
            'input_shape': (dataset_config['modalities'][0]['sequence_length'], total_channels),
            
            # 训练参数
            'batch_size': training_config['batch_size'],
            'learning_rate': training_config['learning_rate'],
            'local_epoch': training_config['epochs'],
            'dropout_rate': self.config.architecture.dropout_rate,
            
            # 高级参数
            'weight_decay': training_config.get('weight_decay', 1e-4),
            'gradient_clip_norm': training_config.get('gradient_clip_norm', 1.0),
            'label_smoothing': training_config.get('label_smoothing', 0.1),
            'optimizer_name': training_config.get('optimizer', 'adam'),
            'scheduler_name': training_config.get('scheduler', 'cosine'),
            
            # 模型特定参数
            'projection_dim': main_expert.params.get('projection_dim', 192),
            'frame_length': main_expert.params.get('frame_length', 16),
            'time_step': main_expert.params.get('time_step', 16),
            'filter_attention_head': main_expert.params.get('filter_attention_head', 4),
            'conv_kernels': main_expert.params.get('conv_kernels', [3, 7, 15, 31, 31, 31]),
            'token_based': main_expert.params.get('token_based', False),
            
            # 系统参数
            'random_seed': self.config.seed,
            'device': self.config.device,
            'output_dir': self.config.output_dir,
            'save_checkpoints': self.config.save_checkpoints,
            'verbose': self.config.verbose,
            
            # 可视化参数
            'show_train_verbose': 1 if self.config.verbose else 0,
            'plot_learning_curves': getattr(self.config, 'visualization', {}).get('plot_learning_curves', True),
            'plot_confusion_matrix': getattr(self.config, 'visualization', {}).get('plot_confusion_matrix', True)
        }
    
    def _get_hardcoded_parameters(self) -> Dict[str, Any]:
        """使用硬编码参数（向后兼容）"""
        return {
            # 实验信息
            'experiment_name': 'Traditional_Hardcoded_Experiment',
            'description': '使用传统硬编码参数的实验',
            
            # 硬编码默认参数（基于原main_torch.ipynb）
            'dataset_name': 'UCI',
            'data_config': 'BALANCED',
            'activity_labels': ['Walk', 'Upstair', 'Downstair', 'Sit', 'Stand', 'Lay'],
            'activity_count': 6,
            'architecture': 'HART',
            'segment_size': 128,
            'num_input_channels': 6,
            'input_shape': (128, 6),
            
            # 训练参数
            'batch_size': 256,
            'learning_rate': 5e-3,
            'local_epoch': 50,
            'dropout_rate': 0.3,
            'weight_decay': 1e-4,
            'gradient_clip_norm': 1.0,
            'label_smoothing': 0.1,
            'optimizer_name': 'adam',
            'scheduler_name': 'cosine',
            
            # 模型参数（基于原代码）
            'projection_dim': 192,
            'frame_length': 16,
            'time_step': 16,
            'filter_attention_head': 4,
            'conv_kernels': [3, 7, 15, 31, 31, 31],
            'token_based': False,
            
            # 系统参数
            'random_seed': 1,
            'device': 'auto',
            'output_dir': './results/traditional_experiment',
            'save_checkpoints': True,
            'verbose': True,
            'show_train_verbose': 1,
            'plot_learning_curves': True,
            'plot_confusion_matrix': True
        }
    
    def setup_environment(self, params: Dict[str, Any]) -> None:
        """设置训练环境"""
        # 设置随机种子
        random_seed = params['random_seed']
        os.environ['PYTHONHASHSEED'] = str(random_seed)
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
        torch.cuda.manual_seed(random_seed)
        torch.cuda.manual_seed_all(random_seed)
        random.seed(random_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
        # 设置设备
        if params['device'] == 'auto':
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(params['device'])
        
        # 创建输出目录
        self.output_dir = Path(params['output_dir'])
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # 设置日志
        self.setup_logging(self.output_dir, params['verbose'])
        
        self.logger.info(f"环境设置完成: 设备={self.device}, 随机种子={random_seed}, 输出目录={self.output_dir}")
        self.logger.info(f"实验名称: {params['experiment_name']}")
        if params['description']:
            self.logger.info(f"实验描述: {params['description']}")
    
    def load_data(self, params: Dict[str, Any]) -> Tuple[torch.utils.data.DataLoader, ...]:
        """加载数据集"""
        dataset_name = params['dataset_name']
        data_config = params['data_config']
        random_seed = params['random_seed']
        batch_size = params['batch_size']
        
        self.logger.info(f"开始加载数据集: {dataset_name}")
        
        try:
            # 使用现有的数据加载函数
            client_count = utils.return_client_by_dataset(dataset_name)
            loaded_dataset = utils.load_dataset_pytorch(
                dataset_name, client_count, data_config, random_seed, './datasets/'
            )
            
            # 获取数据
            central_train_data = loaded_dataset.central_train_data
            central_train_label = loaded_dataset.central_train_label
            central_test_data = loaded_dataset.central_test_data
            central_test_label = loaded_dataset.central_test_label
            
            # 处理验证集
            if hasattr(loaded_dataset, 'central_dev_data') and loaded_dataset.central_dev_data is not None:
                central_dev_data = loaded_dataset.central_dev_data
                central_dev_label = loaded_dataset.central_dev_label
            else:
                # 从训练集分割验证集
                central_train_data, central_dev_data, central_train_label, central_dev_label = train_test_split(
                    central_train_data.cpu().numpy(), 
                    central_train_label.cpu().numpy(),
                    test_size=0.125, 
                    random_state=random_seed,
                    stratify=central_train_label.cpu().numpy()
                )
                central_train_data = torch.FloatTensor(central_train_data)
                central_train_label = torch.LongTensor(central_train_label)
                central_dev_data = torch.FloatTensor(central_dev_data)
                central_dev_label = torch.LongTensor(central_dev_label)
            
            # 创建数据集
            train_dataset = utils.HARDataset(central_train_data, central_train_label)
            dev_dataset = utils.HARDataset(central_dev_data, central_dev_label)
            test_dataset = utils.HARDataset(central_test_data, central_test_label)
            
            # 创建数据加载器
            train_loader = torch.utils.data.DataLoader(
                train_dataset, batch_size=batch_size, shuffle=True, 
                num_workers=2, pin_memory=True
            )
            dev_loader = torch.utils.data.DataLoader(
                dev_dataset, batch_size=batch_size, shuffle=False, 
                num_workers=2, pin_memory=True
            )
            test_loader = torch.utils.data.DataLoader(
                test_dataset, batch_size=batch_size, shuffle=False, 
                num_workers=2, pin_memory=True
            )
            
            self.logger.info(
                f"数据加载完成: 训练集={len(train_dataset)}, "
                f"验证集={len(dev_dataset)}, 测试集={len(test_dataset)}"
            )
            
            return train_loader, dev_loader, test_loader, central_train_label.cpu().numpy()
            
        except Exception as e:
            self.logger.error(f"数据加载失败: {e}")
            raise
    
    def create_model(self, params: Dict[str, Any]) -> nn.Module:
        """创建模型"""
        architecture = params['architecture']
        input_shape = params['input_shape']
        activity_count = params['activity_count']
        
        self.logger.info(f"创建模型: {architecture}, 输入形状={input_shape}, 类别数={activity_count}")
        
        try:
            if architecture == "HART":
                self.model = model.cbranchformer_har_base(
                    input_shape=input_shape,
                    activity_count=activity_count,
                    projection_dim=params['projection_dim'],
                    patch_size=params['frame_length'],
                    time_step=params['time_step'],
                    num_heads=3,
                    filter_attention_head=params['filter_attention_head'],
                    conv_kernels=params['conv_kernels'],
                    dropout_rate=params['dropout_rate'],
                    use_tokens=params['token_based']
                ).to(self.device)
            else:
                # MobileHART或其他架构
                self.model = model.MobileHART_XS(
                    input_shape=input_shape,
                    activity_count=activity_count
                ).to(self.device)
            
            total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
            self.logger.info(f"模型创建完成，总参数数量: {total_params:,}")
            
            return self.model
            
        except Exception as e:
            self.logger.error(f"模型创建失败: {e}")
            raise
    
    def setup_training(self, params: Dict[str, Any], train_labels: np.ndarray) -> None:
        """设置训练组件"""
        try:
            # 计算类权重
            temp_weights = class_weight.compute_class_weight(
                class_weight='balanced',
                classes=np.unique(train_labels),
                y=train_labels.ravel()
            )
            class_weights = {j: temp_weights[j] for j in range(len(temp_weights))}
            class_weights_tensor = torch.FloatTensor(
                [class_weights[i] for i in range(params['activity_count'])]
            ).to(self.device)
            
            # 创建损失函数
            self.criterion = nn.CrossEntropyLoss(
                weight=class_weights_tensor, 
                label_smoothing=params['label_smoothing']
            )
            
            # 创建优化器
            optimizer_name = params['optimizer_name'].lower()
            if optimizer_name == 'adam':
                self.optimizer = optim.Adam(
                    self.model.parameters(),
                    lr=params['learning_rate'],
                    weight_decay=params['weight_decay']
                )
            elif optimizer_name == 'adamw':
                self.optimizer = optim.AdamW(
                    self.model.parameters(),
                    lr=params['learning_rate'],
                    weight_decay=params['weight_decay']
                )
            else:
                self.optimizer = optim.SGD(
                    self.model.parameters(),
                    lr=params['learning_rate'],
                    momentum=0.9,
                    weight_decay=params['weight_decay']
                )
            
            # 创建学习率调度器
            scheduler_name = params['scheduler_name'].lower()
            if scheduler_name == 'cosine':
                self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
                    self.optimizer,
                    T_max=params['local_epoch'],
                    eta_min=1e-6
                )
            elif scheduler_name == 'step':
                self.scheduler = optim.lr_scheduler.StepLR(
                    self.optimizer,
                    step_size=params['local_epoch'] // 3,
                    gamma=0.1
                )
            else:
                self.scheduler = None
            
            self.logger.info(f"训练设置完成: 优化器={optimizer_name}, 调度器={scheduler_name}")
            self.logger.info(f"类权重: {[f'{i}:{w:.3f}' for i, w in class_weights.items()]}")
            
        except Exception as e:
            self.logger.error(f"训练设置失败: {e}")
            raise
    
    def train_epoch(self, train_loader: torch.utils.data.DataLoader, 
                   params: Dict[str, Any]) -> Tuple[float, float]:
        """训练一个epoch"""
        self.model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            
            # 前向传播
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)
            
            # 反向传播
            loss.backward()
            
            # 梯度裁剪
            if params['gradient_clip_norm'] > 0:
                nn.utils.clip_grad_norm_(
                    self.model.parameters(), 
                    max_norm=params['gradient_clip_norm']
                )
            
            self.optimizer.step()
            
            # 统计
            train_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()
        
        avg_loss = train_loss / max(train_total, 1)
        avg_accuracy = train_correct / max(train_total, 1)
        return avg_loss, avg_accuracy
    
    def validate_epoch(self, val_loader: torch.utils.data.DataLoader) -> Tuple[float, float]:
        """验证一个epoch (已修复)"""
        self.model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(val_loader):
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                
                # 前向传播
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                
                # 统计
                val_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
        
        # 计算平均损失和准确率
        avg_loss = val_loss / max(val_total, 1)
        avg_accuracy = val_correct / max(val_total, 1)
        
        return avg_loss, avg_accuracy

    def evaluate(self, test_loader: torch.utils.data.DataLoader, params: Dict[str, Any]) -> None:
        """在测试集上评估模型"""
        self.model.eval()
        all_preds = []
        all_targets = []
        test_loss, test_acc = self.validate_epoch(test_loader) # 复用验证逻辑

        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs = self.model(inputs)
                _, predicted = outputs.max(1)
                all_preds.extend(predicted.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
        
        f1 = f1_score(all_targets, all_preds, average='weighted')
        self.logger.info(f"测试结果: Loss={test_loss:.4f}, Accuracy={test_acc:.4f}, F1-Score={f1:.4f}")

        if params['plot_confusion_matrix']:
            self.plot_confusion_matrix(all_targets, all_preds, params['activity_labels'])
    
    def plot_confusion_matrix(self, y_true, y_pred, labels):
        """绘制混淆矩阵"""
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        cm_path = self.output_dir / 'confusion_matrix.png'
        plt.savefig(cm_path)
        self.logger.info(f"混淆矩阵已保存至 {cm_path}")
        plt.show()

    def plot_learning_curves(self):
        """绘制学习曲线"""
        plt.figure(figsize=(12, 5))
        
        # 损失曲线
        plt.subplot(1, 2, 1)
        plt.plot(self.history['train_loss'], label='Train Loss')
        plt.plot(self.history['val_loss'], label='Validation Loss')
        plt.title('Loss Curves')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        # 准确率曲线
        plt.subplot(1, 2, 2)
        plt.plot(self.history['train_accuracy'], label='Train Accuracy')
        plt.plot(self.history['val_accuracy'], label='Validation Accuracy')
        plt.title('Accuracy Curves')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        
        plt.tight_layout()
        curves_path = self.output_dir / 'learning_curves.png'
        plt.savefig(curves_path)
        self.logger.info(f"学习曲线已保存至 {curves_path}")
        plt.show()

    def run(self):
        """执行完整的训练和评估流程"""
        params = self.get_parameters()
        self.setup_environment(params)
        self.logger.info(f"--- 实验开始: {params['experiment_name']} ---")
        self.logger.info(f"参数:\n{json.dumps(params, indent=2, ensure_ascii=False)}")

        train_loader, dev_loader, test_loader, train_labels = self.load_data(params)
        self.create_model(params)
        self.setup_training(params, train_labels)

        self.logger.info("--- 开始训练 ---")
        best_val_accuracy = 0.0
        best_epoch = -1
        start_time = time.time()

        for epoch in range(params['local_epoch']):
            train_loss, train_acc = self.train_epoch(train_loader, params)
            val_loss, val_acc = self.validate_epoch(dev_loader)

            if self.scheduler:
                self.scheduler.step()
            
            self.history['train_loss'].append(train_loss)
            self.history['train_accuracy'].append(train_acc)
            self.history['val_loss'].append(val_loss)
            self.history['val_accuracy'].append(val_acc)

            self.logger.info(f"Epoch [{epoch+1}/{params['local_epoch']}] | "
                            f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
                            f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
            
            if val_acc > best_val_accuracy:
                best_val_accuracy = val_acc
                best_epoch = epoch
                if params['save_checkpoints']:
                    torch.save(self.model.state_dict(), self.output_dir / 'best_model.pth')
                    self.logger.info(f"在 Epoch {epoch+1} 保存了新的最佳模型")

        total_time = time.time() - start_time
        self.logger.info(f"--- 训练完成, 总耗时: {total_time:.2f} 秒 ---")
        self.logger.info(f"最佳验证准确率: {best_val_accuracy:.4f} (在 Epoch {best_epoch+1}) ")

        self.logger.info("--- 在测试集上进行最终评估 ---")
        if params['save_checkpoints']:
            self.model.load_state_dict(torch.load(self.output_dir / 'best_model.pth'))
            self.logger.info("已加载最佳模型进行评估")
        
        self.evaluate(test_loader, params)

        if params['plot_learning_curves']:
            self.plot_learning_curves()
        
        self.logger.info("--- 实验结束 ---")

## 步骤 4: 创建配置文件
这是本脚本的核心驱动力。下面的单元格将创建一个名为 `config.yaml` 的文件。你可以直接修改这个单元格中的内容来改变实验参数，而无需触碰上面的类定义代码。

In [None]:
%%writefile config.yaml

# 实验元数据
name: "UCI_HART_Baseline_Experiment"
description: "使用HART模型在UCI-HAR数据集上进行的基线实验"
seed: 42
device: "auto"  # 'auto', 'cpu', 'cuda'
output_dir: "./results/uci_hart_baseline"
save_checkpoints: true
verbose: true

# 数据集配置
dataset:
  name: "UCI"
  activity_labels: ['Walk', 'Upstair', 'Downstair', 'Sit', 'Stand', 'Lay']
  modalities:
    - name: "accel"
      channels: 3
      sequence_length: 128
    - name: "gyro"
      channels: 3
      sequence_length: 128

# 模型架构配置
architecture:
  dropout_rate: 0.3
  # 定义专家模型，可以有多个，但目前只使用第一个
  experts:
    expert1:
      type: "HART"
      params:
        projection_dim: 192
        frame_length: 16
        time_step: 16
        filter_attention_head: 4
        conv_kernels: [3, 7, 15, 31, 31, 31]
        token_based: false

# 训练过程配置
training:
  epochs: 20  # 为了快速演示，减少epoch数量
  batch_size: 128
  learning_rate: 0.001
  optimizer: "adamw"  # 'adam', 'adamw', 'sgd'
  weight_decay: 0.0001
  scheduler: "cosine" # 'cosine', 'step', null
  label_smoothing: 0.1
  gradient_clip_norm: 1.0

# 可视化配置
visualization:
  plot_learning_curves: true
  plot_confusion_matrix: true


## 步骤 5: 启动训练
现在，一切准备就绪。运行下面的单元格来实例化 `ConfigurableTrainer` 并启动由 `config.yaml` 文件定义的完整训练流程。

In [None]:
if __name__ == '__main__':
    # 指定配置文件路径
    config_file = 'config.yaml'
    
    # 创建并运行训练器
    trainer = ConfigurableTrainer(config_path=config_file)
    trainer.run()


### 如何运行传统硬编码模式？
如果你想不使用配置文件，而是运行代码中硬编码的参数，只需在实例化训练器时不传入 `config_path` 即可。这会自动回退到 `_get_hardcoded_parameters` 方法中定义的参数。

**示例:**
```python
# trainer = ConfigurableTrainer(config_path=None)
# trainer.run()
```