In [1]:
%%capture
%pip install sentence-transformers scikit-learn langchain-google-genai langchain-experimental

In [2]:
import os
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
import random
import numpy as np
import re
from collections import Counter
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
import json
from pathlib import Path

2025-12-28 10:20:41.057801: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766917241.283102      17 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766917241.357516      17 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766917241.913419      17 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1766917241.913468      17 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1766917241.913471      17 computation_placer.cc:177] computation placer alr

# Configuration

In [3]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
API_KEY = user_secrets.get_secret("GEMINI API")

allowed_nodes = [
    "DRUG",            # Thuốc
    "DISEASE",         # Bệnh
    "CHEMICAL",        # Hóa chất/Thành phần
    "ORGANISM",        # Vi sinh vật
    "TEST_METHOD",      # Phương pháp thử
    "STANDARD",        # Tiêu chuẩn
    "STORAGE_CONDITION",# Bảo quản
    "PRODUCTION_METHOD"       # Phương pháp sản xuất
]

allowed_relationships = [
    "TREATS",          # Drug -> Disease
    "CONTAINS",        # Drug -> Chemical
    "TARGETS",         # Drug -> Organism
    "TESTED_BY",       # Drug -> TestMethod
    "HAS_STANDARD",    # Drug -> Standard
    "STORED_AT",       # Drug -> Storage
    "PRODUCED_BY",     # Drug -> Production
    "REQUIRES"         # TestMethod -> Chemical
]

system_instructions = """
Bạn là chuyên gia trích xuất Knowledge Graph từ Dược điển Việt Nam. 
Nhiệm vụ của bạn là đọc văn bản và trích xuất các Node và Relationship chính xác theo định nghĩa sau.

--- ĐỊNH NGHĨA ENTITY (THỰC THỂ) ---
1. **DRUG (Thuốc)**: Tên chế phẩm, vắc xin hoặc sinh phẩm y tế.
   - *Ví dụ:* Huyết thanh kháng bạch hầu, Interferon Alpha 2.
2. **CHEMICAL (Hóa chất)**: Các chất hóa học, thành phần cấu tạo hoặc thuốc thử.
   - *Ví dụ:* Protein, Globulin miễn dịch, Amoni sulfat.
3. **DISEASE (Bệnh)**: Bệnh lý hoặc triệu chứng lâm sàng.
   - *Ví dụ:* Viêm gan B, Bạch hầu, Uốn ván.
4. **ORGANISM (Vi sinh vật)**: Vi khuẩn, virus hoặc dòng tế bào.
   - *Ví dụ:* Mycobacterium tuberculosis, Virus viêm gan B.
5. **TEST_METHOD (Phương pháp thử)**: Các kỹ thuật kiểm nghiệm.
   - *Ví dụ:* Điện di miễn dịch, ELISA, Sắc ký lỏng.
6. **STANDARD (Tiêu chuẩn)**: Các chỉ số kỹ thuật cụ thể.
   - *Ví dụ:* pH: 6.4-7.2, Nhiệt độ: 2-8°C, Hiệu giá: ≥100 IU/mL.
7. **STORAGE_CONDITION (Điều kiện bảo quản)**: Môi trường lưu giữ.
   - *Ví dụ:* Lạnh: 2-10°C, Tránh ánh sáng.
8. **PRODUCTION_METHOD (Phương pháp sản xuất)**: Công nghệ bào chế.
   - *Ví dụ:* Tái tổ hợp ADN, Nuôi cấy tế bào.

--- ĐỊNH NGHĨA RELATIONSHIP (QUAN HỆ) ---
- **CONTAINS**: (DRUG) -> (CHEMICAL)
- **TREATS**: (DRUG) -> (DISEASE)
- **TARGETS**: (DRUG) -> (ORGANISM)
- **TESTED_BY**: (DRUG) -> (TEST_METHOD)
- **HAS_STANDARD**: (DRUG) -> (STANDARD)
- **STORED_AT**: (DRUG) -> (STORAGE_CONDITION)
- **PRODUCED_BY**: (DRUG) -> (PRODUCTION_METHOD)
- **REQUIRES**: (TEST_METHOD) -> (CHEMICAL)

--- QUY TẮC QUAN TRỌNG (BẮT BUỘC TUÂN THỦ) ---
1. **Auto-Correction (Tự động sửa lỗi)**:
   - Nếu phát hiện lỗi chính tả hoặc lỗi OCR (nhận dạng ký tự) trong tên thực thể, HÃY SỬA NÓ về dạng đúng.
   - *Ví dụ OCR:* "Vltamin C" -> Sửa thành "Vitamin C"; "Paracefamol" -> Sửa thành "Paracetamol"; "mlg" -> Sửa thành "mg"; số "l" nhầm thành "1".
   
2. **Standardization (Chuẩn hóa tên)**: 
   - Viết hoa chữ cái đầu (Title Case). Ví dụ: "viêm gan b" -> "Viêm gan B".
   - Giữ nguyên các thuật ngữ tiếng Latinh hoặc tên khoa học (VD: *Mycobacterium tuberculosis*).

3. **Context Awareness**: Chỉ trích xuất thông tin CÓ TRONG VĂN BẢN. Không tự bịa ra tiêu chuẩn nếu văn bản không ghi.
"""

