In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the dataset
dataset = load_dataset("ought/raft", "tai_safety_research")

In [4]:
# Load the tokenizer and model
model_name = "bigscience/mt0-large"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

class_names = ["TAI safety research", "not TAI safety research"]

# function to classify text
def classify_text(get_prompt, text):
    prompt = get_prompt(text, class_names)
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
    
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=8, num_return_sequences=1)
    
    classification = tokenizer.decode(outputs[0], skip_special_tokens=True)
    classification = classification.split(" Classification:")[-1].strip()
    
    # if any class name matches with classification, return the corresponding class name
    # otherwise, return "unknown"
    for class_name in class_names:
        if class_name == classification:
            return class_name
    else:
        return "unknown"

sample = dataset["train"].shuffle()

In [5]:
sample = dataset["train"].shuffle()

def calculate_accuracy(get_prompt):
    results = []
    for item in sample:
        abstract = item["Abstract Note"]
        true_label = class_names[0] if item["Label"] == 1 else class_names[1]
        predicted_label = classify_text(get_prompt, abstract)
        result = {
            "Abstract": abstract,
            "True Label": true_label.strip(),
            "Predicted Label": predicted_label.strip()
        }
        results.append(result)

        print(f"Abstract: {result['Abstract'][:100]}...")
        print(f"True Label: {result['True Label']}")
        print(f"Predicted Label: {result['Predicted Label']}\n\n")

    correct_predictions = sum(1 for result in results if result["True Label"] == result["Predicted Label"])
    accuracy = correct_predictions / len(results)
    return accuracy

In [14]:
def get_prompt(abstract_note: str, class_names: list[str]) -> str:
     return f"""Abstract: {abstract_note}\n\n Is this abstract {class_names[0]} or {class_names[1]}? """


In [15]:
calculate_accuracy(get_prompt)

Abstract: How can we design good goals for arbitrarily intelligent agents? Reinforcement learning (RL) is a na...
True Label: TAI safety research
Predicted Label: unknown


Abstract: Consider an AI that follows its own motivations. We’re not entirely sure what its motivations are, b...
True Label: TAI safety research
Predicted Label: unknown


Abstract: Inattentional blindness is the psychological phenomenon that causes one to miss things in plain sigh...
True Label: not TAI safety research
Predicted Label: unknown


Abstract: ...
True Label: TAI safety research
Predicted Label: not TAI safety research


Abstract: Procrastination takes a considerable toll on people’s lives, the economy and society at large. Procr...
True Label: not TAI safety research
Predicted Label: unknown


Abstract: I think of ambitious value learning as a proposed solution to the specification problem, which I def...
True Label: TAI safety research
Predicted Label: unknown


Abstract: Specification gaming is a be

0.24