In [1]:
from langchain_google_genai import GoogleGenerativeAI

import os
from dotenv import load_dotenv
load_dotenv() 

load_dotenv()
api_key = os.getenv("GOOGLE_API_KEY")

llm = GoogleGenerativeAI(model="gemini-pro", google_api_key=api_key)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(llm.invoke("write few lines on Machine Learning"))  


Machine learning, a subset of artificial intelligence (AI), empowers computers to learn from data without explicit programming. It involves algorithms that analyze patterns and relationships in data to make predictions or decisions. Machine learning models are trained on historical data and can improve their performance over time as they encounter new data. Types of machine learning include supervised learning, where data is labeled with known outcomes; unsupervised learning, where data is unlabeled; and reinforcement learning, where models learn through trial and error. Machine learning finds applications in various fields, including finance, healthcare, manufacturing, and customer service.


#### Connect with database and ask some basic questions

In [3]:

#from langchain_community.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain

from langchain.utilities import SQLDatabase



For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  from langchain.chains.base import Chain


In [6]:
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine
from dotenv import load_dotenv
import os
from langchain_community.utilities import SQLDatabase

# Load environment variables
load_dotenv()

# Retrieve credentials from the .env file
db_user = os.getenv("DB_USER")
db_password = os.getenv("DB_PASSWORD")
db_host = os.getenv("DB_HOST")
db_name = os.getenv("DB_NAME")

db = SQLDatabase.from_uri(
    f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}",
    sample_rows_in_table_info=3
)


print(db.table_info)



CREATE TABLE sales_tb (
	`TransactionID` INTEGER, 
	`Date` DATE, 
	`CustomerID` VARCHAR(10), 
	`Gender` VARCHAR(10), 
	`Age` INTEGER, 
	`ProductCategory` VARCHAR(50), 
	`Quantity` INTEGER, 
	`PriceperUnit` DECIMAL(10, 2), 
	`TotalAmount` DECIMAL(10, 2)
)ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4

/*
3 rows from sales_tb table:
TransactionID	Date	CustomerID	Gender	Age	ProductCategory	Quantity	PriceperUnit	TotalAmount
1	2023-11-24	CUST001	Male	34	Beauty	3	50.00	150.00
2	2023-02-27	CUST002	Female	26	Clothing	2	500.00	1000.00
3	2023-01-13	CUST003	Male	50	Electronics	1	30.00	30.00
*/


In [8]:
#Convert question to SQL query
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many customers are there"})
response

'```sql\nSELECT COUNT(DISTINCT `CustomerID`) AS `Number of Customers`\nFROM sales_tb;\n```'

In [9]:
cleaned_query = response.strip('```sql\n').strip('\n```')
print(cleaned_query)

SELECT COUNT(DISTINCT `CustomerID`) AS `Number of Customers`
FROM sales_tb;


In [10]:
# Execute the cleaned query
result = db.run(cleaned_query)
print(result)

[(10,)]


In [11]:
chain = create_sql_query_chain(llm, db)
def execute_query(question):
    try:
        # Generate SQL query from question
        response = chain.invoke({"question": question})
        print(response)
        print("###################################################")
        # Strip the formatting markers from the response
        cleaned_query = response.strip('```sql\n').strip('\n```')
        print(cleaned_query)
        print("###################################################")        
        # Execute the cleaned query
        result = db.run(cleaned_query)
        print("###################################################")        
        # Display the result
        print(result)
    except ProgrammingError as e:
        print(f"An error occurred: {e}")


In [12]:
q1 = "How many unique customers are there for each product category"
execute_query(q1)

```sql
SELECT 
  `ProductCategory`,
  COUNT(DISTINCT `CustomerID`) AS `UniqueCustomers`
FROM sales_tb
GROUP BY `ProductCategory`
ORDER BY `UniqueCustomers` DESC
LIMIT 5;
```
###################################################
SELECT 
  `ProductCategory`,
  COUNT(DISTINCT `CustomerID`) AS `UniqueCustomers`
FROM sales_tb
GROUP BY `ProductCategory`
ORDER BY `UniqueCustomers` DESC
LIMIT 5;
###################################################
###################################################
[('Clothing', 4), ('Beauty', 3), ('Electronics', 3)]


