In [9]:
from datasets import load_dataset

ds = load_dataset("deepmind/aqua_rat", "raw")

ds

  from .autonotebook import tqdm as notebook_tqdm
Generating train split: 100%|██████████| 97467/97467 [00:00<00:00, 1411925.99 examples/s]
Generating test split: 100%|██████████| 254/254 [00:00<00:00, 215049.09 examples/s]
Generating validation split: 100%|██████████| 254/254 [00:00<00:00, 181676.88 examples/s]


DatasetDict({
    train: Dataset({
        features: ['question', 'options', 'rationale', 'correct'],
        num_rows: 97467
    })
    test: Dataset({
        features: ['question', 'options', 'rationale', 'correct'],
        num_rows: 254
    })
    validation: Dataset({
        features: ['question', 'options', 'rationale', 'correct'],
        num_rows: 254
    })
})

In [309]:
ds_train = ds['train']

In [310]:
ds_train[8]

{'question': 'The entrance fee for a fair is $5 for persons under the age of 18, and 20% more for persons older. Each ride at the fair costs $0.50. If Joe goes with her 6 years old twin brothers, and they each took 3 rides in total. How much money does Joe end up spending at the fair?',
 'options': ['A)16', 'B)20.5', 'C)17.5', 'D)20', 'E)4.5'],
 'rationale': 'Total entrance fee is (2*$5) + (1.20*5)= $16\nTotal rides fee is (0.50*3)*3= $4.50\nTotal money spent is $20.50\nAnswer is B',
 'correct': 'B'}

In [329]:
import re

def parse_options(options_list):
    """
    Parses a list of option strings into a dictionary.
    
    Parameters:
        options_list (list): List of option strings, e.g., ['A)21', 'B)21.5', ...]
        
    Returns:
        dict: Dictionary mapping option letters to option texts.
    """
    options_dict = {}
    for opt in options_list:
        parts = opt.split(')', 1)
        if len(parts) == 2:
            key = parts[0].strip().upper()
            value = parts[1].strip()
            # Remove any leading option letters from the value (e.g., 'A)$60.00' -> '$60.00')
            value = re.sub(r'^[A-E]\)', '', value).strip()
            if key in options_dict:
                print(f"Warning: Duplicate option key '{key}' found. Overwriting previous value.")
            options_dict[key] = value
        else:
            print(f"Warning: Option '{opt}' is not in the expected format 'Letter)Text'. Skipping.")
    return options_dict

In [330]:
# Number of examples to use for the prompt set
NUM_PROMPT_EXAMPLES = 10

# Create the prompt examples with explicit instruction to answer with only the letter
prompt_examples = []
for i in range(NUM_PROMPT_EXAMPLES):
    example = ds_train[i]
    question = example['question']
    options = example['options']
    rationale = example['rationale']
    correct_option = example['correct'].strip().upper()
    
    # Parse options into a dictionary
    options_dict = parse_options(options)
    
    # Extract the correct answer
    correct_answer = options_dict.get(correct_option, None)
    
    if correct_answer is None:
        print(f"Warning: Correct answer not found for prompt example index {i}. Skipping.")
        continue  # Skip if the correct answer is missing
    
    # Escape backslashes in rationale to prevent format() errors
    escaped_rationale = rationale.replace('\\', '\\\\')
    
    # Construct the prompt example with instruction to answer with only the letter
    example_response = (
        "Question: {}\n"
        "Options:\n{}\n"
        "Answer Explanation: {}\n"
        "Answer: {}\n"
        "And the answer is: {}\n"
        "Please respond with only the letter corresponding to the correct answer (A, B, C, D, or E).\n"
        "###\n"
    ).format(
        question,
        '\n'.join(options),
        escaped_rationale,
        correct_answer,
        correct_option  # Ensure the model sees the correct option letter in examples
    )
    
    prompt_examples.append({
        'question': question,
        'options': options,
        'rationale': escaped_rationale,
        'answer': correct_option  # Store the letter instead of the text
    })


