In [1]:
import csv
from tqdm import tqdm

label_dict = {}
with open("../ComParE2017_Cold_4students/lab/ComParE2017_Cold.tsv", "r", encoding="utf-8") as f:
    reader = csv.DictReader(f, delimiter="\t")
    rows = list(reader)
    for row in tqdm(rows, desc="Loading labels"):
        label_dict[row["file_name"]] = row["Cold (upper respiratory tract infection)"]

Loading labels: 100%|██████████| 19101/19101 [00:00<00:00, 3819384.09it/s]


In [2]:
import os
def search_in_labels(filename, label_dict):
    base_name = os.path.splitext(filename)[0]
    
    if "_logmel" in base_name:
        base_name = base_name.replace("_logmel", "")
    if "_flipped" in base_name:
        base_name = base_name.replace("_flipped", "")
    
    parts = base_name.split("_")
    if len(parts) >= 2:
        audio_filename = f"{parts[0]}_{parts[1]}.wav"
    else:
        audio_filename = f"{base_name}.wav"
    
    return label_dict.get(audio_filename, None)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class AttentionMIL(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        # x: (batch, num_windows, feature_dim)
        attn_weights = self.attention(x)  # (batch, num_windows, 1)
        attn_weights = torch.softmax(attn_weights, dim=1)  # normalize
        weighted_x = x * attn_weights  # (batch, num_windows, feature_dim)
        return weighted_x.sum(dim=1), attn_weights  # 同时返回注意力权重

class SpectrogramSequenceClassifier(nn.Module):
    def __init__(self, 
                 cnn_out_dim=512, 
                 sequence_model='transformer',  # 'transformer', 'lstm', 'attention'
                 hidden_dim=256,
                 num_heads=8,
                 num_layers=2,
                 dropout=0.1):
        super().__init__()
        
        # CNN Encoder for each window
        self.cnn_encoder = nn.Sequential(
            # Block 1
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Block 2
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Block 3
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((4, 4)),  # 输出 (512, 4, 4)
            
            nn.Flatten(), 
            nn.Linear(2048, cnn_out_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout)
        )
        
        self.sequence_model_type = sequence_model
        self.hidden_dim = hidden_dim
        
        # 序列模型选择
        if sequence_model == 'transformer':
            self.pos_encoding = PositionalEncoding(cnn_out_dim, dropout)
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=cnn_out_dim,
                nhead=num_heads,
                dim_feedforward=hidden_dim,
                dropout=dropout,
                batch_first=True
            )
            self.sequence_model = nn.TransformerEncoder(encoder_layer, num_layers)
            self.global_pool = nn.AdaptiveAvgPool1d(1)  # 全局平均池化
            
        elif sequence_model == 'lstm':
            self.sequence_model = nn.LSTM(
                input_size=cnn_out_dim,
                hidden_size=hidden_dim,
                num_layers=num_layers,
                dropout=dropout if num_layers > 1 else 0,
                batch_first=True,
                bidirectional=True
            )
            lstm_out_dim = hidden_dim * 2  # 双向LSTM
            
        elif sequence_model == 'attention':
            self.sequence_model = SelfAttention(cnn_out_dim, hidden_dim, num_heads, dropout)

        elif sequence_model == 'attention_mil':
            # Attention-based MIL聚合
            self.sequence_model = AttentionMIL(cnn_out_dim, hidden_dim)
            classifier_input_dim = cnn_out_dim

        elif sequence_model == 'hybrid':
            # 混合方法：先Transformer再MIL
            self.pos_encoding = PositionalEncoding(cnn_out_dim, dropout)
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=cnn_out_dim,
                nhead=num_heads,
                dim_feedforward=hidden_dim,
                dropout=dropout,
                batch_first=True
            )
            self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
            self.mil_attention = AttentionMIL(cnn_out_dim, hidden_dim)
            classifier_input_dim = cnn_out_dim
            
        # 分类头
        if sequence_model == 'lstm':
            classifier_input_dim = lstm_out_dim
        else:
            classifier_input_dim = cnn_out_dim
            
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 32),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(32, 1)  # 二分类
        )
        
    def forward(self, x):
        batch_size, num_windows, C, H, W = x.shape
        
        # CNN特征提取
        x = x.view(-1, C, H, W)
        cnn_features = self.cnn_encoder(x)
        cnn_features = cnn_features.view(batch_size, num_windows, -1)
        
        # 初始化attention_weights为None
        attention_weights = None
        
        # 序列建模
        if self.sequence_model_type == 'transformer':
            x = self.pos_encoding(cnn_features)
            x = self.sequence_model(x)
            x = x.mean(dim=1)  # 全局平均池化
            
        elif self.sequence_model_type == 'lstm':
            x, _ = self.sequence_model(cnn_features)
            x = x[:, -1, :]
            
        elif self.sequence_model_type == 'attention_mil':
            # 使用Attention MIL聚合
            x, attention_weights = self.sequence_model(cnn_features)
            
        elif self.sequence_model_type == 'hybrid':
            # 先Transformer处理序列，再MIL聚合
            x = self.pos_encoding(cnn_features)
            x = self.transformer(x)
            x, attention_weights = self.mil_attention(x)
        
        # 分类
        output = self.classifier(x)
        
        # 根据训练状态决定返回值
        if self.sequence_model_type == 'attention_mil':
            # 如果是Attention MIL，返回注意力权重
            return output, attention_weights
        else:
            return output

