# Lora Fine-tuning on WAV2WEC2

## Load Data

In [1]:
import csv
from tqdm import tqdm

label_dict = {}
with open("../ComParE2017_Cold_4students/lab/ComParE2017_Cold.tsv", "r", encoding="utf-8") as f:
    reader = csv.DictReader(f, delimiter="\t")
    rows = list(reader)
    for row in tqdm(rows, desc="Loading labels"):
        label_dict[row["file_name"]] = row["Cold (upper respiratory tract infection)"]

Loading labels: 100%|██████████| 19101/19101 [00:00<00:00, 3820841.32it/s]


In [2]:
def search_in_ground_truth(file_id: str, label_dict: dict) -> str:
    wav_name = file_id + ".wav"
    return label_dict.get(wav_name, None)

In [3]:
# Use a pipeline as a high-level helper
from transformers import pipeline

pipe = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h")

  from .autonotebook import tqdm as notebook_tqdm
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


In [4]:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn as nn
import torch.nn.functional as F

class ContrastiveEmbeddingExtractor(nn.Module):
    def __init__(self, input_dim=3072, projection_dim=256, hidden_dim=512):
        super(ContrastiveEmbeddingExtractor, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        self.projection_head = nn.Sequential(
            nn.Linear(hidden_dim, projection_dim),
            nn.ReLU(),
            nn.Linear(projection_dim, projection_dim),
            L2Norm(dim=1)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )
    
    def forward(self, x, return_features=False):
        features = self.encoder(x)
        
        logits = self.classifier(features)
        
        if return_features:
            projections = self.projection_head(features)
            return logits.squeeze(), projections, features
        else:
            return logits.squeeze()

class L2Norm(nn.Module):
    def __init__(self, dim=1):
        super(L2Norm, self).__init__()
        self.dim = dim
    
    def forward(self, x):
        return F.normalize(x, p=2, dim=self.dim)
    
class SupervisedContrastiveLoss(nn.Module):

    def __init__(self, temperature=0.1, minority_weight=2.0):
        super(SupervisedContrastiveLoss, self).__init__()
        self.temperature = temperature
        self.minority_weight = minority_weight
        
    def forward(self, projections, labels):
        device = projections.device
        batch_size = projections.shape[0]
        
        similarity_matrix = torch.matmul(projections, projections.T) / self.temperature
        
        labels = labels.unsqueeze(1)
        mask = torch.eq(labels, labels.T).float().to(device)
        
        mask = mask - torch.eye(batch_size).to(device)
        
        exp_sim = torch.exp(similarity_matrix)
        
        pos_sim = exp_sim * mask
        
        neg_mask = torch.ones_like(mask) - torch.eye(batch_size).to(device)
        all_sim = exp_sim * neg_mask
        
        losses = []
        for i in range(batch_size):
            if mask[i].sum() > 0:  
                pos_sum = pos_sim[i].sum()
                neg_sum = all_sim[i].sum()
                
                if neg_sum > 0:
                    loss_i = -torch.log(pos_sum / neg_sum)
                    
                    if labels[i] == 1:  
                        loss_i = loss_i * self.minority_weight
                    
                    losses.append(loss_i)
        
        if len(losses) > 0:
            return torch.stack(losses).mean()
        else:
            return torch.tensor(0.0).to(device)

class CombinedLoss(nn.Module):
    def __init__(self, classification_loss, contrastive_loss, alpha=0.3):
        super(CombinedLoss, self).__init__()
        self.classification_loss = classification_loss
        self.contrastive_loss = contrastive_loss
        self.alpha = alpha  
        
    def forward(self, logits, projections, labels):
        cls_loss = self.classification_loss(logits, labels.float())
        
        cont_loss = self.contrastive_loss(projections, labels)
        
        total_loss = (1 - self.alpha) * cls_loss + self.alpha * cont_loss
        
        return total_loss, cls_loss, cont_loss

In [None]:
class IntegratedModel(nn.Module):
    def __init__(self, wav2vec2_model, processor, downstream_model):
        super(IntegratedModel, self).__init__()
        self.wav2vec2_model = wav2vec2_model
        self.processor = processor
        self.downstream_model = downstream_model
        
    def forward(self, waveforms, return_features=False):
        device = waveforms.device
        
        waveforms_np = waveforms.cpu().numpy()
        
        inputs = self.processor(
            waveforms_np, 
            sampling_rate=16000, 
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=160000
        )
        
        input_values = inputs['input_values'].to(device)
        attention_mask = inputs.get('attention_mask', None)
        if attention_mask is not None:
            attention_mask = attention_mask.to(device)
 
        out = self.wav2vec2_model(
                input_values, 
                attention_mask=attention_mask,
                output_hidden_states=True
            )
        hs = out.hidden_states
        first_k = hs[1:3]
        
        pooled = []
        for layer in first_k:
            m = layer.mean(dim=1)
            mx = layer.max(dim=1).values
            pooled.extend([m, mx])
        
        embedding = torch.cat(pooled, dim=-1)

        # embedding = out.last_hidden_state.mean(dim=1)
        result = self.downstream_model(embedding, return_features=True)
        if isinstance(result, tuple) and len(result) == 3:
                return result
        else:
            logits = result
            return logits, embedding, embedding

In [6]:
from transformers import AutoModel, AutoTokenizer
import torch

model = AutoModel.from_pretrained("facebook/wav2vec2-base-960h")

print("Wav2Vec2 Model Structure:")
for name, module in model.named_modules():
    if 'attention' in name or 'query' in name or 'key' in name or 'value' in name:
        print(f"  {name}: {type(module)}")

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Wav2Vec2 Model Structure:
  encoder.layers.0.attention: <class 'transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SdpaAttention'>
  encoder.layers.0.attention.k_proj: <class 'torch.nn.modules.linear.Linear'>
  encoder.layers.0.attention.v_proj: <class 'torch.nn.modules.linear.Linear'>
  encoder.layers.0.attention.q_proj: <class 'torch.nn.modules.linear.Linear'>
  encoder.layers.0.attention.out_proj: <class 'torch.nn.modules.linear.Linear'>
  encoder.layers.1.attention: <class 'transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SdpaAttention'>
  encoder.layers.1.attention.k_proj: <class 'torch.nn.modules.linear.Linear'>
  encoder.layers.1.attention.v_proj: <class 'torch.nn.modules.linear.Linear'>
  encoder.layers.1.attention.q_proj: <class 'torch.nn.modules.linear.Linear'>
  encoder.layers.1.attention.out_proj: <class 'torch.nn.modules.linear.Linear'>
  encoder.layers.2.attention: <class 'transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SdpaAttention'>
  encoder.layers.2.

In [7]:
from peft import get_peft_model, LoraConfig, TaskType

lora_config = LoraConfig(
    r=16,  
    lora_alpha=32,
    lora_dropout=0.1,
    bias="none",
    target_modules=["encoder.layers.10.attention.q_proj",
                    "encoder.layers.10.attention.v_proj",
                    "encoder.layers.10.attention.k_proj",
                    "encoder.layers.11.attention.q_proj",
                    "encoder.layers.11.attention.v_proj",
                    ],
    base_model_name_or_path="facebook/wav2vec2-base-960h" 
)

In [8]:
from transformers import Wav2Vec2Processor, Wav2Vec2Model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
upstream_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h").to("cuda")
lora_model = get_peft_model(upstream_model, lora_config)

model_downstream = ContrastiveEmbeddingExtractor(input_dim=3072,projection_dim=512,hidden_dim=256).to(device)
model_downstream.load_state_dict(torch.load("best_contrastive_embedding_model_first.pth", map_location=device))

model = IntegratedModel(lora_model, processor, model_downstream).to(device)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
for param in model.parameters():
    param.requires_grad = False
for name, param in model.named_parameters():
    if 'lora' in name:
        param.requires_grad = True
    if'downstream_model' in name:
        param.requires_grad = True

for name, param in model.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")


wav2vec2_model.base_model.model.masked_spec_embed: requires_grad=False
wav2vec2_model.base_model.model.feature_extractor.conv_layers.0.conv.weight: requires_grad=False
wav2vec2_model.base_model.model.feature_extractor.conv_layers.0.layer_norm.weight: requires_grad=False
wav2vec2_model.base_model.model.feature_extractor.conv_layers.0.layer_norm.bias: requires_grad=False
wav2vec2_model.base_model.model.feature_extractor.conv_layers.1.conv.weight: requires_grad=False
wav2vec2_model.base_model.model.feature_extractor.conv_layers.2.conv.weight: requires_grad=False
wav2vec2_model.base_model.model.feature_extractor.conv_layers.3.conv.weight: requires_grad=False
wav2vec2_model.base_model.model.feature_extractor.conv_layers.4.conv.weight: requires_grad=False
wav2vec2_model.base_model.model.feature_extractor.conv_layers.5.conv.weight: requires_grad=False
wav2vec2_model.base_model.model.feature_extractor.conv_layers.6.conv.weight: requires_grad=False
wav2vec2_model.base_model.model.feature_projec

In [None]:
import os
import librosa
from tqdm import tqdm
import numpy as np
from torch.utils.data import Dataset

input_root = "../ComParE2017_Cold_4students/wav/"
data_dir = "processed_files"
data_split = ["train_files", "devel_files"]

train_dir = os.path.join(input_root, data_split[0], data_dir)
devel_dir = os.path.join(input_root, data_split[1], data_dir)

train_files = [f for f in os.listdir(train_dir) if f.endswith('.wav')]
devel_files = [f for f in os.listdir(devel_dir) if f.endswith('.wav')]

class Wav2Vec2Dataset(Dataset):
    def __init__(self, file_list, label_dict, split, is_training = False, input_root="../ComParE2017_Cold_4students/wav/", max_length=160000):
        self.file_list = file_list
        self.label_dict = label_dict
        self.input_root = input_root 
        self.split = split
        self.max_length = max_length
        self.is_training = is_training
    
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        file_name = self.file_list[idx]
        file_path = os.path.join(self.input_root, self.split, "processed_files", file_name)
        
        label = search_in_ground_truth(file_name[:-4], self.label_dict)
        
        try:
            waveform, sr = librosa.load(file_path, sr=16000)
            
            if self.is_training and label == "C":
                try:
                    file_path_additional = os.path.join(self.input_root, self.split, file_name)
                    if os.path.exists(file_path_additional):
                        additional_waveform, _ = librosa.load(file_path_additional, sr=16000)
                        
                        if np.random.random() > 0.5:
                            waveform = additional_waveform
                            #print(f"Using additional file for {file_name}")
                        
                except Exception as e:
                    print(f"Could not load additional file {file_path_additional}: {e}")
            
            if len(waveform) > self.max_length:
                waveform = waveform[:self.max_length]
            elif len(waveform) < self.max_length:
                padding = self.max_length - len(waveform)
                waveform = np.pad(waveform, (0, padding), mode='constant', constant_values=0)
            
            if len(waveform) == 0:
                raise ValueError("Empty audio file")
                
        except Exception as e:
            print(f"Error processing {file_path}: {e}")
            waveform = np.zeros(self.max_length)
        
        label_tensor = torch.tensor(1 if label == "C" else 0)
        
        return torch.tensor(waveform, dtype=torch.float32), label_tensor

In [11]:
from torch.utils.data import DataLoader

train_set = Wav2Vec2Dataset(train_files, label_dict, "train_files", is_training=True, input_root= "../ComParE2017_Cold_4students/wav/")
devel_set = Wav2Vec2Dataset(devel_files, label_dict, "devel_files", input_root)

train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
devel_loader = DataLoader(devel_set, batch_size=4, shuffle=False)

print("=== Train Set - First 5 samples ===")
for i in range(min(5, len(train_set))):
    input_values, label = train_set[i]
    print(f"Sample {i+1}:")
    print(f"  File: {train_set.file_list[i] if hasattr(train_set, 'file_list') else 'unknown'}")
    print(f"  Input shape: {input_values.shape}")
    print(f"  Label: {label.item()} ({'Cold' if label.item() == 1 else 'Healthy'})")
    print()

=== Train Set - First 5 samples ===
Sample 1:
  File: train_0001.wav
  Input shape: torch.Size([160000])
  Label: 1 (Cold)

Sample 2:
  File: train_0002.wav
  Input shape: torch.Size([160000])
  Label: 0 (Healthy)

Sample 3:
  File: train_0003.wav
  Input shape: torch.Size([160000])
  Label: 0 (Healthy)

Sample 4:
  File: train_0004.wav
  Input shape: torch.Size([160000])
  Label: 1 (Cold)

Sample 5:
  File: train_0005.wav
  Input shape: torch.Size([160000])
  Label: 0 (Healthy)



In [12]:
num_epochs= 50
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-7, weight_decay=1e-8)
classification_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(5).to(device))
contrastive_loss = SupervisedContrastiveLoss(temperature=0.05, minority_weight=2.0)
criterion = CombinedLoss(
    classification_loss=classification_loss,
    contrastive_loss=contrastive_loss,
    alpha=0.4
)
threshold = 0.6


