# mediRAG Pipeline

In [21]:
import os
import torch
from tqdm import tqdm as tqdm
import numpy as np
# import pickle
from uuid import uuid4

from prompts import *
import evaluate
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline #BitsAndBytesConfig
from langchain.docstore.document import Document
from langchain.vectorstores import FAISS
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.chains import LLMChain
from langchain.schema.runnable import RunnablePassthrough

In [2]:
# torch.cuda.set_device(3)  # have to change depending on which device u wana use
# torch.cuda.current_device()

3

## Load Model and Tokenizer

In [4]:
# model_name='mistralai/Mistral-7B-Instruct-v0.1'

# tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="models")

# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.bfloat16,
#     bnb_4bit_use_double_quant=False,
# )

# model = AutoModelForCausalLM.from_pretrained(
#     model_name,
#     quantization_config=bnb_config,
#     cache_dir="models"
# )

# Load Data

In [32]:
# # Util functions
# q2c = {}
# def question2chunk(dataset):
#     question = "QUESTION"
#     context = "CONTEXTS"
#     for split in dataset.keys():
#         for abstract in dataset[split][context]:
#             for sentence in abstract:
#                 yield Document(page_content=sentence)



In [22]:
dataset = load_dataset("bigbio/pubmed_qa", cache_dir="data")
dataset

DatasetDict({
    train: Dataset({
        features: ['QUESTION', 'CONTEXTS', 'LABELS', 'MESHES', 'YEAR', 'reasoning_required_pred', 'reasoning_free_pred', 'final_decision', 'LONG_ANSWER'],
        num_rows: 200000
    })
    validation: Dataset({
        features: ['QUESTION', 'CONTEXTS', 'LABELS', 'MESHES', 'YEAR', 'reasoning_required_pred', 'reasoning_free_pred', 'final_decision', 'LONG_ANSWER'],
        num_rows: 11269
    })
})

In [23]:
def preprocess(dataset):
    page_content_column = "CONTEXTS"
    for split in dataset.keys():
        for contexts in dataset[split][page_content_column]:
            for sentence in contexts:
                yield Document(page_content=sentence)

data = list(preprocess(dataset))  # 655055
data[0]

Document(page_content='In previous work we (Fisher et al., 2011) examined the emergence of neurobehavioral disinhibition (ND) in adolescents with prenatal substance exposure. We computed ND factor scores at three age points (8/9, 11 and 13/14 years) and found that both prenatal substance exposure and early adversity predicted ND. The purpose of the current study was to determine the association between these ND scores and initiation of substance use between ages 8 and 16 in this cohort as early initiation of substance use has been related to later substance use disorders. Our hypothesis was that prenatal cocaine exposure predisposes the child to ND, which, in turn, is associated with initiation of substance use by age 16.')

## Setting up retrievers

In [24]:
embedding_model = "BAAI/bge-large-en-v1.5"
model_kwargs = {'device':'cuda'}
encode_kwargs = {'normalize_embeddings': False}

# Initialize an instance of HuggingFaceEmbeddings with the specified parameters
embeddings = HuggingFaceEmbeddings(
    model_name=embedding_model,   
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs, 
    cache_folder="models"
)

if os.path.exists("faiss_index_pubmed"):
    print(os.path.exists("faiss_index_pubmed"))
    db = FAISS.load_local("faiss_index_pubmed", embeddings)
else:
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=128)
    docs = text_splitter.split_documents(data)  # 676307
    db = FAISS.from_documents(docs, embeddings)
    db.save_local("faiss_index_pubmed")

True


In [25]:
# initialize the bm25 retriever and faiss retriever
bm25_retriever = BM25Retriever.from_documents(data)
# bm25_retriever.k = 3

faiss_retriever = db.as_retriever() #search_kwargs={"k": 3}

# initialize the ensemble retriever
ensemble_retriever = EnsembleRetriever(
    retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
)

In [26]:
question = dataset['train']["QUESTION"][0]
context = dataset['train']["CONTEXTS"][0]

