
# Fine-Tuning Wav2Vec 2.0 with LoRA on GTZAN

This notebook demonstrates how to fine-tune a pre-trained **Wav2Vec 2.0** model for **music genre classification** on the **GTZAN** dataset using **LoRA (Low-Rank Adaptation)** with the **PEFT** library.  
We cover: environment setup, data preparation, LoRA integration, training with validation, evaluation with metrics (accuracy, F1, confusion matrix), logging, checkpointing, and inference.


## 1. Environment Setup

In [None]:
# If running on Colab or a fresh environment, uncomment the following:
# !pip install -U transformers datasets peft torch torchaudio librosa numpy pandas matplotlib scikit-learn accelerate tqdm

import os
import math
import json
import random
import time
from pathlib import Path
from dataclasses import dataclass
import numpy as np
import pandas as pd
import torch
import torchaudio
import librosa
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split, StratifiedKFold
from transformers import (
    AutoConfig,
    AutoProcessor,
    AutoModelForAudioClassification,
    Wav2Vec2FeatureExtractor,
    get_linear_schedule_with_warmup,
    set_seed
)
from peft.tuners.lora import LoraConfig as LoraConfigLowLevel, LoraModel
from peft import PeftModel
from google.colab import drive
import soundfile as sf
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift


# Matplotlib defaults for clean, single-plot visuals
%matplotlib inline

print('torch:', torch.__version__)
print('cuda available:', torch.cuda.is_available())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Some optional perf tweaks
try:
    torch.set_float32_matmul_precision('medium')
except Exception:
    pass

torch.backends.cudnn.benchmark = True
set_seed(42)


## 2. Configuration

In [None]:
# ==== Paths ====
DATA_ROOT = Path('files')  # Change if needed
OUTPUT_DIR = Path('outputs_w2v2_lora_gtzan')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# ==== Dataset / audio ====
SAMPLE_RATE = 16000
CLIP_SECONDS = 29
MAX_LENGTH = SAMPLE_RATE * CLIP_SECONDS

os.makedirs(OUTPUT_DIR, exist_ok=True) # Ensure output directory exists

# ==== Labels ====
LABEL_LIST = ['blues','classical','country','disco','hiphop','jazz','metal','pop','reggae','rock']
NUM_LABELS = len(LABEL_LIST)
LABEL2ID = {l:i for i,l in enumerate(LABEL_LIST)}
ID2LABEL = {i:l for l,i in LABEL2ID.items()}

# ==== Model ====
BASE_MODEL = 'facebook/wav2vec2-base'

# ==== Training ====
BATCH_SIZE = 8
EVAL_BATCH_SIZE = 4
LR = 2e-4
WEIGHT_DECAY = 0.01
NUM_EPOCHS = 100
WARMUP_RATIO = 0.1
GRAD_ACCUM_STEPS = 2
EARLY_STOP_PATIENCE = 3
MAX_GRAD_NORM = 1.0
USE_MIXED_PRECISION = True

# ==== LoRA ====
LORA_R = 32
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
LORA_TARGET_MODULES = ['q_proj','k_proj','v_proj','out_proj']  # attention projections inside wav2vec2 encoder

# ==== Logging ====
LOG_EVERY_STEPS = 20

In [None]:
# We store all the variable in a dictionary

config_dict = {
    "DATA_ROOT": str(DATA_ROOT.resolve()),
    "OUTPUT_DIR": str(OUTPUT_DIR.resolve()),
    "SAMPLE_RATE": SAMPLE_RATE,
    "CLIP_SECONDS": CLIP_SECONDS,
    "MAX_LENGTH": MAX_LENGTH,
    "LABEL_LIST": LABEL_LIST,
    "BASE_MODEL": BASE_MODEL,
    "BATCH_SIZE": BATCH_SIZE,
    "EVAL_BATCH_SIZE": EVAL_BATCH_SIZE,
    "LR": LR,
    "WEIGHT_DECAY": WEIGHT_DECAY,
    "NUM_EPOCHS": NUM_EPOCHS,
    "WARMUP_RATIO": WARMUP_RATIO,
    "GRAD_ACCUM_STEPS": GRAD_ACCUM_STEPS,
    "EARLY_STOP_PATIENCE": EARLY_STOP_PATIENCE,
    "MAX_GRAD_NORM": MAX_GRAD_NORM,
    "USE_MIXED_PRECISION": USE_MIXED_PRECISION,
    "LORA": {
        "r": LORA_R,
        "alpha": LORA_ALPHA,
        "dropout": LORA_DROPOUT,
        "target_modules": LORA_TARGET_MODULES,
    }
}

