# Gemma 7B-it Template

Hi KaggleX team, this is a template for those that are having trouble with Gemma 7b which is hard to run because of the GPU requirements.

In this notebook we run the quantized (int8) model of Gemma 7B Instruction-tunned ('7b-it-quant'), this basically means
we can run the 7B model on just 1 GPU (Here on Kaggle select either a P100 or a T4 GPU).

**Things to keep in mind**:
* This notebook defaults to the Gemma '7b-it-quant' but I have added code so if you want to use the 2b-it model you can also.
* Run this notebook on Kaggle, if you want to use it on google collab you will need to modify some parts but it can be done.
* This is a PyTorch implementation, not Keras
* I have added comments where you should, and should not customize the code for your own projects.


*Best of luck! - Valentin, KaggleX Team D*

**Before you start:** Make sure to pick an accelerator (GPU) --> Either a P100 or the T4x2 on Kaggle work fine.

-- Run this cell which will set up needed directories and download some requirements:

In [1]:
!pip install -q -U immutabledict sentencepiece 
!git clone https://github.com/google/gemma_pytorch.git
!mkdir /kaggle/working/gemma/
!mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/gemma/

Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.[K
remote: Counting objects: 100% (123/123), done.[K
remote: Compressing objects: 100% (68/68), done.[K
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116 (from 1)[K
Receiving objects: 100% (239/239), 2.18 MiB | 11.85 MiB/s, done.
Resolving deltas: 100% (135/135), done.


-- Run this to import the needed python packages and set the path variables needed

In [2]:
# Imports and path setup
import sys 
sys.path.append("/kaggle/working/gemma_pytorch/") 
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch

# Do not modify, used for Pytorch
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Set the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

## Step 1. Create Fromatter.
We first need to make the formatter which is used for the 'instruction' tunned models, these **\<TOKENS>** are *required*, so dont modify unless you no longer want to use the intruction tunned models.

In [3]:
# Class for formatting the input data
class GemmaFormatter:
    """This class is for the special case of the instruction-tunned models, which require these tokens."""
    
    # DO NOT modify these tokens, they are needed for the 7b-it and 2b-it models (instruction tunned)
    _start_of_turn_user = "<start_of_turn>user\n"
    _start_of_turn_model = "<start_of_turn>model\n"
    _end_of_turn = "<end_of_turn>\n"
    def __init__(self, user_instruct = None, model_response = None):
        self._user_instruct = user_instruct
        self._model_response = model_response
        self.dataset = []
        
    def user_input(self, prompt):
        # This will format a single user prompt
        user_prompt =  self._start_of_turn_user + prompt + "\n" + self._end_of_turn + self._start_of_turn_model
        return user_prompt
        
    def apply_format(self):
        # Will retunr the full dataset with tokens applied to both user input and model response
        for user_ins, model_res in zip(self._user_instruct, self._model_response):
            user_prompt =  self._start_of_turn_user + user_ins + "\n" + self._end_of_turn
            model_prompt = self._start_of_turn_model + model_res + "\n" + self._end_of_turn
            print(user_prompt)
            print(model_prompt)

            self.dataset.append([user_prompt, model_prompt])
            
        return self.dataset

We will now test the formatter, the way it works is you input your data in the form of two arrays, and it will format and return the data.

In [4]:
# IMPORTANT: this is an example, you want to modify these arrays with your dataset
# These are list for the desired user inputs and model response
# This is how I want my chess bot to responed, based on the user_input_array

user_input_array = ["Narrate this chess pgn game: 1. e4 e5 2. Nf3 Nc6 3. O-O d3", "Narrate this chess pgn game: 1. e4 c5 2. Nf3 g6 3. Ne5 Bg7"]
model_output_array = ["""White starts of with pawn to e4, balck responds with pawn to e5. White then plays
Knight to f3, black plays Knight c6. White castles king side, and black plays d3.""", 
           """White starts of with pawn to e4, balck responds with pawn c5. White then plays
Knight to f3, black plays pawn to g6. White continues with Knight e5, black counters with Bishop g7."""]

In [5]:
# Test the formatter class so you can see how it works
test_formatter = GemmaFormatter(user_input_array, model_output_array)

data = test_formatter.apply_format()

print(data)

<start_of_turn>user
Narrate this chess pgn game: 1. e4 e5 2. Nf3 Nc6 3. O-O d3
<end_of_turn>

<start_of_turn>model
White starts of with pawn to e4, balck responds with pawn to e5. White then plays
Knight to f3, black plays Knight c6. White castles king side, and black plays d3.
<end_of_turn>

<start_of_turn>user
Narrate this chess pgn game: 1. e4 c5 2. Nf3 g6 3. Ne5 Bg7
<end_of_turn>

<start_of_turn>model
White starts of with pawn to e4, balck responds with pawn c5. White then plays
Knight to f3, black plays pawn to g6. White continues with Knight e5, black counters with Bishop g7.
<end_of_turn>

[['<start_of_turn>user\nNarrate this chess pgn game: 1. e4 e5 2. Nf3 Nc6 3. O-O d3\n<end_of_turn>\n', '<start_of_turn>model\nWhite starts of with pawn to e4, balck responds with pawn to e5. White then plays\nKnight to f3, black plays Knight c6. White castles king side, and black plays d3.\n<end_of_turn>\n'], ['<start_of_turn>user\nNarrate this chess pgn game: 1. e4 c5 2. Nf3 g6 3. Ne5 Bg7\n<en

## Step 2. Create Agent

Now we make our 'Agent' class, this is the llm model we want to feed our test prompt and see the results, and/or fine tune if needed

In [6]:
class GemmaAgent:
    """This is the Agent which we will load into the GPU device and then call the LLM model and get a response"""
    def __init__(self, variant='7b-it-quant', device='cuda:0', user_instruct=None,
                 model_response=None):
        self._variant = variant
        self._device = torch.device(device)
        self.formatter = GemmaFormatter(user_instruct = user_instruct, model_response = model_response)
        
        # Initialize model variant
        print("Initializing model...")
        _weights_dir  = '/kaggle/input/gemma/pytorch/2b-it/2' if "2b" in variant else '/kaggle/input/gemma/pytorch/7b-it-quant/2'
        model_config = get_config_for_2b() if "2b" in variant else get_config_for_7b()
        model_config.tokenizer = os.path.join(_weights_dir, "tokenizer.model")
        model_config.quant = "quant" in variant

        with _set_default_tensor_type(model_config.get_dtype()):
            print("Default dtype set to:", model_config.get_dtype())
            model = GemmaForCausalLM(model_config)
            ckpt_path = os.path.join(_weights_dir , f'gemma-{variant}.ckpt')
            model.load_weights(ckpt_path)
            self.model = model.to(self._device).eval()
    
    # This function will call the 'generate' function of your model
    def call_llm(self, prompt, max_tokens=100, **sampler_kwargs):
        print("Formatting your input prompt with tokens...")
        prompt = self.formatter.user_input(prompt)
        print(prompt)
        
        if sampler_kwargs is None:
            # IMPORTANT: Modify these paramaters for your LLM or add/remove as needed, these are defaults
            sampler_kwargs = {
                'temperature': 0.1,
                'top_p': 0.5,
                'top_k': 10
            }

        # This is the output from the model
        response = self.model.generate(
        prompt,
        device=self._device,
        output_len=max_tokens,
        **sampler_kwargs)

        return response

Ok, now we have the formatter and the model, so lets load it into memory and test that it works

In [7]:
# Load the GemmaAgent class as 'my_chat_bot', with variant='7b-it-quant' or '2b-it'
my_chat_bot = GemmaAgent(variant='7b-it-quant')

Initializing model...
Default dtype set to: torch.bfloat16


In [8]:
# Lets use a sample test prompt
test_prompt = user_input_array[0]
print(test_prompt)

Narrate this chess pgn game: 1. e4 e5 2. Nf3 Nc6 3. O-O d3


In [9]:
# This will format the prompt above and call the llm and print response
response = my_chat_bot.call_llm(test_prompt)
print(response)

Formatting your input prompt with tokens...
<start_of_turn>user
Narrate this chess pgn game: 1. e4 e5 2. Nf3 Nc6 3. O-O d3
<end_of_turn>
<start_of_turn>model

Sure, here's the narrative of the game:

The white pieces start with the Queen's pawn and the knight moves forward to challenge the pawn's control. The queenside castling moves the king and the queen to the opposite sides of the king, creating a symmetrical position. The black pieces respond with the development of the dark-bishop and the dramatic move of moving the pawn to d3, aiming to create a kingside counter-gambit.


In [10]:
# Now lets try the same prompt with different paramaters
sampler_kwargs = {
    'temperature': 0.3,
    'top_p': 0.3,
    'top_k': 5}

response = my_chat_bot.call_llm(test_prompt, 
                                max_tokens=200,
                                **sampler_kwargs)
print(response)

Formatting your input prompt with tokens...
<start_of_turn>user
Narrate this chess pgn game: 1. e4 e5 2. Nf3 Nc6 3. O-O d3
<end_of_turn>
<start_of_turn>model

Sure, here's the narrative of the game:

The white pieces start the game by moving the e-pawn forward to e5 and the knight moves to f3 to challenge the pawn. The black pieces respond with the move d3 to counter the pawn advance and to bring the king out of the back row.

The moves that follow are the standard opening moves of the Queen's Pawn Game, but the black side has chosen the Najdorf Variation of the Queen's Pawn Game, which is a popular variation that leads to a variety of positions.


### OK! Cool, it does as instructed and gives a good story! (although it does not understand specefic chess moves, and has some wrong interpretations)

# The End - have fun testing these models!

## --- Step 3. LoRA (NOTE: This does not work....if someone can fix it ty :)

Now lets try this again but after we fine tune with LoRA on our test dataset.

In [11]:
# Reminder this is an example dataset from my own project
# Replace with your dataset

user_input_array = ["Narrate this chess pgn game: 1. e4 e5 2. Nf3 Nc6 3. O-O d3", "Narrate this chess pgn game: 1. e4 c5 2. Nf3 g6 3. Ne5 Bg7"]
model_output_array = ["""White starts of with pawn to e4, balck responds with pawn to e5. White then plays
Knight to f3, black plays Knight c6. White castles king side, and black plays d3.""", 
           """White starts of with pawn to e4, balck responds with pawn c5. White then plays
Knight to f3, black plays pawn to g6. White continues with Knight e5, black counters with Bishop g7."""]

In [12]:
# We need to install a new package for LoRA use
!pip install -q -U peft

In [13]:
# Import packages needed for LoRA
from peft import LoraConfig, get_peft_model
import torch

Lets take a look at our model architecture, so we find what layers we apply LoRA to.

In [14]:
my_chat_bot.model.named_modules

<bound method Module.named_modules of GemmaForCausalLM(
  (embedder): Embedding()
  (model): GemmaModel(
    (layers): ModuleList(
      (0-27): 28 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (qkv_proj): Linear()
          (o_proj): Linear()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear()
          (up_proj): Linear()
          (down_proj): Linear()
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (sampler): Sampler()
)>

Ok looks like these are the attention layers: 
```(self_attn): GemmaAttention(
          (qkv_proj): Linear()
          (o_proj): Linear()
        )```

Do not run, will give errors:

In [15]:
# from peft import get_peft_model, LoraConfig, TaskType

# # LoRA configuration
# peft_config = LoraConfig(
# #     task_type=TaskType.CAUSAL_LM,
#     inference_mode=False,
#     target_modules=["qkv_proj", "o_proj"],  # Specify the target modules where you want to apply LoRA
#     r=8,
#     lora_alpha=32,
#     lora_dropout=0.1
# )

# # Apply LoRA to our model from above
# lora_chat_bot = get_peft_model(my_chat_bot.model, peft_config)

# # Move the model to GPU if available, (Assumes a T4x2)
# device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
# lora_chat_bot.to(device)

# # Now the model is ready with LoRA applied
# print(lora_chat_bot)