# 引入必要的库

In [None]:
import bisect
import random
import h5py
import numpy as np
import os
import sys
import collections
from itertools import cycle
from typing import Tuple, Union, List, Dict
import random as sys_random
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
import seaborn as sns

import matplotlib.pyplot as plt
import matplotlib
from matplotlib.font_manager import FontManager
# matplotlib.use('Qt5Agg')
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS']

import torch
import torch.nn as nn
import torch.functional as F
from torch.utils.data import Dataset, DataLoader


from pywayne.tools import wayne_print
from pywayne.dsp import butter_bandpass_filter

import neptune
from neptune.integrations.python_logger import NeptuneHandler
from neptune.types import File

import logging
logger = logging.getLogger('micro-hand-gesture-logger')

device = 'cpu'

wayne_print(f'{torch.__version__=}', 'yellow')
wayne_print(f'{torch.cuda.is_available()}', 'green')
wayne_print(f'{torch.backends.mps.is_available()}', 'green')

# 设置随机种子

In [2]:
RANDOM_SEED=1

def set_torch_seed(seed=RANDOM_SEED):
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)  # Numpy module.
    sys_random.seed(seed)  # Python random module.
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# 数据集描述

In [None]:
with h5py.File('aw10_data.h5', 'r') as h5:
    for scene in sorted(h5.keys()):
        attr = h5[scene].attrs
        wayne_print(f"{scene=}, {len(h5[scene])=}, {attr['date']=}, {attr['force_level']=}, {attr['handness']=}, {attr['note']=}, {attr['scene_kw']=}, {attr['scene_property']=}\n")

# 数据集预处理

In [None]:
plt.close('all')
with h5py.File('aw10_data.h5', 'r+') as h5:
    for scene_id, scene in enumerate(sorted(h5.keys())):
        for idx in h5[scene].keys():
            wayne_print(f'{scene=}, {idx=}')
            kw = f'{scene}/{idx}'
            acc = h5[kw]['acc_rawdata']
            acc_filtered = butter_bandpass_filter(
                np.array(acc) / 9.81, order=2, lo=0.1, hi=40, fs=100.0, btype='bandpass', realtime=False
            )
            gyro = h5[kw]['gyro_rawdata']
            gyro_filtered = gyro
            
            if scene_id == 0 and idx == '0':
                fig, ax = plt.subplots(2, 2, sharex='all')
                ax[0][0].plot(acc)
                ax[1][0].plot(acc_filtered)
                ax[0][1].plot(gyro)
                ax[1][1].plot(gyro_filtered)
                plt.show()
            
            if 'acc_filtered' in h5[kw]:
                del h5[kw]['acc_filtered']
            if 'gyro_filtered' in h5[kw]:
                del h5[kw]['gyro_filtered']

            h5[kw]['acc_filtered'] = acc_filtered
            h5[kw]['gyro_filtered'] = gyro_filtered
            print(acc.shape, gyro.shape, acc_filtered.shape, gyro_filtered.shape)
            del acc, gyro, acc_filtered, gyro_filtered

# 数据集划分

## 方法一、Intra-individual Stability
最理想情况，单条数据分成60% / 20% / 20%的训练集、验证集、测试集

In [None]:
dataset_labels = ('train', 'valid', 'test')
for dataset_label in dataset_labels:
    file_name = f'{dataset_label}.h5'
    if os.path.exists(file_name):
        os.remove(file_name)

h5 = h5py.File('aw10_data.h5', 'r')
h5s = {label: h5py.File(f'{label}.h5', 'a') for label in dataset_labels}
for scene in sorted(h5.keys()):
    attr = h5[scene].attrs
#     wayne_print(f"{scene=}, {len(h5[scene])=}, {attr['date']=}, {attr['force_level']=}, {attr['handness']=}, {attr['note']=}, {attr['scene_kw']=}, {attr['scene_property']=}\n")
    n_segments = len(h5[scene])
    all_idx = list(map(str, range(n_segments)))
    random.shuffle(all_idx)
    training_size = int(0.6 * n_segments)
    validation_size = int(0.2 * n_segments)
    test_size = n_segments - training_size - validation_size
    training_indices = all_idx[:training_size]
    validation_indices = all_idx[training_size:training_size + validation_size]
    test_indices = all_idx[-test_size:]
