# Settings & Configutations

In [None]:


from google.cloud import storage
from google.colab import auth
import pandas as pd
import numpy as np
import os
import math

# Models
import torch
import torch.nn as nn
from transformers import Wav2Vec2Model, Wav2Vec2Processor
from torch.utils.data import Dataset, Dataloader
import torchaudio

# Utilities
from tqdm import tqdm
from pydub import AudioSegment

auth.authenticate_user()

# Google Storage

In [None]:
# ------------------- Connect the google cloud storage -------------------

client = storage.Client(project='chulaworks')
bucket = client.bucket('dwn-chula')
blobs = bucket.list_blobs(prefix='datasets/ICNALE/SM')

# ------------------- Overview the files in the bucket -------------------

extensions = dict()
duped_count = 0

for blob in blobs:
    blob_name = blob.name
    # extension
    name, ext = os.path.splitext(blob_name)
    if ext not in extensions:
        extensions[ext] = 0
    extensions[ext] += 1
    # duped files
    if name[-1] == ")":
        duped_count += 1

for ext, count in extensions.items():
    print(f"{ext}\t:\t{count}")
print("-----------------------------------------------------------------------")
print(f"Total duped files\t:\t{duped_count}")

# Preprocessing

In [None]:
# ------------------- Utilities -------------------

wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

# Recursively find files in a folder tree
def dig_folder(file):
    returning = []
    if os.path.isdir(file):
        for f in os.listdir(file):
            returning.extend(dig_folder(os.path.join(file, f)))
    else:
        returning.append(file)
    return returning

# Convert MP3 to tensor
def mp3_to_tensor(mp3_path, frame_rate=16_000):
    audio = AudioSegment.from_mp3(mp3_path)
    audio = audio.set_frame_rate(frame_rate).set_channels(1)
    audio.export("temp.wav", format="wav")
    waveform, sample_rate = torchaudio.load("temp.wav")
    os.remove("temp.wav")
    return waveform, sample_rate

# Download blobs from a google cloud storage bucket
def download_blobs(bucket, prefix, destination_folder):
    blobs = bucket.list_blobs(prefix=prefix)
    os.makedirs(destination_folder, exist_ok=True)
    for blob in blobs:
        local_path = os.path.join(destination_folder, os.path.relpath(blob.name, prefix))
        os.makedirs(os.path.dirname(local_path), exist_ok=True)
        blob.download_to_filename(local_path)
        print(f"Downloaded {blob.name} to {local_path}")

# Create a data configuration DataFrame
def create_data_config(prefix):
    file_paths = dig_folder(prefix)
    labels = [os.path.basename(f).split('_')[-1][:4] for f in file_paths]
    df = pd.DataFrame({
        'path': file_paths,
        'label': labels
    })
    return df

# ------------------- Dataset -------------------

class ICNALE_SM_Dataset(Dataset):
    def __init__(self, data_config):
        self.samples = []
        for row in data_config.iterrows():
            path, label = row['path'], row['label']
            self.samples.append((path, label))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        path, label = self.samples[idx]
        waveform, sample_rate = mp3_to_tensor(path)
        embedding = wav2vec_processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt")
        return (embedding, label)

# ------------------- Preprocess -------------------

gcloud_audio_prefix = 'datasets/ICNALE/SM/ICNALE_SM_Audio'
audio_path = 'dataset/audio'

download_blobs(
    bucket=bucket,
    prefix=gcloud_audio_prefix,
    destination_folder=audio_path,
)

data_config = create_data_config(audio_path)
dataset = ICNALE_SM_Dataset(data_config)


# Pipeline

In [None]:
# ------------------- Models -------------------

class MeanPooler(nn.Module):
    def forward(self, hidden_states: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        # hidden: (B, F, H); mask: (B, H)
        mask = mask.unsqueeze(-1).float()  # (B, T, 1)
        summed = (hidden_states * mask).sum(dim=1)
        counts = mask.sum(dim=1).clamp(min=1e-9)
        return summed / counts  # (B, H)

class PrototypicalClassifier(nn.Module):
    # K prototypes per class with learnable temperature #
    def __init__(self, embed_dim: int, num_classes: int, k: int = 3):
        super().__init__()
        self.k = k
        self.num_classes = num_classes
        self.prototypes = nn.Parameter(
            torch.randn(num_classes * k, embed_dim) / math.sqrt(embed_dim)
        )
        self.log_tau = nn.Parameter(torch.zeros(()))  # temperature

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, D)
        # prototypes: (C * K, D)
        dists = torch.cdist(x, self.prototypes, p=2) ** 2  # squared Euclidean
        dists = dists.view(x.size(0), self.num_classes, self.k)  # (B, C, K)
        dists = dists.mean(dim=2)  # (B, C)
        logits = -dists / torch.exp(self.log_tau)
        return logits

class SpeechModel(nn.Module):
    def __init__(self, num_classes: int, k: int = 3):
        super().__init__()
        self.encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        hidden_size = self.encoder.config.hidden_size
        self.pooler = MeanPooler()
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.GELU(),
            nn.LayerNorm(256),
        )
        self.metric_head = PrototypicalClassifier(embed_dim=256, num_classes=num_classes, k=k)

    def forward(self, input_values: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        out = self.encoder(input_values=input_values, attention_mask=attention_mask)
        pooled = self.pooler(out.last_hidden_state, attention_mask)
        z = self.mlp(pooled)
        logits = self.metric_head(z)
        return logits

# ------------------- Training Loop -------------------

def run_epoch(model, loader, criterion, optimiser=None, scaler=None):
    is_train = optimiser is not None
    model.train() if is_train else model.eval()
    device = next(model.parameters()).device
    total_loss, correct, n = 0.0, 0, 0

    for batch in loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            logits = model(batch["input_values"], batch["attention_mask"])
            loss = criterion(logits, batch["labels"])
        if is_train:
            optimiser.zero_grad()
            if scaler:
                scaler.scale(loss).backward()
                scaler.step(optimiser)
                scaler.update()
            else:
                loss.backward()
                optimiser.step()
        preds = logits.argmax(1)
        total_loss += loss.item() * preds.size(0)
        correct += (preds == batch["labels"]).sum().item()
        n += preds.size(0)

    return total_loss / n, correct / n

# Evaluation