<a href="https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/gemma-transformers-streamlined.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Gemma in 🤗 Transformers

Set-up Python environment:

In [None]:
!pip install --upgrade --quiet transformers accelerate bitsandbytes

Define quantization config for 4-bit inference:

In [14]:
import torch
from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16,
)

Load models from pre-trained:

In [43]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b", low_cpu_mem_usage=True, quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")

Downloading shards: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 1840.01it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.21it/s]


Step 1: encode the inputs

In [48]:
input_ids = tokenizer("Recipe for pasta:", return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)

Step 2: auto-regressively generate

In [59]:
from transformers import set_seed

set_seed(42)
pred_ids = model.generate(input_ids, do_sample=True, temperature=0.6, max_new_tokens=128)

Step 3: decode the outputs

In [60]:
pred_text = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
print(pred_text[0])

Recipe for pasta:

1. Boil 1.50 ltr of water
2. Add 100 gms of butter and let it melt
3. Add 100 gms of all purpose flour and keep stirring
4. Once the flour is well blended, add 1/2 ltr of the boiling water and keep stirring
5. Keep adding the boiling water a.t.a.t you reach a smooth and stretchable dough
6. Keep kneading the dough (keep kneading for atleast 15 mts)
7. Once the dough is smooth and stretchable, keep rolling it
8. Once the
