## Experiments with db_metadata

In [1]:
from sqlalchemy import create_engine,text
import pandas as pd
import json
import numpy as np

### Pydantic Imports

In [2]:
from pydantic import BaseModel,Field
from typing import Dict,List,Literal, Optional
from pydantic import BaseModel,Field
from typing import Dict,List,Literal, Optional,Any

### Langchain

In [3]:
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate,PromptTemplate
from langchain.output_parsers import PydanticOutputParser
from langchain.schema import HumanMessage, SystemMessage

In [4]:
from dotenv import load_dotenv
load_dotenv(".env")

True

### Vector DB and RAG

In [5]:
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
embedding = OpenAIEmbeddings(model="text-embedding-ada-002")


### Pandas 

In [6]:
import pandas as pd

##  Experiment

In [7]:
import pandas as pd
from sqlalchemy import create_engine, text

def get_all_tables_metadata(engine):
    query = """
    SELECT TABLE_NAME, COLUMN_NAME, DATA_TYPE 
    FROM information_schema.columns
    WHERE table_schema = 'test_db_tables'
    """
    result = pd.read_sql(text(query), engine)
    return result

In [8]:

username = 'user_demo'
password = 'password_demo'
host = 'chat-csv.c9pssgmnnguu.us-west-2.rds.amazonaws.com'
port = 3306
database = 'test_db_tables'

In [9]:
engine = create_engine(f'mysql+pymysql://{username}:{password}@{host}:{port}/{database}')

In [10]:
metadata=get_all_tables_metadata(engine)
metadata

Unnamed: 0,TABLE_NAME,COLUMN_NAME,DATA_TYPE
0,auto_insurance,Customer,text
1,auto_insurance,State,text
2,auto_insurance,Customer Lifetime Value,double
3,auto_insurance,Response,text
4,auto_insurance,Coverage,text
...,...,...,...
714,sri_dataset,d5_g,double
715,sri_dataset,d6,text
716,sri_dataset,d7,text
717,unemployment_level,DATE,datetime


In [11]:
class ChartClass(BaseModel):
    thought: str = Field(...,description = "Thought Process behind decision for deciding possibility of a visualization for the given query on this dataset ")
    is_chart_possible: Literal["Yes", "No","rephrase"] = Field(
        ...,
        description="Indicates if a chart is possible based on the query. Must be either 'Yes', 'No' or 'rephrase', use rephrase if some extra info can help in deciding whether to visualize.",
    )
    tables_columns: Dict[str, List[str]] = Field(
        default=None,
        description="Dictionary of tables and their corresponding columns required for the visualization based on the question."
    )
    chart_type: str = Field(
        ...,
        description="Type of chart suitable for this visualization (e.g., line, multiline, bar, pie).",
    )
    type: Literal[
        "time_series", "group_aggregates", "combination", "multilevel_categorical","relational","distribution",""
    ] = Field(default = None, description="")
    is_derived: Literal[True,False] =  Field(...,description = "whether this metric could be be derived using some mathematical operations between 2 columns ")
   

class RejectClass(BaseModel):
    reason:str = Field(...,description = "Reason for rejecting possibility of visualization, need to crisp and polite")
    suggested_alternatives: List[str] = Field(...,description = "list of alternative questions which a user could ask on this dataset which could give a visualization")


In [12]:
class RephraseClass(BaseModel):
    thought: str = Field(...,description="Thought process used while rephrasing the query, always take cue from the database metadata given")
    rephrased_query:str = Field(...,description= "Rephrased query on which a visualization is plausible based on the given chart type")

In [13]:
class ChartConfig(BaseModel):
    x_axis: str = Field(
        ...,
        description="Column of the dataset to be on the x-axis for the given visualization. It should be a column present in the dataset.",
    )
    y_axis: Any = Field(default = None, description="Column of the dataset to be on the y-axis It will be None in case of no y-axis.")
    binning: Optional[str] = Field(
        None,
        description="Binning for time series or combinations, either 'yearly' or 'monthly'. Only fill in case of time series or combinations.",
    )
    heading: str = Field(..., description="The heading for the visualization.")
    group_by: Optional[str] = Field(
        None,
        description="Column of the dataset on which grouping is to be done. Leave blank for time series. Ensure it is a column present in the dataset.",
    )
    operation: Literal["mean", "max", "sum","count","none",None,""] = Field(
        default=None,
        description="Operation to be inferred from the question. It is either 'mean', 'max', or 'sum'.",
    )
    chart_type: str = Field(
        ...,
        description="The chart type to be used for visualization (e.g., bar, line, multiline).",
    )

## Create an LLM class

