In [1]:
import sys
sys.path.append("../common_scripts/")

from common_functions import save_batch, print_sample, count_tokens, create_formatted_samples_for_eval
from eval_prompts import *
from pathlib import Path
import pandas as pd
import json


from datetime import datetime
import time

import fsspec
import os

test_df = pd.read_csv("../../dataset_for_hf/test.csv")

# deepseek

## zero-shot

In [11]:
def create_formatted_inputs_for_zero_shot_eval(row):
    # Adding label for ID since there are d-d questions with both bio & mol interactions.
    return {"custom_id": f"{row.Entities}_{row.Label}:zero_shot",
            "body": {"model": 'deepseek-reasoner',
                     "messages": [
                                     {"role": "system", "content": SYSTEM_PROMPT_ZERO_SHOT},
                                     {"role": "user", "content": USER_PROMPT_ZERO_SHOT.format(row.Question)}
                                 ]
                    }
           }

formatted_samples = [create_formatted_inputs_for_zero_shot_eval(row) for row in test_df.itertuples()]
save_batch(formatted_samples, "../../samples_for_eval/zero_shot/deepseek/batch_input_r1.jsonl")

In [3]:
def run_deepseek_inference(input_path, output_path):
    responses = []
    with open(input_path, 'r') as f:
        for line in f:
            sample = json.loads(line.strip())
            messages = sample["body"]["messages"]
            model = sample["body"]["model"]
            
            try:
                response = client.chat.completions.create(
                    model=model,
                    messages=messages,
                    stream=False
                )
                output = {
                    "custom_id": sample["custom_id"],
                    "response": response.choices[0].message.content
                }
            except Exception as e:
                output = {
                    "custom_id": sample["custom_id"],
                    "response": f"[ERROR] {str(e)}"
                }

            responses.append(output)

    # Save responses
    with open(output_path, 'w') as f:
        for r in responses:
            f.write(json.dumps(r) + '\n')

    print(f"Saved {len(responses)} results to {output_path}")

In [None]:
import json
from openai import OpenAI

# Set up client
client = OpenAI(api_key="", base_url="https://api.deepseek.com")
run_deepseek_inference('../../samples_for_eval/zero_shot/deepseek/batch_input_r1.jsonl', "../../samples_for_eval/zero_shot/deepseek/responses_r1.jsonl")

## upper-bound

In [None]:
def create_formatted_inputs_for_upper_bound_eval(row):
    return {
        "custom_id": f"{row.Entities}_{row.Label}:upper_bound",
        "body": {
            "model": 'deepseek-reasoner',
            "messages": [
                {"role": "system", "content": SYSTEM_PROMPT_ZERO_SHOT},
                {"role": "user", "content": f"QUESTION: {row.Question}\n\nRELEVANT KNOWLEDGE: {row.Question_Background}"}
            ]
        }
    }

# Apply to test_df
formatted_upper_bound = [create_formatted_inputs_for_upper_bound_eval(row) for row in test_df.itertuples()]
save_batch(formatted_upper_bound, "../../samples_for_eval/upper_bound/deepseek/batch_input.jsonl")


In [None]:
import json
from openai import OpenAI

client = OpenAI(api_key="", base_url="https://api.deepseek.com")
run_deepseek_inference('../../samples_for_eval/upper_bound/deepseek/batch_input.jsonl', "../../samples_for_eval/upper_bound/deepseek/responses_r1.jsonl")

# Qwen  

## zero-shot


In [10]:
def create_formatted_inputs_for_zero_shot_eval(row):
    return {
        "custom_id": f"{row.Entities}_{row.Label}:zero_shot",
        "body": {
            "model": "Qwen/Qwen3-235B-A22B-fp8-tput",  # adjust model name as per API
            "messages": [
                {"role": "system", "content": SYSTEM_PROMPT_ZERO_SHOT},
                {"role": "user", "content": USER_PROMPT_ZERO_SHOT.format(row.Question)}
            ]
        }
    }
zero_shot_samples = [create_formatted_inputs_for_zero_shot_eval(row) for row in test_df.itertuples()]
save_batch(zero_shot_samples, "../../samples_for_eval/zero_shot/qwen/batch_input_new.jsonl")


## upper_bound

In [7]:

def create_formatted_inputs_for_qwen_gold(row):
    return {
        "custom_id": f"{row.Entities}_{row.Label}:gold_injected",
        "body": {
            "model": 'Qwen/Qwen3-235B-A22B-fp8-tput', #"qwen/qwen3-235b-a22b:free",
            "messages": [
                {"role": "system", "content": SYSTEM_PROMPT_ZERO_SHOT},
                {"role": "user", "content": f"Question: {row.Question}\n\nGold Evidence:\n{row.Question_Background}\n\nAnswer:"}
            ]
        }
    }
gold_samples = [create_formatted_inputs_for_qwen_gold(row) for row in test_df.itertuples()]
save_batch(gold_samples, "../../samples_for_eval/upper_bound/qwen/batch_input_all.jsonl")

In [3]:
import os
from together import Together
def run_qwen3_inference(input_path, output_path):
    client = Together(api_key='')
    # Step 1: Read already completed custom_ids
    completed_ids = set()
    if os.path.exists(output_path):
        with open(output_path, 'r') as f_out:
            for line in f_out:
                try:
                    obj = json.loads(line)
                    completed_ids.add(obj["custom_id"])
                except:
                    continue
    print(f"Found {len(completed_ids)} completed responses. Resuming from last point...")
    # Step 2: Begin line-by-line inference
    with open(input_path, 'r') as f_in, open(output_path, 'a') as f_out:
        for line in f_in:
            sample = json.loads(line.strip())
            custom_id = sample["custom_id"]
            if custom_id in completed_ids:
                continue  # Skip already processed
            messages = sample["body"]["messages"]
            model = sample["body"]["model"]
            try:
                completion = client.chat.completions.create(
                    model=model,
                    messages=messages,
                    extra_body={}
                )

                if completion.choices and completion.choices[0].message and completion.choices[0].message.content:
                    content = completion.choices[0].message.content
                else:
                    content = "[ERROR] Empty or malformed response."

                result = {
                    "custom_id": custom_id,
                    "response": content
                }
            except Exception as e:
                result = {
                    "custom_id": custom_id,
                    "response": f"[ERROR] {str(e)}"
                }

            f_out.write(json.dumps(result) + "\n")
            f_out.flush()  # Immediately write to disk
            print(f"Processed: {custom_id}")
    

In [4]:
run_qwen3_inference('../../samples_for_eval/zero_shot/qwen/batch_input_new.jsonl', '../../samples_for_eval/zero_shot/qwen/responses_new.jsonl')

Found 142 completed responses. Resuming from last point...
Processed: acetazolamide-Haloperidol-Rosiglitazone-Bimatoprost_2:zero_shot
Processed: Naproxen-alendronic acid_0:zero_shot
Processed: Aciclovir-Doxazosin-Alpha-1B adrenergic receptor_7:zero_shot
Processed: Risedronic acid-Zolpidem_3:zero_shot
Processed: Amlodipine-vincristine_0:zero_shot
Processed: Temazepam-Trimipramine-Histamine H1 receptor_7:zero_shot
Processed: Naproxen-Risedronic acid-vincristine-Torasemide_2:zero_shot
Processed: Duloxetine-desvenlafaxine_3:zero_shot
Processed: Aciclovir-Dipyridamole_0:zero_shot
Processed: Celecoxib-Palonosetron_0:zero_shot
Processed: Methylphenidate-Modafinil-Pamidronic acid-thiotepa_5:zero_shot
Processed: Desloratadine-solifenacin_0:zero_shot
Processed: Naproxen-hydroxychloroquine_3:zero_shot
Processed: Citalopram-Fenofibrate_0:zero_shot
Processed: Telmisartan-Cinacalcet_0:zero_shot
Processed: Rofecoxib-Torasemide_3:zero_shot
Processed: Lamotrigine-Dexpramipexole_3:zero_shot
Processed: A

In [9]:
run_qwen3_inference('../../samples_for_eval/upper_bound/qwen/batch_input_all.jsonl', "../../samples_for_eval/upper_bound/qwen/responses.jsonl")

Saved 168 results to ../../samples_for_eval/upper_bound/qwen/responses.jsonl


# Llama

## zero-shot

