In [None]:
# Secrets (optional)

from kaggle_secrets import UserSecretsClient
secrets = UserSecretsClient()

HF_TOKEN: str | None  = None
KAGGLE_KEY: str | None = None
KAGGLE_USERNAME: str | None = None
    
try:
    HF_TOKEN = secrets.get_secret("HF_TOKEN")
    KAGGLE_KEY = secrets.get_secret("KAGGLE_KEY")
    KAGGLE_USERNAME = secrets.get_secret("KAGGLE_USERNAME")
except:
    pass

In [None]:
# Dependencies (uv for speed)

!pip install uv

!uv pip install -U \
    --python $(which python) \
    --target /kaggle/tmp/lib \
    rigging==1.3.0 \
    kaggle

!uv pip install -U \
    --python $(which python) \
    --target /kaggle/tmp/srvlib \
    vllm

In [None]:
# Download the model

from huggingface_hub import snapshot_download
from pathlib import Path
import shutil

g_model_path = Path("/kaggle/tmp/model")
if g_model_path.exists():
    shutil.rmtree(g_model_path)
g_model_path.mkdir(parents=True)

snapshot_download(
    repo_id="solidrust/Meta-Llama-3-8B-Instruct-hf-AWQ",
    ignore_patterns="original*",
    local_dir=g_model_path,
    local_dir_use_symlinks=False,
    token=globals().get("HF_TOKEN", None)
)

In [None]:
%%writefile util.py

# Helpers for starting the vLLM server

import subprocess
import os
import socket
import time

def check_port(port: int) -> bool:
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
            sock.settimeout(1)
            result = sock.connect_ex(('localhost', port))
            if result == 0:
                return True
    except socket.error:
        pass
    
    return False

def run_and_wait_for_port(
    cmd: list[str], port: int, env: dict[str, str] | None, timeout: int = 60
) -> subprocess.Popen:
    
    if check_port(port):
        raise ValueError(f"Port {port} is already open")
        
    popen = subprocess.Popen(
        cmd,
        env={**os.environ, **(env or {})},
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL
    )
    
    start_time = time.time()
    while time.time() - start_time < timeout:
        if check_port(port):
            return popen
        time.sleep(1)
    
    popen.terminate()
    raise Exception(f"Process did not open port {port} within {timeout} seconds.")

In [None]:
# Validate vLLM startup (optional - commented for faster builds)

# import importlib
# from pathlib import Path
# import util

# util = importlib.reload(util)

# g_srvlib_path = Path("/kaggle/tmp/srvlib")
# assert g_srvlib_path.exists()

# g_model_path = Path("/kaggle/tmp/model")
# assert g_model_path.exists()

# g_vllm_port = 9999
# g_vllm_model_name = "custom"

# # Start vLLM server

# vllm = util.run_and_wait_for_port([
#     "python", "-m",
#     "vllm.entrypoints.openai.api_server",
#     "--enforce-eager",
#     "--model", str(g_model_path),
#     "--port", str(g_vllm_port),
#     "--served-model-name", g_vllm_model_name
# ], g_vllm_port, {"PYTHONPATH": str(g_srvlib_path)})

# print("vLLM Started")

In [None]:
# Connect with Rigging (optional - commented for faster builds)

# import sys
# import logging

# sys.path.insert(0, "/kaggle/tmp/lib")

# logging.getLogger("LiteLLM").setLevel(logging.WARNING)

# import rigging as rg

# generator = rg.get_generator(
#     f"openai/{g_vllm_model_name}," \
#     f"api_base=http://localhost:{g_vllm_port}/v1," \
#     "api_key=sk-1234," \
#     "stop=<|eot_id|>" # Llama requires some hand holding
# )
# chat = generator.chat("Say Hello!").run()

# print(chat.last)

In [None]:
%%writefile main.py

# Main agent file

import itertools
import os
import sys
import typing as t
from pathlib import Path
import logging

# Path fixups

g_working_path = Path('/kaggle/working')
g_input_path = Path('/kaggle/input')
g_temp_path = Path("/kaggle/tmp")
g_agent_path = Path("/kaggle_simulations/agent/")

g_model_path = g_temp_path / "model"
g_srvlib_path = g_temp_path / "srvlib"
g_lib_path = g_temp_path / "lib"

if g_agent_path.exists():
    g_lib_path = g_agent_path / "lib"
    g_model_path = g_agent_path / "model"
    g_srvlib_path = g_agent_path / "srvlib"

sys.path.insert(0, str(g_lib_path))

# Logging noise

logging.getLogger("LiteLLM").setLevel(logging.WARNING)

# Fixed imports

import util # noqa
import rigging as rg  # noqa
from pydantic import BaseModel, field_validator, StringConstraints  # noqa

# Constants

g_vllm_port = 9999
g_vllm_model_name = "custom"

g_generator_id = (
    f"openai/{g_vllm_model_name}," \
    f"api_base=http://localhost:{g_vllm_port}/v1," \
    "api_key=sk-1234," \
    "stop=<|eot_id|>" # Llama requires some hand holding
)

# Types

str_strip = t.Annotated[str, StringConstraints(strip_whitespace=True)]

