In [None]:
# === RESUME EXTRACTION V·ªöI IMPROVED RELATION PROMPTS ===
!pip install -q langchain-groq pyvis pypdf networkx pydantic python-dotenv

import os
import re
import time
import pickle
import unicodedata
import networkx as nx
from typing import List
from dotenv import load_dotenv

from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import TokenTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
from pyvis.network import Network
from pydantic import BaseModel, Field

print("‚úÖ Import OK")

In [None]:
# === C·∫§U H√åNH ===
load_dotenv()

GROQ_API_KEYS = [
    os.getenv("GROQ_API_KEY_1"), os.getenv("GROQ_API_KEY_2"),
    os.getenv("GROQ_API_KEY_3"), os.getenv("GROQ_API_KEY_4"),
    os.getenv("GROQ_API_KEY_5"), os.getenv("GROQ_API_KEY_6"),
    os.getenv("GROQ_API_KEY_7"), os.getenv("GROQ_API_KEY_8"),
    os.getenv("GROQ_API_KEY_9"),
]
PDF_PATH = os.getenv("PDF_PATH")

# Normalize
MEDICAL_ABBREVIATIONS = {
    "btm": "b·ªánh th·∫≠n m·∫°n", "tha": "tƒÉng huy·∫øt √°p", "ƒëtƒë": "ƒë√°i th√°o ƒë∆∞·ªùng",
    "gfr": "ƒë·ªô l·ªçc c·∫ßu th·∫≠n", "ckd": "b·ªánh th·∫≠n m·∫°n",
}

MEDICAL_SYNONYMS = {
    "b·ªánh th·∫≠n m·∫°n": ["b·ªánh th·∫≠n m√£n", "suy th·∫≠n m·∫°n", "ckd"],
    "ƒë√°i th√°o ƒë∆∞·ªùng": ["ti·ªÉu ƒë∆∞·ªùng", "ƒëtƒë", "diabetes"],
}

def normalize_medical_text(text: str) -> str:
    if not text: return "Unknown"
    text = unicodedata.normalize("NFC", text).strip().lower()
    text = re.sub(r'\s+', ' ', text)
    words = [MEDICAL_ABBREVIATIONS.get(re.sub(r'[^\w]', '', w), w) for w in text.split()]
    text = ' '.join(words)
    for canonical, variants in MEDICAL_SYNONYMS.items():
        for variant in variants:
            text = text.replace(variant, canonical)
    return re.sub(r'\s+', ' ', text).strip().title()

class APIKeyManager:
    def __init__(self, api_keys: List[str]):
        self.api_keys = [k for k in api_keys if k]
        self.current_index = 0
        self.failed_keys = set()
    def get_current_key(self) -> str:
        return self.api_keys[self.current_index]
    def rotate_key(self) -> bool:
        self.failed_keys.add(self.current_index)
        for i in range(len(self.api_keys)):
            next_idx = (self.current_index + 1 + i) % len(self.api_keys)
            if next_idx not in self.failed_keys:
                self.current_index = next_idx
                print(f"üîÑ Key #{next_idx + 1}")
                return True
        return False
    def reset_failed(self):
        self.failed_keys.clear()

api_manager = APIKeyManager(GROQ_API_KEYS)
print(f"‚úÖ {len(api_manager.api_keys)} API keys")

In [None]:
# === SCHEMA ===
class MedicalEntity(BaseModel):
    name: str
    type: str
    description: str = ""
    relevance_score: int = Field(default=5, ge=1, le=10)

class MedicalRelation(BaseModel):
    source_name: str
    target_name: str
    relation: str
    confidence_score: int = Field(default=5, ge=1, le=10)
    evidence: str = ""

class AMGExtractionResult(BaseModel):
    entities: List[MedicalEntity] = Field(default_factory=list)
    relations: List[MedicalRelation] = Field(default_factory=list)

print("‚úÖ Schema OK")

In [None]:
# === LOAD CHECKPOINT ===
graph_path = "./amg_data/graph_full.pkl"