In [13]:
for name, param in model.named_parameters():
    if 'lora' in name:
        print(name, param.requires_grad, param.grad.norm() if param.grad is not None else "NO GRAD")


wav2vec2_model.base_model.model.encoder.layers.10.attention.k_proj.lora_A.default.weight True NO GRAD
wav2vec2_model.base_model.model.encoder.layers.10.attention.k_proj.lora_B.default.weight True NO GRAD
wav2vec2_model.base_model.model.encoder.layers.10.attention.v_proj.lora_A.default.weight True NO GRAD
wav2vec2_model.base_model.model.encoder.layers.10.attention.v_proj.lora_B.default.weight True NO GRAD
wav2vec2_model.base_model.model.encoder.layers.10.attention.q_proj.lora_A.default.weight True NO GRAD
wav2vec2_model.base_model.model.encoder.layers.10.attention.q_proj.lora_B.default.weight True NO GRAD
wav2vec2_model.base_model.model.encoder.layers.11.attention.v_proj.lora_A.default.weight True NO GRAD
wav2vec2_model.base_model.model.encoder.layers.11.attention.v_proj.lora_B.default.weight True NO GRAD
wav2vec2_model.base_model.model.encoder.layers.11.attention.q_proj.lora_A.default.weight True NO GRAD
wav2vec2_model.base_model.model.encoder.layers.11.attention.q_proj.lora_B.default.

