# Saka-14B RAG System - Clean Version

**Step-by-step setup for Arabic Mental Health QA**

## Before You Start
1. Select **A100 GPU**: Runtime → Change runtime type → GPU → A100
2. Upload your data files (if not in repo):
   - `knowldege_base/data/processed/articles_all.jsonl`
   - `knowldege_base/data/processed/books_all_ragclean.jsonl`
   - `knowldege_base/data/processed/shifaa_qa_pairs_all.jsonl`
3. Run cells **sequentially** - don't skip steps!


## Step 1: Install Dependencies


In [None]:
%pip install -q torch>=2.0.0 transformers>=4.40.0 accelerate>=0.20.0 bitsandbytes>=0.41.0
%pip install -q sentence-transformers>=2.2.2 chromadb>=0.4.0 rank-bm25
%pip install -q numpy pandas tqdm

# Setup Python path
import sys
import os

if 'colab_files' in os.listdir('.'):
    sys.path.insert(0, 'colab_files')
elif 'knowldege_base' in os.listdir('.'):
    sys.path.insert(0, '.')
else:
    print("⚠️  Please upload colab_files/ or ensure knowldege_base/ is present")

print(f"✅ Python path: {sys.path[0]}")


## Step 2: Verify GPU


In [None]:
import torch
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
    print("⚠️  No GPU! Select A100 GPU runtime.")


## Step 3: Build KB Chunks (One Time)


In [None]:
from knowldege_base.rag_staging.kb_chunker import build_kb_chunks
import os

chunks_file = 'knowldege_base/data/processed/kb_chunks.jsonl'

if os.path.exists(chunks_file):
    print(f"✅ KB chunks exist: {chunks_file}")
    print("   Skipping rebuild. Delete file to rebuild.")
else:
    print("Building KB chunks...")
    build_kb_chunks(output_filename="kb_chunks.jsonl")
    print(f"✅ KB chunks built")


## Step 4: Build Vector Database (One Time)


In [None]:
from knowldege_base.rag_staging.vector_db import VectorDB
import os
import chromadb

vector_db_path = 'knowldege_base/data/vector_db'

if os.path.exists(vector_db_path):
    try:
        client = chromadb.PersistentClient(path=vector_db_path)
        collections = client.list_collections()
        if collections:
            print(f"✅ Vector DB exists with {len(collections)} collection(s)")
            print("   Skipping rebuild. Delete directory to rebuild.")
        else:
            print("⚠️  Rebuilding vector DB...")
            VectorDB.build(force_rebuild=True)
            print("✅ Vector DB rebuilt")
    except Exception as e:
        print(f"⚠️  Error: {e}, rebuilding...")
        VectorDB.build(force_rebuild=True)
        print("✅ Vector DB rebuilt")
else:
    print("Building vector database...")
    VectorDB.build(force_rebuild=False)
    print(f"✅ Vector DB built: {vector_db_path}")


## Step 5: Choose Prompt Version

**Version A: Original Prompt** (no filtering instructions)  
**Version B: Filtering Prompt** (explicit instructions to avoid Quranic content)

Set `PROMPT_VERSION` below to test which works better.


In [None]:
# Choose prompt version: 'original' or 'filtering'
PROMPT_VERSION = 'original'  # Change to 'filtering' to test the other version

print(f"Using prompt version: {PROMPT_VERSION}")


## Step 6: Initialize RAG Pipeline


In [None]:
from knowldege_base.rag_staging.rag_qa import RAGQAPipeline
import os

# Set prompt version via environment variable
os.environ['RAG_PROMPT_VERSION'] = PROMPT_VERSION

print("=" * 80)
print("Initializing Saka-14B RAG Pipeline...")
print("=" * 80)
print("This will:")
print("  1. Load Saka-14B model (~28GB) - 10-15 min first time")
print("  2. Load knowledge base and vector database")
print("  3. Ready to answer questions!")
print("=" * 80)

rag = RAGQAPipeline.build(
    model_name="Sakalti/Saka-14B",
    use_gpu=True,
    load_in_4bit=False,
    load_in_8bit=False,
    max_new_tokens=512,
    download_to_local=False,
)

print("\n" + "=" * 80)
print("✅ RAG Pipeline Ready!")
print("=" * 80)


## Step 7: Test Query


In [None]:
test_query = "أعاني من القلق والتوتر المستمر، ما هي طرق التعامل معه؟"

print("=" * 80)
print("QUESTION:")
print("=" * 80)
print(test_query)
print("\n" + "=" * 80)
print("GENERATING ANSWER...")
print("=" * 80)

result = rag.answer(
    query=test_query,
    top_k=5,
    relevance_threshold=0.5,
)

print("\n" + "=" * 80)
print("ANSWER:")
print("=" * 80)
print(result.answer)
print("\n" + "=" * 80)
print(f"Used KB: {result.used_kb}")
print(f"Top Score: {result.top_score:.4f}")
print(f"Avg Top Score: {result.avg_top_score:.4f}")
print(f"Answer Length: {len(result.answer)} chars")
print("=" * 80)
