In [None]:
!pip install python-dotenv tiktoken google-cloud-aiplatform cassandra-driver

In [None]:
import os
import math
import datetime
import pandas as pd
import numpy as np
import time

from dotenv import load_dotenv
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.concurrent import execute_concurrent_with_args

from vertexai.language_models import TextEmbeddingModel

##
# DO NOT change these values !
#
# - Concurrent batch prediction jobs (Quota limit is 5) (see: https://cloud.google.com/vertex-ai/docs/quotas)
MAX_ALLOWED_CONCURRENT_PREDICTION_JOB_NUM = 5
# - Google Gecko text embedding dimension is 768
VERTEX_AI_EMBEDDING_DIMENSION=768


##
# Access the environment variables from the .env file for the corresponding Astra DB
# NOTE: the ".env" file is prepared using the following Astra CLI command
#       astra db create-dotenv <astradb_name> -k <keyspace_name>
#
load_dotenv()

##
# Helper function to connect to Astra DB CQL session
#
def getAdbCqlSession(keyspace):
    sec_bundle_file = os.environ.get('ASTRA_DB_SECURE_BUNDLE_PATH')
    access_token = os.environ.get('ASTRA_DB_APPLICATION_TOKEN')
    
    cluster = Cluster(
        cloud={
            "secure_connect_bundle": sec_bundle_file,
        },
        auth_provider=PlainTextAuthProvider(
            "token",
            access_token,
        )
    )
    
    cqlSession = cluster.connect()
    if len(keyspace) > 0:
        cqlSession.set_keyspace(keyspace)
    
    return cqlSession
    
def getCQLKeyspace():
    return os.environ.get('ASTRA_DB_KEYSPACE')

In [None]:
##
# Load the food review data
#
food_review_df = pd.read_csv('./data/fine_food_reviews_1k.csv')

##
# Add an empty column to store the embedding value for the combination of 'Summary' and 'Text' columns
#
food_review_df['Embedding'] = None

##
# Convert the input value (UTC in second) to the DateTime type 
# Otherwise, the input statement later will fail with incompatible type
#
food_review_df['Time'] = food_review_df['Time'].apply(lambda x: datetime.datetime.fromtimestamp(x))

print(food_review_df.info())
food_review_df.head(n=5)

total_row = food_review_df.shape[0]
total_batch = math.ceil(total_row/MAX_ALLOWED_CONCURRENT_PREDICTION_JOB_NUM)
print(f"total_row={total_row}, total_batch={total_batch}")

##
# Help function to get the starting row index of a batch
#
def get_batch_start_rowidx(bathidx):
    assert batchidx >=0 and batchidx <= total_batch
    return batchidx * MAX_ALLOWED_CONCURRENT_PREDICTION_JOB_NUM

##
# Help function to get the ending row index of a batch
#
def get_batch_end_rowidx(batchidx):
    assert batchidx >=0 and batchidx <= total_batch
    return min((batchidx+1) * MAX_ALLOWED_CONCURRENT_PREDICTION_JOB_NUM - 1, total_row - 1)

In [None]:
##
# Helper function to get the text embedding vector using Google Vertex AI API
#    "textembedding-gecko@001" embedding dimension is fixed at 768
#
def vertex_ai_text_embeddings(text_arr) -> list:
    start_time = time.time()
    
    embedding_mode_name = "textembedding-gecko@001"
    model = TextEmbeddingModel.from_pretrained(embedding_mode_name)
    embeddings = model.get_embeddings(text_arr)
    
    end_time = time.time()
    
    return (end_time - start_time), embeddings

##
# Get the food review emeddings in batches using the Vertex AI API and update the dataframe accordingly
#
review_text_arr_by_batch = []
fetch_embedding_batch_duration = []

for batchidx in range(0, total_batch):    
    start_row_idx = get_batch_start_rowidx(batchidx)
    end_row_idx = get_batch_end_rowidx(batchidx)
    
    print(f"Get embeddings of the food reviews for batch {batchidx+1} [{start_row_idx}, {end_row_idx}] ...")

    review_text_arr_by_batch.clear()
    for rowidx in range(start_row_idx, end_row_idx+1):
        review_text_arr_by_batch.append(
            food_review_df.iloc[rowidx]['Summary'] + " : " + food_review_df.iloc[rowidx]['Text'])
    
    # Call the Vertex embedding API in batch
    batch_duration,embedding_list = vertex_ai_text_embeddings(review_text_arr_by_batch)
    fetch_embedding_batch_duration.append(batch_duration)

    # For each batch, update the embedding cell for each row in the data frame;
    #   and isnert the record in the 'food_review' table
    for rowidx in range(start_row_idx, end_row_idx+1):
        food_review_df.at[rowidx, 'Embedding'] = embedding_list[rowidx-start_row_idx].values

print(f"""\nVertex Emedding API call duration statistics (per batch): 
   min={np.min(fetch_embedding_batch_duration)}s, 
   max={np.max(fetch_embedding_batch_duration)}s, 
   avg={np.mean(fetch_embedding_batch_duration)}s
""")
food_review_df.head()

In [None]:
adb_keyspace = getCQLKeyspace()
adb_cql_session = getAdbCqlSession(adb_keyspace)

