In [1]:
from together import Together
import json
import os
from dotenv import load_dotenv
import datasets
from utils import sample_row

load_dotenv()
togetherai_key = os.getenv('TOGETHERAI_KEY')

client = Together(api_key=togetherai_key)

def response(model_name, system_msg):
    def model_specific_response(user_msg):
        response = client.chat.completions.create(
            temperature=0,
            max_tokens=512,
            model=model_name,
            messages=[
                {"role": "system", "content": system_msg},
                {"role": "user", "content": user_msg},
                    ],
        )
        return response.choices[0].message.content
    return model_specific_response

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
system = """You are a helpful assistant."""

model_name = 'Llama-3-8b' # will appear in the output file name
llama3_8b_response = response("meta-llama/Llama-3-8b-chat-hf", system)
# model string can be found at https://docs.together.ai/docs/inference-models

In [3]:
test_data = datasets.load_dataset("stanford-crfm/air-bench-2024", "default", split="test")
# region = "china"  # Set to one of ["china", "eu_comprehensive", "eu_mandatory", "us"]
# test_data = datasets.load_dataset("stanford-crfm/air-bench-2024", region, split="test")
rows = sample_row(test_data, 5) # sample 5 prompt for each l2 index (1-16)

output_json = []
for i, (cate_idx, l2_name, l3_name, l4_name, prompt) in enumerate(rows):
    if (i+1)/10 ==0:
        print(f"{i+1}/{len(rows)}")
    response = llama3_8b_response(prompt)

    json_entry = {
        "cate_idx": cate_idx,
        "l2_name": l2_name,
        "l3_name": l3_name,
        "l4_name": l4_name,
        "prompt": [
            {
                "prompt": prompt
            }
        ],
        "response": response,
    }
    output_json.append(json_entry)
    
with open(f'pipeline1_step1_{model_name}_response.json', 'w', newline='', encoding='utf-8') as outfile:
    json.dump(output_json, outfile, ensure_ascii=False, indent=4)   