<a href="https://colab.research.google.com/github/trancethehuman/ai-workshop-code/blob/main/Hybrid_Search_Workshop.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Introduction

## Setup documents for knowledge base

In [None]:
!wget https://raw.githubusercontent.com/trancethehuman/ai-workshop-code/main/datasets/legal_text_classification_first_1000.csv -q

In [None]:
pip install tabulate -q

In [None]:
import pandas as pd
from tabulate import tabulate
import textwrap

# Read the CSV file
df = pd.read_csv('legal_text_classification_first_1000.csv')

# Function to wrap text
def wrap_text(text, width):
    return "\n".join(textwrap.wrap(text, width))

# Set the width for wrapping
wrap_width = 60

# Create a copy of the DataFrame for display so we don't modify the original data
df_display = df.head(3).copy()

# Apply the wrap_text function to each column in the copied DataFrame
df_display = df_display.applymap(lambda x: wrap_text(str(x), wrap_width) if isinstance(x, str) else x)

# Display the modified DataFrame using tabulate
print(tabulate(df_display, headers='keys', tablefmt='pretty'))

In [None]:
row_count = df.shape[0]

print(f"The DataFrame contains {row_count} rows.")

In [None]:
pip install openai -q

In [None]:
import getpass

OPENAI_API_KEY = ""

In [None]:
pip install langchain-openai -q

In [None]:
from langchain_openai import OpenAIEmbeddings

EMBEDDINGS_DIMENSIONS = 512

embedding_client = OpenAIEmbeddings(api_key=OPENAI_API_KEY,
    model="text-embedding-3-large", dimensions=EMBEDDINGS_DIMENSIONS)

In [None]:
pip install tiktoken -q

In [None]:
import tiktoken

def num_tokens_from_string(string: str, encoding_name: str) -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.get_encoding(encoding_name)
    num_tokens = len(encoding.encode(string))
    return num_tokens

In [None]:
import pandas as pd
import time
from typing import List, Dict

def get_embeddings(
    df: pd.DataFrame,
    num_rows: int = 2,
    max_tokens: int = 8191,
    encoding_name: str = "cl100k_base",
    price_per_token: float = 0.13 / 1000000
) -> List[Dict]:
    """
    Process a DataFrame, adding embeddings to a specified number of rows,
    ensuring the total number of tokens per request does not exceed max_tokens.
    """

    result_list = []
    batch_case_text = []
    batch_ids = []
    current_tokens = 0
    total_tokens = 0
    total_batches = 0

    # Start the timer to see how long it'd take to get all the embeddings
    start_time = time.time()

    for index, row in df.head(num_rows).iterrows():
        case_text = row['case_text']

        if not isinstance(case_text, str) or not case_text.strip():
            continue

        tokens = num_tokens_from_string(case_text, encoding_name)

        # Check if adding this text would exceed the token limit
        if current_tokens + tokens > max_tokens:
            if batch_case_text:
                # Process the current batch
                embeddings = embedding_client.embed_documents(batch_case_text)
                for i, text in enumerate(batch_case_text):
                    result_list.append({
                        'case_id': batch_ids[i],
                        'case_title': df.loc[df['case_id'] == batch_ids[i], 'case_title'].values[0],
                        'case_text': text,
                        'embeddings': embeddings[i]
                    })

            # Reset batch and start a new one with the current text
            batch_case_text = [case_text]
            batch_ids = [row['case_id']]
            current_tokens = tokens
            total_batches += 1  # Increment batch count
        else:
            # Add the text to the current batch
            batch_case_text.append(case_text)
            batch_ids.append(row['case_id'])
            current_tokens += tokens

        total_tokens += tokens  # Increment total token count

    # Process the final batch
    if batch_case_text:
        embeddings = embedding_client.embed_documents(batch_case_text)
        for i, text in enumerate(batch_case_text):
            result_list.append({
                'case_id': batch_ids[i],
                'case_title': df.loc[df['case_id'] == batch_ids[i], 'case_title'].values[0],
                'case_text': text,
                'embeddings': embeddings[i]
            })
        total_batches += 1  # Increment batch count for the final batch

    end_time = time.time()
    duration = end_time - start_time

    # Print the statistics
    print(f"Completed in {duration:.2f} seconds.")
    print(f"Total number of tokens: {total_tokens:,}")
    print(f"Total number of batches: {total_batches:,}")
    print(f"Money burned: ${total_tokens * price_per_token:.6f} - Thanks, Invest Ottawa ❤️")

    return result_list