#     print(all_idx)
    for dataset_label, data_indices in zip(
        dataset_labels,
        (training_indices, validation_indices, test_indices)
    ):
        for i, idx in enumerate(data_indices):
#             print(f'{scene}/{idx}/acc_rawdata')
            h5s[dataset_label][f'{scene}/{i}/acc_rawdata'] = np.array(h5[scene][idx]['acc_rawdata'])
            h5s[dataset_label][f'{scene}/{i}/gyro_rawdata'] = np.array(h5[scene][idx]['gyro_rawdata'])
            h5s[dataset_label][f'{scene}/{i}/acc_filtered'] = np.array(h5[scene][idx]['acc_filtered'])
            h5s[dataset_label][f'{scene}/{i}/gyro_filtered'] = np.array(h5[scene][idx]['gyro_filtered'])
        
        for attr_k, attr_v in h5[scene].attrs.items():
            h5s[dataset_label][scene].attrs[attr_k] = attr_v

h5.close()
[h.close() for h in h5s.values()]

## 方法二、Cross-subject Transferability
1. 80%的人（个体）用于训练，20%的人（个体）用于测试，注意80%的人的数据量为全部数据的60%
2. 在线学习，先让用户打开app后采集每个场景的数组1～3次。

# 构建数据集 

In [None]:
y_idx2lbl = ['单击', '双击', '握拳', '左滑', '右滑', '鼓掌', '抖腕', '拍打', '日常']
y_lbl2idx = {l: i for i, l in enumerate(y_idx2lbl)}
y_lbl2onehot = {l: np.eye(len(y_idx2lbl))[i] for i, l in enumerate(y_idx2lbl)}
wayne_print(y_lbl2idx, 'green')
wayne_print(y_lbl2onehot, 'green')

class MHGDataSet(Dataset):
    def __init__(self, h5_path):
        self.h5_path = h5_path
        self.prefix_sum = [0]
        self.scenes = []
        self.init()
        self._all_data = None

    def init(self):
        with h5py.File(self.h5_path, 'r') as h5:
            for scene in sorted(h5.keys()):
                attr = h5[scene].attrs
#                 print(attr.keys())
                scene_len = len(h5[scene])
                self.prefix_sum.append(self.prefix_sum[-1] + scene_len)
                self.scenes.append(scene)
#                 wayne_print(f"{scene=}, {len(h5[scene])=}, {attr['date']=}, {attr['force_level']=}, {attr['handness']=}, {attr['note']=}, {attr['scene_kw']=}, {attr['scene_property']=}\n")
        print(self.prefix_sum)

    def __len__(self):
        return self.prefix_sum[-1]

    def __getitem__(self, idx):
        which_scene_idx, which_scene = self.get_scene_by_idx(idx)
        offset = str(idx - self.prefix_sum[which_scene_idx])
#         print(f'{offset=}')
        
        with h5py.File(self.h5_path, 'r') as h5:
            data = h5[which_scene][offset]  # 获取具体数据
            attr = dict(h5[which_scene].attrs)  # 转换属性为字典
#             print(which_scene)
#             print(22,h5[which_scene].attrs.keys())
            acc_rawdata = np.array(data['acc_rawdata'], dtype=float)
            gyro_rawdata = np.array(data['gyro_rawdata'], dtype=float)
            acc_filtered = np.array(data['acc_filtered'], dtype=float)
            gyro_filtered = np.array(data['gyro_filtered'], dtype=float)
