In [1]:
import json
import os
from tqdm import tqdm
import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer

# Paths
metadata_path = "../../ground_truth_dataset/multi_task_data3/test_hyp.json"
dataset_path = "../../ground_truth_dataset/mistral_format_multi_task3/mistral_test.jsonl"
results_output_path = "./evaluation/mistral-hyper-full3.json"

# Load model and tokenizer
base_model_path = "amazon/MistralLite"
fine_tuned_model_path = "./outputs/mistralLite-hypER-mixed-lora-out-full2/merged/"

tokenizer = AutoTokenizer.from_pretrained(fine_tuned_model_path)
fine_tuned_model = AutoModelForCausalLM.from_pretrained(
    fine_tuned_model_path,
    device_map="auto",
    torch_dtype="auto",
    trust_remote_code=True,
    attn_implementation="flash_attention_2",
)

pipe = pipeline("text-generation", model=fine_tuned_model, tokenizer=tokenizer)

generation_args = {
    "max_new_tokens": 50,
    "return_full_text": False,
    "temperature": 0.0,
    "do_sample": False,
}

# Load and parse JSONL dataset
def load_and_parse_jsonl(file_path):
    parsed_data = []
    with open(file_path, "r") as f:
        for line in f:
            entry = json.loads(line.strip())
            parsed_data.append(parse_jsonl_entry(entry))
    return parsed_data

# Parse individual JSONL entry
def parse_jsonl_entry(entry):
    """
    Parse a single JSONL entry into input text and the actual label.

    Args:
        entry (dict): A dictionary containing the "text" field.

    Returns:
        dict: A dictionary with 'input_text' (model input) and 'actual_label' (ground truth).
    """
    full_text = entry["text"]
    
    # Extract <|assistant|> response
    if "<|assistant|>" in full_text:
        input_text, actual_label = full_text.split("<|assistant|>", 1)
        input_text = input_text.strip() + "<|assistant|>"  # Ensure model prompt ends correctly
        actual_label = actual_label.strip().lower()  # Normalize actual label
        return {
            "input_text": input_text,
            "actual_label": actual_label
        }
    else:
        raise ValueError("Entry does not contain <|assistant|> tag.")

# Load and filter metadata
def load_metadata(file_path):
    with open(file_path, "r") as f:
        return json.load(f)

def filter_metadata(metadata):
    """Filter out entries where source or target abstract is missing."""
    filtered_metadata = []

    for entry in metadata:
        task_name = entry.get("task_name", "unknown")
        input_data = entry.get("input", {})

        if task_name == "hyper-1-hop":
            source_paper = input_data.get("source_paper", {})
            target_paper = input_data.get("target_paper", {})

            # Skip if abstracts are missing
            if not source_paper.get("abstract") or not target_paper.get("abstract"):
                continue  

        filtered_metadata.append(entry)

    return filtered_metadata

# Save predictions
def save_predictions(predictions, file_path):
    with open(file_path, "w") as f:
        json.dump(predictions, f, indent=2)
    print(f"Saved predictions to {file_path}")

def evaluate_model_on_dataset(mistral_dataset, metadata, output_path):
    predictions = []

    # Ensure dataset and metadata are aligned
    assert len(mistral_dataset) == len(metadata), f"Mismatch in dataset and metadata length! ({len(mistral_dataset)} vs {len(metadata)})"

    print("Starting inference...")
    for index, entry in tqdm(enumerate(mistral_dataset), total=len(mistral_dataset)):
        # Get corresponding metadata
        task_metadata = metadata[index]

        input_text = entry["input_text"]
        actual_label = entry["actual_label"]

        # Extract metadata fields
        task_name = task_metadata.get("task_name", "unknown")
        file_name = task_metadata.get("file_name", "unknown")
        chain_label = task_metadata.get("chain_label", "unknown")
        file_path = task_metadata.get("file_path", "unknown")

        try:
            # Directly pass input_text without formatting as a message
            output = pipe(input_text, **generation_args)
            predicted_output = output[0]['generated_text'].strip()

            if not predicted_output:
                predicted_output = "Generation failed"
        except Exception as e:
            print(f"Error with entry: {entry}\nError: {e}")
            predicted_output = "Generation failed"

        # Append metadata and prediction results
        predictions.append({
            "task_name": task_name,
            "file_name": file_name,
            "chain_label": chain_label,
            "file_path": file_path,
            "input": input_text,
            "actual_output": actual_label,
            "predicted_output": predicted_output
        })

    # Save results
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(predictions, f, indent=2, ensure_ascii=False)
    
    print(f"Predictions saved to {output_path}")


