Notes:

This is a difficult challenge! I have attempted different prompts, added probability wieghts to yes/no answers, adjusted LLM parameters, and tried question refinement. It is hard to understand the performance of the LLM in the evaluation logs since it includes team pairing. I have added a debug to the current version. Trying something new everyday! Great learning experience. 

In [None]:
%%bash
cd /kaggle/working
pip install -q -U -t /kaggle/working/submission/lib immutabledict sentencepiece
git clone https://github.com/google/gemma_pytorch.git > /dev/null
mkdir /kaggle/working/submission/lib/gemma/
mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/submission/lib/gemma/

In [None]:
%%writefile submission/main.py

import os
import sys

KAGGLE_AGENT_PATH = "/kaggle_simulations/agent/"
if os.path.exists(KAGGLE_AGENT_PATH):
    sys.path.insert(0, os.path.join(KAGGLE_AGENT_PATH, 'lib'))
else:
    sys.path.insert(0, "/kaggle/working/submission/lib")

import contextlib
import os
import sys
from pathlib import Path

import torch
from gemma.config import get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM

if os.path.exists(KAGGLE_AGENT_PATH):
    WEIGHTS_PATH = os.path.join(KAGGLE_AGENT_PATH, "gemma/pytorch/7b-it-quant/2")
else:
    WEIGHTS_PATH = "/kaggle/input/gemma/pytorch/7b-it-quant/2"

import itertools
from typing import Iterable

def interleave_unequal(x, y):
    return [
        item for pair in itertools.zip_longest(x, y) for item in pair if item is not None
    ]

class GemmaFormatter:
    _start_token = '<start_of_turn>'
    _end_token = '<end_of_turn>'

    def __init__(self, system_prompt: str = None, few_shot_examples: Iterable = None):
        self._system_prompt = system_prompt
        self._few_shot_examples = few_shot_examples
        self._turn_user = f"{self._start_token}user\n{{}}{self._end_token}\n"
        self._turn_model = f"{self._start_token}model\n{{}}{self._end_token}\n"
        self.reset()

    def __repr__(self):
        return self._state

    def user(self, prompt):
        self._state += self._turn_user.format(prompt)
        return self

    def model(self, prompt):
        self._state += self._turn_model.format(prompt)
        return self

    def start_user_turn(self):
        self._state += f"{self._start_token}user\n"
        return self

    def start_model_turn(self):
        self._state += f"{self._start_token}model\n"
        return self

    def end_turn(self):
        self._state += f"{self._end_token}\n"
        return self

    def reset(self):
        self._state = ""
        if self._system_prompt is not None:
            self.user(self._system_prompt)
        if self._few_shot_examples is not None:
            self.apply_turns(self._few_shot_examples, start_agent='user')
        return self

    def apply_turns(self, turns: Iterable, start_agent: str):
        formatters = [self.model, self.user] if start_agent == 'model' else [self.user, self.model]
        formatters = itertools.cycle(formatters)
        for fmt, turn in zip(formatters, turns):
            fmt(turn)
        return self

import re

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Set the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

from collections import defaultdict

#Agent Creation
    
