In [19]:
from transformers import AutoFeatureExtractor, AutoModel
import torch
print(torch.cuda.is_available())  # Should return True
import torchaudio

# Load the feature extractor and model
feature_extractor = AutoFeatureExtractor.from_pretrained("ntu-spml/distilhubert")
model = AutoModel.from_pretrained("ntu-spml/distilhubert").to(device)


True


In [20]:
# import dataset
training_dataset = torchaudio.datasets.LIBRISPEECH(
    root="/scratch/pippalin2/jupyter/GMM-DistilHuBERT/data",    # where your LibriSpeech folder lives
    url="train-clean-100",       # this must match the subfolder name
    download=False            
)
waveform, sample_rate, transcript, _, _, _ = training_dataset[0]
print("Transcript:", transcript)

# Resample if needed: DHuBERT requires 16kHz
if sample_rate != 16000:
    resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
    waveform = resampler(waveform)

Transcript: CHAPTER ONE MISSUS RACHEL LYNDE IS SURPRISED MISSUS RACHEL LYNDE LIVED JUST WHERE THE AVONLEA MAIN ROAD DIPPED DOWN INTO A LITTLE HOLLOW FRINGED WITH ALDERS AND LADIES EARDROPS AND TRAVERSED BY A BROOK


### Test DHuBERT

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

inputs = feature_extractor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
'''
{
  'input_values': tensor of shape [1, num_samples]
}
'''
# Run inference
with torch.no_grad(): # no need gradient since we are just testing
    outputs = model(**inputs)

# Output hidden state shape
print("Hidden state shape:", outputs.last_hidden_state.shape)
# 704 tokens (acoustic representation), each a 768-d vector

Hidden state shape: torch.Size([1, 704, 768])


### Try Fine-Tuning on ASR

In [None]:
from datasets import load_dataset

# Load small subset for testing; can change to "train-clean-100"
librispeech = load_dataset("librispeech_asr", "clean", split="train.100")

# Show example
print(librispeech[0]["audio"])
print(librispeech[0]["text"])


In [None]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("ntu-spml/distilhubert")

# Preprocessing function
def prepare(example):
    audio = example["audio"]

    # Extract audio features
    example["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]

    # Encode labels (character-level)
    with processor.as_target_processor():
        example["labels"] = processor(example["text"]).input_ids

    return example

# Apply preprocessing
processed_ds = librispeech.map(prepare, remove_columns=librispeech.column_names)


In [None]:
from transformers import HubertForCTC

model = HubertForCTC.from_pretrained(
    "ntu-spml/distilhubert",
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id
)


In [None]:
import jiwer

def compute_metrics(pred):
    pred_ids = pred.predictions.argmax(-1)
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    wer = jiwer.wer(label_str, pred_str)
    return {"wer": wer}


In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./distilhubert-asr",
    group_by_length=True,
    per_device_train_batch_size=8,
    evaluation_strategy="steps",
    num_train_epochs=3,
    fp16=True,
    save_steps=100,
    eval_steps=100,
    logging_steps=25,
    learning_rate=1e-4,
    warmup_steps=500,
    save_total_limit=2,
    gradient_checkpointing=True,
    logging_dir="./logs"
)

trainer = Trainer(
    model=model,
    data_collator=processor.feature_extractor.pad,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=processed_ds,
    tokenizer=processor
)

trainer.train()



In [None]:
from datasets import load_dataset

# Load test sample
test_ds = load_dataset("librispeech_asr", "clean", split="test.clean[:1%]")
test_ds = test_ds.map(prepare)

# Predict
pred = trainer.predict(test_ds)
print("WER:", compute_metrics(pred))


In [25]:
!pip install jiwer

Collecting jiwer
  Downloading jiwer-3.1.0-py3-none-any.whl.metadata (2.6 kB)
Collecting click>=8.1.8 (from jiwer)
  Downloading click-8.2.0-py3-none-any.whl.metadata (2.5 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading jiwer-3.1.0-py3-none-any.whl (22 kB)
Downloading click-8.2.0-py3-none-any.whl (102 kB)
Downloading rapidfuzz-3.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m80.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: rapidfuzz, click, jiwer
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3/3[0m [jiwer]32m2/3[0m [jiwer]
Successfully installed click-8.2.0 jiwer-3.1.0 rapidfuzz-3.13.0
