In [3]:
from langchain_chroma import Chroma
from langchain_community.document_loaders import TextLoader
from langchain_community.embeddings.sentence_transformer import (
    SentenceTransformerEmbeddings,
)
from langchain_core.documents import Document

# Sample documents

In [8]:
tables_documents = [
    Document(page_content="main.customers", metadata= {'description': "This table has basic information about a customer, as well as some derived facts based on a customer's orders", 'columns': 'customer_id,first_name,last_name,first_order,most_recent_order,number_of_orders,total_order_amount'} ),
    Document(page_content="main.orders", metadata= {'description': 'This table has basic information about orders, as well as some derived facts based on payments', 'columns': 'order_id,customer_id,order_date,status,amount,credit_card_amount,coupon_amount,bank_transfer_amount,gift_card_amount'} ),
    Document(page_content="main.stg_customers", metadata= {'description': '', 'columns': 'customer_id'} ),
    Document(page_content="main.stg_orders", metadata= {'description': '', 'columns': 'order_id,status'} ),
    Document(page_content="main.stg_payments", metadata= {'description': '', 'columns': 'payment_id,payment_method'} ),
]

queries = [
    {
        "description": "total revenue by fiscal month",
        "sql": "select d.fiscal_month, sum(f.sales) as revenue from core.profitability_fact f join core.date d on f.date_fk = d.date_key"
    },
    {
        "description": "active customers by fiscal month",
        "sql": "select d.fiscal_month, count(distinct f.customer_fk) as customer_count from core.profitability_fact f join core.date d on f.date_fk = d.date_key where f.sales > 0"
    },
    {
        "description": "order count by customer",
        "sql": "SELECT customer_id, COUNT(order_id) AS order_count FROM main.orders GROUP BY customer_id"
    }
]
queries_documents = [
    Document(page_content=q["description"], metadata=q)
    for q in queries
]

# Langchain Chroma basic

In [5]:
# create the open-source embedding function
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")

  from .autonotebook import tqdm as notebook_tqdm
modules.json: 100%|██████████| 349/349 [00:00<00:00, 175kB/s]
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
config_sentence_transformers.json: 100%|██████████| 116/116 [00:00<00:00, 123kB/s]
README.md: 100%|██████████| 10.7k/10.7k [00:00<00:00, 12.9MB/s]
sentence_bert_config.json: 100%|██████████| 53.0/53.0 [00:00<00:00, 67.3kB/s]
config.json: 100%|██████████| 612/612 [00:00<00:00, 1.15MB/s]
model.safetensors: 100%|██████████| 90.9M/90.9M [00:39<00:00, 2.32MB/s]
tokenizer_config.json: 100%|██████████| 350/350 [00:00<?, ?B/s] 
vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 1.81MB/s]
tokenizer.json: 100%|██████████| 466k/466k [00:00<00:00, 691kB/s]
special_tokens_map.json: 100%|██████████| 112/112 [00:00<?, ?B/s] 
1_Po

In [6]:
# load it into Chroma
db = Chroma.from_documents(tables_documents, embedding_function)

In [9]:
# query it
query = "order status"
docs = db.similarity_search(query)
docs

[Document(page_content='main.orders', metadata={'columns': 'order_id,customer_id,order_date,status,amount,credit_card_amount,coupon_amount,bank_transfer_amount,gift_card_amount', 'description': 'This table has basic information about orders, as well as some derived facts based on payments'}),
 Document(page_content='main.stg_orders', metadata={'columns': 'order_id,status', 'description': ''}),
 Document(page_content='main.customers', metadata={'columns': 'customer_id,first_name,last_name,first_order,most_recent_order,number_of_orders,total_order_amount', 'description': "This table has basic information about a customer, as well as some derived facts based on a customer's orders"}),
 Document(page_content='main.stg_payments', metadata={'columns': 'payment_id,payment_method', 'description': ''})]

# Chroma saving to disk

In [10]:
# save to disk
db2 = Chroma.from_documents(tables_documents, embedding_function, persist_directory="./chroma_db")
docs = db2.similarity_search("order status")
docs

[Document(page_content='main.orders', metadata={'columns': 'order_id,customer_id,order_date,status,amount,credit_card_amount,coupon_amount,bank_transfer_amount,gift_card_amount', 'description': 'This table has basic information about orders, as well as some derived facts based on payments'}),
 Document(page_content='main.stg_orders', metadata={'columns': 'order_id,status', 'description': ''}),
 Document(page_content='main.customers', metadata={'columns': 'customer_id,first_name,last_name,first_order,most_recent_order,number_of_orders,total_order_amount', 'description': "This table has basic information about a customer, as well as some derived facts based on a customer's orders"}),
 Document(page_content='main.stg_payments', metadata={'columns': 'payment_id,payment_method', 'description': ''})]

In [12]:
# load from disk
db3 = Chroma(persist_directory="./chroma_db", embedding_function=embedding_function)
docs = db3.similarity_search("order status")
docs

[Document(page_content='main.orders', metadata={'columns': 'order_id,customer_id,order_date,status,amount,credit_card_amount,coupon_amount,bank_transfer_amount,gift_card_amount', 'description': 'This table has basic information about orders, as well as some derived facts based on payments'}),
 Document(page_content='main.stg_orders', metadata={'columns': 'order_id,status', 'description': ''}),
 Document(page_content='main.customers', metadata={'columns': 'customer_id,first_name,last_name,first_order,most_recent_order,number_of_orders,total_order_amount', 'description': "This table has basic information about a customer, as well as some derived facts based on a customer's orders"}),
 Document(page_content='main.stg_payments', metadata={'columns': 'payment_id,payment_method', 'description': ''})]