# Load datasets
mistral_dataset = load_and_parse_jsonl(dataset_path)
metadata = load_metadata(metadata_path)
filtered_metadata = filter_metadata(metadata)

# Run inference and save results
evaluate_model_on_dataset(mistral_dataset, filtered_metadata, results_output_path)


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

🚀 Starting inference...


  1%|          | 10/1839 [02:26<7:53:25, 15.53s/it]You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
100%|██████████| 1839/1839 [5:28:52<00:00, 10.73s/it]  


✅ Predictions saved to ./evaluation/mistral-hyper-full3.json


In [1]:
import json
import re
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
from collections import defaultdict

# Path to saved predictions
results_output_path = "./evaluation/mistral-base.json"

# Load predictions JSON
def load_predictions(file_path):
    """Load saved predictions from JSON."""
    with open(file_path, "r") as f:
        return json.load(f)

# Load the saved results
predictions = load_predictions(results_output_path)

# Initialize storage for task-wise results
task_results = defaultdict(lambda: {"y_true": [], "y_pred": [], "y_true_invalids": [], "y_pred_invalids": []})

# Normalize JSON output (remove formatting issues)
import re

def extract_validity_and_ids(json_str):
    """
    Extracts 'validity' status and any numeric paper IDs from a JSON string.
    Works even if the format is broken.
    """
    if not json_str or json_str.strip() == "":
        return None, []  # Return None if input is empty

    json_str = re.sub(r"```json|```", "", json_str)  # Removes '```json' and '```'
    json_str = re.sub(r"\s+", " ", json_str).strip() # Remove extra spaces and strip
    
    # Extract validity status (valid or invalid)
    validity_match = re.search(r'"validity"\s*:\s*"?(\bvalid\b|\binvalid\b)"?', json_str)

    validity = validity_match.group(1) if validity_match else None

    # Extract numbers from "invalid_paper_ids"
    invalid_ids = re.findall(r'\d+', json_str)  # Find all numbers

    return validity, list(map(int, invalid_ids))  # Convert numbers to integers


In [3]:
import json
import re

def extract_relevance_score(json_str, debug=False):
    """
    Extracts the relevance score from a JSON string or integer input.
    Handles extra text, multiple JSONs, and unexpected formats.

    Args:
        json_str (str or int): Input containing relevance score.
        debug (bool): If True, prints debug information.

    Returns:
        int or None: The extracted relevance score or None if not found.
    """

    # Directly return if already an integer
    if isinstance(json_str, int):
        return json_str  

    try:
        # Ensure input is a string
        if not isinstance(json_str, str):
            raise TypeError(f"Expected str or int, got {type(json_str)}")

        # Remove extra formatting (e.g., triple quotes, markdown artifacts)
        json_str = re.sub(r"```json|```", "", json_str).strip()

        # Extract JSON-like content even if mixed with extra text
        match = re.search(r'\{[^{}]*"relevance_score"[^{}]*\}', json_str)
        if match:
            json_str = match.group(0)  # Extract JSON substring

        # Parse JSON
        parsed_json = json.loads(json_str)

        # Directly return if the parsed JSON is an integer
        if isinstance(parsed_json, int):
            return parsed_json

        # Extract "relevance_score" if present
        if "relevance_score" in parsed_json:
            return int(parsed_json["relevance_score"])

        # Check nested structures
        for key, value in parsed_json.items():
            if isinstance(value, dict) and "relevance_score" in value:
                return int(value["relevance_score"])

        # If nothing matches, return None (skip from evaluation)
        if debug:
            print(f"⚠️ 'relevance_score' not found in JSON: {json_str}")
        return None  

    except (json.JSONDecodeError, KeyError, ValueError, TypeError) as e:
        if debug:
            print(f"Error parsing relevance score from: {json_str}\nError: {e}")
        return None  # Skip this entry for evaluation



