In [None]:
%%bash
cd /kaggle/working
pip install -q -U -t /kaggle/working/submission/lib immutabledict sentencepiece
git clone https://github.com/google/gemma_pytorch.git > /dev/null
mkdir /kaggle/working/submission/lib/gemma/
mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/submission/lib/gemma/

In [None]:
# for TPU
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev

In [None]:
# for GPU
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# for TPU
device = xm.xla_device()
torch.set_default_tensor_type('torch.FloatTensor')

In [None]:
%%writefile submission/main.py

import torch, itertools, contextlib
import os, sys, re
from typing import Iterable
from pathlib import Path

KAGGLE_AGENT_PATH = "/kaggle_simulations/agent/"
if os.path.exists(KAGGLE_AGENT_PATH):
    sys.path.insert(0, os.path.join(KAGGLE_AGENT_PATH, 'lib'))
    WEIGHTS_PATH = os.path.join(KAGGLE_AGENT_PATH, "gemma/pytorch/2b-it/2")
else:
    sys.path.insert(0, "/kaggle/working/submission/lib")
    WEIGHTS_PATH = "/kaggle/input/gemma/pytorch/2b-it/2"

from gemma.config import get_config_for_2b
from gemma.model import GemmaForCausalLM

class GemmaFormatter:
    _start_token = '<start_of_turn>'
    _end_token = '<end_of_turn>'
    def __init__(self, sp: str = None, fse: Iterable = None):
        self._system_prompt = sp
        self._few_shot_examples = fse
        self._turn_user = f"{self._start_token}user\n{{}}{self._end_token}\n"
        self._turn_model = f"{self._start_token}model\n{{}}{self._end_token}\n"
        self.reset()
    def __repr__(self):
        return self._state
    def user(self, prompt):
        self._state += self._turn_user.format(prompt)
        return self
    def model(self, prompt):
        self._state += self._turn_model.format(prompt)
        return self
    def start_user_turn(self):
        self._state += f"{self._start_token}user\n"
        return self
    def start_model_turn(self):
        self._state += f"{self._start_token}model\n"
        return self
    def end_turn(self):
        self._state += f"{self._end_token}\n"
        return self
    def reset(self):
        self._state = ""
        if self._system_prompt is not None:
            self.user(self._system_prompt)
        if self._few_shot_examples is not None:
            self.apply_turns(self._few_shot_examples, start_agent='user')
        return self
    def apply_turns(self, turns: Iterable, start_agent: str):
        formatters = [self.model, self.user] if start_agent == 'model' else [self.user, self.model]
        formatters = itertools.cycle(formatters)
        for fmt, turn in zip(formatters, turns):
            fmt(turn)
        return self

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

class GemmaAgent:
    def __init__(self, sp=None, fse=None):
        self._device = xm.xla_device()
        self.formatter = GemmaFormatter(sp=sp, fse=fse)
        print("Initializing model")
        model_config = get_config_for_2b()
        model_config.tokenizer = WEIGHTS_PATH + '/tokenizer.model'
        with _set_default_tensor_type(model_config.get_dtype()):
            model = GemmaForCausalLM(model_config)
            model.load_weights(WEIGHTS_PATH + '/gemma-2b-it.ckpt')
            self.model = model.to(self._device).eval()
    def __call__(self, obs, *args):
        self._start_session(obs)
        prompt = str(self.formatter)
        response = self._call_llm(prompt)
        response = self._parse_response(response, obs)
        print(f"{response=}")
        return response
    def _start_session(self, obs: dict):
        raise NotImplementedError
    def _call_llm(self, prompt, max_nt=40, **sampler_kwargs):
        if sampler_kwargs is None:
            sampler_kwargs = {'temperature': 0.8, 'top_p': 0.9, 'top_k': 60,}
        response = self.model.generate(
            prompt, device=self._device, output_len=max_nt, **sampler_kwargs,)
        return response
    def _parse_keyword(self, response: str):
        match = re.search(r"(?<=\*\*)([^*]+)(?=\*\*)", response)
        if match is None: keyword = ''
        else: keyword = match.group().lower()
        return keyword
    def _parse_response(self, response: str, obs: dict):
        raise NotImplementedError

