In [1]:


import os
import re
import time
import pickle
import json
import unicodedata
import networkx as nx
from datetime import datetime
from typing import List, Dict, Optional
from dotenv import load_dotenv

from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import TokenTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
import chromadb
from chromadb.utils import embedding_functions
from pyvis.network import Network
from pydantic import BaseModel, Field



In [2]:
# === CELL 2: LOAD CONFIGURATION ===
load_dotenv()

GROQ_API_KEYS = [
    os.getenv("GROQ_API_KEY_1"),
    os.getenv("GROQ_API_KEY_2"),
    os.getenv("GROQ_API_KEY_3"),
    os.getenv("GROQ_API_KEY_4"),
    os.getenv("GROQ_API_KEY_5"),
    os.getenv("GROQ_API_KEY_6"),
    os.getenv("GROQ_API_KEY_7"),
    os.getenv("GROQ_API_KEY_8"),
    os.getenv("GROQ_API_KEY_9"),
]
PDF_PATH = os.getenv("PDF_PATH")

# API Key Manager
class APIKeyManager:
    def __init__(self, api_keys: List[str]):
        self.api_keys = [k for k in api_keys if k and not k.startswith("gsk_YOUR")]
        self.current_index = 0
        self.failed_keys = set()
    
    def get_current_key(self) -> str:
        return self.api_keys[self.current_index]
    
    def rotate_key(self) -> bool:
        self.failed_keys.add(self.current_index)
        for i in range(len(self.api_keys)):
            next_index = (self.current_index + 1 + i) % len(self.api_keys)
            if next_index not in self.failed_keys:
                self.current_index = next_index
                print(f"Key #{next_index + 1}")
                return True
        return False
    
    def reset_failed(self):
        self.failed_keys.clear()

api_manager = APIKeyManager(GROQ_API_KEYS)
print(f" {len(api_manager.api_keys)} API keys loaded")

 9 API keys loaded


In [3]:
# === CELL 3: MEDICAL TEXT NORMALIZATION (FIXED - Preserve Abbreviations) ===

MEDICAL_ABBREVIATIONS = {
    "btm": "bệnh thận mạn", "tha": "tăng huyết áp", "đtđ": "đái tháo đường",
    "gfr": "GFR", "egfr": "eGFR", "ckd": "CKD", "acei": "ACEI", "arb": "ARB",
    "hba1c": "HbA1c", "ldl": "LDL", "hdl": "HDL"
}

MEDICAL_SYNONYMS = {
    "bệnh thận mạn": ["bệnh thận mãn", "suy thận mạn", "CKD"],
    "đái tháo đường": ["tiểu đường", "đtđ", "diabetes"],
    "tăng huyết áp": ["cao huyết áp", "tha"],
}

def normalize_medical_text(text: str) -> str:
    if not text: return "Unknown"
    text = unicodedata.normalize("NFC", text).strip().lower()
    text = re.sub(r'\\s+', ' ', text)
    words = [MEDICAL_ABBREVIATIONS.get(re.sub(r'[^\\w]', '', w), w) for w in text.split()]
    text = ' '.join(words)
    for canonical, variants in MEDICAL_SYNONYMS.items():
        for variant in variants:
            text = text.replace(variant.lower(), canonical)
    text = re.sub(r'\\s+', ' ', text).strip()
    if text: text = text[0].upper() + text[1:]
    return text

normalize_text = normalize_medical_text
print("Normalization ready (preserves abbreviations)")

Normalization ready (preserves abbreviations)


In [4]:
# === CELL 4: PYDANTIC SCHEMAS ===
class MedicalEntity(BaseModel):
    name: str
    type: str
    description: str = ""
    relevance_score: int = Field(default=5, ge=1, le=10)

class MedicalRelation(BaseModel):
    source_name: str
    target_name: str
    relation: str
    confidence_score: int = Field(default=5, ge=1, le=10)
    evidence: str = ""

class AMGExtractionResult(BaseModel):
    entities: List[MedicalEntity] = Field(default_factory=list)
    relations: List[MedicalRelation] = Field(default_factory=list)

print("Schemas ready")

Schemas ready


