# Text Chunking

1. Connects to Given Database
2. User selects the chunking algo, embedding model and other parameters
3. Defaults to process_documents(batch_size=100, max_workers=16, model_name='all-MiniLM-L6-v2', chunk_size=512, chunk_overlap=50)
4. Program creates the necessary table for destination and also creates a checkpoint table for resuming incase the operation fails cause of DB connectivity issues. Rerunning this notebook automatically resumes the previous progress. If no new documents are present then processing ends
5. Varying the chunking algorithms from the one available from Langchain.
6. Both Fulltext and Vector index (ANN) are created upon completion
7. Hybrid search example provided at the end.
8. For fastest results run in a GPU container with parallelism options enabled

In [None]:
# Install required packages
!pip -q install langchain langchain-text-splitters sentence_transformers torch tqdm

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"  
import multiprocessing
import singlestoredb as s2
import json
import torch
from tqdm import tqdm
from langchain.text_splitter import RecursiveCharacterTextSplitter
import concurrent.futures
import numpy as np

#########################################
#  Helper Functions for DB Connections  #
#########################################
def connect_to_db(database_name='knowlagent'):
    """Return a new SingleStore DB connection."""
    return s2.connect(database=database_name)

#########################################
#  Table Creation & Clearing Functions  #
#########################################
def create_tables(cursor, model_name):
    """
    Create the main chunks table and a checkpoint table.
    Note: The main table is created as a columnstore table so that the FULLTEXT
    index needed for hybrid search will work.
    """
    table_name = f"s2docs_chunks_{model_name.replace('-', '_').replace('/', '_')}"
    create_table_sql = f"""
    CREATE TABLE IF NOT EXISTS {table_name} (
       doc_id BIGINT NOT NULL,
       source_url TEXT,
       chunk_index INT NOT NULL,
       chunk_text LONGTEXT,
       embedding JSON DEFAULT NULL,
       vector_embedding VECTOR(384) DEFAULT NULL,
       SORT KEY (doc_id, chunk_index)
    )
    """
    cursor.execute(create_table_sql)
    
    checkpoint_table = f"{table_name}_checkpoint"
    cursor.execute(f"""
        CREATE ROWSTORE TABLE IF NOT EXISTS {checkpoint_table} (
            id INT PRIMARY KEY,
            last_processed_id BIGINT,
            timestamp DATETIME
        )
    """)
    return table_name, checkpoint_table

def create_indexes(cursor, table_name):
    """Create vector and fulltext indexes on a columnstore table in SingleStore."""
    # Create the vector index using ANN.
    try:
        cursor.execute(f"CREATE INDEX idx_{table_name}_vector ON {table_name}(vector_embedding) USING ANN;")
        print("Vector index created successfully.")
    except Exception as e:
        print(f"Warning: Could not create vector index: {e}")
        
    # Create the fulltext index using ALTER TABLE.
    try:
        cursor.execute(f"ALTER TABLE {table_name} ADD FULLTEXT INDEX idx_{table_name}_text (chunk_text);")
        print("FULLTEXT index created successfully.")
    except Exception as e:
        print(f"Warning: Could not create FULLTEXT index: {e}")
        
    return True



#########################################
#  Checkpoint Functions                 #
#########################################
def get_checkpoint(cursor, checkpoint_table):
    """Retrieve the last-processed document ID from the checkpoint table."""
    cursor.execute(f"SELECT last_processed_id FROM {checkpoint_table} WHERE id = 1")
    checkpoint = cursor.fetchone()
    return checkpoint[0] if checkpoint else 0

def update_checkpoint(checkpoint_table, last_processed_id):
    """Update the checkpoint table with the latest processed doc id using a fresh connection."""
    conn = connect_to_db()
    try:
        with conn.cursor() as cursor:
            cursor.execute(f"""
                INSERT INTO {checkpoint_table} (id, last_processed_id, timestamp)
                VALUES (1, %s, NOW())
                ON DUPLICATE KEY UPDATE 
                    last_processed_id = VALUES(last_processed_id),
                    timestamp = VALUES(timestamp)
            """, (last_processed_id,))
        conn.commit()
    except Exception as e:
        print("Error updating checkpoint:", e)
        conn.rollback()
    finally:
        conn.close()

#########################################
#  Query Functions for Reading Docs      #
#########################################
def get_unprocessed_docs_count(cursor, table_name, last_processed_id):
    """Return the count of documents (from s2docs) that need processing."""
    count_query = f"""
        SELECT COUNT(*) FROM s2docs
        WHERE md_content_cleaned IS NOT NULL 
          AND md_content_cleaned != ''
          AND id > {last_processed_id}
          AND id NOT IN (SELECT doc_id FROM {table_name})
    """
    cursor.execute(count_query)
    return cursor.fetchone()[0]

