# Sliding Window on Images

## Load Data

In [36]:
import csv
from tqdm import tqdm

label_dict = {}
with open("../ComParE2017_Cold_4students/lab/ComParE2017_Cold.tsv", "r", encoding="utf-8") as f:
    reader = csv.DictReader(f, delimiter="\t")
    rows = list(reader)
    for row in tqdm(rows, desc="Loading labels"):
        label_dict[row["file_name"]] = row["Cold (upper respiratory tract infection)"]

Loading labels: 100%|██████████| 19101/19101 [00:00<00:00, 3822482.02it/s]


In [37]:
import os
def search_in_labels(filename, label_dict):
    base_name = os.path.splitext(filename)[0]
    
    if "_logmel" in base_name:
        base_name = base_name.replace("_logmel", "")
    if "_flipped" in base_name:
        base_name = base_name.replace("_flipped", "")
    
    parts = base_name.split("_")
    if len(parts) >= 2:
        audio_filename = f"{parts[0]}_{parts[1]}.wav"
    else:
        audio_filename = f"{base_name}.wav"
    
    return label_dict.get(audio_filename, None)

## Data Loader

In [None]:
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os
import torch
import torch.nn.functional as F

class SpectrogramDataset(Dataset):
    def __init__(self, image_paths, label_dict, transform=None, window_size=128, stride=64, is_training=False, num_windows=3):
        self.image_paths = image_paths
        self.label_dict = label_dict
        self.is_training = is_training 
        self.window_size = window_size
        self.stride = stride
        self.num_windows = num_windows

        self.base_transform = transforms.Compose([
            # transforms.RandomAffine(degrees=0, translate=(0.3, 0)),
            transforms.ToTensor()
        ])
        
        self.c_train_transform = transforms.Compose([
            transforms.RandomAffine(degrees=0, translate=(0.3, 0)),
            transforms.ToTensor()
        ])
        
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        filename = os.path.basename(image_path)
        label = search_in_labels(filename, self.label_dict)
        label_num = 1 if label == "C" else 0

        image = Image.open(image_path).convert("RGB")

        if self.is_training and label == "C":
            image = self.c_train_transform(image)
        elif self.is_training and label == "NC":
            image = self.base_transform(image)
        else:
            image = transforms.ToTensor()(image)

        _, H, W = image.shape
        assert H == 128, f"Image height must be 128, but got {H}"

        windows = []
        for start in range(0, W - self.window_size + 1, self.stride):
            window = image[:, :, start:start + self.window_size]
            windows.append(window)

        if (W - self.window_size) % self.stride != 0:
            last_window = image[:, :, -self.window_size:]
            windows.append(last_window)

        if len(windows) == 0:
            pad_width = self.window_size - W
            image_padded = F.pad(image, (0, pad_width), mode='constant', value=0)
            window = image_padded[:, :, :self.window_size]
            windows.append(window)

        windows = torch.stack(windows)  # Shape: (num_extracted_windows, 3, 128, 128)

        if self.is_training:
            # 随机选择指定数量的windows
            num_available = windows.shape[0]
            
            if num_available >= self.num_windows:
                rand_indices = torch.randperm(num_available)[:self.num_windows]
                selected_windows = windows[rand_indices]  
            else:
                rand_indices = torch.randint(0, num_available, (self.num_windows,))
                selected_windows = windows[rand_indices]  
            
            labels = torch.full((self.num_windows,), label_num, dtype=torch.long)
            
            return selected_windows, labels
        else:
            return windows, label_num

In [39]:
import os
import glob
from torch.utils.data import DataLoader

data_split = ["train_files", "devel_files"]
img_dir = "../spectrograms_variable_width"  

def collect_image_paths_devel(split_name):
        sub_dir = os.path.join(img_dir, split_name)
        print(f"🔍 Looking for images in: {sub_dir}")
        
        if not os.path.exists(sub_dir):
            print(f"❌ Directory does not exist: {sub_dir}")
            return []
        
        png_files = glob.glob(os.path.join(sub_dir, "*.png"))
        
        filtered_files = [f for f in png_files if "flipped" not in os.path.basename(f)]
        
        print(f"📁 Found {len(png_files)} PNG files in {split_name}")
        print(f"📋 After filtering out 'flipped' files: {len(filtered_files)} files")
        
        return filtered_files

def collect_image_paths(split_name):
    sub_dir = os.path.join(img_dir, split_name)
    print(f"🔍 Looking for images in: {sub_dir}")
    
    if not os.path.exists(sub_dir):
        print(f"❌ Directory does not exist: {sub_dir}")
        return []
    
    png_files = glob.glob(os.path.join(sub_dir, "*.png"))
    print(f"📁 Found {len(png_files)} PNG files in {split_name}")
    
    return png_files