In [4]:
predictions[0]

{'task_name': 'hyper-1-hop',
 'file_name': 'temporal_chain_CD005158_p-1.json',
 'chain_label': 'valid',
 'file_path': '../ground_truth_path/result_chains/temporal_chain_CD005158_p-1.json',
 'input': '<|prompter|>Hypotheses are frequently the starting point when undertaking the empirical portion of the scientific process. They state something that the scientific process will attempt to evaluate, corroborate, verify, or falsify. Their purpose is to guide the types of data we collect, analyses we conduct, and inferences we would like to make. You are a scientist. Your task is to evaluate the relevance of a target paper to a source paper in a one-hop citation context. You are given a source paper and a target paper that followed from it. Your job is to determine the degree of relevance between these papers based on scientific progression, logical dependence, and hypothesis inspiration.\n\nAssign a relevance score to the target paper based on the following criteria:\n- 0: No meaningful conn

In [6]:
# Print some sample results
for i in range(5):  # Print first 5 results
    print(f"\nPrediction {i}")
    print(f"Task Name: {predictions[i]['task_name']}")
    print(f"Actual Output: {predictions[i]['actual_output']}")
    print(f"Predicted Output: {predictions[i]['predicted_output']}")


🔍 Prediction 0
Task Name: hyper-1-hop
Actual Output: 2
Predicted Output: 1

🔍 Prediction 1
Task Name: hyper-multi-hop2_c1
Actual Output: ```json{"validity": valid,
"invalid_paper_ids": []}```
Predicted Output: 196 healthy Chinese male subjects were enrolled. The in vitro inhibition of platelet aggregation (IPA) was evaluated before and after ticagrelor incubated with platelet rich plasma from 196

🔍 Prediction 2
Task Name: hyper-multi-hop2_c2
Actual Output: ```json{"validity": valid,
"invalid_paper_ids": []}```
Predicted Output: The reasoning chain is invalid. The final hypotheses in the last paper do not logically depend on the previous papers.

Invalid paper IDs: 3, 4, 7

🔍 Prediction 3
Task Name: hyper-1-hop
Actual Output: 0
Predicted Output: 1

🔍 Prediction 4
Task Name: hyper-multi-hop2_c1
Actual Output: ```json{"validity": invalid,
"invalid_paper_ids": [2]}```
Predicted Output: 196 healthy Chinese male subjects were enrolled in this study. The in vitro inhibition of platelet aggr

In [None]:
from tqdm import tqdm

for index, entry in enumerate(tqdm(predictions, total=len(predictions))):
    task_name = entry["task_name"]
    actual_output = entry["actual_output"]
    predicted_output = entry["predicted_output"]

    if task_name == "hyper-1-hop":
        y_true = extract_relevance_score(actual_output)
        y_pred = extract_relevance_score(predicted_output)

    elif task_name in {"hyper-multi-hop2_c1", "hyper-multi-hop2_c2"}:
        y_true, y_true_invalids = extract_validity_and_ids(actual_output)
        y_pred, y_pred_invalids = extract_validity_and_ids(predicted_output)
        
        # Store invalid paper IDs
        task_results[task_name]["y_true_invalids"].append(y_true_invalids)
        task_results[task_name]["y_pred_invalids"].append(y_pred_invalids)

    else:
        continue  # Ignore unknown task names

    # Skip if extraction failed
    if y_true is None or y_pred is None:
        print(f"⚠️ Skipping entry {index} due to parsing issues.")
        continue

    # Store results in structured format
    task_results[task_name]["y_true"].append(y_true)
    task_results[task_name]["y_pred"].append(y_pred)

In [9]:
# Compute Accuracy & Metrics for Each Task
for task, results in task_results.items():
    y_true = results["y_true"]
    y_pred = results["y_pred"]

    accuracy = accuracy_score(y_true, y_pred)
    report = classification_report(y_true, y_pred, zero_division=0)

    print(f"\nTask: {task}")
    print(f"Accuracy: {accuracy * 100:.2f}%")
    print(report)



🔹 Task: hyper-1-hop
✅ Accuracy: 70.09%
              precision    recall  f1-score   support

           0       0.70      1.00      0.82       574
           1       0.00      0.00      0.00        42
           2       0.00      0.00      0.00       203

    accuracy                           0.70       819
   macro avg       0.23      0.33      0.27       819
weighted avg       0.49      0.70      0.58       819


🔹 Task: hyper-multi-hop2_c1
✅ Accuracy: 68.04%
              precision    recall  f1-score   support

     invalid       0.61      0.97      0.75       253
       valid       0.94      0.39      0.55       257

    accuracy                           0.68       510
   macro avg       0.77      0.68      0.65       510
weighted avg       0.77      0.68      0.65       510


🔹 Task: hyper-multi-hop2_c2
✅ Accuracy: 78.43%
              precision    recall  f1-score   support

     invalid       0.71      0.96      0.82       253
       valid       0.95      0.61      0.74    

In [10]:
def precision_recall_f1_jaccard(y_pred_invalids, y_true_invalids):
    """
    Computes precision, recall, F1-score, and Jaccard similarity for invalid paper IDs.
    """
    precisions, recalls, f1s, jaccards = [], [], [], []
    
    for pred_set, true_set in zip(y_pred_invalids, y_true_invalids):
        pred_set, true_set = set(pred_set), set(true_set)

        intersection = len(pred_set & true_set)
        union = len(pred_set | true_set)

        precision = intersection / len(pred_set) if pred_set else 0
        recall = intersection / len(true_set) if true_set else 0
        f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
        jaccard = intersection / union if union > 0 else 1  # Jaccard similarity

        precisions.append(precision)
        recalls.append(recall)
        f1s.append(f1)
        jaccards.append(jaccard)

    # Compute mean metrics
    avg_precision = np.mean(precisions)
    avg_recall = np.mean(recalls)
    avg_f1 = np.mean(f1s)
    avg_jaccard = np.mean(jaccards)

    return avg_precision, avg_recall, avg_f1, avg_jaccard


In [11]:
for task in ["hyper-multi-hop2_c1", "hyper-multi-hop2_c2"]:
    if task in task_results:
        precision, recall, f1, jaccard = precision_recall_f1_jaccard(
            task_results[task]["y_pred_invalids"], task_results[task]["y_true_invalids"]
        )

        print(f"\nTask: {task} (Invalid Paper ID Matching)")
        print(f"Precision: {precision:.2f}")
        print(f"Recall: {recall:.2f}")
        print(f"F1 Score: {f1:.2f}")
        print(f"Jaccard Similarity: {jaccard:.2f}")



🔹 Task: hyper-multi-hop2_c1 (Invalid Paper ID Matching)
✅ Precision: 0.12
✅ Recall: 0.06
✅ F1 Score: 0.07
✅ Jaccard Similarity: 0.26

🔹 Task: hyper-multi-hop2_c2 (Invalid Paper ID Matching)
✅ Precision: 0.12
✅ Recall: 0.06
✅ F1 Score: 0.07
✅ Jaccard Similarity: 0.36


In [19]:
# overall f1-score is the average accuracy from all tasks

# Compute overall f1-score
np.mean([0.58,0.68, 0.68,0.09,0.08 ])


0.6159999999999999