llm = ChatGoogleGenerativeAI(temperature=0, model="gemini-2.5-flash", max_output_tokens=8192, api_key=API_KEY)

prompt_template = ChatPromptTemplate.from_messages([
    ("system", system_instructions),
    ("human", "{input}")
])

llm_transformer = LLMGraphTransformer(
    llm=llm,
    allowed_nodes=allowed_nodes,
    allowed_relationships=allowed_relationships,
    strict_mode=True,
    prompt=prompt_template,
    node_properties=["description"],  # Added to capture entity descriptions
    relationship_properties=["context"],
)

input_path = Path("/kaggle/input/dvn-v-t2/merged_text_final_no_breaks.txt")
output_path = Path("/kaggle/working/graph_documents.json")

# **Data preprocessing**

### Split full text into chunks

In [4]:
with open(input_path, "r", encoding="utf-8") as f:
    text = f.read()
    
chunk_size = 2000
overlap_size = 200

chunks = []
start = 0

while start < len(text):
    end = start + chunk_size
    chunk = text[start:end]
    chunks.append(chunk)
    start += chunk_size - overlap_size
    
print(f"Total chunks created: {len(chunks)}")

Total chunks created: 604


# **Extract Graph**

In [5]:
docs = [Document(page_content=chunk) for chunk in chunks]

try:
    graph_documents = llm_transformer.convert_to_graph_documents(docs)
    
    if graph_documents:
        doc = graph_documents[random.randint(0, len(graph_documents)-1)] # sample a random document
        print(f"Found {len(doc.nodes)} Nodes and {len(doc.relationships)} Relationships")
        
        print("\n--- NODES ---")
        for node in doc.nodes:
            print(f"- [{node.type}]: {node.id}")
            
        print("\n--- RELATIONSHIPS ---")
        for rel in doc.relationships:
            print(f"- {rel.source.id} --[{rel.type}]--> {rel.target.id}")
except Exception as e:
    print("OOPS:", str(e))

Found 20 Nodes and 22 Relationships

--- NODES ---
- [Drug]: Ich Mau
- [Chemical]: Hydrochloric Acid
- [Chemical]: Ethyl Acetate
- [Chemical]: N-Butanol
- [Chemical]: Ethanol
- [Chemical]: Neutral Aluminium Oxide
- [Chemical]: Acetonitrile
- [Chemical]: Acetic Acid
- [Chemical]: Stachydrine Hydrochloride
- [Chemical]: Dragendorff Reagent
- [Test_method]: Soxhlet Extraction
- [Test_method]: Column Chromatography
- [Test_method]: Sắc Ký Lớp Mỏng
- [Test_method]: Độ Ẩm
- [Test_method]: Tỷ Lệ Vụn Nát
- [Standard]: Rf Value And Color Match Stachydrine Hydrochloride
- [Standard]: Rf Value And Color Match Ich Mau Reference Solution
- [Standard]: Độ Ẩm: Không Quá 13.5%
- [Standard]: Tỷ Lệ Vụn Nát: Không Quá 10.0% Qua Rây 4 Mm
- [Standard]: Nồng Độ Stachydrine Hydrochloride: Khoảng 1 Mg/Ml

