# Settings & Configutations

In [None]:
!wget https://raw.githubusercontent.com/tanntnny/culi-scoring/main/assets/cefr_label.csv
!wget https://raw.githubusercontent.com/tanntnny/culi-scoring/main/requirements.txt
!pip install -r requirements.txt

from google.cloud import storage
from google.colab import auth
import pandas as pd
import numpy as np
import os
import math

# Models
import torch
import torch.nn as nn
from transformers import Wav2Vec2Model, Wav2Vec2Processor, get_cosine_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
import torchaudio
from sklearn.model_selection import train_test_split

# Utilities
from tqdm import tqdm
from pydub import AudioSegment

auth.authenticate_user()

# Google Storage

In [None]:
# ------------------- Connect the google cloud storage -------------------

client = storage.Client(project='chulaworks')
bucket = client.bucket('dwn-chula')
blobs = bucket.list_blobs(prefix='datasets/ICNALE/SM')

# ------------------- Overview the files in the bucket -------------------

extensions = dict()
duped_count = 0

for blob in blobs:
    blob_name = blob.name
    # extension
    name, ext = os.path.splitext(blob_name)
    if ext not in extensions:
        extensions[ext] = 0
    extensions[ext] += 1
    # duped files
    if name[-1] == ")":
        duped_count += 1

for ext, count in extensions.items():
    print(f"{ext}\t:\t{count}")
print("-----------------------------------------------------------------------")
print(f"Total duped files\t:\t{duped_count}")

# Preprocessing

In [None]:
# ------------------- Utilities -------------------

wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
cefr_label = pd.read_csv("cefr_label.csv")

# Recursively find files in a folder tree
def dig_folder(file):
    returning = []
    if os.path.isdir(file):
        for f in os.listdir(file):
            returning.extend(dig_folder(os.path.join(file, f)))
    else:
        returning.append(file)
    return returning

# Convert MP3 to tensor
def mp3_to_tensor(mp3_path, frame_rate=16_000):
    audio = AudioSegment.from_mp3(mp3_path)
    audio = audio.set_frame_rate(frame_rate).set_channels(1)
    audio.export("temp.wav", format="wav")
    waveform, sample_rate = torchaudio.load("temp.wav")
    os.remove("temp.wav")
    return waveform, sample_rate

# Download blobs from a google cloud storage bucket
def download_blobs(bucket, prefix, destination_folder):
    blobs = bucket.list_blobs(prefix=prefix)
    os.makedirs(destination_folder, exist_ok=True)
    for blob in blobs:
        local_path = os.path.join(destination_folder, os.path.relpath(blob.name, prefix))
        os.makedirs(os.path.dirname(local_path), exist_ok=True)
        blob.download_to_filename(local_path)
        print(f"Downloaded {blob.name} to {local_path}")

# Create a data configuration DataFrame
def create_data_config(prefix):
    paths = []
    labels = []
    for f in dig_folder(prefix):
      basename = os.path.basename(f)
      label = basename.split("_")[-2] + "_" + basename.split("_")[-1][0]
      if label in cefr_label["CEFR Level"].values:
        paths.append(f)
        labels.append(label)
    df = pd.DataFrame({
        'path': paths,
        'label': labels
    })
    return df

# ------------------- Dataset -------------------

class ICNALE_SM_Dataset(Dataset):
    def __init__(self, data_config):
        self.samples = [] # list of tuples (waveform, label)
        for _, row in tqdm(data_config.iterrows(), total=len(data_config), desc="Initiating ICNALE SM Dataset"):
            path, label = row['path'], row['label']
            waveform, _ = mp3_to_tensor(path)
            waveform = waveform.squeeze().numpy()
            value = cefr_label.loc[cefr_label["CEFR Level"] == label, "label"].values
            if len(value) > 0:
              label = value[0]
              self.samples.append((waveform, label))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        return self.samples[idx]


# ------------------- Preprocess -------------------

os.makedirs("./dataset/audio", exist_ok=True)

!gsutil -m cp -r gs://dwn-chula/datasets/ICNALE/SM/ICNALE_SM_Audio ./dataset/audio

data_config = create_data_config("./dataset/audio") # List of tuples (local_path_to_audio, label)

train_data_config, eval_data_config = train_test_split(
    data_config,
    test_size=0.2,
    random_state=42,
    stratify=data_config['label']
)

# Models

In [None]:
# ------------------- Models -------------------

