# 配置驱动的HAR训练流程 - 任务2.3完整实现

这个notebook实现了完全由配置文件驱动的训练执行引擎。
通过修改配置文件即可启动不同的实验，无需修改代码。

## 步骤 1: 导入所有必需的库
首先，我们导入所有需要的标准库、第三方库和项目内部模块。

In [None]:
# 标准库导入
import os
import sys
import time
import yaml
import random
import logging
import pickle
import json
from pathlib import Path
from typing import Dict, Any, Optional, List, Tuple

# 第三方库导入
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, ReduceLROnPlateau
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import f1_score, confusion_matrix, classification_report
from sklearn.utils import class_weight
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns

# --- 项目内部导入 (占位符) ---
# 在这个独立的notebook中，我们将必要的辅助函数和模型直接包含进来
#真实的模块导入
try:
    from config.config_loader import ConfigLoader
    from config.config_bridge import ConfigBridge
except ImportError:
    print("警告: 配置模块未找到，将只使用传统模式")
    ConfigLoader = None
    ConfigBridge = None


print("所有模块导入完成")

## 步骤 2: 定义辅助工具、模型和配置桥接器
为了使此Notebook可以独立运行，我们将原本在 `utils_torch.py`、`model_cbranchformer.py` 和 `config_bridge.py` 中的关键代码直接定义在这里。

