In [None]:
# === APPLY IMPROVEMENTS L√äN GRAPH HI·ªÜN C√ì (479 nodes) ===
import pickle
import os
import re
import unicodedata
import networkx as nx
from pyvis.network import Network
import chromadb
from chromadb.utils import embedding_functions

print("üöÄ LOADING EXISTING GRAPH...\n")

In [None]:
# === HELPER FUNCTIONS ===
MEDICAL_ABBREVIATIONS = {
    "btm": "b·ªánh th·∫≠n m·∫°n", "tha": "tƒÉng huy·∫øt √°p", "ƒëtƒë": "ƒë√°i th√°o ƒë∆∞·ªùng",
    "gfr": "ƒë·ªô l·ªçc c·∫ßu th·∫≠n", "egfr": "ƒë·ªô l·ªçc c·∫ßu th·∫≠n ∆∞·ªõc t√≠nh",
}

MEDICAL_SYNONYMS = {
    "b·ªánh th·∫≠n m·∫°n": ["b·ªánh th·∫≠n m√£n", "suy th·∫≠n m·∫°n", "ckd"],
    "ƒë√°i th√°o ƒë∆∞·ªùng": ["ti·ªÉu ƒë∆∞·ªùng", "ƒëtƒë", "diabetes"],
    "tƒÉng huy·∫øt √°p": ["cao huy·∫øt √°p", "tha", "hypertension"],
}

def normalize_medical_text(text: str) -> str:
    if not text: return "Unknown"
    text = unicodedata.normalize("NFC", text).strip().lower()
    text = re.sub(r'\s+', ' ', text)
    
    words = text.split()
    expanded = [MEDICAL_ABBREVIATIONS.get(re.sub(r'[^\w]', '', w), w) for w in words]
    text = ' '.join(expanded)
    
    for canonical, variants in MEDICAL_SYNONYMS.items():
        for variant in variants:
            text = text.replace(variant, canonical)
    
    return re.sub(r'\s+', ' ', text).strip().title()

print("‚úÖ Helper functions loaded.")

In [None]:
# === LOAD GRAPH CHECKPOINT ===
graph_path = "./amg_data/graph_full.pkl"