In [14]:
#print trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {trainable_params}")

Total trainable parameters: 1468417


In [15]:
def detailed_lora_check(model):
    print("🔍 Detailed LoRA Parameter Check:")
    
    total_params = 0
    lora_params = 0
    trainable_params = 0
    
    for name, param in model.named_parameters():
        total_params += 1
        
        if 'lora' in name.lower():
            lora_params += 1
            print(f"LoRA param: {name}")
            print(f"  - requires_grad: {param.requires_grad}")
            print(f"  - shape: {param.shape}")
            print(f"  - device: {param.device}")
            print(f"  - dtype: {param.dtype}")
            
            if param.requires_grad:
                trainable_params += 1
    
    print(f"\n📊 Summary:")
    print(f"  Total parameters: {total_params}")
    print(f"  LoRA parameters: {lora_params}")
    print(f"  Trainable LoRA parameters: {trainable_params}")
    
    if hasattr(model.wav2vec2_model, 'peft_config'):
        print(f"  PEFT config: {model.wav2vec2_model.peft_config}")
    
    if hasattr(model.wav2vec2_model, 'active_adapters'):
        print(f"  Active adapters: {model.wav2vec2_model.active_adapters}")
    
    return lora_params, trainable_params

detailed_lora_check(model)

🔍 Detailed LoRA Parameter Check:
LoRA param: wav2vec2_model.base_model.model.encoder.layers.10.attention.k_proj.lora_A.default.weight
  - requires_grad: True
  - shape: torch.Size([16, 768])
  - device: cuda:0
  - dtype: torch.float32