#             x = np.c_[acc_rawdata, gyro_rawdata]
            x = torch.from_numpy(np.c_[acc_filtered, gyro_filtered])
            x = x.permute(1, 0)
            y = torch.from_numpy(y_lbl2onehot[attr['scene_kw']])
            return x.float(), y.float()
        
    def get_all_data(self):
        dtype = torch.float32
        x_all, y_all = [], []
        with h5py.File(self.h5_path, 'r') as h5:
            for scene in sorted(h5.keys()):
                for idx in h5[scene]:
                    kw = f'{scene}/{idx}'
                    x = torch.tensor(np.c_[
                        h5[kw]['acc_filtered'][()],
                        h5[kw]['gyro_filtered'][()]
                    ], dtype=dtype)
                    y = torch.tensor(y_lbl2onehot[h5[scene].attrs['scene_kw']], 
                                   dtype=dtype)
                    x_all.append(x)
                    y_all.append(y)
                    del x, y

        x_all_tensor = torch.stack(x_all).permute(0, 2, 1)
        y_all_tensor = torch.stack(y_all)
        print(x_all_tensor.shape, y_all_tensor.shape)
        del x_all, y_all
        return x_all_tensor, y_all_tensor
        
    def get_scene_by_idx(self, idx):
        which_scene_idx = bisect.bisect_right(self.prefix_sum, idx) - 1
        which_scene = self.scenes[which_scene_idx]
        return which_scene_idx, which_scene
    
    def get_attr_by_idx(self, idx):
        which_scene_idx, which_scene = self.get_scene_by_idx(idx)
        offset = str(idx - self.prefix_sum[which_scene_idx])
        
        with h5py.File(self.h5_path, 'r') as h5:
            data = h5[which_scene][offset]  # 获取具体数据
            attr = dict(h5[which_scene].attrs)  # 转换属性为字典
            return attr
        
    def visualize_by_idx(self, idx):
        plt.close('all')
        x, y = self[idx]
        attr = self.get_attr_by_idx(idx)
        fig, ax = plt.subplots(2, 1, sharex='all')
        ax[0].plot(x.T[:,:3])
        ax[0].legend(('x','y','z'))
        ax[1].plot(x.T[:,3:])
        ax[1].legend(('x','y','z'))
        [a.grid(True) for a in ax]
        plt.suptitle('_'.join(attr.values()))
        plt.show()

dataset = MHGDataSet('train.h5')
print(len(dataset))
x, y = dataset[400]
print(x.shape, y)
print(dataset.get_scene_by_idx(400))
print(dataset.get_attr_by_idx(400))
dataset.visualize_by_idx(400)

train_dataset = MHGDataSet('train.h5')
valid_dataset = MHGDataSet('valid.h5')
test_dataset = MHGDataSet('test.h5')
total_dataset = MHGDataSet('aw10_data.h5')

x, y = total_dataset.get_all_data()

# 定义网络结构

## baseline： vanilla CNN

In [None]:
train_batch_size = 32
valid_batch_size = 32
epoch_num = 1000
learning_rate = 0.001

dl_params = {
    'learning_rate': learning_rate,
    'optimizer': 'Adam',
    'train_batch_size': train_batch_size,
    'valid_batch_size': valid_batch_size,
    'epoch_num': epoch_num,
    'metrics': 'CrossEntropy'
}

class VanillaCNN(nn.Module):
    def __init__(self, num_classes):
        super(VanillaCNN, self).__init__()
        self.conv1 = nn.Conv1d(6, 12, kernel_size=3, padding=1, stride=1)
        self.bn1 = nn.BatchNorm1d(12)
        self.maxpool1 = nn.MaxPool1d(2, stride=2)
        self.conv2 = nn.Conv1d(12, 12, kernel_size=3, padding=1, stride=1)
        self.bn2 = nn.BatchNorm1d(12)
        self.maxpool2 = nn.MaxPool1d(4, stride=4)
        self.conv3 = nn.Conv1d(12, 6, kernel_size=3, padding=1, stride=1)
        self.bn3 = nn.BatchNorm1d(6)
#         self.maxpool3 = nn.MaxPool1d(4, stride=4)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(6 * 7, num_classes)
        
    def forward(self, x):
        xx = self.maxpool1(self.relu(self.bn1(self.conv1(x))))
        xx = self.maxpool2(self.relu(self.bn2(self.conv2(xx))))
        xx = self.relu(self.bn3(self.conv3(xx)))
        
#         print(xx.shape)
        xx = xx.view(xx.size(0), -1)  # flatten
