In [8]:
import pandas as pd


allowed_nodes = [
    "DRUG",            # Thuốc
    "DISEASE",         # Bệnh
    "CHEMICAL",        # Hóa chất/Thành phần
    "ORGANISM",        # Vi sinh vật
    "TEST_METHOD",      # Phương pháp thử
    "STANDARD",        # Tiêu chuẩn
    "STORAGE_CONDITION",# Bảo quản
    "PRODUCTION_METHOD"       # Phương pháp sản xuất
]

allowed_relationships = [
    "TREATS",          # Drug -> Disease
    "CONTAINS",        # Drug -> Chemical
    "TARGETS",         # Drug -> Organism
    "TESTED_BY",       # Drug -> TestMethod
    "HAS_STANDARD",    # Drug -> Standard
    "STORED_AT",       # Drug -> Storage
    "PRODUCED_BY",     # Drug -> Production
    "REQUIRES"         # TestMethod -> Chemical
]

result_csv = "../models/test_predictions.csv"
og_csv = "../data/processed/ViKG-NLQ-2-Cypher-data-cleaned.csv"

In [9]:
import re
from typing import Dict, Set

def fix_query(TARGET_NODES, TARGET_RELS):
    fuzzy_map = {}
    for item in TARGET_NODES + TARGET_RELS:
        key = item.replace("_","").upper() # TEST_METHOD -> TESTMETHOD
        fuzzy_map[key] = item # handles labels with no underscores

    def clean_text(text):
        def replacer_func(match):
            original_match = match.group(1)
            lookup_key = original_match.replace("_","").upper()

            if lookup_key in fuzzy_map:
                return ":" + fuzzy_map[lookup_key]
            return match.group(0)

        return re.sub(r':([a-zA-Z_]+)', replacer_func, text)

    return clean_text

def normalize_cypher_query(query: str) -> str:
    """
    Normalize a Cypher query by replacing variable names with standardized placeholders.
    This allows comparing queries that are semantically identical but use different variable names.
    
    Args:
        query: The Cypher query string to normalize
        
    Returns:
        Normalized query with standardized variable names (var1, var2, etc.)
    """
    if not query or not isinstance(query, str):
        return ""
    
    # First, protect string literals by temporarily replacing them
    string_pattern = r'"[^"]*"|\'[^\']*\''
    strings = []
    
    def save_string(match):
        strings.append(match.group(0))
        return f"__STRING_{len(strings)-1}__"
    
    # Temporarily replace strings with placeholders
    query_no_strings = re.sub(string_pattern, save_string, query)
    
    # Extract all variable names from node and relationship patterns
    variables = []
    
    # Match node patterns: (var:Label) or (var)
    node_pattern = r'\((\w+)(?::\w+(?:\|\w+)*)?\)'
    for match in re.finditer(node_pattern, query_no_strings):
        var = match.group(1)
        if var and not var.isupper() and not var.startswith('__STRING_'):
            if var not in variables:
                variables.append(var)
    
    # Match relationship patterns: -[var:REL]- or -[var]-
    rel_pattern = r'-\[(\w+)(?::\w+)?\]-'
    for match in re.finditer(rel_pattern, query_no_strings):
        var = match.group(1)
        if var and not var.isupper():
            if var not in variables:
                variables.append(var)
    
    # Create mapping from original variables to normalized ones
    var_mapping = {var: f"var{i+1}" for i, var in enumerate(variables)}
    
    # Replace variables in the query (excluding strings)
    normalized = query_no_strings
    
    # Sort by length (longest first) to avoid partial replacements
    for original_var in sorted(var_mapping.keys(), key=len, reverse=True):
        normalized_var = var_mapping[original_var]
        # Use word boundaries to ensure we replace complete variable names
        pattern = r'\b' + re.escape(original_var) + r'\b'
        normalized = re.sub(pattern, normalized_var, normalized)
    
    # Restore string literals
    for i, string in enumerate(strings):
        normalized = normalized.replace(f"__STRING_{i}__", string)
    
    return normalized


def soft_exact_match(prediction: str, ground_truth: str) -> bool:
    """
    Compare two Cypher queries with "soft" exact matching.
    Returns True if queries are identical after normalizing variable names.
    
    Args:
        prediction: The predicted Cypher query
        ground_truth: The ground truth Cypher query
        
    Returns:
        True if queries match after normalization, False otherwise
    """
    norm_pred = normalize_cypher_query(prediction)
    norm_truth = normalize_cypher_query(ground_truth)
    
    return norm_pred == norm_truth

def compare_relationship_type(pred, gt):
    pattern = re.compile(r'\[.+:([A-Z_]+)\]')
    pred_entities = re.findall(pattern, pred)
    gt_entities = re.findall(pattern, gt)
    
    return set(pred_entities) == set(gt_entities)
    
def compare_entity_type(pred, gt):
    pattern = re.compile(r'\(.+:([A-Z_]+)\)')
    pred_entities = re.findall(pattern, pred)
    gt_entities = re.findall(pattern, gt)
    
    return set(pred_entities) == set(gt_entities)

