In [None]:
import sys
sys.path.append("../common_scripts/")

from common_functions import save_batch, print_sample, count_tokens, create_formatted_samples_for_eval
from eval_prompts import SYSTEM_PROMPT_ZERO_SHOT, USER_PROMPT_ZERO_SHOT
from pathlib import Path
import pandas as pd
import json

from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.messages.batch_create_params import Request
import anthropic

claude_client = anthropic.Anthropic()

test_df = pd.read_csv("../../dataset_for_hf/test.csv")

# Zero Shot

## o4-mini

In [None]:
# Batch Create
def create_formatted_inputs_for_zero_shot_eval(row):
    # Adding label for ID since there are d-d questions with both bio & mol interactions.
    return {"custom_id": f"{row.Entities}_{row.Label}:zero_shot",
            "method": "POST", 
            "url": "/v1/chat/completions", 
            "body": {"model": "o4-mini", 
                     "messages": [
                                     {"role": "developer", "content": SYSTEM_PROMPT_ZERO_SHOT},
                                     {"role": "user", "content": USER_PROMPT_ZERO_SHOT.format(row.Question)}
                                 ]
                    }
           }

formatted_samples = [create_formatted_inputs_for_zero_shot_eval(row) for row in test_df.itertuples()]
save_batch(formatted_samples, "../../samples_for_eval/zero_shot/o4-mini/batch_input.jsonl")

In [None]:
# Eval
predictions = pd.read_json("../../samples_for_eval/zero_shot/o4-mini/batch_output.jsonl", lines=True)
predictions.rename(columns={"custom_id": "Entities"}, inplace=True)

test_df["Entities"] = test_df["Entities"] + "_" + test_df["Label"].astype(str) + ":zero_shot"
merged_df = pd.merge(predictions, test_df)

create_formatted_samples_for_eval(merged_df, "../../samples_for_eval/zero_shot/o4-mini/eval_lists.pkl", "OAI")

## Claude 3.7 Sonnet

In [None]:
# Batch Create
def create_batch_request(df):
    requests = []
    for row in df.itertuples():
        requests.append(Request(
            custom_id=f"{row.Entities[:60].replace(" ", "-")}-{row.Label}", # Have to do this - their naming convention.
            params=MessageCreateParamsNonStreaming(
                model="claude-3-7-sonnet-20250219",
                max_tokens=512,
                system= [{"type": "text", "text": SYSTEM_PROMPT_ZERO_SHOT, "cache_control": {"type": "ephemeral"}}],
                messages=[{"role": "user",
                           "content": USER_PROMPT_ZERO_SHOT.format(row.Question)}])))
    return requests

formatted_requests = create_batch_request(test_df)
save_batch(formatted_requests, "../../samples_for_eval/zero_shot/claude-3.7-sonnet/batch_input.jsonl")

message_batch = claude_client.messages.batches.create(requests=formatted_requests)
print(message_batch)

In [None]:
# Eval
predictions = pd.read_json("../../samples_for_eval/zero_shot/claude-3.7-sonnet/batch_results.jsonl", lines=True)
predictions.rename(columns={"custom_id": "Entities"}, inplace=True)
test_df["Entities"] = test_df["Entities"].str.replace(" ", "-")[:60] + "-" + test_df["Label"].astype(str)
merged_df = pd.merge(predictions, test_df)

create_formatted_samples_for_eval(merged_df, "../../samples_for_eval/zero_shot/claude-3.7-sonnet/eval_lists.pkl", "Anthropic")