# 位置编码（用于Transformer）
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:x.size(1), :].transpose(0, 1)
        return self.dropout(x)

# 自注意力模块
class SelfAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads, dropout=0.1):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=input_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        self.norm = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # 自注意力
        attn_output, attn_weights = self.multihead_attn(x, x, x)
        
        # 残差连接和层归一化
        x = self.norm(x + self.dropout(attn_output))
        
        # 全局平均池化
        x = x.mean(dim=1)  # (batch_size, input_dim)
        
        return x

In [4]:
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.data import Dataset
from PIL import Image
import os

class AddGaussianNoise(torch.nn.Module):
    def __init__(self, mean=0., std=0.05):
        super().__init__()
        self.mean = mean
        self.std = std
        
    def forward(self, tensor):
        return tensor + torch.randn_like(tensor) * self.std + self.mean
    
class SequenceSpectrogramDataset(Dataset):
    def __init__(self, image_paths, label_dict, window_size=128, stride=64, 
                 max_windows=10, is_training=True):
        self.image_paths = image_paths
        self.label_dict = label_dict
        self.window_size = window_size
        self.stride = stride
        self.max_windows = max_windows
        self.is_training = is_training
        
        if is_training:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                AddGaussianNoise(0, 0.02),

            ])
        else:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
            ])
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        filename = os.path.basename(image_path)
        label = search_in_labels(filename, self.label_dict)
        label_num = 1 if label == "C" else 0
        
        # 加载图像
        image = Image.open(image_path).convert("RGB")
        image_tensor = self.transform(image)
        
        _, H, W = image_tensor.shape
        assert H == 128, f"Image height must be 128, but got {H}"
        
        # 滑动窗口提取
        windows = []
        for start in range(0, W - self.window_size + 1, self.stride):
            window = image_tensor[:, :, start:start + self.window_size]
            windows.append(window)
        
        # 处理最后一个窗口
        if (W - self.window_size) % self.stride != 0:
            last_window = image_tensor[:, :, -self.window_size:]
            windows.append(last_window)
        
        # 确保至少有一个窗口
        if len(windows) == 0:
            pad_width = self.window_size - W
            image_padded = F.pad(image_tensor, (0, pad_width), mode='constant', value=0)
            window = image_padded[:, :, :self.window_size]
            windows.append(window)
        
        # 限制最大窗口数量
        if len(windows) > self.max_windows:
            if self.is_training:
                # 随机选择
                indices = torch.randperm(len(windows))[:self.max_windows]
                windows = [windows[i] for i in indices]
            else:
                # 均匀采样
                indices = torch.linspace(0, len(windows)-1, self.max_windows).long()
                windows = [windows[i] for i in indices]
        
        # 如果窗口不足，进行填充
        while len(windows) < self.max_windows:
            windows.append(windows[-1].clone())  # 复制最后一个窗口
        
        # 堆叠成序列
        windows_tensor = torch.stack(windows)  # (max_windows, 3, 128, 128)
        
        return windows_tensor, label_num

In [6]:
import os
import glob
from torch.utils.data import DataLoader, Dataset


data_split = ["train_files", "devel_files"]
img_dir = "../spectrograms_variable_width"  

def collect_image_paths_devel(split_name):
        sub_dir = os.path.join(img_dir, split_name)
        print(f"🔍 Looking for images in: {sub_dir}")
        
        if not os.path.exists(sub_dir):
            print(f"❌ Directory does not exist: {sub_dir}")
            return []
        
        png_files = glob.glob(os.path.join(sub_dir, "*.png"))
        
        filtered_files = [f for f in png_files if "flipped" not in os.path.basename(f)]
        
        print(f"📁 Found {len(png_files)} PNG files in {split_name}")
        print(f"📋 After filtering out 'flipped' files: {len(filtered_files)} files")
        
        return filtered_files

def collect_image_paths(split_name):
    sub_dir = os.path.join(img_dir, split_name)
    print(f"🔍 Looking for images in: {sub_dir}")
    
    if not os.path.exists(sub_dir):
        print(f"❌ Directory does not exist: {sub_dir}")
        return []
    
    png_files = glob.glob(os.path.join(sub_dir, "*.png"))
    print(f"📁 Found {len(png_files)} PNG files in {split_name}")
    
    return png_files

