In [None]:
!pip install -q transformers
!pip install -q datasets
!pip install -q evaluate
!pip install -q tokenizers
!pip install -q torch
!pip install -q evaluate
!pip install -q sacrebleu

In [None]:
## mounting drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
import json
from datasets import Dataset, DatasetDict
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments

from transformers import T5ForConditionalGeneration, TrainingArguments, Trainer, DataCollatorForSeq2Seq


## Enhanced Chunking Approach

- **Issue:** The dataset contained very long contract documents (some over 8,000 tokens), but our model has a 512-token limit, causing indexing errors during training.
- **Solution:** We implemented chunking with a sliding window to split long documents into manageable 512-token segments while preserving context.
- **Method:** The function first splits the text into words, then tokenizes each chunk, enforcing a strict 512-token limit by using truncation=True.
**First Overlap Strategy:** We overlapped chunks by 128 tokens to prevent loss of context when splitting text, ensuring smoother clause extractions.
**Handling Labels:** Instead of assigning all labels to every chunk, we filtered labels so that each chunk only contains the clauses present in that section, avoiding hallucination issues.
**Final Output:** The model will now train on many short, structured examples rather than long, unmanageable documents, improving extraction accuracy and preventing token errors.

**Improved Label from first baseline** Rather than showing model chunks with no inputs or questions, we add questions in every chunk so the model sees before evaluation

In [None]:
import json
import multiprocessing
from transformers import AutoTokenizer

# Load T5-small tokenizer
tokenizer = AutoTokenizer.from_pretrained("t5-small")

# Define maximum tokens and overlap
MAX_TOKENS = 512
OVERLAP = 256

# Chunking Function
def split_text_into_chunks(text, max_tokens, tokenizer, overlap=256):
    """
    Splits text into overlapping chunks while maintaining clause integrity.
    Uses a sliding window approach with a specified overlap size.
    """
    words = text.split()
    chunks = []
    start = 0

    while start < len(words):
        # Form a chunk from the current window
        chunk_words = words[start:start + max_tokens]
        chunk_text = " ".join(chunk_words)

        # Tokenize and ensure chunk length is within limits
        tokenized = tokenizer(chunk_text, add_special_tokens=False, truncation=True, max_length=max_tokens)
        tokenized_chunk = tokenized["input_ids"]
        truncated_text = tokenizer.decode(tokenized_chunk, skip_special_tokens=True)

        # Try to end the chunk at a sentence boundary if possible
        if len(truncated_text) < max_tokens and "." in truncated_text:
            end = truncated_text.rfind(".") + 1
            if end > 0:
                truncated_text = truncated_text[:end]

        chunks.append(truncated_text)
        # Move the start pointer forward by max_tokens minus overlap (sliding window)
        start += max_tokens - overlap

    return chunks

# Restructure Dataset for QA
def restructure_dataset_for_qa(contract, doc_id, tokenizer, max_tokens=512, overlap=256):
    """
    Restructures a contract into question-chunk pairs using a sliding window and includes metadata.
    Filters out QA pairs where the answer is not present.
    """
    contract_title = contract["title"]
    paragraphs = contract["paragraphs"]
    processed_data = []

    # Combine all paragraph texts into a single cleaned context
    full_text = " ".join([para["context"] for para in paragraphs])

    for paragraph in paragraphs:
        for qa in paragraph["qas"]:
            if qa["is_impossible"]:
                continue
            clause_type = qa["question"].split("related to \"")[1].split("\"")[0] if "related to \"" in qa["question"] else "General"
            question_text = qa["question"]
            if "Details:" in question_text:
                question_text = question_text.split("Details:")[0].strip()
            expected_answers = [ans["text"] for ans in qa["answers"]]

            # Generate chunks with the adjusted max length
            chunks = split_text_into_chunks(full_text, max_tokens, tokenizer, overlap)

            # Iterate through each chunk to pair it with the question and metadata
            for chunk in chunks:
                answer_presence = any(answer in chunk for answer in expected_answers)
                flag = "answer_present" if answer_presence else "answer_not_present"

                processed_data.append({
                    "doc_id": doc_id,
                    "contract_title": contract_title,
                    "clause_type": clause_type,
                    "question": question_text,
                    "input": chunk,
                    "expected_output": expected_answers if answer_presence else ["No answer"],
                    "flag": flag
                })

    return processed_data

# Processing All Contracts
def process_all_contracts(dataset, tokenizer):
    """
    Apply chunking and restructuring to all contracts in the dataset.
    """
    data_with_ids = [(contract, idx) for idx, contract in enumerate(dataset)]

    with multiprocessing.Pool(processes=8) as pool:
        # Use restructure_dataset_for_qa directly instead of a lambda
        chunked_results = pool.starmap(
            restructure_dataset_for_qa,  # Pass the function directly
            [(contract, doc_id, tokenizer) for contract, doc_id in data_with_ids]
        )

    # Flatten the list of results
    restructured_dataset = [sample for sublist in chunked_results for sample in sublist]
    return restructured_dataset

# Load the CUAD dataset
with open('/content/drive/My Drive/Colab Notebooks/NLP_266_Project/Data/CUAD_v1.json', 'r', encoding='utf-8') as file:
    raw_data = json.load(file)["data"]

# Apply chunking and restructuring
chunked_dataset = process_all_contracts(raw_data, tokenizer)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

In [None]:
#save preprocessed_data
import json
with open("/content/drive/My Drive/Colab Notebooks/NLP_266_Project/CUAD_chunked_dataset_t5.json", "w") as f:
    json.dump(chunked_dataset, f, indent=4)

In [None]:
len(chunked_dataset)

270330

In [None]:
#balance of chunked_dataset
import pandas as pd
df = pd.DataFrame(chunked_dataset)
df['flag'].value_counts()

Unnamed: 0_level_0,count
flag,Unnamed: 1_level_1
answer_not_present,248260
answer_present,22070


In [None]:
chunked_dataset

