In [1]:
from pathlib import Path
import datasets
from datasets import load_dataset
import numpy as np
import pandas as pd
from collections import Counter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
test_output_dir = Path().absolute().parent.parent.parent.joinpath('test_outputs')
predictions_dir = Path().absolute().parent.parent.parent.joinpath('predictions')

In [26]:
tfew_predictions_dir = predictions_dir.joinpath('tfew')
setfit_predictions_dir = predictions_dir.joinpath('setfit')

config_names = ["ade_corpus_v2"]

In [29]:
raft_datasets = {
        config: datasets.load_dataset("ought/raft", config)
        for config in config_names
}

In [7]:
from raft_baselines.classifiers import ChatGPTClassifier
import csv
import os
import shutil

In [22]:
def write_predictions(labeled, config):
    int2str = raft_datasets[config]["train"].features["Label"].int2str

    config_dir = os.path.join(predictions_dir, "combined", config)
    if os.path.isdir(config_dir):
        shutil.rmtree(config_dir)
    os.mkdir(config_dir)

    pred_file = os.path.join(config_dir, "predictions.csv")

    with open(pred_file, "w", newline="") as f:
        writer = csv.writer(
            f,
            quotechar='"',
            delimiter=",",
            quoting=csv.QUOTE_MINIMAL,
            skipinitialspace=True,
        )
        writer.writerow(["ID", "Label"])
        for row in labeled:
            writer.writerow([row["ID"], row["Label"]])

In [32]:
# for all of the subdirs in the tfew_predictions_dir print its name
disagreements = 0
total = 0
for config in config_names: 
    if config == "one_stop_english":
        num_prompt_training_examples = 10
    else:
        num_prompt_training_examples = 20
        
    classifier = ChatGPTClassifier(raft_datasets[config]['train'], model="gpt-4-1106-preview", config=config, num_prompt_training_examples=num_prompt_training_examples)
    # get the predictions from both the tfew and setfit dirs csvs
    tfew_predictions = pd.read_csv(tfew_predictions_dir.joinpath(config, 'predictions.csv'))
    setfit_predictions = pd.read_csv(setfit_predictions_dir.joinpath(config, 'predictions.csv'))
    # find the indices where the predictions are different
    disagree_indices = np.where(tfew_predictions['Label'] != setfit_predictions['Label'])[0]
    agree_indices = np.where(tfew_predictions['Label'] == setfit_predictions['Label'])[0]
    all_indices = np.arange(len(tfew_predictions))
    
    all_labeled = []
    for index in all_indices:
        if index in disagree_indices and index in agree_indices:
            raise Exception(f'Index {index} in both agree and disagree')
        
        one_labeled = {}
        if index in agree_indices:
            one_labeled['ID'] = tfew_predictions['ID'][index]
            one_labeled['Label'] = tfew_predictions['Label'][index]
            
        elif index in disagree_indices:
            disagreements += 1    
            # get the example from the raft dataset
            example = raft_datasets[config]["test"][int(index)]
            # classify the example
            output_probs = classifier.classify(example)
            output = max(output_probs.items(), key=lambda kv_pair: kv_pair[1])
            one_labeled['ID'] = tfew_predictions['ID'][index]
            one_labeled['Label'] = output[0]
        else:
            raise Exception(f'Index {index} not in agree or disagree')
        
        all_labeled.append(one_labeled)
        total += 1
    
    write_predictions(all_labeled, config)

print(f'Disagreements: {disagreements}')
print(f'Total: {total}')

prompt_tokens: 1155, completion_tokens: 1
prompt_tokens: 1091, completion_tokens: 1
prompt_tokens: 1045, completion_tokens: 1
prompt_tokens: 1097, completion_tokens: 1
prompt_tokens: 1083, completion_tokens: 1
prompt_tokens: 1194, completion_tokens: 1
prompt_tokens: 1092, completion_tokens: 1
prompt_tokens: 1106, completion_tokens: 1
prompt_tokens: 1140, completion_tokens: 1
prompt_tokens: 1042, completion_tokens: 1
prompt_tokens: 1121, completion_tokens: 1
prompt_tokens: 1198, completion_tokens: 1
prompt_tokens: 1027, completion_tokens: 1
prompt_tokens: 1227, completion_tokens: 1
prompt_tokens: 1123, completion_tokens: 1
prompt_tokens: 1141, completion_tokens: 1
prompt_tokens: 1172, completion_tokens: 1
prompt_tokens: 1096, completion_tokens: 1
prompt_tokens: 1113, completion_tokens: 1
prompt_tokens: 1169, completion_tokens: 1
prompt_tokens: 1150, completion_tokens: 1
prompt_tokens: 1174, completion_tokens: 1
prompt_tokens: 1134, completion_tokens: 1
prompt_tokens: 1090, completion_to