In [5]:
# === CELL 5: OPTIMIZED PROMPT (12 RELATION TYPES) ===
OPTIMIZED_EXTRACTION_PROMPT = ChatPromptTemplate.from_messages([
    ("system", """Bạn là chuyên gia trích xuất Knowledge Graph y tế.

## ENTITY TYPES (9 loại):
DISEASE | DRUG | SYMPTOM | TEST | ANATOMY | TREATMENT | PROCEDURE | RISK_FACTOR | LAB_VALUE

## RELATION TYPES (12 loại - ƯU TIÊN THEO THỨ TỰ):

### TIER 1 - CORE (Dùng ĐẦU TIÊN nếu match):
1. **CAUSES** - A gây ra B (VD: Tiểu đường CAUSES Bệnh thận mạn)
2. **TREATS** - A điều trị B (VD: Insulin TREATS Tiểu đường)
3. **PREVENTS** - A phòng ngừa B (VD: ACEI PREVENTS Tiến triển bệnh thận)
4. **DIAGNOSES** - A chẩn đoán B (VD: eGFR DIAGNOSES Bệnh thận mạn)

### TIER 2 - CLINICAL (Dùng khi không match TIER 1):
5. **SYMPTOM_OF** - A là triệu chứng của B (VD: Phù SYMPTOM_OF Suy tim)
6. **COMPLICATION_OF** - A là biến chứng của B (VD: Bệnh võng mạc COMPLICATION_OF Tiểu đường)
7. **SIDE_EFFECT_OF** - A là tác dụng phụ của B (VD: Buồn nôn SIDE_EFFECT_OF Hóa trị)
8. **INCREASES_RISK** - A tăng nguy cơ B (VD: Hút thuốc INCREASES_RISK Ung thư)

### TIER 3 - SPECIALIZED (Dùng cho trường hợp cụ thể):
9. **INTERACTS_WITH** - A tương tác với B (VD: Warfarin INTERACTS_WITH Aspirin)
10. **WORSENS** - A làm nặng B (VD: Mất nước WORSENS Suy thận)
11. **INDICATES** - A chỉ định B (VD: Hb < 10 INDICATES Erythropoietin)

### TIER 4 - FALLBACK (CHỈ dùng khi KHÔNG match TIER 1-3):
12. **RELATED_TO** - A liên quan B (CHỈ khi confidence ≥ 7 VÀ không fit các loại trên)

## SCORING:
**Entities** (relevance_score):
- 9-10: CHÍNH (đề cập nhiều lần)
- 7-8: QUAN TRỌNG (liên quan trực tiếp)
- 5-6: PHỤ (đề cập nhưng không trọng tâm)

**Relations** (confidence_score):
- 9-10: RÕ RÀNG (nêu trực tiếp)
- 7-8: MẠNH (bằng chứng rõ)
- ≥6: MINIMUM (dưới 6 KHÔNG trích)

## FEW-SHOT EXAMPLES:

### Example 1:
Input: "Bệnh thận mạn giai đoạn 5 cần lọc máu. Chỉ định dùng Erythropoietin khi Hb < 10g/dL."

Entities:
- Bệnh thận mạn giai đoạn 5 (DISEASE, score=10, desc="Mức độ nặng nhất của bệnh thận mạn")
- Lọc máu (TREATMENT, score=9, desc="Phương pháp thay thế thận")
- Erythropoietin (DRUG, score=8, desc="Thuốc điều trị thiếu máu")
- Hb < 10g/dL (LAB_VALUE, score=7, desc="Ngưỡng chỉ định điều trị")

Relations:
- Lọc máu TREATS Bệnh thận mạn giai đoạn 5 (confidence=9, evidence="cần lọc máu")
- Hb < 10g/dL INDICATES Erythropoietin (confidence=8, evidence="Chỉ định dùng Erythropoietin khi Hb < 10g/dL")

### Example 2:
Input: "Tiểu đường type 2 là yếu tố nguy cơ hàng đầu gây bệnh thận mạn. Kiểm soát đường huyết bằng Metformin."

Entities:
- Đái tháo đường type 2 (DISEASE, score=9, desc="Nguyên nhân chính gây BTM")
- Bệnh thận mạn (DISEASE, score=9, desc="Biến chứng của tiểu đường")
- Metformin (DRUG, score=8, desc="Thuốc kiểm soát đường huyết")

Relations:
- Đái tháo đường type 2 CAUSES Bệnh thận mạn (confidence=10, evidence="yếu tố nguy cơ hàng đầu gây bệnh thận mạn")
- Metformin TREATS Đái tháo đường type 2 (confidence=8, evidence="Kiểm soát đường huyết bằng Metformin")

### Example 3:
Input: "Buồn nôn là tác dụng phụ thường gặp của hóa trị. Thuốc chống nôn giúp giảm triệu chứng."

Entities:
- Buồn nôn (SYMPTOM, score=8, desc="Triệu chứng phổ biến")
- Hóa trị (TREATMENT, score=9, desc="Điều trị ung thư")
- Thuốc chống nôn (DRUG, score=7, desc="Thuốc giảm buồn nôn")

Relations:
- Buồn nôn SIDE_EFFECT_OF Hóa trị (confidence=9, evidence="tác dụng phụ thường gặp của hóa trị")
- Thuốc chống nôn TREATS Buồn nôn (confidence=8, evidence="giúp giảm triệu chứng")

## RULES (QUY TẮC NGHIÊM NGẶT):
1. **Ưu tiên TIER 1-2 relations** - chọn type CỤ THỂ nhất
2. **RELATED_TO CHỈ dùng khi**: confidence ≥ 7 VÀ KHÔNG match TIER 1-3
3. **Mỗi cặp entity CHỈ 1 relation** - chọn type MẠNH nhất
4. **Evidence phải trích từ văn bản gốc**
5. **KHÔNG extract**: số trang, quyết định, văn bản, tên tác giả, tên người
6. **Confidence < 6**: KHÔNG trích xuất relation
7. **GIỮ NGUYÊN tiếng Việt** - không dịch sang tiếng Anh

Trích xuất ENTITIES và RELATIONS từ văn bản:"""),
    ("human", "{text}")
])

print(" Optimized Prompt ready (12 relation types with TIER priority)")

 Optimized Prompt ready (12 relation types with TIER priority)


In [6]:
# === CELL 6: VALIDATION & INVERSE RELATIONS (FIXED) ===