if os.path.exists(graph_path):
    with open(graph_path, "rb") as f:
        G = pickle.load(f)
    print(f"‚úÖ ƒê√£ load graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
else:
    print("‚ùå Kh√¥ng t√¨m th·∫•y checkpoint!")
    raise FileNotFoundError("Ch·∫°y AMG_RAG_Enhanced.ipynb tr∆∞·ªõc ƒë·ªÉ t·∫°o checkpoint.")

In [None]:
# === LOAD VECTORDB ===
chroma_path = "./amg_data/chroma_db"

if os.path.exists(chroma_path):
    client = chromadb.PersistentClient(path=chroma_path)
    ef = embedding_functions.SentenceTransformerEmbeddingFunction(
        model_name="sentence-transformers/all-mpnet-base-v2"
    )
    col = client.get_collection("medical_chunks", embedding_function=ef)
    
    # Get all documents ƒë·ªÉ t√≠nh frequency
    all_docs = col.get()
    all_chunks_text = [doc.lower() for doc in all_docs['documents']]
    
    print(f"‚úÖ ƒê√£ load VectorDB: {len(all_chunks_text)} chunks")
else:
    print("‚ùå Kh√¥ng t√¨m th·∫•y VectorDB!")
    all_chunks_text = []

In [None]:
# === IMPROVEMENT 1: ENTITY ENHANCEMENT T·ª™ VECTORDB ===
print("\nüîç IMPROVEMENT 1: Entity Enhancement t·ª´ VectorDB...\n")

enhanced_count = 0

for node in list(G.nodes):
    label = G.nodes[node].get('label', node)
    
    # Query VectorDB cho entity
    try:
        results = col.query(query_texts=[label], n_results=2)
        if results['documents'] and results['documents'][0]:
            # Combine top 2 chunks
            enhanced_context = " ".join(results['documents'][0])[:300]
            
            # Update description
            old_desc = G.nodes[node].get('description', '')
            G.nodes[node]['description'] = f"{old_desc}\n[Enhanced]: {enhanced_context}"
            
            # Boost confidence
            old_conf = G.nodes[node].get('confidence', 0.5)
            G.nodes[node]['confidence'] = min(1.0, old_conf + 0.15)
            
            enhanced_count += 1
            if enhanced_count <= 5:  # Show first 5
                print(f"   ‚úÖ {label}: conf {old_conf:.2f} ‚Üí {G.nodes[node]['confidence']:.2f}")
    except Exception as e:
        pass  # Skip n·∫øu l·ªói

print(f"\n‚úÖ Enhanced {enhanced_count}/{G.number_of_nodes()} entities")

In [None]:
# === IMPROVEMENT 2: FREQUENCY-BASED CONFIDENCE BOOST ===
print("\nüìä IMPROVEMENT 2: Frequency-based Confidence Boost...\n")

def get_entity_frequency(entity_name: str, chunks: list) -> int:
    """ƒê·∫øm s·ªë l·∫ßn entity xu·∫•t hi·ªán"""
    norm_name = normalize_medical_text(entity_name).lower()
    return sum(1 for chunk in chunks if norm_name in chunk)

boosted_count = 0

for node in list(G.nodes):
    label = G.nodes[node].get('label', node)
    frequency = get_entity_frequency(label, all_chunks_text)
    
    if frequency >= 3:
        # Calculate boost
        boost = 0.1 if frequency >= 5 else 0.05
        
        # Apply boost
        old_conf = G.nodes[node]['confidence']
        G.nodes[node]['confidence'] = min(1.0, old_conf + boost)
        
        # Add frequency info to description
        G.nodes[node]['description'] += f" [Freq: {frequency}x]"
        
        boosted_count += 1
        if boosted_count <= 10:
            print(f"   üìà {label}: {frequency}x ‚Üí conf {old_conf:.2f} ‚Üí {G.nodes[node]['confidence']:.2f}")

print(f"\n‚úÖ Boosted {boosted_count} entities (freq >= 3)")

In [None]:
# === IMPROVEMENT 3: DEDUPLICATION STATUS ===
print("\nüîÑ IMPROVEMENT 3: Deduplication Status...\n")

# Check entities c√≥ multiple pages (ƒë√£ ƒë∆∞·ª£c merge)
merged_entities = []
for node in G.nodes:
    pages = G.nodes[node].get('pages', [])
    if len(pages) > 1:
        merged_entities.append((G.nodes[node].get('label', node), len(pages)))

merged_entities.sort(key=lambda x: x[1], reverse=True)

print(f"‚úÖ T√¨m th·∫•y {len(merged_entities)} entities ƒë√£ ƒë∆∞·ª£c merge t·ª´ nhi·ªÅu chunks")
print("\nTop 10 merged entities:")
for entity, count in merged_entities[:10]:
    print(f"   - {entity}: xu·∫•t hi·ªán ·ªü {count} pages")

In [None]:
# === SAVE IMPROVED GRAPH ===
print("\nüíæ SAVING IMPROVED GRAPH...\n")

# Save checkpoint
with open("./amg_data/graph_improved.pkl", "wb") as f:
    pickle.dump(G, f)

print(f"‚úÖ ƒê√£ l∆∞u: ./amg_data/graph_improved.pkl")
print(f"   - {G.number_of_nodes()} nodes")
print(f"   - {G.number_of_edges()} edges")
print(f"   - {enhanced_count} entities enhanced")
print(f"   - {boosted_count} entities boosted by frequency")

In [None]:
# === VISUALIZE IMPROVED GRAPH ===
print("\nüìä CREATING VISUALIZATION...\n")

net = Network(height="850px", width="100%", directed=True, notebook=False)
net.force_atlas_2based(gravity=-50, spring_length=200)

color_map = {
    'DISEASE': '#ff6b6b', 'DRUG': '#4ecdc4', 'SYMPTOM': '#ffe66d',
    'TEST': '#95e1d3', 'ANATOMY': '#f38181', 'TREATMENT': '#aa96da',
    'PROCEDURE': '#a8d8ea', 'RISK_FACTOR': '#ffa07a', 'LAB_VALUE': '#98d8c8'
}

for node, data in G.nodes(data=True):
    node_type = data.get('type', 'OTHER')
    color = color_map.get(node_type, '#97C2FC')
    conf = data.get('confidence', 0.5)
    size = 15 + conf * 25  # Bigger nodes for higher confidence
    
    title = f"<b>{data.get('label')}</b><br>Type: {node_type}<br>Conf: {conf:.2f}<br>Pages: {data.get('pages')}"
    net.add_node(node, label=data.get('label'), title=title, color=color, size=size)

for u, v, data in G.edges(data=True):
    conf = data.get('confidence', 0.5)
    width = 1 + conf * 3
    net.add_edge(u, v, label=data.get('relation'), title=data.get('evidence', ''), width=width)

net.show_buttons(filter_=['physics'])
net.write_html("Medical_Graph_Improved_479.html")

print("‚úÖ ƒê√£ l∆∞u visualization: Medical_Graph_Improved_479.html")
print("\nüéâ HO√ÄN T·∫§T! M·ªü file HTML ƒë·ªÉ xem graph c·∫£i ti·∫øn.")

In [None]:
# === SO S√ÅNH TR∆Ø·ªöC/SAU ===
print("\nüìä SO S√ÅNH CONFIDENCE SCORES...\n")

# Sample 10 nodes ƒë·ªÉ xem s·ª± thay ƒë·ªïi
import random
sample_nodes = random.sample(list(G.nodes), min(10, len(G.nodes)))

print("Sample confidence scores:")
for node in sample_nodes:
    label = G.nodes[node].get('label', node)
    conf = G.nodes[node].get('confidence', 0)
    rel_score = G.nodes[node].get('relevance_score', 0)
    pages = len(G.nodes[node].get('pages', []))
    print(f"   - {label[:30]:30s} | Conf: {conf:.2f} | Rel: {rel_score} | Pages: {pages}")