# Fine-tuning a model on a Q&A task

In [1]:
import warnings
warnings.filterwarnings('ignore')
import transformers
import datasets as ds
from datasets import load_metric, Dataset
import pandas as pd
import numpy as np
import re
import pyarrow.parquet as pq
import Levenshtein
from sklearn.model_selection import train_test_split
from nltk.translate.bleu_score import sentence_bleu
import torch

## Loading the dataset

In [3]:
# Load the dataset package
SupremeCourtOfIsrael = ds.load_dataset('LevMuchnik/SupremeCourtOfIsrael')

# Convert the datasets to Dataframe (9 Min)
Hugging_Face_df = pd.DataFrame.from_dict(SupremeCourtOfIsrael['train'])

In [2]:
# Load the dataset package locally (30 Sec)
Hugging_Face_df = pq.read_table(source='./SupremeCourtOfIsrael/cases_all.parquet').to_pandas()

## Find Legal Clauses

In [3]:
# Define the list of Basic Laws
basic_laws = [
    'חוק-יסוד הכנסת',
    'חוק-יסוד מקרקעי ישראל',
    'חוק-יסוד נשיא המדינה',
    'חוק-יסוד משק המדינה',
    'חוק-יסוד הצבא',
    'חוק-יסוד ירושלים בירת ישראל',
    'חוק-יסוד השפיטה',
    'חוק-יסוד מבקר המדינה',
    'חוק-יסוד כבוד האדם וחירותו',
    'חוק-יסוד חופש העיסוק',
    'חוק-יסוד הממשלה',
    'חוק-יסוד משאל עם',
    'חוק-יסוד ישראל מדינת הלאום של העם היהודי'
]

# Prepare the regex pattern for Basic Laws
formats = [re.escape(law.split(' ')[1]) for law in basic_laws]  # Get the part after 'חוק-יסוד'
basic_laws_pattern = r'\bחוק-יסוד\s*[-: ]?\s*(?:' + '|'.join(formats) + r')\b'

# Define regex of legal clauses
legal_clauses_pattern = r'(?:תקנה|תקנות|סעיף|חוק|הלכה|פסיקה|פסיקת|צו|פקודה|פקודת|כלל|כללי|כללים|הוראה|הוראות)\s[א-ת\s–"\',()\[\]\-]*?(?:-|:|\s)?(?:ת[א-ת]{1,2}\s?\d{4}|\d{4})'

# Combine the patterns into one pattern that will match both legal clauses and basic Laws
combined_pattern = rf'{legal_clauses_pattern}|{basic_laws_pattern}'

In [4]:
def find_legal_text(text):
    # Find all legal text matches using the compiled combined pattern
    legal_matches = re.findall(combined_pattern, text, re.VERBOSE)
    
    # Filter out matches longer than 30 words
    legal_matches = [match.strip() for match in legal_matches if len(match.split()) <= 30]
    
    # Remove duplicates while preserving order
    seen = set()
    unique_matches = []
    for match in legal_matches:
        if all(Levenshtein.distance(match, existing) > 2 for existing in seen):
            unique_matches.append(match)
            seen.add(match)
    
    return unique_matches

In [5]:
# Remove rows with empty "text" column
Hugging_Face_df = Hugging_Face_df[Hugging_Face_df['text'].isna() == False]
print("Len:", len(Hugging_Face_df))

# Apply find_legal_text function to the 'text' column and store the result in a new column (3 min)
Hugging_Face_df['Legal_Clauses_Found'] = Hugging_Face_df['text'].apply(find_legal_text)

# Save only rows with legal clauses or precedents found in the "text" column
legal_df = Hugging_Face_df[(Hugging_Face_df['Legal_Clauses_Found'].apply(len) > 0)][['text', 'Legal_Clauses_Found']]


# Arange df
# --------------------------

# Rename columns
legal_df.rename(columns={'text': "context", 'Legal_Clauses_Found': 'answers'}, inplace=True)

# Create id column
legal_df['id'] = legal_df.index

# Create new column
legal_df['question'] = 'אילו סעיפים\חוקים\תקנות מצויינים במסמך ?'