# Valid relation types -  BỔ SUNG INVERSE TYPES VÀO WHITELIST
VALID_RELATION_TYPES = {
    # Original 12 types
    "CAUSES", "TREATS", "PREVENTS", "DIAGNOSES",  # TIER 1
    "SYMPTOM_OF", "COMPLICATION_OF", "SIDE_EFFECT_OF", "INCREASES_RISK",  # TIER 2
    "INTERACTS_WITH", "WORSENS", "INDICATES",  # TIER 3
    "RELATED_TO",  # TIER 4
    
    #  FIX 1: Thêm inverse types vào whitelist
    "CAUSED_BY", "TREATED_BY", "PREVENTED_BY", "DIAGNOSED_BY",
    "HAS_SYMPTOM", "HAS_COMPLICATION", "HAS_SIDE_EFFECT", 
    "RISK_INCREASED_BY", "WORSENED_BY", "INDICATED_BY"
}

# Inverse relation mapping
INVERSE_RELATIONS = {
    "CAUSES": "CAUSED_BY",
    "TREATS": "TREATED_BY",
    "PREVENTS": "PREVENTED_BY",
    "DIAGNOSES": "DIAGNOSED_BY",
    "SYMPTOM_OF": "HAS_SYMPTOM",
    "COMPLICATION_OF": "HAS_COMPLICATION",
    "SIDE_EFFECT_OF": "HAS_SIDE_EFFECT",
    "INCREASES_RISK": "RISK_INCREASED_BY",
    "WORSENS": "WORSENED_BY",
    "INDICATES": "INDICATED_BY",
}

SYMMETRIC_RELATIONS = {"INTERACTS_WITH", "RELATED_TO"}

def validate_entity(entity: MedicalEntity) -> bool:
    """Validate if entity is valid"""
    if not entity.name or len(entity.name) < 2:
        return False
    admin_patterns = [r'quyết định', r'văn bản', r'bộ y tế', r'trang \d+', r'điều \d+', r'khoản \d+', r'mục \d+', r'phụ lục']
    for pattern in admin_patterns:
        if re.search(pattern, entity.name.lower()):
            return False
    valid_types = {'DISEASE', 'DRUG', 'SYMPTOM', 'TEST', 'ANATOMY', 'TREATMENT', 'PROCEDURE', 'RISK_FACTOR', 'LAB_VALUE'}
    if entity.type.upper() not in valid_types:
        return False
    return True

def validate_relation(relation: MedicalRelation) -> bool:
    """Validate if relation is valid"""
    if relation.confidence_score < 6:
        return False
    if relation.relation.upper() not in VALID_RELATION_TYPES:
        return False
    if len(relation.source_name) < 2 or len(relation.target_name) < 2:
        return False
    if normalize_text(relation.source_name) == normalize_text(relation.target_name):
        return False
    return True

def generate_inverse_relations(relations: List[MedicalRelation]) -> List[MedicalRelation]:
    """Generate inverse relations"""
    inverse_rels = []
    for rel in relations:
        rel_type = rel.relation.upper()
        if rel_type in SYMMETRIC_RELATIONS:
            continue
        if rel_type in INVERSE_RELATIONS:
            #   FIX 2: Đảm bảo confidence >= 6
            inverse_confidence = max(6, rel.confidence_score - 1)
            inverse_rel = MedicalRelation(
                source_name=rel.target_name,
                target_name=rel.source_name,
                relation=INVERSE_RELATIONS[rel_type],
                confidence_score=inverse_confidence,
                evidence=f"Inverse of: {rel.evidence}"
            )
            inverse_rels.append(inverse_rel)
    return inverse_rels

print("  Validation & Inverse Relations ready (FIXED)")

  Validation & Inverse Relations ready (FIXED)


In [7]:
# === CELL 7: IMPROVED AMG EXTRACTOR ===
class ImprovedAMGExtractor:
    def __init__(self, api_manager: APIKeyManager):
        self.api_manager = api_manager
        self.extraction_prompt = OPTIMIZED_EXTRACTION_PROMPT
        self._init_llm()
    
    def _init_llm(self):
        self.llm = ChatGroq(
            temperature=0.1,
            model="llama-3.3-70b-versatile",
            api_key=self.api_manager.get_current_key()
        )
        self.chain = self.extraction_prompt | self.llm.with_structured_output(AMGExtractionResult)
    
    def extract(self, text: str, max_retries=3) -> AMGExtractionResult:
        for attempt in range(max_retries):
            try:
                result = self.chain.invoke({"text": text})
                
                if result:
                    #  VALIDATION: Filter invalid entities
                    valid_entities = [e for e in result.entities if validate_entity(e)]
                    result.entities = valid_entities
                    
                    #  VALIDATION: Filter invalid relations (conf >= 6, valid types, no self-loop)
                    valid_relations = [r for r in result.relations if validate_relation(r)]
                    
                    #  INVERSE RELATIONS: Generate bidirectional relations
                    inverse_relations = generate_inverse_relations(valid_relations)
                    
                    # Combine original + inverse
                    result.relations = valid_relations + inverse_relations
                
                return result if result else AMGExtractionResult()
                
            except Exception as e:
                if "rate" in str(e).lower() or "limit" in str(e).lower():
                    if self.api_manager.rotate_key():
                        self._init_llm()
                    else:
                        print(" Chờ 60s...")
                        time.sleep(60)
                        self.api_manager.reset_failed()
                        self._init_llm()
                else:
                    time.sleep(1)
        
        return AMGExtractionResult()

print("Improved AMG Extractor ready")

Improved AMG Extractor ready


In [8]:
# === CELL 8: GRAPH HELPERS (FIXED - Upgrade + Dedup) ===

