In [9]:
from transformers import AutoFeatureExtractor, AutoModel
import torch
import torchaudio
from datasets import Dataset, load_from_disk
import numpy as np
print(torch.cuda.is_available())  # Should return True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

True


In [11]:
# Import dataset
dataset = load_from_disk("/share/data/lang/users/ttic_31110/GMM_DHuBERT/GMM-DistilHuBERT/data/hf_librispeech_clean100")
len(dataset)

Loading dataset from disk:   0%|          | 0/47 [00:00<?, ?it/s]

28523

### Test DHuBERT

In [20]:
example = dataset[0]
waveform = torch.tensor(example["audio"]["array"])
sampling_rate = example["audio"]["sampling_rate"]

In [21]:
# Load the feature extractor and model
feature_extractor = AutoFeatureExtractor.from_pretrained("ntu-spml/distilhubert")
model = AutoModel.from_pretrained("ntu-spml/distilhubert").to(device)

# Preprocess waveform
inputs = feature_extractor(waveform, sampling_rate=sampling_rate, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}

# Inference
with torch.no_grad():
    outputs = model(**inputs)

# Output hidden state shape
print("Hidden state shape:", outputs.last_hidden_state.shape)
# 704 tokens (acoustic representation), each a 768-d vector

Hidden state shape: torch.Size([1, 704, 768])


### Try Fine-Tuning on ASR

In [41]:
from transformers import AutoProcessor, AutoModelForCTC, TrainingArguments, Trainer
import torch
from dataclasses import dataclass
from typing import Dict, List, Union
import jiwer


#### processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]

* Apply the processor that wraps a feature extractor and tokenizer to the waveform.
* Returns a dict with a key like `input_values`: padded and normalized audio ready for the model.
* You extract the first (and only) item in the batch using [0].

`example["input_values"]`: a vector of floats the model will ingest.

#### with processor.as_target_processor():

* Switches the processor into target mode which changes its behavior to tokenize text, not audio.
* This is important for models where the same processor handles both input and label preprocessing.

#### example["labels"] = processor(example["text"]).input_ids

* Tokenizes the reference transcript into a sequence of IDs aka label tokens.
* They’ll be aligned to the model’s output using CTC loss, allowing flexible alignment between audio and text.

In [None]:
processor = AutoProcessor.from_pretrained("ntu-spml/distilhubert")
model = AutoModelForCTC.from_pretrained("ntu-spml/distilhubert").to("cuda")

def prepare(example):
    audio = example["audio"]
    example["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    with processor.as_target_processor():
        example["labels"] = processor(example["text"]).input_ids
    return example

data = dataset.map(prepare, remove_columns=dataset.column_names)
data.set_format(type="torch", columns=["input_values", "labels"])

#### data collator: 
Pad and Convert lists of samples into batched tensors for the model, handle labels appropriately for the loss function (CTC)

* `padding`: Can be `True`, `"longest"`, or `"max_length"` – determines how padding is applied.

def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
* `features`:{"input_values": tensor, "labels": [int, int, ...]} 

In [None]:
# Custom DataCollatorCTCWithPadding
class DataCollatorCTCWithPadding:
    processor: any
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": f["input_values"]} for f in features]
        label_features = [{"input_ids": f["labels"]} for f in features]

        # pad input and labels
        batch = self.processor.feature_extractor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt"
        )

        labels_batch = self.processor.tokenizer.pad(
            label_features,
            padding=self.padding,
            return_tensors="pt"
        )

        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) # avoid compute loss on padded positions
        batch["labels"] = labels
        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [None]:
training_args = TrainingArguments(
    output_dir="./checkpoints_distilhubert_asr",
    per_device_train_batch_size=16,
    evaluation_strategy="steps",
    num_train_epochs=3,
    save_steps=1000,
    save_total_limit=2,
    logging_steps=100,
    fp16=True, # Uses less memory
    resume_from_checkpoint=True,
    report_to="wandb",
)

In [None]:
transform = jiwer.Compose([
    jiwer.ToLowerCase(),
    jiwer.RemovePunctuation(),
    jiwer.RemoveMultipleSpaces(),
    jiwer.Strip(),
])

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = torch.argmax(torch.tensor(pred_logits), dim=-1)

    # Decode predictions
    pred_str = processor.batch_decode(pred_ids)

    # Decode references
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, group_tokens=False)

    # Normalize
    pred_str = [transform(p) for p in pred_str]
    label_str = [transform(l) for l in label_str]

    # Compute metrics
    wer = jiwer.wer(label_str, pred_str)
    cer = jiwer.cer(label_str, pred_str)

    # Sentence Error Rate: fraction of sentences with at least 1 error
    ser = sum(p != l for p, l in zip(pred_str, label_str)) / len(label_str)

    return {
        "wer": wer,
        "cer": cer,
        "ser": ser,
    }

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=eval_dataset,
    tokenizer=processor,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train(resume_from_checkpoint=True)