# 1. Introduction

This notebook outlines the creation, compilation, and training of a deep learing network for audio classification using the [TorchSuite](https://github.com/sergio-sanz-rodriguez/torchsuite) framework.
 
https://pytorch.org/tutorials/intermediate/speech_command_classification_with_torchaudio_tutorial.html

# 2. Importing Libraries

In [None]:
#!pip install torcheval
import os
import torch
import glob
import random
import numpy as np
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import IPython.display as ipd
import matplotlib.pyplot as plt
import pandas as pd
import torchaudio

from pathlib import Path
#from torchinfo import summary
from torch.optim.lr_scheduler import CosineAnnealingLR
#from torchaudio.transforms import Resample
from pydub import AudioSegment

# Import custom libraries
from utils.classification_utils import set_seeds, predict_and_play_audio, load_model
from engines.classification import ClassificationEngine
from engines.schedulers import FixedLRSchedulerWrapper
from dataloaders.audio_dataloaders import load_audio, create_dataloaders_waveform, PadWaveform, AudioWaveformTransforms
from models.wav2vec2 import Wav2Vec2Classifier
#from sklearn.utils.class_weight import compute_class_weight

import warnings
os.environ['TORCH_USE_CUDA_DSA'] = "1"
warnings.filterwarnings("ignore", category=UserWarning, module="torch.autograd.graph")
warnings.filterwarnings("ignore", category=FutureWarning, module="onnxscript.converter")

# Paths (modify as needed)
#TRAIN_DIR = Path("data/train")
#TEST_DIR = Path("data/validation")
INFERENCE_DIR = Path("train_soundscapes")

# Define some constants
NUM_WORKERS = os.cpu_count()
SEED = 42

# Create target model directory
MODEL_DIR = Path("outputs")
MODEL_DIR.mkdir(parents=True, exist_ok=True)

# Set seeds
set_seeds(SEED)

IMPORT_DATASET = False

# 3. Specifying the Target Device

In [None]:
# Activate cuda benchmark
cudnn.benchmark = True

# Set device
device = "cpu" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

if device == "cuda":
    !nvidia-smi

device = 'cpu'

In [None]:
def load_model(
        model: torch.nn.Module,
        model_weights_dir: str,
        model_weights_name: str):

    """
    Loads a PyTorch model from a target directory.

    Args:
    model: A target PyTorch model to load.
    model_weights_dir: A directory where the model is located.
    model_weights_name: The name of the model to load.
      Should include either ".pth" or ".pt" as the file extension.

    Example usage:
    model = load_model(model=model,
                       model_weights_dir="models",
                       model_weights_name="05_going_modular_tingvgg_model.pth")

    Returns:
    The loaded PyTorch model.
    """
    # Create the model directory path
    model_dir_path = Path(model_weights_dir)

    # Create the model path
    assert model_weights_name.endswith(".pth") or model_weights_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'"
    model_path = model_dir_path / model_weights_name

    # Load the model
    print(f"[INFO] Loading model from: {model_path}")
    
    model.load_state_dict(torch.load(model_path, weights_only=True))
    
    return model

# 5. Preparing Submission File

In [None]:
df_submission = pd.read_csv('sample_submission.csv').drop([0, 1, 2])
submission_labels = df_submission.columns[1:]
submission_labels

In [None]:
df_info = pd.read_csv('label_to_info.csv')
idx_to_label = dict(zip(df_info["index"], df_info["label"]))
idx_to_label

In [None]:
# Constants
CHUNK_DURATION_SEC = 5
NEW_SAMPLE_RATE = 8000 # Hz
TARGET_LENGTH = NEW_SAMPLE_RATE * CHUNK_DURATION_SEC # use 5-sec length
#_, SAMPLE_RATE = load_audio(TRAIN_DIR / "0" / "CSA36389_chunk0.wav")
BATCH_SIZE = 64
ACCUM = 1
NUM_CLASSES = df_info.shape[0]
NUM_SAMPLES_PER_CLASS = 500
VAL_PERCENTAGE = 0.2
AUGMENT_MAGNITUDE = 2
LR = 1e-5
ETA_MIN = 1e-7
EPOCHS = 25

# 5. Making Predictions

In [None]:
# Load model
model = load_model(
    model = Wav2Vec2Classifier(num_classes=NUM_CLASSES),
    model_weights_dir = MODEL_DIR,
    model_weights_name = 'model_wave_8khz_pauc_epoch8.pth'
)
model.to(device)
model.eval()
inference_context = torch.no_grad()   

In [None]:
# Datafram to collect results
results = []

# Iterate over .ogg files
for audio_path in INFERENCE_DIR.glob("*.ogg"):

    results = []

    # Load audio
    waveform, SAMPLE_RATE = torchaudio.load(audio_path)

    # Pre_processing 
    transform = AudioWaveformTransforms(
        augmentation=False,   
        sample_rate=SAMPLE_RATE,
        new_sample_rate=NEW_SAMPLE_RATE,
        target_length=TARGET_LENGTH
        )

    # Compute the chunk duration in samples
    MAX_CHUNK_DURATION_SAMPLES = CHUNK_DURATION_SEC * SAMPLE_RATE

    # Ensure mono audio
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    total_duration_sec = waveform.shape[1] / SAMPLE_RATE
    num_chunks = max(1, int(total_duration_sec // CHUNK_DURATION_SEC))

    # Process each chunk
    for i in range(num_chunks):
        start_chunk = i * MAX_CHUNK_DURATION_SAMPLES
        end_chunk = start_chunk + MAX_CHUNK_DURATION_SAMPLES
        chunk = waveform[:, int(start_chunk):int(end_chunk)]        

        # Check if the chunk is shorter than the desired duration
        if chunk.shape[1] < MAX_CHUNK_DURATION_SAMPLES:
            pad_size = MAX_CHUNK_DURATION_SAMPLES - chunk.shape[1]
            chunk = F.pad(chunk, (0, pad_size))

        chunk = chunk.to(device)
        transform = transform.to(device)

        with inference_context:
            chunk = chunk.squeeze(1) if chunk.ndim == 3 else chunk            
            chunk = transform(chunk)            
            probs = torch.softmax(model(chunk), dim=1).cpu().numpy().flatten()
        
        # Generate row_id: filename without .ogg + _<chunk_end_time>
        file_id = audio_path.stem
        row_id = f"{file_id}_{(i + 1) * CHUNK_DURATION_SEC}"

        label_to_prob = {str(idx_to_label[j]): probs[j] for j in range(len(probs))}        
        ordered_probs = [label_to_prob.get(label, 0.0) for label in submission_labels]        
        results.append([row_id] + ordered_probs)

    # Convert to dataframe
    submission_df = pd.DataFrame(results, columns=["row_id"] + list(submission_labels))
    display(submission_df)

# Save to CSV
#submission_df.to_csv("submission.csv", index=False)

In [None]:
submission_df