For models that we want to use in a quantized state (e.g. Llama 3 70B), compute and store quantized version of the models to reduce load times.

In [None]:
!pip install --quiet --upgrade transformers

In [None]:
!pip install --quiet --upgrade torch

In [None]:
!pip install --quiet --upgrade bitsandbytes accelerate

In [None]:
%%time

from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import BitsAndBytesConfig

model_id = "meta-llama/Meta-Llama-3-70B-Instruct"

## 8-bit quantization
# bnb_config = BitsAndBytesConfig(
#     load_in_8bit=True
# )

## 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    # bnb_4bit_compute_dtype=torch.bfloat16
    bnb_4bit_compute_dtype="bfloat16"
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    # torch_dtype=torch.bfloat16,
    torch_dtype="auto",
    trust_remote_code=True,
    low_cpu_mem_usage=True,
    device_map="sequential" ## using sequential instead of auto/balanced since otherwise lm_head gets put on CPU
)

tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    trust_remote_code=True
)

In [None]:
quantized_model_dir = "models/" + model_id.replace("/", "__")

model.save_pretrained(quantized_model_dir)
_ = tokenizer.save_pretrained(quantized_model_dir)

In [None]:
import utils
utils.print_device_info()
utils.print_model_info(model)
utils.print_device_map(model)

In [None]:
import gc
import torch

del model
gc.collect()
torch.cuda.empty_cache()

In [None]:
import utils
utils.print_device_info()

In [None]:
%%time

from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained(
    quantized_model_dir,
    local_files_only=True
)

model = AutoModelForCausalLM.from_pretrained(
    quantized_model_dir,
    device_map="sequential",
    local_files_only=True
)

In [None]:
import utils
utils.print_device_info()
utils.print_model_info(model)
utils.print_device_map(model)