# Quantized inference example

Gazelle supports quantization via Huggingface natively. This notebook demonstrates quantized loading using 8bit and 4bit configs via bitsandbytes.

Half precision (ie bfloat16) requires 24GB vRAM. 

### 8-bit

This configuration uses ~15.3GB VRAM. For shorter sequences, you should be able to fit this into a 16GB VRAM card (eg T4, V100). 

In [1]:
import torch
import transformers
from transformers import BitsAndBytesConfig

from gazelle import (
    GazelleConfig,
    GazelleForConditionalGeneration,
    GazelleProcessor,
)

model_id = "tincans-ai/gazelle-v0.2"
config = GazelleConfig.from_pretrained(model_id)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)

audio_processor = transformers.Wav2Vec2Processor.from_pretrained(
    "facebook/wav2vec2-base-960h"
)

def inference_collator(audio_input, prompt="Transcribe the following \n<|audio|>", audio_dtype=torch.float16):
    audio_values = audio_processor(
        audio=audio_input, return_tensors="pt", sampling_rate=16000
    ).input_values
    msgs = [
        {"role": "user", "content": prompt},
    ]
    labels = tokenizer.apply_chat_template(
        msgs, return_tensors="pt", add_generation_prompt=True
    )
    return {
        "audio_values": audio_values.squeeze(0).to("cuda").to(audio_dtype),
        "input_ids": labels.to("cuda"),
    }

  from .autonotebook import tqdm as notebook_tqdm
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


In [2]:
import torchaudio
from IPython.display import Audio

test_audio, sr = torchaudio.load("test6.wav")

if sr != 16000:
    test_audio = torchaudio.transforms.Resample(sr, 16000)(test_audio)

In [2]:
quantization_config_8bit = BitsAndBytesConfig(
    load_in_8bit=True,
)

model = GazelleForConditionalGeneration.from_pretrained(
    model_id,
    device_map="cuda:0",
    quantization_config=quantization_config_8bit,
)

Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.27s/it]


In [11]:
inputs = inference_collator(test_audio, "<|audio|>")
tokenizer.decode(model.generate(**inputs, max_new_tokens=64)[0])

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


'<s> [INST] <|audio|>  [/INST]My greatest accomplishment is being able to help people.</s>'

In [7]:
!nvidia-smi

Wed Mar 20 09:03:54 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.10              Driver Version: 551.61         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        On  |   00000000:01:00.0  On |                  Off |
|  0%   50C    P3            103W /  450W |   15253MiB /  24564MiB |     13%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


### 4-bit quantization

This setup uses 4bit weights and bfloat16 compute. The result uses just ~9.4GB vRAM.

In [3]:
quantization_config_4bit = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)


model = GazelleForConditionalGeneration.from_pretrained(
    model_id,
    device_map="cuda:0",
    quantization_config=quantization_config_4bit,
)

Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.26s/it]


In [4]:
inputs = inference_collator(test_audio, "<|audio|>")
tokenizer.decode(model.generate(**inputs, max_new_tokens=64)[0])

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


'<s> [INST] <|audio|>  [/INST]My greatest accomplishment is being able to raise my children to be good people.</s>'

In [5]:
!nvidia-smi

Wed Mar 20 09:07:22 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.10              Driver Version: 551.61         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        On  |   00000000:01:00.0  On |                  Off |
|  0%   53C    P2            122W /  450W |    9339MiB /  24564MiB |     19%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


You can also enable double/nested quantization. https://huggingface.co/blog/4bit-transformers-bitsandbytes#nested-quantization

This further reduces vRAM to under 9GB!

In [3]:
quantization_config_4bit_dq = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)


model = GazelleForConditionalGeneration.from_pretrained(
    model_id,
    device_map="cuda:0",
    quantization_config=quantization_config_4bit_dq,
)

Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00,  1.35s/it]


In [4]:
inputs = inference_collator(test_audio, "<|audio|>")
tokenizer.decode(model.generate(**inputs, max_new_tokens=64)[0])

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


'<s> [INST] <|audio|>  [/INST]My greatest accomplishment is being able to raise my children to be good people.</s>'

In [5]:
!nvidia-smi

Wed Mar 20 09:09:30 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.10              Driver Version: 551.61         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        On  |   00000000:01:00.0  On |                  Off |
|  0%   50C    P3             64W /  450W |    8942MiB /  24564MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
