In [None]:
#Dependencies
!pip install dspy-ai[chromadb] -Uqq
!pip install termcolor -Uqq
!pip install sqlalchemy -Uqq

## AIM OF THE TUTORIAL

* Build an end-to-end Text-to-SQL pipeline inspired from this [video](https://www.youtube.com/watch?v=L1o1VPVfbb0&pp=ygUYYWR2YW5jZWQgUkFHIGxsYW1hIGluZGV4) from Llama Index. In Llama Index, they used llama index Query Pipeline to build a Text-to-SQL pipeline. Here, we will build a Text-to-SQL pipeline based on our own dataset and from scratch. We will go from the dataset scraping, to building SQLlite database and using DSPy signatures to implementa a text-to-SQL pipeline

## ABOUT THE DATASET
* You can find the dataset [here](https://pages.stern.nyu.edu/~adamodar/New_Home_Page/datacurrent.html). The dataset has different industry based different financial metrics like WACC, tax rates, EBITDA, etc. There are multiple regions data `['US', 'Europe', 'Japan', 'AUS_NZ_CANADA', 'Emerging', 'China', 'India', 'Global']` where we have multiple tables for each region. There are nearly 250 tables with multiple columns in each table. We will build a text-to-SQL pipeline based on our own dataset and from scratch, starting from embedding tables schema and rows using ChromDB vector database.

In [1]:
#imports
from bs4 import BeautifulSoup 
import urllib.request
import ssl
from dotenv import load_dotenv
import openai
import os
import requests
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
import pandas as pd
import json
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
)
import re
from sqlalchemy import inspect
import sqlalchemy
from sqlalchemy import text 
import dspy
from termcolor import colored
import chromadb

In [2]:
from dsp.modules.cache_utils import cache_turn_on

cache_turn_on

True

## SCRAPING THE LINKS OF THE EXCEL FILES FROM THE WEBSITE

In [2]:
ssl._create_default_https_context = ssl._create_stdlib_context
html_link = "https://pages.stern.nyu.edu/~adamodar/New_Home_Page/datacurrent.html"

with urllib.request.urlopen(html_link) as url:
    s = url.read()
    # I'm guessing this would output the html source code ?
    soup = BeautifulSoup(s,"lxml")

html_table = soup.find_all("table")
req_table = html_table[1]
hrefs_list = req_table.find_all('a')

In [42]:
req_href = {"US":[],"Europe":[],"Japan":[],"AUS_NZ_CANADA":[],"Emerging":[],"China":[],"India":[],"Global":[]}

for i in hrefs_list:
    name = i.get_text().strip()
    try:
        href_attr = i['href']
        # Only get the excel files
        if href_attr.endswith('.xls'):
            if "US" in name:
                req_href["US"].append(href_attr)
            elif "Europe" in name:
                req_href["Europe"].append(href_attr)
            elif "Japan" in name:
                req_href["Japan"].append(href_attr)
            elif "Aus" in name:
                req_href['AUS_NZ_CANADA'].append(href_attr)
            elif "Emerging" in name:
                req_href['Emerging'].append(href_attr)
            elif "China" in name:
                req_href['China'].append(href_attr)
            elif "India" in name:
                req_href['India'].append(href_attr)
            elif "Global" in name: 
                req_href['Global'].append(href_attr)
    except:
        pass

In [93]:
#Download the excel files from the website and store it in a folder named DATA
ssl._create_default_https_context = ssl._create_stdlib_context

os.makedirs("DATA",exist_ok=True)
for country,excel_files in req_href.items():
    country_path = os.path.join("DATA",country) 
    os.makedirs(country_path,exist_ok=True)
    for file in excel_files:
        file_name = file.split("/")[-1].split(".")[0]
        full_file_name = os.path.join(country_path,f"{file_name}.xls")
        resp = requests.get(file,verify=False)
        output = open(full_file_name, 'wb')
        output.write(resp.content)
        output.close()

In [35]:
# Sanity check
for country in os.listdir("DATA"):
    dir_len = len(os.listdir(os.path.join("DATA",country)))
    country_len = len(req_href[country])
    print(f'FOR {country} WE HAVE DIRECTORY LEN = {dir_len} and ACTUAL LEN = {country_len}')

FOR Emerging WE HAVE DIRECTORY LEN = 29 and ACTUAL LEN = 29
FOR Europe WE HAVE DIRECTORY LEN = 29 and ACTUAL LEN = 29
FOR Global WE HAVE DIRECTORY LEN = 29 and ACTUAL LEN = 29
FOR AUS_NZ_CANADA WE HAVE DIRECTORY LEN = 29 and ACTUAL LEN = 29
FOR China WE HAVE DIRECTORY LEN = 29 and ACTUAL LEN = 29
FOR Japan WE HAVE DIRECTORY LEN = 29 and ACTUAL LEN = 29
FOR US WE HAVE DIRECTORY LEN = 24 and ACTUAL LEN = 24
FOR India WE HAVE DIRECTORY LEN = 29 and ACTUAL LEN = 29


## CLEANING THE DATASET

In [73]:
sample_excel = pd.ExcelFile("DATA2/US/capex.xls")

In [79]:
sn = 'Variables & FAQ'
sample_excel.parse(sn).head(10)

Unnamed: 0,End Game,"To measure how much companies are reinvesting back into their long term assets, as a prelude to forecasting expected growth.",Unnamed: 2,Unnamed: 3
0,,,,
1,,,,
2,Variable,How it is measured,What it measures,Units
3,Capital Expenditures,Sum of the capital expenditures reported on st...,"Gross investment in long term assets, at least...",$ millions
4,Depreciation,Sum of the depreciation reported on statement ...,"Loss in value of assets, from use, as measured...",$ millions
5,Net Cap Ex,Sum of capital expenditures on statement of ca...,"Net investment in long term assets, at least a...",$ millions
6,Net R&D,Sum of R&D reported as expense in most recent ...,"Net investment in long term assets, expanded t...",$ millions
7,Acquisitions,Sum of acquisitions reported on statement of c...,Augments investment to include acquisition.,$ millions


In [78]:
sample_excel.parse(sample_excel.sheet_names[1]).head(10)

Unnamed: 0,Date updated:,2024-01-05 00:00:00,Unnamed: 2,Unnamed: 3,Unnamed: 4,Unnamed: 5,Unnamed: 6,Unnamed: 7,Unnamed: 8,Unnamed: 9
0,Created by:,"Aswath Damodaran, adamodar@stern.nyu.edu",,,,,,,,
1,What is this data?,"Capital Expenditures, Acquisitions and R&D and...",,,,US companies,,,,
2,Home Page:,http://www.damodaran.com,,,,,,,,
3,Data website:,https://pages.stern.nyu.edu/~adamodar/New_Home...,,,,,,,,
4,Companies in each industry:,https://pages.stern.nyu.edu/~adamodar/pc/datas...,,,,,,,,
5,Variable definitions:,https://pages.stern.nyu.edu/~adamodar/New_Home...,,,,,,,,
6,Industry Name,Number of Firms,Capital Expenditures (US $ millions),Depreciation & Amort ((US $ millions),Cap Ex/Deprecn,Acquisitions (US $ millions),Net R&D (US $ millions),Net Cap Ex/Sales,Net Cap Ex/ EBIT (1-t),Sales/ Invested Capital (LTM)
7,Advertising,57,775.729,1887.338,0.411018,322.559,75.08,-0.016886,-0.21812,3.283403
8,Aerospace/Defense,70,10982.128,13311.598,0.825004,10344.75,830.4634,0.022829,0.318077,1.98434
9,Air Transport,25,25559.725,10609.948,2.409034,368.21,73.5486,0.067203,1.355173,1.77732


* In the dataset above, there are two sheets. The first sheet is the variables and summary, and the second sheet is the table with the data. 
* We will clean the first sheet to get the table name and summary

In [38]:
import pandas as pd
pd.set_option('display.max_rows', 50)

def sanitize_column_name(col_name):
    # Remove special characters and replace spaces with underscores
    return re.sub(r"\W+", "_", col_name)

dir = "DATA"
processed_dir = "Processed Data"
all_infos_dict = []
os.makedirs(processed_dir,exist_ok=True)
for country in os.listdir(dir):
    print(country)
    file_name = os.path.join(dir,country)
    os.makedirs(os.path.join(processed_dir,country),exist_ok=True)
    os.makedirs(file_name,exist_ok=True)
    for excel_file in tqdm(os.listdir(file_name)):
        full_file_name = os.path.join(file_name,excel_file)
        xls = pd.ExcelFile(full_file_name)
        sns = xls.sheet_names
        for sheet_name in sns:
            if "Var" in sheet_name or "var" in sheet_name:
                info_df = xls.parse(sheet_name)
                info_df.dropna(how="all",inplace=True)
                info_dict = {}
                for cols in info_df.columns:
                    if "End" not in cols and 'Unnamed' not in cols:
                        info_dict['Summary'] = cols
                info_dict['Vars'] = info_df.values[1:].tolist()
                all_infos_dict.append(info_dict)
            elif "Industry" in sheet_name or "industry" in sheet_name:
                data_df = xls.parse(sheet_name)
        try:
            data_df.dropna(axis=1,thresh=5,inplace=True)
            data_df.dropna(inplace=True)
            new_header = data_df.iloc[0] #grab the first row for the header
        except:
            print(full_file_name)
            print(data_df)
        data_df = data_df[1:] #take the data less the header row
        data_df.reset_index(inplace=True,drop=True)
        new_header = [sanitize_column_name(str(col)) for col in new_header]
        data_df.columns = new_header #set the header row as the df header
        save_name = full_file_name.split(".")[0].split("/")[-1]
        save_file_path = os.path.join(os.path.join(processed_dir,country),save_name)
        data_df.to_csv(save_file_path+".csv",index=False)
        with open(save_file_path+".json", "w") as outfile: 
            json.dump(info_dict, outfile)

Emerging


100%|██████████| 29/29 [00:00<00:00, 86.52it/s]


Europe


100%|██████████| 29/29 [00:00<00:00, 94.75it/s]


Global


100%|██████████| 29/29 [00:00<00:00, 102.75it/s]


AUS_NZ_CANADA


100%|██████████| 29/29 [00:00<00:00, 104.23it/s]


China


100%|██████████| 29/29 [00:00<00:00, 103.50it/s]


Japan


100%|██████████| 29/29 [00:00<00:00, 66.34it/s] 


US


100%|██████████| 24/24 [00:00<00:00, 102.91it/s]


India


100%|██████████| 29/29 [00:00<00:00, 100.35it/s]


## A SAMPLE METADATA JSON

In [196]:
all_infos_dict[0]

{'Summary': 'Measures of accounting returns, to all claim holders (and from operations)',
 'Vars': [['Number of firms',
   'Number of firms in the industry grouping.',
   'Law of large numbers?'],
  ['R&D Capitalized',
   'My estimate of R&D capitalization, based upon a 5-year straight line amortization period, aggregated across firms in the group',
   'Capitalized value of R&D gets added on to book equity and to invested capital'],
  ['Capitalized R&D as percent of invested capital',
   'My R&D capitalization estimate, as a percent of invested capital including that number, based upon aggregated values across firms',
   'Magnitude of investment in R&D, relative to investment in more traditional capital expenditures.'],
  ['R&D -LTM',
   'Aggregated R&D expenses across the last twelve months, across companies in the group.',
   'Spending on R&D in most recent year'],
  ['R&D: Years minus 1 to minus 5',
   'Aggregated R&D expenses for each of the previous five years, across companies in

## DATAFRAME AFTER PREPROCESSING

In [197]:
df = pd.read_csv("Processed Data/US/capex.csv")
df.head()

Unnamed: 0,Industry_Name,Number_of_Firms,Capital_Expenditures_(US_$_millions),Depreciation_&_Amort_((US_$_millions),Cap_Ex/Deprecn,Acquisitions_(US_$_millions),Net_R&D_(US_$_millions),Net_Cap_Ex/Sales,Net_Cap_Ex/_EBIT_(1-t),Sales/_Invested_Capital_(LTM)
0,Advertising,57,775.729,1887.338,0.411018,322.559,75.08,-0.016886,-0.21812,3.283403
1,Aerospace/Defense,70,10982.128,13311.598,0.825004,10344.75,830.4634,0.022829,0.318077,1.98434
2,Air Transport,25,25559.725,10609.948,2.409034,368.21,73.5486,0.067203,1.355173,1.77732
3,Apparel,38,1730.981,1386.48,1.248472,38.739,0.9474,0.005365,0.072557,1.773076
4,Auto & Truck,34,29899.486,18677.668,1.600815,193.22,983.0696,0.026577,0.675918,1.048732


## BUILD TABLE NAMES AND METADATA

* Here we use a DSPy signature given the first 10 rows of the dataframe, we generate the table name and table explanation. It will help us to dynamically select the correct table based on the query.

In [199]:
df.head(10).to_csv()

',Industry_Name,Number_of_Firms,Capital_Expenditures_(US_$_millions),Depreciation_&_Amort_((US_$_millions),Cap_Ex/Deprecn,Acquisitions_(US_$_millions),Net_R&D_(US_$_millions),Net_Cap_Ex/Sales,Net_Cap_Ex/_EBIT_(1-t),Sales/_Invested_Capital_(LTM)\n0,Advertising,57,775.729,1887.338,0.4110175283918408,322.559,75.07999999999993,-0.0168860862168837,-0.2181203933865563,3.283403041333076\n1,Aerospace/Defense,70,10982.128,13311.597999999998,0.8250044810547916,10344.749999999998,830.4633999999933,0.0228292321038578,0.318077061754813,1.9843395804666923\n2,Air Transport,25,25559.725,10609.948,2.4090339556800844,368.21,73.5486000000019,0.0672027546276203,1.3551729988499104,1.7773199625336142\n3,Apparel,38,1730.9809999999998,1386.4799999999996,1.2484716692631703,38.739,0.9473999999991064,0.0053650838666078,0.0725567415076527,1.7730762352050575\n4,Auto & Truck,34,29899.486,18677.668,1.6008147269776931,193.22,983.0695999999988,0.0265766732899557,0.6759175844836917,1.048731678622114\n5,Auto Parts,39,35

In [23]:
load_dotenv(override=True)
openai.api_key = os.environ['OPENAI_API_KEY']

In [24]:
turbo = dspy.OpenAI(model='gpt-3.5-turbo-instruct', max_tokens=250)
dspy.settings.configure(lm=turbo)

class SQLTableMetadata(dspy.Signature):
    """Give a suitable table name and description about the given table"""
    pandas_dataframe_str = dspy.InputField(desc="First 10 rows of a pandas dataframe delimited by newline character")
    table_name = dspy.OutputField(desc="suitable table name")
    table_summary = dspy.OutputField(desc="a summary about the table")

class CoT(dspy.Module):
    def __init__(self):
        super().__init__()
        self.prog = dspy.ChainOfThought(SQLTableMetadata)
    
    def forward(self, pandas_dataframe_str):
        return self.prog(pandas_dataframe_str=pandas_dataframe_str)

cot = CoT()

# cot(pandas_dataframe_str = df.head(10).to_csv())


In [40]:
processed_dir = "Processed Data"
dfs_str = []
for country in os.listdir(processed_dir):
    country_folder = os.path.join(processed_dir,country)
    # print(f"{country}")
    for files in tqdm(os.listdir(country_folder),desc=f"Building the summary and name for {country}"):
        if files.endswith(".csv"):
            file_name = files.split(".")[0]
            csv_file_path = os.path.join(country_folder,files)
            df = pd.read_csv(csv_file_path,index_col=False)
            json_file_path = os.path.join(country_folder,f"{file_name}.json")
            with open(json_file_path,'r') as f:
                data = json.loads(f.read())
            if 'table_name' in data and 'table_summary' in data:
                # if data['table_name'] == "" or data['table_summary'] == "":
                if data['table_summary'] == "":
                    pass
                else:
                    continue
            dfs_str.append(df.head(10).to_csv())
            table_preds = cot(pandas_dataframe_str = df.head(10).to_csv())
            data['table_name'] = table_preds.table_name
            data['table_summary'] = table_preds.table_summary
            with open(json_file_path,'w') as f:
                json.dump(data, f)

## NEXT TASKS
1. Build database with each region for each table
2. Embed the table summary and table SCHEMA. Also, embed the table rows
3. Retrieval at table level and embed the rows to retrieve relevant rows from the retrieved schema of table
4. Text-to-SQL pipeline

## BUILD THE SQLITE DATABASE FROM THE CSV FILES

It was taken from the [tutorial](https://docs.llamaindex.ai/en/stable/examples/pipeline/query_pipeline_sql/)

In [3]:
# Function to create a sanitized column name
def sanitize_column_name(col_name):
    # Remove special characters and replace spaces with underscores
    return re.sub(r"\W+", "_", col_name)


# Function to create a table from a DataFrame using SQLAlchemy
def create_table_from_dataframe(
    df: pd.DataFrame, table_name: str, engine, metadata_obj
):
    # Sanitize column names
    sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
    df = df.rename(columns=sanitized_columns)

    # Dynamically create columns based on DataFrame columns and data types
    columns = [
        Column(col, String if dtype == "object" else Integer)
        for col, dtype in zip(df.columns, df.dtypes)
    ]

    # Create a table with the defined columns
    table = Table(table_name, metadata_obj, *columns)

    # Create the table in the database
    metadata_obj.create_all(engine)

    # Insert data from DataFrame into the table
    with engine.connect() as conn:
        for _, row in df.iterrows():
            insert_stmt = table.insert().values(**row.to_dict())
            conn.execute(insert_stmt)
        conn.commit()

## DATABASE CREATION

In [4]:
processed_dir  = "Processed Data"
def sqlalchemy_engine(region:str):
    """Create a SQLAlchemy engine for the given region"""
    assert region in os.listdir(processed_dir), f"{region} is not a valid region from {os.listdir(processed_dir)}"
    # Create a SQLAlchemy database for each region
    engine = create_engine(f"sqlite:///{region}.db")
    metadata_obj = MetaData()
    region_path = os.path.join(processed_dir,region)
    dfs = []
    for dataframes_path in os.listdir(region_path):
        if dataframes_path.endswith(".csv"):
            df = pd.read_csv(os.path.join(region_path,dataframes_path),index_col=False)
            dfs.append((dataframes_path,df))
    pbar = tqdm(total=len(dfs),desc=f"Creating tables for {region}")
    for _, df_table_name in enumerate(dfs):
        table_name = df_table_name[0]
        table_name = table_name.split(".")[0]
        df = df_table_name[1]
        # print(f"Creating table: {table_name}")
        create_table_from_dataframe(df,table_name, engine, metadata_obj)
        # print(f"Done creating table for: {table_name}")
        pbar.update(1)
    return engine

In [5]:
us_engine = sqlalchemy_engine("US")
india_engine = sqlalchemy_engine("India")
china_engine = sqlalchemy_engine("China")
europe_engine = sqlalchemy_engine("Europe")
global_engine = sqlalchemy_engine("Global")
aus_nz_canada_engine = sqlalchemy_engine("AUS_NZ_CANADA")
japan_engine = sqlalchemy_engine("Japan")
emerging_engine = sqlalchemy_engine("Emerging")

Creating tables for US:   0%|          | 0/24 [00:00<?, ?it/s]

Creating tables for US: 100%|██████████| 24/24 [00:00<00:00, 32.16it/s]
Creating tables for India: 100%|██████████| 29/29 [00:00<00:00, 33.36it/s]
Creating tables for China: 100%|██████████| 29/29 [00:01<00:00, 28.46it/s]
Creating tables for Europe: 100%|██████████| 29/29 [00:00<00:00, 30.79it/s]
Creating tables for Global: 100%|██████████| 29/29 [00:00<00:00, 30.62it/s]
Creating tables for AUS_NZ_CANADA: 100%|██████████| 29/29 [00:00<00:00, 29.32it/s]
Creating tables for Japan: 100%|██████████| 29/29 [00:00<00:00, 33.82it/s]
Creating tables for Emerging: 100%|██████████| 29/29 [00:00<00:00, 31.91it/s]


In [6]:
def get_table_infos(sql_engine:sqlalchemy.engine.base.Engine,region:str):
    """Get all the tables info in the database based on the given region"""
    inspector = inspect(sql_engine)
    table_names = inspector.get_table_names()
    table_infos_dict = {tb: [] for tb in table_names}
    for tb in table_names:
        column_dict = inspector.get_columns(tb)
        schema_str = ""
        primary_keys = []
        for col in column_dict:
            schema_str += f"{col['name']} ({col['type']}), "
            if col["primary_key"] not in primary_keys:
                primary_keys.append(col["name"])
        with open(os.path.join(processed_dir,region,f"{tb}.json")) as f:
            table_info = json.loads(f.read())
        table_infos_dict[tb] = [
            {
                "table_info": f"Table {tb} has columns: {schema_str[:-2]}",
                "table_summary": f'{table_info["Summary"]}. {table_info["table_summary"]}. ',
            }
        ]
    return table_infos_dict

In [7]:
us_tb_dict = get_table_infos(us_engine,"US")
india_tb_dict = get_table_infos(india_engine,"India")
china_tb_dict = get_table_infos(china_engine,"China")
europe_tb_dict = get_table_infos(europe_engine,"Europe")
global_tb_dict = get_table_infos(global_engine,"Global")
aus_nz_canada_tb_dict = get_table_infos(aus_nz_canada_engine,"AUS_NZ_CANADA")
japan_tb_dict = get_table_infos(japan_engine,"Japan")
emerging_tb_dict = get_table_infos(emerging_engine,"Emerging")

In [8]:
us_tb_dict['DollarUS']

[{'table_info': 'Table DollarUS has columns: Industry_Name (VARCHAR), Number_of_firms (INTEGER), Average_Company_Age_years_ (INTEGER), Market_Cap_millions_ (INTEGER), Book_Equity_millions_ (INTEGER), Enteprise_Value_millions_ (INTEGER), Invested_Capital_millions_ (INTEGER), Total_Debt_including_leases_millions_ (INTEGER), Revenues_millions_ (INTEGER), Gross_Profit_millions_ (INTEGER), EBITDA_millions_ (INTEGER), EBIT_Operating_Income_millions_ (INTEGER), Net_Income_millions_ (INTEGER)',
  'table_summary': 'To report aggregated dollar value of key operating and marker numbers, by industry group, in millions of US $.. This table contains financial information about various industries such as their names, number of firms, average company age, market cap, book equity, enterprise value, invested capital, total debt, revenues, gross profit, EBITDA, EBIT, and net income. The data is delimited by a newline character and can be used for comparison and analysis of the financial performance of di

## EMBEDDINGS

1. Embed the table summary and table SCHEMA to get the table that the user is looking for
2. Embed the table rows for each table, so as to get relevant rows from the retrieved table 

## EMBED THE TABLE SUMMARY AND TABLE SCHEMA

In [9]:
import chromadb.utils.embedding_functions as embedding_functions
from chromadb.utils.batch_utils import create_batches

load_dotenv(override=True)
emb_fn = embedding_functions.OpenAIEmbeddingFunction(
                api_key=os.environ['OPENAI_API_KEY'],
                model_name="text-embedding-3-small")
# EMBEDDING_MODEL = "mixedbread-ai/mxbai-embed-large-v1"
# emb_fn = embedding_functions.HuggingFaceEmbeddingFunction(model_name=EMBEDDING_MODEL,api_key=os.environ["HF_API_KEY"])
def embed_table_info(region:str,tb_dict):
    """Embed the table summary and table SCHEMA to get the table that the user is looking for"""
    client = chromadb.PersistentClient(path=f"{region}_TABLE")

    table_collection = client.create_collection(name="table",embedding_function=emb_fn)

    table_docs = []
    table_metadata = []


    for table_name,table_data in tb_dict.items():
        table_docs.append(table_data[0]['table_info'] + ". " + table_data[0]['table_summary'])
        table_metadata.append({"table_name":table_name,'table_metadata':table_data[0]['table_info']})
    table_ids = [f"id{i}" for i in range(len(table_docs))]
    assert len(table_docs) == len(table_metadata)
    print(len(table_docs),len(table_metadata))
    # Create a batch of data to be sent to OpenAI Embedding API
    batches = create_batches(api=client,ids=table_ids, documents=table_docs, metadatas=table_metadata)
    for batch in tqdm(batches,desc="Embedding table info"):
        table_collection.add(ids=batch[0],
                    documents=batch[3],
                    metadatas=batch[2])

# embed_table_info("US",us_tb_dict)

## For some strange reason, the `create_batches` was not batching the below documents, hence I had to do it manually

In [10]:

def embed_rows(region:str,batch_size:int=24):
    client = chromadb.PersistentClient(path=f"{region}_TABLE")
    # client.delete_collection(name="rows")
    rows_collection = client.create_collection(name="rows",embedding_function=emb_fn)

    rows_docs = []
    rows_metadata = []
    region_path = os.path.join(processed_dir,region)
    for df_path in os.listdir(region_path):
        df_full_path = os.path.join(region_path,df_path)
        df = pd.read_csv(df_full_path,index_col=False)
        for idx,row in df.iterrows():
            row_str = ""
            full_rows = []
            for rv in row.values:
                if isinstance(rv,str):
                    row_str+= rv + ", "
                full_rows.append(str(rv))
                row_str = row_str.replace('"',"")
                # row_str = row_str.replace("'",'"')
            full_rows_str = ", ".join(full_rows)[:-2]
            full_rows_str = full_rows_str.replace('"',"")
            rows_docs.append(row_str[:-2])
            rows_metadata.append({"table_name":df_path.split(".")[0],"region":region,"index":idx,"full_rows":full_rows_str})
    row_ids = [f"id{i}" for i in range(len(rows_docs))]
    # print(len(rows_docs),len(rows_metadata))
    assert len(rows_docs) == len(rows_metadata) == len(row_ids)
    # return rows_docs,rows_metadata,row_ids
    for start in tqdm(range(0,len(rows_docs),batch_size),desc="Embedding rows"):
        end = min(start+batch_size,len(rows_docs))
        batch_ids = row_ids[start:end]
        batch_rows = rows_docs[start:end]
        batch_metadatas = rows_metadata[start:end]
        rows_collection.add(ids=batch_ids,
                    documents=batch_rows,
                    metadatas=batch_metadatas)
    # return batches

In [40]:
region = "US"
embed_table_info(region,us_tb_dict)
embed_rows(region,2000)

Embedding rows: 100%|██████████| 2/2 [00:04<00:00,  2.44s/it]


In [41]:
region = "India"
embed_table_info(region,india_tb_dict)
embed_rows(region,2000)

Embedding rows: 100%|██████████| 2/2 [00:05<00:00,  2.61s/it]


In [42]:
region = "China"
embed_table_info(region,china_tb_dict)
embed_rows(region,2000)

Embedding rows: 100%|██████████| 2/2 [00:25<00:00, 12.86s/it]


In [43]:
region = "Europe"
embed_table_info(region,europe_tb_dict)
embed_rows(region,1000)

Embedding rows: 100%|██████████| 3/3 [00:26<00:00,  8.89s/it]


In [44]:
region = "Global"
embed_table_info(region,global_tb_dict)
embed_rows(region,2000)

Embedding rows: 100%|██████████| 2/2 [00:25<00:00, 12.69s/it]


In [45]:
region = "Emerging"
embed_table_info(region,emerging_tb_dict)
embed_rows(region,2000)

Embedding rows: 100%|██████████| 2/2 [00:25<00:00, 12.61s/it]


In [46]:
region = "Japan"
embed_table_info(region,japan_tb_dict)
embed_rows(region,2000)

Embedding rows: 100%|██████████| 2/2 [00:45<00:00, 22.99s/it]


In [48]:
region = "AUS_NZ_CANADA"
embed_table_info(region,aus_nz_canada_tb_dict)
embed_rows(region,2000)

Embedding rows: 100%|██████████| 2/2 [00:25<00:00, 12.68s/it]


## TEXT-TO-SQL PIPELINE

### LOAD DATABASE

In [10]:
db_dict = {
    "US":us_engine,
    "India":india_engine,
    "China":china_engine,
    "Europe":europe_engine,
    "Global":global_engine,
    "AUS_NZ_CANADA":aus_nz_canada_engine,
    "Japan":japan_engine,
    "Emerging":emerging_engine,
}

def get_collections_db(region:str):
    # Get the database for the given region, table collection and row collection
    client = chromadb.PersistentClient(path=f"{region}_TABLE")
    table_collection = client.get_collection(name="table",embedding_function=emb_fn)
    row_collection = client.get_collection(name="rows",embedding_function=emb_fn)
    return [db_dict[region],table_collection,row_collection]

In [11]:
db_collection_dict = {
    "US":get_collections_db("US"),
    "India":get_collections_db("India"),
    "China":get_collections_db("China"),
    "Europe":get_collections_db("Europe"),
    "Global":get_collections_db("Global"),
    "AUS_NZ_CANADA":get_collections_db("AUS_NZ_CANADA"),
    "Japan":get_collections_db("Japan"),
    "Emerging":get_collections_db("Emerging"),
}

In [12]:
load_dotenv(override=True)
text_to_sql = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=1024)
sql_to_answer = dspy.OpenAI(model='gpt-3.5-turbo',max_tokens=1024)

# DSPy signature for converting text to SQL query
class TextToSQLAnswer(dspy.Signature):
    """Convert natural language text to SQL using suitable schema(s) from multiple schema choices"""

    question:str = dspy.InputField(desc="natural language input which will be converted to SQL")
    relevant_table_schemas_rows:str = dspy.InputField(desc="Multiple possible tables which has table name and corresponding columns, along with relevant rows from the table (values in the same order as columns above)")
    sql:str = dspy.OutputField(desc="Generate syntactically correct sqlite query with correct column names using suitable tables(s) and its rows.\n Don't forget to add distinct.\n Please rename the returned columns into suitable names.\n DON'T OUTPUT anything else other than the sqlite query")

# DSPy signature for converting SQL query and question to natural language text
class SQLReturnToAnswer(dspy.Signature):
    """Answer the question using the rows from the SQL query"""
    question:str = dspy.InputField()
    sql:str = dspy.InputField(desc="sqlite query that generated the rows")
    relevant_rows:str = dspy.InputField(desc="relevant rows to answer the question")
    answer:str = dspy.OutputField(desc="answer to the question using relevant rows and the sql query")

# If there is an SQLError, then rectify the error by trying again
class SQLRectifier(dspy.Signature):
    """Correct the SQL query to resolve the error using the proper table names, columns and rows"""  
    input_sql:str = dspy.InputField(desc="sqlite query that needs to be fixed")
    error_str: str = dspy.InputField(desc="error that needs to be resolved")
    relevant_table_schemas_rows:str = dspy.InputField(desc="Multiple possible tables which has table name and corresponding columns, along with relevant rows from the table (values in the same order as columns above)")
    sql:str = dspy.OutputField(desc="corrected sqlite query to resolve the error and remove and any invalid syntax in the query.\n Don't output anything else other than the sqlite query")

dspy.settings.configure(lm=text_to_sql)

# Filter out the SQL Query
def process_sql_str(sql_str:str):
    sql_str = sql_str.replace("```","")
    sql_str = sql_str.replace("sql","")
    sql_str = sql_str.strip()
    return sql_str

<p align="center">
  <img src="https://raw.githubusercontent.com/Athe-kunal/Text-to-SQL/main/Schema.png" alt="Sublime's custom image"/>
</p>

In [13]:
def get_table_results(table_collection_,question:str):
    # question_emb = emb_fn.embed_with_retries(question)[0]
    # Get the table results for the given question
    table_results = table_collection_.query(
        query_texts = question,
        n_results = 5
    )
    # print(table_results['documents'][0])
    return table_results

def get_row_results(row_collection_,question,table_name:str):
    # Get the row results for the given question
    row_results = row_collection_.query(
        query_texts = question,
        where = {"table_name":table_name},
        n_results = 5
    )
    print(row_results['documents'][0])
    return row_results

In [50]:
from typing import Any

class TextToSQLQueryModule(dspy.Module):
    """Text to SQL to final module"""
    def __init__(self,region:str,use_cot:bool=True,max_retries:int=3):
        """Text to Answer init module

        Args:
            region (str): Region for which the module will be used.
            use_cot (bool, optional): Whether to use chain of thought for sql query generation. Defaults to True.
            max_retries (int, optional): Number of max retries for SQLError. Defaults to 3.
        """
        super().__init__()
        self.region = region
        db,table_collection,row_collection = db_collection_dict[region]
        # print(db,table_collection,row_collection)
        self.table_collection = table_collection
        self.use_cot = use_cot
        self.db = db
        self.row_collection = row_collection
        if self.use_cot == True:
            self.sqlAnswer = dspy.ChainOfThought(TextToSQLAnswer)
        else:
            self.sqlAnswer = dspy.Predict(TextToSQLAnswer)
        self.final_output = dspy.Predict(SQLReturnToAnswer)
        self.max_tries = max_retries
        # Initialize the sql rectifier with CoT reasoning
        self.sql_rectifier = dspy.ChainOfThought(SQLRectifier,rationale_type=dspy.OutputField(
            prefix="Reasoning: Let's think step by step in order to",
            desc="${produce the answer}. We ..."
        ))
    
    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        return self.forward(*args, **kwargs)
        
    def forward(self,question):
        # Embed the question with embedding function
        question_emb = emb_fn([question])[0]
        # Retrieve the relevant tables from table schema and table summary
        docs = self.table_collection.query(
            query_embeddings = question_emb,
            n_results = 5
        )
        # docs = get_table_results(db_collection_dict[self.region][1],question)
        relevant_rows_schemas = ""
    
        existing_table_names = []

        for table_idx,metadata_name in enumerate(docs['metadatas'][0]):
            table_metadata = metadata_name['table_metadata']
            table_name = metadata_name['table_name']
            # If the table name is already in the list of existing table names, skip it
            # if table_name in existing_table_names: 
            #     continue
            existing_table_names.append(table_name)
            # Retrieve the relevant rows from the current table
            rows = self.row_collection.query(
                query_embeddings = question_emb,
                n_results = 5,
                # where clause to filter the rows
                where = {"table_name":table_name}
            )
            # Retrieve the relevant table with the schema and summary
            relevant_rows_schemas += f'Table name: {table_name} \n'
            relevant_rows_schemas += "/* \n"
            for match in re.finditer("columns: ",table_metadata):
                cols_end = match.end()
            relevant_rows_schemas += "col : " + " | ".join(table_metadata[cols_end:].split(", ")) + "\n"
            for row_idx,row in enumerate(rows['metadatas'][0]):
                # Get the relevant rows from the current table
                # relevant_rows_schemas += f'\tRow {row_idx+1} from table {table_name}: {row["full_rows"]}\n'
                relevant_rows_schemas += f'row {row_idx+1} : {" | ".join(row["full_rows"].split(", "))}\n'
            relevant_rows_schemas += "*/" + '\n\n'
        print(colored(relevant_rows_schemas,"yellow"))
        # return 
        sql_query = self.sqlAnswer(question=question,relevant_table_schemas_rows=relevant_rows_schemas)

        num_tries = 0
        print(sql_query)
        while num_tries <= self.max_tries:
            with self.db.connect() as conn:
                try:
                    # Try executing the sql query for the database
                    result = conn.execute(text(process_sql_str(sql_query.sql)))
                    num_tries = self.max_tries + 1
                except Exception as error:
                    # If there is an sql error, then try again with the sql rectifier
                    print(colored(str(error),'red'))
                    sql_query = self.sql_rectifier(input_sql=sql_query.sql,error_str=str(error),relevant_table_schemas_rows=relevant_rows_schemas)
                    print(colored(sql_query.rationale,'green'))
                    print()
                    print(colored(sql_query.sql,'green'))
                    # If all the num_retries are exhausted, then exit the program
                    num_tries += 1
                    if num_tries == self.max_tries+1:
                        return sql_query,error
        # With the retrieved rows from the database, then try to answer the question with dspy context
        with dspy.context(lm=sql_to_answer):
            row_str = ""
            key = tuple(result.keys())
            for row in result.fetchall():
                for r,k in zip(row,key):
                    row_str += f" {k} = {r},"
                row_str = row_str[:-1]
                row_str += "\n"
            print(f"Extracted rows: {row_str}")
            final_answer = self.final_output(question=question,sql=sql_query.sql,relevant_rows=row_str)
            return final_answer
tsql_ = TextToSQLQueryModule("US")
question = "What is the ebitda of software and packaging industry?"
sq = tsql_(question = question)

[33mTable name: margin 
/* 
col : Industry_Name (VARCHAR) | Number_of_firms (INTEGER) | Gross_Margin (INTEGER) | Net_Margin (INTEGER) | Pre_tax_Pre_stock_compensation_Operating_Margin (INTEGER) | Pre_tax_Unadjusted_Operating_Margin (INTEGER) | After_tax_Unadjusted_Operating_Margin (INTEGER) | Pre_tax_Lease_adjusted_Margin (INTEGER) | After_tax_Lease_Adjusted_Margin (INTEGER) | Pre_tax_Lease_R_D_adj_Margin (INTEGER) | After_tax_Lease_R_D_adj_Margin (INTEGER) | EBITDA_Sales (INTEGER) | EBITDASG_A_Sales (INTEGER) | EBITDAR_D_Sales (INTEGER) | COGS_Sales (INTEGER) | R_D_Sales (INTEGER) | SG_A_Sales (INTEGER) | Stock_Based_Compensation_Sales (INTEGER) | Lease_Expense_Sales (INTEGER)
row 1 : Packaging & Container | 22 | 0.2171338247779519 | 0.0285269723092998 | 0.1016364547958028 | 0.0975946017284918 | 0.0799113339135634 | 0.0997355299061771 | 0.0816643450787498 | 0.0998025797146062 | 0.0817313948871789 | 0.1571078511407175 | 0.2522079768591083 | 0.1617111987330198 | 0.782866175222048 | 0.0

In [51]:
print(sq)

Prediction(
    answer='The EBITDA of the software industry is as follows:\n- Software (Entertainment): $156,210.97 million\n- Software (Internet): $1,980.75 million\n- Software (System & Application): $170,251.41 million'
)


In [52]:
# tsql = TextToSQLQueryModule("US")
# sq = tsql(question="What is the effective tax rate of the healthcare industry?")
sq = tsql_(question="What is the EBITDA value and number of firms for all the Software industries, semiconductor industry and aerospace?")
# sq = tsql("What is the debt to EBITDA ratio for software industry?")

[33mTable name: EVA 
/* 
col : Industry_Name (VARCHAR) | Number_of_Firms (INTEGER) | Beta (INTEGER) | ROE (INTEGER) | Cost_of_Equity (INTEGER) | _ROE_COE_ (INTEGER) | BV_of_Equity (INTEGER) | Equity_EVA_US_millions_ (INTEGER) | ROC (INTEGER) | Cost_of_Capital (INTEGER) | _ROC_WACC_ (INTEGER) | BV_of_Capital (INTEGER) | EVA_US_millions_ (INTEGER) | E_D_E_ (INTEGER) | Std_Dev_in_Stock (INTEGER) | Cost_of_Debt (INTEGER) | Tax_Rate (INTEGER) | After_tax_Cost_of_Debt (INTEGER) | D_D_E_ (INTEGER)
row 1 : Aerospace/Defense | 70 | 1.0764050153356324 | 0.1319148365195281 | 0.088314630705439 | 0.043600205814089 | 145775.301 | 6355.833126210783 | 0.1616771766434065 | 0.0781335598681736 | 0.0835436167752329 | 195266.1836785369 | 16313.243218401924 | 0.7970822237983712 | 0.3640214967123067 | 0.050855 | 0.0727529390842114 | 0.03814125 | 0.20291777620162
row 2 : Total Market (without financials) | 5214 | 1.097522642247185 | 0.1659255247192018 | 0.0892860415433705 | 0.0766394831758313 | 8566043.05699

In [53]:
print(sq.answer)

The EBITDA value for the Software (Entertainment) industry is $156,210.97 million with 84 firms. For the Semiconductor Equip industry, the EBITDA value is $23,479.95 million with 30 firms. Lastly, for the Aerospace/Defense industry, the EBITDA value is $48,918.45 million with 70 firms.


In [54]:
tsql = TextToSQLQueryModule("India")
sq = tsql(question="What is the beta value and number of firms for all the Software industries and semiconductor industry?")
print(sq.answer)

[33mTable name: betaIndia 
/* 
col : Industry_Name (VARCHAR) | Number_of_firms (INTEGER) | Beta_ (INTEGER) | D_E_Ratio (INTEGER) | Effective_Tax_rate (INTEGER) | Unlevered_beta (INTEGER) | Cash_Firm_value (INTEGER) | Unlevered_beta_corrected_for_cash (INTEGER) | HiLo_Risk (INTEGER) | Standard_deviation_of_equity (INTEGER) | Standard_deviation_in_operating_income_last_10_years_ (INTEGER) | Beta_2020 (INTEGER) | Beta_2021 (INTEGER) | Beta_2022 (INTEGER) | Beta_2023 (INTEGER) | Average_Beta_2019_23 (INTEGER)
row 1 : Semiconductor | 8 | 1.611640106554315 | 0.0408116499112951 | 0.1361426723519776 | 1.566877312715955 | 0.0034062258080553 | 1.572232692288697 | 0.5628222021115665 | 0.5010518370389238 | 12.74567091312641 | 0.73 | 1.16 | 1.9836144211648223 | 1.9836144211648223 | 1.48589230692366
row 2 : Total Market (without financials) | 3850 | 0.8198477825445601 | 0.1608082521880037 | 0.1670198672216159 | 0.7368982579173127 | 0.0296361733703686 | 0.7594040891618811 | 0.384976556985476 | 0.369

In [55]:
print(sq.answer)

The beta value and number of firms for the Software industries and semiconductor industry are as follows:

- Industry: Semiconductor, Beta Value: 1.572232692288697, Number of Firms: 8
- Industry: Semiconductor, Beta Value: 1.611640106554315, Number of Firms: 8
- Industry: Software (Entertainment), Beta Value: 0.7830360024669405, Number of Firms: 5
- Industry: Software (Entertainment), Beta Value: 0.7867794606333218, Number of Firms: 5
- Industry: Software (Internet), Beta Value: -0.0058016928550091, Number of Firms: 6
- Industry: Software (Internet), Beta Value: -0.0057433681694908, Number of Firms: 6
- Industry: Software (System & Application), Beta Value: 1.164011103524074, Number of Firms: 73
- Industry: Software (System & Application), Beta Value: 1.2041292290782728, Number of Firms: 73


In [56]:
tsql = TextToSQLQueryModule("China")
sq = tsql(question="What is the beta value and number of firms for all the Software industries and semiconductor industry?")
print(sq.answer)

[33mTable name: betaChina 
/* 
col : Industry_Name (VARCHAR) | Number_of_firms (INTEGER) | Beta_ (INTEGER) | D_E_Ratio (INTEGER) | Effective_Tax_rate (INTEGER) | Unlevered_beta (INTEGER) | Cash_Firm_value (INTEGER) | Unlevered_beta_corrected_for_cash (INTEGER) | HiLo_Risk (INTEGER) | Standard_deviation_of_equity (INTEGER) | Standard_deviation_in_operating_income_last_10_years_ (INTEGER)
row 1 : Semiconductor | 173 | 1.3766066410015918 | 0.1214640196981001 | 0.0455831075871332 | 1.261670924483595 | 0.1353894923931989 | 1.4592361686371793 | 0.3256484142873909 | 0.3505300755906819 | 0.93670130463571
row 2 : Total Market (without financials) | 7161 | 1.0686360714269856 | 0.4705306537504954 | 0.0996397852701986 | 0.7898866574410776 | 0.1404090933852625 | 0.9189099737592956 | 0.3186107366027715 | 0.3211951486316037 | 0.26343082981471
row 3 : Semiconductor Equip | 59 | 1.1458070496329555 | 0.1005963129277871 | 0.072001108546688 | 1.0654237722383604 | 0.0896726175453628 | 1.170374299151055 | 

In [57]:
print(sq.answer)

Beta value and number of firms for the Software (Entertainment) industry are Beta_Value = 0.9517387989546392, Number_of_Firms = 19. For the Software (Internet) industry, the values are Beta_Value = 1.3766066410015918, Number_of_Firms = 173. And for the Semiconductor industry, the values are Beta_Value = 1.9284684177392413, Number_of_Firms = 33.


In [58]:
tsql = TextToSQLQueryModule("US")
sq = tsql(question="What is the average tax rate of all healthcare industries?")
print(sq.answer)

[33mTable name: taxrate 
/* 
col : Industry_name (VARCHAR) | Number_of_firms (INTEGER) | Total_Taxable_Income (INTEGER) | Total_Taxes_Paid_Accrual_ (INTEGER) | Total_Cash_Taxes_Paid (INTEGER) | Cash_Taxes_Accrual_Taxes (INTEGER) | Average_across_all_companies (INTEGER) | Average_across_only_money_making_companies (INTEGER) | Aggregate_tax_rate (INTEGER) | Average_across_only_money_making_companies2 (INTEGER) | Aggregate_tax_rate3 (INTEGER)
row 1 : Healthcare Products | 230 | 26657.20900000001 | 5486.0970000000025 | 7358.37 | 1.341275956294611 | 0.0481441302819232 | 0.2058016276197557 | 0.2473315283348905 | 0.2760367748926752 | 0.35300275681386
row 2 : Hospitals/Healthcare Facilities | 32 | 12803.409999999998 | 2598.52 | 2117.581 | 0.8149181072302696 | 0.0685657680171947 | 0.2029553064378943 | 0.2241515389301438 | 0.1653919541747081 | 0.18582584339767
row 3 : Healthcare Support Services | 119 | 76367.83500000002 | 17205.211 | 18119.770000000008 | 1.0531559304910594 | 0.0808427903603178

In [60]:
print(sq.answer)

The average tax rate of all healthcare industries is 0.24116333460436337.


In [61]:
sq = tsql(question="Give me the average tax rate of all healthcare industries where revenues per employee is more than 1 million?")
print(sq.answer)

[33mTable name: taxrate 
/* 
col : Industry_name (VARCHAR) | Number_of_firms (INTEGER) | Total_Taxable_Income (INTEGER) | Total_Taxes_Paid_Accrual_ (INTEGER) | Total_Cash_Taxes_Paid (INTEGER) | Cash_Taxes_Accrual_Taxes (INTEGER) | Average_across_all_companies (INTEGER) | Average_across_only_money_making_companies (INTEGER) | Aggregate_tax_rate (INTEGER) | Average_across_only_money_making_companies2 (INTEGER) | Aggregate_tax_rate3 (INTEGER)
row 1 : Hospitals/Healthcare Facilities | 32 | 12803.409999999998 | 2598.52 | 2117.581 | 0.8149181072302696 | 0.0685657680171947 | 0.2029553064378943 | 0.2241515389301438 | 0.1653919541747081 | 0.18582584339767
row 2 : Healthcare Products | 230 | 26657.20900000001 | 5486.0970000000025 | 7358.37 | 1.341275956294611 | 0.0481441302819232 | 0.2058016276197557 | 0.2473315283348905 | 0.2760367748926752 | 0.35300275681386
row 3 : Healthcare Support Services | 119 | 76367.83500000002 | 17205.211 | 18119.770000000008 | 1.0531559304910594 | 0.0808427903603178

In [62]:
tsql = TextToSQLQueryModule("Europe")
sq = tsql(question="Give me the average tax rate of all banking industries where number of firms is more than 500?")
print(sq.answer)

[33mTable name: taxrateEurope 
/* 
col : Industry_name (VARCHAR) | Number_of_firms (INTEGER) | Total_Taxable_Income (INTEGER) | Total_Taxes_Paid_Accrual_ (INTEGER) | Total_Cash_Taxes_Paid (INTEGER) | Cash_Taxes_Accrual_Taxes (INTEGER) | Average_across_all_companies (INTEGER) | Average_across_only_money_making_companies (INTEGER) | Aggregate_tax_rate (INTEGER) | Average_across_only_money_making_companies_1 (INTEGER) | Aggregate_tax_rate_1 (INTEGER)
row 1 : Financial Svcs. (Non-bank & Insurance) | 143 | 67531.72600000002 | 5859.747000000001 | 4199.801999999999 | 0.716720704835891 | 0.133380349433826 | 0.0867702833480074 | 0.0851491015690155 | 0.0621900586399938 | 0.0637353624270
row 2 : Banks (Regional) | 72 | 7968.15 | 1353.5790000000002 | 675.7199999999998 | 0.499209872493589 | 0.1876823623247109 | 0.169873684606841 | 0.1693503510852582 | 0.0848026204325972 | 0.08480262043259
row 3 : Total Market (without financials) | 6149 | 1358894.4370000027 | 346515.07499999995 | 347363.1549999990