In [None]:
# === 从 utils_torch.py 移入的关键代码 ===
class HARDataset(torch.utils.data.Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

def load_dataset_pytorch(dataset_name, client_count, data_config, random_seed, base_path):
    # 这是一个简化的数据加载器，用于演示目的
    # 它会生成随机数据来模拟真实的数据集加载
    print(f"正在生成 {dataset_name} 的模拟数据...")
    
    # 根据数据集名称定义参数
    if dataset_name.upper() == 'UCI':
        train_samples, test_samples = 7352, 2947
        seq_len, channels = 128, 6
    else: # 默认或其它
        train_samples, test_samples = 10000, 2000
        seq_len, channels = 100, 9
    
    # 创建模拟数据
    train_data = torch.randn(train_samples, seq_len, channels)
    train_label = torch.randint(0, 6, (train_samples,))
    test_data = torch.randn(test_samples, seq_len, channels)
    test_label = torch.randint(0, 6, (test_samples,))

    class MockDataset:
        def __init__(self):
            self.central_train_data = train_data
            self.central_train_label = train_label
            self.central_test_data = test_data
            self.central_test_label = test_label
            self.central_dev_data = None # 让主程序自己分割验证集
            self.central_dev_label = None

    return MockDataset()

def return_client_by_dataset(dataset_name):
    # 模拟函数
    return 1 # 集中式训练

utils = type('utils', (), {'HARDataset': HARDataset, 'load_dataset_pytorch': load_dataset_pytorch, 'return_client_by_dataset': return_client_by_dataset})

# === 从 model_cbranchformer.py 移入的关键代码 (简化版) ===
class cbranchformer_har_base(nn.Module):
    def __init__(self, input_shape, activity_count, **kwargs):
        super().__init__()
        # 这是一个非常简化的模型结构，用于演示
        self.flatten = nn.Flatten()
        in_features = input_shape[0] * input_shape[1]
        self.fc1 = nn.Linear(in_features, 128)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(kwargs.get('dropout_rate', 0.5))
        self.fc2 = nn.Linear(128, activity_count)
        print(f"创建了一个简化的 cbranchformer_har_base 模型")

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class MobileHART_XS(nn.Module):
    def __init__(self, input_shape, activity_count, **kwargs):
        super().__init__()
        self.flatten = nn.Flatten()
        in_features = input_shape[0] * input_shape[1]
        self.fc1 = nn.Linear(in_features, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, activity_count)
        print(f"创建了一个简化的 MobileHART_XS 模型")

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = type('model', (), {'cbranchformer_har_base': cbranchformer_har_base, 'MobileHART_XS': MobileHART_XS})

# === 从 config_bridge.py 移入的关键代码 ===
class ConfigBridge:
    def __init__(self, config_path: str):
        """改进的ConfigBridge初始化"""
        try:
            with open(config_path, 'r', encoding='utf-8') as f:
                self.raw_config = yaml.safe_load(f)
            self.config = self._to_dot_notation(self.raw_config)
            self.use_new_config = True
            print(f"✓ 配置文件加载成功: {config_path}")
        except FileNotFoundError:
            print(f"✗ 配置文件不存在: {config_path}")
            raise
        except yaml.YAMLError as e:
            print(f"✗ YAML解析错误: {e}")
            raise
        except Exception as e:
            print(f"✗ 配置加载失败: {e}")
            raise

    def _to_dot_notation(self, data):
        """将字典转换为点表示法对象，修复values()方法缺失问题"""
        if isinstance(data, dict):
            # 创建一个自定义类，支持字典操作
            class DotDict:
                def __init__(self, data_dict):
                    self._dict = data_dict
                    for k, v in data_dict.items():
                        setattr(self, k, self._to_dot_notation(v))
                
                def values(self):
                    return [getattr(self, k) for k in self._dict.keys()]
                
                def keys(self):
                    return self._dict.keys()
                
                def items(self):
                    return [(k, getattr(self, k)) for k in self._dict.keys()]
                
                def get(self, key, default=None):
                    return getattr(self, key, default)
                
                def _to_dot_notation(self, data):
                    if isinstance(data, dict):
                        return DotDict(data)
                    elif isinstance(data, list):
                        return [self._to_dot_notation(i) for i in data]
                    else:
                        return data
            
            return DotDict(data)
        elif isinstance(data, list):
            return [self._to_dot_notation(i) for i in data]
        else:
            return data

    def get_dataset_config(self) -> Dict[str, Any]:
        return self.raw_config.get('dataset', {})

    def get_training_config(self) -> Dict[str, Any]:
        return self.raw_config.get('training', {})
        
    def get_visualization_config(self) -> Dict[str, Any]:
        return self.raw_config.get('visualization', {})

## 步骤 3: 定义核心训练器类 (已修复并整合)
这里是修复和整合后的 `ConfigurableTrainer` 类。所有方法都被正确地定义在类中，解决了 `NameError` 并改善了代码结构。

In [None]:
class ConfigurableTrainer:
    """配置驱动的训练器类 - 完整实现任务2.3"""
    
    def __init__(self, config_path: Optional[str] = None):
        """初始化训练器"""
        self.config_path = config_path
        self.use_config = config_path is not None and ConfigBridge is not None
        
        if self.use_config:
            try:
                self.config_bridge = ConfigBridge(config_path)
                self.config = self.config_bridge.config
                self.use_config = self.config_bridge.use_new_config
            except Exception as e:
                print(f"配置文件加载失败: {e}")
                print("回退到传统硬编码模式")
                self.use_config = False
                self.config_bridge = None
                self.config = None
        else:
            self.config_bridge = None
            self.config = None
        
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.criterion = None
        self.device = None
        self.output_dir = None
        self.logger = None
        self.gradient_clip_norm = 0
        
        self.history = {
            'train_loss': [], 'train_accuracy': [],
            'val_loss': [], 'val_accuracy': []
        }
        
        print(f"训练器初始化完成，使用配置模式: {'新配置系统' if self.use_config else '传统硬编码'}")

    def setup_logging(self, output_dir: Path, verbose: bool = True):
        """设置日志系统"""
        log_level = logging.INFO if verbose else logging.WARNING
        log_dir = output_dir / 'logs'
        log_dir.mkdir(parents=True, exist_ok=True)
        
        for handler in logging.root.handlers[:]:
            logging.root.removeHandler(handler)
        
        logging.basicConfig(
            level=log_level,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(log_dir / 'training.log'),
                logging.StreamHandler(sys.stdout)
            ]
        )
        self.logger = logging.getLogger(__name__)
        self.logger.info("日志系统初始化完成")

    def get_parameters(self) -> Dict[str, Any]:
        """获取训练参数，支持配置文件和硬编码两种模式"""
        if self.use_config:
            return self._get_config_parameters()
        else:
            return self._get_hardcoded_parameters()

    def _get_config_parameters(self) -> Dict[str, Any]:
        """从配置文件获取参数"""
        try:
            dataset_config = self.config_bridge.get_dataset_config()
            training_config = self.config_bridge.get_training_config()
            visualization_config = self.config_bridge.get_visualization_config()
            
            if not dataset_config or not training_config:
                raise ValueError("配置文件不完整")
            
            total_channels = sum(mod['channels'] for mod in dataset_config['modalities'])
            
            experts_config = self.config.architecture.experts
            main_expert = list(experts_config.values())[0]
            architecture_type = main_expert.type.upper()
            
            return {
                'experiment_name': self.config.name,
                'description': getattr(self.config, 'description', ''),
                'dataset_name': dataset_config['name'],
                'data_config': 'BALANCED',
                'activity_labels': dataset_config['activity_labels'],
                'activity_count': len(dataset_config['activity_labels']),
                'architecture': architecture_type,
                'segment_size': dataset_config['modalities'][0]['sequence_length'],
                'num_input_channels': total_channels,
                'input_shape': (dataset_config['modalities'][0]['sequence_length'], total_channels),
                'batch_size': training_config['batch_size'],
                'learning_rate': training_config['learning_rate'],
                'local_epoch': training_config['epochs'],
                'dropout_rate': self.config.architecture.dropout_rate,
                'weight_decay': training_config.get('weight_decay', 1e-4),
                'gradient_clip_norm': training_config.get('gradient_clip_norm', 1.0),
                'label_smoothing': training_config.get('label_smoothing', 0.1),
                'optimizer_name': training_config.get('optimizer', 'adam'),
                'scheduler_name': training_config.get('scheduler', 'cosine'),
                'early_stopping_patience': training_config.get('early_stopping_patience', 10),
                'projection_dim': getattr(main_expert.params, 'projection_dim', 192),
                'frame_length': getattr(main_expert.params, 'frame_length', 16),
                'time_step': getattr(main_expert.params, 'time_step', 16),
                'filter_attention_head': getattr(main_expert.params, 'filter_attention_head', 4),
                'conv_kernels': getattr(main_expert.params, 'conv_kernels', [3, 7, 15, 31, 31, 31]),
                'token_based': getattr(main_expert.params, 'token_based', False),
                'random_seed': self.config.seed,
                'device': self.config.device,
                'output_dir': self.config.output_dir,
                'save_checkpoints': self.config.save_checkpoints,
                'verbose': self.config.verbose,
                'show_train_verbose': 1 if self.config.verbose else 0,
                'plot_learning_curves': visualization_config.get('plot_learning_curves', True),
                'plot_confusion_matrix': visualization_config.get('plot_confusion_matrix', True),
            }
        except Exception as e:
            error_msg = f"配置参数获取失败: {e}"
            if self.logger: self.logger.error(error_msg)
            else: print(f"✗ {error_msg}")
            raise

    def _get_hardcoded_parameters(self) -> Dict[str, Any]:
        """硬编码参数（向后兼容）"""
        return {
            'experiment_name': 'Hardcoded_Experiment',
            'description': '使用硬编码参数的实验',
            'dataset_name': 'UCI', 'data_config': 'BALANCED',
            'activity_labels': ['Walk', 'Upstair', 'Downstair', 'Sit', 'Stand', 'Lay'],
            'activity_count': 6, 'architecture': 'HART',
            'segment_size': 128, 'num_input_channels': 6,
            'input_shape': (128, 6), 'batch_size': 256,
            'learning_rate': 5e-3, 'local_epoch': 50,
            'dropout_rate': 0.3, 'weight_decay': 1e-4,
            'gradient_clip_norm': 1.0, 'label_smoothing': 0.1,
            'optimizer_name': 'adam', 'scheduler_name': 'cosine',
            'early_stopping_patience': 10,
            'projection_dim': 192, 'frame_length': 16, 'time_step': 16,
            'filter_attention_head': 4, 'conv_kernels': [3, 7, 15, 31, 31, 31],
            'token_based': False, 'random_seed': 42, 'device': 'auto',
            'output_dir': './results/hardcoded_experiment',
            'save_checkpoints': True, 'verbose': True, 'show_train_verbose': 1,
            'plot_learning_curves': True, 'plot_confusion_matrix': True,
        }

    def setup_environment(self, params: Dict[str, Any]) -> None:
        """设置训练环境"""
        try:
            seed = params['random_seed']
            random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed(seed); torch.cuda.manual_seed_all(seed)
            
            device_config = params['device']
            if device_config == 'auto':
                if torch.cuda.is_available(): self.device = torch.device('cuda')
                elif torch.backends.mps.is_available(): self.device = torch.device('mps')
                else: self.device = torch.device('cpu')
            else: self.device = torch.device(device_config)
            print(f"使用设备: {self.device}")
            
            self.output_dir = Path(params['output_dir'])
            self.output_dir.mkdir(parents=True, exist_ok=True)
            
            self.setup_logging(self.output_dir, params['verbose'])
            self.gradient_clip_norm = params.get('gradient_clip_norm', 0)
            
            torch.backends.cudnn.benchmark = True
            torch.backends.cudnn.deterministic = False
            self.logger.info("环境设置完成")
        except Exception as e:
            print(f"环境设置失败: {e}")
            raise

    def load_data(self, params: Dict[str, Any]) -> Tuple[torch.utils.data.DataLoader, ...]:
        """加载数据集 (修复多进程问题)"""
        self.logger.info(f"开始加载数据集: {params['dataset_name']}")
        try:
            client_count = utils.return_client_by_dataset(params['dataset_name'])
            loaded_dataset = utils.load_dataset_pytorch(
                params['dataset_name'], client_count, params['data_config'], 
                params['random_seed'], './datasets/'
            )
            
            train_data, train_label = loaded_dataset.central_train_data, loaded_dataset.central_train_label
            test_data, test_label = loaded_dataset.central_test_data, loaded_dataset.central_test_label
            
            if hasattr(loaded_dataset, 'central_dev_data') and loaded_dataset.central_dev_data is not None:
                dev_data, dev_label = loaded_dataset.central_dev_data, loaded_dataset.central_dev_label
            else:
                train_np, dev_data, train_label_np, dev_label = train_test_split(
                    train_data.numpy(), train_label.numpy(), test_size=0.125, 
                    random_state=params['random_seed'], stratify=train_label.numpy()
                )
                train_data, train_label = torch.FloatTensor(train_np), torch.LongTensor(train_label_np)
                dev_data, dev_label = torch.FloatTensor(dev_data), torch.LongTensor(dev_label)
            
            train_dataset = utils.HARDataset(train_data, train_label)
            dev_dataset = utils.HARDataset(dev_data, dev_label)
            test_dataset = utils.HARDataset(test_data, test_label)
            
            common_loader_params = {'batch_size': params['batch_size'], 'num_workers': 0, 'pin_memory': False}
            train_loader = DataLoader(train_dataset, shuffle=True, **common_loader_params)
            dev_loader = DataLoader(dev_dataset, shuffle=False, **common_loader_params)
            test_loader = DataLoader(test_dataset, shuffle=False, **common_loader_params)
            
            self.logger.info(f"数据加载完成: 训练={len(train_dataset)}, 验证={len(dev_dataset)}, 测试={len(test_dataset)}")
            self.logger.info("注意: DataLoader使用单进程模式(num_workers=0)以保证兼容性")
            
            return train_loader, dev_loader, test_loader, train_label.cpu().numpy()
        except Exception as e:
            self.logger.error(f"数据加载失败: {e}"); raise

    def create_model(self, params: Dict[str, Any]) -> nn.Module:
        """创建模型"""
        self.logger.info(f"创建模型: {params['architecture']}, 输入形状={params['input_shape']}, 类别数={params['activity_count']}")
        try:
            model_map = {
                "HART": model.cbranchformer_har_base,
                "MOBILEHART_XS": model.MobileHART_XS
            }
            model_class = model_map.get(params['architecture'], model.MobileHART_XS)
            self.model = model_class(
                input_shape=params['input_shape'], activity_count=params['activity_count'], **params
            ).to(self.device)
            
            total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
            self.logger.info(f"模型创建完成，总参数数量: {total_params:,}")
            return self.model
        except Exception as e:
            self.logger.error(f"模型创建失败: {e}"); raise

    def setup_training(self, params: Dict[str, Any], train_labels: np.ndarray) -> None:
        """设置训练组件"""
        try:
            weights = class_weight.compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels)
            class_weights = torch.FloatTensor(weights).to(self.device)
            self.criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=params['label_smoothing'])
            
            optimizer_name = params['optimizer_name'].lower()
            opt_class = {'adam': optim.Adam, 'adamw': optim.AdamW, 'sgd': optim.SGD}.get(optimizer_name)
            if not opt_class: raise ValueError(f"不支持的优化器: {optimizer_name}")
            self.optimizer = opt_class(self.model.parameters(), lr=params['learning_rate'], weight_decay=params['weight_decay'])
            
            scheduler_name = params['scheduler_name']
            if scheduler_name == 'cosine': self.scheduler = CosineAnnealingLR(self.optimizer, T_max=params['local_epoch'])
            elif scheduler_name == 'step': self.scheduler = StepLR(self.optimizer, step_size=30, gamma=0.1)
            elif scheduler_name == 'plateau': self.scheduler = ReduceLROnPlateau(self.optimizer, 'min', patience=5)
            else: self.scheduler = None
            
            self.logger.info(f"训练组件设置完成: 优化器={optimizer_name}, 调度器={scheduler_name}")
        except Exception as e:
            self.logger.error(f"训练组件设置失败: {e}"); raise

    def train_epoch(self, train_loader: DataLoader) -> Tuple[float, float]:
        """训练一个epoch"""
        self.model.train()
        total_loss, correct, total = 0.0, 0, 0
        for data, targets in train_loader:
            data, targets = data.to(self.device), targets.to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(data)
            loss = self.criterion(outputs, targets)
            loss.backward()
            if self.gradient_clip_norm > 0: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip_norm)
            self.optimizer.step()
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
        avg_loss = total_loss / len(train_loader)
        accuracy = 100. * correct / total
        return avg_loss, accuracy

    def validate_epoch(self, val_loader: DataLoader) -> Tuple[float, float]:
        """验证一个epoch"""
        self.model.eval()
        total_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for data, targets in val_loader:
                data, targets = data.to(self.device), targets.to(self.device)
                outputs = self.model(data)
                loss = self.criterion(outputs, targets)
                total_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        avg_loss = total_loss / len(val_loader)
        accuracy = 100. * correct / total
        return avg_loss, accuracy

    def run(self):
        """执行完整的训练和评估流程"""
        try:
            params = self.get_parameters()
            self.setup_environment(params)
            self.logger.info(f"--- 实验开始: {params['experiment_name']} ---")
            self.logger.info(f"参数:\n{json.dumps(params, indent=2, ensure_ascii=False)}")

            train_loader, dev_loader, test_loader, train_labels = self.load_data(params)
            self.create_model(params)
            self.setup_training(params, train_labels)

            self.logger.info("--- 开始训练 ---")
            best_val_accuracy, patience_counter = 0.0, 0
            early_stopping = params.get('early_stopping_patience', 10)

            for epoch in range(params['local_epoch']):
                self.logger.info(f"\nEpoch {epoch+1}/{params['local_epoch']}")
                train_loss, train_acc = self.train_epoch(train_loader)
                val_loss, val_acc = self.validate_epoch(dev_loader)
                self.logger.info(f"训练 - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
                self.logger.info(f"验证 - Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%")
                self.history['train_loss'].append(train_loss); self.history['train_accuracy'].append(train_acc)
                self.history['val_loss'].append(val_loss); self.history['val_accuracy'].append(val_acc)
                
                if self.scheduler:
                    if isinstance(self.scheduler, ReduceLROnPlateau): self.scheduler.step(val_loss)
                    else: self.scheduler.step()
                
                if val_acc > best_val_accuracy:
                    best_val_accuracy, patience_counter = val_acc, 0
                    if params['save_checkpoints']:
                        torch.save(self.model.state_dict(), self.output_dir / 'best_model.pth')
                        self.logger.info(f"保存最佳模型 (验证准确率: {best_val_accuracy:.2f}%)")
                else:
                    patience_counter += 1
                
                if patience_counter >= early_stopping:
                    self.logger.info(f"早停触发 (patience: {early_stopping})"); break

            self.logger.info(f"训练完成！最佳验证准确率: {best_val_accuracy:.2f}%")
            if params.get('plot_learning_curves', True): self.plot_learning_curves()
            
            self.logger.info("--- 开始测试评估 ---")
            best_model_path = self.output_dir / 'best_model.pth'
            if best_model_path.exists():
                self.model.load_state_dict(torch.load(best_model_path))
            self.evaluate_model(test_loader, params)
            self.logger.info("--- 实验完成 ---")
            
        except Exception as e:
            error_msg = f"训练过程中发生错误: {e}"
            if self.logger: self.logger.error(error_msg, exc_info=True)
            else: print(f"✗ {error_msg}"); import traceback; traceback.print_exc()
            raise

    def evaluate_model(self, test_loader: DataLoader, params: Dict[str, Any]) -> None:
        """评估模型"""
        self.model.eval()
        test_loss, correct, total = 0.0, 0, 0
        all_preds, all_targets = [], []
        with torch.no_grad():
            for data, targets in test_loader:
                data, targets = data.to(self.device), targets.to(self.device)
                outputs = self.model(data)
                loss = self.criterion(outputs, targets)
                test_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
                all_preds.extend(predicted.cpu().numpy()); all_targets.extend(targets.cpu().numpy())
        
        test_acc = 100. * correct / total
        f1 = f1_score(all_targets, all_preds, average='weighted')
        self.logger.info(f"测试结果: Accuracy={test_acc:.4f}, F1-Score={f1:.4f}")
        if params['plot_confusion_matrix']:
            self.plot_confusion_matrix(all_targets, all_preds, params['activity_labels'])

    def plot_confusion_matrix(self, y_true, y_pred, labels):
        """绘制混淆矩阵"""
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
        plt.title('Confusion Matrix'); plt.ylabel('True Label'); plt.xlabel('Predicted Label')
        cm_path = self.output_dir / 'confusion_matrix.png'
        plt.savefig(cm_path)
        self.logger.info(f"混淆矩阵已保存至 {cm_path}")
        plt.show()

    def plot_learning_curves(self):
        """绘制学习曲线"""
        plt.figure(figsize=(12, 5))
        plt.subplot(1, 2, 1)
        plt.plot(self.history['train_loss'], label='Train Loss')
        plt.plot(self.history['val_loss'], label='Validation Loss')
        plt.title('Loss Curves'); plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend()
        plt.subplot(1, 2, 2)
        plt.plot(self.history['train_accuracy'], label='Train Accuracy')
        plt.plot(self.history['val_accuracy'], label='Validation Accuracy')
        plt.title('Accuracy Curves'); plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.legend()
        plt.tight_layout()
        curves_path = self.output_dir / 'learning_curves.png'
        plt.savefig(curves_path)
        self.logger.info(f"学习曲线已保存至 {curves_path}")
        plt.show()

print("✓ ConfigurableTrainer 类已修复并完整加载")

## 步骤 4: 创建配置文件
这是本脚本的核心驱动力。下面的单元格将创建一个名为 `config.yaml` 的文件。你可以直接修改这个单元格中的内容来改变实验参数，而无需触碰上面的类定义代码。

In [None]:
%%writefile config.yaml

# 实验元数据
name: "UCI_HART_Baseline_Experiment"
description: "使用HART模型在UCI-HAR数据集上进行的基线实验"
seed: 42
device: "auto"  # 'auto', 'cpu', 'cuda'
output_dir: "./results/uci_hart_baseline"
save_checkpoints: true
verbose: true

# 数据集配置
dataset:
  name: "UCI"
  activity_labels: ['Walk', 'Upstair', 'Downstair', 'Sit', 'Stand', 'Lay']
  modalities:
    - name: "accel"
      channels: 3
      sequence_length: 128
    - name: "gyro"
      channels: 3
      sequence_length: 128

# 模型架构配置
architecture:
  dropout_rate: 0.3
  # 定义专家模型，可以有多个，但目前只使用第一个
  experts:
    expert1:
      type: "HART"
      params:
        projection_dim: 192
        frame_length: 16
        time_step: 16
        filter_attention_head: 4
        conv_kernels: [3, 7, 15, 31, 31, 31]
        token_based: false

# 训练过程配置
training:
  epochs: 20  # 为了快速演示，减少epoch数量
  batch_size: 128
  learning_rate: 0.001
  optimizer: "adamw"  # 'adam', 'adamw', 'sgd'
  weight_decay: 0.0001
  scheduler: "cosine" # 'cosine', 'step', null
  label_smoothing: 0.1
  gradient_clip_norm: 1.0
  early_stopping_patience: 10

# 可视化配置
visualization:
  plot_learning_curves: true
  plot_confusion_matrix: true


## 步骤 5: 启动训练
现在，一切准备就绪。运行下面的单元格来实例化 `ConfigurableTrainer` 并启动由 `config.yaml` 文件定义的完整训练流程。

In [None]:
if __name__ == '__main__':
    try:
        config_file = 'config.yaml'
        
        if not os.path.exists(config_file):
            print(f"⚠ 配置文件 {config_file} 不存在，将使用传统硬编码模式")
            trainer = ConfigurableTrainer(config_path=None)
        else:
            print(f"✓ 使用配置文件: {config_file}")
            trainer = ConfigurableTrainer(config_path=config_file)
        
        trainer.run()
        
    except KeyboardInterrupt:
        print("\n用户中断训练")
    except Exception as e:
        # 错误已在run方法中记录，这里仅为捕获最终异常
        print(f"\n训练流程因未捕获的异常而终止。请检查日志。")
