In [None]:
#Set up usage of .env file (storage of Lamini API Key)
from dotenv import load_dotenv
_ = load_dotenv()  

In [None]:
#Set up of imports, logging, database, a class for arguments

import logging
import random
from typing import AsyncIterator, Iterator, Union
import sqlite3
import copy
from tqdm import tqdm
import os
from datetime import datetime
from pprint import pprint

import pandas as pd
import jsonlines
import lamini 
from lamini.generation.base_prompt_object import PromptObject
from lamini.generation.generation_node import GenerationNode
from lamini.generation.generation_pipeline import GenerationPipeline
from util.get_schema import get_schema, get_schema_altered
from util.make_llama_3_prompt import make_llama_3_prompt
from util.setup_logging import setup_logging
from util.load_dataset import get_dataset 
from util.get_default_finetune_args import get_default_finetune_args

logger = logging.getLogger(__name__)
engine = sqlite3.connect("SpotifyData.db")
setup_logging()

class Args:
    """
    Set-up of arguments with defaults
    """
    def __init__(self, 
                 max_examples=100, 
                 sql_model_name="meta-llama/Meta-Llama-3-8B-Instruct", 
                 file_name="test-set.jsonl",
                 training_file_name="generated_queries.jsonl",
                 num_to_generate=10):
        self.sql_model_name = sql_model_name
        self.max_examples = max_examples
        self.file_name = file_name
        self.training_file_name = training_file_name
        self.num_to_generate = num_to_generate

In [None]:
class QueryStage(GenerationNode):
    """
    Generates SQL queries based on natural language prompts.
    This class extends the `GenerationNode` class and provides methods for generating SQL queries plus running them on a database to evaluate their success.
    """
    
    def __init__(self, model_name):
        super().__init__(
            model_name=model_name,
            max_new_tokens=300,
        )

    def generate( 
        self,
        prompt: Union[Iterator[PromptObject], AsyncIterator[PromptObject]],
        *args,
        **kwargs,
    ):
        """
        Generate SQL query based on prompt. Sets output type to a SQLite query.
        """
        results = super().generate(
            prompt,
            output_type={"sqlite_query": "str"},
            *args,
            **kwargs,
        )
        return results


    def postprocess(self, obj: PromptObject):
        """
        Run both the generated and reference SQL queries to assess whether the SQL queries succeeded in hitting the database (not correctness yet)
        """
        
        query_succeeded = False

        try:
            logger.info(f"Running SQL query '{obj.response['sqlite_query']}'")
            obj.data["generated_query"] = obj.response["sqlite_query"]
            df = pd.read_sql(obj.response["sqlite_query"], con=engine)
            obj.data['df'] = df
            logger.info(f"Got data: {df}")
            query_succeeded = True

        except Exception as e:
            logger.error(
                f"Failed to run SQL query: {obj.response['sqlite_query']}"
            )

        logger.info(f"Running reference SQL query '{obj.data['sql']}'")
        df = pd.read_sql(obj.data["sql"], con=engine)
        logger.info(f"Got data: {df}")
        obj.data['reference_df'] = df

        logger.info(f"For question: {obj.data['question']}")
        logger.info(f"For query: {obj.response['sqlite_query']}")

        obj.data["query_succeeded"] = query_succeeded

    def preprocess(self, obj: PromptObject):
        """
        Define full user and system prompts using make_prompt() and make_llama_3_prompt() and assign to PromptObject
        """
        new_prompt = make_llama_3_prompt(**self.make_prompt(obj.data))
        obj.prompt = new_prompt

    def make_prompt(self, data: dict):
        """
        Returns a dictionary with user and system prompts
        """
        system = "You are a data analyst with 15 years of experience writing complex SQL queries.\n"
        system += "Consider the songs table with the following schema:\n"
        system += get_schema_altered() + "\n"
        system += "Only return the data required to answer the question in the SQL, don't include other columns or rows. \n" #Excluded in 1st iteration
        system += (
            "Write a sqlite SQL query that would help you answer the following question. Make sure each query ends with a semicolon:\n"
        )
        user = data["question"]
        return {
            "user": user,
            "system": system,
        }

