In [None]:
import os
import random
import time
import pandas as pd
import openai
from openai import RateLimitError, APIError



# Initialize OpenAI client
client = openai.OpenAI(
    api_key=os.environ.get("SAMBANOVA_API_KEY"),
    base_url="https://api.sambanova.ai/v1",
)


# Token counting using a simple string split method
def count_tokens(text):
    return len(text.split())

# Truncate text to fit within a specified token budget
def truncate_text(text, max_tokens, reserve_tokens=50):
    """
    Truncate text to fit within the token budget.
    reserve_tokens: Number of tokens reserved for the output.
    """
    tokens = text.split()
    max_input_tokens = max_tokens - reserve_tokens
    if len(tokens) > max_input_tokens:
        tokens = tokens[:max_input_tokens]
    return " ".join(tokens)

# Load the NYT dataset
def load_nyt_dataset(data_dir='data/nyt_data'):
    # File paths
    phrase_file = os.path.join(data_dir, 'phrase_text.txt')
    topics_file = os.path.join(data_dir, 'topics.txt')
    topics_label_file = os.path.join(data_dir, 'topics_label.txt')
    locations_file = os.path.join(data_dir, 'locations.txt')
    locations_label_file = os.path.join(data_dir, 'locations_label.txt')
    combined_label_file = os.path.join(data_dir, 'label.txt')

    # Load phrases
    with open(phrase_file, 'r', encoding='utf-8') as f:
        phrases = [line.strip() for line in f]

    # Load topics and their labels
    with open(topics_file, 'r', encoding='utf-8') as f:
        topics = [line.strip() for line in f]
    topic_labels = []
    with open(topics_label_file, 'r', encoding='utf-8') as f:
        for line in f:
            stripped_line = line.strip()
            if stripped_line.isdigit():  # Check if the line is a valid integer
                topic_labels.append(int(stripped_line))
            else:
                print(f"Warning: Skipping invalid line in topics_label.txt: '{line.strip()}'")

    # Map topic labels to topic names
    topic_mapping = {i: name for i, name in enumerate(topics)}

    # Load locations and their labels
    with open(locations_file, 'r', encoding='utf-8') as f:
        locations = [line.strip() for line in f]
    with open(locations_label_file, 'r', encoding='utf-8') as f:
        location_labels = [int(line.strip()) for line in f]

    # Map location labels to location names
    location_mapping = {i: name for i, name in enumerate(locations)}

    # Load combined labels
    with open(combined_label_file, 'r', encoding='utf-8') as f:
        combined_labels = [line.strip().split('\t') for line in f]

    # Combine data into a structured dictionary
    data = []
    for i, phrase in enumerate(phrases):
        topic = topic_mapping.get(topic_labels[i], "Unknown")
        location = location_mapping.get(location_labels[i], "Unknown")
        combined_label = combined_labels[i] if i < len(combined_labels) else ("Unknown", "Unknown")
        data.append({
            'text': phrase,
            'topic': topic,
            'location': location,
            'combined_label': combined_label
        })

    return data, topics, locations

# Select random samples for evaluation
def select_samples(data, num_samples_per_class=10):
    classes = set(item['topic'] for item in data)
    selected_data = {cls: [] for cls in classes}

    for item in data:
        if len(selected_data[item['topic']]) < num_samples_per_class:
            selected_data[item['topic']].append(item)

    return selected_data

# Prompting methods
def direct_prompt(text, classes_formatted):
    return (
        f"{text}\n\n"
        f"Classify the above text into one of the following categories: {classes_formatted}.\n"
        f"Provide only the category name as your answer."
    )

def chain_of_thought_prompt(text, classes_formatted):
    return (
        f"{text}\n\n"
        f"Think step by step to classify the above text into one of the following categories: {classes_formatted}.\n"
        f"Provide only the category name as your answer."
    )

def few_shot_prompt(text, examples, classes_formatted):
    examples_text = ""
    for ex_text, ex_label in examples:
        examples_text += f"Text: {ex_text}\nLabel: {ex_label}\n\n"
    return (
        f"{examples_text}"
        f"Now, classify the following text:\n"
        f"Text: {text}\n\n"
        f"Choose the category from the following options: {classes_formatted}.\n"
        f"Provide only the category name as your answer."
    )