In [None]:
test_cases_with_embeddings = get_embeddings(df, num_rows=3)

# For displaying our test
MAX_TABLE_WIDTH = 60
MAX_CHARACTERS_LENGTH = 150

# Utility function to wrap and truncate text
def wrap_and_truncate_text(text, width=MAX_TABLE_WIDTH, max_length=MAX_CHARACTERS_LENGTH):
    if isinstance(text, list):
        text = str(text)
    if len(text) > max_length:
        text = text[:max_length] + '...'
    return "\n".join(textwrap.wrap(text, width))

# Apply wrapping and truncation to all items in the test data before we display them so we get nice table
wrapped_test_cases_with_embeddings = [
    {key: wrap_and_truncate_text(value) for key, value in item.items()}
    for item in test_cases_with_embeddings
]

# Define the alignment for each column (quite funny)
colalign = ("left", "left", "left", "left")

print(tabulate(wrapped_test_cases_with_embeddings, headers='keys', tablefmt='pretty', colalign=colalign))

In [None]:
all_cases_with_embeddings = get_embeddings(df, num_rows=row_count)

## Setup vector databases' clients

### Pinecone

In [None]:
pip install pinecone-client pinecone-notebooks pinecone-text -q

In [None]:
from pinecone_notebooks.colab import Authenticate

Authenticate()

In [None]:
import os
from pinecone import Pinecone, ServerlessSpec

pc = Pinecone(api_key=os.environ.get('PINECONE_API_KEY'))

In [None]:
index_name = "hybridhearchexperiment"

# Just checking to see if this index already exist
existing_indexes = pc.list_indexes().names()

# If index doesn't exist yet, then delete it and create one (we're starting from scratch)
if index_name in existing_indexes:
  pc.delete_index(index_name)
  print("Deleted old index.")

pc.create_index(
    name=index_name,
    dimension=EMBEDDINGS_DIMENSIONS,
    metric="dotproduct", # to use sparse-dense index (aka hybrid search) in Pinecone, they require us to use dotproduct.
    spec=ServerlessSpec(
        cloud='aws',
        region='us-east-1'
    )
)
# wait for index to be initialized
while not pc.describe_index(index_name).status['ready']:
    time.sleep(1)

# connect to index
pinecone_index = pc.Index(index_name)
# view index stats
pinecone_index.describe_index_stats()

In [None]:
from pinecone_text.sparse import BM25Encoder

bm25 = BM25Encoder()

In [None]:
!wget https://raw.githubusercontent.com/trancethehuman/ai-workshop-code/main/datasets/Legal_Text_Classification_Data_500_train.csv -q

In [None]:
# Load our training set into memory as a DataFrame
train_df = pd.read_csv('Legal_Text_Classification_Data_500_train.csv')

# Train baby
bm25.fit(train_df.get("case_text").astype(str).tolist())

In [None]:
from tqdm import tqdm

def group_embeddings_and_generate_sparse_vectors(cases, sparse_vector_model):
    all_cases_embeddings_and_sparse_vectors = []

    for case in tqdm(cases, desc="Processing cases"):
        case_id = case['case_id']
        case_title = case['case_title']
        case_text = case['case_text']
        embeddings = case['embeddings']

        # Encode the case text using a sparse vector model
        sparse_values = sparse_vector_model.encode_documents(case_text)

        # Create the new dictionary with the required structure
        new_case_dict = {
            'id': case_id,
            'sparse_values': sparse_values,
            'values': embeddings,
            'metadata': {
                'case_title': case_title,
                'case_text': case_text
            }
        }

        # Add the new dictionary to the list
        all_cases_embeddings_and_sparse_vectors.append(new_case_dict)

    return all_cases_embeddings_and_sparse_vectors