def add_entity_to_graph(G, entity: MedicalEntity, page_num: int, chunk_id: int):
    """ FIX 2: Upgrade from UNKNOWN to real type"""
    norm_name = normalize_text(entity.name)
    confidence = min(1.0, entity.relevance_score / 10.0)
    
    if not G.has_node(norm_name):
        G.add_node(norm_name, label=entity.name, type=entity.type.upper(), description=entity.description,
                   confidence=confidence, relevance_score=entity.relevance_score, pages=[page_num], chunks=[chunk_id])
    else:
        #  Upgrade from UNKNOWN
        if G.nodes[norm_name].get("type") == "UNKNOWN":
            G.nodes[norm_name]["type"] = entity.type.upper()
            G.nodes[norm_name]["label"] = entity.name
            G.nodes[norm_name]["description"] = entity.description
        old_conf = G.nodes[norm_name].get('confidence', 0)
        if confidence > old_conf:
            G.nodes[norm_name]['confidence'] = confidence
            if G.nodes[norm_name].get("type") != "UNKNOWN":
                G.nodes[norm_name]['description'] = entity.description
        if page_num not in G.nodes[norm_name]['pages']:
            G.nodes[norm_name]['pages'].append(page_num)
        if chunk_id not in G.nodes[norm_name]['chunks']:
            G.nodes[norm_name]['chunks'].append(chunk_id)

def edge_exists(G, src, tgt, rel, chunk_id):
    """ FIX 3: Check for duplicate edges"""
    if not G.has_edge(src, tgt): return False
    edge_dict = G.get_edge_data(src, tgt)
    if edge_dict:
        for key, data in edge_dict.items():
            if data.get("relation") == rel and data.get("chunk") == chunk_id:
                return True
    return False

def add_relation_to_graph(G, rel: MedicalRelation, page_num: int, chunk_id: int):
    """ FIX 3: Deduplication before adding edge"""
    src = normalize_text(rel.source_name)
    tgt = normalize_text(rel.target_name)
    rel_type = rel.relation.upper()
    if not G.has_node(src):
        G.add_node(src, label=rel.source_name, type="UNKNOWN", confidence=0.5, pages=[page_num], chunks=[chunk_id], description="")
    if not G.has_node(tgt):
        G.add_node(tgt, label=rel.target_name, type="UNKNOWN", confidence=0.5, pages=[page_num], chunks=[chunk_id], description="")
    # Check duplicate
    if not edge_exists(G, src, tgt, rel_type, chunk_id):
        G.add_edge(src, tgt, relation=rel_type, confidence=min(1.0, rel.confidence_score/10),
                   evidence=rel.evidence, page=page_num, chunk=chunk_id)

print(" Graph helpers ready (upgrade + dedup)")

 Graph helpers ready (upgrade + dedup)


In [9]:
# === CELL 9: CHECKPOINT MANAGER (MultiDiGraph) ===

class CheckpointManager:
    def __init__(self, checkpoint_dir="./amg_data"):
        self.checkpoint_dir = checkpoint_dir
        os.makedirs(checkpoint_dir, exist_ok=True)
        self.graph_path = os.path.join(checkpoint_dir, "graph_improved.pkl")
        self.meta_path = os.path.join(checkpoint_dir, "checkpoint_meta.json")
    
    def save(self, G: nx.MultiDiGraph, chunk_id: int, total_chunks: int):
        """Save graph and metadata"""
        with open(self.graph_path, "wb") as f:
            pickle.dump(G, f)
        meta = {"last_chunk_id": chunk_id, "total_chunks": total_chunks, "num_nodes": G.number_of_nodes(),
                "num_edges": G.number_of_edges(), "timestamp": datetime.now().isoformat()}
        with open(self.meta_path, "w", encoding="utf-8") as f:
            json.dump(meta, f, indent=2, ensure_ascii=False)
        print(f"    Checkpoint: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges (chunk {chunk_id})")
    
    def load(self) -> tuple[Optional[nx.MultiDiGraph], Optional[int]]:
        """Load graph and last chunk_id"""
        graph, last_chunk_id = None, None
        if os.path.exists(self.graph_path):
            with open(self.graph_path, "rb") as f:
                graph = pickle.load(f)
            print(f" Loaded: {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges")
        if os.path.exists(self.meta_path):
            with open(self.meta_path, "r", encoding="utf-8") as f:
                meta = json.load(f)
            last_chunk_id = meta.get("last_chunk_id")
            print(f" Last checkpoint: chunk {last_chunk_id} at {meta.get('timestamp')}")
        return graph, last_chunk_id

checkpoint_manager = CheckpointManager()
print("Checkpoint Manager ready (MultiDiGraph)")

Checkpoint Manager ready (MultiDiGraph)


In [10]:
# === CELL 10: LOAD OR CREATE GRAPH (MultiDiGraph) ===

G, last_chunk_id = checkpoint_manager.load()

if G is None:
    print(" No checkpoint. Creating new graph...")
    #  FIX 3: MultiDiGraph
    G = nx.MultiDiGraph()
    start_chunk = 0
else:
    start_chunk = (last_chunk_id + 1) if last_chunk_id is not None else 0
    print(f" Resuming from chunk {start_chunk}")

print(f"\n Current: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")

 Loaded: 398 nodes, 390 edges
 Last checkpoint: chunk 519 at 2026-02-05T22:12:22.488447
 Resuming from chunk 520

 Current: 398 nodes, 390 edges


In [11]:
import sys
!{sys.executable} -m pip install -U tiktoken




In [12]:
# === CELL 11: LOAD PDF AND SPLIT INTO CHUNKS ===
print("Loading PDF...")
loader = PyPDFLoader(PDF_PATH)
documents = loader.load()

