In [17]:
import nltk
from sentence_transformers import SentenceTransformer, util
from rank_bm25 import BM25Okapi
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

nltk.download('punkt')
dense_model = SentenceTransformer('all-MiniLM-L6-v2')

def process_and_rank_sentences(context, question, min_length=50, method="dense", top_k=5):
    sentences = nltk.sent_tokenize(context)
    merged_sentences = []
    original_indices = []  # Track original indices
    current_sentence = ""
    current_index = 0  # Keep track of the current sentence index

    for i, sentence in enumerate(sentences):
        if len(current_sentence) + len(sentence) < min_length:
            current_sentence += " " + sentence
        else:
            if current_sentence:
                merged_sentences.append(current_sentence.strip())
                original_indices.append(current_index)  # Record the starting index for this merged sentence
            current_sentence = sentence
            current_index = i
    if current_sentence:
        merged_sentences.append(current_sentence.strip())
        original_indices.append(current_index)  # Record the last sentence's index
    
    if method == "dense":
        question_embedding = dense_model.encode(question, convert_to_tensor=True)
        sentence_embeddings = dense_model.encode(merged_sentences, convert_to_tensor=True)
        
        scores = util.cos_sim(question_embedding, sentence_embeddings)[0]
        ranked_sentences = sorted(
            zip(merged_sentences, scores.tolist(), original_indices), key=lambda x: x[1], reverse=True
        )
        
    elif method == "bm25":
        tokenized_sentences = [nltk.word_tokenize(sent) for sent in merged_sentences]
        bm25 = BM25Okapi(tokenized_sentences)
        
        scores = bm25.get_scores(nltk.word_tokenize(question))
        ranked_sentences = sorted(
            zip(merged_sentences, scores, original_indices), key=lambda x: x[1], reverse=True
        )

    top_k_sentences = sorted(ranked_sentences[:top_k], key=lambda x: x[2])

    merged_text = ' '.join([sentence for sentence, _, _ in top_k_sentences])
    
    return merged_text

[nltk_data] Downloading package punkt to /home/elkhyo/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [18]:
from datasets import load_dataset
test_dataset = load_dataset("theatticusproject/cuad-qa",split='test')

from tqdm import tqdm

def populate_generations(test_dataset, sample, generate_method, retrieval, top_k, device):
    generated_ans = []
    context_used = []

    for id in tqdm(range(sample)):
        full_context = test_dataset[id]['context']
        question = test_dataset[id]['question']

        if retrieval is not None:
            context = process_and_rank_sentences(full_context, question, min_length=50, method=retrieval, top_k=top_k)
            # top_k_sentences = ranked_sentences[:top_k]
            # context = " ".join(sentence for sentence, _ in top_k_sentences)
            
        # gold_answer = " ".join(test_dataset[id]['answers']['text'])
        # if gold_answer=='':
        #     gold_answer = 'The answer is not present in the contract.'
        else:
            context = full_context
        
        max_new_tokens = 50

        # ans_context =  model_generate(context, question, max_new_tokens, generate_method, device, full_context = full_context)

        # print(f'Generated Answer: {ans_context}')
        # print(f'Answer Gold: {gold_answer}')
        # generated_ans.append(ans_context)
        context_used.append(context)
    return generated_ans, context_used

In [19]:
import numpy as np

def hit_rate(context, original_answers):
    hits = sum(1 for answer in original_answers if answer in context)
    hit_rate_score = hits / len(original_answers) if len(original_answers)!=0 else np.NaN
    return hit_rate_score

def score_function(test_dataset, generated_ans, input_contexts):
    #input_contexts = [x['context'] for x in test_dataset]
    gold_answers_list = [x['answers']['text'] for x in test_dataset]
    gold_answers = [" ".join(x['answers']['text']) for x in test_dataset]
    hit_rate_scores = [hit_rate(context, answer_list) for context, answer_list in zip(input_contexts, gold_answers_list)] 

    # gold_answers = ['Not present in the contract.' if answer == '' else answer for answer in gold_answers]
    # generated_ans = ['None' if answer == '' else answer for answer in generated_ans]

    # correct_scores =  alignscore_scorer(generated_ans, gold_answers)

    # faith_gen_ans, faith_context = zip(*[(gen, context) for gen, context in zip(generated_ans, input_contexts) if gen != 'None' and gen != 'Not present in the contract.</s>'])

    # faith_scores =  alignscore_scorer(faith_context, faith_gen_ans)

    # faith_scores_full = [np.nan] * len(generated_ans)
    # faith_index = 0

    # for i, gen in enumerate(generated_ans):
    #     if gen != 'None' and gen != 'Not present in the contract.</s>':
    #         faith_scores_full[i] = faith_scores[faith_index]
    #         faith_index += 1
            
    return None, None, hit_rate_scores

In [22]:
from datasets import load_dataset
test_dataset = load_dataset("theatticusproject/cuad-qa",split='test')

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
device = torch.device('cuda:0')

generate_method = 'Extend_Adaknn'
retrieval= 'dense'
top_k = 20

generated_answers, used_context = populate_generations(test_dataset, len(test_dataset), generate_method, retrieval, top_k, device)

100%|██████████| 4182/4182 [04:10<00:00, 16.69it/s]


In [23]:
_, _, hit_rate_scores = score_function(test_dataset, generated_answers, used_context)
print(hit_rate_scores)
print(f'Average Hit Rate Score: {np.nanmean(hit_rate_scores)}')

[1.0, 1.0, nan, nan, 1.0, nan, nan, 1.0, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 0.0, 1.0, nan, nan, 0.0, 0.5, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, nan, nan, nan, nan, nan, nan, nan, 1.0, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 1.0, nan, nan, nan, nan, nan, 1.0, 0.4, 1.0, 1.0, 1.0, nan, nan, 1.0, nan, 0.0, 0.0, nan, 0.5, 1.0, nan, 1.0, nan, 1.0, 0.5, 0.0, nan, 0.4, nan, 0.3333333333333333, nan, 0.6666666666666666, 1.0, nan, 1.0, nan, nan, nan, nan, 1.0, 1.0, 0.5, 1.0, nan, 1.0, 0.5, nan, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, nan, 1.0, 0.0, nan, nan, nan, nan, 0.3333333333333333, nan, 0.2, 1.0, 0.0, nan, nan, nan, 0.0, 1.0, 0.125, 0.0, nan, nan, nan, 0.0, nan, 0.16666666666666666, 0.3333333333333333, 0.0, 0.5, nan, nan, 1.0, nan, nan, 1.0, 1.0, 1.0, 1.0, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 1.0, nan, n