In [1]:
# Complete Fixed Implementation - Best Free OCR for Medical Documents
!pip install -q easyocr paddlepaddle paddleocr torch transformers sentence-transformers faiss-cpu
!pip install -q opencv-python pillow pandas numpy networkx matplotlib seaborn
!apt install tesseract-ocr

import easyocr
import cv2
import numpy as np
from PIL import Image
import torch
import re
import json
from datetime import datetime
from pathlib import Path
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Setup directories
drive_root = Path('/content/drive/MyDrive')
lab_reports_dir = drive_root / 'bajaj'
prescriptions_dir = drive_root / 'prescriptions'
output_dir = Path('/content/output')
output_dir.mkdir(exist_ok=True)

print(f"Lab reports: {lab_reports_dir.exists()}")
print(f"Prescriptions: {prescriptions_dir.exists()}")


[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/78.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.2/78.2 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.5/48.5 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.9/2.9 MB[0m [31m101.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m195.0/195.0 MB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.5/65.5 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m70.4/70.4 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.7 MB/s[0m eta [36m0:00:

In [2]:
!pip install verovio

Collecting verovio
  Downloading verovio-5.5.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.2 kB)
Downloading verovio-5.5.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (8.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.7/8.7 MB[0m [31m17.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: verovio
Successfully installed verovio-5.5.0


In [4]:
!pip install 'pytesseract'

Collecting pytesseract
  Downloading pytesseract-0.3.13-py3-none-any.whl.metadata (11 kB)
Downloading pytesseract-0.3.13-py3-none-any.whl (14 kB)
Installing collected packages: pytesseract
Successfully installed pytesseract-0.3.13


In [5]:
# Robust Multi-OCR Processor - Works 100% on Colab
class RobustMedicalOCRProcessor:
    def __init__(self):
        """Initialize multiple OCR engines for maximum reliability"""
        # EasyOCR - excellent for handwritten text
        self.easy_reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())

        # Try to initialize PaddleOCR for better document structure
        try:
            from paddleocr import PaddleOCR
            # Corrected initialization - use_gpu is not a direct argument here
            self.paddle_ocr = PaddleOCR(use_angle_cls=True, lang='en')
            self.has_paddle = True
            # PaddleOCR uses a different mechanism for GPU, often set through environment or model loading
            # We will rely on the default behavior or environment setup for GPU if available.
        except ImportError:
            print("PaddleOCR not available, using EasyOCR only")
            self.paddle_ocr = None
            self.has_paddle = False
        except Exception as e:
            print(f"Error initializing PaddleOCR: {e}")
            self.paddle_ocr = None
            self.has_paddle = False


        # Tesseract as backup
        import pytesseract
        self.tesseract = pytesseract

        print("✅ Robust OCR processor initialized successfully")

    def preprocess_image(self, image_path):
        """Enhanced image preprocessing for medical documents"""
        try:
            # Load image
            image = cv2.imread(str(image_path))
            if image is None:
                return None

            # Convert to RGB
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            # Resize if too large (helps with processing speed)
            height, width = image_rgb.shape[:2]
            if width > 2000:
                scale = 2000 / width
                new_width = int(width * scale)
                new_height = int(height * scale)
                image_rgb = cv2.resize(image_rgb, (new_width, new_height))

            # Enhance contrast for better OCR
            lab = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2LAB)
            clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
            lab[:,:,0] = clahe.apply(lab[:,:,0])
            enhanced = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)

            return enhanced

        except Exception as e:
            print(f"Preprocessing error: {e}")
            return None

    def extract_text_from_image(self, image_path, is_prescription=False):
        """Extract text using multiple OCR engines with fallback"""
        try:
            # Preprocess image
            processed_img = self.preprocess_image(image_path)
            if processed_img is None:
                return {"error": "Could not preprocess image", "text": "", "boxes": []}

            # Try EasyOCR first (best for handwritten)
            easy_text = "" # Initialize easy_text
            easy_confidence = 0.0 # Initialize easy_confidence
            try:
                easy_results = self.easy_reader.readtext(processed_img, batch_size=8, width_ths=0.8)
                easy_text = self._process_easyocr_results(easy_results)
                easy_confidence = self._calculate_easyocr_confidence(easy_results)

                # If EasyOCR gives good results, use it
                if len(easy_text.strip()) > 20 and easy_confidence > 0.4:
                    return {
                        "text": easy_text,
                        "boxes": self._format_easyocr_boxes(easy_results),
                        "confidence_score": easy_confidence,
                        "image_path": str(image_path),
                        "ocr_engine": "EasyOCR"
                    }
            except Exception as e:
                print(f"EasyOCR failed: {e}")

            # Try PaddleOCR if available
            if self.has_paddle:
                try:
                    paddle_results = self.paddle_ocr.ocr(processed_img, cls=True)
                    paddle_text = self._process_paddleocr_results(paddle_results)
                    paddle_confidence = self._calculate_paddleocr_confidence(paddle_results)

                    # If PaddleOCR gives better results, use it
                    if len(paddle_text.strip()) > len(easy_text.strip()) and paddle_confidence > 0.3:
                        return {
                            "text": paddle_text,
                            "boxes": self._format_paddleocr_boxes(paddle_results),
                            "confidence_score": paddle_confidence,
                            "image_path": str(image_path),
                            "ocr_engine": "PaddleOCR"
                        }
                except Exception as e:
                    print(f"PaddleOCR failed: {e}")

            # Fallback to Tesseract
            try:
                # Convert to grayscale for Tesseract
                gray = cv2.cvtColor(processed_img, cv2.COLOR_RGB2GRAY)
                tesseract_text = self.tesseract.image_to_string(gray, config='--psm 6')

                if len(tesseract_text.strip()) > 10:
                    return {
                        "text": tesseract_text,
                        "boxes": [],
                        "confidence_score": 0.7,
                        "image_path": str(image_path),
                        "ocr_engine": "Tesseract"
                    }
            except Exception as e:
                print(f"Tesseract failed: {e}")

            # If all OCR engines fail, return the best result we got
            return {
                "text": easy_text if 'easy_text' in locals() else "",
                "boxes": [],
                "confidence_score": easy_confidence if 'easy_confidence' in locals() else 0.0,
                "image_path": str(image_path),
                "ocr_engine": "Best Available"
            }

        except Exception as e:
            print(f"Complete OCR failure on {image_path}: {str(e)}")
            return {"error": str(e), "text": "", "boxes": []}

    def _process_easyocr_results(self, results):
        """Process EasyOCR results into text"""
        text_blocks = []
        for (bbox, text, confidence) in results:
            if confidence > 0.3 and len(text.strip()) > 0:
                text_blocks.append(text)
        return '\n'.join(text_blocks)

    def _calculate_easyocr_confidence(self, results):
        """Calculate average confidence from EasyOCR"""
        if not results:
            return 0.0
        confidences = [conf for (_, _, conf) in results if conf > 0.3]
        return sum(confidences) / len(confidences) if confidences else 0.0

    def _format_easyocr_boxes(self, results):
        """Format EasyOCR bounding boxes"""
        boxes = []
        for (bbox, text, confidence) in results:
            if confidence > 0.3:
                # Convert bbox to [x1, y1, x2, y2] format
                x_coords = [point[0] for point in bbox]
                y_coords = [point[1] for point in bbox]
                boxes.append({
                    'text': text,
                    'bbox': [min(x_coords), min(y_coords), max(x_coords), max(y_coords)],
                    'confidence': confidence
                })
        return boxes

    def _process_paddleocr_results(self, results):
        """Process PaddleOCR results into text"""
        if not results or not results[0]:
            return ""

        text_blocks = []
        for line in results[0]:
            if line:
                bbox, (text, confidence) = line # Unpack bbox and (text, confidence) tuple
                if confidence > 0.3 and len(text.strip()) > 0:
                    text_blocks.append(text)
        return '\n'.join(text_blocks)

    def _calculate_paddleocr_confidence(self, results):
        """Calculate average confidence from PaddleOCR"""
        if not results or not results[0]:
            return 0.0

        confidences = []
        for line in results[0]:
            if line:
                bbox, (text, confidence) = line # Unpack bbox and (text, confidence) tuple
                if confidence > 0.3:
                    confidences.append(confidence)
        return sum(confidences) / len(confidences) if confidences else 0.0

    def _format_paddleocr_boxes(self, results):
        """Format PaddleOCR bounding boxes"""
        if not results or not results[0]:
            return []

        boxes = []
        for line in results[0]:
            if line:
                bbox, (text, confidence) = line # Unpack bbox and (text, confidence) tuple
                if confidence > 0.3:
                    # Convert bbox to [x1, y1, x2, y2] format
                    x_coords = [point[0] for point in bbox]
                    y_coords = [point[1] for point in bbox]
                    boxes.append({
                        'text': text,
                        'bbox': [min(x_coords), min(y_coords), max(x_coords), max(y_coords)],
                        'confidence': confidence
                    })
        return boxes


# Initialize the robust OCR processor
print("🚀 Loading Robust Medical OCR Processor...")
ocr_processor = RobustMedicalOCRProcessor()

🚀 Loading Robust Medical OCR Processor...


[32mCreating model: ('PP-LCNet_x1_0_doc_ori', None)[0m
[32mUsing official model (PP-LCNet_x1_0_doc_ori), the model files will be automatically downloaded and saved in /root/.paddlex/official_models.[0m


Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

[32mCreating model: ('UVDoc', None)[0m
[33mThe model(UVDoc) is not supported to run in MKLDNN mode! Using `paddle` instead![0m
[32mUsing official model (UVDoc), the model files will be automatically downloaded and saved in /root/.paddlex/official_models.[0m


Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

[32mCreating model: ('PP-LCNet_x1_0_textline_ori', None)[0m
[32mUsing official model (PP-LCNet_x1_0_textline_ori), the model files will be automatically downloaded and saved in /root/.paddlex/official_models.[0m


Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

[32mCreating model: ('PP-OCRv5_server_det', None)[0m
[32mUsing official model (PP-OCRv5_server_det), the model files will be automatically downloaded and saved in /root/.paddlex/official_models.[0m


Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

[32mCreating model: ('PP-OCRv5_server_rec', None)[0m
[32mUsing official model (PP-OCRv5_server_rec), the model files will be automatically downloaded and saved in /root/.paddlex/official_models.[0m


Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

✅ Robust OCR processor initialized successfully


In [6]:
# Enhanced Information Extraction (same as before but optimized)
class MedicalInformationExtractor:
    def __init__(self):
        # Enhanced medical patterns for better extraction
        self.medicine_patterns = [
            r'\b([A-Z][a-z]+(?:mycin|cillin|azole|pril|sartan|olol|pine|statin|formin|lol|ide))\b',
            r'\b([A-Z][a-z]+)\s+(\d+(?:\.\d+)?)\s*(mg|g|ml|mcg|units?|IU)\b',
            r'\b(Tab|Cap|Syp|Inj|Tablet|Capsule)\.?\s*([A-Z][a-z]+)\s*(\d+(?:\.\d+)?)\s*(mg|g|ml|mcg)?',
            r'\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)?)\s*[-–]\s*(\d+(?:\.\d+)?)\s*(mg|g|ml|mcg)\b',
            r'\b(Paracetamol|Aspirin|Ibuprofen|Metformin|Insulin|Amlodipine|Lisinopril|Atorvastatin|Omeprazole|Losartan)\b'
        ]

        self.lab_test_patterns = [
            r'\b(Hemoglobin|Hb|HbA1c|HBA1C|Glucose|Blood Sugar|Random Sugar|Fasting Sugar|Cholesterol|HDL|LDL|Triglycerides|Creatinine|Urea|BUN|TSH|T3|T4|FT3|FT4)\b',
            r'\b(WBC|RBC|Platelet|PLT|Hematocrit|ESR|CRP|SGPT|SGOT|ALT|AST|Bilirubin|Total Bilirubin|Direct Bilirubin)\b',
            r'\b(Sodium|Potassium|Chloride|CO2|Calcium|Phosphorus|Magnesium|Iron|Ferritin)\b',
            r'\b(Total\s+Protein|Albumin|Globulin|A/G Ratio)\b'
        ]

        self.patient_name_patterns = [
            r'(?:Patient|Name|Pt\.?|Mr\.?|Mrs\.?|Ms\.?|Miss)\s*:?\s*([A-Z][a-z]+(?:\s+[A-Z][a-z]+)+)',
            r'Name\s*:?\s*([A-Z][a-z]+(?:\s+[A-Z][a-z]+)+)',
            r'(?:^|\n)([A-Z][a-z]+\s+[A-Z][a-z]+(?:\s+[A-Z][a-z]+)?)(?:\s|$|\n)',
        ]

    def extract_from_prescription(self, ocr_result):
        """Extract structured information from prescription OCR"""
        text = ocr_result.get('text', '')

        return {
            'document_type': 'prescription',
            'patient_name': self._extract_patient_name(text),
            'date': self._extract_date(text),
            'prescriber': self._extract_prescriber(text),
            'medications': self._extract_medications(text),
            'raw_text': text,
            'confidence_score': ocr_result.get('confidence_score', 0),
            'ocr_engine': ocr_result.get('ocr_engine', 'Unknown')
        }

    def extract_from_lab_report(self, ocr_result):
        """Extract structured information from lab report OCR"""
        text = ocr_result.get('text', '')

        return {
            'document_type': 'lab_report',
            'patient_name': self._extract_patient_name(text),
            'date': self._extract_date(text),
            'lab_tests': self._extract_lab_tests(text),
            'diagnoses': self._extract_diagnoses(text),
            'raw_text': text,
            'confidence_score': ocr_result.get('confidence_score', 0),
            'ocr_engine': ocr_result.get('ocr_engine', 'Unknown')
        }

    def _extract_patient_name(self, text):
        """Extract patient name using enhanced patterns"""
        for pattern in self.patient_name_patterns:
            matches = re.finditer(pattern, text, re.MULTILINE | re.IGNORECASE)
            for match in matches:
                name = match.group(1).strip()
                if (len(name) > 3 and
                    not any(word in name.lower() for word in ['report', 'lab', 'test', 'date', 'hospital', 'clinic'])):
                    return name
        return None

    def _extract_date(self, text):
        """Extract date from text"""
        date_patterns = [
            r'(\d{1,2}[-/]\d{1,2}[-/]\d{2,4})',
            r'(\d{1,2}\s+(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s+\d{2,4})',
            r'(?:Date|date|Date:)\s*(\d{1,2}[-/]\d{1,2}[-/]\d{2,4})',
            r'(\d{4}[-/]\d{1,2}[-/]\d{1,2})'
        ]

        for pattern in date_patterns:
            match = re.search(pattern, text, re.IGNORECASE)
            if match:
                try:
                    date_str = match.group(1)
                    for fmt in ['%d/%m/%Y', '%d-%m-%Y', '%d %b %Y', '%d/%m/%y', '%Y-%m-%d']:
                        try:
                            return datetime.strptime(date_str, fmt).isoformat()
                        except:
                            continue
                except:
                    pass
        return None

    def _extract_prescriber(self, text):
        """Extract prescriber information"""
        patterns = [
            r'(?:Dr\.?|Doctor|Physician)\s*:?\s*([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)'
        ]

        for pattern in patterns:
            match = re.search(pattern, text, re.IGNORECASE)
            if match:
                return match.group(1).strip()
        return None

    def _extract_medications(self, text):
        """Extract medication information"""
        medications = []
        seen_names = set()

        for pattern in self.medicine_patterns:
            matches = re.finditer(pattern, text, re.IGNORECASE)
            for match in matches:
                # Handle different pattern structures
                if len(match.groups()) >= 2:
                    if match.group(0).startswith(('Tab', 'Cap', 'Syp', 'Inj')):
                        med_name = match.group(2)
                        dosage = f"{match.group(3)} {match.group(4)}" if len(match.groups()) > 3 else None
                    else:
                        med_name = match.group(1)
                        dosage = f"{match.group(2)} {match.group(3)}" if len(match.groups()) > 2 else None
                else:
                    med_name = match.group(1)
                    dosage = None

                if med_name.lower() not in seen_names:
                    medications.append({
                        'name': med_name,
                        'dosage': dosage,
                        'context': text[max(0, match.start()-50):match.end()+50].strip()
                    })
                    seen_names.add(med_name.lower())

        return medications

    def _extract_lab_tests(self, text):
        """Extract lab test results"""
        tests = []
        seen_tests = set()

        for pattern in self.lab_test_patterns:
            matches = re.finditer(pattern, text, re.IGNORECASE)
            for match in matches:
                test_name = match.group(1)

                if test_name.lower() not in seen_tests:
                    # Look for values near the test name
                    context = text[match.start():match.start()+150]
                    value_match = re.search(r'(\d+(?:\.\d+)?)\s*(mg/dl|g/dl|%|units)', context, re.IGNORECASE)

                    tests.append({
                        'test_name': test_name,
                        'value': value_match.group(1) if value_match else None,
                        'unit': value_match.group(2) if value_match else None,
                        'context': context.strip()
                    })
                    seen_tests.add(test_name.lower())

        return tests

    def _extract_diagnoses(self, text):
        """Extract diagnosis information"""
        patterns = [
            r'(?:Diagnosis|Impression|Clinical\s+Findings?|Conclusion)\s*:?\s*([A-Z][a-z]+(?:\s+[a-z]+)*)',
            r'(?:suggests?|indicates?|shows?)\s*:?\s*([A-Z][a-z]+(?:\s+[a-z]+)*)'
        ]

        diagnoses = []
        for pattern in patterns:
            matches = re.finditer(pattern, text, re.IGNORECASE)
            for match in matches:
                diagnosis = match.group(1).strip()
                if len(diagnosis) > 3:
                    diagnoses.append(diagnosis)

        return list(set(diagnoses))

# Initialize components
ie_processor = MedicalInformationExtractor()

# Knowledge Graph (same as before)
import networkx as nx

class MedicalKnowledgeGraph:
    def __init__(self):
        self.graph = nx.MultiDiGraph()
        self.node_counter = 0
        self.entity_index = defaultdict(list)

    def add_patient_record(self, extracted_data):
        patient_id = self._get_or_create_patient(extracted_data)
        document_id = self._create_document_node(extracted_data)

        self.graph.add_edge(patient_id, document_id, relation='has_document')

        if extracted_data['document_type'] == 'prescription':
            self._process_prescription_data(patient_id, document_id, extracted_data)
        elif extracted_data['document_type'] == 'lab_report':
            self._process_lab_report_data(patient_id, document_id, extracted_data)

        return patient_id, document_id

    def _get_or_create_patient(self, data):
        patient_name = data.get('patient_name') or f"Unknown_Patient_{self.node_counter}"
        existing = self.entity_index.get(f"patient_{patient_name.lower()}", [])

        if existing:
            return existing[0]

        patient_id = f"patient_{self.node_counter}"
        self.node_counter += 1

        self.graph.add_node(patient_id, type='patient', name=patient_name, created_at=datetime.now().isoformat())
        self.entity_index[f"patient_{patient_name.lower()}"].append(patient_id)
        return patient_id

    def _create_document_node(self, data):
        doc_id = f"doc_{self.node_counter}"
        self.node_counter += 1

        self.graph.add_node(doc_id, type='document', document_type=data['document_type'],
                           date=data.get('date'), raw_text=data.get('raw_text', ''),
                           confidence=data.get('confidence_score', 0),
                           ocr_engine=data.get('ocr_engine', 'Unknown'))
        return doc_id

    def _process_prescription_data(self, patient_id, doc_id, data):
        prescription_id = f"prescription_{self.node_counter}"
        self.node_counter += 1

        self.graph.add_node(prescription_id, type='prescription', date=data.get('date'),
                           prescriber=data.get('prescriber'))

        self.graph.add_edge(patient_id, prescription_id, relation='has_prescription')
        self.graph.add_edge(doc_id, prescription_id, relation='contains_prescription')

        for med in data.get('medications', []):
            med_id = self._get_or_create_medication(med['name'])
            self.graph.add_edge(prescription_id, med_id, relation='prescribes',
                              dosage=med.get('dosage'))

    def _process_lab_report_data(self, patient_id, doc_id, data):
        encounter_id = f"encounter_{self.node_counter}"
        self.node_counter += 1

        self.graph.add_node(encounter_id, type='encounter', date=data.get('date'))
        self.graph.add_edge(patient_id, encounter_id, relation='has_encounter')
        self.graph.add_edge(doc_id, encounter_id, relation='documents_encounter')

        for test in data.get('lab_tests', []):
            test_id = self._get_or_create_lab_test(test['test_name'])
            result_id = f"result_{self.node_counter}"
            self.node_counter += 1

            self.graph.add_node(result_id, type='lab_result', value=test.get('value'),
                               unit=test.get('unit'), date=data.get('date'))

            self.graph.add_edge(encounter_id, result_id, relation='has_result')
            self.graph.add_edge(result_id, test_id, relation='result_of_test')

    def _get_or_create_medication(self, med_name):
        key = f"medication_{med_name.lower()}"
        existing = self.entity_index.get(key, [])
        if existing:
            return existing[0]

        med_id = f"medication_{self.node_counter}"
        self.node_counter += 1
        self.graph.add_node(med_id, type='medication', name=med_name)
        self.entity_index[key].append(med_id)
        return med_id

    def _get_or_create_lab_test(self, test_name):
        key = f"lab_test_{test_name.lower()}"
        existing = self.entity_index.get(key, [])
        if existing:
            return existing[0]

        test_id = f"lab_test_{self.node_counter}"
        self.node_counter += 1
        self.graph.add_node(test_id, type='lab_test', name=test_name)
        self.entity_index[key].append(test_id)
        return test_id

    def query_medication_patterns(self):
        patterns = defaultdict(int)
        prescription_nodes = [n for n, d in self.graph.nodes(data=True) if d.get('type') == 'prescription']

        for prescription_id in prescription_nodes:
            for neighbor in self.graph.successors(prescription_id):
                node_data = self.graph.nodes[neighbor]
                if node_data.get('type') == 'medication':
                    patterns[node_data.get('name')] += 1

        sorted_patterns = sorted(patterns.items(), key=lambda x: x[1], reverse=True)
        return {'common_medications': sorted_patterns[:10], 'total_prescriptions': len(prescription_nodes)}

# Initialize Knowledge Graph
kg = MedicalKnowledgeGraph()


In [7]:
# Main Processing Pipeline - FIXED VERSION
def process_medical_documents_robust():
    """Process all medical documents using robust OCR with proper error handling"""
    processed_documents = []

    print("🔍 Processing Lab Reports with Robust OCR...")

    # Get lab report files
    lab_files = []
    if lab_reports_dir.exists():
        lab_files = (list(lab_reports_dir.rglob('*.png')) +
                    list(lab_reports_dir.rglob('*.jpg')) +
                    list(lab_reports_dir.rglob('*.jpeg')))

    print(f"Found {len(lab_files)} lab report files")

    # Process all lab reports
    success_count = 0
    for i, img_path in enumerate(lab_files):
        print(f"Processing lab report {i+1}/{len(lab_files)}: {img_path.name}")

        try:
            # Use robust OCR
            ocr_result = ocr_processor.extract_text_from_image(img_path, is_prescription=False)

            if ocr_result.get('text') and len(ocr_result['text'].strip()) > 20:
                # Information Extraction
                extracted_data = ie_processor.extract_from_lab_report(ocr_result)
                extracted_data['image_path'] = str(img_path)

                # Add to Knowledge Graph
                patient_id, doc_id = kg.add_patient_record(extracted_data)
                extracted_data['patient_id'] = patient_id
                extracted_data['doc_id'] = doc_id

                processed_documents.append(extracted_data)
                success_count += 1

                print(f"✅ SUCCESS: Patient={extracted_data.get('patient_name', 'Unknown')}, "
                      f"Tests={len(extracted_data.get('lab_tests', []))}, "
                      f"Engine={extracted_data.get('ocr_engine', 'Unknown')}")
            else:
                print(f"❌ No meaningful text extracted from {img_path.name}")

        except Exception as e:
            print(f"❌ Error processing {img_path.name}: {str(e)}")
            continue

    print(f"\n📊 Lab Reports: {success_count}/{len(lab_files)} processed successfully")

    print(f"\n🔍 Processing Prescription Images...")

    # Get prescription files
    prescription_files = []
    if prescriptions_dir.exists():
        prescription_files = (list(prescriptions_dir.rglob('*.png')) +
                             list(prescriptions_dir.rglob('*.jpg')) +
                             list(prescriptions_dir.rglob('*.jpeg')))

    print(f"Found {len(prescription_files)} prescription files")

    # Process all prescriptions
    prescription_success = 0
    for i, img_path in enumerate(prescription_files):
        print(f"Processing prescription {i+1}/{len(prescription_files)}: {img_path.name}")

        try:
            # Use robust OCR
            ocr_result = ocr_processor.extract_text_from_image(img_path, is_prescription=True)

            if ocr_result.get('text') and len(ocr_result['text'].strip()) > 20:
                # Information Extraction
                extracted_data = ie_processor.extract_from_prescription(ocr_result)
                extracted_data['image_path'] = str(img_path)

                # Add to Knowledge Graph
                patient_id, doc_id = kg.add_patient_record(extracted_data)
                extracted_data['patient_id'] = patient_id
                extracted_data['doc_id'] = doc_id

                processed_documents.append(extracted_data)
                prescription_success += 1

                print(f"✅ SUCCESS: Patient={extracted_data.get('patient_name', 'Unknown')}, "
                      f"Meds={len(extracted_data.get('medications', []))}, "
                      f"Engine={extracted_data.get('ocr_engine', 'Unknown')}")
            else:
                print(f"❌ No meaningful text extracted from {img_path.name}")

        except Exception as e:
            print(f"❌ Error processing {img_path.name}: {str(e)}")
            continue

    print(f"\n📊 Prescriptions: {prescription_success}/{len(prescription_files)} processed successfully")

    return processed_documents

# Run the fixed processing pipeline
print("🚀 Starting FIXED Medical Document Processing...")
processed_docs = process_medical_documents_robust()

print(f"\n📊 FINAL PROCESSING SUMMARY:")
print(f"Total documents processed: {len(processed_docs)}")
print(f"Lab reports: {sum(1 for doc in processed_docs if doc['document_type'] == 'lab_report')}")
print(f"Prescriptions: {sum(1 for doc in processed_docs if doc['document_type'] == 'prescription')}")

if processed_docs:
    print(f"Knowledge graph nodes: {kg.graph.number_of_nodes()}")
    print(f"Knowledge graph edges: {kg.graph.number_of_edges()}")

    # Show OCR engine statistics
    engine_stats = defaultdict(int)
    for doc in processed_docs:
        engine_stats[doc.get('ocr_engine', 'Unknown')] += 1

    print(f"\n🔧 OCR Engine Usage:")
    for engine, count in engine_stats.items():
        print(f"  {engine}: {count} documents")

    # Save results
    with open('/content/output/processed_documents_robust.json', 'w') as f:
        json.dump(processed_docs, f, indent=2, default=str)

    print(f"\n💾 Documents saved to: /content/output/processed_documents_robust.json")
    print("✅ ROBUST MEDICAL RAG SYSTEM READY!")
else:
    print("❌ No documents were processed successfully. Check your image files and directory paths.")


🚀 Starting FIXED Medical Document Processing...
🔍 Processing Lab Reports with Robust OCR...
Found 426 lab report files
Processing lab report 1/426: BLR-0425-PA-0037318_SASHANK P K 0037318 2 OF 2_28-04-2025_1007-19_AM@E.pdf_page_29.png
✅ SUCCESS: Patient=UHID
AIGG, Tests=0, Engine=EasyOCR
Processing lab report 2/426: BLR-0425-PA-0039192_E-ParmeshwarRunningBill_250426_1612@E.pdf_page_93.png
✅ SUCCESS: Patient=ID
Patient, Tests=0, Engine=EasyOCR
Processing lab report 3/426: BLR-0425-PA-0037318_SASHANK P K 0037318 2 OF 2_28-04-2025_1007-19_AM@E.pdf_page_33.png
✅ SUCCESS: Patient=UHID
AIGG, Tests=0, Engine=EasyOCR
Processing lab report 4/426: BLR-0425-PA-0038965_BIPUL CHAKRABORTY 0038965 2 OF 2_28-04-2025_1014-26_AM.pdf_page_16.png
✅ SUCCESS: Patient=WITH INR AND APTT
Ctared Piasttx, Tests=0, Engine=EasyOCR
Processing lab report 5/426: AHD-0425-PA-0007719_E-REPORTS_250427_2032@E.pdf_page_7.png
✅ SUCCESS: Patient=FULLY
DIAGnOSTIC
COMPUTERISED, Tests=5, Engine=EasyOCR
Processing lab report 6/

In [8]:
# RAG System Module with Sentence Transformers
from sentence_transformers import SentenceTransformer
import faiss

class MedicalRAGSystem:
    def __init__(self, knowledge_graph):
        self.kg = knowledge_graph
        print(" Loading sentence transformer model...")
        self.embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
        self.document_chunks = []
        self.chunk_embeddings = None
        self.faiss_index = None
        print("RAG embedder initialized")

    def index_documents(self, processed_documents):
        """Create vector index for document retrieval"""
        print("Creating document chunks and embeddings...")
        self.document_chunks = []
        chunk_texts = []

        for doc in processed_documents:
            chunks = self._chunk_text(doc.get('raw_text', ''), chunk_size=300)
            for i, chunk in enumerate(chunks):
                if len(chunk.strip()) > 20:  # Only index meaningful chunks
                    chunk_data = {
                        'text': chunk,
                        'document_type': doc.get('document_type'),
                        'patient_name': doc.get('patient_name'),
                        'date': doc.get('date'),
                        'chunk_id': f"{doc.get('patient_name', 'unknown')}_{i}",
                        'source_doc': doc,
                        'ocr_engine': doc.get('ocr_engine', 'Unknown'),
                        'confidence': doc.get('confidence_score', 0)
                    }
                    self.document_chunks.append(chunk_data)
                    chunk_texts.append(chunk)

        if chunk_texts:
            print(f"Creating embeddings for {len(chunk_texts)} chunks...")
            embeddings = self.embedder.encode(chunk_texts, show_progress_bar=True)
            self.chunk_embeddings = np.array(embeddings)

            # Create FAISS index
            dimension = embeddings.shape[1]
            self.faiss_index = faiss.IndexFlatIP(dimension)  # Inner product for cosine similarity

            # Normalize embeddings for cosine similarity
            faiss.normalize_L2(self.chunk_embeddings)
            self.faiss_index.add(self.chunk_embeddings)
            print(f" FAISS index created with {len(chunk_texts)} vectors")
        else:
            print(" No meaningful chunks to index!")

    def _chunk_text(self, text, chunk_size=300, overlap=50):
        """Split text into overlapping chunks"""
        words = text.split()
        chunks = []

        for i in range(0, len(words), chunk_size - overlap):
            chunk = ' '.join(words[i:i + chunk_size])
            if chunk.strip():
                chunks.append(chunk)

        return chunks

    def retrieve_relevant_chunks(self, query, top_k=5):
        """Retrieve relevant document chunks using semantic search"""
        if self.faiss_index is None:
            return []

        # Encode query
        query_embedding = self.embedder.encode([query])
        faiss.normalize_L2(query_embedding)

        # Search
        scores, indices = self.faiss_index.search(query_embedding, top_k)

        results = []
        for score, idx in zip(scores[0], indices[0]):
            if idx < len(self.document_chunks) and idx != -1:
                chunk = self.document_chunks[idx].copy()
                chunk['relevance_score'] = float(score)
                results.append(chunk)

        return results

    def query_knowledge_graph(self, query):
        """Query knowledge graph for structured information"""
        query_lower = query.lower()

        # Patient history queries
        if any(word in query_lower for word in ['history', 'patient', 'records', 'timeline']):
            words = query.split()
            potential_names = [w for w in words if len(w) > 2 and w[0].isupper()]

            if potential_names:
                patient_name = ' '.join(potential_names[:2])
                return self.kg.query_patient_history(patient_name)

        # Medication pattern queries
        elif any(word in query_lower for word in ['prescription', 'medication', 'drug', 'medicine', 'common']):
            return self.kg.query_medication_patterns()

        # Lab test queries
        elif any(word in query_lower for word in ['lab', 'test', 'result', 'abnormal', 'blood']):
            return self.kg.query_lab_patterns()

        return {"message": "No specific KG query pattern matched"}

    def generate_answer(self, query):
        """Generate comprehensive answer using RAG approach"""
        print(f"🔍 Processing query: '{query}'")

        # Step 1: Retrieve relevant chunks
        relevant_chunks = self.retrieve_relevant_chunks(query, top_k=5)

        # Step 2: Query knowledge graph
        kg_results = self.query_knowledge_graph(query)

        # Step 3: Combine and format response
        response = {
            'query': query,
            'kg_results': kg_results,
            'relevant_chunks': relevant_chunks,
            'answer': self._format_answer(query, kg_results, relevant_chunks),
            'metadata': {
                'chunks_found': len(relevant_chunks),
                'kg_pattern_matched': kg_results.get('message') != "No specific KG query pattern matched",
                'avg_relevance_score': np.mean([c.get('relevance_score', 0) for c in relevant_chunks]) if relevant_chunks else 0
            }
        }

        return response

    def _format_answer(self, query, kg_results, chunks):
        """Format final answer from retrieved information"""
        answer_parts = []

        # Add Knowledge Graph insights
        if 'patient_name' in kg_results:
            patient_name = kg_results['patient_name']
            prescriptions = kg_results.get('prescriptions', [])
            lab_results = kg_results.get('lab_results', [])

            answer_parts.append(f"**Medical History for {patient_name}:**")

            if prescriptions:
                answer_parts.append(f"📋 Found {len(prescriptions)} prescription records:")
                for i, presc in enumerate(prescriptions[:3], 1):
                    medications = ', '.join([med['name'] for med in presc.get('medications', [])])
                    date_str = presc.get('date', 'Unknown date')
                    prescriber = presc.get('prescriber', 'Unknown prescriber')
                    answer_parts.append(f"  {i}. **Date:** {date_str} | **Prescriber:** {prescriber}")
                    answer_parts.append(f"     **Medications:** {medications}")

            if lab_results:
                answer_parts.append(f" Lab Results: {len(lab_results)} tests found")

        elif 'common_medications' in kg_results:
            meds = kg_results['common_medications']
            total = kg_results['total_prescriptions']

            answer_parts.append(f"**Medication Analysis from {total} prescriptions:**")
            answer_parts.append("**Most commonly prescribed medications:**")

            for i, (med, count) in enumerate(meds[:5], 1):
                percentage = (count / total * 100) if total > 0 else 0
                answer_parts.append(f"  {i}. **{med}**: {count} prescriptions ({percentage:.1f}%)")

        elif 'common_tests' in kg_results:
            tests = kg_results['common_tests']
            answer_parts.append(" **Most common lab tests:**")
            for i, (test, count) in enumerate(tests[:5], 1):
                answer_parts.append(f"  {i}. **{test}**: {count} occurrences")

        # Add relevant document excerpts
        if chunks:
            answer_parts.append(f"\n **Supporting Evidence from Documents:**")
            for i, chunk in enumerate(chunks[:3], 1):
                relevance = chunk.get('relevance_score', 0)
                doc_type = chunk.get('document_type', 'document').title()
                patient = chunk.get('patient_name', 'Unknown patient')
                engine = chunk.get('ocr_engine', 'Unknown')

                answer_parts.append(f"**{i}. {doc_type}** (Patient: {patient}, OCR: {engine}, Relevance: {relevance:.3f})")
                text_preview = chunk['text'][:200].replace('\n', ' ')
                answer_parts.append(f"   _{text_preview}..._")

        # Add metadata about search quality
        if chunks:
            avg_relevance = np.mean([c.get('relevance_score', 0) for c in chunks])
            answer_parts.append(f"\n**Search Quality:** Average relevance score: {avg_relevance:.3f}")

        return '\n'.join(answer_parts) if answer_parts else "No relevant information found in the processed documents."

# Add query_patient_history and query_lab_patterns methods to KnowledgeGraph
def add_missing_kg_methods():
    """Add missing methods to the knowledge graph"""

    def query_patient_history(self, patient_name):
        """Query patient's medical history"""
        patient_key = f"patient_{patient_name.lower()}"
        patient_ids = self.entity_index.get(patient_key, [])

        if not patient_ids:
            return {"error": f"Patient {patient_name} not found"}

        patient_id = patient_ids[0]
        history = {
            'patient_name': patient_name,
            'prescriptions': [],
            'lab_results': [],
            'diagnoses': []
        }

        # Get prescriptions
        for neighbor in self.graph.successors(patient_id):
            node_data = self.graph.nodes[neighbor]
            if node_data.get('type') == 'prescription':
                medications = []
                for med_neighbor in self.graph.successors(neighbor):
                    med_data = self.graph.nodes[med_neighbor]
                    if med_data.get('type') == 'medication':
                        edge_data = self.graph.get_edge_data(neighbor, med_neighbor)
                        medications.append({
                            'name': med_data.get('name'),
                            'dosage': list(edge_data.values())[0].get('dosage')
                        })

                history['prescriptions'].append({
                    'date': node_data.get('date'),
                    'prescriber': node_data.get('prescriber'),
                    'medications': medications
                })

            elif node_data.get('type') == 'encounter':
                # Get lab results for this encounter
                for result_neighbor in self.graph.successors(neighbor):
                    result_data = self.graph.nodes[result_neighbor]
                    if result_data.get('type') == 'lab_result':
                        history['lab_results'].append({
                            'value': result_data.get('value'),
                            'unit': result_data.get('unit'),
                            'date': result_data.get('date')
                        })

        return history

    def query_lab_patterns(self):
        """Query common lab test patterns"""
        patterns = defaultdict(int)
        test_nodes = [n for n, d in self.graph.nodes(data=True) if d.get('type') == 'lab_test']

        for test_id in test_nodes:
            test_data = self.graph.nodes[test_id]
            test_name = test_data.get('name')
            if test_name:
                patterns[test_name] += 1

        sorted_patterns = sorted(patterns.items(), key=lambda x: x[1], reverse=True)
        return {
            'common_tests': sorted_patterns[:10],
            'total_tests': len(test_nodes)
        }

    # Add methods to the class
    MedicalKnowledgeGraph.query_patient_history = query_patient_history
    MedicalKnowledgeGraph.query_lab_patterns = query_lab_patterns

# Apply the missing methods
add_missing_kg_methods()

# Initialize RAG system if documents were processed successfully
if 'processed_docs' in locals() and processed_docs:
    print(" Initializing RAG System...")
    rag_system = MedicalRAGSystem(kg)
    rag_system.index_documents(processed_docs)
    print("RAG System ready for queries!")
else:
    print("No processed documents found. Run the processing pipeline first.")
    rag_system = None


 Initializing RAG System...
 Loading sentence transformer model...


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

RAG embedder initialized
Creating document chunks and embeddings...
Creating embeddings for 583 chunks...


Batches:   0%|          | 0/19 [00:00<?, ?it/s]

 FAISS index created with 583 vectors
RAG System ready for queries!


In [9]:
# Complete Interactive Demo System
def run_comprehensive_medical_rag_demo():
    """Run comprehensive demonstration of the medical RAG system"""

    if not processed_docs:
        print(" No processed documents found! Run the processing pipeline first.")
        return

    print("COMPREHENSIVE MEDICAL RAG SYSTEM DEMONSTRATION")
    print("=" * 60)

    # System Statistics
    print(f"\n **SYSTEM OVERVIEW:**")
    print(f"    Total documents processed: {len(processed_docs)}")
    print(f"    Lab reports: {sum(1 for doc in processed_docs if doc['document_type'] == 'lab_report')}")
    print(f"   Prescriptions: {sum(1 for doc in processed_docs if doc['document_type'] == 'prescription')}")
    print(f"    Knowledge graph nodes: {kg.graph.number_of_nodes()}")
    print(f"    Knowledge graph edges: {kg.graph.number_of_edges()}")
    print(f"    Vector index size: {len(rag_system.document_chunks) if rag_system else 0}")

    # OCR Engine Statistics
    engine_stats = defaultdict(int)
    confidence_scores = []
    for doc in processed_docs:
        engine_stats[doc.get('ocr_engine', 'Unknown')] += 1
        if doc.get('confidence_score'):
            confidence_scores.append(doc['confidence_score'])

    print(f"\n **OCR ENGINE PERFORMANCE:**")
    for engine, count in engine_stats.items():
        percentage = (count / len(processed_docs)) * 100
        print(f"   {engine}: {count} documents ({percentage:.1f}%)")

    if confidence_scores:
        avg_confidence = sum(confidence_scores) / len(confidence_scores)
        print(f"   📈 Average OCR confidence: {avg_confidence:.3f}")

    # Demo Queries
    demo_queries = [
        "What are the most commonly prescribed medications?",
        "Show me lab test results and abnormal values",
        "List all patients found in the documents",
        "Find medications with dosage information",
        "Show me diabetes related prescriptions"
    ]

    print(f"\n **RUNNING DEMO QUERIES:**")

    for i, query in enumerate(demo_queries, 1):
        print(f"\n{'='*50}")
        print(f" **Demo Query {i}:** {query}")
        print("-" * 30)

        try:
            response = rag_system.generate_answer(query)
            print(response['answer'])

            # Show metadata
            metadata = response.get('metadata', {})
            print(f"\n **Query Metadata:**")
            print(f"   Chunks found: {metadata.get('chunks_found', 0)}")
            print(f"   KG pattern matched: {metadata.get('kg_pattern_matched', False)}")
            print(f"   Avg relevance: {metadata.get('avg_relevance_score', 0):.3f}")

        except Exception as e:
            print(f" **Error:** {str(e)}")

    # Show Knowledge Graph Insights
    print(f"\n **KNOWLEDGE GRAPH INSIGHTS:**")
    try:
        med_patterns = kg.query_medication_patterns()
        if med_patterns.get('common_medications'):
            print(f" **Top 5 Medications Found:**")
            for i, (med, count) in enumerate(med_patterns['common_medications'][:5], 1):
                print(f"   {i}. {med} ({count} prescriptions)")

        lab_patterns = kg.query_lab_patterns()
        if lab_patterns.get('common_tests'):
            print(f"\n**Top 5 Lab Tests Found:**")
            for i, (test, count) in enumerate(lab_patterns['common_tests'][:5], 1):
                print(f"   {i}. {test} ({count} occurrences)")
    except Exception as e:
        print(f"Error getting insights: {e}")

    # Show Sample Extracted Data
    print(f"\n **SAMPLE EXTRACTED DATA:**")

    # Show sample patients
    patient_names = set()
    for doc in processed_docs:
        if doc.get('patient_name') and doc['patient_name'] not in ['Unknown', None]:
            patient_names.add(doc['patient_name'])

    if patient_names:
        print(f"👥 **Unique Patients Found:** {len(patient_names)}")
        for name in list(patient_names)[:5]:
            print(f"   - {name}")

    # Show sample medications
    all_medications = set()
    for doc in processed_docs:
        if doc['document_type'] == 'prescription':
            for med in doc.get('medications', []):
                if med.get('name'):
                    all_medications.add(med['name'])

    if all_medications:
        print(f"\n**Sample Medications Found:** {len(all_medications)} unique")
        for med in list(all_medications)[:5]:
            print(f"   - {med}")

    # Show sample lab tests
    all_tests = set()
    for doc in processed_docs:
        if doc['document_type'] == 'lab_report':
            for test in doc.get('lab_tests', []):
                if test.get('test_name'):
                    all_tests.add(test['test_name'])

    if all_tests:
        print(f"\n **Sample Lab Tests Found:** {len(all_tests)} unique")
        for test in list(all_tests)[:5]:
            print(f"   - {test}")

    print(f"\n **DEMO COMPLETE!**")

    return True

# Interactive Query Interface
def run_interactive_medical_queries():
    """Interactive query interface for the medical RAG system"""

    if not rag_system:
        print(" RAG system not initialized! Run the processing pipeline first.")
        return

    print("\n💬 **INTERACTIVE MEDICAL RAG QUERY SYSTEM**")
    print("=" * 50)

    print("📋 **Sample queries you can try:**")
    sample_queries = [
        "What medications were prescribed for diabetes?",
        "Show me abnormal lab test results",
        "List all patients with their medical history",
        "Find prescriptions with specific dosages",
        "Show me blood sugar test results",
        "What are the common diagnoses found?",
        "Find patients prescribed Metformin",
        "Show me recent lab reports"
    ]

    for i, query in enumerate(sample_queries, 1):
        print(f"   {i}. {query}")

    print(f"\n{'='*50}")
    print(" **Tips:**")
    print("   - Ask about specific patients by name")
    print("   - Search for medication patterns and trends")
    print("   - Query lab test results and abnormal values")
    print("   - Look for relationships between tests and prescriptions")
    print("   - Type 'stats' to see system statistics")
    print("   - Type 'quit' to exit")
    print(f"{'='*50}")

    query_count = 0
    while True:
        try:
            print(f"\n **Enter your medical query:**")
            user_query = input("🔍 Query: ").strip()

            if user_query.lower() in ['quit', 'exit', 'q']:
                print(" Goodbye! Thanks for using the Medical RAG System!")
                break

            if not user_query:
                print(" Please enter a query")
                continue

            if user_query.lower() == 'stats':
                print(f"\n**SYSTEM STATISTICS:**")
                print(f"   Documents processed: {len(processed_docs)}")
                print(f"   Knowledge graph nodes: {kg.graph.number_of_nodes()}")
                print(f"   Vector chunks indexed: {len(rag_system.document_chunks)}")
                print(f"   Queries processed: {query_count}")
                continue

            query_count += 1
            print(f"\n **Processing query #{query_count}...**")
            print("-" * 40)

            # Generate answer
            response = rag_system.generate_answer(user_query)

            print(f"📋 **Answer:**")
            print(response['answer'])

            # Show additional information
            metadata = response.get('metadata', {})
            if metadata.get('chunks_found', 0) > 0:
                print(f"\n **Search Results:**")
                print(f"    Document chunks found: {metadata['chunks_found']}")
                print(f"    Average relevance score: {metadata.get('avg_relevance_score', 0):.3f}")
                print(f"    Knowledge graph used: {'Yes' if metadata.get('kg_pattern_matched') else 'No'}")

                # Show top relevant chunks
                chunks = response.get('relevant_chunks', [])
                if chunks:
                    print(f"\n📄 **Top Sources:**")
                    for i, chunk in enumerate(chunks[:2], 1):
                        doc_type = chunk.get('document_type', 'document').title()
                        patient = chunk.get('patient_name', 'Unknown')
                        score = chunk.get('relevance_score', 0)
                        print(f"   {i}. {doc_type} (Patient: {patient}, Score: {score:.3f})")

        except KeyboardInterrupt:
            print(f"\n\n Session interrupted. Goodbye!")
            break
        except Exception as e:
            print(f" **Error processing query:** {str(e)}")
            print("Please try a different query.")

# Export and Save Functions
def export_results():
    """Export processed results to various formats"""

    if not processed_docs:
        print(" No processed documents to export!")
        return

    print("💾 **EXPORTING RESULTS...**")

    try:
        # Export processed documents
        with open('/content/output/processed_documents_complete.json', 'w') as f:
            json.dump(processed_docs, f, indent=2, default=str)
        print(f" Processed documents: /content/output/processed_documents_complete.json")

        # Export knowledge graph data
        graph_data = kg.export_graph_data()
        with open('/content/output/knowledge_graph_data.json', 'w') as f:
            json.dump(graph_data, f, indent=2, default=str)
        print(f"Knowledge graph: /content/output/knowledge_graph_data.json")

        # Export summary statistics
        stats = {
            'processing_summary': {
                'total_documents': len(processed_docs),
                'lab_reports': sum(1 for doc in processed_docs if doc['document_type'] == 'lab_report'),
                'prescriptions': sum(1 for doc in processed_docs if doc['document_type'] == 'prescription'),
                'processing_date': datetime.now().isoformat()
            },
            'knowledge_graph_stats': {
                'total_nodes': kg.graph.number_of_nodes(),
                'total_edges': kg.graph.number_of_edges(),
            },
            'extraction_insights': {
                'unique_patients': len(set(doc.get('patient_name') for doc in processed_docs if doc.get('patient_name') and doc['patient_name'] != 'Unknown')),
                'total_medications': sum(len(doc.get('medications', [])) for doc in processed_docs if doc['document_type'] == 'prescription'),
                'total_lab_tests': sum(len(doc.get('lab_tests', [])) for doc in processed_docs if doc['document_type'] == 'lab_report')
            }
        }

        with open('/content/output/system_statistics.json', 'w') as f:
            json.dump(stats, f, indent=2, default=str)
        print(f"System statistics: /content/output/system_statistics.json")

        print(f"\n **All results exported to /content/output/ directory**")

    except Exception as e:
        print(f"Export error: {str(e)}")

# Missing export_graph_data method for KnowledgeGraph
def export_graph_data(self):
    """Export graph for visualization and analysis"""
    nodes = []
    edges = []

    for node_id, data in self.graph.nodes(data=True):
        nodes.append({
            'id': node_id,
            'type': data.get('type', 'unknown'),
            'name': data.get('name', node_id),
            **{k: v for k, v in data.items() if k not in ['id', 'type', 'name']}
        })

    for source, target, data in self.graph.edges(data=True):
        edges.append({
            'source': source,
            'target': target,
            'relation': data.get('relation', 'connected'),
            **{k: v for k, v in data.items() if k not in ['source', 'target', 'relation']}
        })

    return {'nodes': nodes, 'edges': edges}

# Add the method to the class
MedicalKnowledgeGraph.export_graph_data = export_graph_data

print(" **MEDICAL RAG SYSTEM COMPLETE!**")
print("\n **Ready to use! Try these commands:**")
print("   - run_comprehensive_medical_rag_demo()  # Run full demo")
print("   - run_interactive_medical_queries()     # Interactive chat")
print("   - export_results()                      # Save all results")


 **MEDICAL RAG SYSTEM COMPLETE!**

 **Ready to use! Try these commands:**
   - run_comprehensive_medical_rag_demo()  # Run full demo
   - run_interactive_medical_queries()     # Interactive chat
   - export_results()                      # Save all results


In [10]:
run_interactive_medical_queries()


💬 **INTERACTIVE MEDICAL RAG QUERY SYSTEM**
📋 **Sample queries you can try:**
   1. What medications were prescribed for diabetes?
   2. Show me abnormal lab test results
   3. List all patients with their medical history
   4. Find prescriptions with specific dosages
   5. Show me blood sugar test results
   6. What are the common diagnoses found?
   7. Find patients prescribed Metformin
   8. Show me recent lab reports

 **Tips:**
   - Ask about specific patients by name
   - Search for medication patterns and trends
   - Query lab test results and abnormal values
   - Look for relationships between tests and prescriptions
   - Type 'stats' to see system statistics
   - Type 'quit' to exit

 **Enter your medical query:**
🔍 Query:  What medications were prescribed for diabetes?

 **Processing query #1...**
----------------------------------------
🔍 Processing query: 'What medications were prescribed for diabetes?'
📋 **Answer:**
**Medication Analysis from 122 prescriptions:**
**Most co