In [14]:
def get_llm_response(prompt_file_path:str,pydantic_class=None,llm=None,input_variables:dict=None):
    with open(prompt_file_path,'r') as f:
        prompt = f.read()
    if pydantic_class:
        parser = PydanticOutputParser(pydantic_object=pydantic_class)
        chat_prompt = PromptTemplate(template = prompt,input_variables=list(input_variables.keys()),partial_variables={"format_instructions": parser.get_format_instructions()})
        brain = chat_prompt|llm|parser
    else:
        chat_prompt = PromptTemplate(template = prompt,input_variables=list(input_variables.keys()))
        brain = chat_prompt|llm
        
    output =brain.invoke(input_variables)
    return output
    

In [15]:
llm = ChatOpenAI(model="gpt-4o",model_kwargs={"response_format": {"type": "json_object"}})


## MYSQL  RESULTANT DF FUNCTION

In [16]:
class MySQLPandasDF:
    def __init__(self, chart_config, connection, table, db_type="mysql", filter_columns=None,derived = None):
        """
        Initializes the MySQLPandasDF object with the given parameters.
        """
        self.chart_config = chart_config
        self.connection = connection
        self.table = table
        self.db_type = db_type
        self.filter_columns = filter_columns
        self.derived = derived
        self.freq_dict = {
            "business_day": "B",
            "calendar_day": "D",
            "weekly": "W",
            "monthly": "M",
            "quarterly": "Q",
            "yearly": "A",
            "hourly": "H",
            "minutely": "T",
            "secondly": "S",
            "milliseconds": "L",
            "microseconds": "U",
            "nanoseconds": "N",
            "daily": "D",
        }

    def _apply_filters(self):
        filter_columns = self.filter_columns
        sql_query = "WHERE "
        for col, conditions in filter_columns.get("numerical_columns", {}).items():
            start_value = conditions.get("start_value")
            end_value = conditions.get("end_value")
            if start_value is not None:
                if sql_query != "WHERE ":
                    sql_query += f' AND "{col}">={start_value}'
                else:
                    sql_query += f'"{col}">={start_value}'
            if end_value is not None:
                if sql_query != "WHERE ":
                    sql_query += f' AND "{col}"<={end_value}'
                else:
                    sql_query += f' "{col}"<={end_value}'

        # Apply filters for date columns
        for col, conditions in filter_columns.get("date_columns", {}).items():
            start_date = conditions.get("start_date")
            end_date = conditions.get("end_date")
            if start_date is not None:
                if sql_query != "WHERE ":
                    sql_query += f" AND DATE(`{col}`)>=DATE('{start_date}')"
                else:
                    sql_query += f"DATE(`{col}`)>=DATE('{start_date}')"
            if end_date is not None:
                if sql_query != "WHERE ":
                    sql_query += f" AND DATE(`{col}`)<=DATE('{end_date}')"
                else:
                    sql_query += f" DATE(`{col}`)<=DATE('{end_date}')"
        for col, values in filter_columns.get("categorical_columns", {}).items():
            if values:
                formatted_values = ", ".join([f"'{val}'" for val in values])
                if sql_query != "WHERE ":
                    sql_query += f' AND "{col}" in ({formatted_values})'

                else:
                    sql_query += f'"{col}" in ({formatted_values})'
        if sql_query == "WHERE ":
            sql_query = ""
        return sql_query.replace('"', "`")

    def _get_column_data_types(self, x_axis, y_axis = None):
        """
        Retrieves the data types of the specified columns from the database for MySQL.
        """
        if y_axis:
            col_data_type_query = f"""
                SELECT COLUMN_NAME, DATA_TYPE
                FROM INFORMATION_SCHEMA.COLUMNS
                WHERE TABLE_NAME = '{self.table}'
                  AND COLUMN_NAME IN ('{x_axis}', '{y_axis}')
            """
        else:
            col_data_type_query = f"""
                SELECT COLUMN_NAME, DATA_TYPE
                FROM INFORMATION_SCHEMA.COLUMNS
                WHERE TABLE_NAME = '{self.table}'
                  AND COLUMN_NAME = '{x_axis}'
            """
        df = pd.read_sql_query(text(col_data_type_query), self.connection)
        return df

    def _generate_timeseries_resultant_df(self):
        """
        Generates a pandas DataFrame for time series data based on the chart configuration for MySQL.
        """
        x_axis, y_axis, binning, operation, group_by,formula = self._extract_chart_config()
        col_data_type = self._get_column_data_types(x_axis, y_axis)
        date_type = col_data_type.loc[
            col_data_type["column_name".upper()] == x_axis, "data_type".upper()
        ].values[0]
        if formula:
            y_val = formula
            operation = ""
        else:
            y_val = y_axis
            
        if not binning:
            binning = "monthly"
        if operation == None :
            operation = "AVG"
        if operation not in ["count","sum"] :
            operation = "AVG"
        clause = ""
        if self.filter_columns:
            clause = self._apply_filters()
            if clause != "WHERE ":
                clause = f" {clause}"

        binning_clause = (
            f"YEAR(`{x_axis}`)"
            if binning.lower() == "yearly"
            else f"YEAR(`{x_axis}`), MONTH(`{x_axis}`)"
        )
        if date_type == "text":
            
            sql_query = (
                f"""
                SELECT DATE(`{x_axis}`)AS `{x_axis}`,
                   {operation}({y_val}) AS {y_axis}
            FROM `{self.table}`"""
                + clause
                + f""" 
            GROUP BY {binning_clause}
            ORDER BY {binning_clause} ASC
            """
            )

        elif date_type in ["int", "bigint", "integer"]:
            sql_query = (
                f"""
                SELECT `{x_axis}` AS `{x_axis}`,
                       {operation}({y_val}) AS {y_axis}
                FROM {self.table}
                """
                + clause
                + f"""
                GROUP BY `{x_axis}`
                ORDER BY `{x_axis}` ASC
            """
            )
            return pd.read_sql_query(text(sql_query), self.connection)
        # Adjusted for MySQL's date functions
        else:
            sql_query = (
                f"""
                SELECT `{x_axis}` AS `{x_axis}`,
                       {operation}({y_val}) AS {y_axis}
                FROM `{self.table}` """
                + clause
                + f""" 
                GROUP BY YEAR(`{x_axis}`),MONTH({x_axis})
                ORDER BY YEAR(`{x_axis}`),MONTH({x_axis}) ASC
            """
            )
        sql_query = sql_query.replace('"', "`")
        date_format = {"yearly": "%Y", "monthly": "%Y-%m", "daily": "%Y-%m-%d"}
        result = self.connection.execute(text(sql_query)).fetchall()
        formatted_results = []
        for row in result:
            formatted_date = row[0].strftime(date_format[binning])
            formatted_results.append({x_axis: formatted_date, y_axis: row[1]})

        # column_names = [desc[0] for desc in self.connection.description]
        return pd.DataFrame(formatted_results)

    def _generate_groupaggregate_resultant_df(self):
        """
        Generates a pandas DataFrame for grouped aggregate data based on the chart configuration for MySQL.
        """
        x_axis, y_axis, binning, operation, group_by,formula = self._extract_chart_config()
        if formula:
            y_val =formula
            operation = ""
        else:
            y_val = y_axis
            
        clause = ""
        if self.filter_columns:
            clause = self._apply_filters()
            if clause != "WHERE ":
                clause = f" {clause}"
        if operation == "mean":
            operation = "AVG"
        sql_query = (
            f"""
            SELECT `{x_axis}`,
                   {operation}({y_axis}) AS {y_axis}
            FROM `{self.table}`"""
            + clause
            + f" GROUP BY `{x_axis}`"
        )
        sql_query = sql_query.replace('"', "`")
        df = pd.read_sql_query(text(sql_query), self.connection)
        return df

    def _generate_combination_resultant_df(self):
        """
        Generates a pandas DataFrame combining time series and grouped aggregate data based on the chart configuration for MySQL.
        """
        x_axis, y_axis, binning, operation, group_by,formula = self._extract_chart_config()
        if formula:
            y_val =formula
            operation = ""
        else:
            y_val = y_axis
        col_data_type = self._get_column_data_types(x_axis, y_axis)
        date_type = col_data_type.loc[
            col_data_type["column_name".upper()] == x_axis, "data_type".upper()
        ].values[0]
        clause = ""
        if self.filter_columns:
            clause = self._apply_filters()
            if clause != "WHERE ":
                clause = f" {clause}"
        date_part = "YEAR" if binning == "yearly" else "MONTH"
        date_format = {"yearly": "%Y", "monthly": "%Y-%m", "daily": "%Y-%m-%d"}
        if operation == "" or operation == "mean":
            operation = "AVG"
        # Adjusted for MySQL's date functions
        if date_type == "text":
            sql_query = (
                f"""
                SELECT DATE(`{x_axis}` AS `{x_axis}`,
                   {operation}({y_axis}) AS {y_axis},{group_by}
            FROM `{self.table}`"""
                + clause
                + f"""
            GROUP BY {group_by},DATE(`{x_axis}`)
            ORDER BY DATE(`{x_axis}`) ASC
            """
            )

        elif date_type in ["int", "bigint", "integer"]:
            sql_query = (
                f"""
                SELECT `{x_axis}` AS `{x_axis}`,
                       {operation}({y_axis}) AS `{y_axis}`,
                       `{group_by}`
                FROM `{self.table}` """
                + clause
                + f"""
                GROUP BY `{group_by}`, `{x_axis}`
                ORDER BY `{x_axis}` ASC
            """
            )

        # Adjusted for MySQL's date functions
        else:
            sql_query = (
                f"""
                SELECT `{x_axis}` AS `{x_axis}`,
                       {operation}({y_axis}) AS `{y_axis}`,`{group_by}`
                FROM `{self.table}`"""
                + clause
                + f"""
                GROUP BY `{group_by}`,{x_axis}
                ORDER BY `{x_axis}` ASC
            """
            )
        sql_query = sql_query.replace('"', "`")
        date_format = {"yearly": "%Y", "monthly": "%Y-%m", "daily": "%Y-%m-%d"}
        result = self.connection.execute(text(sql_query)).fetchall()
        formatted_results = []
        for row in result:
            formatted_date = row[0].strftime(date_format[binning])
            formatted_results.append({x_axis: formatted_date, y_axis: row[1]})

        # column_names = [desc[0] for desc in self.connection.description]
        return pd.DataFrame(formatted_results)

    def _generate_multilevelcategorical_resultant_df(self):
        x_axis, y_axis, binning, operation, group_by,formula = self._extract_chart_config()
        col_data_type = self._get_column_data_types(x_axis, y_axis)
        date_type = col_data_type.loc[
            col_data_type["column_name".upper()] == x_axis, "data_type".upper()
        ].values[0]
        clause = ""
        if self.filter_columns:
            clause = self._apply_filters()
            if clause != "WHERE ":
                clause = f" {clause}"
        date_part = "YEAR" if binning == "yearly" else "MONTH"
        date_format = {"yearly": "%Y", "monthly": "%Y-%m", "daily": "%Y-%m-%d"}
        if operation == "" or operation == "mean":
            operation = "AVG"

        if date_type == "text":
            sql_query = (
                f"""
                SELECT DATE(`{x_axis}`) AS `{x_axis}`,
                    {operation}(`{y_axis}`) AS `{y_axis}`, `{group_by}`
                FROM `{self.table}`"""
                + clause
                + f"""
                GROUP BY `{group_by}`, DATE(`{x_axis}`)
                ORDER BY DATE(`{x_axis}`) ASC
            """
            )

        elif date_type in ["int", "bigint", "integer"]:
            sql_query = (
                f"""
                SELECT `{x_axis}` AS `{x_axis}`,
                    {operation}(`{y_axis}`) AS `{y_axis}`,
                    `{group_by}`
                FROM `{self.table}`"""
                + clause
                + f"""
                GROUP BY `{group_by}`, `{x_axis}`
                ORDER BY `{x_axis}` ASC
            """
            )

        else:
            sql_query = (
                f"""
                SELECT `{x_axis}` AS `{x_axis}`,
                    {operation}(`{y_axis}`) AS `{y_axis}`, {group_by}
                FROM `{self.table}`"""
                + clause
                + f"""
                GROUP BY `{group_by}`, {x_axis}
                ORDER BY `{x_axis}` ASC
            """
            )
        sql_query = sql_query.replace('"', "`")
 
        

        return pd.read_sql(text(sql_query), self.connection)

    def _get_derived_result(self):
        """
        Gives an output for derived results
        """
        x_axis, y_axis, binning, operation, group_by = self._extract_chart_config()
        derived = self.derived
        if operation == "" or operation == "mean" or operation == None:
            operation = "AVG"
        # need to think of ways to not restrict derived , it may even be used for x - axis
        # use derived to get the chart_config
        query = f"""
        SELECT `{x_axis}`, {operation}({derived["formula"]}) as `{derived["term"]}` 
        FROM `{self.table}`
        GROUP BY `{x_axis}`
        ORDER BY `{x_axis}` ASC
        """
        df = pd.read_sql(text(query), self.connection)
        if binning == "" or binning=="monthly" :
            date_range = df[x_axis].dt.date.max() - df[x_axis].dt.date.min()
            print(f"date range is: {date_range.days}")
            if date_range.days < 30:
                binning = "daily"
            elif date_range.days > 1460:
                binning = "yearly"
            else:
                binning = "monthly"

        df[x_axis] = df[x_axis].dt.to_period(self.freq_dict[binning])
        if operation == None or operation == "" or operation == 'AVG':
            operation = "mean"
        df = df.groupby(x_axis)[derived["term"]].agg(operation).reset_index()
        return df

    def _generate_relation_resultant_df(self):
        x_axis, y_axis, binning, operation, group_by,formula = self._extract_chart_config()
        query = f'''
        SELECT `{x_axis}`,{y_axis} from `{self.table}`
        '''
        return pd.read_sql(text(query),self.connection)
    def _generate_distribution_resultant_df(self):
        x_axis, y_axis, binning, operation, group_by,formula = self._extract_chart_config()
        col_data_type = self._get_column_data_types(x_axis)
        data_type = col_data_type.loc[col_data_type["column_name".upper()] == x_axis, "data_type".upper()].values[0]
        print(data_type)
        if data_type.lower() in ['double','bigint']:
            
            query = f'''
            SELECT `{x_axis}` from `{self.table}`
            '''
            df =  pd.read_sql(text(query),self.connection)
            Q1 = np.percentile(df[x_axis], 25)
            Q3 = np.percentile(df[x_axis], 75)
            IQR = Q3 - Q1
            
            # Calculate the bin width using the Freedman-Diaconis rule
            bin_width = 2 * IQR / (df.shape[0] ** (1/3))
            bin_width = max(2,bin_width)
            # Calculate the number of bins
            num_bins = int((df[x_axis].max() - df[x_axis].min()) / bin_width)
            
            # Create the bins and calculate frequencies
            freq, bins = np.histogram(df[x_axis], bins=num_bins)
            
            # Create a DataFrame with the bins and frequencies
            bins_midpoints = 0.5 * (bins[:-1] + bins[1:])  # Calculate midpoints of bins for better representation
            df_bins = pd.DataFrame({
                'Bin_Start': bins[:-1],
                'Bin_End': bins[1:],
                'Frequency': freq,
                'Midpoint': bins_midpoints
            })
            return df_bins
        else:
            query = f'''
            SELECT {x_axis},COUNT(`{x_axis}`) from `{self.table}` GROUP BY `{x_axis}`
            '''
            df=  pd.read_sql(text(query),self.connection)
            return df
            
    def _extract_chart_config(self):
        """
        Extracts chart configuration parameters for MySQL.
        """
        x_axis = self.chart_config.get("x_axis", "")
        y_axis = self.chart_config.get("y_axis", "")
        if " " in y_axis:
            y_axis = '`'+y_axis+'`'
        binning = self.chart_config.get("binning", "")
        operation = self.chart_config.get("operation", "")
        formula = self.chart_config.get("formula","")
        if operation == "" or operation == "mean":
            operation = "AVG"
        group_by = self.chart_config.get("group_by", "")
        return x_axis, y_axis, binning, operation, group_by,formula

    def generate_resultant_df(self, metric_type):
        """
        Public method to generate a pandas DataFrame based on the chart configuration for MySQL.
        """
        print("metric_type: ",metric_type)
        
        if metric_type == "combination":
            return self._generate_combination_resultant_df()
        elif metric_type in [
            "group_aggregate",
            "group_aggregates",
            "group_aggregation",
        ]:
            return self._generate_groupaggregate_resultant_df()
        elif metric_type == "time_series":
            return self._generate_timeseries_resultant_df()
        elif metric_type == "multilevel_categorical":
            return self._generate_multilevelcategorical_resultant_df()
        elif metric_type == "relational":
            return self._generate_relation_resultant_df()
        elif metric_type == "distribution":
            return self._generate_distribution_resultant_df()
        else:
            raise ValueError("Invalid metric type provided.")