In [13]:
q2 = "Calculate total sales amount per product category:"
execute_query(q2)

```sql
SELECT 
  `ProductCategory`, 
  SUM(`TotalAmount`) AS `TotalSales`
FROM 
  sales_tb
GROUP BY 
  `ProductCategory`
ORDER BY 
  `TotalSales` DESC
LIMIT 
  5;
```
###################################################
SELECT 
  `ProductCategory`, 
  SUM(`TotalAmount`) AS `TotalSales`
FROM 
  sales_tb
GROUP BY 
  `ProductCategory`
ORDER BY 
  `TotalSales` DESC
LIMIT 
  5;
###################################################
###################################################
[('Clothing', Decimal('1750.00')), ('Electronics', Decimal('730.00')), ('Beauty', Decimal('280.00'))]


In [14]:
q3 = "calculates the average age of customers grouped by gender."
execute_query(q3)

```sql
SELECT 
    `Gender`, 
    AVG(`Age`) AS `AverageAge`
FROM 
    `sales_tb`
GROUP BY 
    `Gender`
ORDER BY 
    `Gender`
LIMIT 
    5;
```
###################################################
SELECT 
    `Gender`, 
    AVG(`Age`) AS `AverageAge`
FROM 
    `sales_tb`
GROUP BY 
    `Gender`
ORDER BY 
    `Gender`
LIMIT 
    5;
###################################################
###################################################
[('Female', Decimal('41.0000')), ('Male', Decimal('41.4286'))]


In [15]:
q4 = "identify the top spending customers based on their total amount spent."
execute_query(q4)

```sql
SELECT 
  `CustomerID`,
  SUM(`TotalAmount`) AS `TotalSpent`
FROM sales_tb
GROUP BY `CustomerID`
ORDER BY `TotalSpent` DESC
LIMIT 5
```
###################################################
SELECT 
  `CustomerID`,
  SUM(`TotalAmount`) AS `TotalSpent`
FROM sales_tb
GROUP BY `CustomerID`
ORDER BY `TotalSpent` DESC
LIMIT 5
###################################################
###################################################
[('CUST002', Decimal('1000.00')), ('CUST009', Decimal('600.00')), ('CUST004', Decimal('500.00')), ('CUST010', Decimal('200.00')), ('CUST001', Decimal('150.00'))]


In [16]:
q5 = "counts the number of transactions made each month."
execute_query(q5)

```sql
SELECT
  SUBSTRING(`Date`, 1, 7) AS `Month`,
  COUNT(*) AS `Number of Transactions`
FROM sales_tb
GROUP BY `Month`
ORDER BY `Month` DESC
LIMIT 5;
```
###################################################
SELECT
  SUBSTRING(`Date`, 1, 7) AS `Month`,
  COUNT(*) AS `Number of Transactions`
FROM sales_tb
GROUP BY `Month`
ORDER BY `Month` DESC
LIMIT 5;
###################################################
###################################################
[('2023-12', 1), ('2023-11', 1), ('2023-10', 1), ('2023-05', 2), ('2023-04', 1)]


In [17]:
q6 = "calculates the total sales amount and average price per unit for each product category."
execute_query(q6)

```sql
SELECT 
    `ProductCategory`, 
    SUM(`TotalAmount`) AS `TotalSalesAmount`, 
    AVG(`PriceperUnit`) AS `AveragePricePerUnit`
FROM 
    `sales_tb`
GROUP BY 
    `ProductCategory`
ORDER BY 
    `TotalSalesAmount` DESC
LIMIT 
    5;
```
###################################################
SELECT 
    `ProductCategory`, 
    SUM(`TotalAmount`) AS `TotalSalesAmount`, 
    AVG(`PriceperUnit`) AS `AveragePricePerUnit`
FROM 
    `sales_tb`
GROUP BY 
    `ProductCategory`
ORDER BY 
    `TotalSalesAmount` DESC
LIMIT 
    5;
###################################################
###################################################
[('Clothing', Decimal('1750.00'), Decimal('268.750000')), ('Electronics', Decimal('730.00'), Decimal('118.333333')), ('Beauty', Decimal('280.00'), Decimal('43.333333'))]