# Filter pages (skip first 16 pages as before)
filtered_docs = documents[16:]
print(f" Filtered to {len(filtered_docs)} pages")

# Split into chunks
text_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=50)
chunks = text_splitter.split_documents(filtered_docs)
print(f" Total chunks: {len(chunks)}")

# Determine which chunks to process
chunks_to_process = chunks[start_chunk:]
print(f"  Will process: chunk {start_chunk} → {len(chunks)-1}")
print(f"   ({len(chunks_to_process)} chunks remaining)\n")

Loading PDF...
 Filtered to 176 pages
 Total chunks: 700
  Will process: chunk 520 → 699
   (180 chunks remaining)



In [13]:
# === CELL 12: MAIN EXTRACTION LOOP (FIXED) ===

extractor = ImprovedAMGExtractor(api_manager)

print(" EXTRACTION WITH IMPROVEMENTS...")
print("    12 Relation Types (TIER Priority)")
print("  Validation (conf >= 6, valid types, no self-loop)")
print("   Inverse Relations (bidirectional)")
print("   Smart Checkpoint (chunk_id tracking)")
print("   MultiDiGraph (multi-relations supported)\n")

total_entities_added = 0
total_relations_added = 0
last_processed_chunk = start_chunk - 1  #  FIX 5: Track correctly

for i, chunk in enumerate(chunks_to_process, start=start_chunk):
    chunk_text = chunk.page_content
    page_num = chunk.metadata.get('page', 0)
    result = extractor.extract(chunk_text)
    
    if not result or (not result.entities and not result.relations):
        print(f"   Chunk {i+1}/{len(chunks)}: --")
        continue
    
    for entity in result.entities:
        add_entity_to_graph(G, entity, page_num, i)
    for relation in result.relations:
        add_relation_to_graph(G, relation, page_num, i)
    
    num_entities = len(result.entities)
    num_relations = len(result.relations)
    total_entities_added += num_entities
    total_relations_added += num_relations
    avg_score = sum(e.relevance_score for e in result.entities) / num_entities if num_entities > 0 else 0
    print(f"   Chunk {i+1}/{len(chunks)}: +{num_entities} entities, +{num_relations} relations (avg: {avg_score:.1f})")
    
    last_processed_chunk = i  
    if (i + 1) % 20 == 0:
        checkpoint_manager.save(G, i, len(chunks))

checkpoint_manager.save(G, last_processed_chunk, len(chunks))

print(f"\n{'='*60}")
print(" EXTRACTION COMPLETE!")
print(f"{'='*60}")
print(f" Final Stats:")
print(f"   - Nodes: {G.number_of_nodes()}")
print(f"   - Edges: {G.number_of_edges()}")
print(f"   - Entities Added: {total_entities_added}")
print(f"   - Relations Added: {total_relations_added}")
print(f"   - Chunks Processed: {last_processed_chunk - start_chunk + 1}")
print(f"   - Last Chunk ID: {last_processed_chunk}")

 EXTRACTION WITH IMPROVEMENTS...
    12 Relation Types (TIER Priority)
  Validation (conf >= 6, valid types, no self-loop)
   Inverse Relations (bidirectional)
   Smart Checkpoint (chunk_id tracking)
   MultiDiGraph (multi-relations supported)

   Chunk 521/700: +4 entities, +2 relations (avg: 7.8)
   Chunk 522/700: +4 entities, +3 relations (avg: 8.5)
   Chunk 523/700: +3 entities, +4 relations (avg: 8.0)
   Chunk 524/700: +3 entities, +4 relations (avg: 8.3)
   Chunk 525/700: +3 entities, +2 relations (avg: 8.0)
   Chunk 526/700: +8 entities, +6 relations (avg: 7.8)
   Chunk 527/700: +4 entities, +4 relations (avg: 8.5)
   Chunk 528/700: +4 entities, +4 relations (avg: 8.2)
   Chunk 529/700: +6 entities, +2 relations (avg: 7.8)
   Chunk 530/700: +8 entities, +10 relations (avg: 7.4)
   Chunk 531/700: +15 entities, +8 relations (avg: 7.3)
   Chunk 532/700: +3 entities, +2 relations (avg: 8.7)
   Chunk 533/700: +6 entities, +6 relations (avg: 7.0)
   Chunk 534/700: +8 entities, +6 rela

In [14]:
print("\n GRAPH STATISTICS\n")
print(f"Nodes: {G.number_of_nodes()}")
print(f"Edges: {G.number_of_edges()}")

from collections import Counter
entity_types = Counter([G.nodes[n]['type'] for n in G.nodes()])
print("\nEntity Types:")
for etype, count in entity_types.most_common():
    print(f"   {etype}: {count}")

# MultiDiGraph: use keys=True
relation_types = Counter([data['relation'] for u, v, key, data in G.edges(keys=True, data=True)])
print("\nRelation Types:")
for rtype, count in relation_types.most_common():
    inverse_marker = " (inverse)" if rtype.endswith("_BY") else ""
    print(f"   {rtype}{inverse_marker}: {count}")

degrees = dict(G.degree())
top_entities = sorted(degrees.items(), key=lambda x: x[1], reverse=True)[:10]
print("\nTop 10 Most Connected:")
for entity, degree in top_entities:
    etype = G.nodes[entity]['type']
    print(f"   {entity} ({etype}): {degree} connections")   


 GRAPH STATISTICS

Nodes: 999
Edges: 1411