In [17]:
average_order_value = '''
average_order_value
Definition : Average order value (AOV) tracks the average dollar amount spent each time a customer places an order on a website or mobile app. To calculate your company’s average order value, simply divide total revenue by the number of orders.
Formula : Revenue/Number of orders
'''
gross_profit_margin = '''
gross_profit_margin
Definition: Measures the profitability of a company after subtracting the cost of goods sold (COGS).
Formula:Gross Profit Margin=((Revenue−COGS)/Revenue)×100
'''
return_on_investment = '''
return_on_investment
Definition:Measures the profitability of an investment.
Formula:ROI=((Net Profit from Investment−Cost of Investment)/Cost of Investment)×100
'''
customer_acquistion_cost = '''
customer_acquistion_cost
Definition:Measures the cost of acquiring a new customer.
formulae: Number of New Customers Acquired/Total Sales and Marketing Expenses
'''
customer_lifetime_value ='''
customer_lifetime_value
Definition:Measures the total revenue a business can expect from a single customer account.
formulae: Average Purchase Value×Average Purchase Frequency×Average Customer Lifespan 
'''
sales_growth = '''
Sales Growth
Definition: Measures the increase in sales over a specific period.
formulae:Sales Growth=((Sales in Current Period−Sales in Previous Period)/Sales in Previous Period)×100
'''
customer_satisfaction_score = '''
customer satisfaction_score
Definition: Measures customer satisfaction with a product or service.
Formula: ( Number of Survey Responses/Number of Satisfied Customers)×100
'''