class Observation(BaseModel):
    step: int
    role: t.Literal["guesser", "answerer"]
    turnType: t.Literal["ask", "answer", "guess"]
    keyword: str
    category: str
    questions: list[str]
    answers: list[str]
    guesses: list[str]
    
    @property
    def empty(self) -> bool:
        return all(len(t) == 0 for t in [self.questions, self.answers, self.guesses])
    
    def get_history(self) -> t.Iterator[tuple[str, str, str]]:
        return itertools.zip_longest(self.questions, self.answers, self.guesses, fillvalue="[none]")

    def get_history_as_xml(self, *, include_guesses: bool = False) -> str:
        return "\n".join(
            f"""\
            <turn-{i}>
            Question: {question}
            Answer: {answer}
            {'Guess: ' + guess if include_guesses else ''}
            </turn-{i}>
            """
            for i, (question, answer, guess) in enumerate(self.get_history())
        ) if not self.empty else "none yet."


class Answer(rg.Model):
    content: t.Literal["yes", "no"]

    @field_validator("content", mode="before")
    def validate_content(cls, v: str) -> str:
        for valid in ["yes", "no"]:
            if v.lower().startswith(valid):
                return valid
        raise ValueError("Invalid answer, must be 'yes' or 'no'")

    @classmethod
    def xml_example(cls) -> str:
        return f"{Answer.xml_start_tag()}**yes/no**{Answer.xml_end_tag()}"


class Question(rg.Model):
    content: str_strip

    @classmethod
    def xml_example(cls) -> str:
        return Question(content="**question**").to_pretty_xml()


class Guess(rg.Model):
    content: str_strip

    @classmethod
    def xml_example(cls) -> str:
        return Guess(content="**thing/place/person**").to_pretty_xml()


# Functions


def ask(base: rg.PendingChat, observation: Observation) -> str:
    if observation.step == 0:
        # override first question until keyword bug is fixed.
        return "Are we playing 20 questions?"
    
    chat = (
        base.fork(
            f"""\
            You are currently asking the next question.

            <game-history>
            {observation.get_history_as_xml()}
            </game-history>

            Based on the history above, ask the next most useful yes/no
            question and place it in the following format:
            {Question.xml_example()}

            - Your response should be a focused question which will gather the most information
            - Start general with your questions
            - Always try to bisect the remaining search space
            - Pay attention to previous questions and answers

            Before you begin, document your analysis of the game history if available,
            then write your question.
            """
        )
        .until_parsed_as(Question, attempt_recovery=True)
        .run()
    )
    return chat.last.parse(Question).content


def answer(base: rg.PendingChat, observation: Observation) -> t.Literal["yes", "no"]:
    if not observation.keyword:
        print("Keyword wasn't provided to answerer", file=sys.stderr)
        return "yes" # override until keyword bug is fixed.
            
    last_question = observation.questions[-1]
    chat = (
        base.fork(
            f"""\
            The secret word for this game is "{observation.keyword}" [{observation.category}]

            You are currently answering a question about the word above.

            The next question is "{last_question}".

            Answer the yes/no question above and place it in the following format:
            {Answer.xml_example()}

            - Your response should be accurate given the keyword above
            - Always answer with "yes" or "no"

            What is the answer?
            """
        )
        .until_parsed_as(Answer, attempt_recovery=True)
        .run()
    )
    return chat.last.parse(Answer).content


def guess(base: rg.PendingChat, observation: Observation) -> str:
    chat = (
        base.fork(
            f"""\
            You are currently making an informed guess of the keyword.

            <game-history>
            {observation.get_history_as_xml()}
            </game-history>

            Based on the history above, produce a single next best guess
            for the keyword and place it in the following format:
            {Guess.xml_example()}

            - Avoid repeat guesses based on the history above
            - The guess should be a specific person, place, or thing

            Before you begin, document your analysis of the game history if available,
            then write your guess.
            """
        )
        .until_parsed_as(Guess, attempt_recovery=True)
        .run()
    )
        
    return chat.last.parse(Guess).content

# vLLM and Generator

vllm = util.run_and_wait_for_port([
    "python", "-m",
    "vllm.entrypoints.openai.api_server",
    "--enforce-eager",
    "--model", str(g_model_path),
    "--port", str(g_vllm_port),
    "--served-model-name", g_vllm_model_name
], g_vllm_port, {"PYTHONPATH": str(g_srvlib_path)})

print("vLLM Started")

generator = rg.get_generator(g_generator_id)

base =  generator.chat("""\
You are a talented player of the 20 questions game. You are accurate, focused, and
structured in your approach. You will create useful questions, make guesses, or answer
questions about a keyword.

""")

# Entrypoint

def agent_fn(obs: t.Any, _: t.Any) -> str:
    observation = Observation(**obs.__dict__)
    
    try:
        match observation.turnType:
            case "ask":
                return ask(base, observation)
            case "answer":
                return answer(base, observation)
            case "guess":
                return guess(base, observation)
            case _:
                raise ValueError("Unknown turn type")
    except Exception as e:
        print(str(e), file=sys.stderr)
        raise


In [None]:
!apt install pigz pv

In [None]:
!tar --use-compress-program='pigz --fast' \
    -cf submission.tar.gz \
    --dereference \
    -C /kaggle/tmp model lib srvlib \
    -C /kaggle/working main.py util.py

In [None]:
!KAGGLE_USERNAME={KAGGLE_USERNAME} \
 KAGGLE_KEY={KAGGLE_KEY} \
 kaggle competitions submit -c llm-20-questions -f submission.tar.gz -m "Updates"