#########################################

class ScoreStage(GenerationNode):
    """
    Evaluates the similarity between LLM-generated and reference SQL queries.
    This class extends the `GenerationNode` class and provides methods for comparing the dataframes generated by SQL queries. 
    """
    def __init__(self):
        super().__init__(
            model_name="meta-llama/Meta-Llama-3-8B-Instruct",
            max_new_tokens=150,
        )

    def generate(
        self,
        prompt: Union[Iterator[PromptObject], AsyncIterator[PromptObject]],
        *args,
        **kwargs,
    ):
        """
        Generates an explanation and similarity score for the given prompt
        """
        results = super().generate(
            prompt,
            output_type={"explanation": "str", "similar": ["true", "false"]},
            *args,
            **kwargs,
        )
        return results

    def preprocess(self, obj: PromptObject):
        """
        Prepares the prompt for the evaluation model (comparison of the 2 dataframes) by defining the system and user prompts
        """
        obj.prompt = make_llama_3_prompt(**self.make_prompt(obj))
        logger.info(f"Scoring Stage Prompt:\n{obj.prompt}")

    def postprocess(self, obj: PromptObject):
        """
        Updates the `PromptObject` with the evaluation results, including whether the dataframes are similar and an explanation
        """
        obj.data['is_matching'] = self.is_matching(obj.data, obj.response)
        obj.data['explanation'] = obj.response["explanation"]
        obj.data['similar'] = obj.response["similar"] == "true"

    def is_matching(self, data, response):
        """
        Determines whether the dataframes from the LLM and the test set are the same or similar
        """
        return (str(data.get('df',"None")).lower() == str(data['reference_df']).lower() 
                or response['similar'] == "true")

    def make_prompt(self, obj: PromptObject): 

        """
        The evaluation model compares SQL output from the LLM-generated and reference SQL queries, using another LLM in the pipeline
        """

        system_prompt = "Compare the following two dataframes. They are similar if they are almost identical, or if they convey the same information about the songs dataset"
        system_prompt += "Respond with valid JSON {'explanation' : str, 'similar' : bool}"
        user_prompt = (
            f"========== Dataframe 1 =========\n{str(obj.data.get('df','None')).lower()}\n\n"
        )
        user_prompt += (
            f"========== Dataframe 2 =========\n{str(obj.data['reference_df']).lower()}\n\n"
        )
        user_prompt += f"Can you tell me if these dataframes are similar?"
        return {
            "system": system_prompt,
            "user": user_prompt
        }

#######################################################

async def run_eval(dataset, args):
    """
    Asynchronously runs run_evaluation_pipeline and prints the length of the results
    """

    results = await run_evaluation_pipeline(dataset, args)

    print("Total results:", len(results))

    return results

async def run_evaluation_pipeline(dataset, args):
    """
    Asynchronously runs the evaluation pipeline.
    Returns a list of evaluation results.
    """
    results = EvaluationPipeline(args).call(dataset)

    result_list = []

    pbar = tqdm(desc="Saving results", unit=" results") #creates a progress bar using the tqdm library
    async for result in results:
        result_list.append(result)
        pbar.update()
    return result_list

#######################################################

class EvaluationPipeline(GenerationPipeline): 
    """
    Parent class: https://github.com/lamini-ai/lamini/blob/main/lamini/generation/generation_pipeline.py
    """
    
    def __init__(self, args):
        super().__init__()
        self.query_stage = QueryStage(args.sql_model_name) #generating queries
        self.score_stage = ScoreStage() #evaluating queries


    def forward(self, x):
        """
        Serves as the processing flow, passing the data through the query_stage and then the score_stage.
        Called via the "call" method in EvaluationPipeline instansiation during run_evaluation_pipeline()
        """
        x = self.query_stage(x)
        x = self.score_stage(x)
        return x

