While working through the [Starter Notebook](https://www.kaggle.com/code/ryanholbrook/llm-20-questions-starter-notebook), I created this fully documented version jointly with my buddy ChatGPT. The code is 100% the same, no changes, but I hope you find the comments useful.

This notebook illustrates the agent creation process for the **LLM 20 Questions**. Running this notebook produces a `submission.tar.gz` file. You may submit this file directly from the **Submit to competition** heading to the right. Alternatively, from the notebook viewer, click the *Output* tab then find and download `submission.tar.gz`. Click **Submit Agent** at the upper-left of the competition homepage to upload your file and make your submission. 

In [None]:
%%bash
cd /kaggle/working
pip install -q -U -t /kaggle/working/submission/lib immutabledict sentencepiece
git clone https://github.com/google/gemma_pytorch.git > /dev/null
mkdir /kaggle/working/submission/lib/gemma/
mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/submission/lib/gemma/

The above cell sets up the environment by installing required Python packages and preparing the `gemma_pytorch` library for use in the agent. This ensures that all necessary dependencies are bundled together in the submission file:

- `%%bash` is a cell magic command in Jupyter notebooks that indicates that the entire cell should be executed as a bash script. This means that all the lines following the `%%bash` will be interpreted and run as bash commands in the shell, rather than as Python code.

- `cd /kaggle/working` changes the current directory to `/kaggle/working`, which is a working directory in the Kaggle environment where you can store files and set up your workspace.

- `pip install -q -U -t /kaggle/working/submission/lib immutabledict sentencepiece` installs the `immutabledict` and `sentencepiece` Python packages.
   - `-q` suppresses the output from the installation process (quiet mode).
   - `-U` upgrades the packages if they are already installed.
   - `-t /kaggle/working/submission/lib` specifies the target directory where the packages should be installed. This ensures that the dependencies are included in the submission package.

- `git clone https://github.com/google/gemma_pytorch.git > /dev/null` clones the [`gemma_pytorch` repository](https://github.com/google/gemma_pytorch) from GitHub into the current directory (`/kaggle/working`).
   - `> /dev/null` redirects the output of the `git clone` command to `/dev/null`, effectively hiding it from the notebook's output. The `/dev/null` is a special file in Unix-like operating systems, including the environment used by Kaggle. It acts as a black hole for data: any data written to `/dev/null` is discarded and cannot be retrieved. It is not a directory or a space where you can find files; rather, it is a way to suppress output, keeping the notebook output clean and free of unnecessary details.

- `mkdir /kaggle/working/submission/lib/gemma/` creates a new directory named `gemma` inside the `submission/lib` directory. This directory will hold the files from the `gemma_pytorch` repository.

- `mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/submission/lib/gemma/` moves all the files from the `gemma` directory inside the cloned `gemma_pytorch` repository to the newly created `gemma` directory in `submission/lib`. This ensures that the necessary files from the `gemma_pytorch` repository are included in the submission package.

These are the packeges which are installed:

**immutabledict** is a package that provides an immutable dictionary. Immutable dictionaries are useful when you need to ensure that the dictionary cannot be modified after it has been created. This can help prevent accidental changes to the data structure, which is particularly important in certain applications like configurations, constants, or when working with concurrent programming.

**sentencepiece** is the tokenizer of gemma

In [None]:
%%writefile submission/main.py
# Setup
import os
import sys

# **IMPORTANT:** Set up your system path like this to make your code work
# both in notebooks and in the simulations environment.
KAGGLE_AGENT_PATH = "/kaggle_simulations/agent/"
if os.path.exists(KAGGLE_AGENT_PATH):
    sys.path.insert(0, os.path.join(KAGGLE_AGENT_PATH, 'lib'))
else:
    sys.path.insert(0, "/kaggle/working/submission/lib")

import contextlib
import os
import sys
from pathlib import Path

# Import necessary modules for handling tensor operations and model configurations
import torch
from gemma.config import get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM

# Define the path for model weights.
# Check if the script is running in the Kaggle simulation environment.
if os.path.exists(KAGGLE_AGENT_PATH):
    # If running in the Kaggle simulation environment, set the weights path accordingly
    WEIGHTS_PATH = os.path.join(KAGGLE_AGENT_PATH, "gemma/pytorch/7b-it-quant/2")
else:
    # If running locally or in a different environment, use the local input directory for weights
    WEIGHTS_PATH = "/kaggle/input/gemma/pytorch/7b-it-quant/2"

# Prompt Formatting
import itertools
from typing import Iterable


class GemmaFormatter:
    """
    Class for formatting prompts and responses for the 20 Questions game.
    """
    _start_token = '<start_of_turn>'
    _end_token = '<end_of_turn>'

    def __init__(self, system_prompt: str = None, few_shot_examples: Iterable = None):
        """
        Initialize the GemmaFormatter with an optional system prompt and few-shot examples.
        
        Args:
            system_prompt (str): Initial prompt for the system.
            few_shot_examples (Iterable): Examples to initialize the prompt.
        """
        self._system_prompt = system_prompt
        self._few_shot_examples = few_shot_examples
        self._turn_user = f"{self._start_token}user\n{{}}{self._end_token}\n"
        self._turn_model = f"{self._start_token}model\n{{}}{self._end_token}\n"
        self.reset()

    def __repr__(self):
        """
        Return the current state of the prompt.
        
        Returns:
            str: The current formatted state.
        """
        return self._state

    def user(self, prompt):
        """
        Add a user's prompt to the current state.
        
        Args:
            prompt (str): The user's prompt.
        
        Returns:
            self: The GemmaFormatter instance.
        """
        self._state += self._turn_user.format(prompt)
        return self

    def model(self, prompt):
        """
        Add a model's prompt to the current state.
        
        Args:
            prompt (str): The model's prompt.
        
        Returns:
            self: The GemmaFormatter instance.
        """
        self._state += self._turn_model.format(prompt)
        return self

    def start_user_turn(self):
        """
        Start a new user turn.
        
        Returns:
            self: The GemmaFormatter instance.
        """
        self._state += f"{self._start_token}user\n"
        return self

    def start_model_turn(self):
        """
        Start a new model turn.
        
        Returns:
            self: The GemmaFormatter instance.
        """
        self._state += f"{self._start_token}model\n"
        return self

    def end_turn(self):
        """
        End the current turn.
        
        Returns:
            self: The GemmaFormatter instance.
        """
        self._state += f"{self._end_token}\n"
        return self

    def reset(self):
        """
        Reset the formatter to its initial state,
        including system prompt and few-shot examples if provided.
        
        Returns:
            self: The GemmaFormatter instance.
        """
        self._state = ""
        if self._system_prompt is not None:
            self.user(self._system_prompt)
        if self._few_shot_examples is not None:
            self.apply_turns(self._few_shot_examples, start_agent='user')
        return self

    def apply_turns(self, turns: Iterable, start_agent: str):
        """
        Apply a sequence of turns to the formatter.
        
        Args:
            turns (Iterable): The sequence of turns to apply.
            start_agent (str): Specifies whether 'user' or 'model' starts first.
        
        Returns:
            self: The GemmaFormatter instance.
        """
        formatters = [self.model, self.user] if start_agent == 'model' else [self.user, self.model]
        formatters = itertools.cycle(formatters)
        for fmt, turn in zip(formatters, turns):
            fmt(turn)
        return self


# Agent Definitions
import re


@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """
    Context manager for setting the default tensor type in PyTorch.

    This context manager temporarily sets the default tensor type to the specified dtype
    and restores the original default tensor type upon exiting the context.

    Args:
        dtype (torch.dtype): The desired default tensor type to set.

    Yields:
        None
    """
    #Set the default torch dtype to the given dtype.
    torch.set_default_dtype(dtype)
    yield
    # Restore the default tensor type to float
    torch.set_default_dtype(torch.float)


class GemmaAgent:
    """
    Base class for an agent that interacts with the Gemma language model.
    This class handles model initialization, prompt formatting, and response generation.
    """
    def __init__(self, variant='7b-it-quant', device='cuda:0', system_prompt=None, few_shot_examples=None):
        """
        Initialize the GemmaAgent with the specified configuration.

        Args:
            variant (str): The model variant to use (e.g., '7b-it-quant').
            device (str): The device to run the model on (e.g., 'cuda:0' for GPU).
            system_prompt (str): Initial prompt for the system.
            few_shot_examples (Iterable): Examples to initialize the prompt.
        """
        self._variant = variant
        self._device = torch.device(device)
        self.formatter = GemmaFormatter(system_prompt=system_prompt, few_shot_examples=few_shot_examples)

        print("Initializing model")
        # Get the model configuration based on the variant
        model_config = get_config_for_2b() if "2b" in variant else get_config_for_7b()
        model_config.tokenizer = os.path.join(WEIGHTS_PATH, "tokenizer.model")
        model_config.quant = "quant" in variant

        # Set the default tensor type and initialize the model
        with _set_default_tensor_type(model_config.get_dtype()):
            model = GemmaForCausalLM(model_config)
            ckpt_path = os.path.join(WEIGHTS_PATH , f'gemma-{variant}.ckpt')
            model.load_weights(ckpt_path)
            self.model = model.to(self._device).eval()

    def __call__(self, obs, *args):
        """
        Handle a new observation by starting a session, generating a prompt, and parsing the response.

        Args:
            obs (dict): The observation dictionary containing game state information.
            *args: Additional arguments.

        Returns:
            str: The generated response from the model.
        """
        self._start_session(obs)
        prompt = str(self.formatter)
        response = self._call_llm(prompt)
        response = self._parse_response(response, obs)
        print(f"{response=}")
        return response

    def _start_session(self, obs: dict):
        """
        Start a new session for the agent.

        Args:
            obs (dict): The observation dictionary containing game state information.

        Raises:
            NotImplementedError: This method should be implemented by subclasses.
        """
        raise NotImplementedError

    def _call_llm(self, prompt, max_new_tokens=32, **sampler_kwargs):
        """
        Generate a response from the language model using the provided prompt.

        Args:
            prompt (str): The input prompt for the language model.
            max_new_tokens (int): The maximum number of new tokens to generate.
            **sampler_kwargs: Additional keyword arguments for the sampling process.

        Returns:
            str: The generated response from the model.
        """
        if sampler_kwargs is None:
            sampler_kwargs = {
                'temperature': 0.01,
                'top_p': 0.1,
                'top_k': 1,
        }
        response = self.model.generate(
            prompt,
            device=self._device,
            output_len=max_new_tokens,
            **sampler_kwargs,
        )
        return response

    def _parse_keyword(self, response: str):
        """
        Extract the keyword from the model's response.

        Args:
            response (str): The model's response.

        Returns:
            str: The extracted keyword, or an empty string if no keyword is found.
        """
        # find a substring that is enclosed between double asterisks (**).
        match = re.search(r"(?<=\*\*)([^*]+)(?=\*\*)", response)
        if match is None:
            keyword = ''
        else:
            keyword = match.group().lower()
        return keyword

    def _parse_response(self, response: str, obs: dict):
        """
        Parse the model's response based on the observation.

        Args:
            response (str): The model's response.
            obs (dict): The observation dictionary containing game state information.

        Raises:
            NotImplementedError: This method should be implemented by subclasses.
        """
        raise NotImplementedError


def interleave_unequal(x, y):
    """
    Interleave two lists, x and y, handling unequal lengths by filling in None for missing values.
    
    This function takes two lists and interleaves their elements. If the lists have unequal lengths,
    it uses None to fill in the missing values and excludes these None values from the final result.

    Args:
        x (list): The first list to interleave.
        y (list): The second list to interleave.

    Returns:
        list: A list with elements from x and y interleaved, excluding None values.
    
    Example:
        >>> interleave_unequal([1, 2, 3], ['a', 'b'])
        [1, 'a', 2, 'b', 3]
    """
    return [
        item for pair in itertools.zip_longest(x, y) for item in pair if item is not None
    ]


class GemmaQuestionerAgent(GemmaAgent):
    """
    Agent for playing the role of the Questioner in the 20 Questions game.
    
    This agent is responsible for generating questions based on the game state and 
    parsing responses from the language model to continue the questioning process.
    """
    def __init__(self, *args, **kwargs):
        """
        Initialize the GemmaQuestionerAgent with the provided arguments.
        
        Args:
            *args: Positional arguments to pass to the base class initializer.
            **kwargs: Keyword arguments to pass to the base class initializer.
        """
        super().__init__(*args, **kwargs)

    def _start_session(self, obs):
        """
        Start a new session for the Questioner agent by resetting the formatter 
        and applying the initial prompt and previous turns.
        
        Args:
            obs (dict): The observation dictionary containing game state information.
                - obs.questions (list): List of previous questions asked.
                - obs.answers (list): List of corresponding answers received.
                - obs.turnType (str): Type of the current turn ('ask' or 'guess').
        """
        self.formatter.reset()
        self.formatter.user("Let's play 20 Questions. You are playing the role of the Questioner.")
        turns = interleave_unequal(obs.questions, obs.answers)
        self.formatter.apply_turns(turns, start_agent='model')
        if obs.turnType == 'ask':
            self.formatter.user("Please ask a yes-or-no question.")
        elif obs.turnType == 'guess':
            self.formatter.user("Now guess the keyword. Surround your guess with double asterisks.")
        self.formatter.start_model_turn()

    def _parse_response(self, response: str, obs: dict):
        """
        Parse the model's response based on the type of turn in the game.
        
        Args:
            response (str): The response generated by the language model.
            obs (dict): The observation dictionary containing game state information.
                - obs.turnType (str): Type of the current turn ('ask' or 'guess').
        
        Returns:
            str: The parsed question or guess based on the turn type.
        
        Raises:
            ValueError: If the turn type in the observation is unknown.
        """
        if obs.turnType == 'ask':
            match = re.search(".+?\?", response.replace('*', ''))
            if match is None:
                question = "Is it a person?"
            else:
                question = match.group()
            return question
        elif obs.turnType == 'guess':
            guess = self._parse_keyword(response)
            return guess
        else:
            raise ValueError("Unknown turn type:", obs.turnType)


class GemmaAnswererAgent(GemmaAgent):
    """
    Agent for playing the role of the Answerer in the 20 Questions game.
    
    This agent is responsible for providing yes-or-no answers based on the game state and 
    parsing responses from the language model to continue the answering process.
    """
    def __init__(self, *args, **kwargs):
        """
        Initialize the GemmaAnswererAgent with the provided arguments.
        
        Args:
            *args: Positional arguments to pass to the base class initializer.
            **kwargs: Keyword arguments to pass to the base class initializer.
        """
        super().__init__(*args, **kwargs)

    def _start_session(self, obs):
        """
        Start a new session for the Answerer agent by resetting the formatter 
        and applying the initial prompt and previous turns.
        
        Args:
            obs (dict): The observation dictionary containing game state information.
                - obs.questions (list): List of previous questions asked.
                - obs.answers (list): List of corresponding answers given.
                - obs.keyword (str): The keyword that the questioner is trying to guess.
                - obs.category (str): The category of the keyword.
        """
        self.formatter.reset()
        self.formatter.user(f"Let's play 20 Questions. You are playing the role of the Answerer. The keyword is {obs.keyword} in the category {obs.category}.")
        turns = interleave_unequal(obs.questions, obs.answers)
        self.formatter.apply_turns(turns, start_agent='user')
        self.formatter.user(f"The question is about the keyword {obs.keyword} in the category {obs.category}. Give yes-or-no answer and surround your answer with double asterisks, like **yes** or **no**.")
        self.formatter.start_model_turn()

    def _parse_response(self, response: str, obs: dict):
        """
        Parse the model's response to extract the yes-or-no answer.
        
        Args:
            response (str): The response generated by the language model.
            obs (dict): The observation dictionary containing game state information.
        
        Returns:
            str: 'yes' if the answer contains 'yes', otherwise 'no'.
        """
        answer = self._parse_keyword(response)
        return 'yes' if 'yes' in answer else 'no'


# Agent Creation
system_prompt = "You are an AI assistant designed to play the 20 Questions game. In this game, the Answerer thinks of a keyword and responds to yes-or-no questions by the Questioner. The keyword is a specific person, place, or thing."

few_shot_examples = [
    "Let's play 20 Questions. You are playing the role of the Questioner. Please ask your first question.",
    "Is it a person?", "**no**",
    "Is is a place?", "**yes**",
    "Is it a country?", "**yes** Now guess the keyword.",
    "**France**", "Correct!",
]


# **IMPORTANT:** Define agent as a global so you only have to load
# the agent you need. Loading both will likely lead to OOM.
agent = None


def get_agent(name: str):
    """
    Initialize and return the appropriate agent (questioner or answerer) based on the provided name.
    
    This function uses a global variable to ensure that only one instance of the agent is created.
    If the agent is not already initialized, it creates a new instance of either 
    GemmaQuestionerAgent or GemmaAnswererAgent based on the provided name.
    
    Args:
        name (str): The name of the agent to initialize ('questioner' or 'answerer').

    Returns:
        GemmaAgent: An instance of either GemmaQuestionerAgent or GemmaAnswererAgent.
    
    Raises:
        AssertionError: If the agent name is not recognized or the agent fails to initialize.
    """
    global agent
    
    if agent is None and name == 'questioner':
        agent = GemmaQuestionerAgent(
            device='cuda:0',
            system_prompt=system_prompt,
            few_shot_examples=few_shot_examples,
        )
    elif agent is None and name == 'answerer':
        agent = GemmaAnswererAgent(
            device='cuda:0',
            system_prompt=system_prompt,
            few_shot_examples=few_shot_examples,
        )
    assert agent is not None, "Agent not initialized."

    return agent


def agent_fn(obs, cfg):
    """
    Determine the type of turn and invoke the appropriate agent to generate a response.
    
    This function examines the turn type in the observation and uses the corresponding agent 
    (questioner or answerer) to generate a response. If the response is None or empty, it returns "yes".
    
    Args:
        obs (dict): The observation dictionary containing game state information.
            - obs.turnType (str): The type of the current turn ('ask', 'guess', or 'answer').
        cfg (dict): Configuration settings for the agent (not used in the current implementation).

    Returns:
        str: The response generated by the agent or "yes" if the response is None or empty.
    """
    if obs.turnType == "ask":
        response = get_agent('questioner')(obs)
    elif obs.turnType == "guess":
        response = get_agent('questioner')(obs)
    elif obs.turnType == "answer":
        response = get_agent('answerer')(obs)
    if response is None or len(response) <= 1:
        return "yes"
    else:
        return response

In [None]:
!apt install pigz pv > /dev/null

The above cell installs `pigz` for fast compression and `pv` for progress monitoring:

- **apt install pigz pv**: Installs `pigz` (a parallel implementation of gzip) and `pv` (Pipe Viewer, which allows monitoring the progress of data through a pipeline).
- **> /dev/null**: Redirects the command's output to `/dev/null`, suppressing the output to keep the notebook clean.

In [None]:
!tar --use-compress-program='pigz --fast --recursive | pv' -cf submission.tar.gz -C /kaggle/working/submission . -C /kaggle/input/ gemma/pytorch/7b-it-quant/2

The above cell packages the submission directory and necessary model files into a compressed tarball (`submission.tar.gz`). This tarball includes:
   - All files in `/kaggle/working/submission`.
   - The directory `gemma/pytorch/7b-it-quant/2` from `/kaggle/input`.
   
Detailed breakdown:
- **--use-compress-program='pigz --fast --recursive | pv'**: Specifies to use `pigz` for compression, which will use multiple CPU cores to speed up the process. The `--fast` option ensures quick compression, and `--recursive` processes directories recursively. The output is piped through `pv` to monitor the progress.
- **-cf submission.tar.gz**: Creates a file named `submission.tar.gz`.
- **-C /kaggle/working/submission**: Changes to the directory `/kaggle/working/submission` before adding files to the archive.
- **.**: Adds all files from the current directory (which is `/kaggle/working/submission` due to the `-C` option) to the archive.
- **-C /kaggle/input/**: Changes to the `/kaggle/input/` directory before adding more files.
- **gemma/pytorch/7b-it-quant/2**: Adds the `gemma/pytorch/7b-it-quant/2` directory and its contents from `/kaggle/input/` to the archive.