# Estimate cost (adjust cost values as necessary)
def estimate_cost(input_tokens, output_tokens, model_name):
    cost_per_token = {
        'Meta-Llama-1B-Instruct': 0.0001,
        'Meta-Llama-3B-Instruct': 0.0003,
        'Meta-Llama-8B-Instruct': 0.0008,
    }
    total_tokens = input_tokens + output_tokens
    return total_tokens * cost_per_token.get(model_name, 0.0001)

def evaluate_model(
    data, topics, model_name, prompting_method,
    num_samples_per_class=10, token_budget=4000, delay=2, max_retries=3
):
    correct_predictions = 0
    total_cost = 0
    total_samples = 0
    classes_formatted = ", ".join(topics)

    selected_data = select_samples(data, num_samples_per_class)

    for cls, samples in selected_data.items():
        for item in samples:
            # Truncate text to fit within token budget
            text = truncate_text(item['text'], token_budget, reserve_tokens=100)

            # Create prompt based on the chosen prompting method
            if prompting_method == 'direct':
                prompt = direct_prompt(text, classes_formatted)
            elif prompting_method == 'chain-of-thought':
                prompt = chain_of_thought_prompt(text, classes_formatted)
            elif prompting_method == 'few-shot':
                examples = random.sample(data, min(3, len(data)))
                examples = [(ex['text'], ex['topic']) for ex in examples]
                prompt = few_shot_prompt(text, examples, classes_formatted)
            else:
                raise ValueError(f"Unknown prompting method: {prompting_method}")

            input_token_count = count_tokens(prompt)
            if input_token_count > token_budget:
                print(f"Skipping sample due to token limit: {input_token_count} tokens")
                continue

            retries = 0
            while retries <= max_retries:
                try:
                    # API call
                    completion = client.chat.completions.create(
                        model=model_name,
                        messages=[{"role": "user", "content": prompt}],
                        stream=True,
                        temperature=0.1,
                        top_p=0.1
                    )
                    response_text = "".join(
                        chunk.choices[0].delta.content or "" for chunk in completion
                    )

                    output_token_count = count_tokens(response_text)
                    total_cost += estimate_cost(input_token_count, output_token_count, model_name)

                    if response_text.strip().lower() == cls.lower():
                        correct_predictions += 1
                    break  # Exit retry loop on success

                except RateLimitError:
                    print("Rate limit exceeded. Sleeping for 30 seconds...")
                    time.sleep(30)
                    retries += 1

                except APIError as e:
                    print(f"API Error: {e}")
                    if "maximum sequence length" in str(e):
                        print("Token limit exceeded. Please adjust the input size.")
                    retries += 1

                except Exception as e:
                    print(f"Unexpected error: {e}")
                    break

            total_samples += 1
            time.sleep(delay)  # Respect delay between requests

    accuracy = correct_predictions / total_samples if total_samples > 0 else 0
    return accuracy, total_cost


# Run experiments
if __name__ == "__main__":
    data, topics, locations = load_nyt_dataset()
    print(f"Topics: {topics}")
    print(f"Locations: {locations}")

    models = ['Meta-Llama-3.2-1B-Instruct', 'Meta-Llama-3.2-3B-Instruct', 'Meta-Llama-3.1-8B-Instruct']
    prompting_methods = ['direct', 'chain-of-thought', 'few-shot']

    results = []
    for model_name in models:
        for method in prompting_methods:
            accuracy, cost = evaluate_model(data, topics, model_name, method)
            results.append({
                'model': model_name,
                'prompting_method': method,
                'accuracy': accuracy,
                'cost': cost
            })

    results_df = pd.DataFrame(results)
    results_df.to_csv('nyt_results.csv', index=False)
    print(results_df)


Topics: ['business', 'politics', 'sports', 'health', 'education', 'real_estate', 'arts', 'science', 'technology']
Locations: ['united_states', 'iraq', 'japan', 'china', 'britain', 'russia', 'germany', 'canada', 'france', 'italy']
API Error: Requested generation length 1 is not possible! The provided prompt is 4440 tokens long, so generating 1 tokens requires a sequence length of 4441, but the maximum supported sequence length is just 4096!
API Error: Requested generation length 1 is not possible! The provided prompt is 4440 tokens long, so generating 1 tokens requires a sequence length of 4441, but the maximum supported sequence length is just 4096!
API Error: Requested generation length 1 is not possible! The provided prompt is 4440 tokens long, so generating 1 tokens requires a sequence length of 4441, but the maximum supported sequence length is just 4096!
API Error: Requested generation length 1 is not possible! The provided prompt is 4440 tokens long, so generating 1 tokens requir