#######################################################
    
def load_dataset(args):
    """
    Loads original example queries file (test-set-expanded.jsonl) and creates a PromptObject for each data item
    """
    
    path = f"data/{args.file_name}"

    with jsonlines.open(path) as reader:
        for index, obj in enumerate(reversed(list(reader))):
            if index >= args.max_examples:
                break
            yield PromptObject(prompt="", data=obj)

def save_eval_results(results, args):
    """
    Saves the evaluation results to files
    """
    base_path = "./data/results"
    now = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    experiment_name = f"spotify_sql_pipeline_{now}"
    experiment_dir = os.path.join(base_path, experiment_name)
    os.makedirs(os.path.join(base_path, experiment_name))

    # Write args to file
    args_file_name = f"{experiment_dir}/args.txt"
    with open(args_file_name, "w") as writer:
        pprint(args.__dict__, writer)

    def is_correct(r):
        if (
            (result.data["query_succeeded"] and result.data['is_matching']) or 
            result.data["generated_query"] == result.data['sql']
        ):
            return True
        return False

    # Write sql results and errors to file
    results_file_name = f"{experiment_dir}/sql_results.jsonl"
    with jsonlines.open(results_file_name, "w") as writer:
        for result in results:
            if not is_correct(result):
                continue
            writer.write(
                {
                    "question": result.data['question'],
                    "query": result.data["generated_query"],
                    "query_succeeded": result.data["query_succeeded"],
                    "reference_sql": result.data['sql'],
                    "df": str(result.data.get('df', 'None')),
                    "reference_df": str(result.data['reference_df']),
                    'is_matching': result.data['is_matching'],
                    'similar': result.data['similar'],
                }
            )

    results_file_name = f"{experiment_dir}/sql_errors.jsonl"
    with jsonlines.open(results_file_name, "w") as writer:
        for result in results:
            if is_correct(result):
                continue
            writer.write(
                {
                    "question": result.data['question'],
                    "query": result.data["generated_query"],
                    "query_succeeded": result.data["query_succeeded"],
                    "df": str(result.data.get('df', 'None')),
                    "reference_df": str(result.data['reference_df']),
                    'is_matching': result.data['is_matching'],
                    'similar': result.data['similar'],
                }
            )

    # Write statistics to file
    average_sql_succeeded = sum(
        [result.data["query_succeeded"] for result in results]
    ) / len(results)
    average_correct = sum(
        [result.data["query_succeeded"] and result.data['is_matching'] for result in results]
    ) / len(results)

    file_name = f"{experiment_dir}/summary.txt"
    with open(file_name, "w") as writer:
        print(f"Total size of eval dataset: {len(results)}", file=writer)
        print(f"Total size of eval dataset: {len(results)}")
        print(f"Percent Valid SQL Syntax: {average_sql_succeeded*100}", file=writer)
        print(f"Percent Valid SQL Syntax: {average_sql_succeeded*100}")
        print(f"Percent Correct SQL Query: {average_correct*100}", file=writer)
        print(f"Percent Correct SQL Query: {average_correct*100}")


In [None]:
#LLM-generation of SQL and evaluation of SQL

args = Args()
dataset = load_dataset(args)
results = await run_eval(dataset, args) 
save_eval_results(results, args)

