In [None]:
pip install psycopg2-binary oci pandas pyarrow tqdm ipython weasyprint sentence-transformers torch einops

In [None]:
# --- 1. Imports ---
import os
import oci
import pandas as pd
import psycopg2
from psycopg2.pool import SimpleConnectionPool
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from IPython.display import display
import ipywidgets as widgets
import torch
from torch.cuda.amp import autocast
import numpy as np
import json
import time
from datetime import timedelta

# --- 2. Configuration ---
# Fix the bucket name typo
BUCKET_NAME = "aus-legal-corpus"  # Corrected from BUCKET_NAME
OBJECT_PREFIX = ""
DOWNLOAD_DIR = "./data"
CHECKPOINT_FILE = os.path.join(DOWNLOAD_DIR, "checkpoint.json")
os.makedirs(DOWNLOAD_DIR, exist_ok=True)

# PostgreSQL Config
DB_CONFIG = {
    "dbname": "postgres",
    "user": "postgres",
    "password": "",
    "host": "10.150.2.103",
    "port": "5432"
}

# OCI Config
oci_config = {
    "user": "ocid1.user.oc1..aaq",
    "key_file": "./data/oci_api_key.pem",
    "fingerprint": "de:d6",
    "tenancy": "ocid1.tenancy.oc1..aaa",
    "region": "us-sanjose-1"
}


# --- 3. Checkpoint Management ---
def load_checkpoint():
    """Load checkpoint data if exists, else return default values"""
    default_checkpoint = {
        "current_file": "",
        "processed_files": [],
        "total_files": 0,  # NEW: Track total files
        "total_texts_processed": 0,
        "total_tokens_processed": 0,
        "start_time": time.time(),
        "batch_stats": []
    }
    
    if os.path.exists(CHECKPOINT_FILE):
        try:
            with open(CHECKPOINT_FILE, 'r') as f:
                data = json.load(f)
                print(f"⏳ Resuming from checkpoint: {data.get('current_file', 'No file')}")
                
                # Backward compatibility: merge with default values
                for key in default_checkpoint:
                    if key not in data:
                        data[key] = default_checkpoint[key]
                        print(f"⚠️ Added missing key to checkpoint: {key}")
                
                return data
        except Exception as e:
            print(f"⚠️ Error reading checkpoint: {e}. Starting fresh.")
            return default_checkpoint
    
    return default_checkpoint

def save_checkpoint(checkpoint):
    """Save checkpoint data to file"""
    try:
        with open(CHECKPOINT_FILE, 'w') as f:
            json.dump(checkpoint, f, indent=2)
        return True
    except Exception as e:
        print(f"⚠️ Error saving checkpoint: {e}")
        return False

def print_summary(checkpoint):
    """Print processing summary statistics"""
    total_time = time.time() - checkpoint['start_time']
    hours, remainder = divmod(total_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    
    avg_batch_size = np.mean(checkpoint['batch_stats']) if checkpoint['batch_stats'] else 0
    
    print("\n📊 Processing Summary:")
    print(f"  Files processed: {len(checkpoint['processed_files'])}/{checkpoint['total_files']}")
    print(f"  Texts processed: {checkpoint['total_texts_processed']:,}")
    print(f"  Estimated tokens processed: {checkpoint['total_tokens_processed']:,}")
    print(f"  Total processing time: {int(hours)}h {int(minutes)}m {int(seconds)}s")
    print(f"  Average batch size: {avg_batch_size:.1f}")
    print(f"  Texts per second: {checkpoint['total_texts_processed']/total_time:.1f}")

# --- 4. Connect to OCI and Download Parquet Files ---
print("🔍 Listing objects in bucket...")
try:
    object_storage = oci.object_storage.ObjectStorageClient(oci_config)
    namespace = object_storage.get_namespace().data
    objects = object_storage.list_objects(namespace, BUCKET_NAME, prefix=OBJECT_PREFIX).data.objects
    parquet_files = [obj.name for obj in objects if obj.name.endswith(".parquet")]
    
    if not parquet_files:
        raise Exception("❌ No .parquet files found. Check bucket, prefix or region.")
        
    print(f"Found {len(parquet_files)} parquet files")
    
    for obj_name in parquet_files:
        local_file = os.path.join(DOWNLOAD_DIR, os.path.basename(obj_name))
        if not os.path.exists(local_file):
            print(f"⬇️ Downloading {obj_name}...")
            with open(local_file, 'wb') as f:
                response = object_storage.get_object(namespace, BUCKET_NAME, obj_name)
                for chunk in response.data.raw.stream(1024 * 1024, decode_content=False):
                    f.write(chunk)
    print("✅ All Parquet files downloaded.")
except Exception as e:
    print(f"❌ Error in OCI operations: {str(e)}")
    raise

# --- 5. Load Embedding Model ---
print("\nInitializing model...")
try:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    if device == 'cuda':
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"Total memory: {torch.cuda.get_device_properties(0).total_memory/1024**3:.2f} GB")
    
    model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True, device=device)
    
    # Start with smaller batch size and auto-tune
    initial_batch_size = 32 if device == 'cuda' else 8
    model.max_seq_length = 512  # Might help with memory
    
    print("Warming up GPU...")
    with torch.no_grad(), torch.amp.autocast(device_type='cuda', dtype=torch.float16):
        dummy_input = ["warmup"] * initial_batch_size
        _ = model.encode(dummy_input, batch_size=initial_batch_size)
        torch.cuda.empty_cache()
    
    print("✅ Model loaded successfully")