LoRA param: wav2vec2_model.base_model.model.encoder.layers.10.attention.k_proj.lora_B.default.weight
  - requires_grad: True
  - shape: torch.Size([768, 16])
  - device: cuda:0
  - dtype: torch.float32
LoRA param: wav2vec2_model.base_model.model.encoder.layers.10.attention.v_proj.lora_A.default.weight
  - requires_grad: True
  - shape: torch.Size([16, 768])
  - device: cuda:0
  - dtype: torch.float32
LoRA param: wav2vec2_model.base_model.model.encoder.layers.10.attention.v_proj.lora_B.default.weight
  - requires_grad: True
  - shape: torch.Size([768, 16])
  - device: cuda:0
  - dtype: torch.float32
LoRA param: wav2vec2_model.base_model.model.encoder.layers.10.attention.q_proj.lora_A.default.weight
  - requires_grad: True
  - shape: torch.Size([16, 768])
 

(10, 10)

In [None]:
print("🔄 Creating simplest model...")
simple_model = IntegratedModel(lora_model, processor, model_downstream).to(device)

waveforms, labels = next(iter(train_loader))
waveforms = waveforms[:2].to(device)
labels = labels[:2].to(device)

try:
    logits, projections, features = simple_model(waveforms, return_features=True)
    print(f"✅ Simple model works!")
    print(f"  Logits: {logits.shape}")
    print(f"  Projections: {projections.shape}")
    print(f"  Features: {features.shape}")
    
    simple_model.zero_grad()
    loss, _, _ = criterion(logits, projections, labels)
    loss.backward()
    
    grad_count = 0
    for name, param in simple_model.named_parameters():
        if 'lora' in name.lower() and param.grad is not None:
            grad_norm = param.grad.norm().item()
            if grad_norm > 1e-10:
                grad_count += 1
    
    print(f"✅ LoRA gradients working: {grad_count} parameters have gradients")
    