with open(OUTPUT_DIR / 'config.json', 'w') as f:
    json.dump(config_dict, f, indent=2)


## 3. Data Preparation

We scan the folder structure, build a DataFrame of file paths and labels, then create a `torch.utils.data.Dataset` with:
- **Random crop** for training, **center crop** for eval/test
- On-the-fly waveform loading with `librosa` (or `torchaudio`)
- Padding/truncation to fixed `MAX_LENGTH`

### 3.1 Index the dataset

We use a **DataFrame** to store the path to the .wav files and the labels.

In [None]:
# Method used to insert the files in the Dataframe
def index_gtzan(root: Path, allowed_labels):
    rows = []

    for label in allowed_labels:
        class_dir = root / label

        if not class_dir.exists():
            print(f"[WARN] Missing class dir: {class_dir}")
            continue

        for wav in class_dir.rglob('*.wav'):
            rows.append({'path': str(wav), 'label': label})

    # We add the row in the DataFrame
    df = pd.DataFrame(rows)

    # If the dataframe is empty, we return an error
    if df.empty:
        raise FileNotFoundError("No .wav files found. Check DATA_ROOT and structure.")

    # We add the column label_id which contains a numeric value for each label
    df['label_id'] = df['label'].map(LABEL2ID)

    return df

In [None]:
df = index_gtzan(DATA_ROOT, LABEL_LIST).sample(frac=1.0, random_state=42).reset_index(drop=True)
print(df.head(), '\nTotal files:', len(df))

### 3.2 Split the data

We split the daset into 3 parts: **train (80%), validation (10%) and test (10%)**.

In [None]:
train_df, tmp_df = train_test_split(df, test_size=0.667, random_state=42, stratify=df['label_id'])
val_df, test_df = train_test_split(tmp_df, test_size=0.5, random_state=42, stratify=tmp_df['label_id'])

for name, d in [('train',train_df),('val',val_df),('test',test_df)]:
    print(name, 'size:', len(d))

### 3.3 Load the files

In [None]:
# This method loads the .wav files using the librosa library
def load_wav(path, sr=SAMPLE_RATE):
    # librosa loads to float32 in [-1,1]
    y, _ = librosa.load(path, sr=sr, mono=True)
    return y

In [None]:
# This class will be used to load the Data
class GTZANDataset(torch.utils.data.Dataset):
    def __init__(self, df, processor, split='train'):
        self.df = df.reset_index(drop=True)
        self.processor = processor
        self.split = split

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        y = load_wav(row['path'], sr=SAMPLE_RATE)

        # Processor expects raw float; we pass sampling rate for proper normalization
        inputs = self.processor(y, sampling_rate=SAMPLE_RATE, return_tensors='pt')
        input_values = inputs['input_values'][0]  # (T,)
        label = torch.tensor(row['label_id'], dtype=torch.long)

        return input_values, label

We use a customized **DataCollator** to be able to handle the label and label_id comlumns

In [None]:
@dataclass
class Collator:
    processor: any

    def __call__(self, batch):
        input_values = [b[0] for b in batch]
        labels = torch.tensor([b[1] for b in batch], dtype=torch.long)

        padded = self.processor.pad({"input_values": input_values}, return_tensors="pt")

        return padded['input_values'], labels

In [None]:
# Load processor (feature extractor)
processor = AutoProcessor.from_pretrained(BASE_MODEL, use_safetensors=True)

collate_fn = Collator(processor)

### 3.4 Load the data

In [None]:
train_ds = GTZANDataset(train_df, processor, split='train')
val_ds   = GTZANDataset(val_df, processor, split='val')
test_ds  = GTZANDataset(test_df, processor, split='test')

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True, collate_fn=collate_fn
)

