In [1]:
import json
import torch
import os
import numpy as np
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from sentence_transformers import SentenceTransformer
from torch.optim import Adam
from torch.nn.functional import cosine_similarity
from datasets import load_dataset
from transformers import pipeline  # Fact-checking pipeline

2025-03-09 11:30:14.562319: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-09 11:30:14.571839: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741500014.581838 3356167 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741500014.584647 3356167 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-09 11:30:14.597196: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_max_memory_cached()
device = torch.device("cpu")



In [3]:
# Initialize models
retriever_model = SentenceTransformer("all-MiniLM-L6-v2")
fact_checker = pipeline("text-classification", model="facebook/bart-large-mnli")


Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


In [4]:
# RL and training hyperparameters
LEARNING_RATE = 5e-5
num_epochs = 2
temperature = 0.8
max_new_tokens = 64
alpha = 0.5

In [5]:
def save_list_to_json(lst, filename):
    with open(filename, 'w') as file:
        json.dump(lst, file)

def rm_file(file_path):
    if os.path.exists(file_path):
        os.remove(file_path)
        print(f"File {file_path} removed successfully.")

In [6]:
def compute_reward(model_answer, gold_answer, context):
    """
    Compute reward based on cosine similarity and fact-verification confidence.
    """
    model_embedding = retriever_model.encode(model_answer, convert_to_tensor=True)
    gold_embedding = retriever_model.encode(gold_answer, convert_to_tensor=True)
    
    similarity_score = cosine_similarity(model_embedding.unsqueeze(0), gold_embedding.unsqueeze(0)).item()
    
    fact_check_result = fact_checker(f"Does the following context support the statement? {model_answer} Context: {context}")
    fact_confidence = max([res['score'] for res in fact_check_result if res['label'] == "ENTAILMENT"], default=0.5)
    
    return (similarity_score + fact_confidence) / 2

In [7]:
def retrieve_context(query, retrieval_list, hops=2):
    """
    Rank retrieved documents using cosine similarity and concatenate the top ones as context.
    """
    query_embedding = retriever_model.encode(query, convert_to_tensor=True)
    if len(query_embedding.shape) == 1:
        query_embedding = query_embedding.unsqueeze(0)
    scored_docs = sorted(
        retrieval_list,
        key=lambda doc: torch.nn.functional.cosine_similarity(
            query_embedding,
            retriever_model.encode(doc['text'], convert_to_tensor=True).unsqueeze(0),
            dim=1
        ).item(),
        reverse=True
    )
    context = ""
    for i in range(min(hops, len(scored_docs))):
        context += scored_docs[i]['text'] + "\n\n"
    return context

