In [None]:
# @title Step 1: Install the Stable Configuration

!pip install sentence-transformers faiss-cpu nltk numpy spicy --quiet

print("✅ All libraries installed.")
print("🔴 IMPORTANT: Please restart the runtime now before proceeding.")

In [None]:
# @title Step 2: Download NLTK Data and Set Up Paths

import os
import nltk
from google.colab import drive

# --- Download NLTK Data ---
# Download both the standard and specialized 'punkt' tokenizers to prevent errors
nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)
print("✅ NLTK tokenizers downloaded.")

# --- Mount Google Drive & Set Up Paths ---
print("\nMounting Google Drive...")
drive.mount('/content/drive')

# Define paths
TEXTBOOK_DIR = '/content/drive/MyDrive/Medical_Textbooks'
SAVE_DIR = '/content/drive/MyDrive/medprompt_ai_data'
os.makedirs(SAVE_DIR, exist_ok=True)

print(f"\nTextbooks will be loaded from: {TEXTBOOK_DIR}")
print(f"Output files will be saved to: {SAVE_DIR}")

In [None]:
# @title Step 3: Process Textbooks and Generate Data Files (High-Efficiency GPU Method)

import os
import pickle
import re
import numpy as np
import nltk
from sentence_transformers import SentenceTransformer
import faiss
from tqdm.notebook import tqdm
import torch

# --- 1. Verify GPU Availability ---
if torch.cuda.is_available():
    device = 'cuda'
    print(f"✅ GPU is available. Using device: {device}")
else:
    # Fallback to CPU if no GPU is found, though it will be much slower.
    device = 'cpu'
    print("⚠️ WARNING: GPU not found. Falling back to CPU. Processing will be significantly slower.")
    print("To fix this, go to Runtime > Change runtime type and select 'T4 GPU'.")


# ======================================================================================
# --- ⚙️ HYPERPARAMETER CONFIGURATION ---
# ======================================================================================

# --- Model Selection ---
# BAAI/bge-base-en-v1.5 is a powerful model that runs efficiently on a T4 GPU.
LOCAL_EMBEDDING_MODEL = 'BAAI/bge-base-en-v1.5'

# --- Chunking Strategy ---
MIN_PARAGRAPH_LEN = 40       # Skips paragraphs shorter than this character length.
MIN_CHUNK_LEN = 30           # Skips final chunks shorter than this character length.
OVERLAP_SENTENCES = 1        # Number of previous sentences to include for context.

# --- FAISS Index Tuning ---
# Set to True for very large datasets (>100k chunks) to potentially speed up retrieval.
USE_ADVANCED_FAISS_INDEX = True
FAISS_NLIST = 100            # Number of cells for the advanced IVF index.
# ======================================================================================


# --- NLTK Setup ---
nltk.download('punkt', quiet=True)
from nltk.tokenize import sent_tokenize

def robust_sentence_splitter(paragraph):
    """Splits a paragraph into sentences using NLTK."""
    return sent_tokenize(paragraph)

def process_textbooks():
    """
    Loads text files, processes them, generates embeddings locally on the GPU,
    builds a FAISS index, and saves the results.
    """
    if not os.path.exists(TEXTBOOK_DIR):
        print(f"Error: The directory '{TEXTBOOK_DIR}' was not found.")
        return

    # --- Dynamic File Discovery ---
    textbook_filenames = [f for f in os.listdir(TEXTBOOK_DIR) if f.endswith('.txt')]
    print(f"\nFound {len(textbook_filenames)} textbooks to process.")

    # --- Process Textbooks into Chunks ---
    print("\n--- Processing Textbooks and Generating Chunks ---")
    text_chunks = []
    # (This logic remains the same as your provided code)
    for filename in tqdm(textbook_filenames, desc="Processing Textbooks"):
        filepath = os.path.join(TEXTBOOK_DIR, filename)
        with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
            content = f.read()

        paragraphs = content.split('\n\n')
        for para in paragraphs:
            if len(para.strip()) < MIN_PARAGRAPH_LEN:
                continue

            sentences = robust_sentence_splitter(para)
            if not sentences:
                continue

            for i in range(len(sentences)):
                chunk_content = sentences[i]
                start_index = max(0, i - OVERLAP_SENTENCES)
                contextual_prefix = " ".join(sentences[start_index:i])
                full_chunk = (contextual_prefix + " " + chunk_content).strip()

                if len(full_chunk) > MIN_CHUNK_LEN:
                    text_chunks.append({"text": full_chunk, "source": filename})

    print(f"\nTotal text chunks created: {len(text_chunks)}")
    if not text_chunks:
        return

    # --- Generate Embeddings Locally on GPU ---
    print(f"\n--- Loading local model '{LOCAL_EMBEDDING_MODEL}' onto the {device.upper()} ---")
    embedding_model = SentenceTransformer(LOCAL_EMBEDDING_MODEL, device=device)

    print("\n--- Generating Embeddings ---")
    chunk_texts = [chunk['text'] for chunk in text_chunks]

    # The 'encode' method is highly optimized for batch processing on a GPU.
    # A large batch size here is good for performance.
    chunk_embeddings = embedding_model.encode(
        chunk_texts,
        show_progress_bar=True,
        batch_size=128
    )

    embedding_dimension = chunk_embeddings.shape[1]
    print(f"Generated {len(chunk_embeddings)} embeddings of dimension {embedding_dimension}.")

    # --- Build and Save FAISS Index ---
    print("\n--- Building and Saving FAISS Index ---")
    if USE_ADVANCED_FAISS_INDEX:
        print(f"Using advanced FAISS index 'IndexIVFFlat' with nlist={FAISS_NLIST}.")
        quantizer = faiss.IndexFlatL2(embedding_dimension)
        index = faiss.IndexIVFFlat(quantizer, embedding_dimension, FAISS_NLIST)
        print("Training the advanced index...")
        index.train(chunk_embeddings)
    else:
        print("Using standard FAISS index 'IndexFlatL2'.")
        index = faiss.IndexFlatL2(embedding_dimension)

    index.add(chunk_embeddings.astype('float32'))
    print(f"FAISS index built with {index.ntotal} vectors.")

    # --- Save Files ---
    faiss_index_file_path = os.path.join(SAVE_DIR, 'medical_faiss_index.bin')
    faiss.write_index(index, faiss_index_file_path)
    print(f"FAISS index saved to: {faiss_index_file_path}")

    chunks_file_path = os.path.join(SAVE_DIR, 'medical_text_chunks.pkl')
    with open(chunks_file_path, 'wb') as f:
        pickle.dump(text_chunks, f)
    print(f"Text chunks saved to: {chunks_file_path}")

    print("\n--- ✅ Pre-processing Complete ---")

# --- Run the main function ---
if __name__ == '__main__':
    process_textbooks()