In [None]:
# dependencies
#!pip install timm

from google.colab import drive
from huggingface_hub import notebook_login
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import re
import torch
#import timm
from sklearn.model_selection import train_test_split

#drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
#path_model = '/content/drive/My Drive/C5470prj/Mistral-7B-Instruct-v0.2' #mistral + instruct prompting
path_model = "mistralai/Mistral-7B-Instruct-v0.2"
tokenizer = AutoTokenizer.from_pretrained(path_model)
model = AutoModelForCausalLM.from_pretrained(path_model)

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
model.to(device)
device

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

'cpu'

In [None]:
# use colab tpu
#import torch_xla
#import torch_xla.core.xla_model as xm
#device = xm.xla_device()
#model = model.to(device)
#device

In [None]:
# Function to format the input text according to the model's required format
# add few shot examples(5)
def format_fewshot_question(data, examples):
    text = f"""###[INST]  Instructions:
    Below is an instruction that describes a multiple choice task.
    Answer the following multiple choice question by giving the most appropriate response.
    Answer should be one among options provided after the question.
    Select the most suitable answer while making the necessary assumptions.
    Give only answer and a short explanation of two or three sentences. Nothing else.[/INST]"""

    # format examples
    text += """[INST] Here are a few example questions and answers:"""
    for (i,example) in enumerate(examples):
        if example['answer'] not in example['choice_list']:
            continue

        correct_label = ''
        correct_index = example['choice_list'].index(example['answer'])
        for i, ordered_i in enumerate(example['choice_order']):
            if ordered_i == correct_index:
                correct_label = chr(97 + i)

        question = example['question']
        answer = example['answer']
        choices = example['choice_list']
        choice_order = example['choice_order']
        formatted_choices = [f"{chr(97 + i)}. {choices[idx]}" for i, idx in enumerate(choice_order)]

        text += f"""### Example #1:
        Question: {question}
        {formatted_choices[0]}
        {formatted_choices[1]}
        {formatted_choices[2]}
        {formatted_choices[3]}
        The correct answer is: {correct_label}. {answer}
        """

        text += "[/INST]"

    # format main question
    id = data['id']
    question = data['question']
    print("processing data id: " + id + " length: " + str(len(question)))

    choices = data['choice_list']
    choice_order = data['choice_order']
    formatted_choices = [f"{chr(97 + i)}. {choices[idx]}" for i, idx in enumerate(choice_order)]

    text += f"""[INST] Below is the multiple choice question you need to answer. Give only answer and a short explanation of two or three sentences.
    ### Input:
    Question: {question}
    {formatted_choices[0]}
    {formatted_choices[1]}
    {formatted_choices[2]}
    {formatted_choices[3]}

    ### Output Format:
    The correct answer is: {{label}}{{explanation}}
    Change {{label}} to one of a, b, c or d in your output.
    Change {{explanation}} to a short explanation.
    Do not output anything else after it. [/INST]

    The correct answer is: """
    return text, id




In [None]:
import random
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Load the data
#SP_all = np.load('/content/drive/My Drive/C5470prj/BrainTeaser/data/SP-train.npy', allow_pickle=True)
#WP_all = np.load('/content/drive/My Drive/C5470prj/BrainTeaser/data/WP-train.npy', allow_pickle=True)
SP_all = np.load('SP-train.npy', allow_pickle=True)
WP_all = np.load('WP-train.npy', allow_pickle=True)

SP_train, SP_test = train_test_split(SP_all, test_size = 0.2, random_state=42)
WP_train, WP_test = train_test_split(WP_all, test_size = 0.2, random_state=42)

SP_test = SP_test[:50]
WP_test = WP_test[:50]

# stats
correct_answers = 0
question_count = 0
SP_count = 0
SP_correct = 0
WP_count = 0
WP_correct = 0

