<a href="https://colab.research.google.com/github/sanj1210/PALM/blob/main/PALM_multiarith_cot_bison_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets
!pip install google-cloud-aiplatform==1.36.2 --upgrade --user
!pip uninstall bigframes
!pip install bigframes==0.26.0

import bigframes.dataframe
import IPython
import time
from datasets import load_dataset, concatenate_datasets
import re
import pandas as pd
from IPython.display import Markdown, display
from sklearn.metrics.pairwise import cosine_similarity
from vertexai.language_models import (
    TextGenerationModel,
    TextEmbeddingModel,
    ChatModel,
    InputOutputTextPair,
    CodeGenerationModel,
    CodeChatModel,
)

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

import vertexai

PROJECT_ID = "nth-platform-422801-r1"  # @param {type:"string"}
REGION = "us-central1"  # @param {type:"string"}

# Initialize Vertex AI SDK
vertexai.init(project=PROJECT_ID, location=REGION)

import sys

if "google.colab" in sys.modules:
    from google.colab import auth
    auth.authenticate_user()

# Function to extract the answer from the model's output
def extract_answer(output):
    try:
        split_output = output.split("The answer is")
        if len(split_output) > 1:
            answer_part = split_output[-1]
        else:
            answer_part = output

        # Check for time format HH:MM first
        match_time = re.search(r'(\d{1,2}:\d{2})', answer_part)
        if match_time:
            return match_time.group(1)

        # Check for numerical format, including decimals
        match_num = re.search(r'((?:\d{1,3},(?:\d{3},)*\d{3}|\d+)(?:\.\d+)?)(\.)?', answer_part)
        if match_num:
            answer = match_num.group(1).replace(",", "")
            return float(answer)

        return "No answer"
    except Exception as e:
        print("Error extracting answer:", e)
        return "No answer"

# Function to extract the ground truth value
def extract_ground_truth_value(ground_truth_ans):
    try:
        if isinstance(ground_truth_ans, (int, float)):
            return ground_truth_ans
        elif isinstance(ground_truth_ans, str):
            # Check for time format HH:MM first
            match_time = re.search(r'(\d{1,2}:\d{2})', ground_truth_ans)
            if match_time:
                return match_time.group(1)

            # Check for numerical format, including decimals
            match_num = re.search(r'(\d+(?:\.\d+)?)', ground_truth_ans)
            if match_num:
                return float(match_num.group())
        return "No answer"
    except Exception as e:
        print("Error extracting ground truth value:", e)
        return "No answer"

# Load the MultiArith dataset
multiarith = load_dataset("ChilleD/MultiArith")
multiarith_combined = concatenate_datasets([multiarith['train'], multiarith['test']])
multiarith_data = pd.DataFrame(data = multiarith_combined)

# Use multiarith_data["question"] as an input question and multiarith_data['final_ans'] as ground truth
multiarith_data["Full_Question"] = multiarith_data["question"]
multiarith_data["Int_Answer"] = multiarith_data["final_ans"]

# Define the range of questions to process
start_index = 0
end_index = len(multiarith_data) - 1
subset_multiarith_eval = multiarith_data.iloc[start_index:end_index + 1]

# Exemplar prompts
exemplars = '''
Think step by step to answer the given question. The last sentence of the answer must be of the format `The answer is (x)`.
Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.
Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.
Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?
A: Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The answer is 39.
Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?
A: Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. The answer is 8.
Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?
A: Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. The answer is 9.
Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?
A: There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The answer is 29.
Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?
A: Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The answer is 33.
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
A: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8.
'''

# Initialize the model
generation_model = TextGenerationModel.from_pretrained("text-bison@001")

# Function to process a batch of questions
def process_batch(batch):
    responses = []
    for question in batch:
        input_prompt = instruction + "\n" + exemplars + "\n" + "Q: " + question + "\n" + "A: "
        try:
            response = generation_model.predict(prompt=input_prompt, max_output_tokens=1024, temperature=0.2, top_k=0.8, top_p=0.4)
            responses.append(response.text)
        except Exception as e:
            print("Exception caught:", e)
            responses.append("Error: Quota Exceeded")
        time.sleep(2)  # Adding a delay between requests
    return responses

# Process the subset of the dataset in batches (each batch processes 5 questions)
batch_size = 5
results = []
questions_to_process = subset_multiarith_eval['Full_Question']

for i in range(0, len(questions_to_process), batch_size):
    batch_questions = questions_to_process[i:i + batch_size]
    batch_responses = process_batch(batch_questions)
    results.extend(batch_responses)
    print(f"Processed batch {i // batch_size + 1}")
    time.sleep(47)

# Create a dataframe
ids = list(range(start_index, start_index + len(subset_multiarith_eval)))
ground_truth_answers = subset_multiarith_eval['Int_Answer']
extracted_values = [extract_answer(response) for response in results]
ground_truth_values = [extract_ground_truth_value(answer) for answer in ground_truth_answers]

# Function to compare extracted and ground truth values
def compare_values(extracted, ground_truth):
    if isinstance(extracted, str) and isinstance(ground_truth, str):
        return extracted == ground_truth
    elif isinstance(extracted, float) and isinstance(ground_truth, float):
        return extracted == ground_truth
    return False

# Calculate correctness
correctness = []
for extracted, ground_truth in zip(extracted_values, ground_truth_values):
    if extracted == ground_truth == "No answer":
        correctness.append(0)
    else:
        correctness.append(1 if extracted == ground_truth else 0)

# Creating a dataframe
results_df = pd.DataFrame({
    'id': ids,
    'question': subset_multiarith_eval['Full_Question'],
    'groundTruthAns': ground_truth_answers,
    'model output': results,
    'extracted value': extracted_values,
    'ground truth value': ground_truth_values,
    'correctness': correctness
})

from google.colab import files
# Save to a csv file
results_df.to_csv("responses_multiarith_subset.csv", index=False)

# Function to compute accuracy and save results
def compute_accuracy_and_save(results_df, dataset_name, model_name, prompt_type):
    accuracy = results_df['correctness'].mean()
    output_content = f"{{Dataset: {dataset_name}, \nmodel_name: {model_name}, \naccuracy: {accuracy}, \nprompt: {prompt_type}}}"

    with open("accuracy_multiarith_subset.txt", "w") as f:
        f.write(output_content)

    files.download("accuracy_multiarith_subset.txt")

# Compute accuracy and save to accuracy.txt
compute_accuracy_and_save(results_df, "MultiArith", "text-bison@001", "COT with instruction")

# Download the csv file
files.download("responses_multiarith_subset.csv")