In [None]:
from together import Together
def create_formatted_inputs_for_zero_shot_eval(row):
    return {
        "custom_id": f"{row.Entities}_{row.Label}:zero_shot",
        "method": "POST",
        "url": "/v1/chat/completions",  # still standard for OpenAI-compatible endpoints
        "body": {
            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",  # adjust model name as per API
            "messages": [
                {"role": "system", "content": SYSTEM_PROMPT_ZERO_SHOT},
                {"role": "user", "content": USER_PROMPT_ZERO_SHOT.format(row.Question)}
            ]
        }
    }

formatted_samples = [create_formatted_inputs_for_zero_shot_eval(row) for row in test_df.itertuples()]
save_batch(formatted_samples, "../../samples_for_eval/zero_shot/llama/batch_input.jsonl")



In [None]:
import json
from openai import OpenAI  # Or relevant client for Fireworks / Together.ai
client = Together(api_key='')

def run_llama3_inference(input_path, output_path):
    responses = []
    with open(input_path, 'r') as f:
        for line in f:
            sample = json.loads(line.strip())
            messages = sample["body"]["messages"]
            model = sample["body"]["model"]

            try:
                response = client.chat.completions.create(
                    model=model,
                    messages=messages,
                    stream=False
                )
                output = {
                    "custom_id": sample["custom_id"],
                    "response": response.choices[0].message.content
                }
            except Exception as e:
                output = {
                    "custom_id": sample["custom_id"],
                    "response": f"[ERROR] {str(e)}"
                }

            responses.append(output)

    with open(output_path, 'w') as f:
        for r in responses:
            f.write(json.dumps(r) + '\n')

    print(f"Saved {len(responses)} results to {output_path}")

In [None]:
run_llama3_inference('../../samples_for_eval/zero_shot/llama/batch_input.jsonl', "../../samples_for_eval/zero_shot/llama/responses.jsonl")

## upper_bound

In [None]:
def create_formatted_inputs_gold(row):
    
    return {
        "custom_id": f"{row.Entities}_{row.Label}:gold_injected",
        "method": "POST",
        "url": "/v1/chat/completions",  # still standard for OpenAI-compatible endpoints
        "body": {
            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",  # adjust model name as per API
                        "messages": [
                {"role": "system", "content": SYSTEM_PROMPT_ZERO_SHOT},
                {"role": "user", "content": f"Question: {row.Question}\n\nGold Evidence:\n{row.Question_Back}\n\nAnswer:"}
                ]
        }
    }

formatted_samples = [create_formatted_inputs_gold(row) for row in test_df.itertuples()]
save_batch(formatted_samples, "../../samples_for_eval/upper_bound/llama/batch_input.jsonl")

In [None]:
run_llama3_inference('../../samples_for_eval/upper_bound/llama/batch_input.jsonl', "../../samples_for_eval/upper_bound/llama/responses.jsonl")

# TxGemma

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig


In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
tokenizer = AutoTokenizer.from_pretrained("google/txgemma-9b-chat")
model = AutoModelForCausalLM.from_pretrained("google/txgemma-9b-chat", device_map="auto", quantization_config=BitsAndBytesConfig(load_in_8bit=True))
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Device set to use cuda:0


In [10]:
def create_formatted_inputs_for_txgemma_zeroshot(row):
    return {
        "custom_id": f"{row.Entities}_{row.Label}:zero_shot",
        "input_text": f"{SYSTEM_PROMPT_ZERO_SHOT}\n\nQuestion: {row.Question}\n\nAnswer:"
    }

def create_formatted_inputs_for_txgemma_upperbound(row):
    return {
        "custom_id": f"{row.Entities}_{row.Label}:upper_bound",
        "input_text": f"{SYSTEM_PROMPT_UPPER_BOUND}\n\n{USER_PROMPT_UPPER_BOUND.format(row.Question_Background, row.Question)}"
    }



In [11]:

formatted_samples = [create_formatted_inputs_for_txgemma_zeroshot(row) for row in test_df.itertuples()]
save_batch(formatted_samples, "../../samples_for_eval/zero_shot/txgemma/batch_input.jsonl")

In [11]:

formatted_samples = [create_formatted_inputs_for_txgemma_gold(row) for row in test_df.itertuples()]
save_batch(formatted_samples, "../../samples_for_eval/upper_bound/txgemma/batch_input.jsonl")

In [9]:
run_txgemma_inference("../../samples_for_eval/zero_shot/txgemma/batch_input.jsonl", "../../samples_for_eval/zero_shot/txgemma/responses_new.jsonl",pipe)