print("🚀 Collecting image paths...")
train_image_paths = collect_image_paths("train_files")
devel_image_paths = collect_image_paths_devel("devel_files")

train_dataset = SequenceSpectrogramDataset(
    image_paths=train_image_paths,
    label_dict=label_dict,
    window_size=128,
    stride=64,
    max_windows=5,
    is_training=True
)
val_dataset = SequenceSpectrogramDataset(
    image_paths=devel_image_paths,
    label_dict=label_dict,
    window_size=128,
    stride=64,
    max_windows=5, 
    is_training=False
)
train_loader = DataLoader(
    train_dataset,  
    batch_size=16,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset,  
    batch_size=32,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)



🚀 Collecting image paths...
🔍 Looking for images in: ../spectrograms_variable_width\train_files
📁 Found 10475 PNG files in train_files
🔍 Looking for images in: ../spectrograms_variable_width\devel_files
📁 Found 10607 PNG files in devel_files
📋 After filtering out 'flipped' files: 9596 files


In [8]:
from sklearn.metrics import accuracy_score, f1_score, recall_score
from tqdm import tqdm

def evaluate_sequence_model(model, dataloader, device, criterion):
    model.eval()
    all_preds, all_labels = [], []
    total_loss = 0.0
    num_batches = 0
    
    with torch.no_grad():
        for batch_windows, batch_labels in tqdm(dataloader, desc="Evaluating", leave=False):
            batch_windows = batch_windows.to(device)
            batch_labels = batch_labels.to(device)
            
            logits = model(batch_windows).squeeze()
            if logits.dim() == 0:
                logits = logits.unsqueeze(0)
            
            # 计算loss
            loss = criterion(logits, batch_labels.float())
            total_loss += loss.item()
            num_batches += 1
                
            preds = (torch.sigmoid(logits) > 0.6).long()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch_labels.cpu().numpy())
    
    # 计算平均loss
    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    
    # 计算指标
    f1 = f1_score(all_labels, all_preds, zero_division=0)
    uar = recall_score(all_labels, all_preds, average='macro', zero_division=0)
    
    return avg_loss, f1, uar, all_labels, all_preds

In [85]:
class MILLoss(nn.Module):
    def __init__(self, alpha=0.1):
        super().__init__()
        self.alpha = alpha
        self.bce_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(2.0))
        
    def forward(self, logits, labels, attention_weights=None):
        # 主要的分类loss
        classification_loss = self.bce_loss(logits, labels.float())
        
        # 注意力正则化loss（鼓励注意力分布的多样性）
        if attention_weights is not None:
            # 计算注意力熵，鼓励不要过度集中在少数windows
            attention_entropy = -(attention_weights * torch.log(attention_weights + 1e-8)).sum(dim=1).mean()
            total_loss = classification_loss - self.alpha * attention_entropy
        else:
            total_loss = classification_loss
            
        return total_loss

In [None]:

# 创建模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 选择序列模型类型
model = SpectrogramSequenceClassifier(
    cnn_out_dim=512,
    sequence_model='transformer',  # 'transformer', 'lstm', 'attention'
    hidden_dim=128,
    num_heads=4,
    num_layers=1,
    dropout=0.3
).to(device)

# 损失函数和优化器
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(4).to(device))
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6, weight_decay=1e-7)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3)

# 训练
num_epochs = 50
best_uar = 0.0
patience = 5
early_stopping_counter = 0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    all_preds, all_labels = [], []
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    
    for batch_windows, batch_labels in progress_bar:
        batch_windows = batch_windows.to(device)  # (batch_size, max_windows, 3, 128, 128)
        batch_labels = batch_labels.to(device)    # (batch_size,)
        
        optimizer.zero_grad()
        
        # 前向传播
        logits = model(batch_windows).squeeze()   # (batch_size,)
        if logits.dim() == 0:
            logits = logits.unsqueeze(0)
            
        loss = criterion(logits, batch_labels.float())
        
        # 反向传播
        loss.backward()
        optimizer.step()
        
        # 计算预测
        preds = (torch.sigmoid(logits) > 0.6).long()
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(batch_labels.cpu().numpy())
        
        running_loss += loss.item()
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    # 计算指标
    train_acc = accuracy_score(all_labels, all_preds)
    train_f1 = f1_score(all_labels, all_preds)
    
    # 验证
    avg_loss, f1, val_uar, val_labels, val_preds = evaluate_sequence_model(model, val_loader, device, criterion)
    
    print(f"Epoch {epoch+1}:")
    print(f"  Train - Loss: {running_loss/len(train_loader):.4f}, Acc: {train_acc:.4f}, F1: {train_f1:.4f}")
    print(f"  Val - Loss: {avg_loss:.4f}, F1: {f1:.4f}, UAR: {val_uar:.4f}")
    
    # 保存最佳模型
    if val_uar > best_uar:
        best_uar = val_uar
        torch.save(model.state_dict(), 'best_sequence_model_2.pth')
        print(f"  🌟 New best UAR: {best_uar:.4f}")
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
        print(f"  Early stopping counter: {early_stopping_counter}/{patience}")
    if early_stopping_counter >= patience:
        print("Early stopping triggered. Stopping training.")
        break
    
    scheduler.step(val_uar)

