In [None]:
import os
import pandas as pd
from dotenv import load_dotenv, dotenv_values
from groq import Groq
import re

In [None]:
# 1. Setup
# API Clients
#load_dotenv() 
config = dotenv_values(".env") # use this if you are using school laptop / device
groq_client = Groq(api_key=config["GROQ_API_KEY"])
# openai_client = OpenAI(api_key=config["OPENAI_API_KEY"])


# groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
# openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

# Model configs
GROQ_MODEL = "llama3-8b-8192"  
OPENAI_MODEL = "gpt-3.5-turbo-1106"  # "gpt-4-1106-preview"

TEMPERATURE = 0.2
MAX_TOKENS = 20
#BATCH_SIZE = 5

# File paths
INPUT_FILE = "test_predictions_merged.csv"
OUTPUT_FILE = "test_predictions_cot_2.csv" # Name whatever you want

# Save frequency
FREQUENCY = 250


# STARTING INDEX
START_IDX = len(df) // 2 # Set accordingly if Mustafa wants to also help run
# ENDING INDEX
END_IDX = len(df) # Set accordingly if Mustafa wants to also help run

In [None]:
# Laod CSV
df = pd.read_csv(INPUT_FILE)
if "groq_pred_cot" not in df.columns:
    df["groq_pred_cot"] = None

In [None]:
def build_cot_prompt(question, correct_answer, response):
    return f"""
            You are a short answer grader. Think step-by-step.
                    
            1. Analyze if the Response matches the key information in the CorrectAnswer.
            2. Output only one line by deciding:
                - If correct, output 1.
                - If similar but not exact, output 0.
                - If wrong, output -1.
                    
            Question: {question}
            CorrectAnswer: {correct_answer}
            Response: {response}
                    
            Final Label (only -1, 0, or 1):
        """

In [None]:
def parse_prediction(output):
    """
    Extract the final -1, 0, or 1 from a text response.
    Looks for the last valid number in the entire output.
    """
    output = str(output).strip()
    matches = re.findall(r"[-]?\b[01]\b", output)
    if matches:
        return int(matches[-1])  # Return the last matching number
    return None

In [None]:
for idx in range(START_IDX, END_IDX):
    if pd.notnull(df.at[idx, "groq_pred_cot"]):
        continue

    row = df.iloc[idx]
    prompt = build_cot_prompt(row["Question"], row["CorrectAnswer"], row["Response"])
    
    try:
        res = groq_client.chat.completions.create(
            model=GROQ_MODEL,
            messages=[
                {"role": "system", "content": "You are a strict grader. Only return -1, 0, or 1."},
                {"role": "user", "content": prompt}
            ],
            max_tokens=MAX_TOKENS,
            temperature=TEMPERATURE
        )
        prediction = parse_prediction(res.choices[0].message.content.strip())
        if prediction is not None:
            df.at[idx, "groq_pred_cot"] = prediction
            # print(res.choices[0].message.content) # testing purpose
        else:
            print(f"[Warning] Invalid response at idx {idx}: {res.choices[0].message.content}")
    except Exception as e:
        print(f"[Error] idx {idx}: {e}")

    if idx % FREQUENCY == 0:
        df.to_csv(OUTPUT_FILE, index=False)
        print(f"[Progress] Saved at idx {idx}")

# ==== Final save ====
df.to_csv(OUTPUT_FILE, index=False)
print("✅ Finished processing second half of dataset.")