In [2]:
import re
import ollama
from src.data_loading import load_data
from src.parsing import parse_options
from src.rationale_generation import generate_rationale_and_answer
from src.utils import is_rationale_correct
from src.rationalization import rationalize
from src.prompt_generation import create_prompt_examples, create_prompt_set

In [3]:
NUM_PROMPT_EXAMPLES = 10 #
NUM_EXAMPLES_TO_PROCESS = 5 #number of examples to process

### Load Dataset

In [4]:
ds_train = load_data()

dataset_D = ds_train.select(range(NUM_PROMPT_EXAMPLES, len(ds_train)))
dataset_D_subset = dataset_D.select(range(NUM_EXAMPLES_TO_PROCESS))

### Create prompt examples

In [5]:
prompt_examples = create_prompt_examples(ds_train, NUM_PROMPT_EXAMPLES)

prompt_examples

[{'question': "Two friends plan to walk along a 43-km trail, starting at opposite ends of the trail at the same time. If Friend P's rate is 15% faster than Friend Q's, how many kilometers will Friend P have walked when they pass each other?",
  'options': ['A)21', 'B)21.5', 'C)22', 'D)22.5', 'E)23'],
  'rationale': 'If Q complete x kilometers, then P completes 1.15x kilometers.\nx + 1.15x = 43\n2.15x=43\nx = 43/2.15 = 20\nThen P will have have walked 1.15*20=23 km.\nThe answer is E.',
  'answer': 'E'},
 {'question': 'In the coordinate plane, points (x, 1) and (5, y) are on line k. If line k passes through the origin and has slope 1/5, then what are the values of x and y respectively?',
  'options': ['A)4 and 1', 'B)1 and 5', 'C)5 and 1', 'D)3 and 5', 'E)5 and 3'],
  'rationale': 'Line k passes through the origin and has slope 1/5 means that its equation is y=1/5*x.\nThus: (x, 1)=(5, 1) and (5, y) = (5,1) -->x=5 and y=1\nAnswer: C',
  'answer': 'C'},
 {'question': 'For all numbers p and

In [8]:
# Generate the prompt set
prompt_set = create_prompt_set(ds_train, NUM_PROMPT_EXAMPLES)
prompt_set

"Question: Two friends plan to walk along a 43-km trail, starting at opposite ends of the trail at the same time. If Friend P's rate is 15% faster than Friend Q's, how many kilometers will Friend P have walked when they pass each other?\nOptions:\nA)21\nB)21.5\nC)22\nD)22.5\nE)23\nAnswer Explanation: If Q complete x kilometers, then P completes 1.15x kilometers.\nx + 1.15x = 43\n2.15x=43\nx = 43/2.15 = 20\nThen P will have have walked 1.15*20=23 km.\nThe answer is E.\nAnswer: E\nPlease respond with only the letter corresponding to the correct answer (A, B, C, D, or E).\n###\nQuestion: In the coordinate plane, points (x, 1) and (5, y) are on line k. If line k passes through the origin and has slope 1/5, then what are the values of x and y respectively?\nOptions:\nA)4 and 1\nB)1 and 5\nC)5 and 1\nD)3 and 5\nE)5 and 3\nAnswer Explanation: Line k passes through the origin and has slope 1/5 means that its equation is y=1/5*x.\nThus: (x, 1)=(5, 1) and (5, y) = (5,1) -->x=5 and y=1\nAnswer: C

In [9]:
# Initialize lists to hold correct and incorrect pairs
correct_pairs = []
incorrect_pairs = []

In [11]:
# Iterate over each example in the subset
for idx, example in enumerate(dataset_D_subset):
    question = example['question']
    options_list = [opt.strip() for opt in example['options']]
    correct_option = example['correct'].strip().upper()
    
    # Parse options into a dictionary
    options_dict = parse_options(options_list)
    
    # Extract the correct answer text
    correct_answer_text = options_dict.get(correct_option, None)
    
    if correct_answer_text is None:
        print(f"Warning: Correct answer not found for index {idx}. Skipping this example.")
        continue  # Skip if the correct answer is missing
    
    # Generate rationale and answer using the updated function
    generated_rationale, generated_answer_text = generate_rationale_and_answer(question, options_dict, prompt_set)
    
    # Initialize generated_answer_letter
    generated_answer_letter = None
    
    # Map the generated answer text to the option letter
    if generated_answer_text:
        generated_answer_letter = generated_answer_text.strip().upper()
    
    match = re.search(r'And the answer is:\s*([A-E])', generated_rationale, re.IGNORECASE)
    if match:
        generated_answer_letter = match.group(1).upper()
    
    # Categorize the example based on the accuracy of the generated answer
    if (generated_answer_letter == correct_option) and is_rationale_correct(generated_rationale, correct_option, question):
        # Correct answer
        correct_pairs.append({
            'question': question,
            'options': options_dict,
            'rationale': generated_rationale,
            'answer': correct_option
        })
        print('Correct:', {
            'question': question,
            'options': options_dict,
            'rationale': generated_rationale,
            'generated_answer': generated_answer_letter,
            'correct_answer': correct_option
        })
    else:
        # Incorrect answer
        incorrect_pairs.append({
            'question': question,
            'options': options_dict,
            'rationale': generated_rationale,
            'generated_answer': generated_answer_letter,
            'correct_answer': correct_option
        })
        print("Incorrect:",  {
            'question': question,
            'options': options_dict,
            'rationale': generated_rationale,
            'generated_answer': generated_answer_letter,
            'correct_answer': correct_option
        })
    
    # Print progress every example
    print(f"Processed {idx + 1} questions.\n")

Correct: {'question': 'If Tim had lunch at $50 and he gave 20% tip, how much did he spend?', 'options': {'A': '$60.00', 'B': '$35.42', 'C': '$60.60', 'D': '$21.56', 'E': '$78.45'}, 'rationale': "Let's break down the problem step by step.\n\n1. Tim had lunch at $50.\n2. He gave a 20% tip on the total bill, which includes the lunch cost and the tip itself (initially $0).\n3. To calculate the tip amount, we first need to add it to the initial bill: $50 + $0 = $50\n4. The tip is 20% of the new total ($50), so we multiply $50 by 0.20: $50 * 0.20 = $10\n5. Now that we know the total bill with tip, which is still just the initial lunch cost, we add the tip to the original bill: $50 + $10 = $60\n\nSo, Tim's total expenditure (the amount he spent) is the original price of lunch plus the 20% tip.\n\nThe correct calculation for the final amount spent by Tim is:\n\nOriginal price of lunch: $50\nTip: $10 (20% of $50)\nTotal bill with tip: $60 (=$50 + $10)\n\nThus, we can see that option A) $60.00 i

In [381]:
correct_pairs

[{'question': 'If Tim had lunch at $50 and he gave 20% tip, how much did he spend?',
  'options': {'A': '$60.00',
   'B': '$35.42',
   'C': '$60.60',
   'D': '$21.56',
   'E': '$78.45'},
  'rationale': 'To find the amount Tim spent, we need to calculate the tip and add it to the cost of lunch.\n\nThe tip is 20% of $50.\nTip = 0.2 x $50 = $10\n\nTotal amount spent = Cost of lunch + Tip\n= $50 + $10\n= $60\n\nSo, the correct answer is:\n\nAnd the answer is: A',
  'answer': 'A'},
 {'question': 'Rs. 825 becomes Rs. 956 in 3 years at a certain rate of simple interest.If the rate of interest is increased by 4% ,What amount will Rs. 825 become in 3 years ?',
  'options': {'A': 'Rs. 1020.80',
   'B': 'Rs. 1025',
   'C': 'Rs. 1055',
   'D': 'Data inadequate',
   'E': 'None of these'},
  'rationale': "To solve this problem, we need to first find the rate of interest at which Rs. 825 becomes Rs. 956 in 3 years.\n\nThe simple interest formula is:\n\nInterest = (Principal × Rate × Time)\n\nWe are g

In [382]:
incorrect_pairs

[{'question': 'If q is the square of a positive integer, which of the following must be equal to the square of the next positive integer?',
  'options': {'A': '√n + 1',
   'B': 'n + 1',
   'C': 'n^2 + 1',
   'D': 'q + 2√q + 1',
   'E': 'n^2 + 2n + 1'},
  'rationale': "To solve this problem, let's analyze the options and see which one must be equal to the square of the next positive integer.\n\nThe square of a positive integer is given as q = n^2. We are looking for an expression that must be equal to the square of the next positive integer, i.e., (n+1)^2.\n\nLet's examine each option:\n\nA) √n + 1: This expression involves adding 1 to the square root of n, which does not necessarily result in a perfect square.\n\nB) n + 1: This is simply the next positive integer, but it's not its square.\n\nC) n^2 + 1: This option adds 1 to the square of n, but we need an expression that results in the square of (n+1).\n\nD) q + 2√q + 1: This involves adding 2 times the square root of q and 1 to the o

Testing it out. 

In [None]:
# # Test with a specific example
# test_example = dataset_D_subset[2]  # Adjust the index as needed
# question = test_example['question']
# options_list = [opt.strip() for opt in test_example['options']]
# correct_option = test_example['correct'].strip().upper()

# # Parse options into a dictionary
# options_dict = parse_options(options_list)

# # Extract the correct answer
# correct_answer = options_dict.get(correct_option, None)

# print("Question:", question)
# print("Options:", options_dict)
# print("Correct Option:", correct_option)
# print("Correct Answer:", correct_answer)

# # Generate rationale and answer
# generated_rationale, generated_answer_text = generate_rationale_and_answer(question, options_dict, prompt_set)

# print("\nGenerated Rationale:")
# print(generated_rationale)
# print("\nGenerated Answer:", generated_answer_text)
# print("Correct Answer:", correct_answer)


In [383]:
total = len(correct_pairs) + len(incorrect_pairs)
accuracy = len(correct_pairs) / total * 100
print(f"Total questions processed: {total}")
print(f"Correct answers: {len(correct_pairs)}")
print(f"Incorrect answers: {len(incorrect_pairs)}")
print(f"Accuracy: {accuracy:.2f}%")

Total questions processed: 5
Correct answers: 3
Incorrect answers: 2
Accuracy: 60.00%


In [384]:
# Process the incorrect answers
for pair in incorrect_pairs:
    question = pair['question']
    options = pair['options']
    correct_answer = pair['correct_answer']

    # Generate the rationale with the correct answer as a hint
    generated_rationale = rationalize(question, options, correct_answer, prompt_set)

    # Add the rationalized example to correct_pairs
    correct_pairs.append({
        'question': question,
        'options': options,
        'rationale': generated_rationale,
        'answer': correct_answer
    })

    print({
        'question': question,
        'options': options,
        'rationale': generated_rationale,
        'answer': correct_answer
    })

{'question': 'If q is the square of a positive integer, which of the following must be equal to the square of the next positive integer?', 'options': {'A': '√n + 1', 'B': 'n + 1', 'C': 'n^2 + 1', 'D': 'q + 2√q + 1', 'E': 'n^2 + 2n + 1'}, 'rationale': "To answer this question, let's analyze each option given and evaluate whether it must be equal to the square of the next positive integer.\n\nOption A: √n + 1. If q is the square of a positive integer, n = √q, then (√n + 1)² = (n + 1)² ≠ n² + 2√q + 1, so this option does not necessarily equal the square of the next positive integer.\n\nOption B: n + 1. This is a simple increment of n by 1, but it doesn't form a perfect square unless n was originally a perfect square minus 1, which isn't guaranteed.\n\nOption C: n^2 + 1. Similarly, this increments n² by 1, forming an imperfect square, and like option B, it assumes knowledge about the nature of n that we don't have from the question.\n\nOption D: q + 2√q + 1. Let's examine what happens when

In [None]:
class RationalesDataset(torch.utils.data.Dataset):
    def __init__(self, pairs, tokenizer):
        self.examples = []
        for pair in pairs:
            question = pair['question']
            options = '\n'.join(pair['options'])
            rationale = pair['rationale']
            answer = pair['answer']
            input_text = (
                f"Question: {question}\n"
                f"Options:\n{options}\n"
                f"Answer Explanation: {rationale}\n"
                f"Answer: {answer}"
            )
            input_ids = tokenizer.encode(input_text, truncation=True, max_length=512)
            self.examples.append(torch.tensor(input_ids))

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]