In [1]:


import os
import torch
import pandas as pd
import torchaudio
from torch.utils.data import Dataset
from transformers import ASTFeatureExtractor

class OptimizedAudioDataset(Dataset):
    def __init__(self, csv_file, tensor_dir, strict_join=True):
        """
        Args:
            csv_file (str): Path to the metadata CSV (e.g., 'combined_data.csv').
            tensor_dir (str): Path to folder containing pre-processed .pt files.
            strict_join (bool): If True, filters out rows where the .pt file is missing.
        """
        self.df = pd.read_csv(csv_file)
        self.tensor_dir = tensor_dir
        
        # 1. Strict Join: Filter out rows where the pre-processed tensor doesn't exist
        if strict_join:
            self.df['file_exists'] = self.df['id'].apply(
                lambda x: os.path.isfile(os.path.join(self.tensor_dir, f"{x}.pt"))
            )
            original_count = len(self.df)
            self.df = self.df[self.df['file_exists'] == True].copy()
            print(f"Strict Join: Kept {len(self.df)} out of {original_count} samples.")

            if len(self.df) == 0:
                raise FileNotFoundError(f"No matching .pt files found in {tensor_dir}!")

        # 2. Define target labels
        self.target_cols = ['cough', 'cold', 'asthma', 'pneumonia', 'test_status']

        # 3. Preprocess labels (Standardizing 'True'/'Positive' to 1, others to 0)
        for col in self.target_cols:
            self.df[col] = self.df[col].apply(
                lambda x: 1 if (x == True or (isinstance(x, str) and 
                (x.lower() in ['y', 'true', 'p', 'positive']))) else 0
            )

        # 4. Convert to Tensors for fast indexing
        self.labels = torch.tensor(self.df[self.target_cols].values, dtype=torch.float32)
        self.ids = self.df['id'].values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row_id = self.ids[idx]
        tensor_path = os.path.join(self.tensor_dir, f"{row_id}.pt")

        try:
            # FAST LOAD: Directly load the pre-computed spectrogram tensor
            spectrogram = torch.load(tensor_path)
        except Exception as e:
            # Fallback (returns a zero tensor of the correct AST shape [1024, 128])
            print(f"Warning: Could not load {tensor_path} - {e}")
            spectrogram = torch.zeros(1024, 128)

        # Returns: (Spectrogram Tensor, Label Tensor)
        return spectrogram, self.labels[idx]
    def get_pos_weights(self):
        """
        Calculates weights directly from the internal labels tensor.
        """
        # self.labels shape is [N, 5]
        pos_counts = self.labels.sum(dim=0)  # Sum of positives for each class
        total_samples = self.labels.size(0)
        neg_counts = total_samples - pos_counts
        
        # Formula: Negatives / Positives
        # We add a tiny epsilon to avoid division by zero
        weights = neg_counts / (pos_counts + 1e-6)
        
        return weights



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch.utils.data import DataLoader, random_split

# 1. Initialize the full dataset
# Make sure 'audio_files/' is the folder containing your .wav files
full_dataset = OptimizedAudioDataset(csv_file='combined_data.csv', tensor_dir='processed_tensors/')

# 2. Define split sizes (e.g., 80% training, 20% validation)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# 3. Create the DataLoaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=16,       # You can likely increase batch size now
    shuffle=True, 
    num_workers=4,       # Use more workers to load files in parallel
    pin_memory=True      # CRITICAL: Faster CPU->GPU transfer
)

val_loader = DataLoader(
    val_dataset,
    batch_size=8,
    shuffle=False,     # No need to shuffle validation
    num_workers=2
)

Strict Join: Kept 2703 out of 2746 samples.


In [3]:
import torch.nn as nn
from transformers import ASTModel

class CovidAudioClassifier(nn.Module):
    def __init__(self, num_labels=5):
        super(CovidAudioClassifier, self).__init__()
        # 1. The Backbone (Body)
        self.ast = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

        # 2. The Task Head
        self.classifier = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 5)  # The hidden layer allows features to "interact" before the final vote
        )

    def forward(self, x):
        # Pass through AST
        outputs = self.ast(x)
        pooled_output = outputs.last_hidden_state[:, 0, :] # Use the [CLS] token

        # Pass through the Head
        logits = self.classifier(pooled_output)
        return logits