val_loader = torch.utils.data.DataLoader(
    val_ds, batch_size=EVAL_BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True, collate_fn=collate_fn
)
test_loader = torch.utils.data.DataLoader(
    test_ds, batch_size=EVAL_BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True, collate_fn=collate_fn
)

len(train_ds), len(val_ds), len(test_ds)


## 4. Model Setup (Wav2Vec2) and LoRA Integration

We load a pre-trained `Wav2Vec2` base model for **audio classification** and adapt it to **10 genres**.  
Then we configure **LoRA** to inject low-rank adapters into the **attention projection layers** and ensure only LoRA parameters are trainable.

We also report parameter counts: **total vs. trainable**.


In [None]:
config = AutoConfig.from_pretrained(
    BASE_MODEL,
    num_labels=NUM_LABELS,
    label2id=LABEL2ID,
    id2label=ID2LABEL,
    problem_type='single_label_classification'
)

In [None]:
# AutoModelForAudioClassification gives a classifier head on top of wav2vec2
base_model = AutoModelForAudioClassification.from_pretrained(
    BASE_MODEL,
    config=config,
    use_safetensors=True
)

In [None]:
from peft.tuners.lora import LoraConfig, LoraModel

lora_cfg = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    target_modules=LORA_TARGET_MODULES,
    bias="none"
)

model = LoraModel(base_model, lora_cfg, adapter_name="default")
model.to(device)

In [None]:
# Freeze all non-LoRA params to train only LoRA adapters + classification head if present
for name, param in model.named_parameters():
    if 'lora_' in name or 'classifier' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

In [None]:
# Report parameter counts
def count_parameters(m):
    total = sum(p.numel() for p in m.parameters())
    trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
    return total, trainable

In [None]:
total_params, trainable_params = count_parameters(model)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters (LoRA + head): {trainable_params:,}")
print(f"Trainable %: {100*trainable_params/total_params:.2f}%")


## 5. Training Pipeline

We implement a PyTorch training loop with:
- Loss: Cross-entropy
- Optimizer: AdamW
- LR scheduling with warmup
- Mixed precision (optional)
- Early stopping (on validation loss)
- Checkpointing best adapters
- Progress bars and periodic logs


In [None]:
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler

In [None]:
num_update_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
t_total = NUM_EPOCHS * num_update_steps_per_epoch
warmup_steps = int(WARMUP_RATIO * t_total)

In [None]:
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
scaler = GradScaler(enabled=USE_MIXED_PRECISION)

In [None]:
best_val_loss = float('inf')
epochs_no_improve = 0
history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'val_f1': []}

In [None]:
def evaluate(loader):
    model.eval()
    losses = []
    all_preds, all_labels = [], []

    with torch.no_grad():
        for input_values, labels in loader:
            input_values = input_values.to(device)
            labels = labels.to(device)
            with autocast(enabled=USE_MIXED_PRECISION):
                outputs = model(input_values=input_values, labels=labels)
                loss = outputs.loss
            losses.append(loss.item())
            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)
            all_preds.extend(preds.detach().cpu().numpy().tolist())
            all_labels.extend(labels.detach().cpu().numpy().tolist())

    val_loss = float(np.mean(losses)) if losses else 0.0
    val_acc = accuracy_score(all_labels, all_preds)
    val_f1 = f1_score(all_labels, all_preds, average='macro')

    return val_loss, val_acc, val_f1, np.array(all_labels), np.array(all_preds)

**Training loop**: In each epoch we do the validation part and we added an early stopping.

