In [None]:
import os

import pandas as pd
from inspiredco.critique import Critique
from sentence_transformers import SentenceTransformer

from zeno import ZenoOptions, distill, metric, model, zeno

In [None]:
# Insert your API key here
%env INSPIREDCO_API_KEY=

In [None]:
df = pd.read_csv("wmt20-de-en.tsv", sep="\t", quoting=3, keep_default_na=False)  # quoting=3 is to ignore double quotes

df_small = df.head(1000)

In [None]:
df_small.head()

In [None]:
client = Critique(api_key=os.environ["INSPIREDCO_API_KEY"])

In [None]:
@model
def pred_fns(name):
    sentence_embed = SentenceTransformer('paraphrase-MiniLM-L6-v2')
    def pred(df, ops):
        embed = sentence_embed.encode(df[ops.label_column].tolist()).tolist()
        return df["translation"], embed

    return pred


@distill
def bert_score(df, ops):
    eval_dict = df[["source", ops.output_column, "label"]].to_dict("records")
    for d in eval_dict:
        d["references"] = [d.pop("label")]
        d["target"] = d.pop(ops.output_column)

    result = client.evaluate(
        metric="bert_score", config={"model": "bert-base-uncased"}, dataset=eval_dict
    )

    return [round(r["value"], 6) for r in result["examples"]]


@metric
def avg_bert_score(df, ops: ZenoOptions):
    return df[ops.distill_columns["bert_score"]].mean()


@distill
def length(df, ops):
    return df[ops.data_column].str.len()


In [None]:
zeno({
	"view": "text-classification",
	"functions": [pred_fns, bert_score, avg_bert_score, length],
	"metadata": df_small,
	"models": ["human-with-embeddings"],
	"data_column":"text",
	"label_column": "label",
	"batch_size": 500,
	"port": 8121,
})