except Exception as e:
    print(f"❌ Simple model failed: {e}")
    import traceback
    traceback.print_exc()

🔄 Creating simplest model...
✅ Simple model works!
  Logits: torch.Size([2])
  Projections: torch.Size([2, 512])
  Features: torch.Size([2, 256])
✅ LoRA gradients working: 0 parameters have gradients


In [None]:
import time
from sklearn.metrics import accuracy_score, f1_score, recall_score

best_val_uar = 0.0  
patience = 10
early_stop_counter = 0

train_losses = []
val_losses = []
train_uar_scores = [] 
val_uar_scores = []    

print("🚀 Start CL...\n")

for epoch in range(num_epochs):
    model.train()
    
    total_loss = 0.0
    total_cls_loss = 0.0
    total_cont_loss = 0.0
    all_preds, all_labels = [], []
    
    print(f'\n{"="*80}')
    print(f'Epoch [{epoch+1}/{num_epochs}] - Contrastive Embedding Learning')
    print(f'{"="*80}\n')
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    
    for batch_idx, (waveforms, labels) in enumerate(progress_bar):
        
        waveforms = waveforms.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        logits, projections, features = model(waveforms, return_features=True)
        
        if logits.dim() == 0:  
            logits = logits.unsqueeze(0)  
        elif logits.dim() == 1 and logits.shape[0] != labels.shape[0]:
            logits = logits.repeat(labels.shape[0])

        if labels.dim() == 0:
            labels = labels.unsqueeze(0)

        loss, cls_loss, cont_loss = criterion(logits, projections, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        """
        for name, param in model.named_parameters():
            if 'lora' in name:
                print(name, param.requires_grad, param.grad.norm() if param.grad is not None else "NO GRAD")
        """

        optimizer.step()
        
        with torch.no_grad():
            preds = (torch.sigmoid(logits) > threshold).long()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
        
        total_loss += loss.item()
        total_cls_loss += cls_loss.item()
        total_cont_loss += cont_loss.item()
        
        progress_bar.set_postfix({
            'Total': f'{loss.item():.4f}',
            'Cls': f'{cls_loss.item():.4f}',
            'Cont': f'{cont_loss.item():.4f}'
        })
    
    avg_loss = total_loss / len(train_loader)
    avg_cls_loss = total_cls_loss / len(train_loader)
    avg_cont_loss = total_cont_loss / len(train_loader)
    
    train_acc = accuracy_score(all_labels, all_preds)
    train_f1 = f1_score(all_labels, all_preds, zero_division=0)
    train_uar = recall_score(all_labels, all_preds, average='macro', zero_division=0)  # 添加UAR
    
    model.eval()
    val_preds, val_labels = [], []
    total_val_loss = 0.0
    
    with torch.no_grad():
        for waveforms, labels in tqdm(devel_loader, desc="Validating"):
            waveforms = waveforms.to(device)
            labels = labels.to(device)
            
            logits, projections, features = model(waveforms, return_features=True)
            if logits.dim() == 0:  
                logits = logits.unsqueeze(0)  
            elif logits.dim() == 1 and logits.shape[0] != labels.shape[0]:
                logits = logits.repeat(labels.shape[0])

            if labels.dim() == 0:
                labels = labels.unsqueeze(0)
                
            val_loss, val_cls_loss, val_cont_loss = criterion(logits, projections, labels)
            total_val_loss += val_loss.item()
            
            preds = (torch.sigmoid(logits) > threshold).long()
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())
    
    avg_val_loss = total_val_loss / len(devel_loader)
    val_acc = accuracy_score(val_labels, val_preds)
    val_f1 = f1_score(val_labels, val_preds, zero_division=0)
    val_uar = recall_score(val_labels, val_preds, average='macro', zero_division=0)  
    
    train_losses.append(avg_loss)
    val_losses.append(avg_val_loss)
    train_uar_scores.append(train_uar) 
    val_uar_scores.append(val_uar)      
    
    current_lr = optimizer.param_groups[0]['lr']
    
    print(f"\nEpoch [{epoch+1}] Summary:")
    print(f"  🎯 Learning Rate: {current_lr:.2e}")
    print(f"  📈 Training   - Loss: {avg_loss:.4f} (Cls: {avg_cls_loss:.4f}, Cont: {avg_cont_loss:.4f})")
    print(f"                 Acc: {train_acc:.4f}, F1: {train_f1:.4f}, UAR: {train_uar:.4f}")  
    print(f"  📊 Validation - Loss: {avg_val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}, UAR: {val_uar:.4f}") 
    
    if len(set(val_labels)) > 1 and len(set(val_preds)) > 1:
        class_recalls = recall_score(val_labels, val_preds, average=None, zero_division=0)
        print(f"  🎯 Class Recalls - Healthy: {class_recalls[0]:.4f}, Cold: {class_recalls[1]:.4f}")
    
    if val_uar > best_val_uar:
        best_val_uar = val_uar
        early_stop_counter = 0
        torch.save(model.state_dict(), "best_contrastive_embedding_model_lora.pth")
        print(f"🌟 New best UAR: {best_val_uar:.4f}, saving model...")
    else:
        early_stop_counter += 1
        print(f"⏳ No improvement for {early_stop_counter}/{patience} epochs")
        if early_stop_counter >= patience:
            print(f"❌ Early stopping after {patience} epochs without improvement")
            break