#         print(xx.shape)
        xx = self.fc(xx)  # logits
        
        return xx # nn.Softmax(xx)
    
def count_parameters(model):
    ret = sum(p.numel() for p in model.parameters())
    wayne_print(f'{ret=}', 'yellow')
    return ret

def weight_init():
    for m in model.modules():
        for name, param in m.named_parameters():
            if type(m) in (nn.GRU, nn.LSTM, nn.RNN):
                if 'weight_ih' in name:
                    torch.nn.init.xavier_uniform_(param.data)
                elif 'weight_hh' in name:
                    torch.nn.init.orthogonal_(param.data)
                elif 'bias' in name:
                    param.data.fill_(0)
            else:
                if isinstance(m, nn.Linear):
                    nn.init.xavier_normal_(m.weight)
                    nn.initt.constant_(m.bias, 0)
                elif type(m) in (nn.Conv1d, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                elif typee(m) in (nn.BatchNorm1d, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
                    
model = VanillaCNN(len(y_idx2lbl)).to(device)
print(len(y_idx2lbl))
print(model)
count_parameters(model)

# 测试

In [None]:
data = np.loadtxt('data.txt', delimiter=',')
data[:,:3] = butter_bandpass_filter(
    data[:,:3] / 9.81,  # 转换为g
    order=2,
    lo=0.1,
    hi=40,
    fs=100.0,
    btype='bandpass',
    realtime=False
)
# print(data.shape)
# plt.close('all')
# plt.plot(data[:,:3])
# plt.show()


# 转换为torch tensor并调整维度
x = torch.from_numpy(data.T).float().unsqueeze(0)  # [1, 6, 60]
print(x.shape)
x = x.to('cpu')

# 模型预测
with torch.no_grad():
    output = model(x)
#     probabilities = torch.nn.functional.softmax(output, dim=1)
#     predicted_class = torch.argmax(output, dim=1).item()
#     confidence = probabilities[0][predicted_class].item()

# # 获取预测结果
# class_names = ['单击', '双击', '握拳', '左滑', '右滑', '鼓掌', '抖腕', '拍打', '日常']
# predicted_label = class_names[predicted_class]

# print(f"Predicted gesture: {predicted_label} (confidence: {confidence:.3f})")


# 定义DataLoader

In [8]:
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=valid_batch_size, shuffle=True)

# 配置Neptune

In [None]:
use_neptune = True

if use_neptune:
    run = neptune.init_run(
        project="wangyendt/Micro-Hand-Gesture",
        api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJmYWQwNWMyNC1hYzY5LTRhZTEtOGJiYS1lZjdkNGE1NjRiY2UifQ==",
        source_files=["**/*.py", "**/*.ipynb", "config.yaml"],
        description='Micro Hand Gesture EDA',
        tags=['vanilla CNN', 'v0.0.1']
    )  # your credentials
    run["parameters"] = dl_params
    # Add additional parameters
    set_torch_seed()
    run['parameters/seed'] = RANDOM_SEED
    run['parameters/num_of_params'] = count_parameters(model)

    logger.addHandler(NeptuneHandler(run=run))

# 定义混淆矩阵函数

In [10]:
def calculate_confusion_matrix(y_pred, y_true, class_mapping, dataset_name):
    # 确保输入是 numpy 数组
    y_pred = np.array(y_pred)
    y_true = np.array(y_true)
    
    # 获取预测值和真实值中的类别
    classes = sorted(list(set(np.unique(y_pred)) | set(np.unique(y_true))))
    
    # 计算混淆矩阵
    cm = confusion_matrix(y_true, y_pred)
    
    # 获取映射后的标签
    labels = [class_mapping[i] for i in classes]
    
    plt.close('all')
    
    # 创建图形
    fig = plt.figure(figsize=(10, 8))
    
    # 创建主热力图，但不显示上方标签
    ax = sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                     xticklabels=[], yticklabels=labels)
    
    # 创建顶部的第二个x轴
    ax2 = ax.twiny()
    
    # 设置上方x轴的刻度和标签
    ax2.set_xlim(ax.get_xlim())
    ax2.set_xticks(np.arange(len(labels)) + 0.5)
    ax2.set_xticklabels(labels, ha='left')
    
    # 设置标签位置
    ax2.set_xlabel('Predicted')
    ax2.xaxis.set_label_position('top')
    
    plt.ylabel('True')
    plt.title(f'Confusion Matrix of {dataset_name}')
    
    # 调整布局以防止标签被裁剪
    plt.tight_layout()
    # plt.show()
    
    return cm, classes, fig