# Reorder columns
legal_df = legal_df[['id', 'context', 'question', 'answers']]

# Replace newline characters with space
legal_df['context'] = legal_df['context'].str.replace("\n", " ")


# Find 'start_indices' and create a dataset
# ------------------------------------------

# Find the start indices of all the elements of the answers list in the context column
def find_start_indices(row):
    indices = []
    for answer in row['answers']:
        indices.append(row['context'].find(answer))
    return indices

# Apply the function to the df to get the indices of the start of both elements in the context column
legal_df['start_indices'] = legal_df.apply(find_start_indices, axis=1)

# Convert the 'answers' col to a dictionary
def convert_to_dict(answers, start_indices):
    return {'text': answers, 'answer_start': start_indices}

# Apply the function to each row of the df
legal_df['answers'] = legal_df.apply(lambda row: convert_to_dict(row['answers'], row['start_indices']), axis=1)

# Drop 'start_indices' col
legal_df.drop(['start_indices'], axis=1, inplace=True)

Len: 750841


## Find Precedents

In [3]:
# Define initial letters that might precede the prefixes
initial_letters = r'(?:\b(?:ב|וב|ה|וה)\b)?'  # Optional initial letters

additional_prefixes = [
    "אב\"ע", "א\"ת", "את\"פ", "אמ\"ץ", "פ\"ר", "אפ\"ח", "א\"פ", "ב\"ל", "וח\"ק",
    "בק\"מ", "ת\"ת", "ביד\"מ", "בדמ\"ש", "בע\"ק", "בפ\"מ", "עה\"פ", "בה\"ן", "בה\"פ",
    "בפ\"ת", "בש\"ע", "בת\"ת", "בב\"נ", "בע\"א", "בר\"ש", "בר\"ע", "שב\"ד",
    "שנ\"א", "גמ\"ר", "דמ\"ר", "דמ\"ש", "דנ\"א", "דנג\"ץ", "דנ\"מ", "דנ\"פ",
    "ד\"ט", "הס\"ת", "המ\"ע", "ה\"כ", "ה\"ת", "ה\"ט", "ה\"נ", "ה\"פ", "הפ\"ב",
    "הד\"פ", "ה\"ד", "ת\"ט", "תה\"ן", "ו\"ע", "ח\"א", "חב\"ר", "חע\"מ", "חע\"ק",
    "ח\"ד", "ח\"נ", "חס\"מ", "י\"ס", "כ\"צ", "מק\"מ", "מ\"י", "מי\"ב", "מ\"מ",
    "מ\"ת", "מ\"ח", "נע\"ד", "ס\"ע", "ס\"ק", "סק\"כ", "פ\"ל", "עמ\"א", "ע\"א",
    "עב\"ז", "ע\"ב", "עב\"ל", "עח\"ר", "ע\"נ", "ער\"מ", "עמ\"ח", "על\"ע", "עמ\"נ",
    "ע\"מ", "עמ\"מ", "עש\"מ", "עמ\"ש", "ענ\"א", "ענ\"פ", "ענמ\"ש", "עס\"ק", "ע\"ע",
    "עב\"י", "עמל\"ע", "עמש\"מ", "עמר\"מ", "ער\"פ", "ע\"ר", "עמ\"פ", "עש\"ר", "ע\"ו",
    "על\"ח", "עק\"נ", "עק\"פ", "עע\"מ", "עעת\"א", "ע\"פ", "עפ\"א", "עפ\"ג", "עפ\"ר",
    "עפ\"ת", "עפס\"פ", "עפ\"ס", "עש\"א", "ע\"ש", "עש\"ת", "ע\"ח", "עב\"פ", "עא\"פ",
    "עח\"ע", "עע\"ר", "עפ\"ע", "עה\"ג", "עמ\"י", "עמ\"ת", "עכ\"ב", "עק\"מ", "עח\"ק",
    "עפ\"מ", "עפ\"ן", "בג\"ץ", "עג\"ר", "עת\"מ", "עת\"א", "פק\"ח", "פר\"ק", "פ\"ה",
    "פש\"ר", "צ\"א", "צ\"ה", "צ\"ח", "מ\"כ", "צ\"ו", "ק\"פ", "ק\"ג", "רע\"ס", "רע\"א",
    "רע\"מ", "רמ\"ש", "רצ\"פ", "רע\"צ", "רע\"ו", "רע\"ב", "רעת\"א", "רע\"פ", "רע\"ש",
    "רת\"ק", "ש\"ש", "ש", "ש\"ע", "ת\"ד", "נ\"ב", "תמ\"ק", "תמ\"ר", "תנ\"ג", "ת\"ק",
    "ת\"ב", "סב\"א", "גז\"ז", "ח\"ש", "תג\"א", "ת\"ח", "תנ\"ז", "תע\"א", "ת\"צ", "ת\"מ",
    "תא\"מ", "תא\"ח", "תא\"ק", "ת\"א", "תה\"ג", "תה\"ס", "ע\"ל", "תל\"א", "תל\"ב",
    "תל\"פ", "תמ\"ש", "ת\"ע", "ת\"פ", "תפ\"ח", "ת\"ג", "תת\"ח", "תת\"ע", "תו\"ח", "תו\"ב",
    "המ\"ש", "הע\"ז", "ש\"מ", "שע\"מ", "בש\"א", "ר\"ע", "ראו", "למשל", "בת.פ.", "ת.א."
]

