# MazeruHAR 动态训练引擎

欢迎使用 MazeruHAR 的配置驱动训练引擎。此 Notebook 旨在提供一个灵活、可重现且易于使用的训练流程。

**核心理念:** **代码只写一次，实验配置万千。**

您只需要执行以下两个简单步骤即可开始训练：

1.  **配置实验**: 修改位于 `config/` 目录下的 `.yaml` 配置文件。您可以复制 `config/default_configs/shl_config.yaml` 并根据您的需求进行调整，比如更换数据集、模型架构或超参数。
2.  **运行 Notebook**: 在下面的 **“实验配置”**单元格中设置好配置文件的路径，然后从头到尾运行此 Notebook 即可。

---

## 步骤 1: 环境设置与库导入

此单元格负责导入所有必需的库并设置初始环境。它整合了项目所需的所有依赖项。

In [None]:
# 标准库导入
import os
import sys
import time
import yaml
import random
import logging
import json
from pathlib import Path
from typing import Dict, Any, Optional, List, Tuple

# 第三方库导入
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, ReduceLROnPlateau
from sklearn.metrics import f1_score, confusion_matrix, classification_report
from sklearn.utils import class_weight
from sklearn.model_selection import train_test_split
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score

# --- 项目内部模块导入 ---
# 为了确保 Notebook 的可移植性，我们将关键的模块代码直接包含进来，
# 同时保留从文件导入的逻辑作为备用方案。
try:
    # 优先从文件系统导入
    from config.config_loader import ConfigLoader
    from model_layer.dynamic_har_model import create_dynamic_har_model
    import utils_torch as utils
    print("成功从项目文件导入模块。")
except ImportError as e:
    print(f"从文件导入模块失败: {e}。将使用 Notebook 内联定义。")
    
    

print("所有模块已准备就绪。")

## 步骤 2: 实验配置

**这是您唯一需要修改的单元格。**

请将 `CONFIG_PATH`变量设置为您想要使用的配置文件的路径。所有实验参数都将从此文件加载。

In [None]:
# ================== 核心配置 ==================
# 🔥 只需修改此处的配置文件路径即可开始新的实验
CONFIG_PATH = 'config/default_configs/shl_config.yaml'  # 默认使用根目录下的config.yaml
# ==============================================

# 检查配置文件是否存在
if not os.path.exists(CONFIG_PATH):
    print(f"❌ 错误: 配置文件 '{CONFIG_PATH}' 未找到!")
    print("请确保路径正确，或创建一个新的配置文件。")
else:
    print(f"✓ 将使用配置文件: '{CONFIG_PATH}'")

## 步骤 3: 核心训练器类

下面的 `ConfigurableTrainer` 类是整个训练流程的核心。它封装了从配置加载、环境设置、数据处理、模型创建、训练、评估到结果可视化的所有逻辑。您无需修改此类中的任何代码。

