In [None]:
#calculate discovered rules' confidence and head coverage scores
import rdflib
import pandas as pd
import re

# Step 1: extract facts as RDFLib graph
def load_graph_from_nt(nt_file):
    g = rdflib.Graph()
    g.parse(nt_file, format="nt")
    return g

# Step 2: Execute SPARQL for head and body match counts
def execute_sparql(graph, sparql_query):
    result = graph.query(sparql_query)
    for row in result:
        return int(row[0])
    return 0

# Step 3: Calculate metrics for each rule
def calculate_metrics(rules_csv, graph, output_csv):
    rules_df = pd.read_csv(rules_csv, sep="\t", header=None)

    # Add new columns for metrics
    rules_df['Support'] = 0
    rules_df['Confidence'] = 0.0
    rules_df['Head Coverage'] = 0.0

    for index, row in rules_df.iterrows():
        rule = row[3]  # Rule column

        # Extract head and body patterns from the rule
        head_match = re.search(r'(https?://[\w\./#-]+)\(([^,]+),([^)]+)\)\s*<=', rule)
        body_matches = re.findall(r'(https?://[\w\./#-]+)\(([^,]+),([^)]+)\)', rule.split('<=')[1] if '<=' in rule else '')

        if not head_match or not body_matches:
            continue

        head_predicate = head_match.group(1)
        head_arg1 = head_match.group(2)
        head_arg2 = head_match.group(3)

        # Replace "X" with "?X" for variables
        if head_arg1 == "X":
            head_arg1 = "?X"
        else:
            head_arg1 = f"<{head_arg1}>"

        if head_arg2 == "X":
            head_arg2 = "?X"
        else:
            head_arg2 = f"<{head_arg2}>"



        # Extract body components and replace "X" with "?X" for variables
        body_predicate = body_matches[0][0]
        body_arg1 = body_matches[0][1]
        body_arg2 = body_matches[0][2]

        if body_arg1 == "X":
            body_arg1 = "?X"
        else:
            body_arg1 = f"<{body_arg1}>"

        if body_arg2 == "X":
            body_arg2 = "?X"
        else:
            body_arg2 = f"<{body_arg2}>"

        # SPARQL query for counting body & head matches
        body_sparql = f"""
        SELECT (COUNT(*) AS ?bodyMatchCount)
        WHERE {{
          {body_arg1} <{body_predicate}> {body_arg2} .
          {head_arg1} <{head_predicate}> {head_arg2} .
        }}
        """
        total_b_h_matches = execute_sparql(graph, body_sparql)

        # SPARQL query for counting body matches (positive examples)
        head_sparql = f"""
        SELECT (COUNT(*) AS ?headMatchCount)
        WHERE {{

          {body_arg1} <{body_predicate}> {body_arg2} .
        }}
        """
        total_b_matches  = execute_sparql(graph, head_sparql)
        body_sparql = f"""
        SELECT (COUNT(*) AS ?bodyMatchCount)
        WHERE {{
          {body_arg1} <{body_predicate}> {body_arg2} .
          {head_arg1} <{head_predicate}> ?w .
        }}
        """
        total_b_pca_matches = execute_sparql(graph, body_sparql)

        body_sparql = f"""
        SELECT (COUNT(*) AS ?bodyMatchCount)
        WHERE {{
          {head_arg1} <{head_predicate}> ?w .
        }}
        """
        total_h_r_matches = execute_sparql(graph, body_sparql)
        # Calculate metrics

        confidence = total_b_h_matches / total_b_matches if total_b_h_matches > 0 else 0
        head_coverage = total_b_h_matches / total_h_r_matches if total_h_r_matches > 0 else 0
        # Update the DataFrame
        rules_df.at[index, 'Support'] = total_b_h_matches
        rules_df.at[index, 'Confidence'] = confidence
        rules_df.at[index, 'Head Coverage'] = head_coverage

    rules_df.to_csv(output_csv, sep="\t", index=False)