print(f"Training complete! Best UAR: {best_uar:.4f}")

Epoch 1: 100%|██████████| 655/655 [00:57<00:00, 11.41it/s, loss=0.8698]
                                                             

Epoch 1:
  Train - Loss: 1.0702, Acc: 0.8148, F1: 0.0000
  Val - Loss: 0.8332, F1: 0.0000, UAR: 0.5000
  🌟 New best UAR: 0.5000


Epoch 2: 100%|██████████| 655/655 [00:57<00:00, 11.45it/s, loss=1.3632]
                                                             

Epoch 2:
  Train - Loss: 1.0519, Acc: 0.8149, F1: 0.0271
  Val - Loss: 0.8168, F1: 0.1683, UAR: 0.5431
  🌟 New best UAR: 0.5431


Epoch 3: 100%|██████████| 655/655 [00:56<00:00, 11.51it/s, loss=0.8785]
                                                             

Epoch 3:
  Train - Loss: 1.0326, Acc: 0.8116, F1: 0.1373
  Val - Loss: 0.7855, F1: 0.1699, UAR: 0.5436
  🌟 New best UAR: 0.5436


Epoch 4: 100%|██████████| 655/655 [00:57<00:00, 11.49it/s, loss=1.0145]
                                                             

Epoch 4:
  Train - Loss: 1.0009, Acc: 0.8167, F1: 0.2649
  Val - Loss: 0.7860, F1: 0.2359, UAR: 0.5729
  🌟 New best UAR: 0.5729


Epoch 5: 100%|██████████| 655/655 [00:56<00:00, 11.51it/s, loss=0.7162]
                                                             

Epoch 5:
  Train - Loss: 0.9466, Acc: 0.8311, F1: 0.3989
  Val - Loss: 0.7870, F1: 0.2486, UAR: 0.5814
  🌟 New best UAR: 0.5814


Epoch 6: 100%|██████████| 655/655 [00:56<00:00, 11.53it/s, loss=0.5947]
                                                             

Epoch 6:
  Train - Loss: 0.8739, Acc: 0.8516, F1: 0.5332
  Val - Loss: 0.7957, F1: 0.2568, UAR: 0.5889
  🌟 New best UAR: 0.5889


Epoch 7: 100%|██████████| 655/655 [00:56<00:00, 11.51it/s, loss=0.5074]
                                                             

Epoch 7:
  Train - Loss: 0.8242, Acc: 0.8564, F1: 0.5804
  Val - Loss: 0.8267, F1: 0.2933, UAR: 0.6303
  🌟 New best UAR: 0.6303


Epoch 8: 100%|██████████| 655/655 [00:57<00:00, 11.46it/s, loss=0.7319]
                                                             

Epoch 8:
  Train - Loss: 0.7883, Acc: 0.8624, F1: 0.6042
  Val - Loss: 0.8353, F1: 0.2935, UAR: 0.6316
  🌟 New best UAR: 0.6316


Epoch 9: 100%|██████████| 655/655 [00:57<00:00, 11.48it/s, loss=0.6039]
                                                             

Epoch 9:
  Train - Loss: 0.7606, Acc: 0.8618, F1: 0.6130
  Val - Loss: 0.9824, F1: 0.2739, UAR: 0.6384
  🌟 New best UAR: 0.6384


Epoch 10: 100%|██████████| 655/655 [00:59<00:00, 11.06it/s, loss=0.3813]
                                                             

Epoch 10:
  Train - Loss: 0.7406, Acc: 0.8631, F1: 0.6218
  Val - Loss: 0.9685, F1: 0.2814, UAR: 0.6435
  🌟 New best UAR: 0.6435


Epoch 11: 100%|██████████| 655/655 [00:57<00:00, 11.37it/s, loss=0.8488]
                                                             

Epoch 11:
  Train - Loss: 0.7177, Acc: 0.8666, F1: 0.6372
  Val - Loss: 0.8634, F1: 0.2827, UAR: 0.6182
  Early stopping counter: 1/5


Epoch 12: 100%|██████████| 655/655 [00:57<00:00, 11.44it/s, loss=0.3843]
                                                             

Epoch 12:
  Train - Loss: 0.7111, Acc: 0.8684, F1: 0.6396
  Val - Loss: 0.9733, F1: 0.2842, UAR: 0.6417
  Early stopping counter: 2/5