if os.path.exists(graph_path):
    with open(graph_path, "rb") as f:
        G = pickle.load(f)
    print(f"‚úÖ Loaded: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
else:
    print("‚ö†Ô∏è T·∫°o graph m·ªõi")
    G = nx.DiGraph()

In [None]:
# === ‚òÖ IMPROVED EXTRACTOR V·ªöI RELATION FOCUS ===
class ImprovedAMGExtractor:
    def __init__(self, api_manager: APIKeyManager):
        self.api_manager = api_manager
        
        # ‚òÖ‚òÖ‚òÖ PROMPT C·∫¢I TI·∫æN - FOCUS V√ÄO RELATIONS ‚òÖ‚òÖ‚òÖ
        self.extraction_prompt = ChatPromptTemplate.from_messages([
            ("system", """B·∫°n l√† chuy√™n gia tr√≠ch xu·∫•t th·ª±c th·ªÉ v√† QUAN H·ªÜ y t·∫ø.

## ENTITY TYPES:
DISEASE, DRUG, SYMPTOM, TEST, ANATOMY, TREATMENT, PROCEDURE, RISK_FACTOR, LAB_VALUE

## RELATION TYPES (QUAN TR·ªåNG):
CAUSES, TREATS, DIAGNOSED_BY, ASSOCIATED_WITH, RISK_FOR, INDICATES, CONTRAINDICATED_FOR, ADMINISTERED_FOR, MONITORS

## SCORING:
**Entities:** 9-10 (ch√≠nh), 7-8 (quan tr·ªçng), 5-6 (ph·ª•)
**Relations:** 9-10 (ch·∫Øc ch·∫Øn), 7-8 (c√≥ b·∫±ng ch·ª©ng r√µ), 5-6 (c√≥ li√™n quan)

## ‚òÖ FEW-SHOT EXAMPLES (C√ì RELATIONS):

**Example 1:**
Input: "B·ªánh th·∫≠n m·∫°n giai ƒëo·∫°n 5 c·∫ßn l·ªçc m√°u. ƒêi·ªÅu tr·ªã b·∫±ng erythropoietin khi Hb < 10g/dL."

Entities:
- B·ªánh th·∫≠n m·∫°n giai ƒëo·∫°n 5 (DISEASE, score=10)
- L·ªçc m√°u (TREATMENT, score=9)
- Erythropoietin (DRUG, score=8)
- Hb < 10g/dL (LAB_VALUE, score=7)

Relations:
- B·ªánh th·∫≠n m·∫°n giai ƒëo·∫°n 5 ‚Üí TREATED_BY ‚Üí L·ªçc m√°u (conf=10, evidence: "c·∫ßn l·ªçc m√°u")
- Hb < 10g/dL ‚Üí INDICATES ‚Üí Thi·∫øu m√°u (conf=9)
- Erythropoietin ‚Üí ADMINISTERED_FOR ‚Üí Thi·∫øu m√°u (conf=9, evidence: "khi Hb < 10g/dL")
- B·ªánh th·∫≠n m·∫°n ‚Üí CAUSES ‚Üí Thi·∫øu m√°u (conf=8)

**Example 2:**
Input: "ƒê√°i th√°o ƒë∆∞·ªùng type 2 l√† y·∫øu t·ªë nguy c∆° h√†ng ƒë·∫ßu g√¢y b·ªánh th·∫≠n m·∫°n. Ki·ªÉm so√°t ƒë∆∞·ªùng huy·∫øt b·∫±ng metformin."

Entities:
- ƒê√°i th√°o ƒë∆∞·ªùng type 2 (DISEASE, score=10)
- B·ªánh th·∫≠n m·∫°n (DISEASE, score=10)
- Metformin (DRUG, score=8)
- Ki·ªÉm so√°t ƒë∆∞·ªùng huy·∫øt (TREATMENT, score=8)

Relations:
- ƒê√°i th√°o ƒë∆∞·ªùng type 2 ‚Üí RISK_FOR ‚Üí B·ªánh th·∫≠n m·∫°n (conf=10, evidence: "y·∫øu t·ªë nguy c∆° h√†ng ƒë·∫ßu")
- ƒê√°i th√°o ƒë∆∞·ªùng type 2 ‚Üí CAUSES ‚Üí B·ªánh th·∫≠n m·∫°n (conf=9)
- Metformin ‚Üí TREATS ‚Üí ƒê√°i th√°o ƒë∆∞·ªùng type 2 (conf=9, evidence: "ki·ªÉm so√°t ƒë∆∞·ªùng huy·∫øt")
- Ki·ªÉm so√°t ƒë∆∞·ªùng huy·∫øt ‚Üí TREATS ‚Üí ƒê√°i th√°o ƒë∆∞·ªùng type 2 (conf=9)

## ‚òÖ QUY T·∫ÆC QUAN TR·ªåNG:
1. **PH·∫¢I T√åM T·ªêI THI·ªÇU 3-5 RELATIONS cho m·ªói ƒëo·∫°n vƒÉn**
2. T√¨m c·∫£ quan h·ªá TR·ª∞C TI·∫æP v√† GI√ÅN TI·∫æP
3. Evidence ph·∫£i tr√≠ch d·∫´n ch√≠nh x√°c t·ª´ vƒÉn b·∫£n
4. Confidence cao n·∫øu c√≥ t·ª´ kh√≥a r√µ r√†ng: "g√¢y", "ƒëi·ªÅu tr·ªã", "ch·∫©n ƒëo√°n", "li√™n quan"
5. GI·ªÆ NGUY√äN ti·∫øng Vi·ªát, KH√îNG extract: quy·∫øt ƒë·ªãnh, vƒÉn b·∫£n, s·ªë trang

Tr√≠ch xu·∫•t ENTITIES v√† NHI·ªÄU RELATIONS:"""),
            ("human", "{text}")
        ])
        
        self._init_llm()
    
    def _init_llm(self):
        self.llm = ChatGroq(temperature=0.1, model="llama-3.3-70b-versatile", 
                           api_key=self.api_manager.get_current_key())
        self.chain = self.extraction_prompt | self.llm.with_structured_output(AMGExtractionResult)
    
    def extract(self, text: str, max_retries=3):
        for attempt in range(max_retries):
            try:
                result = self.chain.invoke({"text": text})
                if result:
                    # Filter
                    valid = [e for e in result.entities 
                            if len(e.name) >= 2 and 
                            not any(x in e.name.lower() for x in ['quy·∫øt ƒë·ªãnh', 'vƒÉn b·∫£n', 'trang'])]
                    result.entities = valid
                return result if result else AMGExtractionResult()
            except Exception as e:
                if "rate" in str(e).lower() or "limit" in str(e).lower():
                    if self.api_manager.rotate_key():
                        self._init_llm()
                        time.sleep(2)
                    else:
                        print("‚è≥ Ch·ªù 5 ph√∫t...")
                        time.sleep(300)
                        self.api_manager.reset_failed()
                        self._init_llm()
                else:
                    time.sleep(1)
        return AMGExtractionResult()

print("‚úÖ Improved Extractor ready (RELATION FOCUS)")

In [None]:
# === HELPERS ===
def add_entity(G, entity: MedicalEntity, page_num: int):
    norm_name = normalize_medical_text(entity.name)
    confidence = min(1.0, entity.relevance_score / 10.0)
    
    if not G.has_node(norm_name):
        G.add_node(norm_name, label=entity.name, type=entity.type.upper(),
                  description=entity.description, confidence=confidence,
                  relevance_score=entity.relevance_score, pages=[page_num])
    else:
        old_conf = G.nodes[norm_name].get('confidence', 0)
        if confidence > old_conf:
            G.nodes[norm_name]['confidence'] = confidence
        if page_num not in G.nodes[norm_name]['pages']:
            G.nodes[norm_name]['pages'].append(page_num)

def add_relation(G, rel: MedicalRelation, page_num: int):
    src = normalize_medical_text(rel.source_name)
    tgt = normalize_medical_text(rel.target_name)
    if G.has_node(src) and G.has_node(tgt):
        G.add_edge(src, tgt, relation=rel.relation.upper(),
                  confidence=min(1.0, rel.confidence_score/10), 
                  evidence=rel.evidence, page=page_num)

print("‚úÖ Helpers OK")

In [None]:
# === EXTRACT T·ª™ CHUNK 121 ===
START_CHUNK = 120
START_PAGE = 17

print(f"üöÄ RESUME T·ª™ CHUNK {START_CHUNK + 1} (IMPROVED RELATIONS)\n")

# Load PDF
loader = PyPDFLoader(PDF_PATH)
pages = loader.load()[START_PAGE-1:]
print(f"‚úÇÔ∏è C·∫Øt {START_PAGE-1} trang, c√≤n {len(pages)} trang")

# Split
splitter = TokenTextSplitter(chunk_size=1000, chunk_overlap=200)
chunks = splitter.split_documents(pages)
print(f"üì¶ T·ªïng: {len(chunks)} chunks")
print(f"‚ñ∂Ô∏è  X·ª≠ l√Ω: chunk {START_CHUNK + 1} ‚Üí {len(chunks)}\n")

# Extract
extractor = ImprovedAMGExtractor(api_manager)
chunks_to_process = chunks[START_CHUNK:]

print("üß† ƒêANG TR√çCH XU·∫§T (Improved Relation Prompts)...\n")

for i, chunk in enumerate(chunks_to_process, start=START_CHUNK+1):
    page_num = chunk.metadata.get("page", 0) + 1
    result = extractor.extract(chunk.page_content)
    
    if result:
        num_ent = len(result.entities)
        num_rel = len(result.relations)
        
        if num_ent > 0:
            avg_rel = sum(e.relevance_score for e in result.entities) / num_ent
            print(f"   [{i}/{len(chunks)}] +{num_ent} entities, +{num_rel} relations (avg: {avg_rel:.1f})")
            
            for ent in result.entities:
                add_entity(G, ent, page_num)
            
            for rel in result.relations:
                add_relation(G, rel, page_num)
        else:
            print(f"   [{i}/{len(chunks)}] --")
    else:
        print(f"   [{i}/{len(chunks)}] --")
    
    # Save m·ªói 20 chunks
    if i % 20 == 0:
        with open("./amg_data/graph_improved.pkl", "wb") as f:
            pickle.dump(G, f)
        print(f"   üíæ Saved: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges\n")
    
    time.sleep(1)

print(f"\n‚úÖ XONG!")
print(f"Final: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")

In [None]:
# === FREQUENCY BOOST ===
print("\nüìä FREQUENCY BOOST...\n")

boosted = 0
for node in list(G.nodes):
    label = G.nodes[node].get('label', node)
    pages = G.nodes[node].get('pages', [])
    freq = len(pages)
    
    if freq >= 3:
        boost = 0.1 if freq >= 5 else 0.05
        old_conf = G.nodes[node]['confidence']
        G.nodes[node]['confidence'] = min(1.0, old_conf + boost)
        boosted += 1

print(f"‚úÖ Boosted {boosted} entities")

In [None]:
# === SAVE ===
with open("./amg_data/graph_improved.pkl", "wb") as f:
    pickle.dump(G, f)

print(f"\nüíæ ƒê√£ l∆∞u: graph_improved.pkl")
print(f"   {G.number_of_nodes()} nodes | {G.number_of_edges()} edges")
print(f"   T·ª∑ l·ªá: {G.number_of_edges()/G.number_of_nodes():.2f} edges/node")

In [None]:
# === VISUALIZE ===
net = Network(height="850px", width="100%", directed=True, notebook=False)
net.force_atlas_2based(gravity=-50, spring_length=200)

color_map = {
    'DISEASE': '#ff6b6b', 'DRUG': '#4ecdc4', 'SYMPTOM': '#ffe66d',
    'TEST': '#95e1d3', 'TREATMENT': '#aa96da'
}

for node, data in G.nodes(data=True):
    color = color_map.get(data.get('type', 'OTHER'), '#97C2FC')
    conf = data.get('confidence', 0.5)
    size = 15 + conf * 25
    net.add_node(node, label=data.get('label'), color=color, size=size,
                title=f"{data.get('label')}<br>Conf: {conf:.2f}")

for u, v, data in G.edges(data=True):
    width = 1 + data.get('confidence', 0.5) * 3
    net.add_edge(u, v, label=data.get('relation'), width=width)

net.show_buttons(filter_=['physics'])
net.write_html("Medical_Graph_Improved_Relations.html")

print("‚úÖ Medical_Graph_Improved_Relations.html")
print("üéâ HO√ÄN T·∫§T!")