Entity Types:
   DISEASE: 311
   DRUG: 166
   LAB_VALUE: 130
   SYMPTOM: 108
   TREATMENT: 94
   RISK_FACTOR: 58
   ANATOMY: 46
   TEST: 42
   UNKNOWN: 31
   PROCEDURE: 13

Relation Types:
   TREATED_BY (inverse): 209
   TREATS: 209
   CAUSED_BY (inverse): 136
   CAUSES: 136
   RELATED_TO: 123
   COMPLICATION_OF: 77
   HAS_COMPLICATION: 77
   RISK_INCREASED_BY (inverse): 68
   INCREASES_RISK: 68
   SYMPTOM_OF: 49
   HAS_SYMPTOM: 49
   DIAGNOSES: 41
   DIAGNOSED_BY (inverse): 41
   PREVENTS: 25
   PREVENTED_BY (inverse): 25
   HAS_SIDE_EFFECT: 16
   SIDE_EFFECT_OF: 16
   INDICATED_BY (inverse): 15
   INDICATES: 15
   INTERACTS_WITH: 10
   WORSENS: 3
   WORSENED_BY (inverse): 3

Top 10 Most Connected:
   Bệnh thận mạn (DISEASE): 149 connections
   Bệnh cầu thận màng (DISEASE): 67 connections
   Suy thận (DISEASE): 55 connections
   Bệnh thận (DISEASE): 54 connections
   Nhiễm trùng (DISEASE): 48 connections
   Glucocorticoids (DRUG): 46 connect

In [15]:
# === CELL 14: SAMPLE EXPLORATION (MultiDiGraph API) ===

sample_entity = "Bệnh Thận Mạn"

if sample_entity in G:
    print(f"\\n Exploring: {sample_entity}\\n")
    node_data = G.nodes[sample_entity]
    print(f"Type: {node_data['type']}")
    print(f"Description: {node_data.get('description', '')[:200]}...")
    print(f"Confidence: {node_data['confidence']:.2f}")
    print(f"Pages: {node_data['pages']}")
    
    #  FIX: MultiDiGraph API for outgoing edges
    print("\\n Outgoing Relations:")
    for u, v, key, data in G.out_edges(sample_entity, keys=True, data=True):
        print(f"   → {data['relation']} → {v} (conf: {data.get('confidence', 0):.2f})")
        evidence = data.get('evidence', '')[:100]
        if evidence:
            print(f"      Evidence: {evidence}...")
    
    # FIX: MultiDiGraph API for incoming edges
    print("\\n Incoming Relations:")
    for u, v, key, data in G.in_edges(sample_entity, keys=True, data=True):
        print(f"   {u} → {data['relation']} → (conf: {data.get('confidence', 0):.2f})")
else:
    print(f" '{sample_entity}' not found. Try: {list(G.nodes())[:10]}")

 'Bệnh Thận Mạn' not found. Try: ['Tỉ số albumin / creatinine (acr)', 'Tỉ số protein / creatinine (pcr)', 'Creatinine', 'Bun (nitơ trong ure máu)', 'Phosphate', 'Calci', 'Urate', 'Cá thể hóa', 'Bác sỹ', 'Nghiên cứu']


In [16]:
# === CELL 16: GRAPH REASONING (MultiDiGraph Compatible) ===

def get_connected_nodes(G, node_name: str, confidence_threshold: float = 0.5):
    """FIX: MultiDiGraph compatible"""
    connected = []
    if node_name in G:
        for _, neighbor, key, data in G.out_edges(node_name, keys=True, data=True):
            conf = data.get("confidence", 0)
            if conf >= confidence_threshold:
                connected.append({"node": neighbor, "relation": data.get("relation"),
                                  "confidence": conf, "evidence": data.get("evidence", ""), "key": key})
    return connected

def explore_path(G, start_node: str, max_depth: int = 3, confidence_threshold: float = 0.5):
    """ FIX: MultiDiGraph compatible"""
    paths, visited = [], set()
    def dfs(node, path, accumulated_confidence, depth):
        if depth > max_depth or node in visited: return
        visited.add(node)
        if len(path) > 0:
            paths.append({'path': path.copy(), 'confidence': accumulated_confidence, 'final_node': node})
        for neighbor_data in get_connected_nodes(G, node, confidence_threshold):
            neighbor = neighbor_data['node']
            new_confidence = accumulated_confidence * neighbor_data['confidence']
            if new_confidence >= confidence_threshold:
                new_path = path + [(node, neighbor, neighbor_data['relation'])]
                dfs(neighbor, new_path, new_confidence, depth + 1)
        visited.remove(node)
    dfs(start_node, [], 1.0, 0)
    return paths

def reason_about_entity(G, entity_name: str, context_depth: int = 2):
    """ FIX: MultiDiGraph compatible"""
    if entity_name not in G: return f"Entity '{entity_name}' not found"
    node_data = G.nodes[entity_name]
    context = f"## {entity_name}\\nType: {node_data.get('type')}\\nConfidence: {node_data.get('confidence', 0):.2f}\\n\\n"
    context += "### Direct Relations:\\n"
    for conn in get_connected_nodes(G, entity_name, 0.3)[:5]:
        context += f"- {conn['relation']} → {conn['node']} (conf: {conn['confidence']:.2f})\\n"
    if context_depth > 1:
        context += "\\n### Reasoning Paths:\\n"
        paths = explore_path(G, entity_name, context_depth, 0.3)
        top_paths = sorted(paths, key=lambda x: x['confidence'], reverse=True)[:3]
        for p in top_paths:
            path_str = " → ".join([f"{step[0]} [{step[2]}]" for step in p['path']])
            if path_str: context += f"- {path_str} → {p['final_node']} (conf: {p['confidence']:.2f})\\n"
    return context

