# Speech Commands Classification using Audio Spectrogram Transformer (AST)

This script loads the Speech Commands dataset, preprocesses audio data into Mel spectrograms, 
and fine-tunes a Vision Transformer (ViT) model, adapting it to work as an AST for audio classification.
The model is trained using a supervised learning approach, with early stopping to prevent overfitting.


Dependencies:
- PyTorch
- TorchAudio
- timm (for pretrained ViT models)
- Librosa (for audio processing)
- NumPy
- Matplotlib

"""



In [None]:
import os
import torch
import torchaudio
from torch.utils.data import DataLoader
from torchaudio.datasets import SPEECHCOMMANDS
import torchaudio.transforms as transforms

# Ensure the dataset directory exists
os.makedirs("./SpeechCommands", exist_ok=True)

class SubsetSC(SPEECHCOMMANDS):
    """
    Custom subset of the Speech Commands dataset, allowing filtering by training, validation, or testing sets.
    """
    def __init__(self, root: str = "./SpeechCommands", subset: str = None):
        super().__init__(root=root, download=True)
        
        def load_list(filename):
            filepath = os.path.join(self._path, filename)
            with open(filepath) as f:
                return [os.path.join(self._path, line.strip()) for line in f]
        
        if subset == "validation":
            self._walker = load_list("validation_list.txt")
        elif subset == "testing":
            self._walker = load_list("testing_list.txt")
        elif subset == "training":
            excludes = set(load_list("validation_list.txt") + load_list("testing_list.txt"))
            self._walker = [w for w in self._walker if w not in excludes]


In [None]:
# Collate function to handle varying waveform lengths
def collate_fn(batch):
    """
    Pads waveforms in a batch to the maximum length found within the batch.
    """
    waveforms, sample_rates, labels, speaker_ids, utterance_numbers = zip(*batch)
    max_length = max(waveform.shape[1] for waveform in waveforms)
    padded_waveforms = [torch.nn.functional.pad(w, (0, max_length - w.shape[1])) for w in waveforms]
    return torch.stack(padded_waveforms), torch.tensor(sample_rates), labels, speaker_ids, utterance_numbers



In [None]:
# Create dataset objects
train_set = SubsetSC(root="./SpeechCommands", subset="training")
val_set   = SubsetSC(root="./SpeechCommands", subset="validation")
test_set  = SubsetSC(root="./SpeechCommands", subset="testing")



In [None]:
# Create DataLoaders
batch_size = 64
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)



In [None]:
# Test data loading
for batch in train_loader:
    waveforms, sample_rates, labels, speaker_ids, utterance_numbers = batch
    print("Waveforms shape:", waveforms.shape)  # Should be [batch_size, 1, max_length]
    print("Sample rate:", sample_rates[0])
    print("Labels:", labels)
    break


## Model Setup: Vision Transformer (ViT) Adapted for Audio Spectrogram Transformer (AST)

We use a ViT model as the backbone, modifying it to accept single-channel spectrogram inputs
and adapting its classifier head for speech command recognition.

In [None]:
import timm
import torch.nn as nn

# Load a pretrained Vision Transformer model
model = timm.create_model('vit_base_patch16_224_in21k', pretrained=True)
model.patch_embed.img_size = (128, 128)  # Adjust expected input size
model.default_cfg['img_size'] = 128

# Load pretrained AST weights
ast_pretrained_weight = torch.load("audioset_16_16_0.4422.pth")

# Modify the model for AST input
model.patch_embed.proj = nn.Conv2d(1, 768, kernel_size=(8, 8), stride=(8, 8))
num_tokens_ast = 257  # 256 patches + 1 CLS token
model.pos_embed = nn.Parameter(torch.randn(1, num_tokens_ast, 768) * .02)

# Transfer weights from AST to ViT
v = model.state_dict()
v['cls_token'] = ast_pretrained_weight['module.v.cls_token']
for i in range(12):  # Transfer transformer block weights
    for key in ['norm1', 'attn.qkv', 'attn.proj', 'norm2', 'mlp.fc1', 'mlp.fc2']:
        v[f'blocks.{i}.{key}.weight'] = ast_pretrained_weight[f'module.v.blocks.{i}.{key}.weight']
        v[f'blocks.{i}.{key}.bias'] = ast_pretrained_weight[f'module.v.blocks.{i}.{key}.bias']
v['norm.weight'] = ast_pretrained_weight['module.v.norm.weight']
v['norm.bias'] = ast_pretrained_weight['module.v.norm.bias']
model.load_state_dict(v)

In [None]:
# Modify classifier head for speech commands
num_speech_commands = 35
model.head = nn.Linear(model.head.in_features, num_speech_commands)



In [None]:
# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [None]:
# Early stopping setup
patience = 5
best_val_loss = float('inf')
epochs_without_improvement = 0

In [None]:
# Training loop with tqdm for progress tracking
import csv
import matplotlib.pyplot as plt
from tqdm import tqdm

csv_filename = "training_metrics.csv"
with open(csv_filename, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["Epoch", "Train Loss", "Train Accuracy", "Val Loss", "Val Accuracy"])

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
num_epochs = 100
train_losses, val_losses, train_accuracies, val_accuracies = [], [], [], []

for epoch in range(num_epochs):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for spectrograms, labels in tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}] Training", leave=False):
        spectrograms, labels = spectrograms.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(spectrograms)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        running_loss += loss.item() * spectrograms.size(0)
    train_loss, train_accuracy = running_loss / len(train_dataset), 100 * correct / total
    val_loss, val_accuracy = evaluate_model(model, val_loader)
    
    with open(csv_filename, mode='a', newline='') as file:
        writer.writerow([epoch + 1, train_loss, train_accuracy, val_loss, val_accuracy])

    if val_loss < best_val_loss:
        best_val_loss, epochs_without_improvement = val_loss, 0
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            break

print("Training complete!")
