# Part 5: Run inference and integrate
## Architecture
To utilize the fine-tuned SQL model for Question-Answering purposes, there are three main steps:
1. Convert user question to SQL query
2. Execute the SQL query against the database
3. Use the SQL output to answer the user question

We'll go over these three steps below.

In [1]:
import os
import sys
from dotenv import load_dotenv

current_dir = os.getcwd()
kit_dir = os.path.abspath(os.path.join(current_dir, ".."))
repo_dir = os.path.abspath(os.path.join(kit_dir, ".."))

sys.path.append(kit_dir)
sys.path.append(repo_dir)

## Deploy your model in an Endpoint

1. Log in to you SambaStudio environment
2. Select the LLM checkpoint you want to use and deploy an endpoint for inference. See the [SambaStudio endpoint documentation](https://docs.sambanova.ai/sambastudio/latest/endpoints.html).
3. Get your endpoint information from the deployed endpoint UI

## Load the SQL fine-tuned model

Assume you have an endpoint with the URL
        "https://api-stage.sambanova.net/api/predict/generic/12345678-9abc-def0-1234-56789abcdef0/456789ab-cdef-0123-4567-89abcdef0123" with api key `89abcdef-0123-4567-89ab-cdef01234567`

Provide your endpoint Info:

In [2]:

base_url="https://api-stage.sambanova.net"
base_uri="api/predict/generic"
project_id="12345678-9abc-def0-1234-56789abcdef0"
endpoint_id="456789ab-cdef-0123-4567-89abcdef0123"
api_key="89abcdef-0123-4567-89ab-cdef01234567"

#load_dotenv(os.path.join(repo_dir, ".env"))

True

You can also set the endpoint info as environment variables named as follows:

```bash
SAMBASTUDIO_BASE_URL="https://api-stage.sambanova.net"
SAMBASTUDIO_BASE_URI="api/predict/generic"
SAMBASTUDIO_PROJECT_ID="12345678-9abc-def0-1234-56789abcdef0"
SAMBASTUDIO_ENDPOINT_ID="456789ab-cdef-0123-4567-89abcdef0123"
SAMBASTUDIO_API_KEY="89abcdef-0123-4567-89ab-cdef01234567"
```


In [3]:
from langchain_community.llms.sambanova import SambaStudio
sql_llm = SambaStudio(
    sambastudio_base_url=base_url,
    sambastudio_base_uri=base_uri,
    sambastudio_project_id=project_id,
    sambastudio_endpoint_id=endpoint_id,
    sambastudio_api_key=api_key,
    model_kwargs={
        "do_sample": True, 
        "temperature": 0.01,
        "max_tokens_to_generate": 512
        },
)

## Sample database

We are using sample [Chinook dataset](https://database.guide/2-sample-databases-sqlite/). The sample database is located in [databases folder](./databases/).

In [4]:
from langchain_community.utilities import SQLDatabase

db_path = "databases/Chinook.db"
db_uri = f"sqlite:///{db_path}"
db = SQLDatabase.from_uri(db_uri)
print(db.get_usable_table_names())
db.run("SELECT * FROM Genre;")

['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'Rock'), (2, 'Jazz'), (3, 'Metal'), (4, 'Alternative & Punk'), (5, 'Rock And Roll'), (6, 'Blues'), (7, 'Latin'), (8, 'Reggae'), (9, 'Pop'), (10, 'Soundtrack'), (11, 'Bossa Nova'), (12, 'Easy Listening'), (13, 'Heavy Metal'), (14, 'R&B/Soul'), (15, 'Electronica/Dance'), (16, 'World'), (17, 'Hip Hop/Rap'), (18, 'Science Fiction'), (19, 'TV Shows'), (20, 'Sci Fi & Fantasy'), (21, 'Drama'), (22, 'Comedy'), (23, 'Alternative'), (24, 'Classical'), (25, 'Opera')]"

## Convert user question to SQL query
Using the fine-tuned SQL model

### Chain and utils for query generation

#### Query parsing

In [5]:
import re
# method to extract a query from a string generation
def sql_finder(text: str) -> str:
    """Search in a string for a SQL query or code with format"""

    # regex for finding sql_code_pattern with format:
    # ```sql
    #    <query>
    # ```
    sql_code_pattern = re.compile(r'```sql\s+(.*?)\s+```', re.DOTALL)
    match = sql_code_pattern.search(text)
    if match is not None:
        query = match.group(1)
        return query
    else:
        # regex for finding sql_code_pattern with format:
        # ```
        # <quey>
        # ```
        code_pattern = re.compile(r'```\s+(.*?)\s+```', re.DOTALL)
        match = code_pattern.search(text)
        if match is not None:
            query = match.group(1)
            return query
        else:
            raise Exception('No SQL code found in LLM generation')

#### Prompt

In [6]:
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableLambda

#Prompt to use whit the fine tunned model

prompt = PromptTemplate.from_template(
    """[INST]<<SYS>>
    Generate a query using valid SQLite to answer the following questions for the summarized tables schemas provided bellow.
    Do not assume the values on the database tables before generating the SQL query, always generate a SQL that query what is asked. 
    The query must be in the format: ```sql\nquery\n```
    
    Example:
    
    ```sql
    SELECT * FROM mainTable;
    ```
    <</SYS>>
        
    {table_info}
        
    {question}
    [/INST]"""
)

Below is the full formated prompt for a test query. Few points to notice:
- Its lists all the tables in the database. Look at the output of ```db.get_usable_table_names()``` above.
- The prompt template is filled with the db schema as context, and then the natural language question or instruction from the user


In [7]:
print(prompt.format(question= "how many music genres are in the db?", table_info = db.get_table_info()))

[INST]<<SYS>>
    Generate a query using valid SQLite to answer the following questions for the summarized tables schemas provided bellow.
    Do not assume the values on the database tables before generating the SQL query, always generate a SQL that query what is asked. 
    The query must be in the format: ```sql
query
```
    
    Example:
    
    ```sql
    SELECT * FROM mainTable;
    ```
    <</SYS>>
        
    
CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE "Customer" (
	"CustomerId" IN

#### Query generation chain

In [9]:
from langchain_core.runnables import RunnablePassthrough
# chain including, prompt formatting, trained llm invocation and sql query parsing. 
sql_generation_chain = RunnablePassthrough.assign(table_info=lambda x: table_info) | prompt | sql_llm | RunnableLambda(sql_finder)
table_info = db.get_table_info()
response = sql_generation_chain.invoke({'question': "how many music genres are in the db?"})
print(response)

SELECT COUNT(*) 
    FROM Genre;


## Execute the SQL query against the database

In [10]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
execution_chain = sql_generation_chain | execute_query
# chain.invoke({"question": "How many employees are there"})
execution_chain.invoke({"question": "How many genres of music are there"})

'[(25,)]'

## Step 3: Use the SQL output to answer the user question

In [13]:
from langchain_core.output_parsers import StrOutputParser
from operator import itemgetter

# general answer prompt template
answer_prompt = PromptTemplate.from_template(
    """
    [INST]<<SYS>>
    Given the following user question, corresponding SQL query, and SQL result, answer the user question.
    <</SYS>>
    Question: {question} #input
    SQL Query: {query} #firs llm
    SQL Result: {result} #execution
    [/INST]
    Answer: """
) 

# answer generation chain including general answer prompt, cLLm invocation and str output parsing
answer_chain = answer_prompt | sql_llm | StrOutputParser()

# final chain including query generation, query executions and final answer generation
chain = (
    RunnablePassthrough.assign(query=sql_generation_chain).assign(
        result= itemgetter("query") | execute_query
    )
    | answer_chain
)

chain.invoke({"question": "How many genres of music are there"})

'25'