In [34]:
# Importing the necessary libraries:
import os
import pandas as pd
import numpy as np
import transformers
import langchain
import pymysql
from langchain import HuggingFaceHub, SQLDatabase, PromptTemplate, FewShotPromptTemplate
from langchain.chains import SQLDatabaseChain

# Setting the environment for the Hugging face API:
os.environ["HUGGINGFACEHUB_API_TOKEN"] = "........"

In [37]:
# Prompt Attempt:
model = "google/flan-t5-xxl" #'Sandiago21/llama-7b-hf-prompt-answering'

llm = HuggingFaceHub(repo_id=model, model_kwargs={"temperature":0.6, "max_length":512})

prompt = "Alice has a parrot. What animal is Alice's pet?"
print(llm(prompt))

parrot


In [38]:
sql_uri = "mysql+pymysql://root:pass@localhost:3306/wine_store"

In [39]:
db = SQLDatabase.from_uri(sql_uri)

In [40]:
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

In [41]:
PROMPT = """ 
Given an input question, first create a syntactically correct MySQL query to run,  
then look at the results of the query and return the answer.  
The question: {question}
"""

In [42]:
question = "what is the average price from the table purchasepricesdec?" 

db_chain.run(PROMPT.format(question=question))



[1m> Entering new  chain...[0m
 
Given an input question, first create a syntactically correct MySQL query to run,  
then look at the results of the query and return the answer.  
The question: what is the average price from the table purchasepricesdec?

SQLQuery:[32;1m[1;3mSELECT avg(Price) FROM purchasepricesdec[0m
SQLResult: [33;1m[1;3m[(38.65188805482593,)][0m
Answer:[32;1m[1;3m38.65188805482593[0m
[1m> Finished chain.[0m


'38.65188805482593'

In [43]:
# Few Shot:
examples = [
  {"question": "What is the average of the PurchasePrice in the table purchasepricesdec?",
    "answer": "26.48821956"},
  {"question": "How many unique values do we have in the Classification column in the table purchasepricesdec?",
    "answer": "2"},
  {"question": "How many rows do we have in the table purchasepricesdec?",
    "answer":"12262"},
  {"question": "What is the mode of the Volume column in purchasepricesdec table?",
    "answer":"750"}
]

example_prompt = PromptTemplate(input_variables=["question", "answer"], template="Question: {question}\n{answer}")

prompt = FewShotPromptTemplate(
    examples=examples, 
    example_prompt=example_prompt, 
    suffix="Question: {input}", 
    input_variables=["input"]
)

db_chain.run(prompt.format(input="What is the mean of the Volume column in the purchasepricesdec table?"))



[1m> Entering new  chain...[0m
Question: What is the average of the PurchasePrice in the table purchasepricesdec?
26.48821956

Question: How many unique values do we have in the Classification column in the table purchasepricesdec?
2

Question: How many rows do we have in the table purchasepricesdec?
12262

Question: What is the mode of the Volume column in purchasepricesdec table?
750

Question: What is the mean of the Volume column in the purchasepricesdec table?
SQLQuery:[32;1m[1;3mSELECT avg(volume) FROM purchasepricesdec[0m
SQLResult: [33;1m[1;3m[(Decimal('842.9215'),)][0m
Answer:[32;1m[1;3m842.9215[0m
[1m> Finished chain.[0m


'842.9215'