# 1. Introduction

This notebook outlines the creation, compilation, and training of a deep learing network for audio classification using the [TorchSuite](https://github.com/sergio-sanz-rodriguez/torchsuite) framework.
 
https://pytorch.org/tutorials/intermediate/speech_command_classification_with_torchaudio_tutorial.html

# 2. Importing Libraries

In [None]:
#!pip install torcheval
import os
import shutil
import torch
import glob
import random
import librosa
import torch.backends.cudnn as cudnn
import torch.nn as nn
import IPython.display as ipd
import matplotlib.pyplot as plt
import pandas as pd
import torchaudio

from pathlib import Path
from torchinfo import summary
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchaudio.datasets import SPEECHCOMMANDS


# Import custom libraries
from utils.classification_utils import set_seeds, predict_and_play_audio, load_model
from engines.classification import ClassificationEngine
from engines.schedulers import FixedLRSchedulerWrapper
from models.vision_transformer import create_vit
from dataloaders.audio_dataloaders import load_audio, create_dataloaders_spectrogram, AudioSpectrogramTransforms

import warnings
os.environ['TORCH_USE_CUDA_DSA'] = "1"
warnings.filterwarnings("ignore", category=UserWarning, module="torch.autograd.graph")
warnings.filterwarnings("ignore", category=FutureWarning, module="onnxscript.converter")

# Paths (modify as needed)
TARGET_DIR_NAME = Path("data/SpeechCommands/speech_commands_v0.02")
TRAIN_DIR = Path("data/SpeechCommands/train")
TEST_DIR = Path("data/SpeechCommands/test")

# Define some constants
NUM_WORKERS = os.cpu_count()
SEED = 42

# Create target model directory
MODEL_DIR = Path("outputs")
MODEL_DIR.mkdir(parents=True, exist_ok=True)

# Set seeds
set_seeds(SEED)

IMPORT_DATASET = False

# 3. Specifying the Target Device

In [None]:
# Activate cuda benchmark
cudnn.benchmark = True

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

if device == "cuda":
    !nvidia-smi

# 4. Importing Dataset

In [None]:
if IMPORT_DATASET:
    # Download dataset
    os.makedirs("data", exist_ok=True)
    dataset = SPEECHCOMMANDS(
        root="./data",
        url="speech_commands_v0.02",
        folder_in_archive="SpeechCommands",
        download=True,
        subset=None
        )

In [None]:
if IMPORT_DATASET:
    # Read validation and test lists
    val_test_files = set()
    for filename in ["validation_list.txt", "testing_list.txt"]:
        with open(os.path.join(TARGET_DIR_NAME, filename), "r") as f:
            val_test_files.update(f.read().splitlines())

    # Ensure output directories exist
    os.makedirs(TRAIN_DIR, exist_ok=True)
    os.makedirs(TEST_DIR, exist_ok=True)

    # Loop over all class folders
    for class_name in os.listdir(TARGET_DIR_NAME):
        class_path = os.path.join(TARGET_DIR_NAME, class_name)
        if not os.path.isdir(class_path):  # Skip non-folder files
            continue

        # Create class folders in train/ and test/
        os.makedirs(os.path.join(TRAIN_DIR, class_name), exist_ok=True)
        os.makedirs(os.path.join(TEST_DIR, class_name), exist_ok=True)

        # Loop over all audio files in the class folder
        for file_name in os.listdir(class_path):
            # Skip non-wav-audio files
            if not file_name.endswith(".wav"):  
                continue
            
            # Copy file to train/ or test/
            src_path = os.path.join(class_path, file_name)
            dest_folder = TEST_DIR if f"{class_name}/{file_name}" in val_test_files else TRAIN_DIR
            dest_path = os.path.join(dest_folder, class_name)        
            shutil.copy(src_path, dest_path)

    # Remove _background_noise_ (not needed for this notebook)
    background_noise_train = TRAIN_DIR / "_background_noise_"
    background_noise_test = TEST_DIR / "_background_noise_"

    # Remove unnecessary folders and files
    if background_noise_train.exists():
        shutil.rmtree(background_noise_train)

    if background_noise_test.exists():
        shutil.rmtree(background_noise_test)

    if TARGET_DIR_NAME.exists():
        shutil.rmtree(TARGET_DIR_NAME)

    zip_file = Path("data/speech_commands_v0.02.tar.gz")
    if zip_file.exists():
        os.remove(zip_file)

    print("Dataset restructuring completed!")

# 5. Preparing Dataloaders

In [None]:
new_sample_rate = 8000
target_length = 8000 # use 1-sec length
waveform, sample_rate = load_audio('data/SpeechCommands/train/backward/0a2b400e_nohash_0.wav')
IMG_SIZE = 384
BATCH_SIZE = 32
ACCUM_STEPS = 2
FFT_POINTS = 1024
hop_length = round(target_length / (IMG_SIZE - 1))

# Transformations for training dataset
get_transform_train = AudioSpectrogramTransforms(
    augmentation=False,
    mean_std_norm=True,
    fft_analysis_method="single", #"time_freq" #"freq_band"
    sample_rate=sample_rate,
    new_sample_rate=new_sample_rate,
    target_length=target_length,
    n_fft=FFT_POINTS,
    img_size=(IMG_SIZE, IMG_SIZE),
    augment_magnitude=2
)

# Transformations for test dataset
get_transform_test = AudioSpectrogramTransforms(
    augmentation=False,
    mean_std_norm=True,
    fft_analysis_method="single", #"time_freq" #"freq_band"
    sample_rate=sample_rate,
    new_sample_rate=new_sample_rate,
    target_length=target_length,
    n_fft=FFT_POINTS,
    img_size=(IMG_SIZE, IMG_SIZE),
    augment_magnitude=2
)

# Create dataloaders
train_dataloader, test_dataloader, class_names = create_dataloaders_spectrogram(
    train_dir=TRAIN_DIR,
    test_dir=TEST_DIR,
    train_transform=get_transform_train,
    test_transform=get_transform_test,
    batch_size=BATCH_SIZE,
    num_workers=0,
    random_seed=SEED
)

# Verify classes and batches
print(f"Classes: {class_names}")
print(f"Train batches: {len(train_dataloader)}, Test batches: {len(test_dataloader)}")

In [None]:
# Calculate the number of classes
NUM_CLASSES = len(class_names)
print(f"Number of classes: {NUM_CLASSES}")

# 6. Audio Visualization and Reproduction

In [None]:
# Get the length of the train_set
train_set_size = len(train_dataloader.dataset)

# Visualize some audio waveforms
num_samples = 10
fig, axs = plt.subplots(num_samples, 2, figsize=(15, num_samples*3))

# Plot waveform and spectrogram
for row in range(num_samples):
    # Randomly select an index from the train_set
    idx = torch.randint(0, train_set_size, (1,)).item()
    
    # Get waveform
    waveform, _ = load_audio(train_dataloader.dataset.files[idx])

    # Get spectrogram    
    audio_spectrogram_transforms = AudioSpectrogramTransforms(
        augmentation=False,
        mean_std_norm=False,
        fft_analysis_method="single", #"time_freq" #"freq_band"
        sample_rate=sample_rate,
        new_sample_rate=new_sample_rate,
        target_length=target_length,
        n_fft=FFT_POINTS,
        img_size=(IMG_SIZE, IMG_SIZE),
        augment_magnitude=2
    )

    spectrogram = audio_spectrogram_transforms(waveform)

    # Get label
    label = class_names[train_dataloader.dataset.labels[idx]]
    
    # Plot waveform
    axs[row][0].plot(waveform.t().numpy())  # Ensure the waveform is transposed if necessary
    axs[row][0].set_title(f"Waveform - Label: {label} - Idx: {train_dataloader.dataset.labels[idx]}")
    axs[row][0].set_xlabel("Time")
    axs[row][0].set_ylabel("Amplitude")
    axs[row][0].set_xticks([])
    axs[row][0].set_yticks([])

    # Plot spectrogram
    axs[row][1].imshow(spectrogram.permute(1, 2, 0).detach().numpy(), aspect='auto', origin='lower', cmap='magma')
    axs[row][1].set_title(f"Spectrogram - Label: {label} - Idx: {train_dataloader.dataset.labels[idx]}")
    axs[row][1].set_xlabel("Time")
    axs[row][1].set_ylabel("Frequency")
    axs[row][1].set_xticks([])
    axs[row][1].set_yticks([])

fig.tight_layout()
plt.show()

In [None]:
# Play out some audio files
try:
    waveform_first, _ = load_audio(train_dataloader.dataset.files[0])
except:
    waveform_first, _ = load_audio(train_dataloader.dataset.dataset.files[0])
ipd.Audio(waveform_first.numpy(), rate=sample_rate)

In [None]:
try:
    waveform_second, *_ = load_audio(train_dataloader.dataset.files[1])
except:
    waveform_second, *_ = load_audio(train_dataloader.dataset.dataset.files[1])
ipd.Audio(waveform_second.numpy(), rate=sample_rate)

# 7. Creating the Transformer Model

In [None]:
model = create_vit(
    vit_model="vitbase16_2",
    num_classes=NUM_CLASSES,
    dropout=0.1,
    seed=SEED,
    device=device
    )

# Unfreeze the base parameters
for parameter in model.parameters():
    parameter.requires_grad = True

# Compile model
model = torch.compile(model, backend="aot_eager")

# Send model to device
model.to(device)

# 8. Training the Model

In [None]:
# Train the model
EPOCHS = 20
LR = 1e-4
ETA_MIN = 1e-6
model_type="model_spectrogram"
model_name = model_type + ".pth"

# Create AdamW optimizer
optimizer = torch.optim.AdamW(
    params=model.parameters(),
    lr=LR,
    weight_decay=0.0001
)

# Create loss function
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.1)