print(f"\n🎉 Finish Training! Best UAR: {best_val_uar:.4f}")

training_history = {
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_uar_scores': train_uar_scores,
    'val_uar_scores': val_uar_scores,
    'best_val_uar': best_val_uar,
    'total_epochs': epoch + 1,
    'early_stopped': early_stop_counter >= patience
}

torch.save(training_history, 'contrastive_training_history.pth')
print(f"💾 Training history saved to 'contrastive_training_history.pth'")

🚀 Start CL...


Epoch [1/50] - Contrastive Embedding Learning



Epoch 1: 100%|██████████| 2377/2377 [06:45<00:00,  5.86it/s, Total=0.1774, Cls=0.2957, Cont=0.0000]
Validating: 100%|██████████| 2399/2399 [03:50<00:00, 10.42it/s]



Epoch [1] Summary:
  🎯 Learning Rate: 1.00e-07
  📈 Training   - Loss: 0.6598 (Cls: 0.9120, Cont: 0.2813)
                 Acc: 0.8967, F1: 0.0061, UAR: 0.5007
  📊 Validation - Loss: 0.6892, Acc: 0.8946, F1: 0.0000, UAR: 0.5000
🌟 New best UAR: 0.5000, saving model...

Epoch [2/50] - Contrastive Embedding Learning



Epoch 2: 100%|██████████| 2377/2377 [06:46<00:00,  5.85it/s, Total=0.1014, Cls=0.1690, Cont=0.0000] 
Validating: 100%|██████████| 2399/2399 [03:54<00:00, 10.25it/s]



