In [23]:
from sqlalchemy import inspect
from sqlalchemy import create_engine
from sqlalchemy import text
import pandas as pd

In [24]:

# Database Path
#db_file_path = r"C:\Users\vivia\co-pilot-v1\data\databases\partswise_island_moto.db"
db_file_path = r"C:\Users\vivia\co-pilot-v1\data\databases\parts_database.db"
engine = create_engine(f"sqlite:///{db_file_path}")

In [25]:
from json import loads, dumps

# getting parts info: part number, model, price, quantity, brand, description, and year
with engine.connect() as connection:
    query = text("""SELECT p.part_number, m.model_name, p.price, p.quantity, p.brand, p.description, y.year
                    FROM model_year my
                    JOIN models m ON my.model_id=m.id
                    JOIN parts p ON my.part_number=p.part_number
                    JOIN years y ON my.year_id=y.id
                 """)
    parts_df = connection.execute(query)
    
parts_df = pd.DataFrame(parts_df.fetchall(), columns=['partNumber', 'modelName', 'price', 'quantity', 'brand', 'description', 'year'])
parts_df.to_csv(r'C:\Users\vivia\co-pilot-v1\Notebooks\parts.csv', index=False)

In [7]:
import os
import openai

os.environ["OPENAI_API_KEY"] = "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"
openai.api_key = os.environ["OPENAI_API_KEY"]


In [47]:
import sqlite3
import csv

csv_path = r'C:\Users\vivia\co-pilot-v1\Notebooks\parts.csv'

def csv_to_sql(csv_path):
    db_path = r'C:\Users\vivia\co-pilot-v1\Notebooks\bulk_upload.db'
    conn = sqlite3.connect(db_path)
    cur = conn.cursor()

    with open(csv_path) as f:
        reader = csv.reader(f)
        data = list(reader)

        rows_dict = {
            'partNumber': 'VARCHAR',
            'modelName': 'VARCHAR',
            'price': 'INT',
            'quantity': 'INT',
            'brand': 'VARCHAR',
            'description': 'VARCHAR',
            'year': 'INT'
        }

    cur.execute('''DROP TABLE IF EXISTS parts''')

    cur.execute('''CREATE TABLE parts (
                partNumber VARCHAR,
                modelName VARCHAR,
                price INT,
                quantity INT,
                brand VARCHAR,
                description VARCHAR,
                year INT
                )''' )
    for row in data:
        cur.execute("INSERT INTO parts (partNumber, modelName, price, quantity, brand, description, year) values (?, ?, ?, ?, ?, ?, ?)", row)

    conn.commit()
    conn.close()

    bulk_upload_db = r'C:\Users\vivia\co-pilot-v1\Notebooks\bulk_upload.db'
    engine = create_engine(f"sqlite:///{bulk_upload_db}")

    return engine

In [48]:
from llama_index.core import SQLDatabase, VectorStoreIndex
from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.llms.openai import OpenAI

engine = csv_to_sql(csv_path)

def setup_nlsql_query_engine():
    def initialize_table_objects():
        sql_database = SQLDatabase(engine, sample_rows_in_table_info=2, include_tables=['parts'])

        parts_context = "Provides detailed inventory data for individual parts, including the model name, price, quantity, brand, year, and a description of the part."

        table_context_str = "The Table description is: " + parts_context

        table_node_mapping = SQLTableNodeMapping(sql_database)
        table_schema_objs = [
            SQLTableSchema(table_name='parts', context_str=parts_context),
        ]
        obj_index = ObjectIndex.from_objects(
            table_schema_objs,
            table_node_mapping,
            VectorStoreIndex,
        )
        return sql_database, table_schema_objs, obj_index, table_context_str

    sql_database, table_schema_objs, obj_index, table_context_str = initialize_table_objects()

    context_str = (
        "Convert percentages to decimals (e.g., '50%' as '0.5')."
        "Use LOWER for searching brand or modelName because of case sensitivity."
        "Pay close attention to filtering criteria mentioned in the question and incorporate them using the WHERE clause in your SQL query."
        "If the question involves multiple conditions, use logical operators such as AND, OR to combine them effectively."
        "If the question involves grouping of data (e.g., finding totals or averages for different categories), use the GROUP BY clause along with appropriate aggregate functions."
        "Consider using aliases for tables and columns to improve readability of the query, especially in case of complex joins or subqueries."
        "If necessary, use subqueries or common tables expressions (CTEs) to break down the problem into smaller, more manageable parts."
        "Ensure detailed, relevant responses."

    ) 
    context_str_combined = context_str + "\n\n" + table_context_str

    query_engine = SQLTableRetrieverQueryEngine(
        sql_database=sql_database,
        table_retriever=obj_index.as_retriever(similarity_top_k=1),
        synthesize_response=True,
        llm=OpenAI(temperature=0.1, model="gpt-3.5-turbo-0125"),
        context_str_prefix=context_str_combined
    )

    return query_engine

query_engine = setup_nlsql_query_engine()

def process_user_input_to_sql(user_input):
    response = query_engine.query(user_input)
    print(f"RESPONSE: {response.metadata}")
    sql_query = response.metadata.get('sql_query', '').replace('\n', ' ').replace('\r', ' ').strip()
    print(f"SQL QUERY after adjustment: {sql_query}")
    if sql_query.startswith('sql'):
        sql_query = sql_query[3:].strip()
    print(f"SQL: {sql_query}")
    return sql_query

def query_output(user_input):
    sql_query = process_user_input_to_sql(user_input)
    print(f"SQL QUERY Output: {sql_query}")

    with engine.connect() as connection:
        result = connection.execute(text(sql_query))
        result_data = result.fetchall()
        if len(result_data) >= 5:
            result_df = pd.DataFrame(result_data, columns=result.keys())
            return result_df
        else:
            response = query_engine.query(sql_query)
            return str(response)

user_input = "Get all bmw parts from 2003"
response = query_output(user_input)
print(response)

RESPONSE: {'24a65226-929b-40ab-b888-4197c3e26dd1': {'sql_query': "SELECT modelName, price, quantity, brand, year\nFROM parts\nWHERE LOWER(brand) = 'bmw' AND year = 2003\nORDER BY price DESC;", 'result': [(' K1200 R Sport 4Kx 4Cyl.', 1556.5, 1, 'Bmw', 2003), (' K1200 R 4Kx 4Cyl.', 1556.5, 1, 'Bmw', 2003), ('K1200 S 4Kx 4Cyl.', 1556.5, 1, 'Bmw', 2003), ('  R1200 05 St', 891.0800170898438, 1, 'Bmw', 2003), ('R1200 Rt 05', 891.0800170898438, 1, 'Bmw', 2003), (' R900 Rt 10 Sf', 891.0800170898438, 1, 'Bmw', 2003), (' R900Rt 05 Sf', 891.0800170898438, 1, 'Bmw', 2003), ('R 1200 Rt', 400.9800109863281, 1, 'Bmw', 2003), ('K1200 S', 194, 1, 'Bmw', 2003), (' K1300S', 194, 1, 'Bmw', 2003)], 'col_keys': ['modelName', 'price', 'quantity', 'brand', 'year']}, 'sql_query': "SELECT modelName, price, quantity, brand, year\nFROM parts\nWHERE LOWER(brand) = 'bmw' AND year = 2003\nORDER BY price DESC;", 'result': [(' K1200 R Sport 4Kx 4Cyl.', 1556.5, 1, 'Bmw', 2003), (' K1200 R 4Kx 4Cyl.', 1556.5, 1, 'Bmw', 