# 根据混淆矩阵计算各种metric函数

In [11]:
def calculate_advanced_metrics(confusion_matrix: np.ndarray, 
                             y_true: np.ndarray = None,
                             y_pred: np.ndarray = None,
                             y_prob: np.ndarray = None) -> dict:
    """
    计算混淆矩阵的各种高级评估指标，并可选择性地生成ROC曲线。
    
    参数:
        confusion_matrix: numpy.ndarray, 形状为(n, n)的混淆矩阵
        y_true: 真实标签，形状为(N,)
        y_pred: 预测的类别，形状为(N,)
        y_prob: 预测的概率分布，形状为(N, n_classes)，用于ROC曲线
    """
    if not isinstance(confusion_matrix, np.ndarray):
        confusion_matrix = np.array(confusion_matrix)
    
    n_classes = confusion_matrix.shape[0]
    
    # 1. 基础指标计算
    tp = np.diag(confusion_matrix)
    fp = np.sum(confusion_matrix, axis=0) - tp
    fn = np.sum(confusion_matrix, axis=1) - tp
    tn = np.sum(confusion_matrix) - (tp + fp + fn)
    
    # 2. 计算基础指标
    precision = np.zeros(n_classes)
    recall = np.zeros(n_classes)
    specificity = np.zeros(n_classes)
    f1_score = np.zeros(n_classes)
    
    for i in range(n_classes):
        precision[i] = tp[i] / (tp[i] + fp[i]) if (tp[i] + fp[i]) > 0 else 0
        recall[i] = tp[i] / (tp[i] + fn[i]) if (tp[i] + fn[i]) > 0 else 0
        specificity[i] = tn[i] / (tn[i] + fp[i]) if (tn[i] + fp[i]) > 0 else 0
        f1_score[i] = 2 * (precision[i] * recall[i]) / (precision[i] + recall[i]) if (precision[i] + recall[i]) > 0 else 0
    
    # 3. 计算高级指标
    # Cohen's Kappa
    total = np.sum(confusion_matrix)
    observed_accuracy = np.sum(tp) / total
    expected_accuracy = sum(np.sum(confusion_matrix, axis=0) * np.sum(confusion_matrix, axis=1)) / (total * total)
    kappa = (observed_accuracy - expected_accuracy) / (1 - expected_accuracy)
    
    # Balanced Accuracy
    balanced_accuracy = np.mean(recall)
    
    # Matthews Correlation Coefficient (MCC)
    def multiclass_mcc(confusion_matrix):
        t_sum = confusion_matrix.sum()
        s = (confusion_matrix / t_sum).sum()
        r = np.sum(confusion_matrix, axis=1)
        c = np.sum(confusion_matrix, axis=0)
        t = np.trace(confusion_matrix)
        n = t_sum * t - np.sum(r * c)
        d = np.sqrt((t_sum**2 - np.sum(c * c)) * (t_sum**2 - np.sum(r * r)))
        return n / d if d != 0 else 0
    
    mcc = multiclass_mcc(confusion_matrix)
    
    # 4. 如果提供了预测概率，计算ROC曲线相关指标
    roc_data = None
    if y_true is not None and y_prob is not None:
        roc_data = calculate_multiclass_roc(y_true, y_prob, n_classes)
    
    return {
        'basic_metrics': {
            'precision': {
                'per_class': precision.tolist(),
                'macro_avg': float(np.mean(precision))
            },
            'recall': {
                'per_class': recall.tolist(),
                'macro_avg': float(np.mean(recall))
            },
            'f1_score': {
                'per_class': f1_score.tolist(),
                'macro_avg': float(np.mean(f1_score))
            },
            'accuracy': float(observed_accuracy)
        },
        'advanced_metrics': {
            'specificity': {
                'per_class': specificity.tolist(),
                'macro_avg': float(np.mean(specificity))
            },
            'balanced_accuracy': float(balanced_accuracy),
            'cohen_kappa': float(kappa),
            'matthews_correlation_coefficient': float(mcc)
        },
        'roc_data': roc_data
    }

