In [1]:
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import torch
from evaluate import load
from itertools import islice
from tqdm import tqdm
from time import time
from collections import defaultdict
import itertools

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
NUM_EXAMPLES = 50

In [3]:
librispeech_test_clean = load_dataset("librispeech_asr", "clean", split="test")
librispeech_test_clean = islice(librispeech_test_clean, NUM_EXAMPLES)

Downloading builder script: 100%|██████████| 11.5k/11.5k [00:00<00:00, 16.2MB/s]
Downloading metadata: 100%|██████████| 10.1k/10.1k [00:00<00:00, 13.9MB/s]
Downloading readme: 100%|██████████| 10.2k/10.2k [00:00<00:00, 18.9MB/s]
Downloading data: 100%|██████████| 338M/338M [00:40<00:00, 8.37MB/s]
Downloading data:   4%|▍         | 13.4M/347M [00:02<00:54, 6.16MB/s]


KeyboardInterrupt: 

In [4]:
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to("cpu")

In [5]:
model

WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(80, 768, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(768, 768, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 768)
      (layers): ModuleList(
        (0): WhisperEncoderLayer(
          (self_attn): WhisperAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_lay

In [6]:
def get_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    return size_all_mb

size_all_mb = get_model_size(model)
print('model size: {:.3f}MB'.format(size_all_mb))

model size: 922.146MB


In [7]:
def map_to_pred(batch, model):
    audio = batch["audio"]
    input_features = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features

    with torch.no_grad():
        predicted_ids = model.generate(input_features.to("cpu"))[0]
    transcription = processor.decode(predicted_ids)
    return processor.tokenizer._normalize(transcription)

In [8]:
res = defaultdict(list)

t = time()
for el in tqdm(librispeech_test_clean):
    res["reference"].append(processor.tokenizer._normalize(el['text']))
    res["prediction"].append(map_to_pred(el, model))

t = ((time() - t) / NUM_EXAMPLES) * 1000
print(f"avg. time on example: {t}")

wer = load("wer")
print(100 * wer.compute(references=res["reference"], predictions=res["prediction"]))


50it [01:23,  1.67s/it]


avg. time on example: 1671.0307693481445
3.982777179763186


In [9]:
q_model = torch.quantization.quantize_dynamic(
    model,
    {torch.nn.Conv1d, torch.nn.Linear},
    dtype=torch.qint8
)

In [10]:
q_model_size = get_model_size(q_model)
print('model size: {:.3f}MB'.format(q_model_size))

model size: 165.478MB


In [13]:
res = defaultdict(list)

t = time()
for el in tqdm(librispeech_test_clean):
    res["reference"].append(processor.tokenizer._normalize(el['text']))
    res["prediction"].append(map_to_pred(el, q_model))

t = ((time() - t) / NUM_EXAMPLES) * 1000
print(f"avg. time on example: {t}")

wer = load("wer")
print(100 * wer.compute(references=res["reference"], predictions=res["prediction"]))


50it [01:02,  1.24s/it]


avg. time on example: 1242.180733680725
4.413347685683531


In [16]:
model.model.encoder.layers

ModuleList(
  (0): WhisperEncoderLayer(
    (self_attn): WhisperAttention(
      (k_proj): Linear(in_features=768, out_features=768, bias=False)
      (v_proj): Linear(in_features=768, out_features=768, bias=True)
      (q_proj): Linear(in_features=768, out_features=768, bias=True)
      (out_proj): Linear(in_features=768, out_features=768, bias=True)
    )
    (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (activation_fn): GELUActivation()
    (fc1): Linear(in_features=768, out_features=3072, bias=True)
    (fc2): Linear(in_features=3072, out_features=768, bias=True)
    (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (1): WhisperEncoderLayer(
    (self_attn): WhisperAttention(
      (k_proj): Linear(in_features=768, out_features=768, bias=False)
      (v_proj): Linear(in_features=768, out_features=768, bias=True)
      (q_proj): Linear(in_features=768, out_features=768, bias=True)
      (out_proj): Linear(in_features=

In [48]:
import copy
from torch.nn.utils import prune

p_model = copy.deepcopy(model)

In [49]:
module = p_model.model.encoder.layers[0]
print(module)

WhisperEncoderLayer(
  (self_attn): WhisperAttention(
    (k_proj): Linear(in_features=768, out_features=768, bias=False)
    (v_proj): Linear(in_features=768, out_features=768, bias=True)
    (q_proj): Linear(in_features=768, out_features=768, bias=True)
    (out_proj): Linear(in_features=768, out_features=768, bias=True)
  )
  (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (activation_fn): GELUActivation()
  (fc1): Linear(in_features=768, out_features=3072, bias=True)
  (fc2): Linear(in_features=3072, out_features=768, bias=True)
  (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)


In [50]:
parameters_to_prune = tuple(
    list(
        itertools.chain(
            *[
                [
                    (p_model.model.encoder.layers[el].self_attn.k_proj, "weight"),
                    (p_model.model.encoder.layers[el].self_attn.v_proj, "weight"),
                    (p_model.model.encoder.layers[el].self_attn.q_proj, "weight"),
                    (p_model.model.encoder.layers[el].self_attn.out_proj, "weight"),
                    (p_model.model.encoder.layers[el].self_attn.out_proj, "weight"),
                    (p_model.model.encoder.layers[el].fc1, "weight"),
                    (p_model.model.encoder.layers[el].fc2, "weight"),
                ]
                for el in range(12)
            ]
        )
    )
)

torch.nn.utils.prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

In [51]:
def get_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    return size_all_mb

size_all_mb = get_model_size(p_model)
print('model size: {:.3f}MB'.format(size_all_mb))

model size: 1221.396MB


In [52]:
res = defaultdict(list)

t = time()
for el in tqdm(librispeech_test_clean):
    res["reference"].append(processor.tokenizer._normalize(el['text']))
    res["prediction"].append(map_to_pred(el, p_model))

t = ((time() - t) / NUM_EXAMPLES) * 1000
print(f"avg. time on example: {t}")

wer = load("wer")
print(100 * wer.compute(references=res["reference"], predictions=res["prediction"]))

50it [00:58,  1.17s/it]


avg. time on example: 1169.2960262298584
3.982777179763186
