In [None]:
%%bash
mkdir -p /kaggle/working/submission


In [None]:
%%writefile submission/main.py
# Setup
import os
import sys

# **IMPORTANT:** Set up your system path like this to make your code work
# both in notebooks and in the simulations environment.
KAGGLE_AGENT_PATH = "/kaggle_simulations/agent/"
if os.path.exists(KAGGLE_AGENT_PATH):
    sys.path.insert(0, os.path.join(KAGGLE_AGENT_PATH, 'lib'))
else:
    sys.path.insert(0, "/kaggle/working/submission/lib")

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# https://github.com/Lightning-AI/litgpt/issues/327
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)

if os.path.exists(KAGGLE_AGENT_PATH):
    WEIGHTS_PATH = os.path.join(KAGGLE_AGENT_PATH, "llama-3/transformers/8b-chat-hf/1")
else:
    WEIGHTS_PATH = "/kaggle/input/llama-3/transformers/8b-chat-hf/1"

# Prompt Formatting
import itertools
from typing import Iterable


class Llama3Formatter:
    def __init__(self, system_prompt: str = None, few_shot_examples: Iterable = None):
        self._system_prompt = system_prompt
        self._few_shot_examples = few_shot_examples
        self.reset()

    def get_dict(self):
        return self._state

    def user(self, prompt):
        self._state.append({'role': 'user', 'content': prompt})
        return self

    def model(self, prompt):
        self._state.append({'role': 'assistant', 'content': prompt})
        return self
    
    def system(self, prompt):
        self._state.append({'role': 'system', 'content': prompt})
        return self

    def reset(self):
        self._state = []
        if self._system_prompt is not None:
            self.system(self._system_prompt)
        if self._few_shot_examples is not None:
            self.apply_turns(self._few_shot_examples, start_agent='user')
        return self

    def apply_turns(self, turns: Iterable, start_agent: str):
        formatters = [self.model, self.user] if start_agent == 'model' else [self.user, self.model]
        formatters = itertools.cycle(formatters)
        for fmt, turn in zip(formatters, turns):
            fmt(turn)
        return self


# Agent Definitions
import re

class Llama3Agent:
    def __init__(self, system_prompt=None, few_shot_examples=None):
        self.formatter = Llama3Formatter(system_prompt=system_prompt, few_shot_examples=few_shot_examples)
        self.tokenizer = AutoTokenizer.from_pretrained(WEIGHTS_PATH)
        self.terminators = [self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("")]
        
        ### Load original model
        self.model = AutoModelForCausalLM.from_pretrained(
            WEIGHTS_PATH,
            device_map="auto",
            torch_dtype=torch.bfloat16
        )

    def __call__(self, obs, *args):
        self._start_session(obs)
        prompt = self.formatter.get_dict()
        response = self._call_llm(prompt)
        response = self._parse_response(response, obs)
        return response

    def _start_session(self, obs: dict):
        raise NotImplementedError

    def _call_llm(self, prompt, max_new_tokens=32):
        input_ids = self.tokenizer.apply_chat_template(
            prompt,
            add_generation_prompt=True,
            return_tensors="pt",
        ).to(self.model.device)
        outputs = self.model.generate(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
            eos_token_id=self.terminators,
            do_sample=True,
            temperature=0.6,
            top_p=0.9
        )
        response = self.tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)

        return response

    def _parse_keyword(self, response: str):
        match = re.search(r"(?<=\*\*)([^*]+)(?=\*\*)", response)
        if match is None:
            keyword = ''
        else:
            keyword = match.group().lower()
        return keyword

    def _parse_response(self, response: str, obs: dict):
        raise NotImplementedError


def interleave_unequal(x, y):
    return [
        item for pair in itertools.zip_longest(x, y) for item in pair if item is not None
    ]