print("🚀 Collecting image paths...")
train_image_paths = collect_image_paths("train_files")
devel_image_paths = collect_image_paths_devel("devel_files")

train_dataset = SpectrogramDataset(
    image_paths=train_image_paths,
    label_dict=label_dict,
    transform=None,
    window_size=128,
    stride=32,
    is_training=True
)
devel_dataset = SpectrogramDataset(
    image_paths=devel_image_paths,
    label_dict=label_dict,
    transform=None,
    window_size=128,
    stride=32,
    is_training=False
)
train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)
devel_loader = DataLoader(
    devel_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)



🚀 Collecting image paths...
🔍 Looking for images in: ../spectrograms_variable_width\train_files
📁 Found 10475 PNG files in train_files
🔍 Looking for images in: ../spectrograms_variable_width\devel_files
📁 Found 10607 PNG files in devel_files
📋 After filtering out 'flipped' files: 9596 files


## CNN


In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ImprovedCNNBinaryClassifier(nn.Module):
    def __init__(self, input_shape=(3, 128, 128), num_classes=1):
        super(ImprovedCNNBinaryClassifier, self).__init__()
        
        self.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        
        self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))  
        
        self.fc1 = nn.Linear(512 * 4 * 4, 512)  
        self.dropout1 = nn.Dropout(0.5)
        
        self.fc2 = nn.Linear(512, 256)
        self.dropout2 = nn.Dropout(0.3)
        
        self.fc3 = nn.Linear(256, num_classes)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool3(x)
        
        x = F.relu(self.bn4(self.conv4(x)))
        
        x = self.adaptive_pool(x) 
        
        x = x.view(x.size(0), -1)  # [batch, 512*4*4]
        
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        
        x = self.fc3(x)
        
        return x

## Training


In [46]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ImprovedCNNBinaryClassifier().to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(4).to(device))  
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-6)

num_epochs = 100

threshold = 0.5

### Training Loop

In [47]:
from tqdm import trange
from sklearn.metrics import accuracy_score, f1_score

def eval_with_voting(model, dataset, criterion, device, threshold=0.5, vote_mode='soft'):
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for i in trange(len(dataset), desc="Validating"):
            windows, label = dataset[i]  # windows: (N, C, H, W)
            label = torch.tensor(label).to(device)

            windows = windows.to(device)  # shape: (N, C, H, W)
            logits = model(windows).squeeze()  # shape: (N,) or (N,1)

            # 确保 logits 是 1D
            if logits.dim() > 1:
                logits = logits.squeeze()
            if logits.dim() == 0:
                logits = logits.unsqueeze(0)

            probs = torch.sigmoid(logits)

            # 🎯 投票
            if vote_mode == 'soft':
                avg_prob = probs.mean().item()
                final_pred = 1 if avg_prob > threshold else 0
            else:  # 'hard'
                window_preds = (probs > threshold).long()
                final_pred = torch.mode(window_preds).values.item()

            all_preds.append(final_pred)
            all_labels.append(label.item())

            # 平均窗口 loss
            repeated_label = label.repeat(len(logits))
            sample_loss = criterion(logits, repeated_label.float()).item()
            total_loss += sample_loss

    avg_loss = total_loss / len(dataset)
    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)

    # 🎯 返回 loss, accuracy, f1, predictions, labels
    return avg_loss, acc, f1, all_preds, all_labels

In [48]:
import time
from sklearn.metrics import accuracy_score, f1_score, recall_score

best_val_loss = float('inf')
best_uar = 0.0
patience = 4
patience_counter = 0
training_losses = []
validation_losses = []
start_time = time.time()
early_stop_counter = 0