# Join the prefixes into a regex pattern, ensuring word boundaries
prefix_pattern = r'\b(?:' + '|'.join(re.escape(prefix) for prefix in additional_prefixes) + r')\b'

# Pattern for numbers (1-8 digits), joined or separated by dashes or slashes
number_pattern = r'\b\d{1,8}(?:[-/]\d{1,8})*\b'

# Combine the patterns
full_pattern = initial_letters + prefix_pattern + r'[ ,.:;!?]*' + number_pattern

In [4]:
# Find legal precedents within the "text" column
def find_legal_precedents(text):
    pattern = full_pattern

    # Replace newline characters with space
    text = text.replace("\n", " ")

    precedents_matches = re.findall(pattern, text)
    
    # Filter matches to remove those with length exceeding too much characters
    filtered_matches = [match for match in precedents_matches if len(match) <= 25]

    # Remove duplicate legal clauses
    seen = set()
    unique_precedents_matches = []
    
    for match in filtered_matches:
        # Remove all spaces and compare
        cleaned_match = re.sub(r'\s+', '', match)
        if cleaned_match not in seen:
            seen.add(cleaned_match)
            unique_precedents_matches.append(match)
    
    return unique_precedents_matches

In [5]:
# Remove rows with empty "text" column
Hugging_Face_df = Hugging_Face_df[Hugging_Face_df['text'].isna() == False]
print("Len:", len(Hugging_Face_df))

# Apply find_legal_precedents function to the 'text' column and store the result in a new column
Hugging_Face_df['precedents_found'] = Hugging_Face_df['text'].apply(find_legal_precedents)

# Save only rows with precedents found in the "text" column
legal_df = Hugging_Face_df[Hugging_Face_df['precedents_found'].apply(len) > 0][['text', 'precedents_found']]

# Arange df
# --------------------------

# Rename columns
legal_df.rename(columns={'text': "context", 'precedents_found': 'answers'}, inplace=True)

# Create id column
legal_df['id'] = legal_df.index

# Create new column
legal_df['question'] = 'באילו פסקי דין\תקדימים נעשה שימוש ?'

# Reorder columns
legal_df = legal_df[['id', 'context', 'question', 'answers']]

# Replace newline characters with space
legal_df['context'] = legal_df['context'].str.replace("\n", " ")


# Find 'start_indices' and create a dataset
# ------------------------------------------

# Find the start indices of all the elements of the answers list in the context column
def find_start_indices(row):
    indices = []
    for answer in row['answers']:
        indices.append(row['context'].find(answer))
    return indices

# Apply the function to the df to get the indices of the start of both elements in the context column
legal_df['start_indices'] = legal_df.apply(find_start_indices, axis=1)