retrieved_docs = ensemble_retriever.get_relevant_documents(question)

print(f"Question:\n{question}")
print(f"\nContext:\n{context}")
print(f"\nRetrieved document:\n{retrieved_docs[0].page_content}\n{retrieved_docs[1].page_content}\n{retrieved_docs[2].page_content}")

Question:
Does neurobehavioral disinhibition predict initiation of substance use in children with prenatal cocaine exposure?

Context:
['In previous work we (Fisher et al., 2011) examined the emergence of neurobehavioral disinhibition (ND) in adolescents with prenatal substance exposure. We computed ND factor scores at three age points (8/9, 11 and 13/14 years) and found that both prenatal substance exposure and early adversity predicted ND. The purpose of the current study was to determine the association between these ND scores and initiation of substance use between ages 8 and 16 in this cohort as early initiation of substance use has been related to later substance use disorders. Our hypothesis was that prenatal cocaine exposure predisposes the child to ND, which, in turn, is associated with initiation of substance use by age 16.', "We studied 386 cocaine exposed and 517 unexposed children followed since birth in a longitudinal study. Five dichotomous variables were computed based 

In [58]:
# PX implementation

def calculate_map(ranked_lists):
    ap_sum = 0
    for ranked_list in ranked_lists:
        precision_sum = 0
        relevant_docs = 0
        for i, doc in enumerate(ranked_list):
            if doc == 1:
                relevant_docs += 1
                precision_sum += relevant_docs / (i + 1)
        if relevant_docs != 0:
            ap_sum += precision_sum / relevant_docs
        else:
            ap_sum += 0
    map = ap_sum / len(ranked_lists)
    return map

def preprocess_val(val_dataset):
    page_content_column = "CONTEXTS"
    for contexts in val_dataset[page_content_column]:
        for sentence in contexts:
            yield Document(page_content=sentence)

from collections import defaultdict
val_dataset = dataset['validation']
val_data = list(preprocess_val(val_dataset)) # complete docs

# shorten
val_data = val_data[0:20]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=128)
val_chunks = text_splitter.split_documents(val_data)

# check which docs have been split into chunks
docs_split_into_chunks = []
for idx, complete_doc in enumerate(val_data):
    split_into_chunks = True
    for possibly_incomplete_doc in val_chunks:
        if complete_doc == possibly_incomplete_doc:
            split_into_chunks = False
    if split_into_chunks == True:
        docs_split_into_chunks.append(idx)

# for those docs split into chunks, check number of corresponding chunks
chunks_dict = defaultdict(int)
for doc_idx in docs_split_into_chunks:
    full_doc_content = val_data[doc_idx].page_content
    for chunk in val_chunks:
        if chunk.page_content in full_doc_content:
            chunks_dict[doc_idx] += 1

ranked_lists = []
for i in tqdm(range(len(val_data))): # for each validation query
    question = dataset['validation']["QUESTION"][i]
    context = dataset['validation']["CONTEXTS"][i]
    joined_context = ''.join(context)
    # retrieved_docs = db.similarity_search(question)
    retrieved_docs = ensemble_retriever.get_relevant_documents(question)
    
    ranked_ls = []
    if i not in docs_split_into_chunks:
        num_chunks = 1
    else:
        num_chunks = chunks_dict[i]
    for j in range(num_chunks): # num of elements within a nested list
        retrieved_doc_content = retrieved_docs[j].page_content # change this -> retrieved_docs[0]: 1st ranked doc, retrieved_docs[1]: 2nd ranked doc, and so on
        if retrieved_doc_content in joined_context:
            ranked_ls.append(1)
        else:
            ranked_ls.append(0)
    ranked_lists.append(ranked_ls)

print(calculate_map(ranked_lists))
print(ranked_lists)

100%|██████████| 20/20 [05:02<00:00, 15.11s/it]

0.95





In [59]:
ranked_lists

[[1],
 [1],
 [1],
 [0],
 [1],
 [1],
 [1],
 [1],
 [1],
 [1],
 [1],
 [1],
 [1],
 [1],
 [1],
 [1],
 [1],
 [1],
 [1],
 [1]]