# Set scheduler
#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
scheduler = FixedLRSchedulerWrapper(
    scheduler=CosineAnnealingLR(optimizer, T_max=20, eta_min=ETA_MIN),
    fixed_lr=ETA_MIN,
    fixed_epoch=20)

# Set seeds
set_seeds(SEED)

# And train...

# Instantiate the classification engine with the created model and the target device
engine = ClassificationEngine(
    model=model,
    #color_map={'train': 'red', 'test': 'yellow', 'other': 'black'},
    log_verbose=True,
    device=device)

# Configure the training method
results = engine.train(
    target_dir=MODEL_DIR,                       # Directory where the model will be saved
    model_name=model_name,                      # Name of the model
    save_best_model=["loss", "acc", "pauc"],    # Save the best models based on different criteria
    keep_best_models_in_memory=False,           # Do not keep the models stored in memory for the sake of training time and memory efficiency
    train_dataloader=train_dataloader,          # Train dataloader
    test_dataloader=test_dataloader,            # Validation/test dataloader
    apply_validation=True,                      # Enable validation step
    num_classes=NUM_CLASSES,                    # Number of classes
    optimizer=optimizer,                        # Optimizer
    loss_fn=loss_fn,                            # Loss function
    recall_threshold=1.0,                       # False positive rate at recall_threshold recall
    recall_threshold_pauc=0.0,                  # Partial AUC score above recall_threshold_pauc recall
    scheduler=scheduler,                        # Scheduler
    epochs=EPOCHS,                              # Total number of epochs
    amp=True,                                   # Enable Automatic Mixed Precision (AMP)
    enable_clipping=False,                      # Disable clipping on gradients, only useful if training becomes unestable
    debug_mode=False,                           # Disable debug mode    
    accumulation_steps=ACCUM_STEPS              # Accumulation steps 2: effective batch size = batch_size x accumulation steps
    )

