In [None]:
# Setup the environment
!pip install -q -U immutabledict sentencepiece 
!pip install -q 'kaggle_environments>=1.14.8'
!git clone https://github.com/google/gemma_pytorch.git
!mkdir /kaggle/working/gemma/
!mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/gemma/

In [None]:
import sys 
sys.path.append("/kaggle/working/gemma_pytorch/") 
import contextlib
import os
import torch
import re
import kaggle_environments
import itertools
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
from typing import Iterable

In [None]:
class GemmaFormatter:
    _start_token = '<start_of_turn>'
    _end_token = '<end_of_turn>'

    def __init__(self, system_prompt: str = None, few_shot_examples: Iterable = None):
        self._system_prompt = system_prompt
        self._few_shot_examples = few_shot_examples
        self._turn_user = f"{self._start_token}user\n{{}}{self._end_token}\n"
        self._turn_model = f"{self._start_token}model\n{{}}{self._end_token}\n"
        self.reset()

    def __repr__(self):
        return self._state

    def user(self, prompt):
        self._state += self._turn_user.format(prompt)
        return self

    def model(self, prompt):
        self._state += self._turn_model.format(prompt)
        return self

    def start_user_turn(self):
        self._state += f"{self._start_token}user\n"
        return self

    def start_model_turn(self):
        self._state += f"{self._start_token}model\n"
        return self

    def end_turn(self):
        self._state += f"{self._end_token}\n"
        return self

    def reset(self):
        self._state = ""
        if self._system_prompt is not None:
            self.user(self._system_prompt)
        if self._few_shot_examples is not None:
            self.apply_turns(self._few_shot_examples, start_agent='user')
        return self

    def apply_turns(self, turns: Iterable, start_agent: str):
        formatters = [self.model, self.user] if start_agent == 'model' else [self.user, self.model]
        formatters = itertools.cycle(formatters)
        for fmt, turn in zip(formatters, turns):
            fmt(turn)
        return self

In [None]:
def load_model(VARIANT, device):
    WEIGHTS_PATH = f'/kaggle/input/gemma/pytorch/{VARIANT}/2' 

    @contextlib.contextmanager
    def _set_default_tensor_type(dtype: torch.dtype):
        """Sets the default torch dtype to the given dtype."""
        torch.set_default_dtype(dtype)
        yield
        torch.set_default_dtype(torch.float)

    # Model Config.
    model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
    model_config.tokenizer = os.path.join(WEIGHTS_PATH, "tokenizer.model")
    model_config.quant = "quant" in VARIANT

    # Model.
    with _set_default_tensor_type(model_config.get_dtype()):
        model = GemmaForCausalLM(model_config)
        ckpt_path = os.path.join(WEIGHTS_PATH, f'gemma-{VARIANT}.ckpt')
        model.load_weights(ckpt_path)
        model = model.to(device).eval()
        
    return model

In [None]:
def _parse_keyword(response: str):
    match = re.search(r"(?<=\*\*)([^*]+)(?=\*\*)", response)
    if match is None:
        keyword = ''
    else:
        keyword = match.group().lower()
    return keyword

In [None]:
def _parse_response(response: str, obs: dict):
    if obs['turnType'] == 'ask':
        match = re.search(".+?\?", response.replace('*', ''))
        if match is None:
            question = "Is it a person?"
        else:
            question = match.group()
        return question
    elif obs['turnType'] == 'guess':
        guess = _parse_keyword(response)
        return guess
    else:
        raise ValueError("Unknown turn type:", obs['turnType'])

In [None]:
def interleave_unequal(x, y):
    return [
        item for pair in itertools.zip_longest(x, y) for item in pair if item is not None
    ]

In [None]:
system_prompt = "You are an AI assistant designed to play the 20 Questions game. In this game, the Answerer thinks of a keyword and responds to yes-or-no questions by the Questioner. The keyword is a specific person, place, or thing."

few_shot_examples = [
    "Let's play 20 Questions. You are playing the role of the Questioner. Please ask your first question.",
    "Is it a thing?", "**no**",
    "Is is a country?", "**yes**",
    "Is it Europe?", "**yes** Now guess the keyword.",
    "**France**", "Correct!",
]

formatter = GemmaFormatter(system_prompt=system_prompt, few_shot_examples=few_shot_examples)

In [None]:
obs = {
    'turnType': 'ask',
    'questions': [
        'Is it a living entity?',
        'Is it man-made?',
        'Can it be held in a single hand?'
    ],
    'answers': [
        'no',
        'yes',
        'yes'
    ]
}

In [None]:
formatter.reset()
formatter.user("Let's play 20 Questions. You are playing the role of the Questioner.")
turns = interleave_unequal(obs['questions'], obs['answers'])
formatter.apply_turns(turns, start_agent='model')
if obs['turnType'] == 'ask':
    formatter.user("Please ask a yes-or-no question.")
elif obs['turnType'] == 'guess':
    formatter.user("Now guess the keyword. Surround your guess with double asterisks.")
formatter.start_model_turn()

In [None]:
prompt = str(formatter)
prompt

In [None]:
MACHINE_TYPE = "cuda" 
device = torch.device(MACHINE_TYPE)
max_new_tokens = 32
sampler_kwargs = {
'temperature': 0.01,
'top_p': 0.1,
'top_k': 1,
}

# Gemma 2b V2

In [None]:
model = load_model("2b", device)

response = model.generate(
    prompt,
    device=device,
    output_len=max_new_tokens,
    **sampler_kwargs
)
response = _parse_response(response, obs)
print(response)

# Gemma 2b-it V2

In [None]:
model = load_model("2b-it", device)

response = model.generate(
    prompt,
    device=device,
    output_len=max_new_tokens,
    **sampler_kwargs
)
response = _parse_response(response, obs)
print(response)

# Gemma 7b-it-quant V2

In [None]:
model = load_model("7b-it-quant", device)

response = model.generate(
    prompt,
    device=device,
    output_len=max_new_tokens,
    **sampler_kwargs
)
response = _parse_response(response, obs)
print(response)