# Convert the 'answers' col to a dictionary
def convert_to_dict(answers, start_indices):
    return {'text': answers, 'answer_start': start_indices}

# Apply the function to each row of the df
legal_df['answers'] = legal_df.apply(lambda row: convert_to_dict(row['answers'], row['start_indices']), axis=1)

# Drop 'start_indices' col
legal_df.drop(['start_indices'], axis=1, inplace=True)

Len: 750841


## Final DataFrame

In [6]:
legal_df

Unnamed: 0,id,context,question,answers
10,10,ב בית המשפט העליון בירושלים רע...,אילו סעיפים\חוקים\תקנות מצויינים במסמך ?,"{'text': ['תקנות סדר הדין האזרחי, תשמ""ד1984'],..."
11,11,ב בית המשפט העליון בשבתו כבית משפט לער...,אילו סעיפים\חוקים\תקנות מצויינים במסמך ?,"{'text': ['כלל לעניינו של המערער, ועל כך עמד ב..."
17,17,ב בית המשפט העליון ...,אילו סעיפים\חוקים\תקנות מצויינים במסמך ?,"{'text': ['תקנות סדר הדין האזרחי, התשמ""ד1984']..."
19,19,ב בית המשפט העליון ...,אילו סעיפים\חוקים\תקנות מצויינים במסמך ?,"{'text': ['תקנות סדר הדין האזרחי התשמ""ד1984'],..."
20,20,ב בית המשפט העליון ...,אילו סעיפים\חוקים\תקנות מצויינים במסמך ?,"{'text': ['תקנות סדר הדין האזרחי התשמ""ד1984'],..."
...,...,...,...,...
751189,751189,בבית המשפט העליו...,אילו סעיפים\חוקים\תקנות מצויינים במסמך ?,{'text': ['חוק סדר הדין הפלילי (סמכויות אכיפה ...
751190,751190,בבית המשפט העליון ...,אילו סעיפים\חוקים\תקנות מצויינים במסמך ?,{'text': ['חוק סדר הדין הפלילי (סמכויות אכיפה ...
751191,751191,בבית המשפט העליון ...,אילו סעיפים\חוקים\תקנות מצויינים במסמך ?,{'text': ['חוק סדר הדין הפלילי (סמכויות  אכיפ...
751192,751192,בבית המשפט העליון ...,אילו סעיפים\חוקים\תקנות מצויינים במסמך ?,{'text': ['חוק סדר הדין הפלילי (סמכויות אכיפה ...


## Continue

In [7]:
train_df, validation_df = train_test_split(legal_df, test_size=0.2, random_state=42)

In [8]:
train_dataset = Dataset.from_dict(train_df)
validation_dataset = Dataset.from_dict(validation_df)
datasets = ds.DatasetDict({"train":train_dataset,"validation":validation_dataset})

In [9]:
datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'context', 'question', 'answers'],
        num_rows: 147946
    })
    validation: Dataset({
        features: ['id', 'context', 'question', 'answers'],
        num_rows: 36987
    })
})

## Preprocessing the training data

Instantiate a tokenizer with `AutoTokenizer.from_pretrained` method

In [14]:
# Load this model for Legal_Clauses training
modelName = 'shay681/HeBERT_finetuned_Legal_Clauses'

In [None]:
# Load this model for Precedents training
modelName = 'shay681/HeBERT_finetuned_Precedents'

In [15]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(modelName)

# Assertion to ensure that the tokenizer is a fast tokenizer
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

Hyper Parameters:

In [20]:
max_length = 384 # The maximum length of a feature (question and context)
doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.

# For the special case where the model expects padding on the left
pad_on_right = tokenizer.padding_side == "right"

In [21]:
def prepare_train_features(examples):
    # Some of the questions have lots of whitespace on the left, which is not useful and will make the
    # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
    # left whitespace
    examples["question"] = [q.lstrip() for q in examples["question"]]

    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

Apply the function on all the elements of all the splits in `dataset`.

Since the preprocessing changes the number of samples, remove the old columns when applying it.

In [None]:
tokenized_datasets = datasets.map(prepare_train_features, batched=True, remove_columns=datasets["train"].column_names)

In [None]:
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer

model = AutoModelForQuestionAnswering.from_pretrained(modelName)

Instantiate a `Trainer`

In [None]:
model_name = modelName.split("/")[-1]
args = TrainingArguments(
    "./HeBERT_finetuned_Results/",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    save_total_limit=5,
    save_steps=500,
    eval_steps=500,
    logging_steps=100,
    # save_strategy="no"
    # use_cpu=True
)

Data collator that will batch the processed examples together

In [None]:
from transformers import default_data_collator

data_collator = default_data_collator

Define a `Trainer`:

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

Finetune the model by calling the `train` method:

In [None]:
trainer.train()

In [None]:
# Continue training from last checkpoint
trainer.train(resume_from_checkpoint="./HeBERT_finetuned_Results/checkpoint-120000")

Save the model

In [None]:
trainer.save_model("model_fine_tuned")

## Inference

Instantiate a `pipeline` for Q&A

In [None]:
from itertools import combinations
import re

def remove_duplicate_answers(answers):
    unique_answers = []
    seen_answers = set()

    for answer in answers:
        # Remove excess spaces from the answer
        normalized_answer =re.sub(r'\s+', '', answer['answer'])
        # Check if the normalized answer has already been seen
        if normalized_answer not in seen_answers:
            unique_answers.append(answer)
            seen_answers.add(normalized_answer)

    return unique_answers


def merge_answers(answers):
    merged_pairs = []

    for pair in combinations(answers, 2):
        answer1, answer2 = pair
        start1, end1 = answer1['start'], answer1['end']
        start2, end2 = answer2['start'], answer2['end']

        if (start1 <= start2 <= end1) or (start2 <= start1 <= end2):
            # Find the overlapping part
            overlap_start = max(start1, start2)
            overlap_end = min(end1, end2)
            overlap_length = overlap_end - overlap_start
            
            # Merge the answers by concatenating the non-overlapping parts and the overlapping part once
            if start1 <= start2:
                merged_text = answer1['answer'] + " " + answer2['answer'][overlap_length:]
            else:
                merged_text = answer2['answer'] + " " + answer1['answer'][overlap_length:]

            merged_pair = {
                'score': min(answer1['score'], answer2['score']),
                'start': min(start1, start2),
                'end': max(end1, end2),
                'answer': merged_text
            }
            merged_pairs.append(merged_pair)

    return merged_pairs

In [None]:
from transformers import pipeline

validation_dataset = datasets['validation']

# Initialize the question answering pipeline
# for Precedents pipeline: model="HeBERT_finetuned_Precedents"
question_answerer = pipeline("question-answering", model="HeBERT_finetuned_Legal_Clauses")

# Define the regex pattern
pattern = re.compile(full_pattern)

def get_multiple_answers(question_answerer, question, context, n_best_size=200):
    answers = question_answerer(question=question, context=context, top_k=n_best_size)

    # Filter answers 
    filtered_answers = []
    no_match_answers = []
    pairs = []

    for answer in answers:
        # if answer['score'] < 1 and pattern.match(answer['answer']):
        if pattern.match(answer['answer']):
            filtered_answers.append(answer)
        else:
            no_match_answers.append(answer)
            merged_pairs = merge_answers(no_match_answers)
            for pair in merged_pairs:
                if pattern.match(pair['answer']):
                    pairs.append(pair)

    fixed_merged_pairs = remove_duplicate_answers(pairs)
    filtered_answers += fixed_merged_pairs
    
    # Sort answers by their start position to help with the containment check
    filtered_answers.sort(key=lambda x: x['start'])

    # Remove nested answers
    unique_answers = []
    for answer in filtered_answers:
        if not any(answer['start'] >= prev_answer['start'] and answer['end'] <= prev_answer['end'] for prev_answer in unique_answers):
            unique_answers.append(answer)

    # Extract the text of unique answers
    results = [answer for answer in unique_answers]
    
    return results

In [None]:
# Create a list to store the results
results = []

# Iterate over the validation dataset and make predictions

for example in validation_dataset:
    context = example["context"]
    question = example["question"]
    predicted_answers = get_multiple_answers(question_answerer, question, context)

    print(example['id'], ":",  predicted_answers)

    results.append({
        'id': example['id'],
        'predicted_answers': predicted_answers
    })

In [None]:
# Create a DataFrame and save it to CSV
predicted_answers = pd.DataFrame(results)
predicted_answers.to_csv('predicted_answers.csv', index=False, encoding='utf-8')

## Evaluation

Compare Predictions with Ground Truth

In [None]:
def normalize_answer(answer):
    """Normalize answer by removing all spaces and specified initial letters from each word, and converting to lowercase."""
    # Remove all spaces
    answer = re.sub(r'\s+', '', answer)
    
    # Remove initial letters from each word
    words = answer.split()
    initial_letters = {'ב', 'ל', 'ו', 'ש', 'כ'}  # Add any other letters to this set as needed
    
    normalized_words = []
    for word in words:
        if word and word[0] in initial_letters:
            normalized_words.append(word[1:])  # Remove the initial letter
        else:
            normalized_words.append(word)
    
    normalized_answer = ''.join(normalized_words)
    
    return normalized_answer.strip().lower()


def evaluate_predictions(validation_dataset, question_answerer):
    results = []

    for example in validation_dataset:
        context = example["context"]
        question = example["question"]
        ground_truth_answers = example["answers"]["text"]

        # Get the model predictions
        predicted_answers = get_multiple_answers(question_answerer, question, context)

        # Normalize ground truth answers
        normalized_ground_truth = [normalize_answer(ans) for ans in ground_truth_answers]
        
        # Normalize predicted answers
        normalized_predictions = [normalize_answer(pred['answer']) for pred in predicted_answers]

        # Check if any of the predicted answers match any of the ground truth answers
        matched = any(any(gt in pred for pred in normalized_predictions) for gt in normalized_ground_truth)
        # matched = any(normalize_answer(pred['answer']) in normalized_ground_truth for pred in predicted_answers)

        # Calculate BLEU scores
        bleu_scores = [sentence_bleu([normalize_answer(gt).split()], normalize_answer(pred['answer']).split()) for gt in ground_truth_answers for pred in predicted_answers]

        results.append({
            'id': example['id'],
            'context': context,
            'question': question,
            'ground_truth_answers': ground_truth_answers,
            'predicted_answers': [pred['answer'] for pred in predicted_answers],
            'matched': matched,
            'bleu_scores': bleu_scores
        })

    return results

Compute Evaluation Metrics

In [None]:
def compute_metrics(evaluation_results):
    total = len(evaluation_results)
    matched = sum(result['matched'] for result in evaluation_results)
    accuracy = matched / total

    all_bleu_scores = [score for result in evaluation_results for score in result['bleu_scores']]
    average_bleu_score = sum(all_bleu_scores) / len(all_bleu_scores) if all_bleu_scores else 0

    print(f'Accuracy: {accuracy * 100:.2f}%')
    print(f'Average BLEU Score: {average_bleu_score * 100:.2f}%')

    return accuracy, average_bleu_score

In [None]:
from transformers import pipeline

# Initialize Q&A pipeline
# for Precedents pipeline: model="HeBERT_finetuned_Precedents"
question_answerer = pipeline("question-answering", model="HeBERT_finetuned_Legal_Clauses")

# Run evaluation
evaluation_results = evaluate_predictions(validation_dataset, question_answerer)

# Compute the accuracy
accuracy, average_bleu_score = compute_metrics(evaluation_results)

# Convert evaluation results to a DataFrame
df_results = pd.DataFrame(evaluation_results)

# Save to CSV
df_results.to_csv('LegalClauses_evaluation_results.csv', index=False)

# for Precedents Evaluation
# df_results.to_csv('Precedents_evaluation_results.csv', index=False)

## Calculate metrics

In [18]:
# Extract case number from "precedents_found" column
def extract_case_number(text):
    pattern = r'\b\d{1,4}/\d{1,4}\b'
    case_numbers = re.findall(pattern, text)
    return case_numbers

In [19]:
# Function to calculate BLEU score for each pair of ground truth and predicted answers
def calculate_bleu_score(ground_truth, predicted):
    if len(predicted) == 0:
        return 0  # If there are no predicted answers, return 0 BLEU score
    return sentence_bleu([ground_truth], predicted)

# Function to calculate Precision, Recall, and F1 score for each pair of ground truth and predicted answers
def calculate_prf_scores(ground_truth, predicted):
    true_positive = len(set(ground_truth) & set(predicted))
    false_positive = len(set(predicted) - set(ground_truth))
    false_negative = len(set(ground_truth) - set(predicted))
    
    precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
    recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
    f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    return precision, recall, f1_score

In [20]:
# Load the CSV file into a DataFrame
LegalClauses_df = pd.read_csv('./Results/Q&A/LegalClauses_evaluation_results.csv')
Precedents_df = pd.read_csv('./Results/Q&A/Precedents_evaluation_results.csv')

In [21]:
# Relevant to Precedents_evaluation_results
# -----------------------------------------

# Apply extract_case_number function to the 'predicted_answers' column and store the result in a new column
Precedents_df['gta'] = Precedents_df['ground_truth_answers'].apply(extract_case_number)

# Apply extract_case_number function to the 'predicted_answers' column and store the result in a new column
Precedents_df['pa'] = Precedents_df['predicted_answers'].apply(extract_case_number)

# Remove rows where there are no predicted answers
Precedents_df_filtered = Precedents_df[Precedents_df['pa'].apply(len) > 0]

In [24]:
# Find Legal Clauses
# -----------------
# # Apply the function to each row
LegalClauses_df['bleu_score'] = LegalClauses_df.apply(lambda row: calculate_bleu_score(row['ground_truth_answers'], row['predicted_answers']), axis=1)

# Apply the Precision, Recall, and F1 score function to each row and create separate columns for each
LegalClauses_df[['precision', 'recall', 'f1_score']] = LegalClauses_df.apply(lambda row: pd.Series(calculate_prf_scores(row['ground_truth_answers'], row['predicted_answers'])), axis=1)

# Calculate the average scores
average_bleu_score = LegalClauses_df['bleu_score'].mean()
average_precision = LegalClauses_df['precision'].mean()
average_recall = LegalClauses_df['recall'].mean()
average_f1_score = LegalClauses_df['f1_score'].mean()

print(f'Average BLEU Score: {average_bleu_score:.3f}')
print(f'Average Precision: {average_precision:.3f}')
print(f'Average Recall: {average_recall:.3f}')
print(f'Average F1 Score: {average_f1_score:.3f}')

Average BLEU Score: 0.895
Average Precision: 0.997
Average Recall: 0.967
Average F1 Score: 0.970


In [23]:
# Find Precedents
# -----------------
# Apply the function to each row
Precedents_df_filtered['bleu_score'] = Precedents_df_filtered.apply(lambda row: calculate_bleu_score(row['gta'][0], row['pa'][0]), axis=1)

# Apply the Precision, Recall, and F1 score function to each row and create separate columns for each
Precedents_df_filtered[['precision', 'recall', 'f1_score']] = Precedents_df_filtered.apply(lambda row: pd.Series(calculate_prf_scores(row['gta'], row['pa'])), axis=1)

# Calculate the average scores
average_bleu_score = Precedents_df_filtered['bleu_score'].mean()
average_precision = Precedents_df_filtered['precision'].mean()
average_recall = Precedents_df_filtered['recall'].mean()
average_f1_score = Precedents_df_filtered['f1_score'].mean()

print(f'Average BLEU Score: {average_bleu_score:.3f}')
print(f'Average Precision: {average_precision:.3f}')
print(f'Average Recall: {average_recall:.3f}')
print(f'Average F1 Score: {average_f1_score:.3f}')

Average BLEU Score: 0.516
Average Precision: 0.689
Average Recall: 0.807
Average F1 Score: 0.699
