# Create the Model
[Documentation](https://github.com/google/generative-ai-docs/blob/main/examples/gemini/python/langchain/Gemini_LangChain_Summarization_WebLoad.ipynb)

## Setup Environment

In [4]:
from langchain import PromptTemplate
from langchain.document_loaders import WebBaseLoader
from langchain.schema import StrOutputParser
from langchain.schema.prompt_template import format_document
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI
import os
import getpass


if not os.environ.get("OPENAI_API_KEY"):
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API key: ")

USER_AGENT environment variable not set, consider setting it to identify your requests.
  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from langchain_groq import ChatGroq


if "GROQ_API_KEY" not in os.environ:
    os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your Groq API key: ")

if "GOOGLE_API_KEY" not in os.environ:
    os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter your Google AI API key: ")


## Connect to DB

In [6]:
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from dotenv import load_dotenv

load_dotenv()

DB_USER=os.environ.get('DB_USER')
DB_PASS=os.environ.get('DB_PASS')

In [7]:
db = SQLDatabase.from_uri(f"mysql+pymysql://{DB_USER}:{DB_PASS}@localhost/yelp_db",sample_rows_in_table_info=3)

In [5]:
print(db.table_info)


CREATE TABLE market_reviews (
	business_id VARCHAR(255) NOT NULL, 
	name VARCHAR(255), 
	address VARCHAR(255), 
	city VARCHAR(100), 
	state VARCHAR(50), 
	postal_code VARCHAR(20), 
	latitude DECIMAL(10, 7), 
	longitude DECIMAL(10, 7), 
	stars DECIMAL(2, 1), 
	review_count INTEGER, 
	is_open TINYINT(1), 
	attributes TEXT, 
	categories TEXT, 
	hours TEXT, 
	review_id VARCHAR(255) NOT NULL, 
	user_id VARCHAR(255), 
	review_stars DECIMAL(2, 1), 
	useful INTEGER, 
	funny INTEGER, 
	cool INTEGER, 
	text TEXT, 
	date DATETIME, 
	sentiment ENUM('positive','negative','neutral')
)ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4

/*
3 rows from market_reviews table:
business_id	name	address	city	state	postal_code	latitude	longitude	stars	review_count	is_open	attributes	categories	hours	review_id	user_id	review_stars	useful	funny	cool	text	date	sentiment
ytynqOUb3hjKeJfRj5Tshw	Reading Terminal Market	51 N 12th St	Philadelphia	PA	19107	39.9533415	-75.1588545	4.5	5721	1	{'Restaurant

In [29]:
# llm = ChatGoogleGenerativeAI(model='gemini-1.5-flash',temperature=0.8)
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
# llm = ChatGroq(model="llama3-groq-70b-8192-tool-use-preview", temperature=0)
chain = create_sql_query_chain(llm, db)
response = chain.invoke(
    {
        "question": "How many businesses were rated at least 4.5 stars between 2022-01-10 and 2022-05-10"
    }
)
# qns1 = db_chain("How many reviews in the market_reviews table rated a 5.0?")

In [30]:
response

"SELECT COUNT(DISTINCT business_id) AS rated_businesses\nFROM market_reviews\nWHERE stars >= 4.5\nAND date BETWEEN '2022-01-10' AND '2022-05-10';"

In [31]:
db.run(response)

'[(24,)]'

In [32]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

write_query = create_sql_query_chain(llm,db)
execute_query = QuerySQLDataBaseTool(db=db)

chain = write_query | execute_query

chain.invoke(
    {"question":"How many businesses were rated at least 4.5 stars between 2022-01-10 and 2022-05-10"
})

'[(24,)]'

In [33]:
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

In [34]:
answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

In [None]:
def get_schema(_):
    return db.get_table_info()

In [35]:
answer = answer_prompt | llm | StrOutputParser()
chain = (
    RunnablePassthrough.assign(schema = get_schema,query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

chain.invoke(
    {"question":"How many businesses were rated at least 4.5 stars between 2022-01-10 and 2022-05-10"
})

'There were 24 businesses that were rated at least 4.5 stars between January 10, 2022, and May 10, 2022.'

### Agent which provides a more flexible way of interacting with SQL databases. The main advantages of using the SQL Agent are:

* It can answer questions based on the databases’ schema as well as on the databases’ content (like describing a specific table).
* It can recover from errors by running a generated query, catching the traceback and regenerating it correctly.
* It can answer questions that require multiple dependent queries.
* It will save tokens by only considering the schema from relevant tables.

In [36]:
from langchain_community.agent_toolkits import create_sql_agent

agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)

In [None]:
agent_executor.invoke(
    {
        "input": "How many businesses were rated at least 4.5 stars between 2022-01-10 and 2022-05-10"
    }
)

In [39]:
db.run("SELECT DISTINCT city FROM market_reviews")

"[('Philadelphia',), ('King of Prussia',), ('Conshohocken',)]"

# Few Shot Learning
Few-shot learning is a machine learning technique where a model learns to recognize new tasks or categories using just a few examples. We provide it with a small set of questions and their correct answers, allowing it to generalize and make predictions about similar, unseen questions with minimal data. This is useful in situations where data is scarce or costly to obtain.

| Question                                                                                               | SQL Answer                                                                                                                                                                    |
|--------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 1. How many restaurants are categorized as both "Fast Food" and "Healthy" with at least a 4-star rating? | ```SELECT COUNT(*) AS fast_food_healthy FROM market_reviews<br>WHERE stars >= 4 AND categories LIKE '%Fast Food%' AND categories LIKE '%Healthy%';```         |
| 2. What are the top three most frequently occurring categories among restaurants rated above 4.5 stars? | ```SELECT SUBSTRING_INDEX(SUBSTRING_INDEX(categories, ',', n.n), ',', -1) AS category, COUNT(*) AS category_count FROM market_reviews, (SELECT 1 AS n UNION SELECT 2 UNION SELECT 3 UNION SELECT 4 UNION SELECT 5) n WHERE stars > 4.5 GROUP BY category ORDER BY category_count DESC LIMIT 3;``` |
| 3. Which categories have restaurants that are open, rated above 3 stars, and allow dogs? | ```SELECT DISTINCT TRIM(SUBSTRING_INDEX(categories, ',', -1)) AS category FROM market_reviews WHERE is_open = 1 AND stars > 3 AND attributes LIKE '%DogsAllowed%';``` |
| 4. How many restaurants fall under the category "Restaurants" and have an average rating of 4 stars or higher? | ```SELECT COUNT(*) AS restaurants_count FROM market_reviews<br>WHERE categories LIKE '%Restaurants%' AND stars >= 4;```                                            |
| 5. List all unique categories that have restaurants open for lunch and dinner, and are rated at least 4 stars. | ```SELECT DISTINCT TRIM(SUBSTRING_INDEX(categories, ',', -1)) AS unique_categories FROM market_reviews WHERE stars >= 4 AND attributes LIKE '%GoodForMeal%';``` |

In [41]:
db.run("""
SELECT SUBSTRING_INDEX(SUBSTRING_INDEX(categories, ',', n.n), ',', -1) AS category, COUNT(*) AS category_count FROM market_reviews, (SELECT 1 AS n UNION SELECT 2 UNION SELECT 3 UNION SELECT 4 UNION SELECT 5) n WHERE stars > 4.5 GROUP BY category ORDER BY category_count DESC LIMIT 3;
""")

''