class GemmaAgent:
    def __init__(self, variant='7b-it-quant', device='cuda:0', system_prompt=None, few_shot_examples=None):
        self._variant = variant
        self._device = torch.device(device)
        self.formatter = GemmaFormatter(system_prompt=system_prompt, few_shot_examples=few_shot_examples)
        self.question_history = []
        self.answer_history = []
        self.keyword_probabilities = defaultdict(float)
        
        print("Initializing model")
        model_config = get_config_for_7b()
        model_config.tokenizer = os.path.join(WEIGHTS_PATH, "tokenizer.model")
        model_config.quant = "quant" in variant

        with _set_default_tensor_type(model_config.get_dtype()):
            self.model = GemmaForCausalLM(model_config)
            ckpt_path = os.path.join(WEIGHTS_PATH , f'gemma-{variant}.ckpt')
            self.model.load_weights(ckpt_path)
            self.model = self.model.to(self._device).eval()
            
        self.tokenizer = self._initialize_tokenizer(model_config.tokenizer)

    def __call__(self, obs, *args):
        self._start_session(obs)
        prompt = str(self.formatter)
        response = self._call_llm(prompt)
        response = self._parse_response(response, obs)
        print(f"{response=}")
        return response

    def _start_session(self, obs: dict):
        self.formatter.reset()
        if 'system_prompt' in obs:
            self.formatter.user(obs['system_prompt'])
        if 'few_shot_examples' in obs:
            self.formatter.apply_turns(obs['few_shot_examples'], start_agent='user')
        self.formatter.start_user_turn()

    def _call_llm(self, prompt, max_new_tokens=50, sampler_kwargs=None):
        if sampler_kwargs is None:
            sampler_kwargs = {
                'temperature': 0.6,
                'top_p': 0.9,
                'top_k': 50,
                'repetition_penalty': 0.8,
                'num_beams': 3,
                'early_stopping': False,
                'do_sample': True
            }
        response = self.model.generate(
            prompt,
            device=self._device,
            output_len=max_new_tokens,
            **sampler_kwargs,
        )
        return response

    def _parse_keyword(self, response: str):
        match = re.search(r"(?<=\*\*)([^*]+)(?=\*\*)", response)
        return match.group().lower() if match else ''

    def parse_response(self, response: str, obs: dict):
        keyword = self._parse_keyword(response)
        turn_type = obs.get('turnType')
        if turn_type == "ask":
            question = self._generate_question(keyword)
            self.question_history.append(question)
            return question
        elif turn_type == "guess":
            return f"**{keyword}**"
        elif turn_type == "answer":
            last_question = obs.get('questions', [])[-1]
            answer = self._analyze_answer(last_question, keyword)
            self.answer_history.append(answer)
            return answer
        else:
            return "Invalid turn type"
    
    def _analyze_answer(self, question: str, context: str):
        question_lower = question.lower()
        keyword = self._parse_keyword(context)
        if "what" in question_lower:
            return f"Hint: It's a {keyword}"
        elif "where" in question_lower:
            return f"Hint: It's related to {keyword}"
        else:
            answer = "Yes" if "is" in question_lower else "No"
            self.answer_history.append(answer)
            for keyword in self.keyword_probabilities:
                self._update_probabilities(keyword, answer)
            return answer
        
    def _update_probabilities(self, keyword: str, answer: str):
        if keyword in self.keyword_probabilities:
            self.keyword_probabilities[keyword] += 0.4 if answer == "Yes" else -0.25
            self.keyword_probabilities[keyword] = max(0.0, min(1.0, self.keyword_probabilities[keyword]))
        
        total_prob = sum(self.keyword_probabilities.values())
        if total_prob > 0:
            for key in self.keyword_probabilities:
                self.keyword_probabilities[key] /= total_prob
    
    def _suggest_keyword(self):
        return max(self.keyword_probabilities, key=self.keyword_probabilities.get, default="")

    def _refine_question(self, question: str, answers: list):
        if answers:
            last_answer = answers[-1].lower()
            if last_answer == "no":
                if "is it" in question.lower():
                    question = question.replace("Is it", "Could it be")
                elif "does it" in question.lower():
                    question = question.replace("Does it", "Might it")
            elif last_answer == "yes":
                if "is it" in question.lower():
                    question = question.replace("Is it", "Is it specifically")
                elif "does it" in question.lower():
                    question = question.replace("Does it", "Does it specifically")
        
        return question

    def _generate_question(self, context: str):
        if self.question_history:
            last_question = self.question_history[-1]
            if last_question in context:
                return "Next logical or creative question based on the answer to the last question..."
        
        input_ids = self.tokenizer.encode_plus(
            context,
            add_special_tokens=True,
            max_length=512,
            return_attention_mask=True,
            return_tensors='pt'
        )
        output = self.model.generate(input_ids['input_ids'], attention_mask=input_ids['attention_mask'])
        question = self.tokenizer.decode(output[0], skip_special_tokens=True)
        question = self._refine_question(question, self.answer_history)
        self.question_history.append(question)
        return question

    def _make_decision(self, obs: dict, context: str):
        turn_type = obs.get('turnType')
        if turn_type == "ask":
            return self._generate_question(context)
        elif turn_type == "guess":
            return f"**{self._suggest_keyword()}**"
        elif turn_type == "answer":
            question = obs.get('questions', [])[-1]
            return self._analyze_answer(question, context)
        return "Invalid turn type"

# Questioner Agent

class GemmaQuestionerAgent(GemmaAgent):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _start_session(self, obs):
        self.formatter.reset()
        self.formatter.user("Let's play the game 20 Questions. You are playing the role of the Questioner.")
        turns = interleave_unequal(obs.questions, obs.answers)
        self.formatter.apply_turns(turns, start_agent='model')
        if obs.turnType == 'ask':
            self.formatter.user(f"Please ask a yes-or-no question based on the context in {obs.questions} and {obs.answers} to discover the keyword. The yes-or-no question should not be something like \"Does it rain in the desert\" or \"Are you happy?\". It should be a question that helps you know the category and specifics about the keyword to help you narrow the possibilities of your guess. Surround your question with double asterisks.")
        elif obs.turnType == 'guess':
            self.formatter.user("Now guess the keyword only. Surround your guess with double asterisks.")
        self.formatter.start_model_turn()

    def _parse_response(self, response: str, obs: dict):
        if obs.turnType == 'ask':
            match = re.search(".+?\?", response.replace('*', ''))
            if match is None:
                question = "Is it a place?"
            else:
                question = match.group()
            return question
        elif obs.turnType == 'guess':
            guess = self._parse_keyword(response)
            return guess
        else:
            raise ValueError("Unknown turn type:", obs.turnType)