[{'doc_id': 0,
  'contract_title': 'LIMEENERGYCO_09_09_1999-EX-10-DISTRIBUTOR AGREEMENT',
  'clause_type': 'Document Name',
  'question': 'Highlight the parts (if any) of this contract related to "Document Name" that should be reviewed by a lawyer.',
  'input': 'EXHIBIT 10.6 DISTRIBUTOR AGREEMENT THIS DISTRIBUTOR AGREEMENT (the "Agreement") is made by and between Electric City Corp., a Delaware corporation ("Company") and Electric City of Illinois LLC ("Distributor") this 7th day of September, 1999. RECITALS A. The Company\'s Business. The Company is presently engaged in the business of selling an energy efficiency device, which is referred to as an "Energy Saver" which may be improved or otherwise changed from its present composition (the "Products"). The Company may engage in the business of selling other products or other devices other than the Products, which will be considered Products if Distributor exercises its options pursuant to Section 7 hereof. B. Representations. As an ind

### balancing dataset

In [None]:
import random

def undersample_no_answer_chunks(dataset, no_answer_label="no answer", target_ratio=0.40, seed=42):
    """
    Returns a filtered version of the dataset where 'no answer' chunks are undersampled
    to the desired ratio relative to the total dataset size.

    Args:
        dataset: list of dicts (chunked_dataset)
        no_answer_label (str): the consistent label used for 'no answer' chunks
        target_ratio (float): the proportion of 'no answer' chunks in the final dataset
        seed (int): for reproducible random sampling

    Returns:
        list of filtered dicts
    """
    random.seed(seed)

    # Separate chunks
    no_answer_chunks = [chunk for chunk in dataset if chunk['flag'] == "answer_not_present"]
    answer_chunks = [chunk for chunk in dataset if chunk['flag'] == "answer_present"]

    # Calculate how many 'no answer' chunks to retain
    target_no_answer_count = int((len(answer_chunks) / (1 - target_ratio)) * target_ratio)

    # Randomly sample 'no answer' chunks
    sampled_no_answer_chunks = random.sample(no_answer_chunks, min(target_no_answer_count, len(no_answer_chunks)))

    # Combine and shuffle
    final_dataset = answer_chunks + sampled_no_answer_chunks
    random.shuffle(final_dataset)

    return final_dataset
chunked_dataset_t5 = undersample_no_answer_chunks(chunked_dataset)
print(f"Filtered dataset size: {len(chunked_dataset_t5)}")


Filtered dataset size: 36783


In [None]:
#save filtered dataset
import json
with open("/content/drive/My Drive/Colab Notebooks/NLP_266_Project/Data/CUAD_chunked_dataset_t5_balanced.json", "w") as f:
    json.dump(chunked_dataset_t5, f, indent=4)

save dataset

In [None]:
# Save the dataset as JSON
with open("/content/drive/My Drive/Colab Notebooks/NLP_266_Project/chunked_datasets/chunked_dataset.json", "w") as f:
    json.dump(chunked_dataset, f, indent=4)

In [None]:
import datasets
from datasets import Dataset, DatasetDict

# Convert to Hugging Face Dataset format
hf_dataset = Dataset.from_list(chunked_dataset_t5)

# Split into train (80%), validation (10%), and test (10%)
train_test_split = hf_dataset.train_test_split(test_size=0.2, seed=42)
valid_test_split = train_test_split["test"].train_test_split(test_size=0.5, seed=42)

# Correctly label the dataset splits
dataset = DatasetDict({
    "train": train_test_split["train"],
    "validation": valid_test_split["train"],  # Corrected label
    "test": valid_test_split["test"]  # Corrected label
})


In [None]:
dataset

DatasetDict({
    train: Dataset({
        features: ['doc_id', 'contract_title', 'clause_type', 'question', 'input', 'expected_output', 'flag'],
        num_rows: 29426
    })
    validation: Dataset({
        features: ['doc_id', 'contract_title', 'clause_type', 'question', 'input', 'expected_output', 'flag'],
        num_rows: 3678
    })
    test: Dataset({
        features: ['doc_id', 'contract_title', 'clause_type', 'question', 'input', 'expected_output', 'flag'],
        num_rows: 3679
    })
})

In [None]:
#save dataset to disk
dataset.save_to_disk("/content/drive/My Drive/Colab Notebooks/NLP_266_Project/Data/hf_chunked_datasets_t5")

Saving the dataset (0/1 shards):   0%|          | 0/29426 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3678 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3679 [00:00<?, ? examples/s]

In [None]:
ls /content/drive/My\ Drive/Colab\ Notebooks/NLP_266_Project/Data/

bart_chunked_dataset_filtered.json     CUAD_no_answer_filtered_chunked.json
bart_chunked_dataset.json              CUAD_preprocessed_data.json
bart_chunked_dataset_sorted.json       CUAD_v1.json
[0m[01;34mbart_dataset_sorted[0m/                   [01;34mfiltered_bert_chunked_dataset[0m/
bart_fusion_dataset.json               filtered_cuad.json
chunked_dataset.json                   filtered_cuad_orient.json
CUAD_chunked_dataset_t5_balanced.json  instruction_finetuned_data.json
CUAD_no_answer_chunked.json            [01;34mtokenized_bart_dataset_sorted[0m/


In [None]:
#load CUAD_chunked_dataset_t5_balanced.json
import json
with open("/content/drive/My Drive/Colab Notebooks/NLP_266_Project/Data/CUAD_chunked_dataset_t5_balanced.json", "r") as f:
    chunked_dataset_t5 = json.load(f)

In [None]:
len(chunked_dataset_t5)

36783

In [None]:
MAX_SEQUENCE_LENGTH = 512

def tokenize_function(examples):
    """
    Tokenizes the input and output texts for Flan-T5.
    """
    # Extract input and output directly from the batch dictionary
    inputs = examples['input']
    outputs = [ex[0] if ex else "" for ex in examples['expected_output']]

    # Tokenize the inputs (questions + context + flag) in batch mode
    encoded_inputs = tokenizer(
        inputs,
        max_length=MAX_SEQUENCE_LENGTH,
        padding='max_length',
        truncation=True
    )

    # Tokenize the outputs (expected answers or placeholders) in batch mode
    encoded_outputs = tokenizer(
        outputs,
        max_length=MAX_SEQUENCE_LENGTH,
        padding='max_length',
        truncation=True
    )

    # Replace padding token ID (0) with -100 for labels
    labels = encoded_outputs['input_ids']
    labels = [[-100 if token == tokenizer.pad_token_id else token for token in label] for label in labels]

    return {
        'input_ids': encoded_inputs['input_ids'],
        'attention_mask': encoded_inputs['attention_mask'],
        'labels': labels,
        'doc_id': examples['doc_id'],
        'contract_title': examples['contract_title'],
        'clause_type': examples['clause_type'],
        'question': examples['question']
    }

# Apply tokenization separately to each split
tokenized_datasets = {
    split: dataset[split].map(
        tokenize_function,
        batched=True,
        remove_columns=['doc_id', 'contract_title', 'clause_type', 'question', 'input', 'expected_output', 'flag']
    )
    for split in ['train', 'validation','test']
}

print(tokenized_datasets)


Map:   0%|          | 0/29426 [00:00<?, ? examples/s]

Map:   0%|          | 0/3678 [00:00<?, ? examples/s]

Map:   0%|          | 0/3679 [00:00<?, ? examples/s]

{'train': Dataset({
    features: ['doc_id', 'contract_title', 'clause_type', 'question', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 29426
}), 'validation': Dataset({
    features: ['doc_id', 'contract_title', 'clause_type', 'question', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 3678
}), 'test': Dataset({
    features: ['doc_id', 'contract_title', 'clause_type', 'question', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 3679
})}


In [None]:
tokenized_datasets

{'train': Dataset({
     features: ['doc_id', 'contract_title', 'clause_type', 'question', 'input_ids', 'attention_mask', 'labels'],
     num_rows: 29426
 }),
 'validation': Dataset({
     features: ['doc_id', 'contract_title', 'clause_type', 'question', 'input_ids', 'attention_mask', 'labels'],
     num_rows: 3678
 }),
 'test': Dataset({
     features: ['doc_id', 'contract_title', 'clause_type', 'question', 'input_ids', 'attention_mask', 'labels'],
     num_rows: 3679
 })}

## Baseline Model 2

saving tokenized dataset and hugging face dataset

In [None]:
# Save the hf dataset to a directory
dataset.save_to_disk("/content/drive/My Drive/Colab Notebooks/Data/NLP_266_Project/chunked_datasets_t5")

Saving the dataset (0/1 shards):   0%|          | 0/29426 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3678 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3679 [00:00<?, ? examples/s]

In [None]:
# Save the tokenized dataset to the specified folder
from datasets import DatasetDict
# Wrap the dictionary in a DatasetDict object
tokenized_datasets = DatasetDict(tokenized_datasets)
tokenized_datasets.save_to_disk("/content/drive/My Drive/Colab Notebooks/NLP_266_Project/tokenized_dataset_t5")
#save tokenize datase

Saving the dataset (0/1 shards):   0%|          | 0/29426 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3678 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3679 [00:00<?, ? examples/s]

### Inspecting Tokenization
verify how the tokenized inputs look for each model before training

In [None]:
for i in range(3):  # Checking 3 examples
    print(f"Example {i+1}:")
    print("Input IDs:", tokenized_datasets["train"][i]["input_ids"])
    print("Labels:", tokenized_datasets["train"][i]["labels"])
    print("="*80)

Example 1:
Input IDs: [8472, 4171, 254, 9562, 11951, 71, 3, 6934, 29027, 272, 476, 262, 196, 21898, 3, 19846, 476, 3347, 3, 8241, 196, 22704, 668, 4674, 180, 3073, 9562, 5477, 14912, 784, 1649, 12640, 1628, 11, 1602, 3569, 725, 908, 41, 208, 61, 3430, 1853, 3, 13885, 6037, 6197, 4936, 254, 8015, 411, 8775, 8834, 8015, 134, 3347, 9132, 188, 11300, 10466, 20006, 4417, 11300, 180, 3073, 9562, 209, 14912, 599, 23, 61, 599, 75, 61, 784, 26267, 51, 29, 2420, 57, 10310, 7, 24227, 908, 3430, 1853, 3, 13885, 6037, 6197, 4936, 254, 8015, 411, 8775, 8834, 8015, 134, 3347, 584, 24203, 23936, 9978, 4417, 11300, 180, 3073, 9562, 209, 16593, 599, 23, 61, 599, 75, 61, 784, 26267, 51, 29, 2420, 57, 28057, 9688, 13679, 3, 4171, 196, 21898, 3, 19846, 476, 16309, 6554, 8729, 17098, 5652, 10568, 3, 20452, 15397, 6, 3388, 25392, 6, 8472, 4132, 16892, 6431, 15397, 6, 262, 4, 6037, 5329, 24721, 4674, 16003, 29433, 4090, 30061, 134, 71, 13431, 2365, 3, 9744, 3347, 4674, 4083, 4569, 11430, 3001, 12689, 8859, 43

In [None]:
from datasets import DatasetDict, load_from_disk
#load the tokenized dataset
tokenized_datasets = load_from_disk("/content/drive/My Drive/Colab Notebooks/NLP_266_Project/tokenized_datasets/tokenized_dataset_t5")

In [None]:
#list files
!ls /content/drive/My\ Drive/Colab\ Notebooks/NLP_266_Project/tokenized_datasets/

CUAD_BERT_filtered		    CUAD_longt5_tokenized
CUAD_chunked_no_answer_tokenized    tokenized_bart_dataset_sorted
CUAD_chunked_no_answer_tokenized40  tokenized_dataset_t5
CUAD_filtered_tokenized


In [None]:
print(tokenized_datasets)

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 29426
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 3678
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 3679
    })
})


## Training Basleline 2 with improved chunking

In [None]:
import torch
import os
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq
from transformers import ( T5ForConditionalGeneration, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq)

from datasets import load_from_disk

# Load tokenizer and model
model_name = "t5-small"  # Change if needed
t5baseline_model2 = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Define model
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Training arguments
training_args = TrainingArguments(
    output_dir="./t5_baseline_model2a",
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="steps",
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=5,
    fp16=False,
    logging_dir="./logs",
    logging_steps=500,
    report_to="none",
    label_smoothing_factor=0.1
)

# Data collator for dynamic padding
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:

# Trainer setup
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator
)

# Train model
trainer.train()

  trainer = Trainer(
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss
1,2.7689,2.581137
2,2.6844,2.457074
3,2.5613,2.392651
4,2.47,2.359418
5,2.506,2.349454


TrainOutput(global_step=18395, training_loss=2.6603908247453876, metrics={'train_runtime': 3326.4878, 'train_samples_per_second': 44.23, 'train_steps_per_second': 5.53, 'total_flos': 1.991283925057536e+16, 'train_loss': 2.6603908247453876, 'epoch': 5.0})

In [None]:
# Modify this path to the location in your Drive where you want to save the part1 model
baseline_model_checkpoint_filepath = "/content/drive/My Drive/Colab Notebooks/NLP_266_Project/model_checkpoints/t5_baseline_model2a"

# Run this line only after you've trained the part1 model
t5baseline_model2.save_pretrained(baseline_model_checkpoint_filepath)

#save tokenizer
tokenizer.save_pretrained(baseline_model_checkpoint_filepath)

print(f"Model checkpoint saved at: {baseline_model_checkpoint_filepath}")

Model checkpoint saved at: /content/drive/My Drive/Colab Notebooks/NLP_266_Project/model_checkpoints/t5_baseline_model2a



## Evaluating Validation
---
 1. Load Model & Tokenizer for Inference



In [None]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from datasets import load_from_disk

# Load trained model and tokenizer
model_path = "/content/drive/My Drive/Colab Notebooks/NLP_266_Project/model_checkpoints/t5_baseline_model2a"
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Load tokenized dataset
#tokenized_dataset = load_from_disk("/content/drive/My Drive/Colab Notebooks/NLP_266_Project/tokenized_dataset")

### Chunk Level Evaluation

In [None]:
import torch
from tqdm import tqdm

MAX_SEQUENCE_LENGTH=512
# Batch size for evaluation
batch_size = 16
preds, refs = [], []

model.eval()
with torch.no_grad():
    for i in tqdm(range(0, len(tokenized_datasets['validation']), batch_size), desc="Evaluating"):
        batch = tokenized_datasets['validation'].select(range(i, min(i + batch_size, len(tokenized_datasets['validation']))))

        # Prepare inputs
        inputs = {
            "input_ids": torch.tensor(batch["input_ids"]).to("cuda"),
            "attention_mask": torch.tensor(batch["attention_mask"]).to("cuda")
        }

        # Generate predictions
        outputs = model.generate(**inputs, max_length=MAX_SEQUENCE_LENGTH)

        # Decode predictions
        decoded_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        # Process labels: remove -100 and decode
        flat_labels = [[token for token in label if token != -100] for label in batch['labels']]
        decoded_labels = tokenizer.batch_decode(flat_labels, skip_special_tokens=True)

        # Append to lists
        preds.extend(decoded_preds)
        refs.extend(decoded_labels)

        # Free up memory
        del inputs, outputs
        torch.cuda.empty_cache()

print("Predictions and references generated!")


Evaluating: 100%|██████████| 230/230 [02:05<00:00,  1.84it/s]

Predictions and references generated!





In [None]:
# Print a sample of predictions and references
num_samples_to_print = 200  # Number of examples to print
print("\nSample Predictions and References:")

for i in range(min(num_samples_to_print, len(preds))):
    print(f"Example {i + 1}:")
    print(f"Prediction: {preds[i]}")
    print(f"Reference: {refs[i]}")
    print("-" * 5)


Sample Predictions and References:
Example 1:
Prediction: No answer
Reference: No answer
-----
Example 2:
Prediction: SFJ Pharmaceuticals X, Ltd.
Reference: SFJ Pharmaceuticals X, Ltd.
-----
Example 3:
Prediction: May, 2000
Reference: GARMAN
-----
Example 4:
Prediction: September ___, 2019
Reference: No answer
-----
Example 5:
Prediction: No answer
Reference: No answer
-----
Example 6:
Prediction: Vendor
Reference: The scope of AT&T Audits shall also include: (i) practices and procedures used in performing the Services; (ii) systems, communications and information technology used in performing the Services; (iii) general controls and security practices and procedures; (iv) supporting information and calculations regarding invoices and compliance with service requirements; (v) quality initiatives and quality assurance; and (vi) compliance with the terms of this Agreement.
-----
Example 7:
Prediction: No answer
Reference: No answer
-----
Example 8:
Prediction: Subject to the terms and c

#### chunk rouge metrics

In [None]:
!pip install -q evaluate
!pip install -q rouge_score
!pip install -q sacrebleu
import evaluate

# Load the ROUGE metric
rouge = evaluate.load("rouge")
# Load the BLEU metric
bleu = evaluate.load("bleu")

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone


Downloading builder script:   0%|          | 0.00/5.94k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.34k [00:00<?, ?B/s]

In [None]:
# Calculate ROUGE scores
rouge_output = rouge.compute(predictions=preds, references=refs, use_stemmer=True)
print("Chunk level metrics T5 Baseline")
# Access ROUGE scores directly (without 'mid')
print("ROUGE-1:", rouge_output['rouge1'])  # Access the ROUGE-1 score
print("ROUGE-2:", rouge_output['rouge2'])  # Access the ROUGE-2 score
print("ROUGE-L:", rouge_output['rougeL'])  # Access the ROUGE-L score

#------------------BLEU SCORE
# Ensure predictions and references are correctly formatted as strings
formatted_preds = [pred if isinstance(pred, str) else " ".join(pred) for pred in preds]
formatted_refs = [[ref] if isinstance(ref, str) else [" ".join(r) for r in ref] for ref in refs]

# Calculate BLEU score
try:
    bleu_output = bleu.compute(predictions=formatted_preds, references=formatted_refs)
    print("BLEU:", bleu_output['bleu'])
except ValueError as e:
    print(f"Error calculating BLEU: {e}")

import numpy as np

# Calculate Exact Match (EM)
em_scores = [int(pred.strip() == ref.strip()) for pred, ref in zip(preds, refs)]
em_score = np.mean(em_scores)

# Display EM score
print("Exact Match (EM):", em_score)

# -------------------------
# Calculate F1 Score (simplified for binary match)
f1_scores = [1 if em else 0 for em in em_scores]  # Same as EM for binary
f1_score = np.mean(f1_scores)

# Display F1 score
print("F1 Score:", f1_score)

#------------------
# Jaccard similarity function
def jaccard_similarity(pred, ref):
    pred_set = set(pred.split())
    ref_set = set(ref.split())
    intersection = len(pred_set.intersection(ref_set))
    union = len(pred_set.union(ref_set))
    return intersection / union if union != 0 else 0

# Calculate Jaccard Similarity
jaccard_scores = [jaccard_similarity(pred, ref) for pred, ref in zip(preds, refs)]
jaccard_similarity_score = np.mean(jaccard_scores)

# Display Jaccard Similarity

print("Jaccard Similarity:", jaccard_similarity_score)

Chunk level metrics T5 Baseline
ROUGE-1: 0.4037973592297212
ROUGE-2: 0.38356232979814847
ROUGE-L: 0.40203928738953276
BLEU: 0.05759883472195489
Exact Match (EM): 0.39097335508428493
F1 Score: 0.39097335508428493
Jaccard Similarity: 0.3971510509947079


In [None]:
# Print a few formatted predictions and references to understand the structure
print("\nSample Formatted Predictions, References, and Expected Outputs for BLEU Calculation:")
for idx in range(min(50, len(formatted_preds))):
    print(f"Prediction {idx + 1}: {formatted_preds[idx]}")
    print(f"Reference {idx + 1}: {formatted_refs[idx]}")
    print(f"Expected Output {idx + 1}: {refs[idx]}")  # Printing the original expected output
    print("-" * 50)



Sample Formatted Predictions, References, and Expected Outputs for BLEU Calculation:
Prediction 1: No answer
Reference 1: ['No answer']
Expected Output 1: No answer
--------------------------------------------------
Prediction 2: SFJ Pharmaceuticals X, Ltd.
Reference 2: ['SFJ Pharmaceuticals X, Ltd.']
Expected Output 2: SFJ Pharmaceuticals X, Ltd.
--------------------------------------------------
Prediction 3: May, 2000
Reference 3: ['GARMAN']
Expected Output 3: GARMAN
--------------------------------------------------
Prediction 4: September ___, 2019
Reference 4: ['No answer']
Expected Output 4: No answer
--------------------------------------------------
Prediction 5: No answer
Reference 5: ['No answer']
Expected Output 5: No answer
--------------------------------------------------
Prediction 6: Vendor
Reference 6: ['The scope of AT&T Audits shall also include: (i) practices and procedures used in performing the Services; (ii) systems, communications and information technology us

#### chunk mismatch predictions

In [None]:
for i, (pred, ref) in enumerate(zip(preds, refs)):
    if pred.strip() != ref.strip():
        print(f"Example {i}")
        print(f"Prediction: {pred}")
        print(f"Reference:  {ref}")
        print("-" * 60)

Example 2
Prediction: May, 2000
Reference:  GARMAN
------------------------------------------------------------
Example 3
Prediction: September ___, 2019
Reference:  No answer
------------------------------------------------------------
Example 5
Prediction: Vendor
Reference:  The scope of AT&T Audits shall also include: (i) practices and procedures used in performing the Services; (ii) systems, communications and information technology used in performing the Services; (iii) general controls and security practices and procedures; (iv) supporting information and calculations regarding invoices and compliance with service requirements; (v) quality initiatives and quality assurance; and (vi) compliance with the terms of this Agreement.
------------------------------------------------------------
Example 7
Prediction: Subject to the terms and conditions of this Agreement, Arizona hereby grants to the Company a limited, non-exclusive, royalty-free license in, to and under the Arizona Licens

#### chunk level predictions

In [None]:
import evaluate
import numpy as np
from collections import Counter

def compute_chunk_level_metrics(preds, refs):
    # Init
    rouge = evaluate.load("rouge")
    bleu = evaluate.load("bleu")

    # ROUGE
    rouge_output = rouge.compute(predictions=preds, references=refs, use_stemmer=True)
    rouge1 = rouge_output["rouge1"]
    rouge2 = rouge_output["rouge2"]
    rougeL = rouge_output["rougeL"]

    # BLEU
    formatted_preds = [p if isinstance(p, str) else " ".join(p) for p in preds]
    formatted_refs = [[r] if isinstance(r, str) else [" ".join(r)] for r in refs]
    bleu_score = bleu.compute(predictions=formatted_preds, references=formatted_refs)["bleu"]

    # Exact Match
    exact_matches = [int(p.strip().lower() == r.strip().lower()) for p, r in zip(preds, refs)]
    em = np.mean(exact_matches)
    total_chunks = len(preds)
    num_correct = sum(exact_matches)
    num_incorrect = total_chunks - num_correct
    correct_pct = 100 * em
    incorrect_pct = 100 - correct_pct

    # F1
    def compute_f1(p, r):
        p_tokens, r_tokens = p.lower().split(), r.lower().split()
        common = Counter(p_tokens) & Counter(r_tokens)
        num_same = sum(common.values())
        if num_same == 0:
            return 0
        precision = num_same / len(p_tokens)
        recall = num_same / len(r_tokens)
        return 2 * precision * recall / (precision + recall)

    f1 = np.mean([compute_f1(p, r) for p, r in zip(preds, refs)])

    # Jaccard
    def jaccard(p, r):
        p_set, r_set = set(p.lower().split()), set(r.lower().split())
        return len(p_set & r_set) / len(p_set | r_set) if p_set | r_set else 1.0

    jaccard_score = np.mean([jaccard(p, r) for p, r in zip(preds, refs)])

    return {
        "Chunk-Level ROUGE-1": rouge1,
        "Chunk-Level ROUGE-2": rouge2,
        "Chunk-Level ROUGE-L": rougeL,
        "Chunk-Level BLEU": bleu_score,
        "Chunk-Level Exact Match": em,
        "Chunk-Level F1": f1,
        "Chunk-Level Jaccard": jaccard_score,
        "Total Chunks": total_chunks,
        "Correct Predictions": num_correct,
        "Incorrect Predictions": num_incorrect,
        "Correct %": correct_pct,
        "Incorrect %": incorrect_pct
    }

# Example usage
metrics = compute_chunk_level_metrics(preds, refs)
for k, v in metrics.items():
    print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")


Chunk-Level ROUGE-1: 0.4038
Chunk-Level ROUGE-2: 0.3836
Chunk-Level ROUGE-L: 0.4020
Chunk-Level BLEU: 0.0576
Chunk-Level Exact Match: 0.3915
Chunk-Level F1: 0.4013
Chunk-Level Jaccard: 0.3984
Total Chunks: 3678
Correct Predictions: 1440
Incorrect Predictions: 2238
Correct %: 39.1517
Incorrect %: 60.8483


### Document level Aggregation

In [None]:
tokenized_datasets["validation"]

Dataset({
    features: ['doc_id', 'contract_title', 'clause_type', 'question', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 3678
})

In [None]:
from collections import defaultdict

grouped_preds = defaultdict(list)
grouped_refs = defaultdict(list)
grouped_questions = {}

# Fix: Iterate directly over the dataset using its index
for i, example in enumerate(tokenized_datasets["validation"]):
    doc_id = example["doc_id"]
    grouped_preds[doc_id].append(preds[i].strip())
    grouped_refs[doc_id].append(refs[i].strip())
    grouped_questions[doc_id] = example["question"]  # assumes question is the same per doc

#### Option 1 Group all preductions and references by document
The approach groups predictions and references by document (doc_id) and evaluates them collectively at a document level. It calculates metrics like Jaccard similarity (token overlap) and Exact Match (set equality) for each document, aggregating these scores across all documents to provide average metrics and a detailed breakdown. Optionally, it highlights mismatches for further analysis.

The function doesn't explicitly "select" the best answer from the predictions. Instead, it aggregates all predictions for a document (doc_id) and evaluates them collectively against the references using metrics like Jaccard similarity and Exact Match. This means the evaluation focuses on the collective accuracy of all predictions for a document rather than identifying a single "best" answer. If you want a mechanism to pick the best answer per question, you might need to add a ranking or scoring system, such as confiden

In [None]:
from collections import defaultdict, Counter
import evaluate
import numpy as np

def evaluate_baseline_with_metrics(preds, refs, original_dataset, show_mismatches=True, max_show=5):
    grouped_preds = defaultdict(list)
    grouped_refs = defaultdict(list)

    for i, ex in enumerate(original_dataset):
        doc_id = ex["doc_id"]
        pred = preds[i].strip()
        ref = refs[i].strip()

        if pred.lower() != "no answer":
            grouped_preds[doc_id].append(pred)
        if ref.lower() != "no answer":
            grouped_refs[doc_id].append(ref)

    # Collapse multiple predictions/references per doc into single strings
    final_preds = {doc_id: " ".join(set(preds)) if preds else "No answer"
                   for doc_id, preds in grouped_preds.items()}
    final_refs = {doc_id: " ".join(set(refs)) if refs else "No answer"
                  for doc_id, refs in grouped_refs.items()}

    doc_ids = list(set(final_preds.keys()).union(set(final_refs.keys())))
    doc_preds = [final_preds.get(doc_id, "No answer") for doc_id in doc_ids]
    doc_refs = [final_refs.get(doc_id, "No answer") for doc_id in doc_ids]

    # Compute metrics
    rouge = evaluate.load("rouge")
    rouge_output = rouge.compute(predictions=doc_preds, references=doc_refs, use_stemmer=True)

    bleu = evaluate.load("bleu")
    formatted_preds = [p if isinstance(p, str) else " ".join(p) for p in doc_preds]
    formatted_refs = [[r] if isinstance(r, str) else [" ".join(r)] for r in doc_refs]
    bleu_score = bleu.compute(predictions=formatted_preds, references=formatted_refs)["bleu"]

    exact_matches = [int(p.strip().lower() == r.strip().lower()) for p, r in zip(doc_preds, doc_refs)]

    # F1 and Jaccard
    def compute_f1(pred, ref):
        pred_tokens = pred.lower().split()
        ref_tokens = ref.lower().split()
        common = Counter(pred_tokens) & Counter(ref_tokens)
        num_same = sum(common.values())
        if num_same == 0:
            return 0
        precision = num_same / len(pred_tokens)
        recall = num_same / len(ref_tokens)
        return 2 * precision * recall / (precision + recall)

    def jaccard_similarity(pred, ref):
        pred_set = set(pred.lower().split())
        ref_set = set(ref.lower().split())
        intersection = len(pred_set & ref_set)
        union = len(pred_set | ref_set)
        return intersection / union if union else 1.0

    f1_scores = [compute_f1(p, r) for p, r in zip(doc_preds, doc_refs)]
    jaccard_scores = [jaccard_similarity(p, r) for p, r in zip(doc_preds, doc_refs)]

    # Show sample mismatches
    shown = 0
    if show_mismatches:
        for i, (p, r) in enumerate(zip(doc_preds, doc_refs)):
            if p != r and shown < max_show:
                print(f"\n❌ MISMATCH in doc {doc_ids[i]}")
                print(f"> Prediction: {p}")
                print(f"> Reference:  {r}")
                shown += 1

    # Store doc_breakdown
    doc_breakdown = {}
    for i, doc_id in enumerate(doc_ids):
        doc_breakdown[doc_id] = {
            "preds": doc_preds[i],
            "refs": doc_refs[i],
            "exact_match": exact_matches[i]
        }

    num_docs = len(doc_ids)
    num_correct = sum(exact_matches)
    num_incorrect = num_docs - num_correct
    correct_pct = (num_correct / num_docs) * 100 if num_docs else 0
    incorrect_pct = 100 - correct_pct

    return {
        "ROUGE-1": rouge_output["rouge1"],
        "ROUGE-2": rouge_output["rouge2"],
        "ROUGE-L": rouge_output["rougeL"],
        "BLEU": bleu_score,
        "Exact Match": np.mean(exact_matches),
        "F1": np.mean(f1_scores),
        "Jaccard": np.mean(jaccard_scores),
        "Num Docs": num_docs,
        "Correct Predictions": num_correct,
        "Incorrect Predictions": num_incorrect,
        "Correct %": correct_pct,
        "Incorrect %": incorrect_pct,
        "doc_breakdown": doc_breakdown
    }

# ✅ Option 1 - Baseline Aggregation
agg_answers_results = evaluate_baseline_with_metrics(preds, refs, dataset["validation"])

print("\n📘 Document-Level Evaluation (Baseline - All Answers Kept)")
for k, v in agg_answers_results.items():
    if k == "doc_breakdown":
        continue
    print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")



❌ MISMATCH in doc 0
> Prediction: Distributor The term of this Agreement shall be ten (10) years (the "Term") which shall commence on the date upon which the Company delivers to Distributor the last Sample, as defined hereinafter.
> Reference:  Distributor

❌ MISMATCH in doc 2
> Prediction: No answer
> Reference:  To be covered by the Seller for 110% invoice value against All Risks and War Risk. SUPPLY CONTRACT

❌ MISMATCH in doc 3
> Prediction: This Agreement shall automatically be renewed for one (1) or more one (1) month periods unless either the Customer or i-on gives notice to the other party of its intention not to renew the 4 Agreement, which notice must be given not less than fifteen (15) days before the end of the respective initial or renewal term.
> Reference:  The term of this Agreement for the Hosted Site shall commence upon April 1, 1999 and shall continue for a period of six (6) months, unless earlier terminated in accordance with provisions hereof. This Agreement shall

In [None]:
print("\n📘 Matching Predictions (Baseline - All Answers Kept)")
shown = 0
max_show = 30

for doc_id, info in agg_answers_results["doc_breakdown"].items():
    if info["exact_match"]:
        question = next((ex["question"] for ex in dataset["validation"] if ex["doc_id"] == doc_id), "N/A")
        print(f"✔️ Document ID: {doc_id}")
        print(f"Question:     {question}")
        print(f"Predictions:  {info['preds']}")
        print(f"References:   {info['refs']}\n")
        shown += 1
        if shown >= max_show:
            break



📘 Matching Predictions (Baseline - All Answers Kept)
✔️ Document ID: 1
Question:     Highlight the parts (if any) of this contract related to "Anti-Assignment" that should be reviewed by a lawyer.
Predictions:  Distributor
References:   Distributor

✔️ Document ID: 48
Question:     Highlight the parts (if any) of this contract related to "Governing Law" that should be reviewed by a lawyer.
Predictions:  CURO MANAGEMENT, LLC
References:   CURO MANAGEMENT, LLC

✔️ Document ID: 53
Question:     Highlight the parts (if any) of this contract related to "Parties" that should be reviewed by a lawyer.
Predictions:  Servicer
References:   Servicer

✔️ Document ID: 58
Question:     Highlight the parts (if any) of this contract related to "Agreement Date" that should be reviewed by a lawyer.
Predictions:  March 1, 2016
References:   March 1, 2016

✔️ Document ID: 84
Question:     Highlight the parts (if any) of this contract related to "Governing Law" that should be reviewed by a lawyer.
Predict

#### Doc Aggregation Option 2
selects the most representative answer per document based on ROUGE-L similarity among all candidate predictions within a document.

Selecting Representative Answer:- While the earlier function evaluates all predictions collectively for each document, this function actively selects a single "most representative" answer using ROUGE-L similarity scores between candidate answers within the same document.

In [None]:
from collections import defaultdict, Counter
from rouge_score import rouge_scorer
import evaluate
import numpy as np

def evaluate_representative_answer(preds, refs, dataset, max_show=5):
    grouped_preds = defaultdict(list)
    grouped_refs = defaultdict(list)
    grouped_questions = {}

    for i, ex in enumerate(dataset["validation"]):
        doc_id = ex["doc_id"]
        grouped_questions[doc_id] = ex["question"]
        pred = preds[i].strip()
        ref = refs[i].strip()
        if pred.lower() != "no answer":
            grouped_preds[doc_id].append(pred)
        if ref.lower() != "no answer":
            grouped_refs[doc_id].append(ref)

    def select_most_representative(candidates):
        if not candidates:
            return "No answer"
        seen = set()
        unique = [c for c in candidates if not (c in seen or seen.add(c))]
        if len(unique) == 1:
            return unique[0]
        scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
        scores = []
        for i, cand in enumerate(unique):
            others = unique[:i] + unique[i+1:]
            avg_rouge = np.mean([scorer.score(cand, other)["rougeL"].fmeasure for other in others])
            scores.append((avg_rouge, cand))
        return max(scores, key=lambda x: x[0])[1]

    all_doc_ids = set(grouped_preds.keys()) | set(grouped_refs.keys())
    final_preds = {doc_id: select_most_representative(grouped_preds.get(doc_id, [])) for doc_id in all_doc_ids}
    final_refs = {doc_id: select_most_representative(grouped_refs.get(doc_id, [])) for doc_id in all_doc_ids}

    doc_ids = list(final_preds.keys())
    doc_preds = [final_preds[doc_id] for doc_id in doc_ids]
    doc_refs = [final_refs[doc_id] for doc_id in doc_ids]

    doc_breakdown = {}
    exact_matches = []
    f1_scores = []
    jaccard_scores = []

    for doc_id in all_doc_ids:
        pred = final_preds[doc_id]
        ref = final_refs[doc_id]
        match = pred.strip().lower() == ref.strip().lower()
        exact_matches.append(match)

        # F1
        p_tokens = pred.lower().split()
        r_tokens = ref.lower().split()
        common = Counter(p_tokens) & Counter(r_tokens)
        num_same = sum(common.values())
        precision = num_same / len(p_tokens) if p_tokens else 0
        recall = num_same / len(r_tokens) if r_tokens else 0
        f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0
        f1_scores.append(f1)

        # Jaccard
        jaccard = len(set(p_tokens) & set(r_tokens)) / len(set(p_tokens) | set(r_tokens)) if p_tokens or r_tokens else 1.0
        jaccard_scores.append(jaccard)

        doc_breakdown[doc_id] = {
            "question": grouped_questions.get(doc_id, "N/A"),
            "preds": pred,
            "refs": ref,
            "exact_match": match,
            "f1": f1,
            "jaccard": jaccard,
        }

    rouge = evaluate.load("rouge")
    rouge_output = rouge.compute(predictions=doc_preds, references=doc_refs, use_stemmer=True)

    bleu = evaluate.load("bleu")
    formatted_preds = [p if isinstance(p, str) else " ".join(p) for p in doc_preds]
    formatted_refs = [[r] if isinstance(r, str) else [" ".join(r)] for r in doc_refs]
    bleu_score = bleu.compute(predictions=formatted_preds, references=formatted_refs)["bleu"]

    # Optional: Show mismatches
    for i, doc_id in enumerate(doc_ids[:max_show]):
        if final_preds[doc_id] != final_refs[doc_id]:
            print(f"\n❌ MISMATCH in document {doc_id}")
            print(f"Question:   {grouped_questions[doc_id]}")
            print(f"Prediction: {final_preds[doc_id]}")
            print(f"Reference:  {final_refs[doc_id]}")

    num_docs = len(doc_preds)
    num_correct = sum(exact_matches)
    num_incorrect = num_docs - num_correct
    correct_pct = (num_correct / num_docs) * 100 if num_docs else 0
    incorrect_pct = 100 - correct_pct

    return {
        "ROUGE-1": rouge_output["rouge1"],
        "ROUGE-2": rouge_output["rouge2"],
        "ROUGE-L": rouge_output["rougeL"],
        "BLEU": bleu_score,
        "Exact Match": np.mean(exact_matches),
        "F1": np.mean(f1_scores),
        "Jaccard": np.mean(jaccard_scores),
        "Num Docs": num_docs,
        "Correct Predictions": num_correct,
        "Incorrect Predictions": num_incorrect,
        "Correct %": correct_pct,
        "Incorrect %": incorrect_pct,
        "per_document_breakdown": doc_breakdown
    }

# 🧪 Evaluate
best_rouge_answer_results = evaluate_representative_answer(preds, refs, dataset)

print("\n📘 Document-Level Evaluation (Option 2 - Best Answer via ROUGE-L)")
for k, v in best_rouge_answer_results.items():
    if k == "per_document_breakdown":
        continue
    print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")



❌ MISMATCH in document 2
Question:   Highlight the parts (if any) of this contract related to "Insurance" that should be reviewed by a lawyer.
Prediction: No answer
Reference:  SUPPLY CONTRACT

❌ MISMATCH in document 3
Question:   Highlight the parts (if any) of this contract related to "Effective Date" that should be reviewed by a lawyer.
Prediction: This Agreement shall automatically be renewed for one (1) or more one (1) month periods unless either the Customer or i-on gives notice to the other party of its intention not to renew the 4 Agreement, which notice must be given not less than fifteen (15) days before the end of the respective initial or renewal term.
Reference:  The term of this Agreement for the Hosted Site shall commence upon April 1, 1999 and shall continue for a period of six (6) months, unless earlier terminated in accordance with provisions hereof.

📘 Document-Level Evaluation (Option 2 - Best Answer via ROUGE-L)
ROUGE-1: 0.1773
ROUGE-2: 0.0969
ROUGE-L: 0.1645
BLEU

In [None]:
print("\n Matching Predictions (Option 2 - Best Answer via ROUGE-L)")
shown = 0
max_show = 300

per_doc = best_rouge_answer_results.get("per_document_breakdown", {})

if not per_doc:
    print(" No document breakdown found. Make sure 'per_document_breakdown' is returned from Option 2.")
else:
    for doc_id, details in per_doc.items():
        if details["exact_match"]:
            question = next((ex["question"] for ex in dataset["validation"] if ex["doc_id"] == doc_id), "N/A")
            print(f"Document ID: {doc_id}")
            print(f"Question:     {question}")
            print(f"Prediction:   {details['preds']}")
            print(f"Reference:    {details['refs']}\n")
            shown += 1
            if shown >= max_show:
                break

    if shown == 0:
        print(" No exact matches found at the document level for Option 2.")



 Matching Predictions (Option 2 - Best Answer via ROUGE-L)
Document ID: 0
Question:     Highlight the parts (if any) of this contract related to "Agreement Date" that should be reviewed by a lawyer.
Prediction:   Distributor
Reference:    Distributor

Document ID: 1
Question:     Highlight the parts (if any) of this contract related to "Anti-Assignment" that should be reviewed by a lawyer.
Prediction:   Distributor
Reference:    Distributor

Document ID: 5
Question:     Highlight the parts (if any) of this contract related to "Governing Law" that should be reviewed by a lawyer.
Prediction:   ADAMS GOLF
Reference:    ADAMS GOLF

Document ID: 36
Question:     Highlight the parts (if any) of this contract related to "Parties" that should be reviewed by a lawyer.
Prediction:   Schoolpop
Reference:    Schoolpop

Document ID: 39
Question:     Highlight the parts (if any) of this contract related to "Parties" that should be reviewed by a lawyer.
Prediction:   The term of this Agreement ("Ter

#### Option 3 Predictions based on EM and ROUGE if not EM

The function Groups predictions by document.Looks for an exact match between predicted chunks and any reference chunks.
If it can’t find an exact match, it falls back to rouge.

Pick the prediction based on exact match
- we have multiple predictions per document (from different chunks), and you want to:
- Compare each prediction to the reference(s),
- Select the one that has an exact string match,
- If none match, optionally fall back to  highest ROUGE.
- First, it tries to select a prediction that matches a reference exactly (case-insensitive, stripped).
- If no exact match is found, it falls back to either:

Then it computes standard metrics (ROUGE, BLEU, EM, F1, Jaccard) on the resulting final predictions.

In [None]:
from collections import defaultdict, Counter
from rouge_score import rouge_scorer
import evaluate
import numpy as np

def evaluate_exact_match_priority(preds, refs, dataset, max_show=5, fallback="rouge"):
    """
    Select best prediction per document via:
    1. Exact match to any reference chunk.
    2. If not found, fall back to most_common, first, or best ROUGE-L match.

    Args:
        preds: list of model predictions (chunk-level)
        refs: list of ground truth references (chunk-level)
        dataset: HuggingFace Dataset with doc_id and question
        max_show: number of mismatches to print
        fallback: one of ['rouge', 'most_common', 'first']

    Returns:
        Dict of document-level evaluation metrics
    """

    # --- Group predictions and references ---
    grouped_preds = defaultdict(list)
    grouped_refs = defaultdict(list)
    grouped_questions = {}

    for i, ex in enumerate(dataset["validation"]):
        doc_id = ex["doc_id"]
        grouped_questions[doc_id] = ex["question"]
        pred = preds[i].strip()
        ref = refs[i].strip()
        if pred.lower() != "no answer":
            grouped_preds[doc_id].append(pred)
        if ref.lower() != "no answer":
            grouped_refs[doc_id].append(ref)

    # --- ROUGE-based fallback selection ---
    def select_best_by_rouge(candidates):
        if not candidates:
            return "No answer"
        if len(candidates) == 1:
            return candidates[0]
        scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
        scores = []
        for i, cand in enumerate(candidates):
            others = candidates[:i] + candidates[i+1:]
            avg_score = np.mean([
                scorer.score(cand, other)["rougeL"].fmeasure for other in others
            ])
            scores.append((avg_score, cand))
        return max(scores, key=lambda x: x[0])[1]

    # --- Final predictions ---
    final_preds, final_refs = {}, {}

    for doc_id in grouped_preds:
        preds_list = grouped_preds[doc_id]
        refs_list = grouped_refs.get(doc_id, [])
        matched = None

        # Exact match first
        for pred in preds_list:
            if pred.strip().lower() in [r.strip().lower() for r in refs_list]:
                matched = pred
                break

        # Fallback logic
        if not matched:
            if fallback == "rouge":
                matched = select_best_by_rouge(preds_list)
            elif fallback == "most_common":
                matched = Counter(preds_list).most_common(1)[0][0]
            else:
                matched = preds_list[0]

        final_preds[doc_id] = matched
        final_refs[doc_id] = refs_list[0] if refs_list else "No answer"

    # --- Evaluation ---
    doc_ids = list(final_preds.keys())
    doc_preds = [final_preds[doc_id] for doc_id in doc_ids]
    doc_refs = [final_refs[doc_id] for doc_id in doc_ids]

    rouge = evaluate.load("rouge")
    rouge_output = rouge.compute(predictions=doc_preds, references=doc_refs, use_stemmer=True)

    bleu = evaluate.load("bleu")
    formatted_preds = [p if isinstance(p, str) else " ".join(p) for p in doc_preds]
    formatted_refs = [[r] if isinstance(r, str) else [" ".join(r)] for r in doc_refs]
    bleu_score = bleu.compute(predictions=formatted_preds, references=formatted_refs)["bleu"]

    exact_matches = [int(p.strip().lower() == r.strip().lower()) for p, r in zip(doc_preds, doc_refs)]
    f1_scores, jaccard_scores = [], []

    for p, r in zip(doc_preds, doc_refs):
        p_tokens, r_tokens = p.lower().split(), r.lower().split()
        common = Counter(p_tokens) & Counter(r_tokens)
        num_same = sum(common.values())

        precision = num_same / len(p_tokens) if p_tokens else 0
        recall = num_same / len(r_tokens) if r_tokens else 0
        f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0
        f1_scores.append(f1)

        jaccard = len(set(p_tokens) & set(r_tokens)) / len(set(p_tokens) | set(r_tokens)) if p_tokens or r_tokens else 1.0
        jaccard_scores.append(jaccard)

    # --- Show sample mismatches ---
    for i, doc_id in enumerate(doc_ids[:max_show]):
        if final_preds[doc_id].strip().lower() != final_refs[doc_id].strip().lower():
            print(f"\n❌ MISMATCH in document {doc_id}")
            print(f"Question:   {grouped_questions[doc_id]}")
            print(f"Prediction: {final_preds[doc_id]}")
            print(f"Reference:  {final_refs[doc_id]}")

    num_docs = len(doc_preds)
    num_correct = sum(exact_matches)
    num_incorrect = num_docs - num_correct
    correct_pct = (num_correct / num_docs) * 100 if num_docs else 0
    incorrect_pct = 100 - correct_pct

    return {
        "ROUGE-1": rouge_output["rouge1"],
        "ROUGE-2": rouge_output["rouge2"],
        "ROUGE-L": rouge_output["rougeL"],
        "BLEU": bleu_score,
        "Exact Match": np.mean(exact_matches),
        "F1": np.mean(f1_scores),
        "Jaccard": np.mean(jaccard_scores),
        "Num Docs": num_docs,
        "Correct Predictions": num_correct,
        "Incorrect Predictions": num_incorrect,
        "Correct %": correct_pct,
        "Incorrect %": incorrect_pct,
        "final_preds": final_preds,  # Add final_preds to the output
        "final_refs": final_refs    # Add final_refs to the output
    }

# apply
em_priority_results = evaluate_exact_match_priority(preds, refs, dataset, max_show=3, fallback="rouge")

print("\nEvaluation (Option 3: Exact Match Priority + Fallback)")
for k, v in em_priority_results.items():
    print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")



❌ MISMATCH in document 91
Question:   Highlight the parts (if any) of this contract related to "Post-Termination Services" that should be reviewed by a lawyer.
Prediction: May, 2000
Reference:  GARMAN

❌ MISMATCH in document 249
Question:   Highlight the parts (if any) of this contract related to "Parties" that should be reviewed by a lawyer.
Prediction: September ___, 2019
Reference:  "Licensee" and together with Licensor, the "Parties"),

Evaluation (Option 3: Exact Match Priority + Fallback)
ROUGE-1: 0.3044
ROUGE-2: 0.1874
ROUGE-L: 0.2957
BLEU: 0.1342
Exact Match: 0.2440
F1: 0.2968
Jaccard: 0.2779
Num Docs: 336
Correct Predictions: 82
Incorrect Predictions: 254
Correct %: 24.4048
Incorrect %: 75.5952
final_preds: {177: 'SFJ Pharmaceuticals X, Ltd.', 91: 'May, 2000', 249: 'September ___, 2019', 138: 'Vendor', 456: 'December 31, 2018', 44: 'Freedom Mortgage', 309: 'StartEngine Crowdfunding, Inc.', 150: 'This Agreement shall be governed and construed in accordance with the laws of the

### mistmatches predictions vs chunks seen

In [None]:
# --- Updated cell 77 code ---
shown = 0
max_examples = 5

# Get final_preds and final_refs from the function's output
final_preds = em_priority_results["final_preds"]
final_refs = em_priority_results["final_refs"]

for doc_id in final_preds:
    pred = final_preds[doc_id]
    ref = final_refs[doc_id]

    if pred.strip().lower() != ref.strip().lower():
        print(f"\n❌ MISMATCH in document {doc_id}")
        print(f"Question:   {grouped_questions[doc_id]}")
        print(f"Prediction: {pred}")
        print(f"Reference:  {ref}")

        # Show chunks the model had for this doc
        print("\nChunks seen:")
        doc_chunks = [example["input"] for example in dataset["validation"] if example["doc_id"] == doc_id]
        for i, chunk in enumerate(doc_chunks):
            print(f"\n-- Chunk {i+1} --\n{chunk[:9000]}")  # Show first 500 chars for readability

        shown += 1
        if shown >= max_examples:
            break


❌ MISMATCH in document 91
Question:   Highlight the parts (if any) of this contract related to "Post-Termination Services" that should be reviewed by a lawyer.
Prediction: May, 2000
Reference:  GARMAN

Chunks seen:

-- Chunk 1 --
with respect to its subject matter, and this Agreement supersedes all prior understandings, representations, negotiations and communications between the parties, oral and written. Dated the ____ day of May, 2000. GARMAN ROUTING SYSTEMS, INC.

-- Chunk 2 --
11. FEES Sparkling shall pay the fees as set out in the Fee Schedule in accordance with the terms of this Agreement and the Fee Schedule. 12. SPECIFICATIONS AND WARRANTY For the acceptance period and for a period of one year from the Maintenance Commencement Date, and thereafter for as long as the Software is covered by Maintenance Services and is used by Sparkling in accordance with this Agreement, Garman warrants that the Software shall perform in conformance with the Specifications in all material respec

In [None]:
import torch
from tqdm import tqdm

MAX_SEQUENCE_LENGTH=512
# Batch size for evaluation
batch_size = 16
preds, refs = [], []

model.eval()
with torch.no_grad():
    for i in tqdm(range(0, len(tokenized_datasets['validation']), batch_size), desc="Evaluating"):
        batch = tokenized_datasets['validation'].select(range(i, min(i + batch_size, len(tokenized_datasets['validation']))))

        # Prepare inputs
        inputs = {
            "input_ids": torch.tensor(batch["input_ids"]).to("cuda"),
            "attention_mask": torch.tensor(batch["attention_mask"]).to("cuda")
        }

        # Generate predictions
        outputs = model.generate(**inputs, max_length=MAX_SEQUENCE_LENGTH)

        # Decode predictions
        decoded_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        # Process labels: remove -100 and decode
        flat_labels = [[token for token in label if token != -100] for label in batch['labels']]
        decoded_labels = tokenizer.batch_decode(flat_labels, skip_special_tokens=True)

        # Append to lists
        preds.extend(decoded_preds)
        refs.extend(decoded_labels)

        # Free up memory
        del inputs, outputs
        torch.cuda.empty_cache()

print("Predictions and references generated!")


## Test Evaluation

In [None]:
import torch
from tqdm import tqdm
import pandas as pd

MAX_SEQUENCE_LENGTH = 512
batch_size = 16

model.eval()
test_preds, test_refs = [], []
doc_ids, clause_types, questions, titles = [], [], [], []

with torch.no_grad():
    for i in tqdm(range(0, len(tokenized_datasets['test']), batch_size), desc="Evaluating on Test"):
        batch = tokenized_datasets['test'].select(range(i, min(i + batch_size, len(tokenized_datasets['test']))))

        input_ids = torch.tensor(batch["input_ids"]).to("cuda")
        attention_mask = torch.tensor(batch["attention_mask"]).to("cuda")

        outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=MAX_SEQUENCE_LENGTH)
        decoded_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        # Remove -100s and decode
        flat_labels = [[t for t in label if t != -100] for label in batch["labels"]]
        decoded_labels = tokenizer.batch_decode(flat_labels, skip_special_tokens=True)

        # Extend predictions and references
        test_preds.extend(decoded_preds)
        test_refs.extend(decoded_labels)

        # Add metadata
        doc_ids.extend(batch["doc_id"])
        clause_types.extend(batch["clause_type"])
        questions.extend(batch["question"])
        titles.extend(batch["contract_title"])

        # Cleanup
        del input_ids, attention_mask, outputs
        torch.cuda.empty_cache()

print(" Predictions and references generated.")

# Create a DataFrame for inspection
test_df = pd.DataFrame({
    "doc_id": doc_ids,
    "contract_title": titles,
    "clause_type": clause_types,
    "question": questions,
    "reference": test_refs,
    "prediction": test_preds
})

# Save for later inspection
test_df.to_csv("t5_test_predictions.csv", index=False)
print("Saved to t5_test_predictions.csv")

# Show a preview
test_df.head()


In [None]:
# Save chunk level test to CSV
test_df.to_csv("t5_test_predictions.csv", index=False)
print("Saved to t5_test_predictions.csv")


📁 Saved to t5_test_predictions.csv


### test chunk scores

In [None]:
import evaluate
import numpy as np
from collections import Counter

def compute_chunk_level_metrics_test(preds, refs):
    # Init
    rouge = evaluate.load("rouge")
    bleu = evaluate.load("bleu")

    # ROUGE
    rouge_output = rouge.compute(predictions=preds, references=refs, use_stemmer=True)
    rouge1 = rouge_output["rouge1"]
    rouge2 = rouge_output["rouge2"]
    rougeL = rouge_output["rougeL"]

    # BLEU
    formatted_preds = [p if isinstance(p, str) else " ".join(p) for p in preds]
    formatted_refs = [[r] if isinstance(r, str) else [" ".join(r)] for r in refs]
    bleu_score = bleu.compute(predictions=formatted_preds, references=formatted_refs)["bleu"]

    # Exact Match
    exact_matches = [int(p.strip().lower() == r.strip().lower()) for p, r in zip(preds, refs)]
    em = np.mean(exact_matches)
    total_chunks = len(preds)
    num_correct = sum(exact_matches)
    num_incorrect = total_chunks - num_correct
    correct_pct = 100 * em
    incorrect_pct = 100 - correct_pct

    # F1
    def compute_f1(p, r):
        p_tokens, r_tokens = p.lower().split(), r.lower().split()
        common = Counter(p_tokens) & Counter(r_tokens)
        num_same = sum(common.values())
        if num_same == 0:
            return 0
        precision = num_same / len(p_tokens)
        recall = num_same / len(r_tokens)
        return 2 * precision * recall / (precision + recall)

    f1 = np.mean([compute_f1(p, r) for p, r in zip(preds, refs)])

    # Jaccard
    def jaccard(p, r):
        p_set, r_set = set(p.lower().split()), set(r.lower().split())
        return len(p_set & r_set) / len(p_set | r_set) if p_set | r_set else 1.0

    jaccard_score = np.mean([jaccard(p, r) for p, r in zip(preds, refs)])

    return {
        "Chunk-Level ROUGE-1": rouge1,
        "Chunk-Level ROUGE-2": rouge2,
        "Chunk-Level ROUGE-L": rougeL,
        "Chunk-Level BLEU": bleu_score,
        "Chunk-Level Exact Match": em,
        "Chunk-Level F1": f1,
        "Chunk-Level Jaccard": jaccard_score,
        "Total Chunks": total_chunks,
        "Correct Predictions": num_correct,
        "Incorrect Predictions": num_incorrect,
        "Correct %": correct_pct,
        "Incorrect %": incorrect_pct
    }

# Example usage
test_metrics = compute_chunk_level_metrics_test(test_preds, test_refs)
for k, v in test_metrics.items():  #
    print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")


Chunk-Level ROUGE-1: 0.4079
Chunk-Level ROUGE-2: 0.3827
Chunk-Level ROUGE-L: 0.4066
Chunk-Level BLEU: 0.0606
Chunk-Level Exact Match: 0.3955
Chunk-Level F1: 0.4065
Chunk-Level Jaccard: 0.4036
Total Chunks: 3679
Correct Predictions: 1455
Incorrect Predictions: 2224
Correct %: 39.5488
Incorrect %: 60.4512


#### Chunk-Level Evaluation by clause_type

In [None]:
import pandas as pd
from collections import defaultdict

def compute_clause_level_chunk_accuracy(preds, refs, clause_types):
    clause_stats = defaultdict(lambda: {"preds": [], "refs": []})

    for p, r, ct in zip(preds, refs, clause_types):
        clause_stats[ct]["preds"].append(p)
        clause_stats[ct]["refs"].append(r)

    table = []
    for clause, group in clause_stats.items():
        preds_group = group["preds"]
        refs_group = group["refs"]
        exact_matches = [int(p.strip().lower() == r.strip().lower()) for p, r in zip(preds_group, refs_group)]

        table.append({
            "Clause Type": clause,
            "Total Chunks": len(preds_group),
            "Exact Matches": sum(exact_matches),
            "Accuracy (%)": round(100 * (sum(exact_matches) / len(preds_group)), 2)
        })

    df = pd.DataFrame(table).sort_values("Accuracy (%)", ascending=False)
    return df

# Example usage
clause_chunk_df = compute_clause_level_chunk_accuracy(test_preds, test_refs, clause_types)
clause_chunk_df


Unnamed: 0,Clause Type,Total Chunks,Exact Matches,Accuracy (%)
31,No-Solicit Of Employees,23,18,78.26
22,Notice Period To Terminate Renewal,28,21,75.0
38,Affiliate License-Licensee,28,21,75.0
10,Competitive Restriction Exception,34,25,73.53
25,Non-Disparagement,15,11,73.33
6,Joint Ip Ownership,24,17,70.83
28,Governing Law,116,82,70.69
17,Uncapped Liability,44,31,70.45
14,Renewal Term,53,36,67.92
39,Price Restrictions,6,4,66.67


### Document level agggregation- test

In [None]:
from collections import Counter

def compute_f1(pred, ref):
    pred_tokens = pred.lower().split()
    ref_tokens = ref.lower().split()
    common = Counter(pred_tokens) & Counter(ref_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = num_same / len(pred_tokens)
    recall = num_same / len(ref_tokens)
    return 2 * precision * recall / (precision + recall)

def compute_jaccard(pred, ref):
    p_set = set(pred.lower().split())
    r_set = set(ref.lower().split())
    return len(p_set & r_set) / len(p_set | r_set) if p_set | r_set else 1.0


In [None]:
def compute_clause_type_accuracy(doc_ids, preds, refs, clause_types):
    """
    Compute clause-level matching per clause type.
    Returns: total, correct, incorrect, percent correct.
    """
    from collections import defaultdict

    clause_tracker = defaultdict(list)

    for doc_id, pred, ref, clause in zip(doc_ids, preds, refs, clause_types):
        clause_tracker[(doc_id, clause)].append((pred.strip(), ref.strip()))

    total = len(clause_tracker)
    correct = sum(
        all(p.lower() == r.lower() for p, r in pair_list)
        for pair_list in clause_tracker.values()
    )
    incorrect = total - correct
    return {
        "Total Clauses": total,
        "Correct Clauses": correct,
        "Wrong Clauses": incorrect,
        "Clause Accuracy %": round(100 * correct / total, 2) if total > 0 else 0.0
    }


#### Option 1; Aggregate all chunks

In [None]:
from collections import defaultdict
import evaluate
import numpy as np

def document_eval_option_1(preds, refs, doc_ids, clause_types):
    grouped_preds = defaultdict(list)
    grouped_refs = defaultdict(list)

    for i in range(len(doc_ids)):
        grouped_preds[doc_ids[i]].append(preds[i].strip())
        grouped_refs[doc_ids[i]].append(refs[i].strip())

    # Collapse per doc
    final_preds = {doc: " ".join(set(p)) for doc, p in grouped_preds.items()}
    final_refs = {doc: " ".join(set(r)) for doc, r in grouped_refs.items()}

    doc_ids_set = list(set(final_preds.keys()) | set(final_refs.keys()))
    doc_preds = [final_preds.get(doc, "No answer") for doc in doc_ids_set]
    doc_refs = [final_refs.get(doc, "No answer") for doc in doc_ids_set]

    # Accuracy calculations
    exact_matches = [int(p.strip().lower() == r.strip().lower()) for p, r in zip(doc_preds, doc_refs)]
    f1s = [compute_f1(p, r) for p, r in zip(doc_preds, doc_refs)]
    jaccards = [compute_jaccard(p, r) for p, r in zip(doc_preds, doc_refs)]

    rouge = evaluate.load("rouge")
    bleu = evaluate.load("bleu")
    rouge_scores = rouge.compute(predictions=doc_preds, references=doc_refs)
    bleu_score = bleu.compute(predictions=doc_preds, references=[[r] for r in doc_refs])["bleu"]

    # Clause-level stats
    clause_stats = compute_clause_type_accuracy(doc_ids, preds, refs, clause_types)

    # Per-document output for saving later
    doc_breakdown = [
        {
            "doc_id": doc,
            "prediction": final_preds[doc],
            "reference": final_refs[doc],
            "exact_match": int(final_preds[doc].strip().lower() == final_refs[doc].strip().lower())
        }
        for doc in doc_ids_set
    ]

    # New document-level stats
    correct_docs = sum(exact_matches)
    wrong_docs = len(exact_matches) - correct_docs
    doc_accuracy = 100 * correct_docs / len(exact_matches)

    return {
        "Approach": "All Answers Aggregated",
        "Exact Match": np.mean(exact_matches),
        "F1": np.mean(f1s),
        "Jaccard": np.mean(jaccards),
        "BLEU": bleu_score,
        "ROUGE-1": rouge_scores["rouge1"],
        "ROUGE-2": rouge_scores["rouge2"],
        "ROUGE-L": rouge_scores["rougeL"],
        "Num Docs": len(doc_ids_set),
        "Correct Docs": correct_docs,
        "Wrong Docs": wrong_docs,
        "Document Accuracy (%)": round(doc_accuracy, 2),
        **clause_stats,
        "doc_breakdown": doc_breakdown
    }


In [None]:
# Run and capture results
opt1 = document_eval_option_1(test_preds, test_refs, doc_ids, clause_types)

# Convert breakdown to DataFrame
import pandas as pd
df_doc_preds = pd.DataFrame(opt1["doc_breakdown"])

# Save to CSV
df_doc_preds.to_csv("option1_doc_level_predictions.csv", index=False)
print("Saved document-level predictions to 'option1_doc_level_predictions.csv'")


Saved document-level predictions to 'option1_doc_level_predictions.csv'


In [None]:
option1_result = document_eval_option_1(test_preds, test_refs, doc_ids, clause_types)
print("Option 1 Results:")
for k, v in option1_result.items():
    if k != "doc_breakdown":
        print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")


Option 1 Results:
Approach: All Answers Aggregated
Exact Match: 0.0574
F1: 0.2781
Jaccard: 0.2190
BLEU: 0.0411
ROUGE-1: 0.2821
ROUGE-2: 0.2044
ROUGE-L: 0.2603
Num Docs: 453
Correct Docs: 26
Wrong Docs: 427
Document Accuracy (%): 5.7400
Total Clauses: 2169
Correct Clauses: 907
Wrong Clauses: 1262
Clause Accuracy %: 41.8200


In [None]:
option1_result

{'Approach': 'All Answers Aggregated',
 'Exact Match': np.float64(0.05739514348785872),
 'F1': np.float64(0.2781216021929464),
 'Jaccard': np.float64(0.21898300085860165),
 'BLEU': 0.04114599838204027,
 'ROUGE-1': np.float64(0.28212604232240057),
 'ROUGE-2': np.float64(0.20438414307938363),
 'ROUGE-L': np.float64(0.2602903881985664),
 'Num Docs': 453,
 'Total Clauses': 2169,
 'Correct Clauses': 907,
 'Wrong Clauses': 1262,
 'Clause Accuracy %': 41.82,
 'doc_breakdown': [{'doc_id': 0,
   'prediction': 'No answer The term of this Agreement shall be ten (10) years (the "Term") which shall commence on the date upon which the Company delivers to Distributor the last Sample, as defined hereinafter.',
   'reference': 'No answer Distributor',
   'exact_match': 0},
  {'doc_id': 1,
   'prediction': 'No answer Distributor This Agreement shall commence on the Effective Date and, unless earlier terminated as set out in this Agreement, shall continue for the Term.',
   'reference': 'No answer Distri

#### Select most representative answer (ROUGE-L based)

In [None]:
def document_eval_option_2(preds, refs, doc_ids, clause_types, questions, titles):
    from collections import defaultdict
    from rouge_score import rouge_scorer
    import evaluate
    import numpy as np

    grouped_preds = defaultdict(list)
    grouped_refs = defaultdict(list)
    grouped_metadata = {}

    for i in range(len(doc_ids)):
        doc_id = doc_ids[i]
        grouped_preds[doc_id].append(preds[i].strip())
        grouped_refs[doc_id].append(refs[i].strip())
        if doc_id not in grouped_metadata:
            grouped_metadata[doc_id] = {
                "question": questions[i],
                "title": titles[i]
            }

    def select_most_representative(candidates):
        if not candidates:
            return "No answer"
        unique = list(set(candidates))
        if len(unique) == 1:
            return unique[0]
        scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
        scores = []
        for i, c in enumerate(unique):
            others = unique[:i] + unique[i+1:]
            avg = np.mean([scorer.score(c, o)["rougeL"].fmeasure for o in others])
            scores.append((avg, c))
        return max(scores, key=lambda x: x[0])[1]

    doc_ids_set = list(set(grouped_preds.keys()) | set(grouped_refs.keys()))
    final_preds = {doc: select_most_representative(grouped_preds.get(doc, [])) for doc in doc_ids_set}
    final_refs = {doc: select_most_representative(grouped_refs.get(doc, [])) for doc in doc_ids_set}

    doc_preds = [final_preds[doc] for doc in doc_ids_set]
    doc_refs = [final_refs[doc] for doc in doc_ids_set]

    exact_matches = [int(p.strip().lower() == r.strip().lower()) for p, r in zip(doc_preds, doc_refs)]
    f1s = [compute_f1(p, r) for p, r in zip(doc_preds, doc_refs)]
    jaccards = [compute_jaccard(p, r) for p, r in zip(doc_preds, doc_refs)]

    rouge = evaluate.load("rouge")
    bleu = evaluate.load("bleu")
    rouge_scores = rouge.compute(predictions=doc_preds, references=doc_refs)
    bleu_score = bleu.compute(predictions=doc_preds, references=[[r] for r in doc_refs])["bleu"]

    clause_stats = compute_clause_type_accuracy(doc_ids, preds, refs, clause_types)

    # Per-document breakdown with metadata
    doc_breakdown = []
    for doc in doc_ids_set:
        doc_breakdown.append({
            "doc_id": doc,
            "contract_title": grouped_metadata[doc]["title"],
            "question": grouped_metadata[doc]["question"],
            "prediction": final_preds[doc],
            "reference": final_refs[doc],
            "exact_match": int(final_preds[doc].strip().lower() == final_refs[doc].strip().lower())
        })

    # New doc-level stats
    correct_docs = sum(exact_matches)
    wrong_docs = len(exact_matches) - correct_docs
    doc_accuracy = 100 * correct_docs / len(exact_matches)

    return {
        "Approach": "Best Answer via ROUGE-L",
        "Exact Match": np.mean(exact_matches),
        "F1": np.mean(f1s),
        "Jaccard": np.mean(jaccards),
        "BLEU": bleu_score,
        "ROUGE-1": rouge_scores["rouge1"],
        "ROUGE-2": rouge_scores["rouge2"],
        "ROUGE-L": rouge_scores["rougeL"],
        "Num Docs": len(doc_ids_set),
        "Correct Docs": correct_docs,
        "Wrong Docs": wrong_docs,
        "Document Accuracy (%)": round(doc_accuracy, 2),
        **clause_stats,
        "doc_breakdown": doc_breakdown
    }


In [None]:
option2_result = document_eval_option_2(test_preds, test_refs, doc_ids, clause_types, questions, titles)
print("Option 2 Results:")
for k, v in option2_result.items():
    if k != "doc_breakdown":
        print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")


Option 2 Results:
Approach: Best Answer via ROUGE-L
Exact Match: 0.2561
F1: 0.2798
Jaccard: 0.2703
BLEU: 0.0304
ROUGE-1: 0.2820
ROUGE-2: 0.2566
ROUGE-L: 0.2765
Num Docs: 453
Correct Docs: 116
Wrong Docs: 337
Document Accuracy (%): 25.6100
Total Clauses: 2169
Correct Clauses: 907
Wrong Clauses: 1262
Clause Accuracy %: 41.8200


In [None]:
# Call the function
option2_result = document_eval_option_2(test_preds, test_refs, doc_ids, clause_types, questions, titles)

# Save breakdown
import pandas as pd
pd.DataFrame(option2_result["doc_breakdown"]).to_csv("option2_doc_level_predictions.csv", index=False)
print("Saved to option2_doc_level_predictions.csv")


Saved to option2_doc_level_predictions.csv


In [None]:
option2_result

{'Approach': 'Best Answer via ROUGE-L',
 'Exact Match': np.float64(0.2560706401766004),
 'F1': np.float64(0.2798374896659436),
 'Jaccard': np.float64(0.2702554392512793),
 'BLEU': 0.030405134092450548,
 'ROUGE-1': np.float64(0.28204097397813466),
 'ROUGE-2': np.float64(0.25662979713549616),
 'ROUGE-L': np.float64(0.27649742011591555),
 'Num Docs': 453,
 'Total Clauses': 2169,
 'Correct Clauses': 907,
 'Wrong Clauses': 1262,
 'Clause Accuracy %': 41.82,
 'doc_breakdown': [{'doc_id': 0,
   'contract_title': 'LIMEENERGYCO_09_09_1999-EX-10-DISTRIBUTOR AGREEMENT',
   'question': 'Highlight the parts (if any) of this contract related to "Parties" that should be reviewed by a lawyer.',
   'prediction': 'No answer',
   'reference': 'No answer',
   'exact_match': 1},
  {'doc_id': 1,
   'contract_title': 'WHITESMOKE,INC_11_08_2011-EX-10.26-PROMOTION AND DISTRIBUTION AGREEMENT',
   'question': 'Highlight the parts (if any) of this contract related to "Change Of Control" that should be reviewed by

#### Option 3 Exact match and best rouge

In [None]:
def document_eval_option_3(preds, refs, doc_ids, clause_types, questions, titles, fallback="rouge"):
    from collections import defaultdict, Counter
    from rouge_score import rouge_scorer
    import evaluate
    import numpy as np

    grouped_preds = defaultdict(list)
    grouped_refs = defaultdict(list)
    grouped_metadata = {}

    for i in range(len(doc_ids)):
        doc_id = doc_ids[i]
        grouped_preds[doc_id].append(preds[i].strip())
        grouped_refs[doc_id].append(refs[i].strip())
        if doc_id not in grouped_metadata:
            grouped_metadata[doc_id] = {
                "question": questions[i],
                "title": titles[i]
            }

    def select_most_representative(candidates):
        if not candidates:
            return "No answer"
        unique = list(set(candidates))
        if len(unique) == 1:
            return unique[0]
        scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
        scores = []
        for i, c in enumerate(unique):
            others = unique[:i] + unique[i+1:]
            avg = np.mean([scorer.score(c, o)["rougeL"].fmeasure for o in others])
            scores.append((avg, c))
        return max(scores, key=lambda x: x[0])[1]

    def select_with_fallback(p_list, r_list):
        for p in p_list:
            if p.strip().lower() in [r.strip().lower() for r in r_list]:
                return p
        if fallback == "rouge":
            return select_most_representative(p_list)
        elif fallback == "most_common":
            return Counter(p_list).most_common(1)[0][0]
        return p_list[0]

    doc_ids_set = list(set(grouped_preds.keys()))
    final_preds = {doc: select_with_fallback(grouped_preds[doc], grouped_refs.get(doc, [])) for doc in doc_ids_set}
    final_refs = {doc: grouped_refs.get(doc, ["No answer"])[0] for doc in doc_ids_set}

    doc_preds = [final_preds[doc] for doc in doc_ids_set]
    doc_refs = [final_refs[doc] for doc in doc_ids_set]

    exact_matches = [int(p.strip().lower() == r.strip().lower()) for p, r in zip(doc_preds, doc_refs)]
    f1s = [compute_f1(p, r) for p, r in zip(doc_preds, doc_refs)]
    jaccards = [compute_jaccard(p, r) for p, r in zip(doc_preds, doc_refs)]

    rouge = evaluate.load("rouge")
    bleu = evaluate.load("bleu")
    rouge_scores = rouge.compute(predictions=doc_preds, references=doc_refs)
    bleu_score = bleu.compute(predictions=doc_preds, references=[[r] for r in doc_refs])["bleu"]

    clause_stats = compute_clause_type_accuracy(doc_ids, preds, refs, clause_types)

    # Document-level stats
    num_docs = len(doc_ids_set)
    correct_docs = sum(exact_matches)
    wrong_docs = num_docs - correct_docs
    doc_acc = 100 * correct_docs / num_docs if num_docs else 0

    # Build document-level metadata table
    doc_breakdown = []
    for doc in doc_ids_set:
        doc_breakdown.append({
            "doc_id": doc,
            "contract_title": grouped_metadata[doc]["title"],
            "question": grouped_metadata[doc]["question"],
            "prediction": final_preds[doc],
            "reference": final_refs[doc],
            "exact_match": int(final_preds[doc].strip().lower() == final_refs[doc].strip().lower())
        })

    return {
        "Approach": "Exact Match → Fallback",
        "Exact Match": np.mean(exact_matches),
        "F1": np.mean(f1s),
        "Jaccard": np.mean(jaccards),
        "BLEU": bleu_score,
        "ROUGE-1": rouge_scores["rouge1"],
        "ROUGE-2": rouge_scores["rouge2"],
        "ROUGE-L": rouge_scores["rougeL"],
        "Num Docs": num_docs,
        "Correct Docs": correct_docs,
        "Wrong Docs": wrong_docs,
        "Document Accuracy (%)": doc_acc,
        "doc_breakdown": doc_breakdown,
        **clause_stats
    }
option3_result = document_eval_option_3(test_preds, test_refs, doc_ids, clause_types,questions, titles, fallback="rouge")
print("Option 3 Results:")
for k, v in option3_result.items():
    if k != "doc_breakdown":
        print(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}")


Option 3 Results:
Approach: Exact Match → Fallback
Exact Match: 0.3422
F1: 0.3496
Jaccard: 0.3471
BLEU: 0.0494
ROUGE-1: 0.3515
ROUGE-2: 0.3205
ROUGE-L: 0.3487
Num Docs: 453
Correct Docs: 155
Wrong Docs: 298
Document Accuracy (%): 34.2163
Total Clauses: 2169
Correct Clauses: 907
Wrong Clauses: 1262
Clause Accuracy %: 41.8200


In [None]:
# Run and save results
option3_result = document_eval_option_3(test_preds, test_refs, doc_ids, clause_types, questions, titles)

# Save the per-document breakdown to CSV
import pandas as pd
pd.DataFrame(option3_result["doc_breakdown"]).to_csv("option3_doc_level_predictions.csv", index=False)
print("Saved to option3_doc_level_predictions.csv")


Saved to option3_doc_level_predictions.csv


In [None]:
option3_result

{'Approach': 'Exact Match → Fallback',
 'Exact Match': np.float64(0.34216335540838855),
 'F1': np.float64(0.34963994050154146),
 'Jaccard': np.float64(0.34706638101031545),
 'BLEU': 0.04943072731476541,
 'ROUGE-1': np.float64(0.35149979486749927),
 'ROUGE-2': np.float64(0.3204753562306943),
 'ROUGE-L': np.float64(0.34866105733525343),
 'Num Docs': 453,
 'doc_breakdown': [{'doc_id': 0,
   'contract_title': 'LIMEENERGYCO_09_09_1999-EX-10-DISTRIBUTOR AGREEMENT',
   'question': 'Highlight the parts (if any) of this contract related to "Parties" that should be reviewed by a lawyer.',
   'prediction': 'No answer',
   'reference': 'Distributor',
   'exact_match': 0},
  {'doc_id': 1,
   'contract_title': 'WHITESMOKE,INC_11_08_2011-EX-10.26-PROMOTION AND DISTRIBUTION AGREEMENT',
   'question': 'Highlight the parts (if any) of this contract related to "Change Of Control" that should be reviewed by a lawyer.',
   'prediction': 'No answer',
   'reference': 'No answer',
   'exact_match': 1},
  {'do

#### Document Shared Utility Function for Clause Accuracy

In [None]:
#view all results
import pandas as pd

results_df = pd.DataFrame([option1_result, option2_result, option3_result])
results_df.set_index("Approach", inplace=True)
results_df.round(4)


Unnamed: 0_level_0,Exact Match,F1,Jaccard,BLEU,ROUGE-1,ROUGE-2,ROUGE-L,Num Docs
Approach,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
All Answers Aggregated,0.0574,0.2781,0.219,0.0411,0.2821,0.2044,0.2603,453
Best Answer via ROUGE-L,0.2561,0.2798,0.2703,0.0304,0.282,0.2566,0.2765,453
Exact Match → Fallback,0.3422,0.3496,0.3471,0.0494,0.3515,0.3205,0.3487,453


In [None]:
import pandas as pd

# Run all evaluations
opt1 = document_eval_option_1(test_preds, test_refs, doc_ids, clause_types)
opt2 = document_eval_option_2(test_preds, test_refs, doc_ids, clause_types, questions, titles)
opt3 = document_eval_option_3(test_preds, test_refs, doc_ids, clause_types, questions, titles)

# Helper to extract all relevant metrics
def extract_all_metrics(result_dict):
    return {
        "Approach": result_dict["Approach"],
        "Exact Match": result_dict["Exact Match"],
        "F1": result_dict["F1"],
        "Jaccard": result_dict["Jaccard"],
        "BLEU": result_dict["BLEU"],
        "ROUGE-1": result_dict["ROUGE-1"],
        "ROUGE-2": result_dict["ROUGE-2"],
        "ROUGE-L": result_dict["ROUGE-L"],
        "Num Docs": result_dict["Num Docs"],
        "Correct Docs": result_dict["Correct Docs"],
        "Wrong Docs": result_dict["Wrong Docs"],
        "Document Accuracy (%)": result_dict["Document Accuracy (%)"],
        "Total Clauses": result_dict["Total Clauses"],
        "Correct Clauses": result_dict["Correct Clauses"],
        "Wrong Clauses": result_dict["Wrong Clauses"],
        "Clause Accuracy (%)": result_dict["Clause Accuracy %"]
    }

# Build DataFrame
comparison_df = pd.DataFrame([
    extract_all_metrics(opt1),
    extract_all_metrics(opt2),
    extract_all_metrics(opt3)
])

# Optional: Round float values
float_cols = comparison_df.select_dtypes(include=["float"]).columns
comparison_df[float_cols] = comparison_df[float_cols].round(4)

# Show the full comparison
comparison_df


Unnamed: 0,Approach,Exact Match,F1,Jaccard,BLEU,ROUGE-1,ROUGE-2,ROUGE-L,Num Docs,Correct Docs,Wrong Docs,Document Accuracy (%),Total Clauses,Correct Clauses,Wrong Clauses,Clause Accuracy (%)
0,All Answers Aggregated,0.0574,0.2781,0.219,0.0411,0.2821,0.2044,0.2603,453,26,427,5.74,2169,907,1262,41.82
1,Best Answer via ROUGE-L,0.2561,0.2798,0.2703,0.0304,0.282,0.2566,0.2765,453,116,337,25.61,2169,907,1262,41.82
2,Exact Match → Fallback,0.3422,0.3496,0.3471,0.0494,0.3515,0.3205,0.3487,453,155,298,34.2163,2169,907,1262,41.82


In [None]:
set_correct_1 = set([doc_id for doc_id, d in enumerate(option1_result["doc_breakdown"]) if d["exact_match"]]) #17
set_correct_2 = set([doc_id for doc_id, d in enumerate(option2_result["doc_breakdown"]) if d["exact_match"]]) #123
set_correct_3 = set([doc_id for doc_id, d in enumerate(option3_result["doc_breakdown"]) if d["exact_match"]]) # 149

In [None]:
set_correct_1 = set([d["doc_id"] for d in option1_result["doc_breakdown"] if d["exact_match"]])
set_correct_2 = set([d["doc_id"] for d in option2_result["doc_breakdown"] if d["exact_match"]])
set_correct_3 = set([d["doc_id"] for d in option3_result["doc_breakdown"] if d["exact_match"]])

print(f"Option 1: {len(set_correct_1)}") #number of documents predicted for option 1
print(f"Option 2: {len(set_correct_2)}")
print(f"Option 3: {len(set_correct_3)}")


Option 1: 26
Option 2: 116
Option 3: 155


###  Extract Perfectly Predicted Docs for Each Option

In [None]:
# Get doc_ids with perfect prediction from doc_breakdown
def get_perfect_docs(result):
    return {entry["doc_id"] for entry in result["doc_breakdown"] if entry["exact_match"] == 1}


In [None]:
perfect_docs_opt1 = get_perfect_docs(option1_result)
perfect_docs_opt2 = get_perfect_docs(option2_result)
perfect_docs_opt3 = get_perfect_docs(option3_result)

print(f"Option 1 (All Answers Aggregated): {len(perfect_docs_opt1)} docs")
print(f"Option 2 (Best via ROUGE-L):       {len(perfect_docs_opt2)} docs")
print(f"Option 3 (Exact Match → Fallback): {len(perfect_docs_opt3)} docs")


Option 1 (All Answers Aggregated): 26 docs
Option 2 (Best via ROUGE-L):       116 docs
Option 3 (Exact Match → Fallback): 155 docs


In [None]:
print(f"Docs correct in all 3: {len(perfect_docs_opt1 & perfect_docs_opt2 & perfect_docs_opt3)}")
print(f"Docs only in Option 3: {len(perfect_docs_opt3 - perfect_docs_opt1 - perfect_docs_opt2)}")


Docs correct in all 3: 24
Docs only in Option 3: 89


#### To List the Document IDs Per Option


In [None]:
print("Perfectly Predicted Docs - Option 1 (All Answers Aggregated):")
print(sorted(perfect_docs_opt1))

print("\nPerfectly Predicted Docs - Option 2 (Best via ROUGE-L):")
print(sorted(perfect_docs_opt2))

print("\nPerfectly Predicted Docs - Option 3 (Exact Match → Fallback):")
print(sorted(perfect_docs_opt3))


Perfectly Predicted Docs - Option 1 (All Answers Aggregated):
[22, 33, 79, 95, 101, 107, 114, 147, 162, 168, 173, 217, 246, 255, 269, 328, 353, 395, 424, 429, 446, 447, 454, 476, 502, 503]

Perfectly Predicted Docs - Option 2 (Best via ROUGE-L):
[0, 1, 5, 10, 14, 21, 22, 26, 31, 32, 33, 40, 41, 50, 51, 56, 62, 63, 79, 80, 82, 86, 90, 95, 96, 98, 99, 101, 102, 107, 109, 114, 129, 134, 135, 141, 145, 147, 150, 152, 162, 163, 168, 169, 171, 173, 181, 190, 194, 198, 209, 210, 212, 216, 217, 224, 226, 230, 241, 246, 251, 255, 256, 259, 262, 269, 277, 279, 281, 282, 286, 290, 292, 296, 299, 307, 312, 323, 326, 328, 334, 339, 350, 353, 361, 362, 363, 366, 368, 376, 379, 382, 385, 386, 390, 395, 397, 401, 407, 411, 413, 422, 424, 427, 429, 446, 447, 454, 458, 474, 476, 483, 501, 502, 503, 504]

Perfectly Predicted Docs - Option 3 (Exact Match → Fallback):
[1, 9, 20, 21, 22, 24, 26, 30, 33, 37, 40, 42, 50, 51, 52, 53, 55, 56, 57, 62, 67, 79, 80, 82, 85, 86, 90, 95, 96, 97, 101, 107, 109, 114, 1

#### Inspecting  All Predictions from a Specific Doc based on doc id

In [None]:
def inspect_clauses_by_doc(doc_id, doc_ids, preds, refs, clause_types, questions, titles, max_show=20):
    print(f"\n📄 Document: {doc_id}")
    shown = 0
    for i in range(len(doc_ids)):
        if doc_ids[i] == doc_id and shown < max_show:
            match = "" if preds[i].strip().lower() == refs[i].strip().lower() else "❌"
            print(f"\nClause Type: {clause_types[i]}")
            print(f"Question:    {questions[i]}")
            print(f"Reference:   {refs[i]}")
            print(f"Prediction:  {preds[i]} {match}")
            shown += 1

for doc_id in list(perfect_docs_opt1)[:3]:
    inspect_clauses_by_doc(doc_id, doc_ids, test_preds, test_refs, clause_types, questions, titles)

def inspect_perfect_documents(perfect_doc_set, preds, refs, doc_ids, clause_types, questions, titles, max_docs=5):
    for doc_id in list(perfect_doc_set)[:max_docs]:
        inspect_clauses_by_doc(doc_id, doc_ids, preds, refs, clause_types, questions, titles)

inspect_perfect_documents(perfect_docs_opt3, test_preds, test_refs, doc_ids, clause_types, questions, titles)



📄 Document: 395

Clause Type: Anti-Assignment
Question:    Highlight the parts (if any) of this contract related to "Anti-Assignment" that should be reviewed by a lawyer.
Reference:   No answer
Prediction:  No answer 

Clause Type: Document Name
Question:    Highlight the parts (if any) of this contract related to "Document Name" that should be reviewed by a lawyer.
Reference:   Cooperation Agreement
Prediction:  COOPERATION AGREEMENT 

📄 Document: 269

Clause Type: Warranty Duration
Question:    Highlight the parts (if any) of this contract related to "Warranty Duration" that should be reviewed by a lawyer.
Reference:   No answer
Prediction:  No answer 

📄 Document: 147

Clause Type: Parties
Question:    Highlight the parts (if any) of this contract related to "Parties" that should be reviewed by a lawyer.
Reference:   No answer
Prediction:  No answer 

📄 Document: 1

Clause Type: Change Of Control
Question:    Highlight the parts (if any) of this contract related to "Change Of Contr

In [None]:
def clause_accuracy_per_method(doc_breakdown, dataset):
    from collections import Counter

    clause_correct = Counter()
    clause_total = Counter()

    for ex in dataset:
        doc_id = ex["doc_id"]
        clause = ex["clause_type"]

        if doc_id in doc_breakdown:
            clause_total[clause] += 1
            if doc_breakdown[doc_id]["exact_match"]:
                clause_correct[clause] += 1

    rows = []
    for clause in clause_total:
        total = clause_total[clause]
        correct = clause_correct[clause]
        acc = 100 * correct / total
        rows.append((clause, total, correct, total - correct, acc))

    df = pd.DataFrame(rows, columns=["Clause", "Total", "Correct", "Wrong", "Accuracy %"])
    return df.sort_values("Accuracy %", ascending=False)


### inspect_doc_level_clause_predictions

### Find Perfectly Predicted Documents

In [None]:
from collections import defaultdict

def get_docs_with_all_correct_clauses(doc_ids, preds, refs):
    """
    Identify document IDs for which ALL clauses were predicted correctly.

    Args:
        doc_ids (list): List of doc_id per clause.
        preds (list): Predicted answers per clause.
        refs (list): Reference answers per clause.

    Returns:
        Set of doc_ids that had perfect predictions.
    """

    clause_per_doc = defaultdict(list)

    for i in range(len(doc_ids)):
        doc_id = doc_ids[i]
        pred = preds[i].strip().lower()
        ref = refs[i].strip().lower()
        clause_per_doc[doc_id].append(pred == ref)

    # Check if all clauses were correct
    perfect_docs = {doc_id for doc_id, matches in clause_per_doc.items() if all(matches)}
    return perfect_docs


In [None]:
perfect_doc_ids = get_docs_with_all_correct_clauses(doc_ids, test_preds, test_refs)
print(f" Total Perfectly Predicted Documents: {len(perfect_doc_ids)}")
print("Examples:\n", list(perfect_doc_ids)[:5])


 Total Perfectly Predicted Documents: 20
Examples:
 [395, 269, 147, 33, 424]