In [None]:
class ModelStage(GenerationNode):
    """
    Generates SQL queries to create training data for LLM fine-tuning, using an LLM
    """
    def __init__(self):
        super().__init__(
            model_name="meta-llama/Meta-Llama-3-8B-Instruct",
            max_new_tokens=300,
        )

    def generate(
        self,
        prompt: Union[Iterator[PromptObject], AsyncIterator[PromptObject]],
        *args,
        **kwargs,
    ):
        """
        Add_template() called to add in Llama 3 prompt tokens and the rest of the prompt.
        LLM results returned as "explanation", "sql_query_1", "sql_query_2"
        """
        prompt = self.add_template(prompt)

        results = super().generate(
            prompt,
            output_type={
                "explanation": "str",
                "sql_query_1": "str",
                "sql_query_2": "str",
            },
            *args,
            **kwargs,
        )

        return results

    async def add_template(self, prompts):
        """
        Creates prompt and adds in prompt tokens for Llama 3 models
        """
        async for prompt in prompts:
            new_prompt = make_llama_3_prompt(**self.make_prompt(prompt.data))
            yield PromptObject(prompt=new_prompt, data=prompt.data)

    async def process_results(self, results):
        """
        Iterates through the generated results: filters out invalid or empty responses and
        extracts valid SQL queries
        """
        
        async for result in results:
            if result is None:
                continue

            if result.response is None:
                continue

            logger.info("=====================================")
            logger.info(f"Generated query 1: {result.response['sql_query_1']}")
            logger.info(f"Generated query 2: {result.response['sql_query_2']}")
            logger.info("=====================================")

            if self.check_sql_query(result.response["sql_query_1"]):
                new_result = PromptObject(prompt="", data=copy.deepcopy(result.data))
                new_result.data.generated_sql_query = result.response["sql_query_1"]
                yield new_result

            if self.check_sql_query(result.response["sql_query_2"]):
                new_result = PromptObject(prompt="", data=copy.deepcopy(result.data))
                new_result.data.generated_sql_query = result.response["sql_query_2"]
                yield new_result

    def make_prompt(self, data):
        """
        Create prompt that creates more SQL queries
        """
        
        system = "You are a data analyst with 15 years of experience writing complex SQL queries.\n"
        system += (
            "Consider a table called 'songs' with the following schema (columns). \n"
        )
        system += get_schema_altered()
        system += "Only return the data required to answer the question in the SQL, don't include other columns or rows. \n"
        system += "Consider the following questions, and queries used to answer them:\n"
        for example in data.sample:
            system += "Question: " + example["question"] + "\n"
            system += "Query: " + example["sql"] + "\n"
        
        user = "Write two queries that are similar but different to those above.\n"
        user += "Format the queries as a JSON object, i.e.\n"
        user += '{ "explanation": str, "sql_query_1" : str, "sql_query_2": str }.\n'

        user += "First write an explanation of why you decided to write these new queries in about 3-5 sentences, then write valid sqlite SQL queries for each of the 2 new queries. Make sure each query is complete and ends with a ;\n"

        return {"system": system, "user": user}

    def check_sql_query(self, query):
        """
        Checks validity of SQL against database
        """
        try:
            pd.read_sql(query, con=engine)
        except Exception as e:
            logger.debug(f"Error in SQL query: {e}")
            return False

        logger.info(f"SQL query {query} is valid")

        return True


In [None]:
class QuestionStage(GenerationNode):
    """
    Generates a question, given a query
    """
    def __init__(self):
        super().__init__(
            model_name="meta-llama/Meta-Llama-3-8B-Instruct",
            max_new_tokens=150,
        )

    def generate(
        self,
        prompt: Union[Iterator[PromptObject], AsyncIterator[PromptObject]],
        *args,
        **kwargs,
    ):
        """
        Returns an explanation and a question, to match the SQL
        """
        results = super().generate(
            prompt,
            output_type={
                "explanation": "str",
                "question": "str",
            },
            *args,
            **kwargs,
        )
        return results

    def preprocess(self, obj: PromptObject):
        """
        Creates prompt and adds in prompt tokens for Llama 3 models
        """
        new_prompt = make_llama_3_prompt(**self.make_question_prompt(obj.data))
        obj.prompt = new_prompt

    def make_question_prompt(self, data):
        """
        Creates the prompt for the question creation. Uses Chain of thought (CoT) prompting. 
        """
        system = "You are a data analyst with 15 years of experience writing complex SQL queries.\n"
        system += (
            "Consider a table called 'songs' with the following schema (columns)\n"
        )
        system += get_schema() + "\n"
        system += "Queries, and questions that they are used to answer:\n"
        for example in data.sample:
            system += "Query: " + example["sql"] + "\n"
            system += "Question: " + example["question"] + "\n"

        user = "Now consider the following query.\n"
        user += "Query: " + data.generated_sql_query + "\n"
        user += "Write a question that this query could be used to answer.\n"

        user += "Format your response as a JSON object, i.e.\n"
        user += '{ "explanation": str, "question": str }.\n'

        user += "First write an explanation in about 3-5 sentences, then write a one sentence question.\n"

        return {"system": system, "user": user}