#Answerer Agent
            
class GemmaAnswererAgent(GemmaAgent):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _start_session(self, obs):
        self.formatter.reset()
        self.formatter.user(f"Let's play 20 Questions. You are playing the role of the Answerer. The keyword is {obs.keyword} in the category {obs.category}.")
        turns = interleave_unequal(obs.questions, obs.answers)
        self.formatter.apply_turns(turns, start_agent='user')
        self.formatter.user(f"The question is about the keyword {obs.keyword} in the category {obs.category}. Give only yes-or-no answers based on the context of the keyword and category.")
        self.formatter.start_model_turn()

    def _parse_response(self, response: str, obs: dict):
        answer = self._parse_keyword(response)
        return '**yes**' if '**yes**' in answer else '**no**'


system_prompt = "You are an AI assistant designed to ask and answers yes or no questions to discover a secret keyword. The keyword is a specific person, place, or thing. Try to guess the keyword in every round if you are the questioner. Try to narrow down to a catagory then use questions about the starting letter to further narrow your guess."

few_shot_examples = [
    ("Is it a person?", "Yes"),
    ("Is this person a male?", "Yes"),
    ("Is this person an actor?", "No"),
    ("Is this person a politician?", "Yes"),
    ("Is this person currently alive?", "No"),
    ("Is it a place?", "Yes"),
    ("Is this place a city?", "Yes"),
    ("Is this city in the United States?", "Yes"),
    ("Is this city the capital of a state?", "No"),
    ("Is this city on the east coast?", "Yes"),
    ("Is this place a country?", "Yes"),
    ("Is this country in Europe?", "Yes"),
    ("Is this country a member of the European Union?", "Yes"),
    ("Is it known for its historical landmarks?", "Yes"),
    ("Is it a geographical feature?", "Yes"),
    ("Is it a mountain?", "No"),
    ("Is it a river?", "Yes"),
    ("Is this river in Africa?", "Yes"),
    ("Is this river the longest river in the world?", "Yes"),
    ("Does this river begin with the letter n?", "Yes"),
    ("Is the keyword the Nile River?", "Yes"),
    ("Is it a thing?", "Yes"),
    ("Is this thing man-made?", "Yes"),
    ("Is it used for communication?", "Yes"),
    ("Is it a type of technology?", "Yes"),
    ("Is it smaller than a bread box?", "Yes"),
    ("Does this thing begin with the letter s?", "Yes"),
    ("Is it a type of vehicle?", "No"),
    ("Is it found indoors?", "Yes"),
    ("Is it used for entertainment?", "No")
]


agent = None

def get_agent(name: str):
    global agent
    
    if agent is None and name == 'questioner':
        agent = GemmaQuestionerAgent(
            device='cuda:0',
            system_prompt=system_prompt,
            few_shot_examples=few_shot_examples
        )
    elif agent is None and name == 'answerer':
        agent = GemmaAnswererAgent(
            device='cuda:0',
            system_prompt=system_prompt,
            few_shot_examples=few_shot_examples
        )
    assert agent is not None, "Agent not initialized."

    return agent

def agent_fn(obs, cfg):
    if obs.turnType == "ask":
        response = get_agent('questioner')(obs)
    elif obs.turnType == "guess":
        response = get_agent('questioner')(obs)
    elif obs.turnType == "answer":
        response = get_agent('answerer')(obs)
    if response is None or len(response) <= 1:
        return "yes"
    else:
        return response

# Submission

In [None]:
!apt install pigz pv > /dev/null

In [None]:
!tar --use-compress-program='pigz --fast --recursive | pv' -cf submission.tar.gz -C /kaggle/working/submission . -C /kaggle/input/ gemma/pytorch/7b-it-quant/2

# Debug

In [None]:
# from kaggle_environments import make

# def simple_agent(obs, cfg):
#     if obs.turnType == "ask":
#         response = "Is it a duck?"
#     elif obs.turnType == "guess":
#         response = "duck"
#     elif obs.turnType == "answer":
#         response = "no"
#     return response


# debug_config = {'episodeSteps': 61,
#                 'actTimeout': 60,      
#                 'runTimeout': 1200,    
#                 'agentTimeout': 3600}  

# env = make(environment="llm_20_questions", configuration=debug_config, debug=True)


# game_output = env.run(agents=[agent_fn, simple_agent, simple_agent, simple_agent])

# env.render(mode="ipython", width=1080, height=700)