In [None]:
# Dependencies

!pip install -U \
    -t /kaggle/tmp/lib \
    rigging \
    kaggle \
    transformers \
    accelerate \
    -i https://pypi.org/simple/ bitsandbytes

In [None]:
# Secrets

from kaggle_secrets import UserSecretsClient
secrets = UserSecretsClient()

HF_TOKEN = secrets.get_secret("HF_TOKEN")

KAGGLE_KEY = secrets.get_secret("KAGGLE_KEY")
KAGGLE_USERNAME = secrets.get_secret("KAGGLE_USERNAME")

In [None]:
# Download the model

from huggingface_hub import snapshot_download
from pathlib import Path
import shutil

g_model_path = Path("/kaggle/tmp/model")
if g_model_path.exists():
    shutil.rmtree(g_model_path)
g_model_path.mkdir(parents=True)

snapshot_download(
    repo_id="alokabhishek/Meta-Llama-3-8B-Instruct-bnb-8bit",
    local_dir=g_model_path,
    local_dir_use_symlinks=False,
    token=HF_TOKEN
)

In [None]:
# Validate model loading

import sys

sys.path.insert(0, "/kaggle/tmp/lib")

import rigging as rg

generator = rg.get_generator("transformers!/kaggle/tmp/model,device_map=cuda:0")
chat = generator.chat("Say Hello!").run()

print(chat.last)

In [None]:
%%writefile main.py

import itertools
import os
import sys
import typing as t
from pathlib import Path

# Path fixups

g_working_path = Path('/kaggle/working')
g_input_path = Path('/kaggle/input')
g_temp_path = Path("/kaggle/tmp")
g_agent_path = Path("/kaggle_simulations/agent/")
g_model_path = g_temp_path / "model"

if g_agent_path.exists():
    sys.path.insert(0, str(g_agent_path / "lib"))
    g_model_path = g_agent_path / "model"
else:
    sys.path.insert(0, str(g_temp_path / "lib"))

import rigging as rg  # noqa
from pydantic import BaseModel, field_validator, StringConstraints  # noqa

# Constants

g_generator_id = f"transformers!{g_model_path},trust_remote_code=True,max_tokens=1024,temperature=1.0,top_k=256"

# Types

str_strip = t.Annotated[str, StringConstraints(strip_whitespace=True)]

class Observation(BaseModel):
    step: int
    role: t.Literal["guesser", "answerer"]
    turnType: t.Literal["ask", "answer", "guess"]
    keyword: str
    category: str
    questions: list[str]
    answers: list[str]
    guesses: list[str]
    
    @property
    def empty(self) -> bool:
        return all(len(t) == 0 for t in [self.questions, self.answers, self.guesses])
    
    def get_history(self) -> t.Iterator[tuple[str, str, str]]:
        return itertools.zip_longest(self.questions, self.answers, self.guesses, fillvalue="[none]")

    def get_history_as_xml(self, *, skip_guesses: bool = False) -> str:
        return "\n".join(
            f"""\
            <turn-{i}>
            Question: {question}
            Answer: {answer}
            {'Guess: ' + guess if not skip_guesses else ''}
            </turn-{i}>
            """
            for i, (question, answer, guess) in enumerate(self.get_history())
        ) if not self.empty else "none yet."


class Answer(rg.Model):
    content: t.Literal["yes", "no", "maybe"]

    @field_validator("content", mode="before")
    def validate_content(cls, v: str) -> str:
        for valid in ["yes", "no", "maybe"]:
            if v.lower().startswith(valid):
                return valid
        raise ValueError("Invalid answer, must be one of 'yes', 'no', 'maybe'")

    @classmethod
    def xml_example(cls) -> str:
        return f"{Answer.xml_start_tag()}**yes/no/maybe**{Answer.xml_end_tag()}"


class Question(rg.Model):
    content: str_strip

    @classmethod
    def xml_example(cls) -> str:
        return Question(content="**question**").to_pretty_xml()


class Guess(rg.Model):
    content: str_strip

    @classmethod
    def xml_example(cls) -> str:
        return Guess(content="**thing/place/person**").to_pretty_xml()


# Functions


def ask(base: rg.PendingChat, observation: Observation) -> str:
    chat = (
        base.fork(
            f"""\
            You are currently asking the next question.

            <game-history>
            {observation.get_history_as_xml(skip_guesses=True)}
            </game-history>

            Based on the history above, ask the next most useful yes/no
            question and place it in the following format:
            {Question.xml_example()}

            - Your response should be a focused question which will gather the most information
            - Start general with your questions
            - Always try to bisect the remaining search space
            - Pay attention to previous questions and answers

            What is your next question?
            """
        )
        .until_parsed_as(Question, attempt_recovery=True)
        .run()
    )
    return chat.last.parse(Question).content


def answer(base: rg.PendingChat, observation: Observation) -> t.Literal["yes", "no", "maybe"]:
    last_question = observation.questions[-1]
    chat = (
        base.fork(
            f"""\
            The secret word for this game is "{observation.keyword}".

            You are currently answering a question about the work above.

            The next question is "{last_question}".

            Answer the yes/no question above and place it in the following format:
            {Answer.xml_example()}

            - Your response should be accurate given the keyword above
            - Always answer with "yes", "no", or "maybe"

            What is the answer?
            """
        )
        .until_parsed_as(Answer, attempt_recovery=True)
        .run()
    )
    return chat.last.parse(Answer).content


def guess(base: rg.PendingChat, observation: Observation) -> str:
    pending = (
        base.fork(
            f"""\
            You are currently making an informed guess of the keyword.

            <game-history>
            {observation.get_history_as_xml()}
            </game-history>

            Based on the history above, produce a single next best guess
            for the keyword and place it in the following format:
            {Guess.xml_example()}

            - Avoid repeat guesses based on the history above
            - The guess should be a specific person, place, or thing

            What is your guess?
            """
        )
        .until_parsed_as(Guess, attempt_recovery=True)
        .run()
    )
        
    return chat.last.parse(Guess).content

# Generator

generator = rg.get_generator(g_generator_id)

# Entrypoint

def agent_fn(obs: t.Any, _: t.Any) -> str:
    observation = Observation(**obs.__dict__)
    
    try:
        base = generator.chat("""\
            You are a talented player of the 20 questions game. You are accurate, focused, and
            structured in your approach. You will create useful questions, make guesses, or answer
            questions about a keyword.
            
            """
        )
    
        match observation.turnType:
            case "ask":
                return ask(base, observation)
            case "answer":
                if not observation.keyword:
                    return "maybe"
                return answer(base, observation)
            case "guess":
                return guess(base, observation)
            case _:
                raise ValueError("Unknown turn type")
    except Exception as e:
        print(str(e), file=sys.stderr)
        raise


In [None]:
!apt install pigz pv

In [None]:
# Make the submission.tar.gz

!tar --use-compress-program='pigz --fast' \
    -cf submission.tar.gz \
    --dereference \
    -C /kaggle/tmp model lib \
    -C /kaggle/working main.py

In [None]:
# Push directly to Kaggle

!KAGGLE_USERNAME={KAGGLE_USERNAME} \
 KAGGLE_KEY={KAGGLE_KEY} \
 kaggle competitions submit -c llm-20-questions -f submission.tar.gz -m "Updates"