### OCI Data Science - Useful Tips
<details>
<summary><font size="2">Check for Public Internet Access</font></summary>

```python
import requests
response = requests.get("https://oracle.com")
assert response.status_code==200, "Internet connection failed"
```
</details>
<details>
<summary><font size="2">Helpful Documentation </font></summary>
<ul><li><a href="https://docs.cloud.oracle.com/en-us/iaas/data-science/using/data-science.htm">Data Science Service Documentation</a></li>
<li><a href="https://docs.cloud.oracle.com/iaas/tools/ads-sdk/latest/index.html">ADS documentation</a></li>
</ul>
</details>
<details>
<summary><font size="2">Typical Cell Imports and Settings for ADS</font></summary>

```python
%load_ext autoreload
%autoreload 2
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')

import logging
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.ERROR)

import ads
from ads.dataset.factory import DatasetFactory
from ads.automl.provider import OracleAutoMLProvider
from ads.automl.driver import AutoML
from ads.evaluations.evaluator import ADSEvaluator
from ads.common.data import ADSData
from ads.explanations.explainer import ADSExplainer
from ads.explanations.mlx_global_explainer import MLXGlobalExplainer
from ads.explanations.mlx_local_explainer import MLXLocalExplainer
from ads.catalog.model import ModelCatalog
from ads.common.model_artifact import ModelArtifact
```
</details>
<details>
<summary><font size="2">Useful Environment Variables</font></summary>

```python
import os
print(os.environ["NB_SESSION_COMPARTMENT_OCID"])
print(os.environ["PROJECT_OCID"])
print(os.environ["USER_OCID"])
print(os.environ["TENANCY_OCID"])
print(os.environ["NB_REGION"])
```
</details>

In [3]:
pip install --upgrade sentence-transformers torch einops

Collecting einops
  Downloading einops-0.8.1-py3-none-any.whl.metadata (13 kB)
Downloading einops-0.8.1-py3-none-any.whl (64 kB)
Installing collected packages: einops
Successfully installed einops-0.8.1
Note: you may need to restart the kernel to use updated packages.


In [None]:
# --- 1. Imports ---
import os
import oci
import pandas as pd
import psycopg2
from psycopg2.pool import SimpleConnectionPool
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from IPython.display import display
import ipywidgets as widgets
import torch
from torch.cuda.amp import autocast
import numpy as np

# --- 2. Configuration ---
# Fix the bucket name typo
BUCKET_NAME = "aus-legal-corpus"  # Corrected from BUCKET_NAME
OBJECT_PREFIX = ""
DOWNLOAD_DIR = "./data"
os.makedirs(DOWNLOAD_DIR, exist_ok=True)

# PostgreSQL Config
DB_CONFIG = {
    "dbname": "postgres",
    "user": "postgres",
    "password": "RAbbithole1234##",
    "host": "",
    "port": "5432"
}

# OCI Config
oci_config = {
    "user": "ocid1.user.oc1..aaq",
    "key_file": "./data/oci_api_key.pem",
    "fingerprint": "de:d6",
    "tenancy": "ocid1.tenancy.oc1..aua",
    "region": "us-sanjose-1"
}


# --- 3. Connect to OCI and Download Parquet Files ---
print("🔍 Listing objects in bucket...")
try:
    object_storage = oci.object_storage.ObjectStorageClient(oci_config)
    namespace = object_storage.get_namespace().data
    objects = object_storage.list_objects(namespace, BUCKET_NAME, prefix=OBJECT_PREFIX).data.objects
    parquet_files = [obj.name for obj in objects if obj.name.endswith(".parquet")]
    
    if not parquet_files:
        raise Exception("❌ No .parquet files found. Check bucket, prefix or region.")
        
    print(f"Found {len(parquet_files)} parquet files")
    
    for obj_name in parquet_files:
        local_file = os.path.join(DOWNLOAD_DIR, os.path.basename(obj_name))
        if not os.path.exists(local_file):
            print(f"⬇️ Downloading {obj_name}...")
            with open(local_file, 'wb') as f:
                response = object_storage.get_object(namespace, BUCKET_NAME, obj_name)
                for chunk in response.data.raw.stream(1024 * 1024, decode_content=False):
                    f.write(chunk)
    print("✅ All Parquet files downloaded.")
except Exception as e:
    print(f"❌ Error in OCI operations: {str(e)}")
    raise

# --- 4. Load Embedding Model ---
print("\nInitializing model...")
try:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    if device == 'cuda':
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"Total memory: {torch.cuda.get_device_properties(0).total_memory/1024**3:.2f} GB")
    
    model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True, device=device)
    
    # Start with smaller batch size and auto-tune
    initial_batch_size = 32 if device == 'cuda' else 8
    model.max_seq_length = 512  # Might help with memory
    
    print("Warming up GPU...")
    with torch.no_grad(), autocast():
        dummy_input = ["warmup"] * initial_batch_size
        _ = model.encode(dummy_input, batch_size=initial_batch_size)
        torch.cuda.empty_cache()
    
    print("✅ Model loaded successfully")
except Exception as e:
    print(f"❌ Error loading model: {str(e)}")
    raise

# --- 5. Connect to PostgreSQL ---
print("\nConnecting to PostgreSQL...")
try:
    pool = SimpleConnectionPool(1, 4, **DB_CONFIG)
    conn = pool.getconn()
    cursor = conn.cursor()
    
    # Test connection
    cursor.execute("SELECT 1")
    conn.commit()
    print("✅ PostgreSQL connection successful")
