# Notebook to show Simple Start

Based on the excllent "LLM 20 Questions Starter Notebook"  

However this is much more simpliffied in order to help to get started with prompt engineering quickly.    

It is not as advanced as the "LLM 20 Questions Starter Notebook" and the prompts are taken directly from the "LLM 20 Questions" notebook supplied for the competition.  

At the end there are cells to help test your prompts. 

Any issues/suggestions for improvements are welcome.

In [None]:
%%bash
cd /kaggle/working
pip install -q -U -t /kaggle/working/submission/lib immutabledict sentencepiece
git clone https://github.com/google/gemma_pytorch.git > /dev/null
mkdir /kaggle/working/submission/lib/gemma/
mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/submission/lib/gemma/

**Add the model using Add Input**  
This is found on the right hand side of the enivronment

In [None]:
# %%writefile submission/main.py

import os
import sys

KAGGLE_AGENT_PATH = "/kaggle_simulations/agent/"
if os.path.exists(KAGGLE_AGENT_PATH):
    sys.path.insert(0, os.path.join(KAGGLE_AGENT_PATH, 'lib'))
else:
    sys.path.insert(0, "/kaggle/working/submission/lib")

import contextlib
import os
import sys
from dataclasses import dataclass
from typing import Literal, List

import torch
from gemma.config import get_config_for_7b
from gemma.model import GemmaForCausalLM

if os.path.exists(KAGGLE_AGENT_PATH):
    WEIGHTS_PATH = os.path.join(KAGGLE_AGENT_PATH, "gemma/pytorch/7b-it-quant/2")
else:
    WEIGHTS_PATH = "/kaggle/input/gemma/pytorch/7b-it-quant/2"


# Set our own dataclass so we know what is in the object
@dataclass
class ObsData:
    answers: List["str"]
    category: str
    keyword: str
    questions: List["str"]
    turn_type: Literal["ask", "guess", "answer"]
    

class Agents:
    def __init__(self):
        self.model = None
        
        # It is much better if you request a GPU session as a 7B model is rather large
        self._device = torch.device("cpu")
        if torch.cuda.is_available():
            self._device = torch.device("cuda:0")
    
    def answer_agent(self, question: str, category: str, keyword: str) -> Literal["yes", "no"]:
        info_prompt = """You are a very precise answerer in a game of 20 questions. The keyword that the questioner is trying to guess is [the {category} {keyword}]. """
        answer_question_prompt = f"""Answer the following question with only yes, no, or if unsure maybe: {question}"""
    
        prompt = "{}{}".format(
            info_prompt.format(category=category, keyword=keyword),
            answer_question_prompt
        )
        
        return self._call_llm(prompt)

    def ask_agent(self, questions: List[str], answers: List[str]) -> str:
        info_prompt = """You are playing a game of 20 questions where you ask the questions and try to figure out the keyword, which will be a real or fictional person, place, or thing. \nHere is what you know so far:\n{q_a_thread}"""
        questions_prompt = """Ask one yes or no question."""

        q_a_thread = ""
        for i in range(0, len(answers)):
            q_a_thread = "{}Q: {} A: {}\n".format(
                q_a_thread,
                questions[i],
                answers[i]
            )
    
        prompt = "{}{}".format(
            info_prompt.format(q_a_thread=q_a_thread),
            questions_prompt
        )

        return self._call_llm(prompt)

    
    def guess_agent(self, questions: List[str], answers: List[str]) -> str:
        # It is expected that the answer is surrounder by **         
        info_prompt = """You are playing a game of 20 questions where you ask the questions and try to figure out the keyword, which will be a real or fictional person, place, or thing. \nHere is what you know so far:\n{q_a_thread}"""
        guess_prompt = """Guess the keyword. Only respond with the exact word/phrase. For example, if you think the keyword is [paris], don't respond with [I think the keyword is paris] or [Is the kewyord Paris?]. Respond only with the word [paris]."""

        q_a_thread = ""
        for i in range(0, len(answers)):
            q_a_thread = "{}Q: {} A: {}\n".format(
                q_a_thread,
                questions[i],
                answers[i]
            )
        
        prompt = "{}{}".format(
            info_prompt.format(q_a_thread=q_a_thread),
            guess_prompt
        )

        return f"**{self._call_llm(prompt)}**"
    
    def _call_llm(self, prompt: str):
        self._set_model()
        
        sampler_kwargs = {
            'temperature': 0.01,
            'top_p': 0.1,
            'top_k': 1,
        }
        
        return self.model.generate(
            prompt,
            device=self._device,
            output_len=100,
            **sampler_kwargs,
        )
         
    def _set_model(self):
        if self.model is None:
            print("No model yet so setting up")
            model_config = get_config_for_7b()
            model_config.tokenizer = os.path.join(WEIGHTS_PATH, "tokenizer.model")
            model_config.quant = True

            # Not using a context manager blows the stack
            with self._set_default_tensor_type(model_config.get_dtype()):
                model = GemmaForCausalLM(model_config)
                ckpt_path = os.path.join(WEIGHTS_PATH , f'gemma-7b-it-quant.ckpt')
                model.load_weights(ckpt_path)
                self.model = model.to(self._device).eval()
    
    @contextlib.contextmanager
    def _set_default_tensor_type(self, dtype: torch.dtype):
        """Set the default torch dtype to the given dtype."""
        torch.set_default_dtype(dtype)
        yield
        torch.set_default_dtype(torch.float)

# The entry point so name and paramters are preset
def agent_fn(obs, cfg) -> str:
    obs_data = ObsData(
        turn_type=obs.turnType,
        questions=obs.questions,
        answers=obs.answers,
        keyword=obs.keyword,
        category=obs.category
    )
    
    if obs_data.turn_type == "ask":
        response = agents.ask_agent(questions=obs.questions, answers=obs.answers)
    if obs_data.turn_type == "guess":
        response = agents.guess_agent(questions=obs.questions, answers=obs.answers)
    if obs_data.turn_type == "answer":
        response = agents.answer_agent(question=obs.questions[-1], category=obs.category, keyword=obs.keyword)
    
    if response is None or len(response) <= 1:
        return "yes"
    
    return response

# Instantiate the agents class so the model is only set once
agents = Agents()

In [None]:
# Manually run the answer agent
@dataclass
class ObsDataIn(ObsData):
    turnType: str
    
obs_data = ObsDataIn(
        turn_type="",
        turnType="answer",
        questions=["Is it a place"],
        answers=[],
        keyword="Antartica",
        category="Place"
    )

print(agent_fn(obs_data, {}))

In [None]:
# Manually run the question agent
@dataclass
class ObsDataIn(ObsData):
    turnType: str
    
obs_data = ObsDataIn(
        turn_type="",
        turnType="ask",
        questions=["Is it a place?"],
        answers=["Yes"],
        keyword="",
        category=""
    )

print(agent_fn(obs_data, {}))

In [None]:
# Manually run the guess agent
@dataclass
class ObsDataIn(ObsData):
    turnType: str
    
obs_data = ObsDataIn(
        turn_type="",
        turnType="guess",
        questions=["Is it a place?", "Is it in the northen hemisphere?", "Is it a city?", "Is it icy?"],
        answers=["Yes", "No", "No", "Yes"],
        keyword="",
        category=""
    )

print(agent_fn(obs_data, {}))