In [None]:
# ===================================================================================
# ===== 将此代码块完整复制并替换 train.ipynb 中 "步骤 3: 核心训练器类" 的整个单元格 =====
# ===================================================================================
class ConfigurableTrainer:
    """配置驱动的训练器类"""

    def __init__(self, config_path: str):
        self.config_path = config_path
        # self.config 现在是一个字典
        self.config = self._load_config(config_path)
        self.params = self._extract_parameters()
        self.device = None
        self.output_dir = None
        self.logger = None
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.criterion = None
        self.history = {'train_loss': [], 'train_accuracy': [], 'val_loss': [], 'val_accuracy': [], 'val_f1': []}

    def _load_config(self, path: str) -> Dict[str, Any]:
        with open(path, 'r', encoding='utf-8') as f:
            return yaml.safe_load(f)

    def _extract_parameters(self) -> Dict[str, Any]:
        """将配置文件中的所有层级展平到一个字典中以便于访问。"""
        params = {}
        for key, value in self.config.items():
            if isinstance(value, dict):
                params.update(value)
            else:
                params[key] = value
        return params

    def setup_environment(self):
        """设置随机种子、设备和日志记录"""
        seed = self.params.get('seed', 42)
        random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed); torch.cuda.manual_seed_all(seed)
        
        device_pref = self.params.get('device', 'auto')
        if device_pref == 'auto':
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = torch.device(device_pref)
        
        self.output_dir = Path(self.params.get('output_dir', './results/default_experiment'))
        self.output_dir.mkdir(parents=True, exist_ok=True)

        # 设置日志
        logging.basicConfig(level=logging.INFO, 
                            format='%(asctime)s [%(levelname)s] - %(message)s',
                            handlers=[logging.FileHandler(self.output_dir / 'training.log'), logging.StreamHandler(sys.stdout)])
        self.logger = logging.getLogger()
        self.logger.info(f"环境设置完成。设备: {self.device}, 输出目录: {self.output_dir}")

    def load_data(self) -> Tuple[DataLoader, DataLoader, DataLoader, np.ndarray]:
        """根据配置加载并准备SHL多模态数据"""
        self.logger.info(f"加载SHL数据集: {self.params['name']}")
        
        from data_layer import SHLDataParser, UniversalHARDataset
        
        parser_config = {
            'name': 'SHL',
            'data_path': self.params.get('path', './datasets/'),
            'window_size': self.params.get('window_size', 128),
            'step_size': self.params.get('step_size', 64),
            'sample_rate': self.params.get('sample_rate', 100),
            'modalities': {
                'imu': {'enabled': True, 'channels': 6},
                'pressure': {'enabled': True, 'channels': 1}
            }
        }
        
        parser = SHLDataParser(parser_config)
        train_dataset = UniversalHARDataset(parser, split='train')
        dev_dataset = UniversalHARDataset(parser, split='val')
        test_dataset = UniversalHARDataset(parser, split='test')
        
        train_labels = []
        for _, label in train_dataset:
            train_labels.append(label)
        train_labels = np.array(train_labels)
        
        common_params = {
            'batch_size': self.params['batch_size'], 
            'num_workers': self.params.get('num_workers', 0), 
            'pin_memory': self.params.get('pin_memory', True)
        }
        
        train_loader = DataLoader(train_dataset, shuffle=True, **common_params)
        dev_loader = DataLoader(dev_dataset, shuffle=False, **common_params)
        test_loader = DataLoader(test_dataset, shuffle=False, **common_params)
        
        self.logger.info("SHL多模态数据加载完成。")
        self.logger.info(f"训练集大小: {len(train_dataset)}")
        self.logger.info(f"验证集大小: {len(dev_dataset)}")
        self.logger.info(f"测试集大小: {len(test_dataset)}")
        
        return train_loader, dev_loader, test_loader, train_labels
        
    def build_model(self):
        """根据配置构建多模态模型"""
        self.logger.info("构建SHL多模态模型...")
        
        from model_layer import create_dynamic_har_model
        
        # 【修复 1】: 使用字典的 `in` 关键字检查，而不是 hasattr
        if 'architecture' not in self.config:
            raise ValueError("配置文件缺少'architecture'部分")
        
        # 【修复 2】: 同样使用字典方式检查嵌套的键
        if 'experts' not in self.config['architecture']:
            raise ValueError("配置文件缺少'architecture.experts'部分")
        
        # create_dynamic_har_model 接收的是整个 config 字典，这是正确的
        self.model = create_dynamic_har_model(self.config).to(self.device)
        
        total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        self.logger.info(f"模型构建完成。总参数量: {total_params:,}")
        
        if hasattr(self.model, 'experts'):
            self.logger.info("专家模型信息:")
            for expert_name, expert in self.model.experts.items():
                expert_params = sum(p.numel() for p in expert.parameters() if p.requires_grad)
                self.logger.info(f"  {expert_name}: {expert_params:,} 参数")
    
    def setup_training_components(self, train_labels: np.ndarray):
        """设置优化器、损失函数和调度器"""
        weights = class_weight.compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels)
        class_weights = torch.FloatTensor(weights).to(self.device)
        self.criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=self.params.get('label_smoothing', 0.0))

        opt_name = self.params.get('optimizer', 'adamw').lower()
        opt_map = {'adam': optim.Adam, 'adamw': optim.AdamW, 'sgd': optim.SGD}
        self.optimizer = opt_map[opt_name](self.model.parameters(), lr=self.params['learning_rate'], weight_decay=self.params.get('weight_decay', 1e-4))

        scheduler_name = self.params.get('scheduler', 'cosine')
        if scheduler_name == 'cosine':
            self.scheduler = CosineAnnealingLR(self.optimizer, T_max=self.params['epochs'])
        elif scheduler_name == 'step':
            self.scheduler = StepLR(self.optimizer, step_size=30, gamma=0.1)
        else: 
            self.scheduler = None
        self.logger.info("训练组件设置完成。")

    def train(self, train_loader, dev_loader):
        """训练方法 - 关键修改多模态数据处理"""
        self.logger.info("--- 开始训练 ---")
        best_val_f1 = 0.0
        patience_counter = 0
        patience = self.params.get('early_stopping_patience', 10)

        for epoch in range(self.params['epochs']):
            self.model.train()
            train_loss, train_correct, train_total = 0, 0, 0

            for batch_idx, (data_dict, targets) in enumerate(train_loader):
                data_dict = {k: v.to(self.device) for k, v in data_dict.items()}
                targets = targets.to(self.device)

                self.optimizer.zero_grad()
                outputs = self.model(data_dict)
                loss = self.criterion(outputs, targets)
                loss.backward()

                if self.params.get('gradient_clip_norm', 0) > 0:
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(),
                        self.params['gradient_clip_norm']
                    )

                self.optimizer.step()

                train_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                train_total += targets.size(0)
                train_correct += (predicted == targets).sum().item()

            if self.scheduler:
                self.scheduler.step()

            val_loss, val_f1, val_acc = self.evaluate(dev_loader, is_test=False)
            train_acc = train_correct / train_total

            self.logger.info(
                f"Epoch {epoch+1}/{self.params['epochs']} | "
                f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc:.4f} | "
                f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}"
            )

            self.history['train_loss'].append(train_loss / len(train_loader))
            self.history['train_accuracy'].append(train_acc)
            self.history['val_loss'].append(val_loss)
            self.history['val_accuracy'].append(val_acc)
            self.history['val_f1'].append(val_f1)

            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                patience_counter = 0
                if self.params.get('save_checkpoints', True):
                    torch.save(self.model.state_dict(), self.output_dir / 'best_model.pth')
                    self.logger.info(f"新最佳模型已保存，验证F1分数: {best_val_f1:.4f}")
            else:
                patience_counter += 1

            if patience_counter >= patience:
                self.logger.info(f"早停触发! 最佳验证F1: {best_val_f1:.4f}")
                break

    def evaluate(self, data_loader, is_test=False):
        """评估方法 - 支持多模态数据"""
        self.model.eval()
        total_loss = 0
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for data_dict, targets in data_loader:
                data_dict = {k: v.to(self.device) for k, v in data_dict.items()}
                targets = targets.to(self.device)
                
                outputs = self.model(data_dict)
                loss = self.criterion(outputs, targets)
                
                total_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                
                all_preds.extend(predicted.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
        
        avg_loss = total_loss / len(data_loader)
        accuracy = accuracy_score(all_targets, all_preds)
        f1 = f1_score(all_targets, all_preds, average='weighted')
        
        return avg_loss, f1, accuracy

    def plot_learning_curves(self):
        plt.figure(figsize=(12, 5))
        
        plt.subplot(1, 2, 1)
        plt.plot(self.history['train_loss'], label='Train Loss')
        plt.plot(self.history['val_loss'], label='Validation Loss')
        plt.title('Loss vs. Epochs')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(self.history['train_accuracy'], label='Train Accuracy')
        plt.plot(self.history['val_accuracy'], label='Validation Accuracy')
        plt.title('Accuracy vs. Epochs')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        
        plt.tight_layout()
        save_path = self.output_dir / 'learning_curves.png'
        plt.savefig(save_path)
        self.logger.info(f"学习曲线已保存至 {save_path}")
        plt.show()

    def plot_confusion_matrix(self, y_true, y_pred, labels):
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        save_path = self.output_dir / 'confusion_matrix.png'
        plt.savefig(save_path)
        self.logger.info(f"混淆矩阵已保存至 {save_path}")
        plt.close()

    def validate_shl_config(self):
        """验证SHL配置文件的完整性 (已修复访问逻辑)"""
        required_sections = ['dataset', 'architecture', 'training']
        
        # 【修复 3】: 使用字典的 `in` 关键字检查
        for section in required_sections:
            if section not in self.config:
                raise ValueError(f"配置文件缺少必需的 '{section}' 部分")
        
        # 【修复 4】: 使用字典的方括号 `[]` 访问
        dataset_config = self.config['dataset']
        if dataset_config['name'] != 'SHL':
            raise ValueError("配置文件不是为SHL数据集设计的")
        
        # 【修复 5】: 使用正确的逻辑来检查模态是否存在
        modalities = dataset_config['modalities']
        modalities_present = [m['name'] for m in modalities]
        if 'imu' not in modalities_present or 'pressure' not in modalities_present:
            raise ValueError("SHL配置必须在 dataset.modalities 中同时包含 'imu' 和 'pressure' 模态")

        # 【修复 6】: 使用字典的方括号 `[]` 访问嵌套的键
        experts = self.config['architecture']['experts']
        expected_experts = ['imu_expert', 'pressure_expert']
        for expert in expected_experts:
            if expert not in experts:
                raise ValueError(f"配置文件缺少必需的专家模型: {expert}")
    
        self.logger.info("✓ SHL配置文件验证通过")
        
    def run(self):
        """完整的训练流程 - SHL多模态版本"""
        try:
            self.setup_environment()
            
            # 【注意】: 这里的 'use_config' 看起来是一个用于触发验证的自定义参数，
            # 你的 YAML 文件中没有这一项，所以 validate_shl_config() 默认不会运行。
            # 如果需要强制验证，可以临时修改这里的逻辑。
            if self.params.get('use_config', False):
                self.validate_shl_config()
            
            self.output_dir = Path(self.params.get('output_dir', './results/shl_experiment'))
            self.output_dir.mkdir(parents=True, exist_ok=True)
            
            self.logger.info("开始SHL多模态训练实验")
            self.logger.info(f"使用配置: {self.config_path}")
            
            train_loader, dev_loader, test_loader, train_labels = self.load_data()
            self.build_model()
            self.setup_training_components(train_labels)
            self.train(train_loader, dev_loader)
            
            self.logger.info("进行最终测试...")
            test_loss, test_f1, test_acc = self.evaluate(test_loader, is_test=True)
            self.logger.info(f"最终测试结果 - Loss: {test_loss:.4f}, Acc: {test_acc:.4f}, F1: {test_f1:.4f}")
            
            self.logger.info("SHL多模态训练实验完成!")
        
        except Exception as e:
            self.logger.error(f"训练过程中发生错误: {str(e)}")
            raise e

## 步骤 4: 执行训练流程

最后，我们实例化 `ConfigurableTrainer` 类并调用其 `run` 方法来启动整个训练和评估流程。所有操作都将由之前加载的配置驱动。

In [None]:
if __name__ == '__main__':
    try:
        # 使用SHL配置文件
        config_file = 'config/default_configs/shl_config.yaml'
        
        if not os.path.exists(config_file):
            print(f"❌ SHL配置文件 {config_file} 不存在!")
            print("请先创建SHL数据集的配置文件")
        else:
            print(f"✓ 使用SHL配置文件: {config_file}")
            trainer = ConfigurableTrainer(config_path=config_file)
            trainer.run()
        
    except KeyboardInterrupt:
        print("\n用户中断训练")
    except Exception as e:
        print(f"\nSHL训练流程发生错误: {str(e)}")
        import traceback
        traceback.print_exc()