class MeanPooler(nn.Module):
    def forward(self, hidden_states: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        # hidden: (B, F, H); mask: (B, H)
        mask = mask.unsqueeze(-1).float()  # (B, T, 1)
        summed = (hidden_states * mask).sum(dim=1)
        counts = mask.sum(dim=1).clamp(min=1e-9)
        return summed / counts  # (B, H)

class PrototypicalClassifier(nn.Module):
    # K prototypes per class with learnable temperature #
    def __init__(self, embed_dim: int, num_classes: int, k: int = 3):
        super().__init__()
        self.k = k
        self.num_classes = num_classes
        self.prototypes = nn.Parameter(
            torch.randn(num_classes * k, embed_dim) / math.sqrt(embed_dim)
        )
        self.log_tau = nn.Parameter(torch.zeros(()))  # temperature

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, D)
        # prototypes: (C * K, D)
        dists = torch.cdist(x, self.prototypes, p=2) ** 2  # squared Euclidean
        dists = dists.view(x.size(0), self.num_classes, self.k)  # (B, C, K)
        dists = dists.mean(dim=2)  # (B, C)
        logits = -dists / torch.exp(self.log_tau)
        return logits

class SpeechModel(nn.Module):
    def __init__(self, num_classes: int, k: int = 3):
        super().__init__()
        self.encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        hidden_size = self.encoder.config.hidden_size
        self.pooler = MeanPooler()
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.GELU(),
            nn.LayerNorm(256),
        )
        self.metric_head = PrototypicalClassifier(embed_dim=256, num_classes=num_classes, k=k)

    def forward(self, input_values: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        out = self.encoder(input_values=input_values, attention_mask=attention_mask)
        pooled = self.pooler(out.last_hidden_state, attention_mask)
        z = self.mlp(pooled)
        logits = self.metric_head(z)
        return logits

# Training Loop

In [None]:
# ------------------- Utilities -------------------

def collate_fn(batch):
    waveforms, labels = zip(*batch)
    proc_out = wav2vec_processor(
        waveforms,
        sampling_rate=16_000,
        return_tensors="pt",
        padding=True,
    )
    proc_out["labels"] = torch.tensor(labels, dtype=torch.long)
    return proc_out

def make_class_weights(dataset: Dataset, num_classes: int) -> torch.Tensor:
    counts = torch.zeros(num_classes)
    for _, label in dataset:
        counts[label] += 1
    weights = 1.0 / (counts + 1e-9)
    weights = weights / weights.sum() * num_classes
    return weights

def run_epoch(model, loader, criterion, optimiser=None, scaler=None):
    is_train = optimiser is not None
    model.train() if is_train else model.eval()
    device = next(model.parameters()).device
    total_loss, correct, n = 0.0, 0, 0

    for batch in loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            logits = model(batch["input_values"], batch["attention_mask"])
            loss = criterion(logits, batch["labels"])
        if is_train:
            optimiser.zero_grad()
            if scaler:
                scaler.scale(loss).backward()
                scaler.step(optimiser)
                scaler.update()
            else:
                loss.backward()
                optimiser.step()
        preds = logits.argmax(1)
        total_loss += loss.item() * preds.size(0)
        correct += (preds == batch["labels"]).sum().item()
        n += preds.size(0)

    return total_loss / n, correct / n

# ------------------- Training Loop -------------------

# =================== COINFIGURATION ===================

NUM_CLASSES = len(cefr_label)
K_PROTOTYPES = 3
BATCH_SIZE = 8
EPOCHS = 10
LR = 5e-5
WARMUP_FRAC = 0.1
WEIGHT_DECAY = 0.01

# ======================================================

train_dataset = ICNALE_SM_Dataset(train_data_config)
eval_dataset = ICNALE_SM_Dataset(eval_data_config)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SpeechModel(num_classes=NUM_CLASSES, k=K_PROTOTYPES).to(device)

class_weights = make_class_weights(train_dataset, NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)

optimiser = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
total_steps = len(train_loader) * EPOCHS
scaler = torch.cuda.amp.GradScaler()
sched = get_cosine_schedule_with_warmup(
        optimiser, int(total_steps * WARMUP_FRAC), total_steps
)

best_eval = 0.0
for epoch in range(1, EPOCHS + 1):
    print(f"Training on epoch {epoch} ... ")
    train_loss, train_acc = run_epoch(model, train_loader, criterion, optimiser, scaler)
    sched.step()
    eval_loss, eval_acc = run_epoch(model, eval_loader, criterion)
    print(
        f"Epoch {epoch:02d} | "
        f"train: loss={train_loss:.4f} acc={train_acc*100:.2f}% | "
        f"dev: loss={eval_loss:.4f} acc={eval_acc*100:.2f}%"
    )
    if eval_acc > best_eval:
        best_dev = eval_acc
        torch.save(model.state_dict(), "best_cefr_w2v2.pt")
        print("  ! New best model saved.")

    print(f"Best eval accuracy: {best_dev*100:.2f}%")
