In [None]:
import sys
sys.path.append("../common_scripts/")

from common_functions import save_batch, print_sample, count_tokens, create_formatted_samples_for_eval
from eval_prompts import *
from pathlib import Path
import pandas as pd
import json

from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.messages.batch_create_params import Request
import anthropic

claude_client = anthropic.Anthropic()

from datetime import datetime
import time

import fsspec
from google import genai
from google.genai.types import CreateBatchJobConfig
import os

test_df = pd.read_csv("../../dataset_for_hf/test.csv")

# Zero Shot

## o4-mini

In [None]:
# Batch Create
def create_formatted_inputs_for_zero_shot_eval(row):
    # Adding label for ID since there are d-d questions with both bio & mol interactions.
    return {"custom_id": f"{row.Entities}_{row.Label}:zero_shot",
            "method": "POST", 
            "url": "/v1/chat/completions", 
            "body": {"model": "o4-mini", 
                     "messages": [
                                     {"role": "developer", "content": SYSTEM_PROMPT_ZERO_SHOT},
                                     {"role": "user", "content": USER_PROMPT_ZERO_SHOT.format(row.Question)}
                                 ]
                    }
           }

formatted_samples = [create_formatted_inputs_for_zero_shot_eval(row) for row in test_df.itertuples()]
save_batch(formatted_samples, "../../samples_for_eval/zero_shot/o4-mini/batch_input.jsonl")

In [None]:
# Eval
predictions = pd.read_json("../../samples_for_eval/zero_shot/o4-mini/batch_output.jsonl", lines=True)
predictions.rename(columns={"custom_id": "Entities"}, inplace=True)

test_df["Entities"] = test_df["Entities"] + "_" + test_df["Label"].astype(str) + ":zero_shot"
merged_df = pd.merge(predictions, test_df)

create_formatted_samples_for_eval(merged_df, "../../samples_for_eval/zero_shot/o4-mini/eval_lists.pkl", "OAI")

## Claude 3.7 Sonnet

In [None]:
# Batch Create
def create_batch_request(df):
    requests = []
    for row in df.itertuples():
        requests.append(Request(
            custom_id=f"{row.Entities[:60].replace(" ", "-")}-{row.Label}", # Have to do this - their naming convention.
            params=MessageCreateParamsNonStreaming(
                model="claude-3-7-sonnet-20250219",
                max_tokens=512,
                system= [{"type": "text", "text": SYSTEM_PROMPT_ZERO_SHOT, "cache_control": {"type": "ephemeral"}}],
                messages=[{"role": "user",
                           "content": USER_PROMPT_ZERO_SHOT.format(row.Question)}])))
    return requests

formatted_requests = create_batch_request(test_df)
save_batch(formatted_requests, "../../samples_for_eval/zero_shot/claude-3.7-sonnet/batch_input.jsonl")

message_batch = claude_client.messages.batches.create(requests=formatted_requests)
print(message_batch)

In [None]:
# Eval
predictions = pd.read_json("../../samples_for_eval/zero_shot/claude-3.7-sonnet/batch_results.jsonl", lines=True)
predictions.rename(columns={"custom_id": "Entities"}, inplace=True)
test_df["Entities"] = test_df["Entities"].str.replace(" ", "-").str[:60] + "-" + test_df["Label"].astype(str)
merged_df = pd.merge(predictions, test_df)
create_formatted_samples_for_eval(merged_df, "../../samples_for_eval/zero_shot/claude-3.7-sonnet/eval_lists.pkl", "Anthropic")

## Gemini 2.0 Flash

In [None]:
def create_batch_request(row):
    return {"systemInstruction": SYSTEM_PROMPT_ZERO_SHOT,
            "request":{"contents": [{"role": "user", "parts": [{"text": USER_PROMPT_ZERO_SHOT.format(row.Question)}]}]}}

formatted_samples = [create_batch_request(row) for row in test_df.itertuples()]
save_batch(formatted_samples, "../../samples_for_eval/zero_shot/gemini-2.0-flash/batch_input.jsonl")

In [None]:
PROJECT_ID = "striped-torus-458820-q6"  # @param {type: "string", placeholder: "[your-project-id]", isTemplate: true}
LOCATION = os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")

client = genai.Client(vertexai=True, project=PROJECT_ID, location=LOCATION)

MODEL_ID = "gemini-2.0-flash-001"
INPUT_DATA = "gs://zeroshot/batch_input.jsonl"
BUCKET_URI = "zeroshot"
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
BUCKET_URI = f"gs://{PROJECT_ID}-{TIMESTAMP}"

! gsutil mb -l {LOCATION} -p {PROJECT_ID} {BUCKET_URI}

gcs_batch_job = client.batches.create(
    model=MODEL_ID,
    src=INPUT_DATA,
    config=CreateBatchJobConfig(dest=BUCKET_URI),
)
gcs_batch_job.name

# Upper Bound

## o4-mini

In [None]:
# Batch Create
def create_formatted_inputs_for_upper_bound_eval(row):
    # Adding label for ID since there are d-d questions with both bio & mol interactions.
    return {"custom_id": f"{row.Entities}_{row.Label}:upper_bound",
            "method": "POST", 
            "url": "/v1/chat/completions",
            "body": {"model": "o4-mini",
                     "messages": [
                                     {"role": "developer", "content": SYSTEM_PROMPT_UPPER_BOUND},
                                     {"role": "user", "content": USER_PROMPT_UPPER_BOUND.format(row.Question_Background, row.Question)}
                                 ]
                    }
           }

formatted_samples = [create_formatted_inputs_for_upper_bound_eval(row) for row in test_df.itertuples()]
save_batch(formatted_samples, "../../samples_for_eval/upper_bound/o4-mini/batch_input.jsonl")

## Claude 3.7 Sonnet

In [None]:
# Batch Create
def create_batch_request(df):
    requests = []
    for row in df.itertuples():
        requests.append(Request(
            custom_id=f"{row.Entities[:60].replace(" ", "-")}-{row.Label}", # Have to do this - their naming convention.
            params=MessageCreateParamsNonStreaming(
                model="claude-3-7-sonnet-20250219",
                max_tokens=512,
                system= [{"type": "text", "text": SYSTEM_PROMPT_UPPER_BOUND, "cache_control": {"type": "ephemeral"}}],
                messages=[{"role": "user",
                           "content": USER_PROMPT_UPPER_BOUND.format(row.Question_Background, row.Question)}])))
    return requests

formatted_requests = create_batch_request(test_df)
save_batch(formatted_requests, "../../samples_for_eval/upper_bound/claude-3.7-sonnet/batch_input.jsonl")

message_batch = claude_client.messages.batches.create(requests=formatted_requests)
print(message_batch)

In [None]:
# Eval
predictions = pd.read_json("../../samples_for_eval/upper_bound/claude-3.7-sonnet/batch_results.jsonl", lines=True)
predictions.rename(columns={"custom_id": "Entities"}, inplace=True)
test_df["Entities"] = test_df["Entities"].str.replace(" ", "-").str[:60] + "-" + test_df["Label"].astype(str)
merged_df = pd.merge(predictions, test_df)
create_formatted_samples_for_eval(merged_df, "../../samples_for_eval/upper_bound/claude-3.7-sonnet/eval_lists.pkl", "Anthropic")