def calculate_multiclass_roc(y_true: np.ndarray, y_prob: np.ndarray, n_classes: int) -> Dict:
    """
    计算多分类ROC曲线（one-vs-rest方式）
    
    参数:
        y_true: 真实标签，形状为(N,)
        y_prob: 预测概率，形状为(N, n_classes)
        n_classes: 类别数量
    """
    # 将标签进行二值化处理
    y_true_bin = label_binarize(y_true, classes=range(n_classes))
    
    # 计算每个类别的ROC曲线和AUC
    fpr = {}
    tpr = {}
    roc_auc = {}
    
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_prob[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    
    # 计算微平均ROC曲线
    fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin.ravel(), y_prob.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
    
    return {
        'fpr': fpr,
        'tpr': tpr,
        'roc_auc': roc_auc
    }

def plot_roc_curves(roc_data: Dict, n_classes: int, class_names: List[str] = None):
    """
    绘制ROC曲线
    
    参数:
        roc_data: ROC曲线数据
        n_classes: 类别数量
        class_names: 类别名称列表（可选）
    """
    plt.close('all')
    fig = plt.figure(figsize=(10, 8))
    
    # 设置颜色循环
    colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'green', 'red', 'purple', 'brown', 'pink', 'gray'])
    
    # 绘制每个类别的ROC曲线
    for i, color in zip(range(n_classes), colors):
        class_label = f'类别 {i}' if class_names is None else class_names[i]
        plt.plot(roc_data['fpr'][i], roc_data['tpr'][i], color=color, lw=2,
                label=f'ROC曲线 {class_label} (AUC = {roc_data["roc_auc"][i]:0.2f})')
    
    # 绘制微平均ROC曲线
    plt.plot(roc_data['fpr']['micro'], roc_data['tpr']['micro'],
            label=f'微平均ROC曲线 (AUC = {roc_data["roc_auc"]["micro"]:0.2f})',
            color='deeppink', linestyle=':', linewidth=4)
    
    # 绘制对角线
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('假阳性率')
    plt.ylabel('真阳性率')
    plt.title('多分类ROC曲线 (One-vs-Rest)')
    plt.legend(loc="lower right")
    plt.grid(True)
    # plt.show()
    
    return fig

# 模型训练、模型验证、模型保存

In [None]:
import psutil
import os

def get_memory_usage():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024  # MB

print(f"当前内存使用: {get_memory_usage():.2f} MB")

In [None]:
train_losses = []
valid_losses = []
use_wandb = False

best_model_accuracy = collections.defaultdict(float)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(epoch_num):
    running_train_loss = []
    model.train()
    torch.set_grad_enabled(True)
    for i, loader in enumerate(train_loader):
        x, y = loader
        x = x.to(device)
        y = y.to(device)
