In [1]:
%pip uninstall -y torch pynvml
%pip install vllm func_timeout

!export VLLM_LOGGING_LEVEL=DEBUG
!export VLLM_TRACE_FUNCTION=1

!cp -r '/kaggle/input/bird-bench' '/kaggle/working/'
# read-only /kaggle/input/ causes issues with opening some databases 

Found existing installation: torch 2.5.1+cu121
Uninstalling torch-2.5.1+cu121:
  Successfully uninstalled torch-2.5.1+cu121
Found existing installation: pynvml 11.4.1
Uninstalling pynvml-11.4.1:
  Successfully uninstalled pynvml-11.4.1
Note: you may need to restart the kernel to use updated packages.
Collecting vllm
  Downloading vllm-0.7.3-cp38-abi3-manylinux1_x86_64.whl.metadata (25 kB)
Collecting func_timeout
  Downloading func_timeout-4.3.5.tar.gz (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting blake3 (from vllm)
  Downloading blake3-1.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Collecting transformers>=4.48.2 (from vllm)
  Downloading transformers-4.49.0-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m1.7 MB/s

In [2]:
import re
import json
import sqlite3
from enum import Enum
from abc import ABC
from pathlib import Path
from typing import  Sequence, Optional, Literal, Callable
import pandas as pd
from tqdm import tqdm
from func_timeout import func_timeout, FunctionTimedOut
from vllm import LLM, SamplingParams

# from itertools import batched
def batched(sequence: Sequence, n: int = 1):
    l = len(sequence)
    for ndx in range(0, l, n):
        yield sequence[ndx:min(ndx + n, l)]

# Database Handler

In [3]:
### core/dbhandler.py
# import json
# import sqlite3
# from pathlib import Path
# from typing import Optional
# from func_timeout import func_timeout, FunctionTimedOut

class SQLiteDatabase:
    """ Handler class for sqlite3 databases. Provides SQL execution capabilities and access to schema"""
    def __init__(self, db_id: str, input_path: Path, exec_timeout: float = 30.0, use_cached_schema: Optional[Path] = None) -> None:
        """ Attributes
            ----------
                db_id: str
                    name of database; database must exist in input_path/db_id/db_id.sqlite
                input_path: Path
                    parent directory of database folder
                db_path: Path
                    full path of the db_id.sqlite file 
                exec_timeout: float
                    maximum number of seconds for query to return a result; aborts returning [(err),]
                schema: dict[str, str]
                    either raw_schema or read from json in path 

                raw_schema: dict[str, str]
                    unaugmented, plain db schemas indexed by table_name, read from db_id.sqlite
                descriptions: dict[str, str]
                    Table descriptions, indexed by table_name, read from table_name.csv 
                    which exist in input_path/db_id/database_description/

                use_cached_schemas: Path | None
                    use pre-generated schema stored in path/to/aug.json provided
                    instead of raw_schema. File must map db_id: schema.
        """
        self.db_id = db_id
        self.input_path = input_path
        self.db_path = (self.input_path / self.db_id / self.db_id).with_suffix('.sqlite')
        self.exec_timeout = exec_timeout

        self.raw_schema: dict[str, str] = self.__fetch_raw_schema()
        self.descriptions: dict[str, str] = self.__fetch_db_descriptions()

        if use_cached_schema:
            with open(use_cached_schema, 'r') as f:
                self.schema = json.load(f)[db_id]
        else:
            self.schema = self.raw_schema

    def __getitem__(self, table_name: str):
        """ Return the schema of a table in the database. """
        return self.schema[table_name] 
    
    def __str__(self):
        """ Returns the database schema as a human-readable/executable string. """
        return "\n\n".join( list(self.schema.values()) )
    
    def run_query(self, sql: str, timeout: Optional[float] = None) -> list[tuple]:
        """ Executes SQL query and fetches all rows. """
        try:
            def execute_sql():
                with sqlite3.connect(self.db_path, uri=True) as conn:
                    rows = conn.execute(sql).fetchall()
                return rows
                
            rows = func_timeout(timeout=(timeout or self.exec_timeout), func=execute_sql)
        except FunctionTimedOut as timeout_error:
            rows = [("Error: timedout", )]
        return rows
    
    def __fetch_raw_schema(self) -> dict[str, str]:
        """ Returns a dict of schema of all tables in a .sqlite database indexed by table name """
        tables = self.run_query("SELECT name FROM sqlite_master WHERE type='table';")
        schemas: dict[str, str] = {}
        for table, in tables:
            if table != "sqlite_sequence":
                schema, = self.run_query(f"SELECT sql FROM sqlite_master WHERE name='{table}';")[0]
                schemas[table] = schema
        return schemas
    
    def __fetch_db_descriptions(self) -> dict[str, str]:
        """ Returns a dict of database_descriptions from each table_name.csv as strings  """
        def case_insensitive_file_reader(filepath: Path):
            content = f'Descriptions file for table at {filepath} does not exist.'
            if filepath.exists():
                with open(filepath, 'r', errors='ignore') as file:
                    content = file.read()
            else:
                file_stem = filepath.stem
                candidate_stems = [
                    stem for stem in (
                        file_stem.capitalize(), file_stem.title(), file_stem.upper(), file_stem.lower()
                    )
                    if filepath.with_stem(stem).exists()
                ]
                if candidate_stems:
                    file_stem = candidate_stems[0]
                    filepath = filepath.with_stem(file_stem)
                    content = case_insensitive_file_reader(filepath)
            return content
        
        descriptions = {}
        for table in self.raw_schema.keys():
            filepath = (self.input_path / self.db_id / 'database_description' / table).with_suffix('.csv')
            descriptions[table] = case_insensitive_file_reader(filepath)

        return descriptions

# Bird-Bench Evaluation Metrics

In [4]:
### core/birdeval.py
# import pandas as pd
# from tqdm import tqdm
# from core.dbhandler import SQLiteDatabase

def get_correctness_labels(df: pd.DataFrame, databases: dict[str, SQLiteDatabase], pred_col: str, true_col: str) -> list[bool]:
    ''' Takes DataFrame of BIRD questions with prediction column.
        Runs gold and predicted SQL queries on databases given.
        Returns labels, where labels[i] is True where pred_sql results same as ground_sql's
    '''
    labels = []
    for i, question in tqdm(df.iterrows(), desc='Executing SQL', total=len(df)):
        db = databases[question['db_id']]
        try:
            pred_res = db.run_query(question[pred_col])
            true_res = db.run_query(question[true_col])
        except Exception as e:
            print(f"Q_{question['question_id']}: {e.__class__.__name__} {e}")
            labels.append(False)
        else:
            labels.append( set(pred_res) == set(true_res) )
    return labels


def calculate_accuracy(df: pd.DataFrame, pred_col: str, true_col: str, labels: list[bool]) -> str:
    ex_report = (
        f"=== EX Results | TrueCol: {true_col} | PredCol: {pred_col} ===\n"
        f"Accuracy : {(sum(labels) / len(labels)) * 100: .3f}%\n"
        "Breakdown by Difficulty:\n"
    )    
    for difficulty in df['difficulty'].unique():
        difficulty_mask = df['difficulty'] == difficulty
        correct_rows = [label for label, mask in zip(labels, difficulty_mask) if mask]
        n_correct = sum(correct_rows)
        n_total = sum(difficulty_mask)
        sub_accuracy = (n_correct / n_total) * 100
        ex_report += f"\t{difficulty}: {sub_accuracy: .3f}% ({n_correct} of {n_total})\n"
    ex_report += '=== end ===\n'
    return ex_report


# # TODO: add soft-f1 score to report
def calculate_softf1():
    raise NotImplementedError
def calculate_ves():
    raise NotImplementedError
def calculate_rves():
    raise NotImplementedError

    
def evaluate(df: pd.DataFrame, databases: dict[str, SQLiteDatabase], pred_col: str, true_col: str = 'SQL') -> tuple[list[bool], str]:
    print(f'\n--- Evaluating Performance | TrueCol: {true_col} | PredCol: {pred_col} ---')
    labels = get_correctness_labels(df, databases, pred_col, true_col)
    ex_report = calculate_accuracy(df, pred_col, true_col, labels)
    # f1_report = calculate_softf1(df, pred_col, true_col, labels)
    # ves_report = calculate_ves(df, pred_col, true_col, labels)
    # rves_report = calculate_rves(df, pred_col, true_col, labels)

    # report = "\n\n".join(ex_report, f1_report, ves_report, rves_report)
    report = ex_report      # until the rest gets implemented
    print(report)
    print(f'--- Evaluation Completed | TrueCol: {true_col} | PredCol: {pred_col} ---\n')
    return labels, report

# Agents

In [5]:
# import re
# import json
# import sqlite3
# from pathlib import Path
# from abc import ABC
# from typing import Optional, Callable, Sequence
# from tqdm import tqdm
# import pandas as pd
# from vllm import LLM, SamplingParams
# from core.dbhandler import SQLiteDatabase


class TextToSQLGenerationOutput:
    def __init__(
        self, input_prompts: list[str], raw_responses: list[str], parsed_sql: list[str], 
        n_in_tokens: list[int], n_out_tokens: list[int]
    ) -> None:
        self.input_prompts = input_prompts
        self.raw_responses = raw_responses
        self.parsed_sql = parsed_sql
        self.n_in_tokens = n_in_tokens
        self.n_out_tokens = n_out_tokens

    def as_dataframe(self, col_suffix: Optional[str] = '') -> pd.DataFrame:
        if col_suffix: 
            col_suffix = f'_{col_suffix}'
        df = pd.DataFrame({
            f'input_prompts{col_suffix}': self.input_prompts,
            f'raw_responses{col_suffix}': self.raw_responses,
            f'parsed_sql{col_suffix}':    self.parsed_sql,
            f'n_in_tokens{col_suffix}':   self.n_in_tokens,
            f'n_out_tokens{col_suffix}':  self.n_out_tokens,
        })
        return df

    def __str__(self):
        print("Input Prompts: ", self.input_prompts)
        print("Raw Responses: ", self.raw_responses)
        print("Parsed SQL: ", self.parsed_sql)
        print("Total Input Tokens = ", sum(self.n_in_tokens))
        print("Total Input Tokens = ", sum(self.n_out_tokens))
        
        
class TextToSQL(ABC):
    """ Base class for all Text-to-SQL generation agents. """
    # TODO: maybe implement a prompt format cleanup function; 
    # [Does Prompt Formatting Have Any Impact on LLM Performance?](https://arxiv.org/pdf/2411.10541)

    def __init__(
        self, llm: LLM, databases: dict[str, SQLiteDatabase], 
        output_path: Path,
    ) -> None:
        """ Attributes
            ----------
                llm: LLM
                    Text generation model for offline generation
                databases: dict[str, SQLiteDatabase]
                    Dictionary of SQLiteDatabases indexed by db_id
                output_path: Path
                    Directory to dump output json
        """
        self.llm = llm
        self.databases = databases
        self.output_path = output_path

    def process_bird_df(self, idx: int, row: pd.DataFrame, **kwargs) -> tuple:
        """ Takes a row of a DataFrame of BIRD Bench questions. 
            Processes and returns necessary columns required by this Agent's generate_response(). 
            Output tuple must be unpackable as parameters to generate_response().
        """
        db = self.databases[ row['db_id'] ]
        schema = str(db)
        question = f"{row['question']}  Hint: {row['evidence']}"
        return schema, question

    def generate_prompt(self, schema: str, question: str, **kwargs) -> str:
        """ Takes a question and a schema to generate the agent's SQL generation prompt """
        raise NotImplementedError
    
    def generate_text(self, prompts: list[str], cfg: SamplingParams, use_tqdm: bool = False) -> TextToSQLGenerationOutput:        
        outputs = self.llm.generate(prompts, sampling_params=cfg, use_tqdm=use_tqdm)

        input_prompts: list[str] = [output.prompt for output in outputs]
        raw_responses: list[str] = [output.outputs[0].text for output in outputs]
        parsed_sql:    list[str] = [self.auto_parse_sql(response) for response in raw_responses]
        n_in_tokens:   list[int] = [len(output.prompt_token_ids) for output in outputs]
        n_out_tokens:  list[int] = [len(output.outputs[0].token_ids) for output in outputs]

        return TextToSQLGenerationOutput(input_prompts, raw_responses, parsed_sql, n_in_tokens, n_out_tokens)
    
    def batched_generate(
        self, df: pd.DataFrame, cfg: SamplingParams, batch_size: int, 
        savename: str, evaluator_fn: Optional[Callable] = None, **kwargs
    ) -> tuple[TextToSQLGenerationOutput, Optional[list[bool]]]:
        """ Generates SQL from a DataFrame of BIRD questions. Evaluates performance using evaluator_fn.
            Saves responses with savename as suffix.
            Kwargs passed on to process_bird_df().
        """
        input_prompts: list[str] = []
        raw_responses: list[str] = []
        parsed_sql:    list[str] = []
        n_in_tokens:   list[int] = []
        n_out_tokens:  list[int] = []
        
        for i, batch in enumerate(tqdm(batched(df, batch_size), desc=f'{savename} Generating SQL')):
            # Generate
            prompts: list[str] = [
                self.generate_prompt(*self.process_bird_df(idx, row, **kwargs))
                for idx, row in batch.iterrows()
            ]
            outputs = self.generate_text(prompts, cfg, use_tqdm=False)

            # Record responses
            input_prompts.extend(outputs.input_prompts)
            raw_responses.extend(outputs.raw_responses)
            parsed_sql.extend(outputs.parsed_sql)
            n_in_tokens.extend(outputs.n_in_tokens)
            n_out_tokens.extend(outputs.n_out_tokens)
            if savename:
                self.dump_to_json(f"{savename}_raw", raw_responses)
                self.dump_to_json(f"{savename}_clean", parsed_sql)
        
        final_output = TextToSQLGenerationOutput(input_prompts, raw_responses, parsed_sql, n_in_tokens, n_out_tokens)
        final_df = pd.concat(
            [df, final_output.as_dataframe(col_suffix=savename)], 
            axis=1,
        )
        if evaluator_fn:
            labels, report = evaluator_fn(final_df, self.databases, f'parsed_sql_{savename}')
            final_df[f'label_{savename}'] = labels
            with open(self.output_path/f'results_{savename}.txt', 'w') as f:
                f.write(report)
        else:
            labels = None
        final_df.to_json(self.output_path / f"df_batgen_{savename}.json", orient='records')
        return final_output, labels

    def parse_with_regex(self, response: str) -> str:
        """ Extracts SQL from responses containing '''sql ... ''' using regex. """
        try:
            sql = re.search(r'```sql(.*?)```', response, re.DOTALL).group(1).strip()
        except AttributeError as e:
            sql = ''
        return sql
    
    def auto_parse_sql(self, response: str) -> str:
        """ Extracts SQL from responses containing '''sql ... ''' using regex. 
            If regex search fails, attempts to parse using LLM.
            Returns cleaned SQL or an empty string.
        """
        matched = self.parse_with_regex(response)
        if not matched:
            prompt = (
                "Please extract the SQL query from the text. Enclose your response within "
                "a ```sql <<your response here>> ``` code block. Exclude any additional "
                "text from your response, leaving only the SQL.\n\n"
                f"### Text:\n{response}\n\n"
                f"### SQL:\n"
            )
            raw_output = self.llm.generate(prompt, SamplingParams(temperature=0), use_tqdm=False)
            llm_parsed = raw_output[0].outputs[0].text
            matched = self.parse_with_regex(llm_parsed)
            if matched:
                print("Successfully parsed with LLM.")
            else:
                print("Failed to parse with LLM. Returning empty string.")
        return matched
    
    def is_sql_same(self, db_id: str, query_1: str, query_2: str) -> bool:
        """ Executes SQL queries and returns True if outputs match, with no operation errors. """
        try:
            res_1 = self.databases[db_id].run_query(query_1)
            res_2 = self.databases[db_id].run_query(query_2)
        except sqlite3.OperationalError as e:
            print(f"{e.__class__.__name__} {e}")
            return False
        else:
            return set(res_1) == set(res_2)
        
    def dump_to_json(self, filename: str, obj: object) -> None:
        """ Dumps a list of objects to self.output_path/filename.json; use for keeping backups. """
        filepath = self.output_path / f"{filename}.json"
        filepath.parent.mkdir(parents=True, exist_ok=True)
        with open(filepath, 'w') as f:
            json.dump(obj, f, ensure_ascii=False, indent=4)

In [6]:
# import pandas as pd
# from sqlgen.base_agent import TextToSQL

# TODO: long prompts performs poorly with small models


class ZeroShotAgent(TextToSQL):
    """ Zero-shot SQL Generator based on OpenAI Cookbook's "Natural Language to SQL" example and zero-shot COT. """            
    def generate_prompt(self, schema: str, question: str) -> str:
        prompt = (
            "Given the following SQLite tables, your job is to write queries given a user’s request. "
            f"### QUESTION\n{question}.\n\n"
            f"### SCHEMA\n{schema}\n\n"
            f"### QUESTION\n{question}.\n\n"
            f"### RESPONSE\nLet's think step by step "
        )
        return prompt
    

### DO NOT USE, NEEDS MODIFICATION
class MetaPromptZeroShotAgent(ZeroShotAgent):
    """ Zero-shot SQL Generator using detailed meta-prompt of instructions"""
    def generate_prompt(self, schema: str, question: str) -> str:
        prompt = (
            "You are an SQLite expert who excels at writing queries. Your job is to write  "
            "a valid SQLite query to answer a given user question based on the schema below. "
            "Here is how you should approach the problem:\n"
            "1. Begin your response with 'Let\'s think step by step.'\n"
            "2. Analyze the question and schema carefully, showing all your workings:\n"
            "   - Decompose the question into subproblems.\n"
            "   - Identify the tables and the columns required to write the query.\n"
            "   - Identify the operations you will need to perform.\n"
            "3. Review your choices before generation:\n"
            "   - Identify if you missed any tables and columns.\n"
            "   - Identify if you picked any unnecessary tables and columns.\n"
            "   - Identify any unnecessary subqueries, joins, groupings, orderings, sortings etc.\n"
            "4. Ensure your choices are correct and optimal.\n"
            "5. Finally, show your reasoning and write down the SQL query.\n\n"
            f"### Schema:\n{schema}\n\n"
            f"### Question:\n{question}."
        )
        return prompt
        

class OptimizerAgent(TextToSQL):
    def process_bird_df(self, idx: int, row: pd.DataFrame, pred_col=str) -> tuple:
        schema, question = super().process_bird_df(idx, row)
        sql = row[pred_col]
        return schema, question, sql

    def generate_prompt(self, schema: str, question: str, sql: str) -> str:
        prompt = (
            "You are an SQLite expert who excels at debugging and optimizing SQL queries. "
            "You will be given a database schema, a question, and an SQL query answering "
            "that question based on the given schema. Carefully analyse the schema, the "
            "question and the query. Your job is to do the following:\n"
            "1. Begin your response with 'Let\'s think step by step.'\n"
            "2. Analyze the query\n"
            "   - identify any invalid SQLite keywords.\n"
            "   - identify any invalid or missing columns and tables.\n"
            "   - identify any unnecessary subqueries, joins, groupings, orderings, sortings etc.\n"
            "   - ensure that query is a single SQL statement.\n"
            "3. Show your reasoning and write down the corrected, optimized, valid SQLite query.\n\n"
            f"### Schema:\n{schema}\n\n"
            f"### Question:\n{question}.\n\n"
            f"### SQL:\n{sql}"
        )
        return prompt

In [7]:
# from typing import Literal, Callable
# from pathlib import Path
# import pandas as pd
# from vllm import LLM, SamplingParams
# from sqlgen.base_agent import TextToSQL, TextToSQLGenerationOutput
# from core.dbhandler import SQLiteDatabase


class ZeroShotStarter(TextToSQL):
    personas = {
        'simple': "who offers short, and simple solutions to user questions",
        'technical': "who provides highly technical answers to user questions",
        'thinker': "who does not hesistate to dig deep into a problem and explore several approaches before settling on a solution"
    }

    def process_bird_df(self, idx: int, row: pd.DataFrame, persona: Literal['simple', 'technical', 'thinker']) -> tuple:
        schema, question = super().process_bird_df(idx, row)
        return schema, question, self.personas[persona]

    def generate_prompt(self, schema: str, question: str, persona: str) -> str:
        prompt = (
            f"You are a helpful SQL coding assistant{' ' + persona if persona else ''}. "
            "Please generate a SQLite query to answer the user question, based on the schema below. "
            "In your response, first briefly explain your reasoning. Your final answer should be enclosed "
            "in a markdown code block.\n\n"
            f"### QUESTION\n{question}.\n\n"
            f"### SCHEMA\n{schema}\n\n"
            f"### QUESTION\n{question}.\n\n"
            f"### RESPONSE\nLet's think step by step "
        )
        return prompt


class DiscussionAgent(TextToSQL):
    # TODO: Add personas to DiscussionAgent
    def process_bird_df(self, idx: int, row: pd.DataFrame, agent_responses: list[dict[int, str]]) -> tuple[str, str, dict[int, str]]:
        schema, question = super().process_bird_df(idx, row)
        return schema, question, agent_responses[idx]
    
    def generate_prompt(self, schema: str, question: str, agent_responses: dict[int, str]) -> str:
        other_responses = ''.join(
            f"###### Agent_{agent}\n{response}\n\n"
            for agent, response in agent_responses.items()
        )
        prompt = (
            "You are a helpful SQL coding agent. You understand that collaborative discussion is the best way to solve problems. "
            "Using the other agent's response as additional information, your job is to generate a SQLite query to answer the "
            "user question based on the schema. In your response, please explain your reasoning clearly so that others may "
            "give you constructive feedback. Your final answer should be enclosed in a markdown code block.\n\n"
            f"### QUESTION\n{question}\n\n"
            f"### SCHEMA\n{schema}\n\n"
            f"### AGENT RESPONSES\n{other_responses}"
            f"### QUESTION\n{question}\n\n"
            f"### YOUR RESPONSE\nLet's think step by step "
        )
        return prompt
    

class DiscussionJudge(TextToSQL):
    def process_bird_df(self, idx: int, row: pd.DataFrame, agent_responses: list[dict[int, str]]) -> tuple:
        schema, question = super().process_bird_df(idx, row)
        return schema, question, agent_responses[idx]
    
    def generate_prompt(self, schema: str, question: str, agent_responses: dict[int, str]) -> str:
        n_agents = len(agent_responses)
        other_responses = ''.join(
            f"###### Agent_{agent}\n{response}\n\n"
            for agent, response in agent_responses.items()
        )
        prompt = (
            f"You are a SQL expert overseeing {n_agents} coding agents collaborating to answer the user question based on the given schema. "
            f"Using the other agents' responses as additional information, generate the production-ready SQLite query. "
            f"Your final answer should be enclosed in a markdown code block.\n\n"
            f"### QUESTION\n{question}\n\n"
            f"### SCHEMA\n{schema}\n\n"
            f"### AGENT RESPONSES\n{other_responses}"
            f"### QUESTION\n{question}\n\n"
            f"### VERDICT\nLet's think step by step "
        )
        return prompt


class MultiAgentDiscussion:
    def discuss(
        df: pd.DataFrame, databases: dict[str, SQLiteDatabase], llm: LLM,
        output_path: Path, savename: str, batch_size: int, evaluator_fn: Callable
    ) -> pd.DataFrame:
            
        # TODO: Add personas to DiscussionAgent
        starter = ZeroShotStarter(llm, databases, output_path)
        agent_1 = DiscussionAgent(llm, databases, output_path)
        agent_2 = DiscussionAgent(llm, databases, output_path)
        agent_3 = DiscussionAgent(llm, databases, output_path)
        judge   = DiscussionJudge(llm, databases, output_path)
        
        cfg = SamplingParams(
            temperature=0.6,
            top_p=0.8,
            repetition_penalty=1.2,
            max_tokens=2048,
        )
        
        def gather_agent_responses(
                agent_num1: int, responses_1: TextToSQLGenerationOutput, 
                agent_num2: int, responses_2: TextToSQLGenerationOutput,
            ) -> list[dict[int, str]]:
            agent_responses = [
                {agent_num1: resp1, agent_num2: resp2}
                for resp1, resp2 in zip(responses_1.raw_responses, responses_2.raw_responses)
            ]
            return agent_responses
        
        def gather_all_responses(
                responses_1: TextToSQLGenerationOutput, 
                responses_2: TextToSQLGenerationOutput, 
                responses_3: TextToSQLGenerationOutput,
            ) -> list[dict[int, str]]:
            agent_responses = [
                {1: resp1, 2: resp2, 3: resp3}
                for resp1, resp2, resp3 in zip(responses_1.raw_responses, responses_2.raw_responses, responses_3.raw_responses)
            ]
            return agent_responses
        
        starters_1, starters_1_label = starter.batched_generate(df, cfg, batch_size, 'starter1', evaluator_fn, persona='simple')
        starters_2, starters_2_label = starter.batched_generate(df, cfg, batch_size, 'starter2', evaluator_fn, persona='technical')
        starters_3, starters_3_label = starter.batched_generate(df, cfg, batch_size, 'starter3', evaluator_fn, persona='thinker')

        agent_1_discuss_r1, a1r1_label = agent_1.batched_generate(df, cfg, batch_size, f'agent1_R{1}', evaluator_fn, agent_responses=gather_agent_responses(2, starters_2, 3, starters_3))
        agent_2_discuss_r1, a2r1_label = agent_2.batched_generate(df, cfg, batch_size, f'agent2_R{1}', evaluator_fn, agent_responses=gather_agent_responses(1, starters_1, 3, starters_3))
        agent_3_discuss_r1, a3r1_label = agent_3.batched_generate(df, cfg, batch_size, f'agent3_R{1}', evaluator_fn, agent_responses=gather_agent_responses(1, starters_1, 2, starters_2))
        
        verdict_r1, verdict_r1_label = judge.batched_generate(df, cfg, batch_size, 'judge_r1', evaluator_fn, agent_responses=gather_all_responses(agent_1_discuss_r1, agent_2_discuss_r1, agent_3_discuss_r1))

        agent_1_discuss_r2, a1r2_label = agent_1.batched_generate(df, cfg, batch_size, f'agent1_R{2}', evaluator_fn, agent_responses=gather_agent_responses(2, agent_2_discuss_r1, 3, agent_3_discuss_r1))
        agent_2_discuss_r2, a2r2_label = agent_2.batched_generate(df, cfg, batch_size, f'agent2_R{2}', evaluator_fn, agent_responses=gather_agent_responses(1, agent_1_discuss_r1, 3, agent_3_discuss_r1))
        agent_3_discuss_r2, a3r2_label = agent_3.batched_generate(df, cfg, batch_size, f'agent3_R{2}', evaluator_fn, agent_responses=gather_agent_responses(1, agent_1_discuss_r1, 2, agent_2_discuss_r1))

        verdict_r2, verdict_r2_label = judge.batched_generate(df, cfg, batch_size, 'judge_r2', evaluator_fn, agent_responses=gather_all_responses(agent_1_discuss_r2, agent_2_discuss_r2, agent_3_discuss_r2))

        agent_1_discuss_r3, a1r3_label = agent_1.batched_generate(df, cfg, batch_size, f'agent1_R{3}', evaluator_fn, agent_responses=gather_agent_responses(2, agent_2_discuss_r2, 3, agent_3_discuss_r2))
        agent_2_discuss_r3, a2r3_label = agent_2.batched_generate(df, cfg, batch_size, f'agent2_R{3}', evaluator_fn, agent_responses=gather_agent_responses(1, agent_1_discuss_r2, 3, agent_3_discuss_r2))
        agent_3_discuss_r3, a3r3_label = agent_3.batched_generate(df, cfg, batch_size, f'agent3_R{3}', evaluator_fn, agent_responses=gather_agent_responses(1, agent_1_discuss_r2, 2, agent_2_discuss_r2))

        verdict_r3, verdict_r3_label = judge.batched_generate(df, cfg, batch_size, 'judge_r3', evaluator_fn, agent_responses=gather_all_responses(agent_1_discuss_r3, agent_2_discuss_r3, agent_3_discuss_r3))

        final_df = pd.concat([
                df,
                starters_1.as_dataframe(col_suffix='start_simple'),
                pd.DataFrame({'label_starter1': starters_1_label}),
                starters_2.as_dataframe(col_suffix='start_technical'),
                pd.DataFrame({'label_starter2': starters_2_label}),
                starters_3.as_dataframe(col_suffix='start_thinker'),
                pd.DataFrame({'label_starter3': starters_3_label}),

                agent_1_discuss_r1.as_dataframe(col_suffix='agent1_r1'),
                pd.DataFrame({'label_agent1_R1': a1r1_label}),
                agent_2_discuss_r1.as_dataframe(col_suffix='agent2_r1'),
                pd.DataFrame({'label_agent2_R1': a2r1_label}),
                agent_3_discuss_r1.as_dataframe(col_suffix='agent3_r1'),
                pd.DataFrame({'label_agent3_R1': a3r1_label}),

                verdict_r1.as_dataframe(col_suffix='judge_r1'),
                pd.DataFrame({'label_judge_r1': verdict_r1_label}),

                agent_1_discuss_r2.as_dataframe(col_suffix='agent1_r2'),
                pd.DataFrame({'label_agent1_R2': a1r2_label}),
                agent_2_discuss_r2.as_dataframe(col_suffix='agent2_r2'),
                pd.DataFrame({'label_agent2_R2': a2r2_label}),
                agent_3_discuss_r2.as_dataframe(col_suffix='agent3_r2'),
                pd.DataFrame({'label_agent3_R2': a3r2_label}),

                verdict_r2.as_dataframe(col_suffix='judge_r2'),
                pd.DataFrame({'label_judge_r2': verdict_r2_label}),

                agent_1_discuss_r3.as_dataframe(col_suffix='agent1_r3'),
                pd.DataFrame({'label_agent1_R3': a1r3_label}),
                agent_2_discuss_r3.as_dataframe(col_suffix='agent2_r3'),
                pd.DataFrame({'label_agent2_R3': a2r3_label}),
                agent_3_discuss_r3.as_dataframe(col_suffix='agent3_r3'),
                pd.DataFrame({'label_agent3_R3': a3r3_label}),

                verdict_r3.as_dataframe(col_suffix='judge_r3'),
                pd.DataFrame({'label_judge_r3': verdict_r3_label}),
            ], 
            axis=1,
        )
        final_df.to_json(output_path/f"df_{savename}_final.json", orient='records')
        print('finished discussion')
        return final_df

# Config

In [8]:
### config.py
# from core.llm import SupportedModels
# from pathlib import Path

class SupoortedModels(Enum):
    qwen25_coder_14b_instruct_awq = '/kaggle/input/qwen2.5-coder/transformers/14b-instruct-awq/1'
    qwen25_coder_32b_instruct_awq = '/kaggle/input/qwen2.5-coder/transformers/32b-instruct-awq/1'

### Experiment Configurations ###
### Import only to utils.py   ###
EXPERIMENT = [
    'zs_32b',
    'mad_250-500',
    'zs_cot_qscq',
    'mad',
    'zero-meta-optim-unaug',
][0]

MODEL = SupoortedModels.qwen25_coder_32b_instruct_awq       # TODO: make this settable from bash somehow maybe?
GPU_MEMORY_UTILIZATION = 0.97
TENSOR_PARALLEL_SIZE = 2                # set equal to number of GPU               
MODEL_MAX_SEQ_LEN = 4096 * 0.5          # 4096*2 is max for 14B AWQ model with fp8 KV-cache on 15GB VRAM 
KV_CACHE_DTYPE = 'fp8'                  # Reduces memory consumption; fp8 might impact models that use Grouped Query Attn like Qwen
BATCH_SIZE = 32                         # saves after every batch
SEED = 42


INPUT_PATH  = Path(f'/kaggle/working/bird-bench/bird-bench/bird-minidev')
OUTPUT_PATH = Path(f'/kaggle/working/results/{MODEL.name}_{EXPERIMENT}/')
BIRD_QUESTION_FILENAME = 'dev.json'
DATABASES_FOLDERNAME = 'dev_databases'
DB_EXEC_TIMEOUT = 30.0                              # maximum number of seconds a query execution is allowed to take
USE_CACHED_SCHEMA = None #Path('/kaggle/working/bird-bench/aug-minidev/aug-minidev/aug.json')  # Use pre-generated schema 

# set all to FALSE for actual runs
USE_DEBUG_DATASET = False                          # Debug with only first 5 bird questions
USE_DEBUG_DB = False                               # True for ['formula_1', 'debit_card_specializing', 'thrombosis_prediction'] only subset

# Experiment Utils

## Data Reader

In [9]:
### utils.py

### BIRD Dataset Reader Function ###
def read_dataset() -> tuple[pd.DataFrame, dict[str, SQLiteDatabase]]:
    """ BIRD dataset reader function.
        1. Reads dataset into DataFrame from "INPUT_PATH/BIRD_QUESTION_FILENAME".
        2. Lists database names from folders in "INPUT_PATH/DB_FOLDERNAME/".
        3. Creates dict of SQLiteDatabases, indexed by db_name.
        Returns DataFrame of BIRD questions and dict of databases.
    """
    df = pd.read_json(INPUT_PATH / BIRD_QUESTION_FILENAME)
    db_names: list[str] = [f.name for f in (INPUT_PATH / DATABASES_FOLDERNAME).iterdir()]
    databases: dict[str, SQLiteDatabase] = {
        db_id: SQLiteDatabase(db_id, (INPUT_PATH / DATABASES_FOLDERNAME), DB_EXEC_TIMEOUT, USE_CACHED_SCHEMA) 
        for db_id in db_names
    }
    print(f'{db_names=}, {len(df)=}')
    return df, databases

## Experiment Scripts

In [10]:
def setup_experiment():
    OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
    df, databases = read_dataset()
    cfg = SamplingParams(
        temperature=0,
        top_p=1,
        repetition_penalty=1.1,
        max_tokens=4096,
    )
    llm = LLM(
        MODEL.value,
        gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
        tensor_parallel_size=TENSOR_PARALLEL_SIZE,
        max_model_len=MODEL_MAX_SEQ_LEN,
        max_seq_len_to_capture=MODEL_MAX_SEQ_LEN,
        kv_cache_dtype=KV_CACHE_DTYPE,
        seed=SEED,
    )
    return df, databases, cfg, llm


def agent_baseline(
    agent: TextToSQL, cfg: SamplingParams, df: pd.DataFrame, 
    batch_size: int, savename: str, evaluator_fn: Callable, **kwargs
) -> tuple[pd.DataFrame, str]:
    print(f"Experiment: {savename}_{'' if USE_CACHED_SCHEMA else 'un'}aug_{MODEL.name}")
    outputs, labels = agent.batched_generate(df, cfg, batch_size, savename, evaluator_fn, **kwargs)

    df[f'input_prompts_{savename}'] = outputs.input_prompts
    df[f'n_in_tokens_{savename}']   = outputs.n_in_tokens
    df[f'raw_responses_{savename}'] = outputs.raw_responses
    df[f'n_out_tokens_{savename}']  = outputs.n_out_tokens
    df[f'parsed_sql_{savename}']    = outputs.parsed_sql    
    df[f'label_{savename}']         = labels
    df.to_json(OUTPUT_PATH/f'df_{savename}.json', orient='records')
        
    print(f"Experiment: {savename}_{'' if USE_CACHED_SCHEMA else 'un'}aug_{MODEL.name}_{EXPERIMENT} Successfully Completed.\n\n\n")
    return df

# Run

## Prepare

In [11]:
if __name__ == '__main__':
    df, databases, cfg, llm = setup_experiment()

db_names=['california_schools', 'student_club', 'card_games', 'codebase_community', 'formula_1', 'toxicology', 'superhero', 'thrombosis_prediction', 'debit_card_specializing', 'european_football_2', 'financial'], len(df)=500
INFO 02-24 09:19:48 __init__.py:207] Automatically detected platform cuda.
INFO 02-24 09:19:59 config.py:549] This model supports multiple tasks: {'classify', 'reward', 'embed', 'score', 'generate'}. Defaulting to 'generate'.
INFO 02-24 09:20:01 config.py:1096] Using fp8 data type to store kv cache. It reduces the GPU memory footprint and boosts the performance. Meanwhile, it may cause accuracy drop without a proper scaling factor
INFO 02-24 09:20:01 config.py:1382] Defaulting to use mp for distributed inference
INFO 02-24 09:20:01 llm_engine.py:234] Initializing a V0 LLM engine (v0.7.3) with config: model='/kaggle/input/qwen2.5-coder/transformers/32b-instruct-awq/1', speculative_config=None, tokenizer='/kaggle/input/qwen2.5-coder/transformers/32b-instruct-awq/1', 

TypeError: 'float' object cannot be interpreted as an integer

## Run experiment

In [None]:
if __name__ == '__main__':       
    agent_zs = ZeroShotAgent(llm, databases, OUTPUT_PATH)
    df = agent_baseline(agent_zs, cfg, df, BATCH_SIZE, EXPERIMENT, evaluate)

In [None]:
# if __name__ == '__main__':
#     MultiAgentDiscussion.discuss(
#         df=df[250:].reset_index(), 
#         databases=databases, 
#         llm=llm, 
#         output_path=OUTPUT_PATH, 
#         savename=f'multiag', 
#         batch_size=BATCH_SIZE, 
#         evaluator_fn=evaluate, 
#     )

In [None]:
!rm -r '/kaggle/working/bird-bench'
# cleans up the output directory for easy/fast results download