--- RELATIONSHIPS ---
- Ich Mau --[TESTED_BY]--> Soxhlet Extraction
- Ich Mau --[TESTED_BY]--> Column Chromatography
- Ich Mau --[TESTED_BY]--> Sắc Ký Lớp Mỏng
- Ich Mau --[TESTED_BY]--> Độ Ẩm
- Ich Mau --[T

# **Analysis**

In [6]:
def display_graph_stats(graph_documents):
    """Display statistics about extracted graph."""
    total_nodes = sum(len(g.nodes) for g in graph_documents)
    total_rels = sum(len(g.relationships) for g in graph_documents)
    
    node_types = {}
    rel_types = {}
    
    for graph_doc in graph_documents:
        for node in graph_doc.nodes:
            node_types[node.type] = node_types.get(node.type, 0) + 1
        for rel in graph_doc.relationships:
            rel_types[rel.type] = rel_types.get(rel.type, 0) + 1
    
    print(f"\n{'='*60}")
    print(f"EXTRACTION SUMMARY")
    print(f"{'='*60}")
    print(f"Total Nodes: {total_nodes}")
    print(f"Total Relationships: {total_rels}")
    
    print(f"\nNode Distribution:")
    for node_type, count in sorted(node_types.items(), key=lambda x: -x[1]):
        print(f"  {node_type:20s}: {count:4d}")
    
    print(f"\nRelationship Distribution:")
    for rel_type, count in sorted(rel_types.items(), key=lambda x: -x[1]):
        print(f"  {rel_type:20s}: {count:4d}")
    print(f"{'='*60}\n")

In [7]:
display_graph_stats(graph_documents)


EXTRACTION SUMMARY
Total Nodes: 15661
Total Relationships: 14715

Node Distribution:
  Standard            : 4194
  Chemical            : 4179
  Test_method         : 2444
  Disease             : 1451
  Production_method   : 1190
  Drug                : 1119
  Organism            :  631
  Storage_condition   :  453

Relationship Distribution:
  HAS_STANDARD        : 4265
  REQUIRES            : 3405
  TESTED_BY           : 2405
  TREATS              : 1368
  CONTAINS            : 1219
  PRODUCED_BY         : 1216
  STORED_AT           :  480
  TARGETS             :  357



# **Post-Processing**

In [8]:
def validate_graph(graph_documents):
    """Remove invalid triples based on rules."""
    dropped_triples = 0
    
    for doc in graph_documents:
        valid_rels = []
        for rel in doc.relationships:
            is_valid = True
            
            # Rule 1: No self-loops
            if rel.source.id == rel.target.id:
                is_valid = False
            
            # Rule 2: Check allowed types for specific relationships (Basic sanity check)
            # Example: TREATS should be DRUG -> DISEASE
            if rel.type == "TREATS":
                if rel.source.type != "DRUG" or rel.target.type != "DISEASE":
                    # We might be lenient here if the extraction was fuzzy, but let's enforce if strict
                    pass 
            
            if is_valid:
                valid_rels.append(rel)
            else:
                dropped_triples += 1
                
        doc.relationships = valid_rels
        
    print(f"Validation complete. Dropped {dropped_triples} invalid relationships.")
    return graph_documents

In [9]:
def deduplicate_nodes(graph_documents, model_name='vinai/bartpho-syllable', threshold=0.90):
    """Merge similar nodes using Vietnamese Sentence BERT."""
        
    print(f"Loading embedding model: {model_name}...")
    model = SentenceTransformer(model_name)
    
    # 1. Collect all unique node IDs and their frequencies
    node_counts = Counter()
    for doc in graph_documents:
        for node in doc.nodes:
            node_counts[node.id] += 1
            
    unique_ids = list(node_counts.keys())
    print(f"Total unique nodes before deduplication: {len(unique_ids)}")
    
    if not unique_ids:
        return graph_documents

    # 2. Compute embeddings
    print("Computing embeddings (this may take a moment)...")
    embeddings = model.encode(unique_ids, show_progress_bar=True)
    
    # 3. Compute similarity matrix
    print("Clustering similar nodes...")
    similarity_matrix = cosine_similarity(embeddings)
    
    # 4. Cluster nodes
    id_mapping = {}
    visited = set()
    
    for i in range(len(unique_ids)):
        if i in visited: continue
        
        # Find all nodes similar to unique_ids[i]
        similar_indices = [i]
        visited.add(i)
        
        for j in range(i + 1, len(unique_ids)):
            if j in visited: continue
            if similarity_matrix[i][j] >= threshold:
                similar_indices.append(j)
                visited.add(j)
        
        if len(similar_indices) > 1:
            # Found a cluster
            cluster_ids = [unique_ids[idx] for idx in similar_indices]
            
            # Choose canonical ID: The one with highest frequency
            canonical_id = max(cluster_ids, key=lambda x: node_counts[x])
            
            # print(f"Merging cluster: {cluster_ids} -> {canonical_id}")
            
            for cid in cluster_ids:
                if cid != canonical_id:
                    id_mapping[cid] = canonical_id
            
    # 5. Apply mapping
    if id_mapping:
        print(f"Merging {len(id_mapping)} duplicate nodes...")
        for doc in graph_documents:
            # Update Nodes
            for node in doc.nodes:
                if node.id in id_mapping:
                    node.id = id_mapping[node.id]
            
            # Update Relationships
            for rel in doc.relationships:
                if rel.source.id in id_mapping:
                    rel.source.id = id_mapping[rel.source.id]
                if rel.target.id in id_mapping:
                    rel.target.id = id_mapping[rel.target.id]
    else:
        print("No duplicates found above threshold.")
                
    return graph_documents

In [10]:
def normalize_text(text):
    """Standardize text: Title Case, trim whitespace, fix units."""
    if not text: return ""
    # Standardize whitespace
    text = " ".join(text.split())
    # Title Case
    text = text.title()
    # Fix common unit formatting (e.g., 10mg -> 10 mg)
    text = re.sub(r'(\d)\s*(mg|ml|g|kg|l)', r'\1 \2', text, flags=re.IGNORECASE)
    return text

def normalize_graph(graph_documents):
    for doc in graph_documents:
        for node in doc.nodes:
            node.id = normalize_text(node.id)
    return graph_documents

RELATIONSHIP_MAPPINGS = {
    # TREATS
    "USED_FOR": "TREATS",
    "INDICATED_FOR": "TREATS",
    "PREVENTS": "TREATS",
    "CURES": "TREATS",
    
    # CONTAINS
    "COMPRISES": "CONTAINS",
    "CONSISTS_OF": "CONTAINS",
    "HAS_COMPONENT": "CONTAINS",
    "INCLUDES": "CONTAINS",
    
    # TARGETS
    "KILLS": "TARGETS",
    "INHIBITS": "TARGETS",
    "ACTS_ON": "TARGETS",
    
    # PRODUCED_BY
    "MANUFACTURED_BY": "PRODUCED_BY",
    "PREPARED_BY": "PRODUCED_BY",
    
    # STORED_AT
    "STORE_IN": "STORED_AT",
    "KEPT_IN": "STORED_AT",
    
    # HAS_STANDARD
    "COMPLIES_WITH": "HAS_STANDARD",
    "MEETS_STANDARD": "HAS_STANDARD",
    
    # TESTED_BY
    "ANALYZED_BY": "TESTED_BY",
    "CHECKED_WITH": "TESTED_BY",
}

def align_schema(graph_documents):
    for doc in graph_documents:
        for rel in doc.relationships:
            # Normalize relationship type to uppercase
            rel_type_upper = rel.type.upper().replace(" ", "_")
            if rel_type_upper in RELATIONSHIP_MAPPINGS:
                rel.type = RELATIONSHIP_MAPPINGS[rel_type_upper]
            else:
                rel.type = rel_type_upper
    return graph_documents

In [11]:
# Execute Pipeline
print("--- STARTING POST-PROCESSING ---")

# 1. Normalization
print("\n[1/4] Normalizing text...")
RELATIONSHIP_MAPPINGS = {
    # TREATS
    "USED_FOR": "TREATS",
    "INDICATED_FOR": "TREATS",
    "PREVENTS": "TREATS",
    "CURES": "TREATS",
    
    # CONTAINS
    "COMPRISES": "CONTAINS",
    "CONSISTS_OF": "CONTAINS",
    "HAS_COMPONENT": "CONTAINS",
    "INCLUDES": "CONTAINS",
    
    # TARGETS
    "KILLS": "TARGETS",
    "INHIBITS": "TARGETS",
    "ACTS_ON": "TARGETS",
    
    # PRODUCED_BY
    "MANUFACTURED_BY": "PRODUCED_BY",
    "PREPARED_BY": "PRODUCED_BY",
    
    # STORED_AT
    "STORE_IN": "STORED_AT",
    "KEPT_IN": "STORED_AT",
    
    # HAS_STANDARD
    "COMPLIES_WITH": "HAS_STANDARD",
    "MEETS_STANDARD": "HAS_STANDARD",
    
    # TESTED_BY
    "ANALYZED_BY": "TESTED_BY",
    "CHECKED_WITH": "TESTED_BY",
}
graph_documents = normalize_graph(graph_documents)

# 2. Schema Alignment
print("\n[2/4] Aligning schema...")
graph_documents = align_schema(graph_documents)

# 3. Entity Resolution
print("\n[3/4] Resolving entities...")
graph_documents = deduplicate_nodes(graph_documents, threshold=0.92) # High threshold for safety

# 4. Validation
print("\n[4/4] Validating graph...")
graph_documents = validate_graph(graph_documents)

print("\n--- POST-PROCESSING COMPLETE ---")
display_graph_stats(graph_documents)

--- STARTING POST-PROCESSING ---

[1/4] Normalizing text...

[2/4] Aligning schema...

[3/4] Resolving entities...
Loading embedding model: vinai/bartpho-syllable...




config.json:   0%|          | 0.00/897 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.58G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.58G [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

dict.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/2.83M [00:00<?, ?B/s]

Total unique nodes before deduplication: 8807
Computing embeddings (this may take a moment)...


Batches:   0%|          | 0/276 [00:00<?, ?it/s]

Clustering similar nodes...
Merging 8025 duplicate nodes...

[4/4] Validating graph...
Validation complete. Dropped 4069 invalid relationships.

--- POST-PROCESSING COMPLETE ---

EXTRACTION SUMMARY
Total Nodes: 15661
Total Relationships: 10646

Node Distribution:
  Standard            : 4194
  Chemical            : 4179
  Test_method         : 2444
  Disease             : 1451
  Production_method   : 1190
  Drug                : 1119
  Organism            :  631
  Storage_condition   :  453

Relationship Distribution:
  REQUIRES            : 2984
  HAS_STANDARD        : 2547
  TESTED_BY           : 1667
  TREATS              :  996
  CONTAINS            :  993
  PRODUCED_BY         :  818
  STORED_AT           :  326
  TARGETS             :  315



# **Save the result**

In [12]:
def graph_to_dict(graph_documents):
    """Convert graph_documents to serializable dict format."""
    result = []
    for doc in graph_documents:
        doc_dict = {
            'nodes': [
                {
                    'id': node.id,
                    'type': node.type,
                    'properties': getattr(node, 'properties', {})
                } for node in doc.nodes
            ],
            'relationships': [
                {
                    'type': rel.type,
                    'source': rel.source.id,
                    'target': rel.target.id,
                    'properties': getattr(rel, 'properties', {})
                } for rel in doc.relationships
            ]
        }
        result.append(doc_dict)
    return result

# Save to file
with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(graph_to_dict(graph_documents), f, ensure_ascii=False, indent=2)

print(f"Graph documents saved to {output_path.resolve()}")

Graph documents saved to /kaggle/working/graph_documents.json