def compare_entity_name(pred, gt):
    pattern = re.compile(r'\"(\w+)\"')
    pred_entities = re.findall(pattern, pred)
    gt_entities = re.findall(pattern, gt)
    
    return set(pred_entities) == set(gt_entities)

In [10]:
df = pd.read_csv(result_csv, encoding="utf-8")
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 541 entries, 0 to 540
Data columns (total 3 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   prediction    541 non-null    object
 1   ground_truth  541 non-null    object
 2   match         541 non-null    bool  
dtypes: bool(1), object(2)
memory usage: 9.1+ KB


# Error analysis

In [11]:
df['soft_exact_match'] = df.apply(lambda row: soft_exact_match(row['prediction'], row['ground_truth']), axis=1)

soft_accuracy = df['soft_exact_match'].sum() / len(df)
print(f"Hard Exact Match Accuracy: {df['match'].sum()}/{len(df)} - {(df['match']).sum() / len(df) * 100:.2f}%")
print(f"Soft Exact Match Accuracy: {df['soft_exact_match'].sum()}/{len(df)} - {soft_accuracy * 100:.2f}%")

# Wrong predictions analysis
wrong_predictions = df[df['soft_exact_match'] == False]
samples = wrong_predictions.sample(5)

for idx, row in samples.iterrows():
    print(f"Index: {idx}")
    print(f"Prediction: {row['prediction']}")
    print(f"Ground Truth: {row['ground_truth']}")
    print("-" * 50)

Hard Exact Match Accuracy: 357/541 - 65.99%
Soft Exact Match Accuracy: 395/541 - 73.01%
Index: 444
Prediction: MATCH (d:D)-[:TESTEDBY](t:TESTMETHOD) WHERE toLower(d.id) CONTAINS "vắc xin thành phẩm" RETURN t.id
Ground Truth: MATCH (d:D)-[:TESTEDBY](t:TESTMETHOD) WHERE toLower(d.id) CONTAINS "vắc xin thành phẩm" AND (toLower(t.id) CONTAINS "chất bảo quản" OR toLower(t.id) CONTAINS "ph" OR toLower(t.id) CONTAINS "nhận dạng") RETURN t.id
--------------------------------------------------
Index: 346
Prediction: MATCH (t:TESTMETHOD)-[:HASSTANDARD](s:STANDARD) WHERE toLower(t.id) CONTAINS "phụ lục 9.6" RETURN s.id
Ground Truth: MATCH (tm:TESTMETHOD)-[:REQUIRES](s:STANDARD) WHERE toLower(tm.id) CONTAINS "phụ lục 9.6" AND toLower(s.id) CONTAINS "h" RETURN s.id
--------------------------------------------------
Index: 268
Prediction: MATCH (t:TESTMETHOD)-[:HASSTANDARD](s:STANDARD) WHERE toLower(t.id) CONTAINS "chiết nóng" RETURN s.id
Ground Truth: MATCH (s:STANDARD)-[:REQUIRES](tm:TESTMETHOD) W

## Entity Type Confusion

In [12]:
df['entity_type_match'] = df.apply(lambda row: compare_entity_type(row['prediction'], row['ground_truth']), axis=1)
entity_type_accuracy = df['entity_type_match'].value_counts().to_dict()

print(f"Entity Type Match Accuracy: {entity_type_accuracy[True]}/{len(df)} ({entity_type_accuracy[True] / len(df) * 100:.0f}%)")
print(f"Entity Type Mismatch Count: {entity_type_accuracy.get(False, 0)} ({entity_type_accuracy.get(False, 0) / len(df) * 100:.0f}%)")
print()
# sample
wrong_entity_type = df[df['entity_type_match'] == False].sample(3, random_state=42)
for idx, row in wrong_entity_type.iterrows():
    print(f"Index: {idx}")
    print(f"Prediction: {row['prediction']}")
    print(f"Ground Truth: {row['ground_truth']}")
    print("-" * 50)

Entity Type Match Accuracy: 529/541 (98%)
Entity Type Mismatch Count: 12 (2%)

Index: 455
Prediction: MATCH (c:CHEMICAL)-[:PRODUCEDBY](o:ORGANISM) WHERE toLower(c.id) CONTAINS "kháng nguyên bề mặt (hbsag)" RETURN o.id
Ground Truth: MATCH (o:ORGANISM)<-[:PRODUCEDBY]-(c:CHEMICAL) WHERE toLower(c.id) CONTAINS "hbsag" RETURN o.id
--------------------------------------------------
Index: 432
Prediction: MATCH (t:TESTMETHOD)-[:HASSTANDARD](s:STANDARD) WHERE toLower(t.id) CONTAINS "định lượng" RETURN s.id
Ground Truth: MATCH (s:STANDARD)-[:TESTEDBY](tm:TESTMETHOD) WHERE toLower(tm.id) CONTAINS "định lượng" RETURN s.id
--------------------------------------------------
Index: 47
Prediction: MATCH (d:D)-[:PRODUCEDBY](o:ORGANISM) WHERE toLower(d.id) CONTAINS "cỏ mần" RETURN o.id
Ground Truth: MATCH (d:D)-[:PRODUCEDBY](p:PRODUCTIONMETHOD) WHERE toLower(d.id) CONTAINS "cỏ mần" RETURN p.id
--------------------------------------------------


## Relation Type Confusion

In [13]:
df['rel_type_match'] = df.apply(lambda row: compare_relationship_type(row['prediction'], row['ground_truth']), axis=1)
rel_type_accuracy = df['rel_type_match'].value_counts().to_dict()

print(f"Relationship Type Match Accuracy: {rel_type_accuracy[True]}/{len(df)} ({rel_type_accuracy[True] / len(df) * 100:.0f}%)")
print(f"Relationship Type Mismatch Count: {rel_type_accuracy.get(False, 0)} ({rel_type_accuracy.get(False, 0) / len(df) * 100:.0f}%)")

print()

# mismatch samples
wrong_rel_type = df[df['rel_type_match'] == False].sample(3, random_state=42)
for idx, row in wrong_rel_type.iterrows():
    print(f"Index: {idx}")
    print(f"Prediction: {row['prediction']}")
    print(f"Ground Truth: {row['ground_truth']}")
    print("-" * 50)

Relationship Type Match Accuracy: 526/541 (97%)
Relationship Type Mismatch Count: 15 (3%)

Index: 320
Prediction: MATCH (t:TESTMETHOD)-[:HASSTANDARD](s:STANDARD) WHERE toLower(t.id) CONTAINS "sắc ký lỏng" RETURN s.id
Ground Truth: MATCH (m:TESTMETHOD)<-[:TESTEDBY]-(d:D)-[:HASSTANDARD](s:STANDARD) WHERE toLower(m.id) CONTAINS "sắc ký lỏng" RETURN s.id
--------------------------------------------------
Index: 360
Prediction: MATCH (d:D)-[:HASSTANDARD](s:STANDARD)-[:TESTEDBY](tm:TESTMETHOD) WHERE toLower(d.id) CONTAINS "tỷ lệ vụn nát" RETURN tm.id
Ground Truth: MATCH (s:STANDARD)-[:TESTEDBY](tm:TESTMETHOD) WHERE toLower(s.id) CONTAINS "tỷ lệ vụn nát" RETURN tm.id
--------------------------------------------------
Index: 105
Prediction: MATCH (tm:TESTMETHOD)<-[:TESTEDBY]-(d:D)-[:HASSTANDARD](s:STANDARD) WHERE toLower(tm.id) CONTAINS "phụ lục 9.8" RETURN s.id
Ground Truth: MATCH (tm:TESTMETHOD)<-[:TESTEDBY]-(s:STANDARD) WHERE toLower(tm.id) CONTAINS "phụ lục 9.8" RETURN s.id
---------------

## Entity Name identification error

In [14]:
## Entity Name identification error
df['entity_name_match'] = df.apply(lambda row: compare_entity_name(row['prediction'], row['ground_truth']), axis=1)
entity_name_accuracy = df['entity_name_match'].value_counts().to_dict()

print(f"Entity Name Match Accuracy: {entity_name_accuracy[True]}/{len(df)} ({entity_name_accuracy[True] / len(df) * 100:.0f}%)")
print(f"Entity Name Mismatch Count: {entity_name_accuracy.get(False, 0)} ({entity_name_accuracy.get(False, 0) / len(df) * 100:.0f}%)")

# sample
wrong_entity_name = df[df['entity_name_match'] == False].sample(3, random_state=42)
for idx, row in wrong_entity_name.iterrows():
    print(f"Index: {idx}")
    print(f"Prediction: {row['prediction']}")
    print(f"Ground Truth: {row['ground_truth']}")
    print("-" * 50)

Entity Name Match Accuracy: 511/541 (94%)
Entity Name Mismatch Count: 30 (6%)
Index: 473
Prediction: MATCH (t:TESTMETHOD)-[:REQUIRES](c:CHEMICAL) WHERE toLower(t.id) CONTAINS "sắc ký lớp mỏng (dung môi 2)" RETURN c.id
Ground Truth: MATCH (tm:TESTMETHOD)-[:REQUIRES](d:D) WHERE toLower(tm.id) CONTAINS "sắc ký lớp mỏng" AND toLower(tm.id) CONTAINS "toluen" RETURN d.id
--------------------------------------------------
Index: 346
Prediction: MATCH (t:TESTMETHOD)-[:HASSTANDARD](s:STANDARD) WHERE toLower(t.id) CONTAINS "phụ lục 9.6" RETURN s.id
Ground Truth: MATCH (tm:TESTMETHOD)-[:REQUIRES](s:STANDARD) WHERE toLower(tm.id) CONTAINS "phụ lục 9.6" AND toLower(s.id) CONTAINS "h" RETURN s.id
--------------------------------------------------
Index: 439
Prediction: MATCH (d:D)-[:TESTEDBY](tm:TESTMETHOD) WHERE toLower(d.id) CONTAINS "hemagglutinin" RETURN tm.id
Ground Truth: MATCH (d:D)-[:TESTEDBY](tm:TESTMETHOD) WHERE toLower(tm.id) CONTAINS "khuếch tán miễn dịch đơn" RETURN tm.id
--------------