def get_unprocessed_docs_batch(cursor, table_name, last_processed_id, batch_size):
    """Retrieve a batch of unprocessed docs from s2docs."""
    select_sql = f"""
        SELECT id, md_content_cleaned
        FROM s2docs
        WHERE md_content_cleaned IS NOT NULL
          AND md_content_cleaned != ''
          AND id > %s
          AND id NOT IN (SELECT doc_id FROM {table_name})
        ORDER BY id
        LIMIT %s
    """
    cursor.execute(select_sql, (last_processed_id, batch_size))
    return cursor.fetchall()

#########################################
#  Embedding & Text Processing Functions
#########################################
def batch_encode_text(model, texts, batch_size=32, device='cuda' if torch.cuda.is_available() else 'cpu'):
    """Encode a list of texts in batches, using GPU if available."""
    if not texts:
        return np.array([])
    
    model = model.to(device)
    all_embeddings = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        embeddings = model.encode(batch, convert_to_tensor=True)
        if torch.is_tensor(embeddings):
            embeddings = embeddings.cpu().numpy()
        all_embeddings.append(embeddings)
    if all_embeddings:
        return np.vstack(all_embeddings)
    return np.array([])

def process_document_batch(doc_batch, text_splitter, embed_model):
    """
    Process a batch of documents:
     - Split each document's cleaned text into chunks.
     - Compute embeddings for each chunk.
     - Return a list of tuples (doc_id, chunk_index, chunk_text, embedding_list).
    """
    results = []
    for doc in doc_batch:
        doc_id, cleaned_text = doc
        chunks = text_splitter.split_text(cleaned_text)
        if chunks:
            embeddings = batch_encode_text(embed_model, chunks)
            for idx, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
                results.append((doc_id, idx, chunk, embedding.tolist()))
    return results

#########################################
#  DB Write Functions                   #
#########################################
def insert_chunks(table_name, chunks_data):
    """
    Insert chunk records into the main table using its own DB connection.
    Uses SingleStore's JSON_ARRAY_PACK to insert vector_embedding.
    """
    conn = connect_to_db()
    success_count = 0
    try:
        with conn.cursor() as cursor:
            insert_sql = f"""
                INSERT INTO {table_name} (doc_id, source_url, chunk_index, chunk_text, embedding, vector_embedding)
                VALUES (%s, %s, %s, %s, %s, JSON_ARRAY_PACK(%s))
            """
            for doc_id,source_url, idx, chunk, embedding in chunks_data:
                embedding_json = json.dumps(embedding)
                try:
                    cursor.execute(insert_sql, (doc_id,source_url, idx, chunk, embedding_json, embedding_json))
                    success_count += 1
                except Exception as e:
                    print(f"Error inserting chunk {idx} for doc {doc_id}: {e}")
        conn.commit()
    except Exception as e:
        print("DB Insert error:", e)
        conn.rollback()
    finally:
        conn.close()
    return success_count