In [None]:
global_step = 0
for epoch in range(1, NUM_EPOCHS+1):
    model.train()
    running_loss = 0.0
    pbar = tqdm(enumerate(train_loader, start=1), total=len(train_loader), desc=f"Epoch {epoch}/{NUM_EPOCHS}")
    optimizer.zero_grad(set_to_none=True)

    for step, (input_values, labels) in pbar:
        input_values = input_values.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        with autocast(enabled=USE_MIXED_PRECISION):
            outputs = model(input_values=input_values, labels=labels)
            loss = outputs.loss / GRAD_ACCUM_STEPS

        scaler.scale(loss).backward()

        if step % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()
            global_step += 1

        running_loss += loss.item() * GRAD_ACCUM_STEPS

        if global_step % LOG_EVERY_STEPS == 0:
            pbar.set_postfix({'train_loss': f"{running_loss / step:.4f}", 'lr': f"{scheduler.get_last_lr()[0]:.2e}"})

    train_epoch_loss = running_loss / len(train_loader)

    # Validation
    val_loss, val_acc, val_f1, y_true, y_pred = evaluate(val_loader)
    history['train_loss'].append(train_epoch_loss)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_f1'].append(val_f1)

    print(f"\nEpoch {epoch}: train_loss={train_epoch_loss:.4f} | val_loss={val_loss:.4f} | val_acc={val_acc:.4f} | val_f1={val_f1:.4f}")

    # Early stopping on val_loss
    if val_loss < best_val_loss - 1e-6:
        best_val_loss = val_loss
        epochs_no_improve = 0
        
        # Save best LoRA adapters
        save_dir = OUTPUT_DIR / 'best_lora'
        model.save_pretrained(save_dir)
        processor.save_pretrained(save_dir)
        print(f"[Checkpoint] Saved best adapters to: {save_dir}")
    else:
        epochs_no_improve += 1
        print(f"[EarlyStopping] No improvement for {epochs_no_improve} epoch(s).")
        if epochs_no_improve >= 7:
            print("[EarlyStopping] Stopping training.")
            break

In [None]:
# Save training history
pd.DataFrame(history).to_csv(OUTPUT_DIR / 'history.csv', index=False)
print("Training complete.")


## 6. Evaluation

We load the **best** saved adapters (by validation loss) and evaluate on the **test** set.
We compute:
- Accuracy
- Macro F1-score
- Per-class precision/recall/F1
- Confusion matrix

We also plot training curves.


In [None]:
# Reload the best adapters for evaluation
best_dir = OUTPUT_DIR / 'best_lora'
if best_dir.exists():
    eval_model = AutoModelForAudioClassification.from_pretrained(BASE_MODEL, config=config, use_safetensors=True )
    eval_model = PeftModel.from_pretrained(eval_model, best_dir)
else:
    print("[WARN] No best_lora directory found, evaluating current model.")
    eval_model = model

eval_model = model

In [None]:
eval_model.to(device)
eval_model.eval()

In [None]:
test_loss, test_acc, test_f1, y_true, y_pred = None, None, None, None, None

with torch.no_grad():
    losses = []
    all_preds, all_labels = [], []

    for input_values, labels in tqdm(test_loader, desc="Testing"):
        input_values = input_values.to(device)
        labels = labels.to(device)
        with torch.cuda.amp.autocast(enabled=USE_MIXED_PRECISION):
            outputs = eval_model(input_values=input_values, labels=labels)
            loss = outputs.loss
        losses.append(loss.item())
        logits = outputs.logits
        preds = torch.argmax(logits, dim=-1)
        all_preds.extend(preds.detach().cpu().numpy().tolist())
        all_labels.extend(labels.detach().cpu().numpy().tolist())

    test_loss = float(np.mean(losses)) if losses else 0.0
    y_true = np.array(all_labels)
    y_pred = np.array(all_preds)
    test_acc = accuracy_score(y_true, y_pred)
    test_f1 = f1_score(y_true, y_pred, average='macro')

In [None]:
print(f"Test loss: {test_loss:.4f} | Test acc: {test_acc:.4f} | Test macro-F1: {test_f1:.4f}")

In [None]:
print("\nPer-class report:")
print(classification_report(y_true, y_pred, target_names=LABEL_LIST))

In [None]:
cm = confusion_matrix(y_true, y_pred)
cm_df = pd.DataFrame(cm, index=LABEL_LIST, columns=LABEL_LIST)
cm_df

In [None]:
# Plot training curves
hist = pd.read_csv(OUTPUT_DIR / 'history.csv')
plt.figure()
plt.plot(hist['train_loss'], label='train_loss')
plt.plot(hist['val_loss'], label='val_loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training & Validation Loss')
plt.legend()
plt.show()

In [None]:
plt.figure()
plt.plot(hist['val_acc'], label='val_acc')
plt.plot(hist['val_f1'], label='val_f1')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.title('Validation Accuracy & Macro F1')
plt.legend()
plt.show()