# Quant the model to int4

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# Define the quantization configuration
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",  # Optional: Use 'nf4' for Normal Float 4 quantization
)

# Load the model with the quantization configuration
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b-it",
    quantization_config=quantization_config,
    device_map="auto",  # Automatically distribute the model across available devices
)

# Load the Italian tokenizer
tokenizer = AutoTokenizer.from_pretrained("NousResearch/gemma-2b-it-tokenizer")

# Save the quantized model and tokenizer
model_save_path = "gemma-2-2b-it-int4"
model.save_pretrained(model_save_path)
tokenizer.save_pretrained(model_save_path)

print(f"Quantized model and tokenizer saved to {model_save_path}")


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.17it/s]


Quantized model and tokenizer saved to gemma-2-2b-it-int4


# Test the quant int4 model

In [5]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the quantized model and tokenizer
model_name = "google/gemma-2-2b-it"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Function to generate text based on a prompt
def generate_text(prompt, max_new_tokens=500):
    # Encode the prompt text
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    
    # Generate text
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
    
    # Decode the generated tokens
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text

# Example usage
prompt = "Please explain as much as possible regarding pytorch"
generated_text = generate_text(prompt)
print(generated_text)


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.02it/s]


Please explain as much as possible regarding pytorch and its applications.

## PyTorch: A Deep Dive

PyTorch is a powerful open-source machine learning framework developed by Facebook (now Meta). It's known for its flexibility, dynamic computation graph, and ease of use, making it a popular choice for researchers and developers alike.

**Key Features:**

* **Dynamic Computation Graph:** Unlike static graphs in frameworks like TensorFlow, PyTorch allows you to define and modify your model's structure during runtime. This flexibility is crucial for experimentation and debugging.
* **Tensor Operations:** PyTorch leverages tensors, multi-dimensional arrays, for efficient numerical computations. Tensors are the building blocks of your data and models.
* **GPU Acceleration:** PyTorch seamlessly integrates with GPUs, enabling faster training and inference for complex models.
* **Pythonic API:** PyTorch's API is designed to be intuitive and familiar to Python users, making it easy to learn and