Epoch 13: 100%|██████████| 655/655 [00:57<00:00, 11.49it/s, loss=0.2157]
                                                             

Epoch 13:
  Train - Loss: 0.6890, Acc: 0.8736, F1: 0.6548
  Val - Loss: 0.9375, F1: 0.2710, UAR: 0.6178
  Early stopping counter: 3/5


Epoch 14: 100%|██████████| 655/655 [00:57<00:00, 11.48it/s, loss=0.2539]
                                                             

Epoch 14:
  Train - Loss: 0.6848, Acc: 0.8752, F1: 0.6608
  Val - Loss: 1.0400, F1: 0.2671, UAR: 0.6236
  Early stopping counter: 4/5


Epoch 15: 100%|██████████| 655/655 [00:57<00:00, 11.45it/s, loss=0.2869]
                                                             

Epoch 15:
  Train - Loss: 0.6586, Acc: 0.8830, F1: 0.6769
  Val - Loss: 1.0461, F1: 0.2807, UAR: 0.6416
  Early stopping counter: 5/5
Early stopping triggered. Stopping training.
Training complete! Best UAR: 0.6435




: 

In [87]:
# 创建模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SpectrogramSequenceClassifier(
    cnn_out_dim=256,
    sequence_model='attention_mil',  # 使用MIL
    hidden_dim=128,
    dropout=0.3
).to(device)

# 使用MIL loss
criterion = MILLoss(alpha=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=1e-7)

In [None]:
import time
from sklearn.metrics import accuracy_score, f1_score, recall_score
num_epochs = 50
best_uar = 0.0
patience = 5
early_stopping_counter = 0
start_time = time.time()

# 记录训练历史
training_losses = []
validation_losses = []
training_uars = []
validation_uars = []

