Hi all, I split the code into several parts and added some comments and explanation generated from chatGPT. 

The Gemma Example in this link(https://www.kaggle.com/models/google/gemma/PyTorch/7b-it-quant/2) might useful for the beginner.

Hope this notebook can help. 

---

This notebook illustrates the agent creation process for the **LLM 20 Questions**. Running this notebook produces a `submission.tar.gz` file. You may submit this file directly from the **Submit to competition** heading to the right. Alternatively, from the notebook viewer, click the *Output* tab then find and download `submission.tar.gz`. Click **Submit Agent** at the upper-left of the competition homepage to upload your file and make your submission. 

In [None]:
%%bash
cd /kaggle/working
pip install -q -U -t /kaggle/working/submission/lib immutabledict sentencepiece
git clone https://github.com/google/gemma_pytorch.git > /dev/null
mkdir /kaggle/working/submission/lib/gemma/
mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/submission/lib/gemma/

# Part 1: File Creation

- `%%writefile` tells the notebook to write everything in the cell below this line into a file.
- `submission/main.py` is the path where the file will be written. If the directory `submission` doesn't exist, it will be created.


In [None]:
# #%%writefile submission/main.py
# # Setup
import os
import sys

# Part 2: Importing Required Libraries and Setting Up Weights Path

In [None]:
KAGGLE_AGENT_PATH = "/kaggle_simulations/agent/"
if os.path.exists(KAGGLE_AGENT_PATH):
    sys.path.insert(0, os.path.join(KAGGLE_AGENT_PATH, 'lib'))
else:
    sys.path.insert(0, "/kaggle/working/submission/lib")

import contextlib
import os
import sys
from pathlib import Path

import torch
from gemma.config import get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM

if os.path.exists(KAGGLE_AGENT_PATH):
    WEIGHTS_PATH = os.path.join(KAGGLE_AGENT_PATH, "gemma/pytorch/7b-it-quant/2")
else:
    WEIGHTS_PATH = "/kaggle/input/gemma/pytorch/7b-it-quant/2"

# Part 3: Prompt Formatting with GemmaFormatter

- `GemmaFormatter`: formats the conversation between the user and the model using predefined tokens to mark the start and end of turns.

In [None]:
import itertools
from typing import Iterable


class GemmaFormatter:
    _start_token = '<start_of_turn>'
    _end_token = '<end_of_turn>'

    # Initilatization
    def __init__(self, system_prompt: str = None, few_shot_examples: Iterable = None):
        self._system_prompt = system_prompt
        self._few_shot_examples = few_shot_examples
        self._turn_user = f"{self._start_token}user\n{{}}{self._end_token}\n"
        self._turn_model = f"{self._start_token}model\n{{}}{self._end_token}\n"
        self.reset()

    # Returns the current state of the conversation as a string.
    def __repr__(self):
        return self._state

    # Methods of Adding Turns - Adds a user prompt to the conversation
    def user(self, prompt):
        self._state += self._turn_user.format(prompt)
        return self
        
    # Methods of Adding Turns - Adds a model response to the conversation.
    def model(self, prompt):
        self._state += self._turn_model.format(prompt)
        return self

    # Methods of Adding Turns -  Marks the start of a user turn
    def start_user_turn(self):
        self._state += f"{self._start_token}user\n"
        return self

    # Methods of Adding Turns -  Marks the start of a model turn.
    def start_model_turn(self):
        self._state += f"{self._start_token}model\n"
        return self

    # Methods of Adding Turns -  Marks the end of the current turn
    def end_turn(self):
        self._state += f"{self._end_token}\n"
        return self

    # Reset Method
    def reset(self):
        # Initializes `_state` to an empty string
        self._state = ""  

        # Adds the system prompt if provided.
        if self._system_prompt is not None:
            self.user(self._system_prompt)  
            
        # Applies few-shot examples if provided
        if self._few_shot_examples is not None: 
            self.apply_turns(self._few_shot_examples, start_agent='user')
        return self

    def apply_turns(self, turns: Iterable, start_agent: str):
        formatters = [self.model, self.user] if start_agent == 'model' else [self.user, self.model]
        formatters = itertools.cycle(formatters)
        for fmt, turn in zip(formatters, turns):
            fmt(turn)
        return self

###  Example

In [None]:
# Initialize formatter with a system prompt and few-shot examples
formatter = GemmaFormatter(
    system_prompt="This is a system prompt.",
    few_shot_examples=["Example question?", "Example answer."]
)

# Add a user turn
formatter.user("What is the capital of France?")

# Add a model turn
formatter.model("The capital of France is Paris.")

# Print the formatted conversation
print(formatter)


The `GemmaFormatter` class thus helps in structuring and formatting conversations in a consistent manner, ensuring that turns are properly marked and organized.

# Part 4: Agent Definitions and Utilities

- The `_set_default_tensor_type` context manager temporarily sets the default data type for PyTorch tensors to a specified type and then resets it back to torch.float after the block of code using the context manager is executed.

In [None]:
import re

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Set the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

### Example

In [None]:
torch.tensor([1.0, 2.0]).dtype

In [None]:
with _set_default_tensor_type(torch.float64):
    print(torch.tensor([1.0, 2.0]).dtype)

# Part 5: Base GemmaAgent Class

The GemmaAgent class is designed to:

- Initialize and configure a language model.
- Format and handle prompts and responses.
- Use context managers to temporarily set tensor data types.
- Interact with the model to generate responses based on formatted prompts.

In [None]:
class GemmaAgent:
    def __init__(self, variant='7b-it-quant', device='cuda:0', system_prompt=None, few_shot_examples=None):
        self._variant = variant
        self._device = torch.device(device)
        self.formatter = GemmaFormatter(system_prompt=system_prompt, few_shot_examples=few_shot_examples)

        print("Initializing model")
        
        # Model Config.
        model_config = get_config_for_2b() if "2b" in variant else get_config_for_7b()
        model_config.tokenizer = os.path.join(WEIGHTS_PATH, "tokenizer.model")
        model_config.quant = "quant" in variant
        
        # Model.
        with _set_default_tensor_type(model_config.get_dtype()):
            model = GemmaForCausalLM(model_config)
            ckpt_path = os.path.join(WEIGHTS_PATH , f'gemma-{variant}.ckpt')
            model.load_weights(ckpt_path)
            self.model = model.to(self._device).eval()

    def __call__(self, obs, *args):
        self._start_session(obs)
        prompt = str(self.formatter)
        response = self._call_llm(prompt)
        response = self._parse_response(response, obs)
        print(f"{response=}")
        return response

    def _start_session(self, obs: dict):
        raise NotImplementedError

    def _call_llm(self, prompt, max_new_tokens=32, **sampler_kwargs):
        if sampler_kwargs is None:
            sampler_kwargs = {
                'temperature': 0.01,
                'top_p': 0.1,
                'top_k': 1,
        }
        response = self.model.generate(
            prompt,
            device=self._device,
            output_len=max_new_tokens,
            **sampler_kwargs,
        )
        return response

    def _parse_keyword(self, response: str):
        match = re.search(r"(?<=\*\*)([^*]+)(?=\*\*)", response)
        if match is None:
            keyword = ''
        else:
            keyword = match.group().lower()
        return keyword

    def _parse_response(self, response: str, obs: dict):
        raise NotImplementedError

# Part 6: GemmaQuestionerAgent Class

In [None]:
def interleave_unequal(x, y):
    return [
        item for pair in itertools.zip_longest(x, y) for item in pair if item is not None
    ]

class GemmaQuestionerAgent(GemmaAgent):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _start_session(self, obs):
        self.formatter.reset()
        self.formatter.user("Let's play 20 Questions. You are playing the role of the Questioner.")
        turns = interleave_unequal(obs.questions, obs.answers)
        self.formatter.apply_turns(turns, start_agent='model')
        if obs.turnType == 'ask':
            self.formatter.user("Please ask a yes-or-no question.")
        elif obs.turnType == 'guess':
            self.formatter.user("Now guess the keyword. Surround your guess with double asterisks.")
        self.formatter.start_model_turn()

    def _parse_response(self, response: str, obs: dict):
        if obs.turnType == 'ask':
            match = re.search(".+?\?", response.replace('*', ''))
            if match is None:
                question = "Is it a person?"
            else:
                question = match.group()
            return question
        elif obs.turnType == 'guess':
            guess = self._parse_keyword(response)
            return guess
        else:
            raise ValueError("Unknown turn type:", obs.turnType)

GemmaQuestionerAgent:
- `__init__` ： Sets up the agent by calling the parent class's constructor.
-`_start_session` : Interleaving questions and answers and setting up the conversation format.
- `_parse_response` : Interprets the model's responses differently depending on whether the agent is asking a question or making a guess.

# Part 7: GemmaAnswererAgent Class

In [None]:
class GemmaAnswererAgent(GemmaAgent):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _start_session(self, obs):
        self.formatter.reset()
        self.formatter.user(f"Let's play 20 Questions. You are playing the role of the Answerer. The keyword is {obs.keyword} in the category {obs.category}.")
        turns = interleave_unequal(obs.questions, obs.answers)
        self.formatter.apply_turns(turns, start_agent='user')
        self.formatter.user(f"The question is about the keyword {obs.keyword} in the category {obs.category}. Give yes-or-no answer and surround your answer with double asterisks, like **yes** or **no**.")
        self.formatter.start_model_turn()

    def _parse_response(self, response: str, obs: dict):
        answer = self._parse_keyword(response)
        return 'yes' if 'yes' in answer else 'no'

# Part 8: Agent Creation and Function Definitions

In [None]:
# Agent Creation
system_prompt = "You are an AI assistant designed to play the 20 Questions game. In this game, the Answerer thinks of a keyword and responds to yes-or-no questions by the Questioner. The keyword is a specific person, place, or thing."

few_shot_examples = [
    "Let's play 20 Questions. You are playing the role of the Questioner. Please ask your first question.",
    "Is it a person?", "**no**",
    "Is is a place?", "**yes**",
    "Is it a country?", "**yes** Now guess the keyword.",
    "**France**", "Correct!",
]

# **IMPORTANT:** Define agent as a global so you only have to load
# the agent you need. Loading both will likely lead to OOM.

# Initialize agent variable
agent = None

# Function to get the appropriate agent based on the name
def get_agent(name: str):
    global agent
    
    # If agent is not initialized and the requested agent is a "questioner"
    if agent is None and name == 'questioner':
        # Initialize GemmaQuestionerAgent with specific parameters
        agent = GemmaQuestionerAgent(
            device='cuda:0',  # Device for computation
            system_prompt=system_prompt,  # System prompt for the agent
            few_shot_examples=few_shot_examples,  # Examples to guide the agent's behavior
        )
    # If agent is not initialized and the requested agent is an "answerer"
    elif agent is None and name == 'answerer':
        # Initialize GemmaAnswererAgent with the same parameters
        agent = GemmaAnswererAgent(
            device='cuda:0',
            system_prompt=system_prompt,
            few_shot_examples=few_shot_examples,
        )
    
    # Ensure that the agent is initialized
    assert agent is not None, "Agent not initialized."

    # Return the initialized agent
    return agent

# Function to handle interactions based on observations
def agent_fn(obs, cfg):
    # If observation is for asking a question
    if obs.turnType == "ask":
        # Get the "questioner" agent to respond to the observation
        response = get_agent('questioner')(obs)
    # If observation is for making a guess
    elif obs.turnType == "guess":
        # Get the "questioner" agent to respond to the observation
        response = get_agent('questioner')(obs)
    # If observation is for providing an answer
    elif obs.turnType == "answer":
        # Get the "answerer" agent to respond to the observation
        response = get_agent('answerer')(obs)
    
    # If the response from the agent is either None or very short
    if response is None or len(response) <= 1:
        # Assume a positive response ("yes")
        return "yes"
    else:
        # Return the response received from the agent
        return response

1. **GemmaFormatter Class**: This class handles formatting prompts for the game. It has methods to construct user and model turns, start user and model turns, end turns, reset the state, and apply turns. It ensures consistent formatting of prompts for the agents.

2. **GemmaAgent Class**: This is an abstract class representing a generic agent in the game. It defines common methods and attributes such as initialization, call, starting a session, calling the language model (LLM), parsing responses, and setting default tensor type.

3. **GemmaQuestionerAgent Class** and **GemmaAnswererAgent Class**: These classes inherit from GemmaAgent and implement specific behaviors for the Questioner and Answerer agents, respectively. They override the `_start_session` and `_parse_response` methods to customize the behavior of the agents.

4. **interleave_unequal Function**: This function interleaves two lists of unequal lengths. It's used to interleave questions and answers in the game.

5. **get_agent Function**: This function initializes and returns the appropriate agent based on the input name ('questioner' or 'answerer'). It ensures that only one instance of the agent is created and reused.

6. **agent_fn Function**: This function acts as the entry point for the game. It determines the type of agent to use based on the observation's turn type ('ask', 'guess', or 'answer') and calls the respective agent's `__call__` method to generate a response.


In [None]:
# !apt install pigz pv > /dev/null

In [None]:
# !tar --use-compress-program='pigz --fast --recursive | pv' -cf submission.tar.gz -C /kaggle/working/submission . -C /kaggle/input/ gemma/pytorch/7b-it-quant/2