In [None]:
class QueryGenPipeline(GenerationPipeline):
    """
    Represents a pipeline for generating SQL queries from a given dataset.
    It consists of two stages: a model stage and a question stage.
    """
    def __init__(self):
        super().__init__()
        self.model_stage = ModelStage()
        self.question_stage = QuestionStage()

    def forward(self, x):
        """
        Passes the input through the model stage and then the question stage,
        returning the final processed output.
        """
        x = self.model_stage(x)
        x = self.question_stage(x)
        return x

async def run_query_gen_pipeline(dataset_queries):
    """
    Runs the query generation pipeline
    """
    return QueryGenPipeline().call(dataset_queries)

In [None]:
#Generate N samples, for every example in the dataset

all_examples = []

async def load_queries(args):
    """
    Reads the original user-created SQL from a file and randomly samples the queries for each generation round
    """
    path = f"data/{args.file_name}" #i.e. original user-created set of queries

    with jsonlines.open(path) as reader:
        global all_examples

        all_examples = [obj for obj in reader]

    sample_count = args.num_to_generate #number of examples to generate per round
    sample_size = 3

    random.seed(42)

    for i in range(sample_count):
        example_sample = ExampleSample(random.sample(all_examples, sample_size), i)
        yield PromptObject(prompt="", data=example_sample)


class ExampleSample:
    """
    Represents the sampled subset of user-generated queries
    """
    def __init__(self, sample, index):
        self.sample = sample
        self.index = index

In [None]:
async def save_generation_results(results, args):
    """
    Saves the generated questions and SQL to a file
    """
    path = f"data/training_data/{args.training_file_name}"

    pbar = tqdm(desc="Saving results", unit=" results") #progress bar
    with jsonlines.open(path, "w") as writer:

        async for result in results:
            writer.write(
                {
                    "question": result.response["question"],
                    "sql": result.data.generated_sql_query,
                }
            )
            pbar.update()

        for example in all_examples:
            writer.write(example)
            pbar.update()

In [None]:
#Create training data ie more SQL and questions for LLM fine-tuning, using LLM. 

args = Args()
dataset_queries = load_queries(args)
results = await run_query_gen_pipeline(dataset_queries)
await save_generation_results(results, args)

In [None]:
#Set-up for fine tuning using the extra LLM-generated data

def make_question(obj):
    system = "You are a data analyst with 15 years of experience writing complex SQL queries.\n"
    system += "Consider the 'song' table with the following schema:\n"
    system += get_schema_altered() + "\n"
    system += "Only return the data required to answer the question in the SQL, don't include other columns or rows. \n"
    system += (
        "Write a sqlite SQL query that would help you answer the following question:\n"
    )
    user = obj["question"]
    return {"system": system, "user": user}

args = Args()
llm = lamini.Lamini(model_name="meta-llama/Meta-Llama-3-8B-Instruct")

dataset = get_dataset(args, make_question)

finetune_args = get_default_finetune_args()


In [None]:
#Fine tuning

"""
llm.train(
    data_or_dataset_id=dataset,
    finetune_args=finetune_args,
    is_public=False,  # For sharing
)
"""

In [None]:
#Evaluation of fine-tuned model against original user-created dataset

args = Args(sql_model_name="place_holder")
dataset = load_dataset(args)
results = await run_eval(dataset, args)
save_eval_results(results, args)