if __name__ == "__main__":
    nt_file_path = 'path/to/MIMIC_non-code.nt'
    rules_csv_path = "path/to/output_1.0-0.99.csv"
    output_csv_path = "path/to/output_rules_with_metrics_1.0-0.99_new.csv"

    graph = load_graph_from_nt(nt_file_path)

    # Calculate metrics and update the rules CSV file
    calculate_metrics(rules_csv_path, graph, output_csv_path)

    print(f"Metrics calculated and saved to {output_csv_path}")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

def analyze_head_coverage_thresholds(df, hc_col, quantile_range=(80, 100)):
    """
    Analyze rule filtering based on Head Coverage quantile thresholds and plot results.
    """
    print(df.info())
    hc_series = df[hc_col]
    hc_filtered = hc_series[(hc_series != 0.0) & (~hc_series.isna())]



    # Create quantile steps

    quantiles = [q / 100 for q in range(80, 100)]
    thresholds = [hc_filtered.quantile(q) for q in quantiles]


    # Create results DataFrame
    df_plot = pd.DataFrame({
        'Quantile': quantiles,
        'Threshold': thresholds
    })

    plt.figure(figsize=(10, 4))
    plt.plot(df_plot['Quantile'], df_plot['Threshold'], marker='o')
    plt.title('Head Coverage Thresholds by Quantile')
    plt.xlabel('Quantile')
    plt.ylabel('HC Threshold')
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig("path/to/hc_thresh.svg")
    plt.show()


    return df_plot


In [None]:
#analyse HC thresholds of test set
df =  pd.read_csv("path/to/test_file.csv", sep='\t')
results = analyze_head_coverage_thresholds(df, hc_col='0.0.1_int')


In [None]:
#find f-score of valid set
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
desired_rules_path = "path/to/valid_file.csv"
all_rules_path = "path/to/output_rules_with_metrics_1.0-0.98.csv"
df_desired = pd.read_csv(desired_rules_path, sep="\t")
df_all = pd.read_csv(all_rules_path, sep="\t",skiprows=[1])

# Validate columns
hc_col = '0.0.1_int'
df_desired.rename(columns={"rule": "3"}, inplace=True)

match_cols = ['3']  # column(s) used to identify rule match

if hc_col not in df_desired.columns or hc_col not in df_all.columns:
    raise ValueError(f"Missing '{hc_col}' in one of the files.")
for col in match_cols:
    if col not in df_desired.columns or col not in df_all.columns:
        raise ValueError(f"Column '{col}' missing in one of the files.")

# Generate quantile thresholds
quantiles = [q / 100 for q in range(80, 100)]
thresholds = [df_all[hc_col].quantile(q) for q in quantiles]

# Compute recall for each threshold
recall_values = []

print("=== Recall per Threshold ===")
for q, thresh in zip(quantiles, thresholds):
    df_desired_filtered = df_desired[df_desired[hc_col] > thresh]
    df_all_filtered = df_all[df_all[hc_col] > thresh]

    if df_desired_filtered.empty:
        recall_values.append(0)
        print(f"Quantile {q:.2f} | Threshold: {thresh} | No desired rules above threshold â†’ Recall: 0.0000")
        continue

    merged = pd.merge(df_all_filtered, df_desired_filtered, on=match_cols, how='inner')
    matching_count = len(merged)

    recall = matching_count / len(df_desired)
    recall_values.append(recall)

    print(
        f"Quantile {q:.2f} | Threshold: {thresh} | "
        f"Relevant rule count in all rules: {matching_count} | "
        f"All relevant rules' count: {len(df_desired)} | Recall: {recall:.4f}"
    )

# Plotting
range_labels = [str(t) for t in thresholds]

# Create bar chart
plt.figure(figsize=(12, 6))
plt.bar(range_labels, recall_values, color='coral')

# Axis labels and title
plt.xlabel("Head Coverage Threshold")
plt.ylabel("Recall")
plt.title("Recall by Head Coverage Threshold")
plt.xticks(rotation=45)


# Format Y-axis to always show 4 digits after the decimal
plt.gca().yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))


plt.grid(axis='y')
plt.tight_layout()
plt.show()