except Exception as e:
    print(f"❌ PostgreSQL connection failed: {str(e)}")
    raise

# --- 6. Create Table ---
print("\nEnsuring database table exists...")
try:
    cursor.execute("""
    CREATE EXTENSION IF NOT EXISTS vector;
    CREATE TABLE IF NOT EXISTS legal_docs_v3 (
        id SERIAL PRIMARY KEY,
        content TEXT,
        jurisdiction TEXT,
        source TEXT,
        citation TEXT,
        embedding VECTOR(768)
    );
    """)
    conn.commit()
    print("✅ Table created/verified")
except Exception as e:
    print(f"❌ Error creating table: {str(e)}")
    conn.rollback()

# --- 7. Optimized Batch Insert ---
def insert_batch(batch):
    try:
        args_str = ",".join(cursor.mogrify("(%s, %s, %s, %s, %s)", x).decode("utf-8") for x in batch)
        cursor.execute("INSERT INTO legal_docs_v3 (content, jurisdiction, source, citation, embedding) VALUES " + args_str)
        conn.commit()
        return True
    except Exception as e:
        print(f"❌ Batch insert failed: {str(e)}")
        conn.rollback()
        return False

# --- 8. Process Files with Better Progress Tracking ---
def process_files():
    local_files = sorted([f for f in os.listdir(DOWNLOAD_DIR) if f.endswith(".parquet")])
    if not local_files:
        print("❌ No local parquet files found")
        return

    for file_idx, file in enumerate(local_files, 1):
        print(f"\n📂 Processing file {file_idx}/{len(local_files)}: {file}")
        
        try:
            df = pd.read_parquet(os.path.join(DOWNLOAD_DIR, file))
            if "text" not in df.columns:
                print(f"⚠️ Skipping {file}, no 'text' column")
                continue

            # Prepare all texts first
            texts = []
            metadata = []
            for _, row in df.iterrows():
                text = row.get("text", "").strip()
                if text:
                    texts.append(text)
                    metadata.append((
                        text,
                        str(row.get("jurisdiction", "")).strip(),
                        str(row.get("source", "")).strip(),
                        str(row.get("citation", "")).strip()
                    ))

            if not texts:
                print("⚠️ No valid texts found in this file")
                continue

            print(f"📝 Found {len(texts)} texts to process")
            
            # Dynamic batch sizing
            max_batch_size = 128
            current_batch_size = min(32, max_batch_size)
            successful_batches = 0
            
            for i in tqdm(range(0, len(texts), desc="Processing", unit="text"):
                batch_texts = texts[i:i+current_batch_size]
                batch_metadata = metadata[i:i+current_batch_size]
                
                try:
                    with torch.no_grad(), autocast():
                        vectors = model.encode(
                            batch_texts,
                            batch_size=current_batch_size,
                            convert_to_tensor=True,
                            normalize_embeddings=True
                        ).cpu().numpy().tolist()

                    db_batch = [(*meta, vec) for meta, vec in zip(batch_metadata, vectors)]
                    
                    if insert_batch(db_batch):
                        successful_batches += 1
                        # Gradually increase batch size if successful
                        if successful_batches % 5 == 0 and current_batch_size < max_batch_size:
                            current_batch_size = min(current_batch_size * 2, max_batch_size)
                            print(f"⚡ Increased batch size to {current_batch_size}")
                    
                except torch.cuda.OutOfMemoryError:
                    print("⚠️ GPU OOM - reducing batch size")
                    current_batch_size = max(current_batch_size // 2, 8)
                    torch.cuda.empty_cache()
                    continue
                    
                except Exception as e:
                    print(f"⚠️ Batch error: {str(e)} - retrying with smaller batch")
                    current_batch_size = max(current_batch_size // 2, 8)
                    continue

        except Exception as e:
            print(f"❌ Error processing file {file}: {str(e)}")
            continue

# Actually run the processing
process_files()

# --- 9. Create HNSW Vector Index ---
cursor.execute("""
CREATE INDEX IF NOT EXISTS legal_docs_hnsw_idx_v3
ON legal_docs_v3 USING hnsw (embedding vector_cosine_ops)
WITH (m = 16, ef_construction = 64);
""")
cursor.execute("ANALYZE legal_docs_v3;")
conn.commit()
print("✅ HNSW vector index created.")

# --- 10. Sentence Search Interface ---
def search_query(user_query, top_k=5):
    user_vector = model.encode(user_query).tolist()
    query = """
    SELECT content, jurisdiction, source, citation,
           1 - (embedding <#> %s::vector) AS similarity
    FROM legal_docs_v3
    ORDER BY embedding <#> %s::vector
    LIMIT %s;
    """
    cursor.execute(query, (user_vector, user_vector, top_k))
    rows = cursor.fetchall()
    results = pd.DataFrame(rows, columns=["Content", "Jurisdiction", "Source", "Citation", "Similarity"])
    return results

# --- 11. UI to Accept Query ---
input_box = widgets.Textarea(
    placeholder='Ask a legal question...',
    description='Query:',
    layout=widgets.Layout(width='80%', height='100px')
)

button = widgets.Button(description="Search")
output = widgets.Output()

def on_button_click(b):
    with output:
        output.clear_output()
        print("🔍 Searching...")
        result_df = search_query(input_box.value)
        display(result_df)

button.on_click(on_button_click)

display(input_box, button, output)