In [18]:
time_series_trend ='''
time_series: Analyzing time series trends. Graph type: Line.
Example: 1."Give me sales trend."
         2."Sales from July to october"
chart types supported : line
'''
group_aggregates = '''
group_aggregates: Analyzing over different product categories. Graph types: Bar, Pie,stacked_bar
Example:1. "Give me sales for each product category."
        2. "Give me sales for product category A and Product category B"
        3. Give me count per product category
chart types supported: bar, pie
'''
combination = '''
combination: A combination of time_series and group_aggregates.Where query is requoring Analysis of time series trends for each group category. Graph type: multiline.
Example: 1.Give me product wise sales trend ,
         2. sales trend for each vehicle_type
chart types supported: multiline
'''
multilevel_categorical = '''
multilevel_categorical: Analyze trend where there is a split at 2 levels
Example: 1. Show me the quantity sold by Product category breakdown by gender
        2. Show me the average profits generated by each product category segmented by customer login type
chart types supported: group_bar, stacked_bar
'''
relational ='''
relational: Analyze relationships like correlation and distribution between two or more columns
Example: 1. relationship between sales and profit 
        2. correlation across marketting expenditure and profits
chart types supported: correlogram, scatter plot
'''
distribution = '''
distribution: chart types which highlight the distribution trend for a variable
Example: 1. show me the distribution of age across the organization
        2. Highlight the skewness of in quantity sold 
chart types supported: histogram,box plot
Note y_axis would be None in this case always
'''