Epoch [2] Summary:
  🎯 Learning Rate: 1.00e-07
  📈 Training   - Loss: 0.7535 (Cls: 1.0673, Cont: 0.2829)
                 Acc: 0.8979, F1: 0.0000, UAR: 0.5000
  📊 Validation - Loss: 0.8023, Acc: 0.8946, F1: 0.0000, UAR: 0.5000
⏳ No improvement for 1/10 epochs

Epoch [3/50] - Contrastive Embedding Learning



Epoch 3: 100%|██████████| 2377/2377 [06:49<00:00,  5.81it/s, Total=0.0532, Cls=0.0887, Cont=0.0000] 
Validating: 100%|██████████| 2399/2399 [03:55<00:00, 10.18it/s]



Epoch [3] Summary:
  🎯 Learning Rate: 1.00e-07
  📈 Training   - Loss: 0.8979 (Cls: 1.3170, Cont: 0.2693)
                 Acc: 0.8979, F1: 0.0000, UAR: 0.5000
  📊 Validation - Loss: 0.9376, Acc: 0.8946, F1: 0.0000, UAR: 0.5000
⏳ No improvement for 2/10 epochs

Epoch [4/50] - Contrastive Embedding Learning



Epoch 4: 100%|██████████| 2377/2377 [06:48<00:00,  5.81it/s, Total=8.5861, Cls=14.3101, Cont=0.0000]
Validating: 100%|██████████| 2399/2399 [03:55<00:00, 10.21it/s]



Epoch [4] Summary:
  🎯 Learning Rate: 1.00e-07
  📈 Training   - Loss: 1.0572 (Cls: 1.5879, Cont: 0.2610)
                 Acc: 0.8979, F1: 0.0000, UAR: 0.5000
  📊 Validation - Loss: 1.0803, Acc: 0.8946, F1: 0.0000, UAR: 0.5000
⏳ No improvement for 3/10 epochs

Epoch [5/50] - Contrastive Embedding Learning



Epoch 5: 100%|██████████| 2377/2377 [06:47<00:00,  5.84it/s, Total=0.0023, Cls=0.0039, Cont=0.0000] 
Validating: 100%|██████████| 2399/2399 [03:51<00:00, 10.37it/s]



Epoch [5] Summary:
  🎯 Learning Rate: 1.00e-07
  📈 Training   - Loss: 1.1844 (Cls: 1.8085, Cont: 0.2484)
                 Acc: 0.8979, F1: 0.0000, UAR: 0.5000
  📊 Validation - Loss: 1.1813, Acc: 0.8946, F1: 0.0000, UAR: 0.5000
⏳ No improvement for 4/10 epochs

Epoch [6/50] - Contrastive Embedding Learning



Epoch 6:   6%|▌         | 136/2377 [00:22<06:16,  5.95it/s, Total=0.0176, Cls=0.0293, Cont=0.0000] 


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# training_history = torch.load('contrastive_training_history.pth')
# train_losses = training_history['train_losses']
# val_losses = training_history['val_losses']
# train_uar_scores = training_history['train_uar_scores']
# val_uar_scores = training_history['val_uar_scores']

# 创建图表
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Contrastive Learning Training Progress', fontsize=16, fontweight='bold')

# 1. Loss曲线
epochs = range(1, len(train_losses) + 1)
ax1.plot(epochs, train_losses, 'b-o', label='Training Loss', linewidth=2, markersize=4)
ax1.plot(epochs, val_losses, 'r-s', label='Validation Loss', linewidth=2, markersize=4)
ax1.set_title('Training & Validation Loss', fontweight='bold')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# 2. UAR曲线
ax2.plot(epochs, train_uar_scores, 'b-o', label='Training UAR', linewidth=2, markersize=4)
ax2.plot(epochs, val_uar_scores, 'r-s', label='Validation UAR', linewidth=2, markersize=4)
ax2.set_title('Training & Validation UAR', fontweight='bold')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('UAR Score')
ax2.legend()
ax2.grid(True, alpha=0.3)

# 3. Loss详细分析
ax3.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
ax3.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
ax3.fill_between(epochs, train_losses, alpha=0.3, color='blue')
ax3.fill_between(epochs, val_losses, alpha=0.3, color='red')
ax3.set_title('Loss Curves (Filled)', fontweight='bold')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Loss')
ax3.legend()
ax3.grid(True, alpha=0.3)