for data in np.concatenate((WP_test, SP_test)):
    id = data['id']
    wp = False
    sp = False
    if '_' in id: # exclude reconstructed questions during eval
        continue
    if "WP" in id:
        wp = True
    if "SP" in id:
        sp = True

    question_count += 1
    if wp:
      WP_count += 1
    if sp:
      SP_count += 1

    # inference
    # find 5 random example from train set
    examples = []
    if wp:
        examples = np.random.choice(WP_train, size=5, replace=False)
    if sp:
        examples = np.random.choice(SP_train, size=5, replace=False)

    formatted_text, id = format_fewshot_question(data, examples)
    #print(formatted_text)

    inputs = tokenizer(formatted_text, return_tensors="pt").to(device)
    outputs = model.generate(**inputs, max_new_tokens=3) # 6 for base, 30 for instruct
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    #print(response)

    match_str = "The correct answer is:"
    predicted_answer = ''
    match = re.search(r'[^a-zA-Z]([a-d])[^a-zA-Z]', response[len(formatted_text):])
    if match:
        predicted_answer = match.group(1)
    else:
        predicted_answer = chr(97 + random.randint(0, 2))

    # Find the index of the correct answer in the choice order
    correct_label = ''
    # sanitize dataset
    if data['answer'] not in data['choice_list']:
        question_count -= 1
        if wp:
            WP_count -= 1
        if sp:
            SP_count -= 1
        continue

    correct_index = data['choice_list'].index(data['answer'])
    for i, ordered_i in enumerate(data['choice_order']):
        if ordered_i == correct_index:
            correct_label = chr(97 + i)

    if predicted_answer == correct_label:
        correct_answers += 1
        if wp:
            WP_correct += 1
        if sp:
            SP_correct += 1
    print("evaluated data id: " + id + ", predicted: " + predicted_answer + ", correct: " + correct_label + " total acc: " + str(correct_answers / question_count))
    if SP_count != 0:
        print(" SP acc: " + str(SP_correct / SP_count) + " correct: " + str(SP_correct) + "  count: "  + str(SP_count))
    if WP_count != 0:
        print(" WP acc: " + str(WP_correct / WP_count) + " correct: " + str(WP_correct) + "  count: "  + str(WP_count))

# Calculate accuracy
accuracy = correct_answers / question_count
print(f"Total Accuracy: {accuracy:.2f}")



Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


processing data id: WP-31 length: 23


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: WP-31, predicted: a, correct: a total acc: 1.0
 WP acc: 1.0 correct: 1  count: 1
processing data id: WP-16 length: 48


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: WP-16, predicted: a, correct: a total acc: 1.0
 WP acc: 1.0 correct: 2  count: 2
processing data id: WP-4 length: 83


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: WP-4, predicted: b, correct: a total acc: 0.6666666666666666
 WP acc: 0.6666666666666666 correct: 2  count: 3
processing data id: WP-69 length: 47


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: WP-69, predicted: b, correct: c total acc: 0.5
 WP acc: 0.5 correct: 2  count: 4
processing data id: WP-6 length: 85


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: WP-6, predicted: a, correct: a total acc: 0.6
 WP acc: 0.6 correct: 3  count: 5
processing data id: WP-0 length: 41


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: WP-0, predicted: c, correct: c total acc: 0.6666666666666666
 WP acc: 0.6666666666666666 correct: 4  count: 6
processing data id: WP-57 length: 134


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: WP-57, predicted: a, correct: b total acc: 0.5714285714285714
 WP acc: 0.5714285714285714 correct: 4  count: 7
processing data id: WP-12 length: 41


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: WP-12, predicted: c, correct: c total acc: 0.625
 WP acc: 0.625 correct: 5  count: 8
processing data id: WP-108 length: 102


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: WP-108, predicted: b, correct: b total acc: 0.6666666666666666
 WP acc: 0.6666666666666666 correct: 6  count: 9
processing data id: WP-149 length: 33


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: WP-149, predicted: b, correct: b total acc: 0.7
 WP acc: 0.7 correct: 7  count: 10
processing data id: WP-38 length: 49


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: WP-38, predicted: c, correct: a total acc: 0.6363636363636364
 WP acc: 0.6363636363636364 correct: 7  count: 11
processing data id: WP-22 length: 43


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: WP-22, predicted: a, correct: a total acc: 0.6666666666666666
 WP acc: 0.6666666666666666 correct: 8  count: 12
processing data id: WP-28 length: 36


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: WP-28, predicted: a, correct: a total acc: 0.6923076923076923
 WP acc: 0.6923076923076923 correct: 9  count: 13
processing data id: WP-94 length: 40


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: WP-94, predicted: a, correct: a total acc: 0.7142857142857143
 WP acc: 0.7142857142857143 correct: 10  count: 14
processing data id: WP-39 length: 46


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: WP-39, predicted: c, correct: a total acc: 0.6666666666666666
 WP acc: 0.6666666666666666 correct: 10  count: 15