def interleave_unequal(x, y):
    return [item for pair in itertools.zip_longest(x, y) for item in pair if item is not None]

class GemmaQuestionerAgent(GemmaAgent):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def _start_session(self, obs):
        self.formatter.reset()
        self.formatter.user("Let's play 20 Questions. You are playing the role of the Questioner.")
        turns = interleave_unequal(obs.questions, obs.answers)
        self.formatter.apply_turns(turns, start_agent='model')
        if obs.turnType == 'ask':
            self.formatter.user("Please ask a yes-or-no question.")
        elif obs.turnType == 'guess':
            self.formatter.user("Now guess the keyword. Surround your guess with double asterisks.")
        self.formatter.start_model_turn()
    def _parse_response(self, response: str, obs: dict):
        if obs.turnType == 'ask':
            match = re.search(".+?\?", response.replace('*', ''))
            if match is None: question = "Is it a person?" #make random choice for person, place, thing
            else: question = match.group()
            return question
        elif obs.turnType == 'guess':
            guess = self._parse_keyword(response)
            return guess
        else: raise ValueError("Unknown turn type:", obs.turnType)

class GemmaAnswererAgent(GemmaAgent):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def _start_session(self, obs):
        self.formatter.reset()
        self.formatter.user(f"Let's play 20 Questions. You are playing the role of the Answerer. The keyword is {obs.keyword} in the category {obs.category}.")
        turns = interleave_unequal(obs.questions, obs.answers)
        self.formatter.apply_turns(turns, start_agent='user')
        self.formatter.user(f"The question is about the keyword {obs.keyword} in the category {obs.category}. Give yes-or-no answer and surround your answer with double asterisks, like **yes** or **no**.")
        self.formatter.start_model_turn()
    def _parse_response(self, response: str, obs: dict):
        answer = self._parse_keyword(response)
        return 'yes' if 'yes' in answer else 'no'

sp = "You are playing the 20 Questions game. In this game, the Answerer thinks of a keyword and responds to yes-or-no questions by the Questioner. The keyword is a specific person, place, or thing."
fse = [
    "Let's play 20 Questions. You are playing the role of the Questioner. Please ask your first question.",
    "Is it a person?", "**no**",
    "Is is a place?", "**yes**",
    "Is it a country?", "**yes** Now guess the keyword.",
    "**France**", "Correct!",
]
agent = None

def get_agent(name: str):
    global agent
    if agent is None and name == 'questioner':
        agent = GemmaQuestionerAgent(sp=sp, fse=fse,)
    elif agent is None and name == 'answerer':
        agent = GemmaAnswererAgent(sp=sp, fse=fse,)
    assert agent is not None, "Agent not initialized."
    return agent

def agent_fn(obs, cfg):
    if obs.turnType == "ask": response = get_agent('questioner')(obs)
    elif obs.turnType == "guess": response = get_agent('questioner')(obs)
    elif obs.turnType == "answer": response = get_agent('answerer')(obs)
    if response is None or len(response) <= 1: return "yes"
    else: return response

# Testing

In [None]:
!pip install -q pygame

In [None]:
from kaggle_environments import make
env = make("llm_20_questions")

In [None]:
#Run Code
%run submission/main.py

In [None]:
#*** Needs to be fixed ****
#env.run([get_agent('questioner'), "random"])

In [None]:
env.render(mode="ipython")

# Package

In [None]:
!apt install pigz pv > /dev/null

In [None]:
!tar --use-compress-program='pigz --fast --recursive | pv' -cf submission.tar.gz -C /kaggle/working/submission . -C /kaggle/input/ gemma/pytorch/2b-it/2