print("Starting training...\n")

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    all_preds, all_labels = [], []

    print(f'\n{"="*80}')
    print(f'Epoch [{epoch+1}/{num_epochs}]')
    print(f'{"="*80}\n')

    total_windows = 0
    total_samples = 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1} Training")

    for batch_data, batch_labels in progress_bar:
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)
        
        # 检查数据形状并处理
        if len(batch_data.shape) == 5:  # (batch_size, num_windows, 3, 128, 128)
            batch_size, num_windows = batch_data.shape[:2]
            
            # 重塑为 (batch_size * num_windows, 3, 128, 128)
            batch_data = batch_data.view(-1, 3, 128, 128)
            batch_labels = batch_labels.view(-1)  # (batch_size * num_windows,)
            
            total_windows += batch_size * num_windows
            total_samples += batch_size
            
        else:  # (batch_size, 3, 128, 128) - 单个window情况
            total_windows += batch_data.shape[0]
            total_samples += batch_data.shape[0]

        # Forward
        optimizer.zero_grad()
        logits = model(batch_data).squeeze()
        
        # 处理单样本情况
        if logits.dim() == 0:
            logits = logits.unsqueeze(0)
            
        loss = criterion(logits, batch_labels.float())

        # Backward
        loss.backward()
        optimizer.step()

        # 计算预测
        preds = (torch.sigmoid(logits) > threshold).long()
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(batch_labels.cpu().numpy())

        running_loss += loss.item()
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'windows': f'{total_windows}',
            'samples': f'{total_samples}'
        })
    
    print(f"Total windows processed in epoch {epoch+1}: {total_windows}")
    print(f"Total samples processed in epoch {epoch+1}: {total_samples}")
    
    if train_dataset.num_windows > 1:
        print(f"Average windows per sample: {total_windows/total_samples:.1f}")

    epoch_loss = running_loss / len(train_loader)
    training_losses.append(epoch_loss)
    
    train_accuracy = accuracy_score(all_labels, all_preds)
    train_f1 = f1_score(all_labels, all_preds, zero_division=0)

    # 验证阶段 - 使用voting评估
    avg_val_loss, val_accuracy, val_f1, val_preds, val_labels = eval_with_voting(
        model=model,
        dataset=devel_dataset,
        criterion=criterion,
        device=device,
        threshold=threshold,
        vote_mode='soft'
    )

    val_uar = recall_score(val_labels, val_preds, average='macro', zero_division=0)
    validation_losses.append(avg_val_loss)

    print(f"\nEpoch [{epoch+1}] Summary:")
    print(f"  📈 Training   - Loss: {epoch_loss:.4f}, ACCR: {train_accuracy:.4f}, F1: {train_f1:.4f}")
    print(f"  📊 Validation - Loss: {avg_val_loss:.4f}, UAR: {val_uar:.4f}, F1: {val_f1:.4f}")
    
    # 详细的类别召回率
    if len(set(val_labels)) > 1 and len(set(val_preds)) > 1:
        class_recalls = recall_score(val_labels, val_preds, average=None, zero_division=0)
        print(f"  🎯 Class Recalls - Healthy: {class_recalls[0]:.4f}, Cold: {class_recalls[1]:.4f}")

    # 保存最佳模型
    if val_uar > best_uar:
        best_uar = val_uar
        early_stop_counter = 0
        
        # 保存完整的模型信息
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_uar': best_uar,
            'train_loss': epoch_loss,
            'val_loss': avg_val_loss,
            'num_windows': train_dataset.num_windows
        }, "best_sliding_window.pth")
        
        print(f"🌟 New best UAR: {best_uar:.4f}, saving model...")
    else:
        early_stop_counter += 1
        print(f"⏳ No improvement for {early_stop_counter}/{patience} epochs")
        
        if early_stop_counter >= patience:
            print(f"❌ No improvement in UAR for {patience} epochs, early stopping...")
            break

print(f"\n🎉 Training complete in {(time.time() - start_time)/60:.2f} min")
print(f"🏆 Best Validation UAR: {best_uar:.4f}")

# 打印训练统计信息
if hasattr(train_dataset, 'num_windows'):
    print(f"📊 Training configuration:")
    print(f"   - Windows per sample: {train_dataset.num_windows}")
    print(f"   - Window size: {train_dataset.window_size}")
    print(f"   - Stride: {train_dataset.stride}")

# 保存训练历史
training_history = {
    'training_losses': training_losses,
    'validation_losses': validation_losses,
    'best_uar': best_uar,
    'num_epochs_trained': epoch + 1,
    'early_stopped': early_stop_counter >= patience,
    'num_windows': getattr(train_dataset, 'num_windows', 1)
}

torch.save(training_history, 'training_history.pth')
print(f"💾 Training history saved to 'training_history.pth'")

Starting training...


Epoch [1/100]



Epoch 1 Training: 100%|██████████| 164/164 [01:35<00:00,  1.72it/s, loss=0.2920, windows=31425, samples=10475]


Total windows processed in epoch 1: 31425
Total samples processed in epoch 1: 10475
Average windows per sample: 3.0


Validating: 100%|██████████| 9596/9596 [01:05<00:00, 145.70it/s]



Epoch [1] Summary:
  📈 Training   - Loss: 0.5563, ACCR: 0.8902, F1: 0.7214
  📊 Validation - Loss: 1.0710, UAR: 0.5004, F1: 0.0020
  🎯 Class Recalls - Healthy: 0.9999, Cold: 0.0010