## 
# CQL statement to create the 'food_veiw' C* table
#
cql_schema_stmt=f"""CREATE TABLE IF NOT EXISTS {adb_keyspace}.food_review (
  id int PRIMARY KEY,
  time TIMESTAMP,
  product_id TEXT,
  user_id TEXT,
  score INT,
  summary TEXT,
  text TEXT,
  embedding VECTOR<FLOAT, {VERTEX_AI_EMBEDDING_DIMENSION}>
);"""
print(f"cql_schema_stmt={cql_schema_stmt}")

adb_cql_session.execute(cql_schema_stmt)

In [None]:
##
# CQL prepared statement for inserting one record into the 'food_review' table
#
review_insert_stmt = cql_session.prepare(f"""
    INSERT INTO {adb_keyspace}.food_review(id, time, product_id, user_id, score, summary, text, embedding) 
    VALUES(?,?,?,?,?,?,?,?)
    """
)

##
# Helper function to insert the food reviews with embedding values for a batch
# - batchidx: the batch index
# - concurrent: whether to use cassandra.concurrent library
# 
def insert_with_embedding_batch(batchidx, concurrent=False):
    assert batchidx >=0 and batchidx <= total_batch
    
    start_time = time.time()
    
    start_row_idx = get_batch_start_rowidx(batchidx)
    end_row_idx = get_batch_end_rowidx(batchidx)
    
    if concurrent == False:
        for rowidx in range(start_row_idx, end_row_idx+1):
            cql_session.execute(review_insert_stmt, food_review_df.iloc[rowidx])
    else:
        parameters = []
        for rowidx in range(start_row_idx, end_row_idx+1):
            parameters.append(food_review_df.iloc[rowidx])
        execute_concurrent_with_args(cql_session,
                                     review_insert_stmt, 
                                     parameters,
                                     concurrency=MAX_ALLOWED_CONCURRENT_PREDICTION_JOB_NUM)
        
    end_time = time.time()
    
    return (end_time - start_time)

##
# Insert the food reviews with embeddings in batch
#
adb_insert_duration = []

# Concurrent insert is much faster 
concurrent_insert = True

for batchidx in range(0, total_batch):
    start_row_idx = get_batch_start_rowidx(batchidx)
    end_row_idx = get_batch_end_rowidx(batchidx)
    
    print(f"Insert food review with the embedding value for batch {batchidx+1} [{start_row_idx}, {end_row_idx}] ...")
    batch_duration = insert_with_embedding_batch(batchidx, concurrent_insert)    
    adb_insert_duration.append(batch_duration)
    
print(f"""\nAstra DB C* table batch insert duration statistics with concurrent insert ({concurrent_insert}): 
   min={np.min(adb_insert_duration)}s, 
   max={np.max(adb_insert_duration)}s, 
   avg={np.mean(adb_insert_duration)}s
""") 

In [None]:
##
# Helper function to query the food reviews using vector search and/or other regular searchs
# - stmt: the CQL statement to execute
# - top: Top N record to show (default is is not to show any queried results)
#
def food_review_query(stmt, top=None):
    start_time = time.time()
    
    results = cql_session.execute(stmt)

    cnt=0
    for result in results:
        cnt += 1
        if top and cnt<top:
            print(result)
            
    end_time = time.time()
        
    return (end_time - start_time), cnt

##
# Define an SASI index on a regular column 'product_id' 
#
product_index_creation_stmt=f"""CREATE CUSTOM INDEX IF NOT EXISTS food_review_product_index
    ON {adb_keyspace}.food_review(product_id) USING 'StorageAttachedIndex';
"""
cql_session.execute(product_index_creation_stmt)

##
# Define an SASI index on the vector column 'embedding' 
# - default similarity comparison mode:
#   * COSINE
#   * DOT_PRODUCT
#   * EUCLIDEAN
ANN_INDEX_MODE = "'COSINE'"
ann_index_creation_stmt=f"""CREATE CUSTOM INDEX IF NOT EXISTS food_review_ann_index
    ON {adb_keyspace}.food_review(embedding) USING 'StorageAttachedIndex'
    WITH OPTIONS = {{ 'similarity_function': {ANN_INDEX_MODE} }};
"""
cql_session.execute(ann_index_creation_stmt)

In [None]:
duration,query_embeddings = vertex_ai_text_embeddings(["hamburger is good"])

In [None]:
vector_query_limit = 100

# Pure vector search 
query_duration,cnt = food_review_query(f"""
    SELECT 
        id, 
        time, 
        product_id, 
        user_id, score, 
        summary, 
        text
    FROM {adb_keyspace}.food_review 
    ORDER BY embedding ANN OF {query_embeddings[0].values}
    LIMIT {vector_query_limit};
    """,
top=2)

print(f"""\nPure vector search with maximum {vector_query_limit} records: 
   result_cnt={cnt}, 
   query_duration={query_duration}s
""") 

In [None]:
# Both vector search and regular search
query_duration,cnt = food_review_query(f"""
    SELECT 
        id, 
        time, 
        product_id, 
        user_id, score, 
        summary, 
        text
    FROM {adb_keyspace}.food_review 
    WHERE product_id in ('B0006UFY46', 'B0077HIJYS')
    ORDER BY embedding ANN OF {query_embeddings[0].values}
    LIMIT {vector_query_limit};
    """,
top=2)

print(f"""\nPure vector search with maximum {vector_query_limit} records: 
   result_cnt={cnt}, 
   query_duration={query_duration}s
""") 