print("Starting training with Attention MIL...")
print(f"Model: {model.sequence_model_type}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

for epoch in range(num_epochs):
    # =================== 训练阶段 ===================
    model.train()
    running_loss = 0.0
    all_train_preds, all_train_labels = [], []
    
    print(f'\n{"="*80}')
    print(f'Epoch [{epoch+1}/{num_epochs}]')
    print(f'{"="*80}')
    
    train_progress = tqdm(train_loader, desc=f"Training Epoch {epoch+1}")
    
    for batch_idx, (batch_windows, batch_labels) in enumerate(train_progress):
        batch_windows = batch_windows.to(device)
        batch_labels = batch_labels.to(device)
        
        optimizer.zero_grad()
        
        # 修复：正确解包模型输出
        logits, _ = model(batch_windows)  # 忽略attention_weights
        logits = logits.squeeze()
        if logits.dim() == 0:
            logits = logits.unsqueeze(0)
            
        loss = criterion(logits, batch_labels.float())
        loss.backward()
        optimizer.step()
        
        # 计算训练预测
        with torch.no_grad():
            preds = (torch.sigmoid(logits) > 0.5).long()
            all_train_preds.extend(preds.cpu().numpy())
            all_train_labels.extend(batch_labels.cpu().numpy())
        
        running_loss += loss.item()
        train_progress.set_postfix({
            'loss': f'{loss.item():.4f}',
            'avg_loss': f'{running_loss/(batch_idx+1):.4f}'
        })
    
    # 计算训练指标
    epoch_train_loss = running_loss / len(train_loader)
    train_accuracy = accuracy_score(all_train_labels, all_train_preds)
    train_f1 = f1_score(all_train_labels, all_train_preds, zero_division=0)
    train_uar = recall_score(all_train_labels, all_train_preds, average='macro', zero_division=0)
    
    training_losses.append(epoch_train_loss)
    training_uars.append(train_uar)
    
    # =================== 验证阶段 ===================
    model.eval()
    val_running_loss = 0.0
    all_val_preds, all_val_labels = [], []
    all_attention_weights = []
    
    val_progress = tqdm(val_loader, desc=f"Validation Epoch {epoch+1}")
    
    with torch.no_grad():
        for batch_idx, (batch_windows, batch_labels) in enumerate(val_progress):
            batch_windows = batch_windows.to(device)
            batch_labels = batch_labels.to(device)
            
            # 修复：正确解包模型输出
            logits, attention_weights = model(batch_windows)
            
            # 保存attention weights（如果存在）
            if attention_weights is not None:
                all_attention_weights.append(attention_weights.cpu())
                
            logits = logits.squeeze()
            if logits.dim() == 0:
                logits = logits.unsqueeze(0)
            
            loss = criterion(logits, batch_labels.float())
            val_running_loss += loss.item()
            
            # 计算预测
            preds = (torch.sigmoid(logits) > 0.5).long()
            all_val_preds.extend(preds.cpu().numpy())
            all_val_labels.extend(batch_labels.cpu().numpy())
            
            val_progress.set_postfix({
                'val_loss': f'{loss.item():.4f}',
                'avg_val_loss': f'{val_running_loss/(batch_idx+1):.4f}'
            })
    
    # 计算验证指标
    epoch_val_loss = val_running_loss / len(val_loader)
    val_accuracy = accuracy_score(all_val_labels, all_val_preds)
    val_f1 = f1_score(all_val_labels, all_val_preds, zero_division=0)
    val_uar = recall_score(all_val_labels, all_val_preds, average='macro', zero_division=0)
    
    validation_losses.append(epoch_val_loss)
    validation_uars.append(val_uar)
    
    # =================== 打印结果 ===================
    print(f"\nEpoch [{epoch+1}] Summary:")
    print(f"{'─'*60}")
    print(f"Training   - Loss: {epoch_train_loss:.4f}, Acc: {train_accuracy:.4f}, F1: {train_f1:.4f}, UAR: {train_uar:.4f}")
    print(f"Validation - Loss: {epoch_val_loss:.4f}, Acc: {val_accuracy:.4f}, F1: {val_f1:.4f}, UAR: {val_uar:.4f}")
    
    # 显示类别召回率
    if len(set(all_val_labels)) > 1 and len(set(all_val_preds)) > 1:
        class_recalls = recall_score(all_val_labels, all_val_preds, average=None, zero_division=0)
        print(f"Class Recalls - Healthy: {class_recalls[0]:.4f}, Cold: {class_recalls[1]:.4f}")
    
    # 显示注意力权重统计（如果有的话）
    if all_attention_weights:
        attention_concat = torch.cat(all_attention_weights, dim=0)  # (total_samples, num_windows, 1)
        attention_mean = attention_concat.mean(dim=0).squeeze()  # (num_windows,)
        attention_std = attention_concat.std(dim=0).squeeze()   # (num_windows,)
        
        print(f"Attention weights - Mean: {attention_mean.numpy()}")
        print(f"Attention weights - Std:  {attention_std.numpy()}")
        
        # 找出最重要的窗口
        most_important_window = torch.argmax(attention_mean).item()
        print(f"Most important window: {most_important_window} (weight: {attention_mean[most_important_window]:.4f})")
    
    # =================== 保存最佳模型 ===================
    if val_uar > best_uar:
        best_uar = val_uar
        early_stopping_counter = 0
        
        torch.save(model.state_dict(), 'best_attention_mil_model.pth')
        print(f"🌟 New best UAR: {best_uar:.4f}, model saved!")
        
    else:
        early_stopping_counter += 1
        print(f"⏳ No improvement for {early_stopping_counter}/{patience} epochs")
        
        if early_stopping_counter >= patience:
            print(f"❌ Early stopping triggered after {patience} epochs without improvement")
            break
    
    print(f"Best UAR so far: {best_uar:.4f}")

# =================== 训练完成 ===================
total_time = (time.time() - start_time) / 60
print(f"\n{'='*80}")
print(f"🎉 Training completed in {total_time:.2f} minutes!")
print(f"🏆 Best Validation UAR: {best_uar:.4f}")
print(f"📁 Best model saved as: 'best_attention_mil_model.pth'")
print(f"{'='*80}")

# =================== 保存训练历史 ===================
training_history = {
    'training_losses': training_losses,
    'validation_losses': validation_losses,
    'training_uars': training_uars,
    'validation_uars': validation_uars,
    'best_uar': best_uar,
    'total_epochs': epoch + 1,
    'early_stopped': early_stopping_counter >= patience,
    'sequence_model_type': model.sequence_model_type
}

torch.save(training_history, 'attention_mil_training_history.pth')
print(f"💾 Training history saved to 'attention_mil_training_history.pth'")

Starting training with Attention MIL...
Model: attention_mil
Parameters: 688,322

Epoch [1/50]


Training Epoch 1:   0%|          | 0/655 [00:00<?, ?it/s]

Training Epoch 1: 100%|██████████| 655/655 [06:01<00:00,  1.81it/s, loss=0.3734, avg_loss=0.7100]
Validation Epoch 1: 100%|██████████| 300/300 [01:18<00:00,  3.82it/s, val_loss=0.4355, avg_val_loss=0.5082]



Epoch [1] Summary:
────────────────────────────────────────────────────────────
Training   - Loss: 0.7100, Acc: 0.8167, F1: 0.0476, UAR: 0.5107
Validation - Loss: 0.5082, Acc: 0.8909, F1: 0.1089, UAR: 0.5258
Class Recalls - Healthy: 0.9884, Cold: 0.0633
Attention weights - Mean: [0.2754279  0.19137117 0.18316384 0.17952251 0.17051463]
Attention weights - Std:  [0.04251278 0.02831747 0.02631738 0.02507317 0.02371163]
Most important window: 0 (weight: 0.2754)
🌟 New best UAR: 0.5258, model saved!
Best UAR so far: 0.5258

Epoch [2/50]


Training Epoch 2: 100%|██████████| 655/655 [06:10<00:00,  1.77it/s, loss=0.7667, avg_loss=0.5829]
Validation Epoch 2: 100%|██████████| 300/300 [01:22<00:00,  3.65it/s, val_loss=0.4979, avg_val_loss=0.5048]



Epoch [2] Summary:
────────────────────────────────────────────────────────────
Training   - Loss: 0.5829, Acc: 0.8559, F1: 0.5247, UAR: 0.6911
Validation - Loss: 0.5048, Acc: 0.8877, F1: 0.0739, UAR: 0.5149
Class Recalls - Healthy: 0.9872, Cold: 0.0425
Attention weights - Mean: [0.25964236 0.19059387 0.18453002 0.18350157 0.18173222]
Attention weights - Std:  [0.05666114 0.03416167 0.03050242 0.02907554 0.02840365]
Most important window: 0 (weight: 0.2596)
⏳ No improvement for 1/5 epochs
Best UAR so far: 0.5258

Epoch [3/50]


Training Epoch 3: 100%|██████████| 655/655 [06:13<00:00,  1.75it/s, loss=0.2464, avg_loss=0.4892]
Validation Epoch 3: 100%|██████████| 300/300 [01:23<00:00,  3.61it/s, val_loss=0.5071, avg_val_loss=0.5028]



Epoch [3] Summary:
────────────────────────────────────────────────────────────
Training   - Loss: 0.4892, Acc: 0.8754, F1: 0.6394, UAR: 0.7676
Validation - Loss: 0.5028, Acc: 0.8926, F1: 0.1472, UAR: 0.5377
Class Recalls - Healthy: 0.9873, Cold: 0.0880
Attention weights - Mean: [0.23608373 0.19401203 0.1900369  0.189817   0.1900504 ]
Attention weights - Std:  [0.04206767 0.02924171 0.02587952 0.02447437 0.02393272]
Most important window: 0 (weight: 0.2361)
🌟 New best UAR: 0.5377, model saved!
Best UAR so far: 0.5377

Epoch [4/50]


Training Epoch 4: 100%|██████████| 655/655 [06:23<00:00,  1.71it/s, loss=0.2563, avg_loss=0.4530]
Validation Epoch 4: 100%|██████████| 300/300 [01:23<00:00,  3.58it/s, val_loss=0.5796, avg_val_loss=0.5187]



Epoch [4] Summary:
────────────────────────────────────────────────────────────
Training   - Loss: 0.4530, Acc: 0.8836, F1: 0.6635, UAR: 0.7816
Validation - Loss: 0.5187, Acc: 0.8559, F1: 0.2138, UAR: 0.5604
Class Recalls - Healthy: 0.9348, Cold: 0.1860
Attention weights - Mean: [0.2153361  0.19616434 0.19484645 0.19573659 0.19791654]
Attention weights - Std:  [0.02811105 0.02314117 0.01997692 0.01869156 0.01816249]
Most important window: 0 (weight: 0.2153)
🌟 New best UAR: 0.5604, model saved!
Best UAR so far: 0.5604

Epoch [5/50]


Training Epoch 5: 100%|██████████| 655/655 [06:23<00:00,  1.71it/s, loss=0.1432, avg_loss=0.4290]
Validation Epoch 5: 100%|██████████| 300/300 [01:22<00:00,  3.64it/s, val_loss=0.5376, avg_val_loss=0.5078]



Epoch [5] Summary:
────────────────────────────────────────────────────────────
Training   - Loss: 0.4290, Acc: 0.8891, F1: 0.6790, UAR: 0.7903
Validation - Loss: 0.5078, Acc: 0.8778, F1: 0.1894, UAR: 0.5503
Class Recalls - Healthy: 0.9652, Cold: 0.1355
Attention weights - Mean: [0.21706188 0.1956438  0.19458053 0.19534472 0.19736908]
Attention weights - Std:  [0.03263332 0.02534659 0.02171865 0.02037849 0.01994068]
Most important window: 0 (weight: 0.2171)
⏳ No improvement for 1/5 epochs
Best UAR so far: 0.5604

Epoch [6/50]


Training Epoch 6: 100%|██████████| 655/655 [06:19<00:00,  1.73it/s, loss=0.3550, avg_loss=0.4073]
Validation Epoch 6: 100%|██████████| 300/300 [01:21<00:00,  3.66it/s, val_loss=0.4949, avg_val_loss=0.4946]



Epoch [6] Summary:
────────────────────────────────────────────────────────────
Training   - Loss: 0.4073, Acc: 0.8970, F1: 0.7009, UAR: 0.8022
Validation - Loss: 0.4946, Acc: 0.8916, F1: 0.1950, UAR: 0.5533
Class Recalls - Healthy: 0.9819, Cold: 0.1246
Attention weights - Mean: [0.22216052 0.19499806 0.19321266 0.19393492 0.1956938 ]
Attention weights - Std:  [0.04474369 0.03279345 0.02850872 0.02708638 0.02699584]
Most important window: 0 (weight: 0.2222)
⏳ No improvement for 2/5 epochs
Best UAR so far: 0.5604

Epoch [7/50]


Training Epoch 7: 100%|██████████| 655/655 [06:20<00:00,  1.72it/s, loss=0.5948, avg_loss=0.3973]
Validation Epoch 7: 100%|██████████| 300/300 [01:22<00:00,  3.64it/s, val_loss=0.5361, avg_val_loss=0.4921]



Epoch [7] Summary:
────────────────────────────────────────────────────────────
Training   - Loss: 0.3973, Acc: 0.8972, F1: 0.7019, UAR: 0.8031
Validation - Loss: 0.4921, Acc: 0.8716, F1: 0.2761, UAR: 0.5897
Class Recalls - Healthy: 0.9469, Cold: 0.2324
Attention weights - Mean: [0.22090124 0.19436733 0.19320996 0.19451559 0.19700588]
Attention weights - Std:  [0.04239969 0.02986862 0.02616129 0.02487054 0.02461621]
Most important window: 0 (weight: 0.2209)
🌟 New best UAR: 0.5897, model saved!
Best UAR so far: 0.5897

Epoch [8/50]


Training Epoch 8: 100%|██████████| 655/655 [06:17<00:00,  1.73it/s, loss=0.0976, avg_loss=0.3842]
Validation Epoch 8: 100%|██████████| 300/300 [01:21<00:00,  3.66it/s, val_loss=0.4979, avg_val_loss=0.4894]



Epoch [8] Summary:
────────────────────────────────────────────────────────────
Training   - Loss: 0.3842, Acc: 0.9018, F1: 0.7147, UAR: 0.8101
Validation - Loss: 0.4894, Acc: 0.8831, F1: 0.2618, UAR: 0.5804
Class Recalls - Healthy: 0.9639, Cold: 0.1968
Attention weights - Mean: [0.2111472  0.19748296 0.19537196 0.19643593 0.19956195]
Attention weights - Std:  [0.04013909 0.03106999 0.02692309 0.02550613 0.02502733]
Most important window: 0 (weight: 0.2111)
⏳ No improvement for 1/5 epochs
Best UAR so far: 0.5897

Epoch [9/50]


Training Epoch 9: 100%|██████████| 655/655 [06:18<00:00,  1.73it/s, loss=0.6740, avg_loss=0.3707]
Validation Epoch 9: 100%|██████████| 300/300 [01:22<00:00,  3.63it/s, val_loss=0.4917, avg_val_loss=0.5272]



Epoch [9] Summary:
────────────────────────────────────────────────────────────
Training   - Loss: 0.3707, Acc: 0.9051, F1: 0.7268, UAR: 0.8187
Validation - Loss: 0.5272, Acc: 0.8929, F1: 0.2474, UAR: 0.5727
Class Recalls - Healthy: 0.9783, Cold: 0.1672
Attention weights - Mean: [0.2167499  0.19705632 0.19399977 0.19475359 0.19744042]
Attention weights - Std:  [0.04897913 0.03652412 0.03155878 0.03004584 0.02981487]
Most important window: 0 (weight: 0.2167)
⏳ No improvement for 2/5 epochs
Best UAR so far: 0.5897

Epoch [10/50]


Training Epoch 10: 100%|██████████| 655/655 [06:21<00:00,  1.72it/s, loss=0.0619, avg_loss=0.3644]
Validation Epoch 10: 100%|██████████| 300/300 [01:22<00:00,  3.65it/s, val_loss=0.4797, avg_val_loss=0.5297]



Epoch [10] Summary:
────────────────────────────────────────────────────────────
Training   - Loss: 0.3644, Acc: 0.9021, F1: 0.7232, UAR: 0.8203
Validation - Loss: 0.5297, Acc: 0.8985, F1: 0.1589, UAR: 0.5423
Class Recalls - Healthy: 0.9936, Cold: 0.0910
Attention weights - Mean: [0.2114687  0.19630946 0.19562002 0.1968387  0.19976309]
Attention weights - Std:  [0.0486172  0.03781758 0.03272303 0.03127244 0.03145121]
Most important window: 0 (weight: 0.2115)
⏳ No improvement for 3/5 epochs
Best UAR so far: 0.5897

Epoch [11/50]


Training Epoch 11:  27%|██▋       | 178/655 [01:40<04:30,  1.76it/s, loss=0.5250, avg_loss=0.3563]


KeyboardInterrupt: 