In [19]:
chart_hints = [time_series_trend,group_aggregates,combination,multilevel_categorical,relational,distribution]

In [20]:
faiss_hints = FAISS.from_texts(chart_hints,embedding)

In [21]:
knowledge_glossary = [average_order_value,gross_profit_margin,return_on_investment,customer_acquistion_cost,customer_lifetime_value,sales_growth,customer_satisfaction_score]
knowledge_hints = FAISS.from_texts(knowledge_glossary,embedding)



In [22]:
metadata =metadata.values.tolist()
metadata = [str(i) for i in metadata]

In [23]:
metadata

["['auto_insurance', 'Customer', 'text']",
 "['auto_insurance', 'State', 'text']",
 "['auto_insurance', 'Customer Lifetime Value', 'double']",
 "['auto_insurance', 'Response', 'text']",
 "['auto_insurance', 'Coverage', 'text']",
 "['auto_insurance', 'Education', 'text']",
 "['auto_insurance', 'Effective To Date', 'datetime']",
 "['auto_insurance', 'EmploymentStatus', 'text']",
 "['auto_insurance', 'Gender', 'text']",
 "['auto_insurance', 'Income', 'bigint']",
 "['auto_insurance', 'Location Code', 'text']",
 "['auto_insurance', 'Marital Status', 'text']",
 "['auto_insurance', 'Monthly Premium Auto', 'bigint']",
 "['auto_insurance', 'Months Since Last Claim', 'bigint']",
 "['auto_insurance', 'Months Since Policy Inception', 'bigint']",
 "['auto_insurance', 'Number of Open Complaints', 'bigint']",
 "['auto_insurance', 'Number of Policies', 'bigint']",
 "['auto_insurance', 'Policy Type', 'text']",
 "['auto_insurance', 'Policy', 'text']",
 "['auto_insurance', 'Renew Offer Type', 'text']",
 

In [24]:
columnretriever_db = FAISS.from_texts(metadata,embedding=OpenAIEmbeddings(model="text-embedding-3-large"))

In [25]:
# def retriever(query, k=5):
#     results = columnretriever_db.similarity_search(query, k=k)
#     contexts = [result.page_content for result in results]
#     return contexts

In [26]:
# def model_response(query: str):
#     # Retrieve context
#     contexts = retriever(query, k=5)

#     chat = ChatOpenAI(model="gpt-4o-mini",
#                       temperature=0,
#                       model_kwargs={"seed": 42})

#     messages = [
#         SystemMessage(
#             content=f"You are a business analyst designed to give key insights. You have access to the following context: {contexts}. You are task is to provide the specific columns that related to query.Provide your response in JSON format. For example: {{\n"
#                     '"TABLENAMEA": [["COLUMNNAMEA", "DATATYPEA"], ["COLUMNNAMEB", "DATATYPEB"]],\n'
#                     '"TABLENAMEB": [["COLUMNNAMEA", "DATATYPEA"], ["COLUMNNAMEB", "DATATYPEB"]]\n'
#                     '}}'),
#         HumanMessage(
#             content=query),
#     ]

#     return json.loads(chat(messages).content)
#     # response = chat(messages.content)
#     # return response

In [54]:
def get_resultant_df(query, db_metadata, connection):
    results = columnretriever_db.similarity_search(query, k=10)
    contexts = [result.page_content for result in results]
    print(contexts)

    chart_hints = faiss_hints.similarity_search(query, k=1)
    knowledge_glossary = knowledge_hints.similarity_search(query, k=3)

    chart_class = get_llm_response(
        prompt_file_path="prompts/chart_class2.txt",
        llm=llm,
        input_variables={
            "db_metadata": db_metadata,"question": query,"chart_hints": chart_hints[0].page_content,"contexts": contexts},
        pydantic_class=ChartClass
    )

    print(chart_class)

    if chart_class.is_chart_possible == "Yes":
        derived = None
        tables_columns = chart_class.tables_columns
        table_name = list(tables_columns.keys())[0]
        print(table_name)
        columns = list(tables_columns.values())
        print(columns)
        if chart_class.is_derived:
            derived_response = get_llm_response(
                prompt_file_path="prompts/math_preprocessing.txt",
                llm=llm,
                input_variables={"db_metadata": tables_columns,"knowledge_glossary": knowledge_glossary,"query": query})
            derived = derived_response.content
            print(f"derived: {derived}")

        chart_config = get_llm_response(
            prompt_file_path="prompts/chart_config.txt",
            llm=llm,
            input_variables={"chart_class": chart_class,"question": query,"db_metadata": tables_columns},
            pydantic_class=ChartConfig
        )
        print()
        print(chart_config)
      
        if derived:
            derived = json.loads(derived)
            mysql_pandas_df = MySQLPandasDF(
                chart_config=chart_config.model_dump(),
                connection=connection,
                table=table_name,
                derived=derived
            )
        else:
            mysql_pandas_df = MySQLPandasDF(
                chart_config=chart_config.model_dump(),
                connection=connection,
                table=table_name
            )

        metric_type = chart_class.type
        resultant_df = mysql_pandas_df.generate_resultant_df(metric_type)
        return resultant_df

    elif chart_class.is_chart_possible == "No":
        reject_reason = get_llm_response(
            prompt_file_path='prompts/RejectClass.txt',
            pydantic_class=RejectClass,
            input_variables={
                "db_metadata": db_metadata,
                "question": query
            },
            llm=llm
        )
        print("Reject Reason\n")
        print(reject_reason.suggested_alternatives)

        new_query = reject_reason.suggested_alternatives[0] if reject_reason.suggested_alternatives else reject_reason.suggested_alternatives
        return get_resultant_df(query=new_query, db_metadata=db_metadata, connection=connection)

    else:
        rephrase_query = get_llm_response(
            prompt_file_path="prompts/rephrase.txt",
            pydantic_class=RephraseClass,
            input_variables={
                "db_metadata": db_metadata,
                "question": query
            },
            llm=llm
        )

        new_query = rephrase_query.rephrased_query
        print(new_query)
        return get_resultant_df(query=new_query, db_metadata=db_metadata, connection=connection)


In [49]:
# def get_resultant_df(query, table,db_metadata, connection):
#     results = columnretriever_db.similarity_search(query, k=3)
#     contexts = [result.page_content for result in results]
#     print(contexts)

#     chart_hints = faiss_hints.similarity_search(query, k=1)
#     knowledge_glossary = knowledge_hints.similarity_search(query, k=3)

#     chart_class = get_llm_response(
#         prompt_file_path="prompts/chart_class2.txt",
#         llm=llm,
#         input_variables={
#             "db_metadata": db_metadata,"question": query,"chart_hints": chart_hints[0].page_content,"contexts": contexts},
#         pydantic_class=ChartClass
#     )

#     print(chart_class)

#     if chart_class.is_chart_possible == "Yes":
#         derived = None
#         if chart_class.is_derived:
#             derived_response = get_llm_response(
#                 prompt_file_path="prompts/math_preprocessing.txt",
#                 llm=llm,
#                 input_variables={"db_metadata": db_metadata,"knowledge_glossary": knowledge_glossary,"query": query})
#             derived = derived_response.content
#             print(f"derived: {derived}")

#         chart_config = get_llm_response(
#             prompt_file_path="prompts/chart_config.txt",
#             llm=llm,
#             input_variables={"chart_class": chart_class,"question": query,"db_metadata": db_metadata},
#             pydantic_class=ChartConfig
#         )
#         print()
#         print(chart_config)

#         if derived:
#             derived = json.loads(derived)
#             mysql_pandas_df = MySQLPandasDF(
#                 chart_config=chart_config.model_dump(),
#                 connection=connection,
#                 table=table,
#                 derived=derived
#             )
#         else:
#             mysql_pandas_df = MySQLPandasDF(
#                 chart_config=chart_config.model_dump(),
#                 connection=connection,
#                 table=table
#             )

#         metric_type = chart_class.type
#         resultant_df = mysql_pandas_df.generate_resultant_df(metric_type)
#         return resultant_df

#     elif chart_class.is_chart_possible == "No":
#         reject_reason = get_llm_response(
#             prompt_file_path='prompts/RejectClass.txt',
#             pydantic_class=RejectClass,
#             input_variables={
#                 "db_metadata": db_metadata,
#                 "question": query
#             },
#             llm=llm
#         )
#         print("Reject Reason\n")
#         print(reject_reason.suggested_alternatives)

#         new_query = reject_reason.suggested_alternatives[0] if reject_reason.suggested_alternatives else reject_reason.suggested_alternatives
#         return get_resultant_df(query=new_query, table=table, db_metadata=db_metadata, connection=connection)

#     else:
#         rephrase_query = get_llm_response(
#             prompt_file_path="prompts/rephrase.txt",
#             pydantic_class=RephraseClass,
#             input_variables={
#                 "db_metadata": db_metadata,
#                 "question": query
#             },
#             llm=llm
#         )

#         new_query = rephrase_query.rephrased_query
#         print(new_query)
#         return get_resultant_df(query=new_query, table=table,db_metadata=db_metadata, connection=connection)


## Testing

In [50]:
db_metadata = get_all_tables_metadata(engine)
# db_metadata = db_metadata.to_csv()
db_metadata

Unnamed: 0,TABLE_NAME,COLUMN_NAME,DATA_TYPE
0,auto_insurance,Customer,text
1,auto_insurance,State,text
2,auto_insurance,Customer Lifetime Value,double
3,auto_insurance,Response,text
4,auto_insurance,Coverage,text
...,...,...,...
714,sri_dataset,d5_g,double
715,sri_dataset,d6,text
716,sri_dataset,d7,text
717,unemployment_level,DATE,datetime


### distribution

For the given Metadata of the datasets

["['sales_data', 'date', 'datetime']","['sales_data', 'day of the week', 'text']","['sales_data', 'product_id', 'bigint']","['sales_data', 'sales_amount', 'double']","['sales_data', 'gender', 'text']","['product_categories', 'product_id', 'bigint']","['product_categories', 'category', 'text']",
"['product_categories', 'cost', 'double']","['product_categories', 'quantity', 'bigint']"]


{question: , table , columns}

In [56]:
%%time
from loguru import logger
connection = engine.connect()
query = "Income by Marital Status"
df =get_resultant_df(query = query,db_metadata=db_metadata,connection = connection)
df

["['auto_insurance', 'Marital Status', 'text']", "['bank_demog', 'marital', 'text']", "['auto_insurance', 'Income', 'bigint']", "['auto_insurance', 'EmploymentStatus', 'text']", "['ecom_sales', 'Profit', 'double']", "['ecom_sales', 'Sales', 'double']", "['cgr_premiums', 'gender', 'text']", "['generic_percentage_data', 'No Coverage / Out-of-pocket (%)', 'double']", "['ecommerce', 'Profit', 'double']", "['auto_insurance', 'Customer Lifetime Value', 'double']"]
thought="The query asks for a relationship between 'Income' and 'Marital Status'. Both 'Income' and 'Marital Status' are present in the 'auto_insurance' dataset. This suggests that we can analyze the correlation or distribution between these columns." is_chart_possible='Yes' tables_columns={'auto_insurance': ['Income', 'Marital Status']} chart_type='scatter plot' type='relational' is_derived=False
auto_insurance
[['Income', 'Marital Status']]

x_axis='Marital Status' y_axis='Income' binning=None heading='Income by Marital Status' g

Unnamed: 0,Marital Status,Income
0,Married,56274
1,Single,0
2,Married,48767
3,Married,0
4,Single,43836
...,...,...
9129,Married,71941
9130,Divorced,21604
9131,Single,0
9132,Married,21941


In [61]:
# rephrase
# Ratio of premium ticket types to other ticket types by gender
# Yes
# derived:{"thought":"To calculate the ratio of premium ticket types to other ticket types by gender, we need to count the number of premium ticket types and the number of other ticket types for each gender. Then, we can divide the number of premium ticket types by the number of other ticket types for each gender.","formula":"COUNT(Ticket_Type='Premium')/COUNT(Ticket_Type!='Premium') BY Customer_Gender","term":"premium_to_other_ticket_ratio_by_gender"}
# metric_type:  group_aggregates
# SELECT `Customer Gender`,
 #                   none(`Ticket Type`) AS `Ticket Type`
 #            FROM `customer_support` GROUP BY `Customer Gender`]

### Relational

In [None]:
connection = engine.connect()
table = "ecommerce"
query = "relation between profit and quantity"
df =get_resultant_df(query = query,table = "ecommerce",db_metadata = db_metadata,connection = connection)
df

### Time series

In [None]:
from loguru import logger
connection = engine.connect()
table = "ecommerce"
query = "profits over time"

In [None]:
df =get_resultant_df(query = query,table = "ecommerce",db_metadata = db_metadata,connection = connection)
df

### Derived Metrics

In [None]:
connection = engine.connect()
table = "ecommerce"
query = "Average order value over time"
df =get_resultant_df(query = query,table = "ecommerce",db_metadata = db_metadata,connection = connection)
df

In [None]:
knowledge_glossary = [average_order_value,gross_profit_margin,return_on_investment,customer_acquistion_cost,customer_lifetime_value,sales_growth,customer_satisfaction_score]

## Logic

## Rough Testing

In [None]:
#chart_class = get_llm_response(prompt_file_path="prompts/chart_class.txt",llm=llm,input_variables={"db_metadata":db_metadata,"question":"weather trend over time"},pydantic_class=ChartClass)

In [None]:
#chart_class

In [None]:
#reject_reason = get_llm_response(prompt_file_path='prompts/RejectClass.txt',pydantic_class=RejectClass,input_variables={"db_metadata":db_metadata,"question":"weather trend over time"},llm = llm)

In [None]:
#reject_reason

In [None]:
#input_variables = {"db_metadata":db_metadata,"question":"sales"}

In [None]:
#rephrase_query = get_llm_response(prompt_file_path="prompts/rephrase.txt",pydantic_class=RephraseClass,input_variables={"db_metadata":db_metadata,"question":"sales"},llm = llm)

In [None]:
#rephrase_query