In [54]:
# SH implementation

from collections import defaultdict

def preprocess_val(val_contexts):
    page_content_column = "CONTEXTS"
    for contexts in val_contexts:
        for sentence in contexts:
            yield Document(page_content=sentence)

def create_question2chunk(val_questions, val_contexts):
    question2chunk = {}
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=128)
    for question, context in zip(val_questions, val_contexts):
        # split each doc into different sentences then chunk them
        context_docs = []
        for sentence in context:
            doc = Document(page_content=sentence)
            context_docs.append(doc)
        chunks = text_splitter.split_documents(context_docs)
        question2chunk[question] = chunks
    return question2chunk

print("Loading data")
val_dataset = dataset['validation']
val_contexts = val_dataset["CONTEXTS"][:20]
val_questions = val_dataset["QUESTION"][:20]

print("Processing data")
val_data = list(preprocess_val(val_contexts)) # complete docs

print("Creating dict mapping question to chunks")
question2chunk = create_question2chunk(val_questions,val_contexts)

all_relevance = []
for i in tqdm(range(len(val_questions))): # for each validation query
    print("Retrieving documents")
    question = val_questions[i]
    retrieved_docs = ensemble_retriever.get_relevant_documents(question)

    print("Label relevance")
    # label the retrieved list of doc as relevant or not
    bin_relevance = []
    for d in retrieved_docs:
        gt_chunks = question2chunk[question]
        if d in gt_chunks:
            bin_relevance.append(1)
        else:
            bin_relevance.append(0)
    all_relevance.append(bin_relevance)
    print("-------------")

print(all_relevance)

Loading data
Processing data
Creating dict mapping question to chunks


  0%|          | 0/20 [00:00<?, ?it/s]

Retrieving documents


  5%|▌         | 1/20 [00:14<04:40, 14.74s/it]

Label relevance
-------------
Retrieving documents


 10%|█         | 2/20 [00:30<04:32, 15.16s/it]

Label relevance
-------------
Retrieving documents


 15%|█▌        | 3/20 [00:45<04:19, 15.29s/it]

Label relevance
-------------
Retrieving documents


 20%|██        | 4/20 [01:00<04:02, 15.13s/it]

Label relevance
-------------
Retrieving documents


 25%|██▌       | 5/20 [01:15<03:48, 15.21s/it]

Label relevance
-------------
Retrieving documents


 30%|███       | 6/20 [01:30<03:32, 15.18s/it]

Label relevance
-------------
Retrieving documents


 35%|███▌      | 7/20 [01:47<03:24, 15.70s/it]

Label relevance
-------------
Retrieving documents


 40%|████      | 8/20 [02:02<03:02, 15.25s/it]

Label relevance
-------------
Retrieving documents


 45%|████▌     | 9/20 [02:16<02:45, 15.07s/it]

Label relevance
-------------
Retrieving documents


 50%|█████     | 10/20 [02:31<02:29, 14.98s/it]

Label relevance
-------------
Retrieving documents


 55%|█████▌    | 11/20 [02:45<02:12, 14.72s/it]

Label relevance
-------------
Retrieving documents


 60%|██████    | 12/20 [02:59<01:56, 14.50s/it]

Label relevance
-------------
Retrieving documents


 65%|██████▌   | 13/20 [03:15<01:43, 14.79s/it]

Label relevance
-------------
Retrieving documents


 70%|███████   | 14/20 [03:31<01:31, 15.29s/it]

Label relevance
-------------
Retrieving documents


 75%|███████▌  | 15/20 [03:46<01:15, 15.18s/it]

Label relevance
-------------
Retrieving documents


 80%|████████  | 16/20 [04:00<00:59, 14.94s/it]

Label relevance
-------------
Retrieving documents


 85%|████████▌ | 17/20 [04:16<00:45, 15.17s/it]

Label relevance
-------------
Retrieving documents


 90%|█████████ | 18/20 [04:31<00:30, 15.14s/it]