Found 0 completed responses. Resuming...
Processed: Candesartan-Citalopram-Fluphenazine-Histamine H1 receptor_8:zero_shot
Processed: Lamotrigine-Oxaprozin-Sildenafil_4:zero_shot
Processed: Ranitidine-Sibutramine-Zolpidem_4:zero_shot
Processed: Cetirizine-Dipyridamole-Sildenafil-Zaleplon_2:zero_shot
Processed: Alprazolam-Bumetanide-Fenofibrate-Modafinil_5:zero_shot
Processed: alendronic acid-Pyridostigmine_0:zero_shot
Processed: Ketorolac-Oxaprozin_0:zero_shot
Processed: Naproxen-Doxazosin-Alpha-1B adrenergic receptor_7:zero_shot
Processed: Chloroquine-hydroxychloroquine_3:zero_shot
Processed: Amlodipine-Pamidronic acid-thiotepa_1:zero_shot
Processed: Diazepam-Duloxetine-Acamprosate_4:zero_shot
Processed: Tramadol-Exemestane_0:zero_shot
Processed: Haloperidol-Muscarinic acetylcholine receptor M5_6:zero_shot
Processed: Alprazolam-Ciclopirox_3:zero_shot
Processed: Norfloxacin-Zolpidem-vincristine-Torasemide_2:zero_shot
Processed: Naproxen-Ondansetron-Salmeterol-Beta-1 adrenergic receptor_

In [12]:
run_txgemma_inference("../../samples_for_eval/upper_bound/txgemma/batch_input.jsonl", "../../samples_for_eval/upper_bound/txgemma/responses_new.jsonl",pipe)

Found 0 completed responses. Resuming...
Processed: Candesartan-Citalopram-Fluphenazine-Histamine H1 receptor_8:gold_injected
Processed: Lamotrigine-Oxaprozin-Sildenafil_4:gold_injected
Processed: Ranitidine-Sibutramine-Zolpidem_4:gold_injected
Processed: Cetirizine-Dipyridamole-Sildenafil-Zaleplon_2:gold_injected
Processed: Alprazolam-Bumetanide-Fenofibrate-Modafinil_5:gold_injected
Processed: alendronic acid-Pyridostigmine_0:gold_injected
Processed: Ketorolac-Oxaprozin_0:gold_injected
Processed: Naproxen-Doxazosin-Alpha-1B adrenergic receptor_7:gold_injected
Processed: Chloroquine-hydroxychloroquine_3:gold_injected
Processed: Amlodipine-Pamidronic acid-thiotepa_1:gold_injected
Processed: Diazepam-Duloxetine-Acamprosate_4:gold_injected
Processed: Tramadol-Exemestane_0:gold_injected
Processed: Haloperidol-Muscarinic acetylcholine receptor M5_6:gold_injected
Processed: Alprazolam-Ciclopirox_3:gold_injected
Processed: Norfloxacin-Zolpidem-vincristine-Torasemide_2:gold_injected
Processed:

In [4]:
def run_txgemma_inference(input_path, output_path, pipe, max_tokens=256):
    completed_ids = set()
    if os.path.exists(output_path):
        with open(output_path, 'r') as f:
            for line in f:
                try:
                    obj = json.loads(line)
                    completed_ids.add(obj["custom_id"])
                except:
                    continue
    print(f"Found {len(completed_ids)} completed responses. Resuming...")

    with open(input_path, 'r') as f_in, open(output_path, 'a') as f_out:
        for line in f_in:
            sample = json.loads(line.strip())
            custom_id = sample["custom_id"]
            if custom_id in completed_ids:
                continue

            try:
                output = pipe(sample["input_text"], max_new_tokens=max_tokens)[0]["generated_text"]
                answer = output[len(sample["input_text"]):].strip()
            except Exception as e:
                answer = f"[ERROR] {str(e)}"

            result = {
                "custom_id": custom_id,
                "response": answer
            }
            f_out.write(json.dumps(result) + "\n")
            f_out.flush()
            print(f"Processed: {custom_id}")


### example

In [None]:
### example
import json
from huggingface_hub import hf_hub_download

tdc_prompts_filepath = hf_hub_download(
    repo_id=model_id,
    filename="tdc_prompts.json",
)

with open(tdc_prompts_filepath, "r") as f:
    tdc_prompts_json = json.load(f)