In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm # For a nice progress bar

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0

    for spectrograms, labels in tqdm(dataloader, desc="Training"):
        # 1. Move data to GPU/CPU
        spectrograms = spectrograms.to(device)
        labels = labels.to(device) # Labels are [Batch, 5]

        # 2. Forward pass
        optimizer.zero_grad()
        logits = model(spectrograms)

        # 3. Calculate Loss
        # BCEWithLogitsLoss applies Sigmoid internally for stability
        loss = criterion(logits, labels)

        # 4. Backward pass
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(dataloader)

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for spectrograms, labels in tqdm(dataloader, desc="Validating"):
            spectrograms = spectrograms.to(device)
            labels = labels.to(device)

            logits = model(spectrograms)
            loss = criterion(logits, labels)
            running_loss += loss.item()

            # Convert logits to binary predictions (0 or 1)
            # We use a threshold of 0.5 after applying sigmoid
            preds = (torch.sigmoid(logits) > 0.5).float()

            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    # Concatenate all results for metric calculation
    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    # Simple Accuracy: Percentage of ALL individual labels predicted correctly
    accuracy = (all_preds == all_labels).float().mean()

    return running_loss / len(dataloader), accuracy

In [5]:

import torch.nn as nn
import torch.optim as optim
import torchaudio.transforms as T

class MultiLabelHingeLoss(nn.Module):
    def __init__(self, pos_weight=None):
        super(MultiLabelHingeLoss, self).__init__()
        self.pos_weight = pos_weight

    def forward(self, logits, targets):
        """
        logits: Raw output from model [Batch, 5]
        targets: 0 or 1 labels [Batch, 5]
        """
        # 1. Convert targets from {0, 1} to {-1, 1}
        # 0 -> -1, 1 -> +1
        targets_signed = 2 * targets - 1
        
        # 2. Hinge Logic: max(0, 1 - y * y_pred)
        # We want the correct class logit to be > 1.0 (margin)
        hinge_raw = 1 - (targets_signed * logits)
        loss = torch.clamp(hinge_raw, min=0)
        
        # 3. Apply Class Weights
        # If the target is Positive (1), multiply loss by pos_weight.
        # Otherwise, keep weight as 1.0.
        if self.pos_weight is not None:
            # Expand pos_weight to match batch size if necessary, though broadcasting usually handles it
            weights = torch.where(targets == 1, self.pos_weight, torch.ones_like(logits))
            loss = loss * weights
            
        return loss.mean()

# 1. Device & Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 15 # Increased epochs (Augmentation makes training harder/slower, which is good)
learning_rate = 1e-5

# 2. Model Setup (INCREASED DROPOUT)
model = CovidAudioClassifier(num_labels=5).to(device)
# Update the dropout layer in the existing model instance
model.classifier[2] = nn.Dropout(p=0.5) 

# 3. Augmentation Transforms (SpecAugment)
# We apply this ONLY during training
time_masking = T.TimeMasking(time_mask_param=80) # Mask up to 80 time steps
freq_masking = T.FrequencyMasking(freq_mask_param=40) # Mask up to 40 freq bins

# 4. Loss & Optimizer
class_weights = full_dataset.get_pos_weights().to(device)
criterion = MultiLabelHingeLoss(pos_weight=class_weights) # Using your Hinge Loss
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01) # Added weight_decay

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3
)

# 5. Training Loop with "Save Best" Logic
print(f"Starting Robust Training on {device}...")

best_val_loss = float('inf')