Label relevance
-------------
Retrieving documents


 95%|█████████▌| 19/20 [04:46<00:14, 14.92s/it]

Label relevance
-------------
Retrieving documents


100%|██████████| 20/20 [05:00<00:00, 15.03s/it]

Label relevance
-------------
[[1, 1, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], [0, 1, 0, 1, 1, 0, 0, 0], [1, 0, 1, 0, 1, 0, 0], [1, 1, 0, 1, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 1, 0, 1, 0, 0], [1, 0, 1, 0, 0, 0, 0], [1, 1, 0, 0, 1], [1, 1, 0, 0, 1, 0, 0], [1, 1, 0, 0, 0], [1, 0, 0, 0, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 1], [1, 1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 0]]





In [55]:
def precision_at_k(r, k):
    """Score is precision @ k (This we solve for you!)
    Relevance is binary (nonzero is relevant).
    >>> r = [0, 0, 1]
    >>> precision_at_k(r, 1)
    0.0
    >>> precision_at_k(r, 2)
    0.0
    >>> precision_at_k(r, 3)
    0.33333333333333331
    >>> precision_at_k(r, 4)
    Traceback (most recent call last):
        File "<stdin>", line 1, in ?
    ValueError: Relevance score length < k
    Args:
        r: Relevance scores (list or numpy) in rank order
            (first element is the first item)
    Returns:
    Precision @ k
    Raises:
        ValueError: len(r) must be >= k
    """
    assert k >= 1
    r = np.asarray(r)[:k] != 0
    if r.size != k:
        raise ValueError('Relevance score length < k')
    return np.mean(r)

def average_precision(r):
    """Score is average precision (area under PR curve)
    Relevance is binary (nonzero is relevant).
    >>> r = [1, 1, 0, 1, 0, 1, 0, 0, 0, 1]
    >>> delta_r = 1. / sum(r)
    >>> sum([sum(r[:x + 1]) / (x + 1.) * delta_r for x, y in
    enumerate(r) if y])
    0.7833333333333333
    >>> average_precision(r)
    0.78333333333333333
    Args:
        r: Relevance scores (list or numpy) in rank order
            (first element is the first item)
    Returns:
        Average precision
    """
    #write your code here
    # print(r)
    relevant_idx = [i+1 for i in list(np.where(np.array(r)==1)[0])]
    # print(relevant_idx)
    n = sum(r)
    if n == 0: 
        return 0
    else:
        precision_k = [precision_at_k(r,pk) for pk in relevant_idx]
        avg_p = 1/n * sum(precision_k)
        return avg_p
    
def mean_average_precision(rs):
    """Score is mean average precision
    Relevance is binary (nonzero is relevant).
    >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1]]
    >>> mean_average_precision(rs)
    0.78333333333333333
    >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1], [0]]
    >>> mean_average_precision(rs)
    0.39166666666666666
    Args:
        rs: Iterator of relevance scores (list or numpy) in rank order
            (first element is the first item)
    Returns:
        Mean average precision
    """
    #write your code here
    avg_precision = [average_precision(r) for r in rs]
    n = len(rs)
    m_avg_p = 1/n * sum(avg_precision)

    return m_avg_p

In [57]:
score = mean_average_precision(all_relevance)
print(score)

0.8840277777777777


In [None]:
# testing
# def preprocess_val(val_contexts):
#     page_content_column = "CONTEXTS"
#     for contexts in val_contexts:
#         for sentence in contexts:
#             yield Document(page_content=sentence)

# def create_question2chunk(val_questions, val_contexts):
#     question2chunk = {}
#     text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=128)
#     for question, context in zip(val_questions, val_contexts):
#         # split each doc into different sentences then chunk them
#         context_docs = []
#         for sentence in context:
#             doc = Document(page_content=sentence)
#             context_docs.append(doc)
#         chunks = text_splitter.split_documents(context_docs)
#         question2chunk[question] = chunks
#     return question2chunk