# 4. 最佳性能标注
best_epoch = np.argmax(val_uar_scores) + 1
best_uar = max(val_uar_scores)

ax4.plot(epochs, val_uar_scores, 'g-o', label='Validation UAR', linewidth=3, markersize=6)
ax4.axhline(y=best_uar, color='red', linestyle='--', alpha=0.7, label=f'Best UAR: {best_uar:.4f}')
ax4.axvline(x=best_epoch, color='red', linestyle='--', alpha=0.7, label=f'Best Epoch: {best_epoch}')
ax4.scatter([best_epoch], [best_uar], color='red', s=100, zorder=5)
ax4.annotate(f'Best: {best_uar:.4f}\nEpoch: {best_epoch}', 
             xy=(best_epoch, best_uar), xytext=(best_epoch+1, best_uar-0.02),
             arrowprops=dict(arrowstyle='->', color='red'),
             fontsize=10, fontweight='bold')
ax4.set_title('Validation UAR with Best Performance', fontweight='bold')
ax4.set_xlabel('Epoch')
ax4.set_ylabel('UAR Score')
ax4.legend()
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('contrastive_training_progress.png', dpi=300, bbox_inches='tight')
plt.show()

# 打印训练总结
print("📊 Training Summary:")
print(f"   🏆 Best Validation UAR: {best_uar:.4f} (Epoch {best_epoch})")
print(f"   📈 Final Training Loss: {train_losses[-1]:.4f}")
print(f"   📊 Final Validation Loss: {val_losses[-1]:.4f}")
print(f"   🎯 Total Epochs: {len(train_losses)}")
if 'early_stopped' in locals() and early_stopped:
    print(f"   ⏹️  Early stopped: Yes")
else:
    print(f"   ⏹️  Early stopped: No")

# 额外的详细loss图
plt.figure(figsize=(12, 8))

# 分类损失和对比损失分别绘制（如果有的话）
if 'total_cls_loss' in locals() and 'total_cont_loss' in locals():
    plt.subplot(2, 2, 1)
    # 这里需要你保存每个epoch的分类损失和对比损失
    # 假设你有这些数据
    plt.title('Classification vs Contrastive Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)

# Loss差异图
plt.subplot(2, 2, 2)
loss_diff = np.array(val_losses) - np.array(train_losses)
plt.plot(epochs, loss_diff, 'purple', linewidth=2, marker='o')
plt.title('Validation - Training Loss Difference')
plt.xlabel('Epoch')
plt.ylabel('Loss Difference')
plt.axhline(y=0, color='black', linestyle='-', alpha=0.3)
plt.grid(True, alpha=0.3)

# UAR改进图
plt.subplot(2, 2, 3)
if len(val_uar_scores) > 1:
    uar_improvement = np.diff(val_uar_scores)
    plt.bar(range(2, len(val_uar_scores)+1), uar_improvement, 
            color=['green' if x > 0 else 'red' for x in uar_improvement],
            alpha=0.7)
    plt.title('UAR Improvement by Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('UAR Change')
    plt.axhline(y=0, color='black', linestyle='-', alpha=0.3)
    plt.grid(True, alpha=0.3)

# 学习曲线平滑图
plt.subplot(2, 2, 4)
# 应用移动平均平滑
def moving_average(data, window=3):
    if len(data) < window:
        return data
    return np.convolve(data, np.ones(window)/window, mode='valid')

if len(train_losses) >= 3:
    smooth_train = moving_average(train_losses, 3)
    smooth_val = moving_average(val_losses, 3)
    smooth_epochs = range(2, len(smooth_train)+2)
    
    plt.plot(smooth_epochs, smooth_train, 'b-', linewidth=3, label='Smoothed Training Loss')
    plt.plot(smooth_epochs, smooth_val, 'r-', linewidth=3, label='Smoothed Validation Loss')
    plt.title('Smoothed Loss Curves')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('detailed_training_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

print("📈 训练图表已保存:")
print("   📊 contrastive_training_progress.png - 主要训练进度")
print("   📈 detailed_training_analysis.png - 详细分析图表")