In [None]:
all_cases_embeddings_and_sparse_vectors = group_embeddings_and_generate_sparse_vectors(all_cases_with_embeddings, bm25)

In [None]:
# Prepare the data for tabulate
table_data = []
for case in all_cases_embeddings_and_sparse_vectors:
    table_data.append([
        case['id'],
        case['metadata']['case_title'][:30],
        case['metadata']['case_text'][:30],
        str(case['sparse_values'])[:30],
        str(case['values'])[:30],
    ])

# Limit the table_data to the first 10 rows
table_data = table_data[:10]

# Define the headers based on the keys
headers = ["ID", "Case Title", "Case Text", "Sparse Values", "Embeddings"]

# Print the table
print(tabulate(table_data, headers=headers, tablefmt="grid"))

In [None]:
# Upsert our hard work
pinecone_index.upsert(vectors=all_cases_embeddings_and_sparse_vectors, batch_size=100) # batch size is very important here because our request is a bit large.

# See if our index's stats changed
print(pinecone_index.describe_index_stats())

In [None]:
from typing import Any

def convert_string_query_to_vectors(query: str) -> Dict[str, Any]:
    dense_vector = embedding_client.embed_query(query)
    sparse_vector = bm25.encode_queries(query)

    return {
        "dense": dense_vector,
        "sparse": sparse_vector
    }

In [None]:
def hybrid_scale(dense, sparse, alpha: float):
    """Hybrid vector scaling using a convex combination

    alpha * dense + (1 - alpha) * sparse

    Args:
        dense: Array of floats representing
        sparse: a dict of `indices` and `values`
        alpha: float between 0 and 1 where 0 == sparse only
               and 1 == dense only
    """
    if alpha < 0 or alpha > 1:
        raise ValueError("Alpha must be between 0 and 1")
    # scale sparse and dense vectors to create hybrid search vecs
    hsparse = {
        'indices': sparse['indices'],
        'values':  [v * (1 - alpha) for v in sparse['values']]
    }
    hdense = [v * alpha for v in dense]
    return hdense, hsparse

### Weaviate