processing data id: WP-45 length: 31


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: WP-45, predicted: a, correct: a total acc: 0.6875
 WP acc: 0.6875 correct: 11  count: 16
processing data id: WP-81 length: 41


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: WP-81, predicted: c, correct: a total acc: 0.6470588235294118
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-203 length: 49


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-203, predicted: c, correct: c total acc: 0.6666666666666666
 SP acc: 1.0 correct: 1  count: 1
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-29 length: 141


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-29, predicted: b, correct: c total acc: 0.631578947368421
 SP acc: 0.5 correct: 1  count: 2
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-185 length: 107


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-185, predicted: c, correct: c total acc: 0.65
 SP acc: 0.6666666666666666 correct: 2  count: 3
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-207 length: 113


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-207, predicted: a, correct: a total acc: 0.6666666666666666
 SP acc: 0.75 correct: 3  count: 4
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-5 length: 138


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-5, predicted: c, correct: b total acc: 0.6363636363636364
 SP acc: 0.6 correct: 3  count: 5
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-32 length: 224


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-32, predicted: b, correct: a total acc: 0.6086956521739131
 SP acc: 0.5 correct: 3  count: 6
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-197 length: 195


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-197, predicted: b, correct: a total acc: 0.5833333333333334
 SP acc: 0.42857142857142855 correct: 3  count: 7
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-80 length: 89


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-80, predicted: c, correct: c total acc: 0.6
 SP acc: 0.5 correct: 4  count: 8
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-12 length: 217


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-12, predicted: c, correct: a total acc: 0.5769230769230769
 SP acc: 0.4444444444444444 correct: 4  count: 9
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-85 length: 472


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-85, predicted: c, correct: c total acc: 0.5925925925925926
 SP acc: 0.5 correct: 5  count: 10
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-206 length: 153


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-206, predicted: a, correct: a total acc: 0.6071428571428571
 SP acc: 0.5454545454545454 correct: 6  count: 11
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-151 length: 73


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-151, predicted: b, correct: a total acc: 0.5862068965517241
 SP acc: 0.5 correct: 6  count: 12
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-195 length: 163


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-195, predicted: b, correct: b total acc: 0.6
 SP acc: 0.5384615384615384 correct: 7  count: 13
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-62 length: 126


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-62, predicted: a, correct: a total acc: 0.6129032258064516
 SP acc: 0.5714285714285714 correct: 8  count: 14
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-133 length: 320


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-133, predicted: c, correct: a total acc: 0.59375
 SP acc: 0.5333333333333333 correct: 8  count: 15
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-83 length: 196


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-83, predicted: c, correct: a total acc: 0.5757575757575758
 SP acc: 0.5 correct: 8  count: 16
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-37 length: 80


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-37, predicted: a, correct: a total acc: 0.5882352941176471
 SP acc: 0.5294117647058824 correct: 9  count: 17
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-13 length: 347


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-13, predicted: c, correct: a total acc: 0.5714285714285714
 SP acc: 0.5 correct: 9  count: 18
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-196 length: 111


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


evaluated data id: SP-196, predicted: a, correct: b total acc: 0.5555555555555556
 SP acc: 0.47368421052631576 correct: 9  count: 19
 WP acc: 0.6470588235294118 correct: 11  count: 17
processing data id: SP-0 length: 139
evaluated data id: SP-0, predicted: a, correct: a total acc: 0.5675675675675675
 SP acc: 0.5 correct: 10  count: 20
 WP acc: 0.6470588235294118 correct: 11  count: 17
Total Accuracy: 0.57


results:

 SP acc: 0.48717948717948717 correct: 19  count: 39
 WP acc: 0.6666666666666666 correct: 20  count: 30
Total Accuracy: 0.57


evaluated data id: WP-31, predicted: a, correct: a total acc: 0.525
 SP acc: 1.0 correct: 39  count: 39
 WP acc: 1.0 correct: 1  count: 1
processing data id: WP-16 length: 48


evaluated data id: SP-0, predicted: a, correct: a total acc: 0.5675675675675675
 SP acc: 0.4 correct: 8  count: 20
 WP acc: 0.7647058823529411 correct: 13  count: 17
Total Accuracy: 0.57

 SP acc: 0.5 correct: 10  count: 20
 WP acc: 0.6470588235294118 correct: 11  count: 17
Total Accuracy: 0.57