In [8]:
def query_bot_rl(prompt, temperature=temperature, max_new_tokens=max_new_tokens):
    """
    Generate an answer token-by-token using sampling, while accumulating log probabilities.
    Returns the generated text and total log probability.
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask

    generated_tokens = []
    log_probs = []
    for _ in range(max_new_tokens):
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        # Convert logits to float32 to avoid NaN issues.
        logits = outputs.logits[:, -1, :].to(torch.float32)
        logits = logits / temperature
        distribution = torch.distributions.Categorical(logits=logits)
        token = distribution.sample()  # Sample next token.
        log_prob = distribution.log_prob(token)
        generated_tokens.append(token)
        log_probs.append(log_prob)
        # Append the sampled token to the input sequence.
        input_ids = torch.cat([input_ids, token.unsqueeze(-1)], dim=-1)
        attention_mask = torch.cat([attention_mask, torch.ones_like(token.unsqueeze(-1))], dim=-1)
        # Stop if EOS token is generated.
        if token.item() == tokenizer.eos_token_id:
            break
    total_log_prob = torch.stack(log_probs).sum()
    generated_text = tokenizer.decode(torch.cat(generated_tokens), skip_special_tokens=True)
    return generated_text, total_log_prob

In [9]:
# Load model and tokenizer
model_name = "microsoft/DialoGPT-medium"
save_file = 'qa_output/gpt_rl_triviaqa_fact.json'
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

prefix = (
    "Below is a question followed by context from different sources. "
    "Please answer using the context. If insufficient, respond 'Insufficient Information'. "
    "Answer directly."
)

dataset = load_dataset("trivia_qa", "rc", split="train[:100]")
doc_data = []
for example in dataset:
    query = example["question"]
    # Use available evidence if present; otherwise, provide a dummy context.
    if "evidence" in example and len(example["evidence"]) > 0:
        retrieval_list = [{"text": example["evidence"][0]["text"]}]
    else:
        retrieval_list = [{"text": "No additional context available."}]
    # Assume the answer is stored as a string or in a dict.
    answer = example["answer"]["value"] if isinstance(example["answer"], dict) else example["answer"]
    doc_data.append({
        "query": query,
        "retrieval_list": retrieval_list,
        "answer": answer,
        "question_type": "trivia"
    })

optimizer = Adam(model.parameters(), lr=LEARNING_RATE)

  return self.fget.__get__(instance, owner)()


Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

In [24]:
# Training Loop with Hybrid Training on TriviaQA subset.
rm_file(save_file)
save_list = []

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    for d in tqdm(doc_data):
        # Retrieve context and construct prompt.
        retrieval_list = d['retrieval_list']
        context = retrieve_context(d['query'], retrieval_list, hops=2)
        prompt = f"{prefix}\n\nQuestion: {d['query']}\n\nContext:\n\n{context}\n\nAnswer:"
        
        model.train()

        # ----- RL Branch -----
        # Generate an answer using RL sampling.
        rl_response, total_log_prob = query_bot_rl(prompt)
        # Compute reward comparing the generated answer with the gold answer.
        reward = compute_reward(rl_response, d['answer'], context)
        reward_tensor = torch.tensor(reward, device=device, dtype=total_log_prob.dtype)
        rl_loss = -reward_tensor * total_log_prob

        # ----- Supervised (MLE) Branch -----
        # Construct a combined input: prompt followed by the gold answer.
        # We add a marker "Answer:" to indicate the beginning of the answer.
        prompt_text = prompt + "\n\nAnswer:"
        prompt_enc = tokenizer(prompt_text, return_tensors="pt").to(device)
        answer_enc = tokenizer(d['answer'], return_tensors="pt").to(device)
        # Concatenate the prompt and gold answer.
        input_ids = torch.cat([prompt_enc.input_ids, answer_enc.input_ids], dim=1)
        attention_mask = torch.cat([prompt_enc.attention_mask, answer_enc.attention_mask], dim=1)
        # Create labels: ignore loss for prompt tokens.
        labels = input_ids.clone()
        labels[:, :prompt_enc.input_ids.shape[1]] = -100

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        supervised_loss = outputs.loss

        # ----- Hybrid Loss -----
        # Combine supervised (MLE) loss and RL loss.
        combined_loss = alpha * supervised_loss + (1 - alpha) * rl_loss

        optimizer.zero_grad()
        combined_loss.backward()
        optimizer.step()

        # Save training metrics.
        save = {
            'query': d['query'],
            'prompt': prompt,
            'rl_model_answer': rl_response,
            'gold_answer': d['answer'],
            'question_type': d['question_type'],
            'reward': reward,
            'supervised_loss': supervised_loss.item(),
            'rl_loss': rl_loss.item(),
            'combined_loss': combined_loss.item()
        }
        save_list.append(save)
        
    # Optionally, save a checkpoint after each epoch.
    torch.save(model.state_dict(), f"checkpoint_fact_epoch_{epoch+1}.pt")
    print(f"Epoch {epoch+1} completed.")

save_list_to_json(save_list, save_file)


Epoch 1/2


100%|██████████| 100/100 [09:54<00:00,  5.94s/it]


Epoch 1 completed.
Epoch 2/2


100%|██████████| 100/100 [12:10<00:00,  7.30s/it]


Epoch 2 completed.


File qa_output/gpt_rl_triviaqa_fact.json removed successfully.
Epoch 3/5


100%|██████████| 100/100 [10:52<00:00,  6.53s/it]


Epoch 3 completed.
Epoch 4/5


100%|██████████| 100/100 [15:05<00:00,  9.06s/it]


Epoch 4 completed.
Epoch 5/5


100%|██████████| 100/100 [08:57<00:00,  5.37s/it]


Epoch 5 completed.


In [10]:
# ----- Evaluation on Unseen Data -----
# Load the saved model and set it to evaluation mode.
model.load_state_dict(torch.load("checkpoint_fact_epoch_5.pt"))
model.eval()

# Load a few questions from the TriviaQA test split.
test_dataset = load_dataset("trivia_qa", "rc", split="test[:5]")

for example in test_dataset:
    query = example["question"]
    if "evidence" in example and len(example["evidence"]) > 0:
        retrieval_list = [{"text": example["evidence"][0]["text"]}]
    else:
        retrieval_list = [{"text": "No additional context available."}]
    context = retrieve_context(query, retrieval_list, hops=2)
    prompt = f"{prefix}\n\nQuestion: {query}\n\nContext:\n\n{context}\n\nAnswer:"
    response, _ = query_bot_rl(prompt)
    print(f"Question: {query}")
    print(f"Generated Answer: {response}")
    # print(f"Gold Answer: {example['answer']['value']}")
    print("----\n")

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Question: Asmara international airport is in which country?
Generated Answer: Barbuda, FL
----

Question: At whose concert were 11 people trampled to death in Ohio in 1979?
Generated Answer: 6th Floor concert
----

Question: Andy Warhol/'s 60s exhibition featured cans of which product?
Generated Answer: Brasil and olive oil
----

Question: San Giusto international airport is in which county?
Generated Answer: San Jose, Ca
----

Question: Who had a 60s No 1 with Travelin' Man?
Generated Answer: Traveling Man
----



In [20]:
from evaluate import load

test_data = load_dataset("trivia_qa", "rc", split="test[:100]")

doc_data = []
for example in test_data:
    query = example["question"]
    # Use available evidence if present; otherwise, provide a dummy context.
    if "evidence" in example and len(example["evidence"]) > 0:
        retrieval_list = [{"text": example["evidence"][0]["text"]}]
    else:
        retrieval_list = [{"text": "No additional context available."}]
    # Assume the answer is stored as a string or in a dict.
    answer = example["answer"]["value"] if isinstance(example["answer"], dict) else example["answer"]
    doc_data.append({
        "query": query,
        "retrieval_list": retrieval_list,
        "answer": answer,
        "question_type": "trivia"
    })

# Load evaluation metrics
bleu = load("bleu")
rouge = load("rouge")
bert_score = load("bertscore")

# Generate predictions
predictions = []
ground_truths = []

D = []

for data in tqdm(doc_data):
    retrieval_list = data['retrieval_list']
    context = retrieve_context(data['query'], retrieval_list, hops=2)
    prompt = f"{prefix}\n\nQuestion: {data['query']}\n\nContext:\n\n{context}"
    gold_answer = data['answer']
    model_answer, _ = query_bot_rl(prompt)

    predictions.append(model_answer)
    ground_truths.append(gold_answer)
    
    reward = compute_reward(model_answer, gold_answer, context)
    D.append(reward)



Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

100%|██████████| 100/100 [04:27<00:00,  2.67s/it]


In [50]:
# Compute Exact Match (EM)
exact_match = sum([1 if pred.lower().strip() == gold.lower().strip() else 0 
                   for pred, gold in zip(predictions, ground_truths)]) / len(ground_truths)

# Compute BLEU Score
bleu_score = bleu.compute(predictions=predictions, references=ground_truths)

# Compute ROUGE Score
rouge_score = rouge.compute(predictions=predictions, references=ground_truths)

# Compute BERTScore
bert_score_result = bert_score.compute(predictions=predictions, references=ground_truths, lang="en")
avg_bert_score = sum(bert_score_result['f1']) / len(bert_score_result['f1'])

# Compute average reward
avg_reward = np.mean(D)

# Print Results
print(f"Exact Match (EM): {exact_match:.4f}")
print(f"BLEU Score: {bleu_score['bleu']:.4f}")
print(f"ROUGE-L Score: {rouge_score['rougeL']:.4f}")
print(f"BERTScore (F1): {sum(bert_score_result['f1']) / len(bert_score_result['f1']):.4f}")
print(f"Average Reward: {avg_reward:.4f}")

# # Save results to a file
# results = {
#     'exact_match': exact_match,
#     'bleu': bleu_score['bleu'],
#     'rougeL': rouge_score['rougeL'],
#     'bert_score': avg_bert_score,
#     'average_reward': avg_reward
# }

# print(results)
print("Evaluation completed.")

Exact Match (EM): 0.0000
BLEU Score: 0.0000
ROUGE-L Score: 0.0000
BERTScore (F1): 0.8148
Average Reward: 0.3106
Evaluation completed.
