In [4]:
import random

def nshot_chats(nshot_data: list, n: int, question: str) -> dict:

    def question_prompt(s):
        return f'Question: {s}'

    def answer_prompt(s):
        return f'Answer: {s}'

    chats = []

    random.seed(42)
    for qna in random.sample(nshot_data, n):
        chats.append(
            {"role": "user", "content": question_prompt(qna["question"])})
        chats.append(
            {"role": "assistant", "content": answer_prompt(qna["answer"])})

    chats.append({"role": "user", "content": question_prompt(question)+
                                             " Let's think step by step. At the end, you MUST write the answer as an integer after '####'."})

    return chats


def extract_ans_from_response(answer: str, eos=None):
    if eos:
        answer = answer.split(eos)[0].strip()

    answer = answer.split('####')[-1].strip()

    for remove_char in [',', '$', '%', 'g']:
        answer = answer.replace(remove_char, '')

    try:
        return int(answer)
    except ValueError:
        return answer

In [5]:
import json
def read_jsonl(path: str):
    with open(path) as fh:
        return [json.loads(line) for line in fh.readlines() if line]

In [6]:
train_data = read_jsonl('../data/gsm8k/train.jsonl')
test_data = read_jsonl('../data/gsm8k/test.jsonl')

In [10]:
N_SHOT = 8

messages = nshot_chats(nshot_data=train_data, n=N_SHOT, question=test_data[0]['question'])  # 8-shot prompt


In [11]:
messages

[{'role': 'user',
  'content': 'Question: For every 12 cans you recycle, you receive $0.50, and for every 5 kilograms of newspapers, you receive $1.50. If your family collected 144 cans and 20 kilograms of newspapers, how much money would you receive?'},
 {'role': 'assistant',
  'content': 'Answer: There are 144/12 = <<144/12=12>>12 sets of 12 cans that the family collected.\nSo, the family would receive $0.50 x 12 = $<<0.50*12=6>>6 for the cans.\nThere are 20/5 = <<20/5=4>>4 sets of 5 kilograms of newspapers that the family collected.\nSo, the family would receive $1.50 x 4 = $<<1.50*4=6>>6 for the newspapers.\nTherefore, the family would receive a total of $6 + $6 = $<<6+6=12>>12.\n#### 12'},
 {'role': 'user',
  'content': 'Question: Betty picked 16 strawberries. Matthew picked 20 more strawberries than Betty and twice as many as Natalie. They used their strawberries to make jam. One jar of jam used 7 strawberries and they sold each jar at $4. How much money were they able to make fr

In [19]:
import requests

def get_llama2_response(messages):

    response = requests.post(
        "http://localhost:11434/api/chat",
        json={
            "model": "llama3.2",
            "messages": messages,
            "stream": False
        }
    )

    response_data = response.json()
    return response_data['message']['content']

In [18]:
# extract_ans_from_response(response_daata['message']['content'])

18

In [22]:
import os
if not os.path.exists('log'):
    os.makedirs('log')

log_file_path = 'log/errors.txt'
with open(log_file_path, 'w') as log_file:
    log_file.write('')

from tqdm import tqdm
total = correct = 0
for qna in tqdm(test_data[:100]):

    messages = nshot_chats(nshot_data=train_data, n=N_SHOT, question=qna['question'])
    response = get_llama2_response(messages)

    pred_ans = extract_ans_from_response(response)
    true_ans = extract_ans_from_response(qna['answer'])

    total += 1
    if pred_ans != true_ans:
        with open(log_file_path, 'a', encoding='utf-8') as log_file:
            log_file.write(f"{messages}\n\n")
            log_file.write(f"Response: {response}\n\n")
            log_file.write(f"Ground Truth: {qna['answer']}\n\n")
            log_file.write(f"Current Accuracy: {correct/total:.3f}\n\n")
            log_file.write('\n\n')
    else:
        correct += 1

print(f"Total Accuracy: {correct/total:.3f}")

100%|██████████| 100/100 [10:05<00:00,  6.06s/it]

Total Accuracy: 0.620