In [331]:
prompt_examples

[{'question': "Two friends plan to walk along a 43-km trail, starting at opposite ends of the trail at the same time. If Friend P's rate is 15% faster than Friend Q's, how many kilometers will Friend P have walked when they pass each other?",
  'options': ['A)21', 'B)21.5', 'C)22', 'D)22.5', 'E)23'],
  'rationale': 'If Q complete x kilometers, then P completes 1.15x kilometers.\nx + 1.15x = 43\n2.15x=43\nx = 43/2.15 = 20\nThen P will have have walked 1.15*20=23 km.\nThe answer is E.',
  'answer': 'E'},
 {'question': 'In the coordinate plane, points (x, 1) and (5, y) are on line k. If line k passes through the origin and has slope 1/5, then what are the values of x and y respectively?',
  'options': ['A)4 and 1', 'B)1 and 5', 'C)5 and 1', 'D)3 and 5', 'E)5 and 3'],
  'rationale': 'Line k passes through the origin and has slope 1/5 means that its equation is y=1/5*x.\nThus: (x, 1)=(5, 1) and (5, y) = (5,1) -->x=5 and y=1\nAnswer: C',
  'answer': 'C'},
 {'question': 'For all numbers p and

In [332]:
import re

def is_rationale_correct(generated_rationale, correct_option_letter, question):
    """
    Heuristic checks to verify the correctness of the generated rationale.
    """
    # Check if the correct option letter is mentioned in the rationale
    if correct_option_letter not in generated_rationale:
        return False
    
    # Additional heuristic: presence of mathematical operations (for math questions)
    math_indicators = ['+', '-', '*', '/', '=', 'km', 'percentage', 'rate', 'total']
    if any(indicator in generated_rationale.lower() for indicator in math_indicators):
        return True
    
    # If no math indicators are found, but the correct answer is mentioned, consider it correct
    return True

In [333]:
def extract_answer_letter(model_response):
    """
    Extracts the answer letter from the model's response.
    
    Parameters:
        model_response (str): The raw response from the model.
        
    Returns:
        str or None: The extracted answer letter (A, B, C, D, or E) or None if extraction fails.
    """
    # Remove any surrounding whitespace and convert to uppercase
    answer = model_response.strip().upper()
    
    # Check if the answer is a single valid option letter
    if answer in ['A', 'B', 'C', 'D', 'E']:
        return answer
    else:
        # If the answer doesn't match, return None
        return None

In [334]:
# Function to escape special characters to prevent formatting issues
def escape_special_characters(text):
    return text.replace('\\', '\\\\').replace('{', '{{').replace('}', '}}')

# Create the prompt examples with explicit instruction to answer with only the letter
prompt_set = ""
for i in range(NUM_PROMPT_EXAMPLES):
    example = ds_train[i]
    question = example['question']
    options = example['options']
    rationale = escape_special_characters(example['rationale'])
    correct_option = example['correct'].strip().upper()
    
    # Parse options into a dictionary
    options_dict = parse_options(options)
    
    # Extract the correct answer text
    correct_answer_text = options_dict.get(correct_option, "")
    
    # Construct the prompt example
    prompt_set += (
        "Question: {}\n"
        "Options:\n{}\n"
        "Answer Explanation: {}\n"
        "Answer: {}\n"
        "Please respond with only the letter corresponding to the correct answer (A, B, C, D, or E).\n"
        "###\n"
    ).format(
        question,
        '\n'.join(options),
        rationale,
        correct_option  # Option letter
    )

# Exclude prompt examples from the main dataset
start_index = NUM_PROMPT_EXAMPLES
dataset_D = ds_train.select(range(start_index, len(ds_train)))

print(f"Type of dataset_D: {type(dataset_D)}")
print(f"First element of dataset_D: {dataset_D[0]}")
print(f"Type of first element: {type(dataset_D[0])}")

Type of dataset_D: <class 'datasets.arrow_dataset.Dataset'>
First element of dataset_D: {'question': 'If Tim had lunch at $50 and he gave 20% tip, how much did he spend?', 'options': ['A)A)$60.00', 'B)B)$35.42', 'C)C)$60.60', 'D)D)$21.56', 'E)E)$78.45'], 'rationale': 'The tip is 20% of what he paid for lunch.\ntip = 20% of 50.00 = (20/100)*50.00 = = $10.00\nTotal spent\n50.00 + 10.00 = $60.00\ncorrect answer is A)$60.00', 'correct': 'A'}
Type of first element: <class 'dict'>


In [338]:
def extract_answer(generated_rationale, options):
    """
    Extracts the answer from the generated rationale based on the specific ending format.
    
    Parameters:
        generated_rationale (str): The rationale generated by the model.
        options (dict): A dictionary mapping option letters to option texts.
        
    Returns:
        str or None: The extracted answer text, or None if extraction fails.
    """
    # Regex to match the specific ending instruction
    match = re.search(r'And the answer is:\s*([A-E])\)?\s*([$]?[0-9.]+)', generated_rationale, re.IGNORECASE)
    
    if match:
        option_letter = match.group(1).upper()
        answer_value = match.group(2).strip()
        
        # Check if the option letter is valid
        if option_letter in options:
            # Optionally, verify that the answer value matches the option
            if answer_value == options[option_letter]:
                return options[option_letter]
            else:
                # If there's a mismatch, return the option's answer text
                return options[option_letter]
        else:
            # If the letter isn't in options, attempt to match the answer value directly
            for key, value in options.items():
                if value.lower() == answer_value.lower():
                    return value
            return None
    else:
        # If the specific pattern isn't found, return None
        return None

In [359]:
import re
import ollama

qwen = "qwen2.5:7b"
llama = "llama3.1:8b"

def generate_rationale_and_answer(question, options, prompt_set):
    """
    Generates a rationale and answer for a given question and options using Ollama's model.
    
    Parameters:
        question (str): The question to be answered.
        options (dict): A dictionary mapping option letters to option texts.
        prompt_set (str): The initial prompt set containing examples.
        
    Returns:
        tuple: (generated_rationale (str), generated_answer_text (str or None))
    """
    # Prepare the options text
    options_text = '\n'.join([f"{key}) {value}" for key, value in options.items()])
    
    # Construct the input prompt with the new instruction
    input_text = (
        prompt_set
        + "\n\n"
        + "Please solve the following problem and provide a detailed explanation. "
        + "Ensure that you end your response with the following format:\n"
        + "\"And the answer is: [answer]\"\n\n"
        + f"Question: {question}\n"
        + "Options:\n"
        + options_text
        + "\nAnswer Explanation:"
    )
    
    try:
        # Use Ollama to get the response
        response = ollama.chat(model=llama, messages=[
            {
                'role': 'user',
                'content': input_text,
            },
        ])
        
        # Extract the generated rationale
        generated_rationale_full = response['message']['content'].strip()
        
        # Assuming the model follows instructions, the last line contains "And the answer is: [answer]"
        generated_rationale = generated_rationale_full
        
        # Extract the generated answer using the new pattern
        generated_answer_text = extract_answer(generated_rationale, options)
        
    except Exception as e:
        print(f"Error generating rationale: {e}")
        generated_rationale = ''
        generated_answer_text = None
    
    return generated_rationale, generated_answer_text


In [360]:
# Function to map the generated letter to the corresponding answer text (if needed)
def map_letter_to_answer(letter, options_dict):
    """
    Maps the option letter to the corresponding answer text.
    
    Parameters:
        letter (str): The option letter generated by the model (A, B, C, D, or E).
        options_dict (dict): A dictionary mapping option letters to option texts.
        
    Returns:
        str or None: The corresponding answer text if the letter is valid, else None.
    """
    letter = letter.strip().upper()
    return options_dict.get(letter, None)

def map_answer_to_option(generated_answer, options_dict):
    """
    Maps the generated answer to the corresponding option letter.
    """
    generated_answer = generated_answer.strip().upper()
    for letter, option_text in options_dict.items():
        if generated_answer == option_text.strip().upper():
            return letter
    return None

# Define the number of examples to process
NUM_EXAMPLES_TO_PROCESS = 5

# Adjust the dataset_D to include only the specified number of examples
dataset_D_subset = dataset_D.select(range(NUM_EXAMPLES_TO_PROCESS))

# Initialize lists to hold correct and incorrect pairs
correct_pairs = []
incorrect_pairs = []

# Iterate over each example in the subset
for idx, example in enumerate(dataset_D_subset):
    question = example['question']
    options_list = [opt.strip() for opt in example['options']]
    correct_option = example['correct'].strip().upper()
    
    # Parse options into a dictionary
    options_dict = parse_options(options_list)
    
    # Extract the correct answer text
    correct_answer_text = options_dict.get(correct_option, None)
    
    if correct_answer_text is None:
        print(f"Warning: Correct answer not found for index {idx}. Skipping this example.")
        continue  # Skip if the correct answer is missing
    
    # Generate rationale and answer using the updated function
    generated_rationale, generated_answer_text = generate_rationale_and_answer(question, options_dict, prompt_set)
    
    # Initialize generated_answer_letter
    generated_answer_letter = None
    
    # Map the generated answer text to the option letter
    if generated_answer_text:
        generated_answer_letter = generated_answer_text.strip().upper()
    
    match = re.search(r'And the answer is:\s*([A-E])', generated_rationale, re.IGNORECASE)
    if match:
        generated_answer_letter = match.group(1).upper()
    
    # Categorize the example based on the accuracy of the generated answer
    if (generated_answer_letter == correct_option) and is_rationale_correct(generated_rationale, correct_option, question):
        # Correct answer
        correct_pairs.append({
            'question': question,
            'options': options_dict,
            'rationale': generated_rationale,
            'answer': correct_option
        })
        print('Correct:', {
            'question': question,
            'options': options_dict,
            'rationale': generated_rationale,
            'generated_answer': generated_answer_letter,
            'correct_answer': correct_option
        })
    else:
        # Incorrect answer
        incorrect_pairs.append({
            'question': question,
            'options': options_dict,
            'rationale': generated_rationale,
            'generated_answer': generated_answer_letter,
            'correct_answer': correct_option
        })
        print("Incorrect:",  {
            'question': question,
            'options': options_dict,
            'rationale': generated_rationale,
            'generated_answer': generated_answer_letter,
            'correct_answer': correct_option
        })
    
    # Print progress every example
    print(f"Processed {idx + 1} questions.\n")

Correct: {'question': 'If Tim had lunch at $50 and he gave 20% tip, how much did he spend?', 'options': {'A': '$60.00', 'B': '$35.42', 'C': '$60.60', 'D': '$21.56', 'E': '$78.45'}, 'rationale': "To find out how much Tim spent in total, we need to calculate the tip and add it to the cost of lunch.\n\nFirst, let's calculate the tip:\n\nTip = 20% of $50\n= (20/100) × 50\n= 0.2 × 50\n= $10\n\nNow, let's add the tip to the cost of lunch:\n\nTotal amount spent = Cost of lunch + Tip\n= $50 + $10\n= $60\n\nSo, Tim spent a total of $60.\n\nAnd the answer is: A", 'generated_answer': 'A', 'correct_answer': 'A'}
Processed 1 questions.

Incorrect: {'question': 'Rs. 825 becomes Rs. 956 in 3 years at a certain rate of simple interest.If the rate of interest is increased by 4% ,What amount will Rs. 825 become in 3 years ?', 'options': {'A': 'Rs. 1020.80', 'B': 'Rs. 1025', 'C': 'Rs. 1055', 'D': 'Data inadequate', 'E': 'None of these'}, 'rationale': "Let's break down the problem step by step.\n\nFirst, 

In [361]:
correct_pairs

[{'question': 'If Tim had lunch at $50 and he gave 20% tip, how much did he spend?',
  'options': {'A': '$60.00',
   'B': '$35.42',
   'C': '$60.60',
   'D': '$21.56',
   'E': '$78.45'},
  'rationale': "To find out how much Tim spent in total, we need to calculate the tip and add it to the cost of lunch.\n\nFirst, let's calculate the tip:\n\nTip = 20% of $50\n= (20/100) × 50\n= 0.2 × 50\n= $10\n\nNow, let's add the tip to the cost of lunch:\n\nTotal amount spent = Cost of lunch + Tip\n= $50 + $10\n= $60\n\nSo, Tim spent a total of $60.\n\nAnd the answer is: A",
  'answer': 'A'},
 {'question': 'q is a positive integer and multiple of 2; p = 4^q, what is the remainder when p is divided by 10?',
  'options': {'A': '10',
   'B': '6',
   'C': '4',
   'D': '0',
   'E': 'It Cannot Be Determined'},
  'rationale': 'Since q is a positive integer and multiple of 2, we can write it as q = 2k for some positive integer k.\n\nNow, p = 4^q = 4^(2k) = (4^2)^k = 16^k\n\nWe know that any power of 16 ends

In [362]:
incorrect_pairs

[{'question': 'Rs. 825 becomes Rs. 956 in 3 years at a certain rate of simple interest.If the rate of interest is increased by 4% ,What amount will Rs. 825 become in 3 years ?',
  'options': {'A': 'Rs. 1020.80',
   'B': 'Rs. 1025',
   'C': 'Rs. 1055',
   'D': 'Data inadequate',
   'E': 'None of these'},
  'rationale': "Let's break down the problem step by step.\n\nFirst, we know that Rs. 825 becomes Rs. 956 in 3 years at a certain rate of simple interest. We need to find this initial rate of interest.\n\nThe formula for simple interest is:\n\nI = (P x R x T) / 100\n\nwhere I is the interest earned, P is the principal amount (Rs. 825), R is the rate of interest, and T is the time period in years (3 years).\n\nWe can rearrange this formula to solve for R:\n\nR = (I x 100) / (P x T)\n\nGiven that Rs. 825 becomes Rs. 956 in 3 years, the interest earned I is:\n\nI = 956 - 825 = Rs. 131\n\nNow we can plug in the values:\n\nR = (131 x 100) / (825 x 3)\n= 13.33%\n\nSo the initial rate of inter

Testing it out. 

In [None]:
# # Test with a specific example
# test_example = dataset_D_subset[2]  # Adjust the index as needed
# question = test_example['question']
# options_list = [opt.strip() for opt in test_example['options']]
# correct_option = test_example['correct'].strip().upper()

# # Parse options into a dictionary
# options_dict = parse_options(options_list)

# # Extract the correct answer
# correct_answer = options_dict.get(correct_option, None)

# print("Question:", question)
# print("Options:", options_dict)
# print("Correct Option:", correct_option)
# print("Correct Answer:", correct_answer)

# # Generate rationale and answer
# generated_rationale, generated_answer_text = generate_rationale_and_answer(question, options_dict, prompt_set)

# print("\nGenerated Rationale:")
# print(generated_rationale)
# print("\nGenerated Answer:", generated_answer_text)
# print("Correct Answer:", correct_answer)


In [366]:
total = len(correct_pairs) + len(incorrect_pairs)
accuracy = len(correct_pairs) / total * 100
print(f"Total questions processed: {total}")
print(f"Correct answers: {len(correct_pairs)}")
print(f"Incorrect answers: {len(incorrect_pairs)}")
print(f"Accuracy: {accuracy:.2f}%")

Total questions processed: 14
Correct answers: 11
Incorrect answers: 3
Accuracy: 78.57%


In [365]:
def rationalize(question, options, correct_answer, prompt_set):
    """
    Generates a correct rationale for a question given the correct answer.

    Parameters:
        question (str): The question to be answered.
        options (dict): A dictionary mapping option letters to option texts.
        correct_answer (str): The correct answer text.
        prompt_set (str): The initial prompt set containing examples.
        model: The language model instance.
        tokenizer: The tokenizer for the model.

    Returns:
        str: The generated rationale.
    """
    # Prepare the options text
    options_text = '\n'.join([f"{key}: {value}" for key, value in options.items()])

    # Construct the input prompt with explicit instructions
    input_text = (
        prompt_set
        + "\n\n"
        + "Provide a detailed explanation for the following question, ensuring that the explanation clearly justifies why the correct answer is chosen.\n"
        + f"Question: {question}\n"
        + "Options:\n"
        + options_text + "\n"
        + f"Correct Answer: {correct_answer}\n"
        + "Explanation:"
    )

    try:
        # Use Ollama (or your LLM) to get the response
        response = ollama.chat(model=llama, messages=[
            {
                'role': 'user',
                'content': input_text,
            },
        ])

        # Extract the generated rationale
        generated_rationale_full = response['message']['content'].strip()

        # Assign the entire response as the rationale
        generated_rationale = generated_rationale_full

        # Optionally, verify the rationale's correctness
        if is_rationale_correct(generated_rationale, correct_answer, question):
            return generated_rationale
        else:
            print(f"Generated rationale does not sufficiently explain the correct answer for question: {question}")
            return ''

    except Exception as e:
        print(f"Error during rationalization: {e}")
        return ''
    
# Process the incorrect answers
for pair in incorrect_pairs:
    question = pair['question']
    options = pair['options']
    correct_answer = pair['correct_answer']

    # Generate the rationale with the correct answer as a hint
    generated_rationale = rationalize(question, options, correct_answer, prompt_set)

    # Add the rationalized example to correct_pairs
    correct_pairs.append({
        'question': question,
        'options': options,
        'rationale': generated_rationale,
        'answer': correct_answer
    })

    print({
        'question': question,
        'options': options,
        'rationale': generated_rationale,
        'answer': correct_answer
    })

{'question': 'Rs. 825 becomes Rs. 956 in 3 years at a certain rate of simple interest.If the rate of interest is increased by 4% ,What amount will Rs. 825 become in 3 years ?', 'options': {'A': 'Rs. 1020.80', 'B': 'Rs. 1025', 'C': 'Rs. 1055', 'D': 'Data inadequate', 'E': 'None of these'}, 'rationale': "To solve this problem, we first need to find the rate of simple interest at which Rs. 825 becomes Rs. 956 in 3 years.\n\nLet's use the formula for simple interest:\n\nSimple Interest (SI) = (Principal × Rate × Time) / 100\n\nWe are given:\nPrincipal (P) = Rs. 825\nAmount after 3 years = Rs. 956\nTime (T) = 3 years\n\nFirst, let's find the SI:\nSI = P - Amount after 3 years\n= 825 - 956\n= -131\n\nSince we cannot have a negative amount, it seems there might be some confusion in the given information. However, to follow through with the problem as presented:\n\nLet's assume this is actually the interest paid or received during these 3 years and calculate the principal that would grow by Rs

In [None]:
class RationalesDataset(torch.utils.data.Dataset):
    def __init__(self, pairs, tokenizer):
        self.examples = []
        for pair in pairs:
            question = pair['question']
            options = '\n'.join(pair['options'])
            rationale = pair['rationale']
            answer = pair['answer']
            input_text = (
                f"Question: {question}\n"
                f"Options:\n{options}\n"
                f"Answer Explanation: {rationale}\n"
                f"Answer: {answer}"
            )
            input_ids = tokenizer.encode(input_text, truncation=True, max_length=512)
            self.examples.append(torch.tensor(input_ids))

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]