except Exception as e:
    print(f"❌ Error loading model: {str(e)}")
    raise

# --- 6. Connect to PostgreSQL ---
print("\nConnecting to PostgreSQL...")
try:
    pool = SimpleConnectionPool(1, 4, **DB_CONFIG)
    conn = pool.getconn()
    cursor = conn.cursor()
    
    # Test connection
    cursor.execute("SELECT 1")
    conn.commit()
    print("✅ PostgreSQL connection successful")
except Exception as e:
    print(f"❌ PostgreSQL connection failed: {str(e)}")
    raise

# --- 7. Create Table ---
print("\nEnsuring database table exists...")
try:
    cursor.execute("""
    CREATE EXTENSION IF NOT EXISTS vector;
    CREATE TABLE IF NOT EXISTS legal_docs_v4 (
        id SERIAL PRIMARY KEY,
        content TEXT,
        jurisdiction TEXT,
        source TEXT,
        citation TEXT,
        embedding VECTOR(768)
    );
    """)
    conn.commit()
    print("✅ Table created/verified")
except Exception as e:
    print(f"❌ Error creating table: {str(e)}")
    conn.rollback()

# --- 8. Optimized Batch Insert ---
def insert_batch(batch):
    try:
        args_str = ",".join(cursor.mogrify("(%s, %s, %s, %s, %s)", x).decode("utf-8") for x in batch)
        cursor.execute("INSERT INTO legal_docs_v4 (content, jurisdiction, source, citation, embedding) VALUES " + args_str)
        conn.commit()
        return True
    except Exception as e:
        print(f"❌ Batch insert failed: {str(e)}")
        conn.rollback()
        return False