#########################################
#   Main Process Function               #
#########################################
def process_documents(batch_size=100, max_workers=4, model_name='all-MiniLM-L6-v2', chunk_size=512, chunk_overlap=50):
    """
    Process documents from s2docs:
      - Create necessary tables and indexes.
      - Read unprocessed documents in batches (using a fresh read connection per batch).
      - Process text splitting and compute embeddings using parallel threads.
      - Insert chunks using separate DB connections.
      - Update checkpoints so that if the process stops, it may resume later.
    """
    from sentence_transformers import SentenceTransformer
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    embed_model = SentenceTransformer(model_name, device=device)
    
    # Setup tables and get checkpoint info.
    conn_setup = connect_to_db()
    try:
        with conn_setup.cursor() as cursor:
            table_name, checkpoint_table = create_tables(cursor, model_name)
            conn_setup.commit()
            last_processed_id = get_checkpoint(cursor, checkpoint_table)
            total_docs = get_unprocessed_docs_count(cursor, table_name, last_processed_id)
    except Exception as e:
        print("Error during setup:", e)
        conn_setup.rollback()
        conn_setup.close()
        return
    conn_setup.close()
    
    if total_docs == 0:
        print("No new documents to process.")
        return
    
    print(f"Found {total_docs} unprocessed documents.")
    pbar = tqdm(total=total_docs, desc="Processing documents", unit="doc")
    
    # Initialize text splitter.
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        while True:
            conn_r = connect_to_db()
            try:
                with conn_r.cursor() as cursor:
                    docs = get_unprocessed_docs_batch(cursor, table_name, last_processed_id, batch_size)
            finally:
                conn_r.close()
            
            if not docs:
                break
            
            # Break docs into smaller batches for parallel processing.
            doc_batches = [docs[i:i + max(1, batch_size // max_workers)] 
                           for i in range(0, len(docs), max(1, batch_size // max_workers))]
            
            futures = [executor.submit(process_document_batch, batch, text_splitter, embed_model) 
                       for batch in doc_batches]
            processed_docs = set()
            for future in concurrent.futures.as_completed(futures):
                try:
                    chunks_data = future.result()
                    doc_ids = {doc_id for doc_id, _, _, _ in chunks_data}
                    insert_chunks(table_name, chunks_data)
                    processed_docs.update(doc_ids)
                    pbar.update(len(doc_ids))
                except Exception as e:
                    print(f"Error processing a batch: {e}")
            
            if processed_docs:
                current_max_id = max(processed_docs)
                last_processed_id = max(last_processed_id, current_max_id)
                update_checkpoint(checkpoint_table, last_processed_id)
    pbar.close()
    print(f"All documents processed and stored in {table_name}.")
    
    conn_ddl = connect_to_db()
    try:
        with conn_ddl.cursor() as cursor:
            create_indexes(cursor, table_name)
        conn_ddl.commit()
    except Exception as e:
        print("Error creating indexes:", e)
        conn_ddl.rollback()
    finally:
        conn_ddl.close()

In [None]:
# Optional: Clear tables for a fresh start.
def clear_tables(model_name='all-MiniLM-L6-v2', confirm=False):
    """Delete the chunks and checkpoint tables for a given embedding model."""
    if not confirm:
        confirmation = input(f"Are you sure you want to delete all tables for model '{model_name}'? (y/n): ")
        if confirmation.lower() != 'y':
            print("Operation cancelled.")
            return False

    table_name = f"s2docs_chunks_{model_name.replace('-', '_').replace('/', '_')}"
    checkpoint_table = f"{table_name}_checkpoint"
    
    conn = connect_to_db()
    try:
        with conn.cursor() as cursor:
            # Drop vector index if it exists.
            cursor.execute(f"""
                SELECT COUNT(*) 
                FROM information_schema.statistics 
                WHERE table_schema = DATABASE() 
                  AND table_name = '{table_name}' 
                  AND index_name = 'idx_{table_name}_vector'
            """)
            if cursor.fetchone()[0] > 0:
                try:
                    cursor.execute(f"DROP INDEX idx_{table_name}_vector ON {table_name}")
                except Exception as e:
                    print(f"Warning when dropping vector index: {e}")
            # Drop fulltext index if it exists.
            cursor.execute(f"""
                SELECT COUNT(*) 
                FROM information_schema.statistics 
                WHERE table_schema = DATABASE() 
                  AND table_name = '{table_name}' 
                  AND index_name = 'idx_{table_name}_text'
            """)
            if cursor.fetchone()[0] > 0:
                try:
                    cursor.execute(f"DROP INDEX idx_{table_name}_text ON {table_name}")
                except Exception as e:
                    print(f"Warning when dropping FULLTEXT index: {e}")
            # Drop tables.
            cursor.execute(f"DROP TABLE IF EXISTS {checkpoint_table}")
            cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
        conn.commit()
        print(f"Successfully deleted tables for model '{model_name}'")
        return True
    except Exception as e:
        print(f"Error deleting tables: {e}")
        conn.rollback()
        return False
    finally:
        conn.close()

# clear_tables(model_name='all-MiniLM-L6-v2', confirm=True)

In [None]:
## Try multi processing 
# multiprocessing.set_start_method('spawn', force=True)
# # Launch process_documents in a separate process.
# p = multiprocessing.Process(target=process_documents, kwargs={'batch_size': 100, 'max_workers': 16})
# p.start()
# # Wait for up to 1200 seconds (20 minutes).
# p.join(timeout=1200)

# # If still running after timeout, terminate.
# if p.is_alive():
#     print("process_documents is still running after 20 minutes. Terminating...")
#     p.terminate()
#     p.join()  # Wait for the process to terminate.
#     print("process_documents terminated after 20 minutes.")



# Process documents with GPU acceleration. Default values for embedding models, chunk size, etc can be added.
process_documents(batch_size=200, max_workers=16)