In [1]:
import fastrepl

In [2]:
from IPython.display import clear_output

In [None]:
%env OPENAI_API_KEY=
%env DEEPINFRA_API_KEY=

In [4]:
from fastrepl.utils import map_number_range
from datasets import load_dataset

dataset = load_dataset("yelp_review_full", split="test")
dataset = dataset.shuffle(seed=12)
dataset = dataset.select(range(100))
dataset = dataset.rename_column("text", "sample")
dataset = dataset.map(
    lambda row: {
        "reference": map_number_range(row["label"], 0, 4, 0, 10),
        "sample": row["sample"],
    },
    remove_columns=["label"],
)

dataset = fastrepl.Dataset.from_hf(dataset)

In [5]:
dataset

fastrepl.Dataset({
    features: ['sample', 'reference'],
    num_rows: 100
})

In [6]:
clear_output(wait=True)

eval1 = fastrepl.SimpleEvaluator(
    node=fastrepl.LLMGradingHead(
        model="gpt-3.5-turbo-0613",
        context="You will get a input text from Yelp review.",
        number_from=0,
        number_to=10,
        position_debias_strategy="shuffle",
        references=[
            ("this is the best", "10"),
            ("this is the worst", "0"),
        ],
    )
)

result1 = fastrepl.local_runner(
    evaluator=eval1,
    dataset=dataset,
    output_feature="prediction",
).run()


print(
    result1.compare(
        "accuracy",
        prediction_column="prediction",
        reference_column="reference",
    )
)

print(result1.compare("mse"))
print(result1.compare("mae"))

Output()

{'accuracy': 0.61}
{'mse': 2.625}
{'mae': 1.0}


In [7]:
clear_output(wait=True)

eval2 = fastrepl.SimpleEvaluator(
    node=fastrepl.LLMGradingHead(
        model="deepinfra/mistralai/Mistral-7B-Instruct-v0.1",
        context="You will get a input text from Yelp review.",
        number_from=0,
        number_to=10,
        position_debias_strategy="shuffle",
        references=[
            ("this is the best", "10"),
            ("this is the worst", "0"),
        ],
    )
)


result2 = fastrepl.local_runner(
    evaluator=eval2,
    dataset=dataset,
    output_feature="prediction",
).run()

print(result2.compare("accuracy"))
print(result2.compare("mse"))
print(result2.compare("mae"))

Output()

{'accuracy': 0.42}
{'mse': 4.1875}
{'mae': 1.525}