# --- 9. Process Files with Checkpointing ---
def process_files():
    # Get list of local parquet files
    local_files = sorted([f for f in os.listdir(DOWNLOAD_DIR) if f.endswith(".parquet")])
    if not local_files:
        print("❌ No local parquet files found")
        return

    # Load or initialize checkpoint
    checkpoint = load_checkpoint()
    
    # Initialize total_files if not set or if mismatch detected
    if checkpoint['total_files'] == 0 or checkpoint['total_files'] != len(local_files):  # Removed extra parenthesis
        print(f"📦 Found {len(local_files)} files to process")
        checkpoint['total_files'] = len(local_files)
        save_checkpoint(checkpoint)
    
    # Skip already processed files
    files_to_process = [f for f in local_files if f not in checkpoint['processed_files']]
    
    # If checkpoint indicates a file was in progress, start from there
    if checkpoint['current_file'] and checkpoint['current_file'] in local_files:
        files_to_process.insert(0, checkpoint['current_file'])
    
    print(f"⏩ Resuming processing - {len(files_to_process)} files remaining")
    
    # Main processing loop
    for file_idx, file in enumerate(files_to_process, 1):
        checkpoint['current_file'] = file
        print(f"\n📂 Processing file {file_idx}/{len(files_to_process)}: {file}")
        
        try:
            # Read parquet file
            df = pd.read_parquet(os.path.join(DOWNLOAD_DIR, file))
            
            # Validate file structure
            if "text" not in df.columns:
                print(f"⚠️ Skipping {file}, no 'text' column")
                checkpoint['processed_files'].append(file)
                save_checkpoint(checkpoint)
                continue

            # Prepare all texts and metadata
            texts = []
            metadata = []
            for _, row in df.iterrows():
                text = row.get("text", "").strip()
                if text:
                    texts.append(text)
                    metadata.append((
                        text,
                        str(row.get("jurisdiction", "")).strip(),
                        str(row.get("source", "")).strip(),
                        str(row.get("citation", "")).strip()
                    ))

            if not texts:
                print("⚠️ No valid texts found in this file")
                checkpoint['processed_files'].append(file)
                save_checkpoint(checkpoint)
                continue

            print(f"📝 Processing {len(texts)} texts in this file")
            
            # Dynamic batch sizing configuration
            max_batch_size = 128  # Maximum allowed batch size
            min_batch_size = 8    # Minimum batch size
            current_batch_size = min(32, max_batch_size)  # Start with 32
            successful_batches = 0
            
            # Initialize progress bar
            with tqdm(total=len(texts), desc="Texts processed", unit="text") as pbar:
                for i in range(0, len(texts), current_batch_size):
                    batch_texts = texts[i:i+current_batch_size]
                    batch_metadata = metadata[i:i+current_batch_size]
                    
                    try:
                        # Process batch with GPU
                        batch_start = time.time()
                        with torch.no_grad(), torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                            vectors = model.encode(
                                batch_texts,
                                batch_size=current_batch_size,
                                convert_to_tensor=True,
                                normalize_embeddings=True
                            ).cpu().numpy().tolist()

                        # Prepare database batch
                        db_batch = [(*meta, vec) for meta, vec in zip(batch_metadata, vectors)]
                        
                        # Insert batch into database
                        if insert_batch(db_batch):
                            successful_batches += 1
                            
                            # Update progress tracking
                            processed_count = len(batch_texts)
                            checkpoint['total_texts_processed'] += processed_count
                            checkpoint['total_tokens_processed'] += processed_count * 100  # Estimate 100 tokens per text
                            checkpoint['batch_stats'].append(current_batch_size)
                            
                            # Update progress bar
                            pbar.update(processed_count)
                            
                            # Periodically save checkpoint
                            if successful_batches % 5 == 0:
                                save_checkpoint(checkpoint)
                            
                            # Dynamic batch size adjustment
                            if successful_batches % 10 == 0 and current_batch_size < max_batch_size:
                                new_batch_size = min(current_batch_size * 2, max_batch_size)
                                if new_batch_size != current_batch_size:
                                    current_batch_size = new_batch_size
                                    print(f"⚡ Increased batch size to {current_batch_size}")
                    
                    except torch.cuda.OutOfMemoryError:
                        print("⚠️ GPU OOM - reducing batch size")
                        current_batch_size = max(current_batch_size // 2, min_batch_size)
                        torch.cuda.empty_cache()
                        continue
                        
                    except Exception as e:
                        print(f"⚠️ Batch error: {str(e)} - retrying with smaller batch")
                        current_batch_size = max(current_batch_size // 2, min_batch_size)
                        continue

            # Mark file as successfully completed
            checkpoint['processed_files'].append(file)
            checkpoint['current_file'] = ""  # Reset current file
            save_checkpoint(checkpoint)
            print(f"✅ Completed processing {file}")

        except Exception as e:
            print(f"❌ Error processing file {file}: {str(e)}")
            save_checkpoint(checkpoint)  # Save state before continuing
            continue

    # Print final summary
    print_summary(checkpoint)
    
    # Clean up checkpoint if all files processed
    if len(checkpoint['processed_files']) == checkpoint['total_files']:
        try:
            os.remove(CHECKPOINT_FILE)
            print("✅ All files processed - checkpoint removed")
        except Exception as e:
            print(f"⚠️ Could not remove checkpoint file: {e}")

# Actually run the processing
process_files()

# --- 10. Create HNSW Vector Index ---
cursor.execute("""
CREATE INDEX IF NOT EXISTS legal_docs_hnsw_idx_v4
ON legal_docs_v4 USING hnsw (embedding vector_cosine_ops)
WITH (m = 16, ef_construction = 64);
""")
cursor.execute("ANALYZE legal_docs_v4;")
conn.commit()
print("✅ HNSW vector index created.")

# --- 11. Sentence Search Interface ---
def search_query(user_query, top_k=5):
    user_vector = model.encode(user_query).tolist()
    query = """
    SELECT content, jurisdiction, source, citation,
           1 - (embedding <#> %s::vector) AS similarity
    FROM legal_docs_v4
    ORDER BY embedding <#> %s::vector
    LIMIT %s;
    """
    cursor.execute(query, (user_vector, user_vector, top_k))
    rows = cursor.fetchall()
    results = pd.DataFrame(rows, columns=["Content", "Jurisdiction", "Source", "Citation", "Similarity"])
    return results

# --- 12. UI to Accept Query ---
input_box = widgets.Textarea(
    placeholder='Ask a legal question...',
    description='Query:',
    layout=widgets.Layout(width='80%', height='100px')
)

button = widgets.Button(description="Search")
output = widgets.Output()

def on_button_click(b):
    with output:
        output.clear_output()
        print("🔍 Searching...")
        result_df = search_query(input_box.value)
        display(result_df)

button.on_click(on_button_click)

display(input_box, button, output)