In [1]:
import logging
import os
import re
import gc
import torch
import numpy as np
from google.colab import drive
from huggingface_hub import login, snapshot_download
from docx import Document
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from fastapi import FastAPI
from pyngrok import ngrok
import uvicorn
import nest_asyncio
from sentence_transformers import SentenceTransformer
import faiss
from typing import List
from google.colab import userdata
from google.colab import drive
from datasets import load_dataset
import json
import os

# Mount Google Drive
drive.mount('/content/drive')
nest_asyncio.apply()

# Configuration
KB_PATH = '/content/drive/MyDrive/lifesciences/training_documents/'
DRIVE_MODEL_PATH = '/content/drive/MyDrive/lifesciences/models/'
MODEL_REPO_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

# Constants
CHUNK_SIZE = 512  # In tokens
CHUNK_OVERLAP = 50
SIMILARITY_THRESHOLD = 0.65
MAX_CONTEXT_LENGTH = 3000

# Initialize logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

# ----------------------------
# Initialization (Run once)
# ----------------------------
class RAGSystem:
    def __init__(self):
        self.tokenizer = None
        self.model = None
        self.embedding_model = None
        self.index = None
        self.chunks = []

    def initialize(self):
        """One-time initialization of all components"""
        self._auth_huggingface()
        model_path = self._download_model()
        self._load_llm(model_path)
        self._load_embedding_model()
        self._process_knowledge_base()

    def _auth_huggingface(self):
        """Authenticate with Hugging Face"""
        try:
            from google.colab import userdata
            login(token=userdata.get("HF_TOKEN"))
            logging.info("Hugging Face authentication successful.")
        except Exception as e:
            logging.error(f"Authentication failed: {str(e)}")
            raise

    def _download_model(self):
        """Download model from Hugging Face Hub"""
        try:
            model_path = snapshot_download(
                repo_id=MODEL_REPO_ID,
                cache_dir=DRIVE_MODEL_PATH,
                revision="main",
                ignore_patterns=["*.msgpack", "*.h5", "*.ot"],
                local_dir=DRIVE_MODEL_PATH,
                local_dir_use_symlinks=False
            )
            logging.info(f"Model downloaded to {model_path}")
            return model_path
        except Exception as e:
            logging.error(f"Model download failed: {str(e)}")
            raise

    def _load_llm(self, model_path):
        """Load LLM with quantization"""
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
            bnb_config = BitsAndBytesConfig(load_in_4bit=True)
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path,
                device_map="auto",
                torch_dtype=torch.float16,
                trust_remote_code=True,
                quantization_config=bnb_config
            )
            logging.info("LLM loaded successfully")
        except Exception as e:
            logging.error(f"Model loading failed: {str(e)}")
            raise

    def _load_embedding_model(self):
        """Load sentence transformer model"""
        self.embedding_model = SentenceTransformer('thenlper/gte-base')
        logging.info("Embedding model loaded")

    def _process_knowledge_base(self):
        """Process documents and create FAISS index"""
        self.chunks = self._chunk_documents()
        embeddings = self._create_embeddings(self.chunks)
        self._create_faiss_index(embeddings)
        logging.info(f"Processed {len(self.chunks)} knowledge chunks")

    def _chunk_documents(self) -> List[str]:
        """Improved document chunking with text cleaning"""
        chunks = []
        for doc_path in self._get_docx_files():
            content = self._read_docx(doc_path)
            print(content)
            content = self._clean_text(content)
            chunks.extend(self._token_based_chunking(content))
        return chunks

    def _get_docx_files(self):
        """Get all DOCX files from knowledge base path"""
        return [os.path.join(KB_PATH, f) for f in os.listdir(KB_PATH) if f.endswith('.docx')]

    def _read_docx(self, file_path: str) -> str:
        """Read DOCX file with error handling"""
        try:
            doc = Document(file_path)
            return "\n".join(p.text for p in doc.paragraphs if p.text.strip())
        except Exception as e:
            logging.error(f"Error reading {file_path}: {str(e)}")
            return ""

    def _clean_text(self, text: str) -> str:
        """Clean and normalize text"""
        text = re.sub(r'\s+', ' ', text)  # Replace multiple whitespace
        text = re.sub(r'\u200b', '', text)  # Remove zero-width spaces
        return text.strip()

    def _token_based_chunking(self, text: str) -> List[str]:
        """Token-aware text chunking with overlap"""
        tokens = self.tokenizer.encode(text, add_special_tokens=False)
        chunks = []
        start = 0
        while start < len(tokens):
            end = min(start + CHUNK_SIZE, len(tokens))
            chunk = tokens[start:end]
            chunks.append(self.tokenizer.decode(chunk, skip_special_tokens=True))
            start = end - CHUNK_OVERLAP if end - CHUNK_OVERLAP > start else end
        return chunks

    def _create_embeddings(self, chunks: List[str]):
        """Create embeddings with batching and memory management"""
        embeddings = []
        batch_size = 64
        for i in range(0, len(chunks), batch_size):
            batch = chunks[i:i+batch_size]
            emb = self.embedding_model.encode(batch, show_progress_bar=False)
            embeddings.append(emb)
            del batch
            gc.collect()
        return np.vstack(embeddings)

    def _create_faiss_index(self, embeddings: np.ndarray):
        """Create optimized FAISS index"""
        dimension = embeddings.shape[1]
        self.index = faiss.IndexFlatIP(dimension)  # Inner product for cosine similarity
        faiss.normalize_L2(embeddings)  # Normalize for cosine similarity
        self.index.add(embeddings)
        logging.info(f"FAISS index created with {self.index.ntotal} vectors")