for epoch in range(num_epochs):
    
    # --- TRAINING PHASE ---
    model.train()
    running_loss = 0.0
    
    for spectrograms, labels in train_loader:
        spectrograms, labels = spectrograms.to(device), labels.to(device)
        
        # --- NEW: APPLY SPECAUGMENT ---
        # AST Input shape: [Batch, Time (1024), Freq (128)]
        # Torchaudio Masks expect: [..., Freq, Time]
        # 1. Swap axes to [Batch, Freq, Time]
        aug_spec = spectrograms.transpose(1, 2)
        
        # 2. Apply Masks
        aug_spec = freq_masking(aug_spec)
        aug_spec = time_masking(aug_spec)
        
        # 3. Swap back to [Batch, Time, Freq]
        aug_spec = aug_spec.transpose(1, 2)
        # -----------------------------
        
        optimizer.zero_grad()
        logits = model(aug_spec) # Use augmented specs
        loss = criterion(logits, labels)
        
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    avg_train_loss = running_loss / len(train_loader)

    # --- VALIDATION PHASE ---
    # Note: We do NOT use SpecAugment here. Validation must be clean.
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    # --- SCHEDULER & CHECKPOINTING ---
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
    
    # Save Best Model Logic
    if val_loss < best_val_loss:
        print(f"--> NEW BEST! (Loss dropped from {best_val_loss:.4f} to {val_loss:.4f}). Saving model.")
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        print("--> No improvement.")
        
    print("-" * 30)

Loading weights: 100%|██████████| 199/199 [00:00<00:00, 402.55it/s, Materializing param=layernorm.weight]                                 
[1mASTModel LOAD REPORT[0m from: MIT/ast-finetuned-audioset-10-10-0.4593
Key                         | Status     |  | 
----------------------------+------------+--+-
classifier.layernorm.weight | UNEXPECTED |  | 
classifier.dense.weight     | UNEXPECTED |  | 
classifier.dense.bias       | UNEXPECTED |  | 
classifier.layernorm.bias   | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


Starting Robust Training on cuda...


Validating: 100%|██████████| 68/68 [00:47<00:00,  1.43it/s]


Epoch 1/15
Train Loss: 1.6347 | Val Loss: 1.5251 | Val Acc: 0.6159
--> NEW BEST! (Loss dropped from inf to 1.5251). Saving model.
------------------------------


Validating: 100%|██████████| 68/68 [00:47<00:00,  1.42it/s]


Epoch 2/15
Train Loss: 1.3854 | Val Loss: 1.4338 | Val Acc: 0.6129
--> NEW BEST! (Loss dropped from 1.5251 to 1.4338). Saving model.
------------------------------


Validating: 100%|██████████| 68/68 [00:47<00:00,  1.42it/s]


Epoch 3/15
Train Loss: 1.2360 | Val Loss: 1.3579 | Val Acc: 0.6854
--> NEW BEST! (Loss dropped from 1.4338 to 1.3579). Saving model.
------------------------------


Validating: 100%|██████████| 68/68 [00:47<00:00,  1.43it/s]


Epoch 4/15
Train Loss: 1.0390 | Val Loss: 1.2252 | Val Acc: 0.7257
--> NEW BEST! (Loss dropped from 1.3579 to 1.2252). Saving model.
------------------------------


Validating: 100%|██████████| 68/68 [00:47<00:00,  1.42it/s]

Epoch 5/15
Train Loss: 0.9904 | Val Loss: 1.3626 | Val Acc: 0.7416
--> No improvement.
------------------------------



Validating: 100%|██████████| 68/68 [00:47<00:00,  1.42it/s]

Epoch 6/15
Train Loss: 0.8775 | Val Loss: 1.2660 | Val Acc: 0.7534
--> No improvement.
------------------------------



Validating: 100%|██████████| 68/68 [00:47<00:00,  1.43it/s]

Epoch 7/15
Train Loss: 0.7817 | Val Loss: 1.4827 | Val Acc: 0.7963
--> No improvement.
------------------------------



Validating: 100%|██████████| 68/68 [00:47<00:00,  1.42it/s]

Epoch 8/15
Train Loss: 0.6797 | Val Loss: 1.4570 | Val Acc: 0.7922
--> No improvement.
------------------------------



Validating: 100%|██████████| 68/68 [00:47<00:00,  1.43it/s]

Epoch 9/15
Train Loss: 0.5762 | Val Loss: 1.3812 | Val Acc: 0.7837
--> No improvement.
------------------------------



Validating: 100%|██████████| 68/68 [00:48<00:00,  1.39it/s]

Epoch 10/15
Train Loss: 0.5271 | Val Loss: 1.3293 | Val Acc: 0.8111
--> No improvement.
------------------------------





KeyboardInterrupt: 

In [6]:
# Save the model currently sitting in memory
torch.save(model.state_dict(), 'hinge_dropout.pth')
print("Saved safely!")

Saved safely!
