In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
from transformers import BitsAndBytesConfig
from torch import bfloat16

def load_model(model_id, bit_count=32):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if bit_count == 32:
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto')
    elif bit_count == 8:
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto', load_in_8bit=quantized)
    elif bit_count == 4:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True, #4-bit quantization
            bnb_4bit_quant_type='nf4', #Normalized float 4
            bnb_4bit_use_double_quant=True, #Second quantization after the first
            bnb_4bit_compute_dtype=bfloat16 #Computation type
        )
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            quantization_config=bnb_config,
            device_map='auto'
        )
        
        print(f"Model Size: {model.get_memory_footprint():,} bytes")
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            
        pipe = pipeline(model=model, tokenizer=tokenizer, task='text-generation')  
        
        return pipe, model, device

In [None]:
bit_count = 4
model_id = "mistralai/Mistral-7B-v0.1"

pipe, model, device = load_model(model_id)