#         print(type(x), type(y))
#         print(x.shape, y.shape)
        y_out = model(x)
        loss = criterion(y_out, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss = loss.item()
        log = f'epoch={epoch+1},i={i+1}/{len(train_loader)},training loss={train_loss}'
        sys.stdout.write("\r{0}".format(log))
        sys.stdout.flush()
        running_train_loss.append(train_loss)
    train_losses.append(np.mean(running_train_loss))
    print()
    print(f'epoch={epoch+1}, training loss={train_losses[-1]:.5f}')
    
    running_valid_loss = []
    model.eval()
    torch.set_grad_enabled(False)
    
    if epoch % 50 == 0:
        vsets = train_dataset, valid_dataset, total_dataset
        vset_names = ('train', 'valid', 'total')
    else:
        vsets = train_dataset, valid_dataset
        vset_names = ('train', 'valid')
    for vset, vset_name in zip(vsets, vset_names):
        vx, vy = vset.get_all_data()
        vx = vx.to(device)
        vy_out = model(vx).cpu()
        vloss = criterion(vy_out, vy)
        log = f'epoch={epoch+1}, vset loss={vloss.item()}'
        sys.stdout.write("\r{0}".format(log))
        sys.stdout.flush()

        vy_prob = torch.nn.functional.softmax(vy_out, dim=1).numpy()
        vy_pred = np.argmax(vy_out.numpy(), axis=1)
        vy_gt = np.argmax(vy.numpy(), axis=1)
        cm, classes, cm_fig = calculate_confusion_matrix(vy_pred, vy_gt, y_idx2lbl, vset_name)
        
        print(cm.shape, vy_gt.shape, vy_pred.shape)
        # 计算指标
        metrics = calculate_advanced_metrics(
            cm,
            y_true=vy_gt,
            y_pred=vy_pred,
            y_prob=vy_prob
        )
        # 绘制ROC曲线
        if metrics['roc_data']:
            roc_fig = plot_roc_curves(metrics['roc_data'], cm.shape[1])
        
        accuracy = metrics['basic_metrics']['accuracy']
        if accuracy > best_model_accuracy[vset_name]:
            best_model_accuracy[vset_name] = accuracy
            best_model_path = f'{vset_name}_epoch={epoch}_accuracy={accuracy:.3f}.pt'
            torch.save(
                model.state_dict(),
                best_model_path
            )
            if use_neptune:
                run[f'model/saved_model/{vset_name}_epoch={epoch}_accuracy={accuracy:.3f}'].upload(best_model_path)
                run[f'{vset_name}/best_accuracy'].log(accuracy)
                run[f'{vset_name}/best_accuracy_figs/epoch_{epoch+1}_confusion_matrix'].upload(cm_fig)
                run[f'{vset_name}/best_accuracy_figs/epoch_{epoch+1}_roc'].upload(roc_fig)
        

        # 打印结果
#         print("\n基础指标:")
#         print(f"准确率 (Accuracy): {metrics['basic_metrics']['accuracy']:.3f}")
#         print(f"宏平均精确率: {metrics['basic_metrics']['precision']['macro_avg']:.3f}")
#         print(f"宏平均召回率: {metrics['basic_metrics']['recall']['macro_avg']:.3f}")
#         print(f"宏平均F1分数: {metrics['basic_metrics']['f1_score']['macro_avg']:.3f}")

#         print("\n高级指标:")
#         print(f"Cohen's Kappa: {metrics['advanced_metrics']['cohen_kappa']:.3f}")
#         print(f"Matthews相关系数: {metrics['advanced_metrics']['matthews_correlation_coefficient']:.3f}")
#         print(f"平衡准确率: {metrics['advanced_metrics']['balanced_accuracy']:.3f}")

        
        if use_neptune:
            run[f'{vset_name}/loss'].log(train_losses[-1])
            run[f'{vset_name}/figs/epoch_{epoch+1}_confusion_matrix'].upload(cm_fig)
            run[f'{vset_name}/accuracy'].log(metrics['basic_metrics']['accuracy'])
            run[f'{vset_name}/precision'].log(metrics['basic_metrics']['precision']['macro_avg'])
            run[f'{vset_name}/recall'].log(metrics['basic_metrics']['recall']['macro_avg'])
            run[f'{vset_name}/f1_score'].log(metrics['basic_metrics']['f1_score']['macro_avg'])
            run[f'{vset_name}/cohen_kappa'].log(metrics['advanced_metrics']['cohen_kappa'])
            run[f'{vset_name}/matthews_correlation_coefficient'].log(metrics['advanced_metrics']['matthews_correlation_coefficient'])
            run[f'{vset_name}/balanced_accuracy'].log(metrics['advanced_metrics']['balanced_accuracy'])
            run[f'{vset_name}/figs/epoch_{epoch+1}_roc'].upload(roc_fig)
            
    if use_neptune:
        run[f'epoch'] = epoch

if use_neptune:
    run.stop()