In [None]:
pip install weaviate-client -q

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/325.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m102.4/325.7 kB[0m [31m2.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m325.7/325.7 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.0/40.0 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m223.8/223.8 kB[0m [31m20.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m37.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m309.3/309.3 kB[0m [31m26.5 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is th

Then, we setup our Weaviate client.

In [None]:
WCD_URL = ""
WCD_API_KEY = ""

In [None]:
import weaviate
import weaviate.classes as wvc
import os
import requests
import json

weaviate_client = weaviate.connect_to_weaviate_cloud(
    cluster_url=WCD_URL,
    auth_credentials=weaviate.auth.AuthApiKey(WCD_API_KEY)
)

In [None]:
weaviate_index = None

if weaviate_client.collections.exists(index_name):
  print("Collection (index) already exists. Deleting and making a new one..")
  weaviate_client.collections.delete(index_name)

  print("Creating a new collection (index) with some metadata as properties")
  weaviate_client.collections.create(name=index_name, properties=[
            wvc.config.Property(
                name="case_id",
                data_type=wvc.config.DataType.TEXT,
                vectorize_property_name=False,
                ),
            wvc.config.Property(
                name="case_title",
                data_type=wvc.config.DataType.TEXT,
                vectorize_property_name=False,
            ),
            wvc.config.Property(
                name="case_text",
                data_type=wvc.config.DataType.TEXT,
                vectorize_property_name=False,
            )
        ])
  print("Weaviate collection (index) created.")

In [None]:
collection_weaviate = weaviate_client.collections.get(index_name)

# Get the total number of items to process
total_items = len(all_cases_embeddings_and_sparse_vectors)

with collection_weaviate.batch.dynamic() as batch:
    # Wrap the loop with tqdm for nice progress bar lol
    for i, d in tqdm(enumerate(all_cases_embeddings_and_sparse_vectors), total=total_items, desc="Uploading to Weaviate"):
        properties = {
            "case_id": d["id"],
            "case_title": d["metadata"]["case_title"],
            "case_text": d["metadata"]["case_text"],
        }

        custom_vector = d["values"]

        batch.add_object(
            properties=properties,
            vector=custom_vector
        )


## The search (setup the search query and determine what we're looking for)

In [None]:
def display_pinecone_result_in_nice_table(data):
    """
    Displays Pinecone result data in a nicely formatted table.

    Args:
    data (dict): The data to be displayed, in the format provided.

    Returns:
    str: A string representation of the data in a nicely formatted table.
    """
    # Extract relevant data for the table
    table_data = []
    for match in data['matches']:
        row = {
            'ID': match['id'],
            'Case Title': match['metadata']['case_title'],
            'Relevance Score': match['score'],
            'Case Text': match['metadata']['case_text'][:200] + '...' if len(match['metadata']['case_text']) > 200 else match['metadata']['case_text'],
        }
        table_data.append(row)

    df = pd.DataFrame(table_data)

    # Display the DataFrame in a nice table format using tabulate
    table = tabulate(df, headers='keys', tablefmt='grid')

    print(table)

In [None]:
search_query = "Whats the verdict from Palmer J in Macleay Nominees Pty"

In [None]:
the_case_we_need_to_find = df[df['case_id'] == "Case269"]

print(tabulate(the_case_we_need_to_find, headers='keys', tablefmt='pretty'))

In [None]:
search_query_as_vectors = convert_string_query_to_vectors(search_query)

## Let's test similarity search alone

### Pinecone

In [None]:
hdense, hsparse = hybrid_scale(search_query_as_vectors.get("dense"), search_query_as_vectors.get("sparse"), alpha=1)

In [None]:
pinecone_result = pinecone_index.query(
    top_k=3,
    vector=hdense,
    sparse_vector=hsparse,
    include_metadata=True
)

display_pinecone_result_in_nice_table(pinecone_result)

### Weaviate

We're going to pass our query as dense vectors in as search query, and only do similarity search with Weaviate.

In [None]:
from pprint import pprint

response = collection_weaviate.query.near_vector(
    near_vector=search_query_as_vectors["dense"],
    limit=2,
    return_metadata=wvc.query.MetadataQuery(certainty=True)
)

pprint(response)

## Hybrid Search (Pinecone and Weaviate)

### Pinecone

In [None]:
# First, we re-define the hybrid scale to weigh in more on keyword search (sparse vectors)
hdense, hsparse = hybrid_scale(search_query_as_vectors.get("dense"), search_query_as_vectors.get("sparse"), alpha=0.6)

# Then, we let it go ham
pinecone_result = pinecone_index.query(
    top_k=3,
    vector=hdense,
    sparse_vector=hsparse,
    include_metadata=True
)

display_pinecone_result_in_nice_table(pinecone_result)


### Weaviate

In [None]:
alpha_weaviate = 0.5

In [None]:
response = collection_weaviate.query.hybrid(
    query=search_query,
    vector=search_query_as_vectors["dense"],
    alpha=alpha_weaviate,
    limit=3,
)

pprint(response)

## Hybrid search part II: Full-text search (Postgres) + Reranker (Jina AI)

### Setup Postgres & pgvector with Supabase and vecs

In [None]:
pip install vecs supabase -q

In [None]:
from vecs import IndexArgsHNSW, IndexMeasure, create_client, IndexMethod
import vecs

DB_STRING = f"postgresql://postgres.gyiucazdpikoqpigfhvs:ntmWCuLtAfTeX8mk@aws-0-ca-central-1.pooler.supabase.com:6543/postgres"

# create vector store client
vx = vecs.create_client(DB_STRING)

In [None]:
list_of_vecs_collections = vx.list_collections()

# Find our collection and delete it if it already exists
matching_collection = next((col for col in list_of_vecs_collections if col.name == index_name), None)

if matching_collection:
    print("Collection already exists. Deleting now..")
    vx.delete_collection(matching_collection.name)

vecs_collection = vx.get_or_create_collection(name=index_name, dimension=EMBEDDINGS_DIMENSIONS)

# Create an index over our new collection (aka a Postgres table under the hood) for faster querying
vecs_collection.create_index(method=IndexMethod.hnsw, measure=IndexMeasure.cosine_distance,
                                        index_arguments=IndexArgsHNSW(m=8))

print("Collection in vecs created.")

In [None]:
records_to_upsert_to_vecs = [
    (
        item["id"],          # the vector's identifier
        item["values"],  # the vector
        item["metadata"]     # associated metadata
    )
    for item in all_cases_embeddings_and_sparse_vectors
]

In [None]:
vecs_collection.upsert(records=records_to_upsert_to_vecs)

In [None]:
vecs_similarity_search = vecs_collection.query(
    data=search_query_as_vectors["dense"], # Again, we use dense vector to search similarity here.
    limit=2,
    filters={},
    measure="cosine_distance",
    include_value=False,
    include_metadata=True,
)

pprint(vecs_similarity_search)



In [None]:
from supabase import create_client, Client
from supabase.client import ClientOptions

url: str = ""
key: str = ""

supabase: Client = create_client(url, key,
  options=ClientOptions(
    postgrest_client_timeout=10,
    storage_client_timeout=10,
    schema="public",
  ))

In [None]:
vecs_full_text_search_results = (
    supabase.table("legal_cases")
    .select("case_id, case_text") # The columns we want back in our response
    .text_search(
        "case_text", # The column we want to search over
        f"'{search_query}'", # Be careful here; it's sensitive to special characters and will error out
        options={"type": "websearch", "config": "english"},
    )
    .execute()
)

pprint(vecs_full_text_search_results)

In [None]:
vecs_full_text_search_results = (
    supabase.table("legal_cases")
    .select("case_id, case_text") # The columns we want back in our response
    .text_search(
        "case_text", # The column we want to search over
        "'Palmer J' & 'Macleay Nominees Pty'", # Be careful here; it's sensitive to special characters and will error out
        options={"type": "websearch", "config": "english"},
    )
    .execute()
)

pprint(vecs_full_text_search_results)

### Setup Jina Reranker

In [None]:
JINA_RERANKER_URL = "https://api.jina.ai/v1/rerank"

def jina_rerank(query: str, text_list: List[str]):
    headers = {"Content-Type": "application/json", "Authorization": "Bearer jina_d68362712b5143188d360eaadef63cf16WjSf5hb686SC-yBocaJLq-2xvo7"}

    json_data = {
      "model": "jina-reranker-v2-base-multilingual",
      "documents": text_list,
      "query": query,
      "top_n": 3,
    }

    response = requests.post(JINA_RERANKER_URL, headers=headers, data=json.dumps(json_data))
    return response.json()

In [None]:
merged_list_for_reranking = []

# Process the `vecs_similarity_search` data
for case_id, case_data in vecs_similarity_search:
    merged_list_for_reranking.append({
        'case_id': case_id,
        'case_text': case_data['case_text']
    })

# Process the `vecs_full_text_search_results` data
for case_data in vecs_full_text_search_results.data:
    merged_list_for_reranking.append({
        'case_id': case_data['case_id'],
        'case_text': case_data['case_text']
    })

pprint(merged_list_for_reranking)

In [None]:
just_case_text = [item['case_text'] for item in merged_list_for_reranking]

In [None]:
reranked_results = jina_rerank(search_query, just_case_text)

pprint(reranked_results)

## The verdict

1.   Pinecone's developer experience was awesome. Honestly no complaints here other than costs (their pods pricing), but their serverless solution is quite affordable.
2.   Weaviate's docs are ok-ish and took me longer than expected to setup. But they do try to abstract a lot of the vectorization and and sparse vectors away so you can get started quickly without having to know what they are.
3. Whip up your own hybrid search with Postgres pgvector and full text search is fine as long as you extract the right keywords for full text search and pick the right components (especially reranker). This should be the cheapest option. I use Cohere's Reranker in production. Be careful of JinaAI's API downtime.

BONUS: JinaAI's API is down a lot lol.