In [None]:
# Make predictions by manually loading the best model
if device == "cuda":
    torch.cuda.empty_cache()

torch.cuda.empty_cache()
model = create_vit(
    vit_model="vitbase16_2",
    num_classes=NUM_CLASSES,
    dropout=0.1,
    seed=SEED,
    device=device
    )

# Compile model
model = torch.compile(model, backend="aot_eager")

# Find the file that matchs the pattern `_pauc_`
model_file = glob.glob(os.path.join(MODEL_DIR, "model_spectrogram_acc_*.pth"))
model_name = os.path.basename(model_file[0])

# Instantiate engine for predictions
engine2 = ClassificationEngine(
        model=model,        
        log_verbose=True,
        device=device)

engine2.load(target_dir=MODEL_DIR, model_name=model_name)
#indexes2 = engine2.predict(
#    dataloader=test_dataloader,
#    output_type='argmax').tolist()

In [None]:
# Load now the model and assign it to `model`
if device == "cuda":
    torch.cuda.empty_cache()

# Find the file that matchs the pattern `_pauc_`
model_file = glob.glob(os.path.join(MODEL_DIR, "model_spectrogram_acc_*.pth"))
model_name = os.path.basename(model_file[0])
    
model = create_vit(
    vit_model="vitbase16_2",
    num_classes=NUM_CLASSES,
    dropout=0.1,
    seed=42,
    device=device
    )

model = torch.compile(model, backend="aot_eager")
model = load_model(model, MODEL_DIR, model_name)
model.to(device)

In [None]:
# Get 24 random indexes from the test dataset
num_samples = 24
random_indices = random.sample(range(len(test_dataloader.dataset)), num_samples)

# Load audio files and get predictions
waveform_list = []
label_list = []
sample_rate_list = []
for idx in random_indices:
    
    # Load waveform and label
    try:
        waveform, sample_rate = load_audio(test_dataloader.dataset.files[idx])
        actual_label = class_names[test_dataloader.dataset.labels[idx]]
    except:
        waveform, sample_rate = load_audio(test_dataloader.dataset.dataset.files[idx])
        actual_label = class_names[test_dataloader.dataset.dataset.labels[idx]]

    # Append data
    waveform_list.append(waveform)
    label_list.append(actual_label)
    sample_rate_list.append(sample_rate)

# Predict and play back
predict_and_play_audio(
    model=model,
    waveform_list=waveform_list,
    label_list=label_list,
    sample_rate_list=sample_rate_list,
    class_names=class_names,
    transform=get_transform_test,
    device=device
)

In [None]:
# Generate a classification report 
pred_list, classif_report = engine2.predict_and_store(
    test_dir=TEST_DIR,
    transform=get_transform_test,
    class_names=class_names,
    sample_fraction=1,
    seed=SEED)

In [None]:
pd.DataFrame(classif_report)

In [None]:
speed = round(1.0 / pd.DataFrame(pred_list)['time_for_pred'].mean(), 2)
print(f'GPU: Predicted Images per Sec [fps]: {speed}')