# val_dataset = dataset['validation']
# val_contexts = val_dataset["CONTEXTS"][:1000]
# val_questions = val_dataset["QUESTION"][:1000]

# val_data = list(preprocess_val(val_contexts)) # complete docs
# out = create_question2chunk(val_questions,val_contexts)

# print(len(out))
# o = out['Do posterior fossa and spinal gangliogliomas form two distinct clinicopathologic and molecular subgroups?']
# Document(page_content='Gangliogliomas are low-grade glioneuronal tumors of the central nervous system and the commonest cause of chronic intractable epilepsy. Most gangliogliomas (>70%) arise in the temporal lobe, and infratentorial tumors account for less than 10%. Posterior fossa gangliogliomas can have the features of a classic supratentorial tumor or a pilocytic astrocytoma with focal gangliocytic differentiation, and this observation led to the hypothesis tested in this study - gangliogliomas of the posterior fossa and spinal cord consist of two morphologic types that can be distinguished by specific genetic alterations.') in o

## Initializing Pipeline


In [8]:
text_generation_pipeline = pipeline(
    model=model,
    tokenizer=tokenizer,
    task="text-generation",
    max_new_tokens=300,
    do_sample=False,
)

mistral_llm = HuggingFacePipeline(pipeline=text_generation_pipeline)

# Create prompt from prompt template 
prompt = PromptTemplate(
    input_variables=["context", "question"],
    template=PROMPT_TEMPLATE_QA_EXPLAINER,
)

# Create llm chain 
llm_chain = LLMChain(llm=mistral_llm, prompt=prompt)

## Running Queries

In [9]:
idx = 0

question = dataset['train'][2]["QUESTION"]
context = dataset['train'][2]["CONTEXTS"]
long_answer = dataset['train'][2]["LONG_ANSWER"]
final_decision = dataset['train'][2]["final_decision"]

## QA without Retrieval

In [10]:
input_ids = tokenizer.encode(question, return_tensors="pt").to("cuda")

with torch.no_grad():
    output = model.generate(
        input_ids=input_ids,
        do_sample=False,
        return_dict_in_generate=True,
        max_new_tokens=300,
    )

output = tokenizer.decode(output.sequences[0][len(input_ids[0]):])

output

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


'\n\n## Abstract\n\nHeart failure with preserved ejection fraction (HFpEF) is a common form of heart failure characterized by impaired left ventricular (LV) diastolic function and normal or near-normal LV ejection fraction. The pathophysiology of HFpEF is complex and involves both structural and functional changes in the heart. Recent studies have shown that HFpEF is associated with dynamic impairment of active relaxation and contraction of the LV on exercise, which may be related to myocardial energy deficiency. This review summarizes the current evidence on the dynamic impairment of LV function in HFpEF and its association with myocardial energy deficiency.\n\n## Introduction\n\nHeart failure with preserved ejection fraction (HFpEF) is a common form of heart failure characterized by impaired left ventricular (LV) diastolic function and normal or near-normal LV ejection fraction. HFpEF is estimated to affect 50% of all patients with heart failure, and its prevalence is increasing with

In [11]:
output_with_citations = ""
citations = ""
citation_list = []

for lines in output.split("\n"):
    lines = lines.strip()
    if len(lines.split(" ")) > 10:
        for line in lines.split("."):
            line = line.strip()
            docs_and_scores = db.similarity_search_with_score(line)[0]  # choosing top 1 relevant document
            if docs_and_scores[1] < 0.5:  # returned distance score is L2 distance, a lower score is better
                doc_content = docs_and_scores[0].page_content
                if doc_content in citation_list:
                    idx = citation_list.index(doc_content)

                else:
                    citation_list.append(doc_content)
                    idx = len(citation_list)
                    citations += f"[{idx}] {doc_content}\n"

                output_with_citations += line + f" [{idx}]. "

final_output_with_citations = output_with_citations + "\n\nCitations:\n" + citations

In [12]:
print(final_output_with_citations)