class Llama3QuestionerAgent(Llama3Agent):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.category_determined = False
        self.is_place = False
        self.first_char_range = (0, 25)  # A to Z in terms of index
        self.second_char_range = None
        self.final_guess = None

    def _start_session(self, obs):
        global guesses

        self.formatter.reset()
        self.formatter.user("Let's play 20 Questions. You are playing the role of the Questioner.")
        turns = interleave_unequal(obs.questions, obs.answers)
        self.formatter.apply_turns(turns, start_agent='model')

        if not self.category_determined:
            self.formatter.user("Is it a place?")
        else:
            if self.second_char_range is None:
                mid_index = (self.first_char_range[0] + self.first_char_range[1]) // 2
                mid_char = chr(65 + mid_index)  # Convert index to alphabet (0 -> A, 1 -> B, ..., 25 -> Z)
                self.formatter.user(f"Does the keyword start with a letter before {mid_char}?")
            elif self.final_guess is None:
                mid_index = (self.second_char_range[0] + self.second_char_range[1]) // 2
                mid_char = chr(65 + mid_index)
                self.formatter.user(f"Does the second letter of the keyword come before {mid_char}?")
            else:
                self.formatter.user(f"Is the keyword **{self.final_guess}**?")

    def _parse_response(self, response: str, obs: dict):
        if not self.category_determined:
            answer = self._parse_keyword(response)
            self.is_place = (answer == 'yes')
            self.category_determined = True
            return "Is it a place?"
        else:
            if self.second_char_range is None:
                answer = self._parse_keyword(response)
                mid_index = (self.first_char_range[0] + self.first_char_range[1]) // 2
                if answer == 'yes':
                    self.first_char_range = (self.first_char_range[0], mid_index)
                else:
                    self.first_char_range = (mid_index + 1, self.first_char_range[1])

                if self.first_char_range[0] == self.first_char_range[1]:
                    self.second_char_range = (0, 25)  # Reset for second character
                    return f"Does the keyword start with {chr(65 + self.first_char_range[0])}?"
                else:
                    mid_index = (self.first_char_range[0] + self.first_char_range[1]) // 2
                    mid_char = chr(65 + mid_index)
                    return f"Does the keyword start with a letter before {mid_char}?"
            elif self.final_guess is None:
                answer = self._parse_keyword(response)
                mid_index = (self.second_char_range[0] + self.second_char_range[1]) // 2
                if answer == 'yes':
                    self.second_char_range = (self.second_char_range[0], mid_index)
                else:
                    self.second_char_range = (mid_index + 1, self.second_char_range[1])

                if self.second_char_range[0] == self.second_char_range[1]:
                    first_char = chr(65 + self.first_char_range[0])
                    second_char = chr(65 + self.second_char_range[0])
                    self.final_guess = first_char + second_char
                    return f"Does the keyword start with {first_char}{second_char}?"
                else:
                    mid_index = (self.second_char_range[0] + self.second_char_range[1]) // 2
                    mid_char = chr(65 + mid_index)
                    return f"Does the second letter of the keyword come before {mid_char}?"
            else:
                answer = self._parse_keyword(response)
                if answer == 'yes':
                    return f"The keyword is **{self.final_guess}**."
                else:
                    self.final_guess = None
                    return "Let's continue guessing."


class Llama3AnswererAgent(Llama3Agent):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _start_session(self, obs):
        self.formatter.reset()
        self.formatter.user(f"Let's play 20 Questions. You are playing the role of the Answerer. The keyword is {obs.keyword} in the category {obs.category}.")
        turns = interleave_unequal(obs.questions, obs.answers)
        self.formatter.apply_turns(turns, start_agent='user')
        self.formatter.user(f"The question is about the keyword {obs.keyword} in the category {obs.category}. Give yes-or-no answer and surround your answer with double asterisks, like **yes** or **no**.")

    def _parse_response(self, response: str, obs: dict):
        answer = self._parse_keyword(response)
        return 'yes' if 'yes' in answer else 'no'


# Agent Creation
system_prompt = "You are a very smart AI assistant designed to play the 20 Questions game. In this game, the Answerer thinks of a keyword and responds to yes-or-no questions by the Questioner. The keyword is a specific place or thing."

few_shot_examples = [
    "Let's play 20 Questions. You are playing the role of the Questioner. Please ask your first question.",
    "Is the category of the keyword place?", "**no**",
    "Is it a food?", "**yes** Now guess the keyword in the category things.",
    "**Veggie Burger**", "Correct.",
]


# **IMPORTANT:** Define agent as a global so you only have to load
# the agent you need. Loading both will likely lead to OOM.
agent = None


def get_agent(name: str):
    global agent
    
    if agent is None and name == 'questioner':
        agent = Llama3QuestionerAgent(
            system_prompt=system_prompt,
            few_shot_examples=few_shot_examples,
        )
    elif agent is None and name == 'answerer':
        agent = Llama3AnswererAgent(
            system_prompt=system_prompt,
            few_shot_examples=few_shot_examples,
        )
    assert agent is not None, "Agent not initialized."

    return agent


turnRound = 1
guesses = []

def agent_fn(obs, cfg):
    global turnRound
    global guesses

    if obs.turnType == "ask":
        if turnRound == 1:
            response = "Is it a place?" # First question
        else:
            response = get_agent('questioner')(obs)
    elif obs.turnType == "guess":
        response = get_agent('questioner')(obs)
        turnRound += 1
        guesses.append(response)
    elif obs.turnType == "answer":
        response = get_agent('answerer')(obs)
        turnRound += 1
    if response is None or len(response) <= 1:
        return "yes"
    else:
        return response

In [None]:
!apt install pigz pv > /dev/null

In [None]:
!tar --use-compress-program='pigz --fast --recursive | pv' -cf submission.tar.gz -C /kaggle/working/submission . -C /kaggle/input/ llama-3/transformers/8b-chat-hf/1

In [None]:
%%bash
mkdir -p /kaggle/working/submission