# ----------------------------
# FastAPI Application
# ----------------------------
app = FastAPI()
rag_system = RAGSystem()

@app.on_event("startup")
async def startup_event():
    rag_system.initialize()
    ngrok.set_auth_token(userdata.get("ngrok_auth_token"))
    public_url = ngrok.connect(8000)
    logging.info(f"API available at: {public_url}")

@app.post("/query")
async def handle_query(query: str):
    try:
        if len(query) < 3:
            return {"error": "Query too short"}

        context = retrieve_context(query)
        response = generate_response(query, context)
        return {"response": response}

    except Exception as e:
        logging.error(f"Error processing query: {str(e)}")
        return {"error": "Processing failed"}

def retrieve_context(query: str) -> str:
    """Retrieve relevant context from knowledge base"""
    query_embed = rag_system.embedding_model.encode([query], show_progress_bar=False)
    faiss.normalize_L2(query_embed)

    # Search with score threshold
    distances, indices = rag_system.index.search(query_embed, 5)
    relevant_chunks = [
        rag_system.chunks[i]
        for i, score in zip(indices[0], distances[0])
        if score > SIMILARITY_THRESHOLD
    ]

    if not relevant_chunks:
        logging.warning("No relevant context found")
        return ""

    return "\n".join(relevant_chunks)[:MAX_CONTEXT_LENGTH]

def generate_response(query: str, context: str) -> str:
    """Generate response using LLM without showing model's 'thinking'"""
    if not context:
        return "I don't have sufficient information to answer that question."

    # Modified prompt to enforce direct answers
    prompt = f"""Answer the question directly using only the provided context.
Do not explain your reasoning or thought process.
If unsure, say "I don't know".

Context: {context}

Question: {query}

Answer:"""

    inputs = rag_system.tokenizer(
        prompt,
        return_tensors="pt",
        max_length=4096,
        truncation=True
    ).to(rag_system.model.device)

    with torch.inference_mode():
        outputs = rag_system.model.generate(
            **inputs,
            max_new_tokens=256,  # Reduced to prevent verbose output
            temperature=0.2,     # More deterministic
            top_p=0.9,
            repetition_penalty=1.5,
            do_sample=False,     # Disable creative sampling
            eos_token_id=rag_system.tokenizer.eos_token_id
        )

    full_response = rag_system.tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Strict answer extraction
    return extract_answer(full_response)

def extract_answer(full_response: str) -> str:
    """Forcefully extract only the answer portion"""
    # Split on the last "Answer:" occurrence
    if "Answer:" in full_response:
        return full_response.rsplit("Answer:", 1)[-1].strip()

    # Fallback patterns
    patterns = [
        r"(?i)(?:final answer|response):\s*(.*)",
        r"(?i)here(?:'s| is) (?:the|my) (?:answer|response):\s*(.*)",
        r"^(.*?)(?:Note:|Explanation:|Reasoning:)",  # Stop at unwanted sections
    ]

    for pattern in patterns:
        match = re.search(pattern, full_response, re.DOTALL)
        if match:
            answer = match.group(1).strip()
            # Remove any markdown formatting
            answer = re.sub(r"\*\*|__|```", "", answer)
            return re.split(r"[\.\?!]\s", answer)[0] + "."  # Take first sentence

    # Final fallback - take first 2 sentences
    sentences = re.split(r"(?<=[.!?])\s+", full_response.strip())
    return " ".join(sentences[:2]).strip()

def start_ngrok():
    ngrok.set_auth_token(userdata.get("ngrok_auth_token"))
    public_url = ngrok.connect(8000)
    logging.info(f"FastAPI app is running at: {public_url}")
    print(f"FastAPI app is running at: {public_url}")
    nest_asyncio.apply()
    uvicorn.run(app, host="0.0.0.0", port=8000)

# ----------------------------
# Main Execution
# ----------------------------
if __name__ == "__main__":
    start_ngrok()


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


        on_event is deprecated, use lifespan event handlers instead.

        Read more about it in the
        [FastAPI docs for Lifespan Events](https://fastapi.tiangolo.com/advanced/events/).
        
  @app.on_event("startup")
INFO:     Started server process [61004]
INFO:     Waiting for application startup.


FastAPI app is running at: NgrokTunnel: "https://f824-34-105-12-8.ngrok-free.app" -> "http://localhost:8000"


For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.


Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

Table of Contents
1 General Requirements
1.1 Utilites
1.2 Facility Requirements
1.3 Safety Requirements
1.4 Environmental Requirements
1.5 Calibration Requirements
1.6 Documentation
1.7 Training Requirements
1.8 Vendor and Warranty Specifications
2 Automation
2.1 Platform and Networking Requirements
2.2 Hardware Requirements
2.3 Software Requirements
2.4 System Performance Requirements
2.5 Historical Data Trending
2.6 Alarm and Events
2.7 Operational Requirements
2.8 User Roles and Access Requirements
2.9 Password Rules
2.10 Time Synchronization
2.11 Security Requirements
2.12 Anti Virus and Patching
2.13 Electronic Signatures (21 CFR 11)
2.14 Electronic Records (21 CFR 11) - Audit Trail 
2.15 Electronic Records (21 CFR 11) - Data Retention
2.16 Electronic Records (21 CFR 11) - Backup and Disaster Recovery
2.17 Interface with Third Party Systems
2.18 Version Control
2.19 Environmental Requirements
2.20 Documentation Requirements
2.21 Training Requirements
2.22 Test System Requirements


INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


INFO:     95.223.75.30:0 - "POST /query?query=%22The%20SCADA%20system%20should%20support%20quality%20control%20%22 HTTP/1.1" 200 OK


INFO:     Shutting down
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.
INFO:     Finished server process [61004]
