In [1]:
%load_ext autoreload
%autoreload 2

In [30]:
from datasets import load_dataset
from transformers import FlaxWhisperForConditionalGeneration, WhisperProcessor
import numpy as np

In [38]:
model, params = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", _do_init=False)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")

In [39]:
librispeech = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

def preprocess(batch):
    batch["input_features"] = processor(
        batch["audio"]["array"], sampling_rate=16000, return_tensors="np"
    ).input_features[0]
    return batch

dataset_processed = librispeech.map(preprocess, remove_columns=librispeech.column_names)

eval_dataloader = dataset_processed.with_format("numpy").iter(batch_size=4)

Found cached dataset librispeech_asr_dummy (/Users/sanchitgandhi/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)
Loading cached processed dataset at /Users/sanchitgandhi/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-4935437da9c0dd4b.arrow


In [40]:
batch = next(iter(eval_dataloader))
decoder_input_ids = np.ones((batch["input_features"].shape[0], 1)) * model.config.decoder_start_token_id

In [41]:
# test forward pass
logits = model(batch["input_features"], decoder_input_ids=decoder_input_ids, params=params).logits

In [42]:
# test generate
pred_ids = model.generate(batch["input_features"], params=params, max_new_tokens=64)
pred_str = processor.batch_decode(pred_ids.sequences, skip_special_tokens=True)
pred_str

[' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.',
 " Nor is Mr. Quilter's manner less interesting than his matter.",
 ' He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind.',
 " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of rocky Ithaca."]

In [43]:
from models import FlaxWhisperForConditionalGeneration as FlaxScanRematWhisperForConditionalGeneration

In [44]:
model, params = FlaxScanRematWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", _do_init=False)

In [45]:
# model structure is entirely equivalent to the original model -> check we get the same outputs
# test forward pass
new_logits = model(batch["input_features"], decoder_input_ids=decoder_input_ids, params=params).logits
print("Max diff in logits: ", np.max(np.abs(new_logits - logits)))

# test generate
pred_ids = model.generate(batch["input_features"], params=params, max_new_tokens=64)
pred_str = processor.batch_decode(pred_ids.sequences, skip_special_tokens=True)
pred_str

0.0


[' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.',
 " Nor is Mr. Quilter's manner less interesting than his matter.",
 ' He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind.',
 " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of rocky Ithaca."]

In [46]:
# enable gradient checkpointing -> check we get the same outputs
model.enable_gradient_checkpointing()

# test forward pass
new_logits = model(batch["input_features"], decoder_input_ids=decoder_input_ids, params=params).logits
print("Max diff in logits: ", np.max(np.abs(new_logits - logits)))

# test generate
pred_ids = model.generate(batch["input_features"], params=params, max_new_tokens=64)
pred_str = processor.batch_decode(pred_ids.sequences, skip_special_tokens=True)
pred_str

Max diff in logits:  1.5258789e-05


[' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.',
 " Nor is Mr. Quilter's manner less interesting than his matter.",
 ' He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind.',
 " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of rocky Ithaca."]

In [47]:
# enable scan -> check we get the same outputs
model.enable_scan()  # to enable scan in the nn.Module
params = model.convert_unroll_to_scan(params) # to convert the unrolled params to scan

# test forward pass
new_logits = model(batch["input_features"], decoder_input_ids=decoder_input_ids, params=params).logits
print("Max diff in logits: ", np.max(np.abs(new_logits - logits)))

# test generate
pred_ids = model.generate(batch["input_features"], params=params, max_new_tokens=64)
pred_str = processor.batch_decode(pred_ids.sequences, skip_special_tokens=True)
pred_str

Max diff in logits:  1.6212463e-05


[' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.',
 " Nor is Mr. Quilter's manner less interesting than his matter.",
 ' He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind.',
 " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of rocky Ithaca."]