def find_related_entities(G, entity_name: str, top_k: int = 10, min_confidence: float = 0.5):
    """FIX: MultiDiGraph compatible"""
    if entity_name not in G: return []
    entity_scores = {}
    for conn in get_connected_nodes(G, entity_name, min_confidence):
        neighbor = conn['node']
        entity_scores[neighbor] = entity_scores.get(neighbor, 0) + conn['confidence']
    for conn1 in get_connected_nodes(G, entity_name, 0.3):
        for conn2 in get_connected_nodes(G, conn1['node'], 0.3):
            second_hop = conn2['node']
            if second_hop != entity_name and second_hop not in entity_scores:
                propagated = conn1['confidence'] * conn2['confidence'] * 0.5
                if propagated >= min_confidence:
                    entity_scores[second_hop] = propagated
    sorted_entities = sorted(entity_scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
    return [{'entity': e, 'type': G.nodes[e].get('type'), 'score': s} for e, s in sorted_entities]

print("Graph Reasoning (MultiDiGraph compatible)")

Graph Reasoning (MultiDiGraph compatible)


In [17]:
# === CELL 17: EXAMPLE - REASONING DEMO ===

# Demo: Reason about a specific entity
demo_entity = "Bệnh Thận Mạn"  # Change this

if demo_entity in G:
    print(" REASONING ABOUT ENTITY\n")
    print(reason_about_entity(G, demo_entity, context_depth=2))
    
    print("\n" + "="*60)
    print(" RELATED ENTITIES\n")
    related = find_related_entities(G, demo_entity, top_k=10, min_confidence=0.3)
    for r in related:
        print(f"   {r['entity']} ({r['type']}): {r['score']:.3f}")
    
    print("\n" + "="*60)
    print(" EXPLORING PATHS\n")
    paths = explore_path(G, demo_entity, max_depth=2, confidence_threshold=0.3)
    top_5_paths = sorted(paths, key=lambda x: x['confidence'], reverse=True)[:5]
    for i, path_data in enumerate(top_5_paths, 1):
        path_str = " → ".join([f"{p[0]} [{p[2]}]" for p in path_data['path']])
        if path_str:
            print(f"  {i}. {path_str} → {path_data['final_node']}")
            print(f"     Confidence: {path_data['confidence']:.3f}\n")
else:
    print(f"'{demo_entity}' not in graph. Try: {list(G.nodes())[:5]}")

'Bệnh Thận Mạn' not in graph. Try: ['Tỉ số albumin / creatinine (acr)', 'Tỉ số protein / creatinine (pcr)', 'Creatinine', 'Bun (nitơ trong ure máu)', 'Phosphate']


In [18]:
# === CELL 15: EXPORT GRAPH ===
# Export to various formats

# 1. Export to JSON
import json
from networkx.readwrite import json_graph

graph_json = json_graph.node_link_data(G)
with open("./amg_data/graph_improved.json", "w", encoding="utf-8") as f:
    json.dump(graph_json, f, indent=2, ensure_ascii=False)
print(" Exported to: ./amg_data/graph_improved.json")

# 2. Export to GraphML (for Gephi, Cytoscape)
nx.write_graphml(G, "./amg_data/graph_improved.graphml")
print(" Exported to: ./amg_data/graph_improved.graphml")

# 3. Export to CSV (edges)
import pandas as pd
edges_data = []
for u, v, data in G.edges(data=True):
    edges_data.append({
        "source": u,
        "target": v,
        "relation": data.get('relation', ''),
        "confidence": data.get('confidence', 0),
        "evidence": data.get('evidence', '')[:100]
    })
df_edges = pd.DataFrame(edges_data)
df_edges.to_csv("./amg_data/graph_edges.csv", index=False, encoding="utf-8-sig")
print(" Exported to: ./amg_data/graph_edges.csv")

print("\nAll exports complete!")

 Exported to: ./amg_data/graph_improved.json


NetworkXError: GraphML writer does not support <class 'list'> as data values.

In [21]:
# === CELL 18: VISUALIZATION WITH PYVIS (COLORIZED) ===
from pyvis.network import Network
import IPython.display

def visualize_graph(G, output_file="amg_rag_graph.html"):
    print(f"Visualizing {G.number_of_nodes()} nodes and {G.number_of_edges()} edges...")

    net = Network(
        height="750px",
        width="100%",
        bgcolor="#222222",
        font_color="white",
        select_menu=True,
        filter_menu=True
    )
    net.force_atlas_2based()

    color_map = {
        "DISEASE": "#FF6B6B",
        "DRUG": "#4ECDC4",
        "SYMPTOM": "#FFE66D",
        "TEST": "#1A535C",
        "ANATOMY": "#FF9F1C",
        "TREATMENT": "#2B2D42",
        "PROCEDURE": "#8D99AE",
        "RISK_FACTOR": "#EF233C",
        "LAB_VALUE": "#219EBC",
        "UNKNOWN": "#cccccc"
    }

    # Add nodes
    for node, data in G.nodes(data=True):
        node_type = data.get("type", "UNKNOWN")
        color = color_map.get(node_type, "#999999")

        degree = G.degree(node)
        size = 15 + (degree * 2)

        title_html = f"<b>{node}</b><br>Type: {node_type}<br>Connections: {degree}"
        if data.get("description"):
            title_html += f"<br><i>{data['description'][:100]}...</i>"

        net.add_node(
            node,
            label=str(node),
            title=title_html,
            color=color,
            value=size,
            group=node_type
        )

    # Add edges
    for u, v, data in G.edges(data=True):
        rel_type = data.get("relation", "RELATED_TO")
        confidence = float(data.get("confidence", 0.5))
        width = max(1.0, confidence * 3)

        net.add_edge(
            u, v,
            title=f"{rel_type} (conf: {confidence:.2f})",
            label=rel_type,
            width=width,
            color="#aaaaaa"
        )

    # Save (UTF-8) + display
    net.show_buttons(filter_=['physics'])

    html = net.generate_html()
    with open(output_file, "w", encoding="utf-8") as f:
        f.write(html)

    print(f"✅ Graph saved to {output_file}")
    return output_file

html_file = visualize_graph(G, "amg_rag_graph.html")

# Display in notebook
IPython.display.HTML(open(html_file, "r", encoding="utf-8").read())


Visualizing 999 nodes and 1411 edges...
✅ Graph saved to amg_rag_graph.html


In [22]:
# Top 20 nodes theo degree (tổng bậc)
top_nodes = sorted(G.degree, key=lambda x: x[1], reverse=True)[:20]

print("Top 20 nodes by degree:")
for i, (node, deg) in enumerate(top_nodes, 1):
    ntype = G.nodes[node].get("type", "UNKNOWN")
    print(f"{i:02d}. {node} | degree={deg} | type={ntype}")


Top 20 nodes by degree:
01. Bệnh thận mạn | degree=149 | type=DISEASE
02. Bệnh cầu thận màng | degree=67 | type=DISEASE
03. Suy thận | degree=55 | type=DISEASE
04. Bệnh thận | degree=54 | type=DISEASE
05. Nhiễm trùng | degree=48 | type=DISEASE
06. Glucocorticoids | degree=46 | type=DRUG
07. Adpkd | degree=45 | type=DISEASE
08. Hcth | degree=40 | type=DISEASE
09. Bệnh thận màng | degree=38 | type=DISEASE
10. Tăng huyết áp | degree=38 | type=DISEASE
11. Protein niệu | degree=36 | type=LAB_VALUE
12. Đái tháo đường | degree=31 | type=DISEASE
13. Chronic kidney disease | degree=30 | type=DISEASE
14. Bệnh thận iga | degree=28 | type=DISEASE
15. Suy thận giai đoạn cuối | degree=26 | type=DISEASE
16. Xơ cầu thận ổ cục bộ | degree=26 | type=DISEASE
17. Đái máu | degree=22 | type=SYMPTOM
18. Vct | degree=21 | type=DISEASE
19. Aav | degree=21 | type=DISEASE
20. Tổn thương thận cấp | degree=20 | type=DISEASE


In [23]:
import random

def short_text(x, n=180):
    if x is None:
        return ""
    if isinstance(x, (list, tuple)):
        x = " | ".join(map(str, x))
    if isinstance(x, dict):
        x = json.dumps(x, ensure_ascii=False)
    x = str(x).replace("\n", " ").strip()
    return x[:n] + ("..." if len(x) > n else "")

edges = list(G.edges(data=True))
print("Total edges:", len(edges))

random.seed(42)  # để lần nào sample cũng giống nhau (bỏ nếu không cần)
sample = random.sample(edges, k=min(20, len(edges)))

for i, (u, v, data) in enumerate(sample, 1):
    rel = data.get("relation", "RELATED_TO")
    conf = data.get("confidence", None)
    ev = data.get("evidence", data.get("context", data.get("chunk", None)))  # thử vài key phổ biến

    conf_str = f"{conf:.2f}" if isinstance(conf, (int, float)) else "N/A"
    print(f"\n{i:02d}) {u} --[{rel} | conf={conf_str}]--> {v}")
    print("    evidence:", short_text(ev, 220) if ev else "(no evidence)")


Total edges: 1411

01) Nhiễm trùng nang thận --[TREATED_BY | conf=0.60]--> Tmp/smx
    evidence: Inverse of: có thể được lựa chọn để điều trị