Heart failure with preserved ejection fraction (HFpEF) is a common form of heart failure characterized by impaired left ventricular (LV) diastolic function and normal or near-normal LV ejection fraction [1]. The pathophysiology of HFpEF is complex and involves both structural and functional changes in the heart [2]. Recent studies have shown that HFpEF is associated with dynamic impairment of active relaxation and contraction of the LV on exercise, which may be related to myocardial energy deficiency [3]. This review summarizes the current evidence on the dynamic impairment of LV function in HFpEF and its association with myocardial energy deficiency [1]. Heart failure with preserved ejection fraction (HFpEF) is a common form of heart failure characterized by impaired left ventricular (LV) diastolic function and normal or near-normal LV ejection fraction [0]. HFpEF is estimated to affect 50% of all patients with heart failure, and its prevalence is increasing with age [4]. The pathophy

## QA with Retrieval

In [13]:
retriever = db.as_retriever(
    search_type="similarity",
    search_kwargs={'k': 3}
)

# retriever = db.as_retriever(search_type="similarity_score_threshold", 
#                                  search_kwargs={"score_threshold": .5, 
#                                                 "k": top_k})

rag_chain = ({"context": retriever, "question": RunnablePassthrough()} | llm_chain)

# QA with retrieval
qa_retrieval_result = rag_chain.invoke(question)

qa_retrieval_result

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


{'context': [Document(page_content='We sought to evaluate the role of exercise-related changes in left ventricular (LV) relaxation and of LV contractile function and vasculoventricular coupling (VVC) in the pathophysiology of heart failure with preserved ejection fraction (HFpEF) and to assess myocardial energetic status in these patients.'),
  Document(page_content='Nearly half of patients with heart failure have a preserved ejection fraction (HFpEF). Symptoms of exercise intolerance and dyspnea are most often attributed to diastolic dysfunction; however, impaired systolic and/or arterial vasodilator reserve under stress could also play an important role.'),
  Document(page_content='To investigate the associations between glucose metabolism, left ventricular (LV) contractile reserve, and exercise capacity in patients with chronic systolic heart failure (HF).')],
 'question': 'Is heart failure with preserved ejection fraction characterized by dynamic impairment of active relaxation and

# Evaluation

In [14]:
bleu = evaluate.load("bleu", cache_dir="evaluation_metrics")  # value ranges from 0 to 1. score of 1 is better

bleu_score = bleu.compute(predictions=[output_with_citations], references=[long_answer])
print(f"BLEU Score: {bleu_score}")

bleu_score = bleu.compute(predictions=[qa_retrieval_result["text"]], references=[long_answer])
print(f"BLEU Score: {bleu_score}")

BLEU Score: {'bleu': 0.0, 'precisions': [0.05627705627705628, 0.013043478260869565, 0.004366812227074236, 0.0], 'brevity_penalty': 1.0, 'length_ratio': 9.625, 'translation_length': 231, 'reference_length': 24}
BLEU Score: {'bleu': 0.0, 'precisions': [0.07075471698113207, 0.018957345971563982, 0.004761904761904762, 0.0], 'brevity_penalty': 1.0, 'length_ratio': 8.833333333333334, 'translation_length': 212, 'reference_length': 24}


In [15]:
bertscore = evaluate.load("bertscore", cache_dir="evaluation_metrics")

bert_score = bertscore.compute(predictions=[output_with_citations], references=[long_answer], lang="en")
print(f"BERTScore: {bert_score}")

bert_score = bertscore.compute(predictions=[qa_retrieval_result["text"]], references=[long_answer], lang="en")
print(f"BERTScore: {bert_score}")

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BERTScore: {'precision': [0.7966367602348328], 'recall': [0.8586264848709106], 'f1': [0.8264709115028381], 'hashcode': 'roberta-large_L17_no-idf_version=0.3.12(hug_trans=4.35.1)'}
BERTScore: {'precision': [0.8198319673538208], 'recall': [0.8714783191680908], 'f1': [0.8448666334152222], 'hashcode': 'roberta-large_L17_no-idf_version=0.3.12(hug_trans=4.35.1)'}