🌟 New best UAR: 0.5004, saving model...

Epoch [2/100]



Epoch 2 Training: 100%|██████████| 164/164 [01:28<00:00,  1.84it/s, loss=0.2382, windows=31425, samples=10475]


Total windows processed in epoch 2: 31425
Total samples processed in epoch 2: 10475
Average windows per sample: 3.0


Validating: 100%|██████████| 9596/9596 [01:04<00:00, 148.66it/s]



Epoch [2] Summary:
  📈 Training   - Loss: 0.4034, ACCR: 0.9346, F1: 0.8229
  📊 Validation - Loss: 0.9679, UAR: 0.5914, F1: 0.2435
  🎯 Class Recalls - Healthy: 0.7506, Cold: 0.4322
🌟 New best UAR: 0.5914, saving model...

Epoch [3/100]



Epoch 3 Training: 100%|██████████| 164/164 [01:31<00:00,  1.78it/s, loss=0.1884, windows=31425, samples=10475]


Total windows processed in epoch 3: 31425
Total samples processed in epoch 3: 10475
Average windows per sample: 3.0


Validating: 100%|██████████| 9596/9596 [01:04<00:00, 147.70it/s]



Epoch [3] Summary:
  📈 Training   - Loss: 0.3457, ACCR: 0.9424, F1: 0.8440
  📊 Validation - Loss: 0.8355, UAR: 0.5957, F1: 0.2740
  🎯 Class Recalls - Healthy: 0.9065, Cold: 0.2849
🌟 New best UAR: 0.5957, saving model...

Epoch [4/100]



Epoch 4 Training: 100%|██████████| 164/164 [01:56<00:00,  1.40it/s, loss=0.3578, windows=31425, samples=10475]


Total windows processed in epoch 4: 31425
Total samples processed in epoch 4: 10475
Average windows per sample: 3.0


Validating: 100%|██████████| 9596/9596 [01:06<00:00, 144.42it/s]



Epoch [4] Summary:
  📈 Training   - Loss: 0.3121, ACCR: 0.9495, F1: 0.8630
  📊 Validation - Loss: 1.1764, UAR: 0.5209, F1: 0.0822
  🎯 Class Recalls - Healthy: 0.9983, Cold: 0.0435
⏳ No improvement for 1/4 epochs

Epoch [5/100]



Epoch 5 Training: 100%|██████████| 164/164 [01:36<00:00,  1.70it/s, loss=0.2914, windows=31425, samples=10475]


Total windows processed in epoch 5: 31425
Total samples processed in epoch 5: 10475
Average windows per sample: 3.0


Validating: 100%|██████████| 9596/9596 [01:17<00:00, 123.03it/s]



Epoch [5] Summary:
  📈 Training   - Loss: 0.2902, ACCR: 0.9474, F1: 0.8600
  📊 Validation - Loss: 2.5997, UAR: 0.5000, F1: 0.0000
⏳ No improvement for 2/4 epochs

Epoch [6/100]



Epoch 6 Training: 100%|██████████| 164/164 [01:47<00:00,  1.52it/s, loss=0.1329, windows=31425, samples=10475]


Total windows processed in epoch 6: 31425
Total samples processed in epoch 6: 10475
Average windows per sample: 3.0


Validating: 100%|██████████| 9596/9596 [01:06<00:00, 144.11it/s]



Epoch [6] Summary:
  📈 Training   - Loss: 0.2639, ACCR: 0.9491, F1: 0.8665
  📊 Validation - Loss: 1.5688, UAR: 0.5231, F1: 0.0939
  🎯 Class Recalls - Healthy: 0.9948, Cold: 0.0514
⏳ No improvement for 3/4 epochs

Epoch [7/100]



Epoch 7 Training:   9%|▉         | 15/164 [00:10<01:47,  1.39it/s, loss=0.1855, windows=2880, samples=960]


KeyboardInterrupt: 

In [44]:
from sklearn.metrics import accuracy_score, f1_score, recall_score
uar = recall_score(val_labels, val_preds, average='macro')
print(f"Validation UAR: {uar:.4f}")

if len(set(val_labels)) > 1 and len(set(val_preds)) > 1:
    class_recalls = recall_score(val_labels, val_preds, average=None)
    print(f"Class 0 (Healthy) Recall: {class_recalls[0]:.4f}")
    print(f"Class 1 (Cold) Recall: {class_recalls[1]:.4f}")

Validation UAR: 0.6368
Class 0 (Healthy) Recall: 0.5971
Class 1 (Cold) Recall: 0.6766