02) Tắc nghẽn đường tiết niệu --[CAUSES | conf=0.80]--> Suy thận cấp
    evidence: Có bệnh lý tắc nghẽn đường tiết niệu hoặc có thay đổi cấu trúc đường tiết niệu

03) Bệnh thận mạn --[HAS_SYMPTOM | conf=0.60]--> Đái máu vi thể
    evidence: Inverse of: đái máu vi thể

04) Tổn thương trực tiếp ống thận --[CAUSED_BY | conf=0.70]--> Tmp/smx
    evidence: Inverse of: Tổn thương trực tiếp ống thận

05) Viêm phúc mạc --[HAS_SYMPTOM | conf=0.70]--> Lmb
    evidence: Inverse of: quả LMB cũng được đánh giá dựa trên tỷ lệ viêm phúc mạc

06) Hbv-dna --[CAUSES | conf=0.90]--> Viêm hoại tử
    evidence: tiêu bản sinh thiết gan cho thấy tổn thương mô học là viêm hoại tử và xơ hóa đáng kể

07) Tacrolimus --[TREATS | conf=0.70]--> Bệnh thận tiến triển
    evidence: cần theo dõi nồng độ tacrolimus

08) Tăng huyết áp --[HAS_COMPLICATION | conf=0.70]--> Btm
    ev