# Chroma multiple collections

In [14]:
persist_directory = 'chroma_multiple_collections_db'
tables_chroma = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name='tables')
queries_chroma = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name='queries')

In [15]:
tables_chroma.add_documents(tables_documents)

['48ae7416-63f8-4ad4-a5dc-b79f210a6aba',
 'da5b5184-4e41-464f-a735-d3d97d8f52be',
 '91736ede-4d28-4466-b479-940a71bf70e8',
 '30b9f2c1-a9f2-49c2-be59-45d258d6e871',
 'b8c5f147-762b-456e-b6c2-1308e95de0ca']

In [16]:
queries_chroma.add_documents(queries_documents)

['c483aebd-9b9f-4b66-9cf7-6d71feecb966',
 '91840a42-5506-4da6-9f44-9125dec81479',
 '699de194-fb63-4bb6-9fdf-7c5a51bb23a4']

In [17]:
queries_chroma.similarity_search("order status")

Number of requested results 4 is greater than number of elements in index 3, updating n_results = 3


[Document(page_content='order count by customer', metadata={'description': 'order count by customer', 'sql': 'SELECT customer_id, COUNT(order_id) AS order_count FROM main.orders GROUP BY customer_id'}),
 Document(page_content='active customers by fiscal month', metadata={'description': 'active customers by fiscal month', 'sql': 'select d.fiscal_month, count(distinct f.customer_fk) as customer_count from core.profitability_fact f join core.date d on f.date_fk = d.date_key where f.sales > 0'}),
 Document(page_content='total revenue by fiscal month', metadata={'description': 'total revenue by fiscal month', 'sql': 'select d.fiscal_month, sum(f.sales) as revenue from core.profitability_fact f join core.date d on f.date_fk = d.date_key'})]

In [18]:
tables_chroma.similarity_search("order status")

[Document(page_content='main.orders', metadata={'columns': 'order_id,customer_id,order_date,status,amount,credit_card_amount,coupon_amount,bank_transfer_amount,gift_card_amount', 'description': 'This table has basic information about orders, as well as some derived facts based on payments'}),
 Document(page_content='main.stg_orders', metadata={'columns': 'order_id,status', 'description': ''}),
 Document(page_content='main.customers', metadata={'columns': 'customer_id,first_name,last_name,first_order,most_recent_order,number_of_orders,total_order_amount', 'description': "This table has basic information about a customer, as well as some derived facts based on a customer's orders"}),
 Document(page_content='main.stg_payments', metadata={'columns': 'payment_id,payment_method', 'description': ''})]

# ChromaDB - common functions for saving and searching

In [36]:
from typing import List

In [46]:
def _save_items(
    #self,
    items: List,
    chroma: Chroma,
    text_mapper,
    metadata_mapper,
):
    documents = [
        Document(
            text=text_mapper(item),
            metadata=metadata_mapper(item),
        )
        for item in items
    ]
    chroma.add_documents(documents)


def _find_similar_items(
    #self,
    query: str,
    chroma: Chroma,
    similarity_top_k: int,
    similarity_cutoff: float,
    node_metadata_mapper
):
    top_n = chroma.similarity_search_with_score(query, k=similarity_top_k)
    return [
        node_metadata_mapper(**n[0].metadata)
        for n in top_n
        if n[1] > similarity_cutoff
    ]

In [47]:
def pass_mapper(**kwargs):
    return kwargs

_find_similar_items(
    query="order status",
    chroma=tables_chroma,
    similarity_top_k=5,
    similarity_cutoff=.1,
    node_metadata_mapper=pass_mapper
)

[{'columns': 'order_id,customer_id,order_date,status,amount,credit_card_amount,coupon_amount,bank_transfer_amount,gift_card_amount',
  'description': 'This table has basic information about orders, as well as some derived facts based on payments'},
 {'columns': 'order_id,status', 'description': ''},
 {'columns': 'customer_id,first_name,last_name,first_order,most_recent_order,number_of_orders,total_order_amount',
  'description': "This table has basic information about a customer, as well as some derived facts based on a customer's orders"},
 {'columns': 'payment_id,payment_method', 'description': ''},
 {'columns': 'customer_id', 'description': ''}]

# SQL Generation

## Hugging Face basic

In [19]:
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline

In [20]:
hf_llm = HuggingFacePipeline.from_model_id(
    model_id="rakeshkiriyath/gpt2Medium_text_to_sql",
    task="text-generation",
)

In [28]:
question = "I need a list of employees who joined in the company last 6 months with a salary hike of 30%"

In [29]:
hf_llm.invoke(question)

'I need a list of employees who joined in the company last 6 months with a salary hike of 30% and less than 600 reviews.SELECT employees FROM employees WHERE last_joined_in_company = "6 months" OR salary_increase ='

## HuggingFace with prompt template

In [30]:
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

In [23]:
template = """Question: {question}"""
prompt = PromptTemplate.from_template(template)

In [33]:
chain = prompt | hf_llm #| StrOutputParser()

In [34]:
chain.invoke({"question": question})

'Question: I need a list of employees who joined in the company last 6 months with a salary hike of 30%SELECT Employees FROM employees WHERE Last_6_month_joined < 6 GROUP BY Last_6_month_joined HAVING SUM'