In [None]:
# !pip install transformers

from transformers import pipeline
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import classification_report, accuracy_score

In [None]:
# Load zero-shot classifier (multilingual)
classifier = pipeline("zero-shot-classification", model="joeddav/xlm-roberta-large-xnli")

In [None]:
df = pd.read_csv("/content/HIPE-2022-v2.1-letemps-test-fr.tsv",
                 sep="\t",
                 comment="#",
                 quoting=3,
                 names=["TOKEN", "NE-COARSE-LIT", "NE-COARSE-METO", "NE-FINE-LIT", "NE-FINE-METO",
                        "NE-FINE-COMP", "NE-NESTED", "NEL-LIT", "NEL-METO", "MISC"])
df = df[df["TOKEN"] != "TOKEN"]

df["sentence_id"] = (df["MISC"].fillna("_").str.contains("EndOfSentence")).cumsum()

sentences_df = df.groupby("sentence_id").agg({
    "TOKEN": lambda x: " ".join(x.astype(str)),
    "NE-COARSE-LIT": list
}).reset_index()

def has_entity(labels, prefix):
    return any(l.startswith(f"B-{prefix}") or l.startswith(f"I-{prefix}") for l in labels)

sentences_df["person"] = sentences_df["NE-COARSE-LIT"].apply(lambda x: has_entity(x, "pers"))
sentences_df["location"] = sentences_df["NE-COARSE-LIT"].apply(lambda x: has_entity(x, "loc"))
sentences_df["organization"] = sentences_df["NE-COARSE-LIT"].apply(lambda x: has_entity(x, "org"))

# Limit to subset for testing speed
# sentences_df = sentences_df.iloc[:200]

In [None]:
# Run zero-shot classification
ner_labels = ["person", "location", "organization"]
zs_results = []

for text in tqdm(sentences_df["TOKEN"]):
    output = classifier(text, candidate_labels=ner_labels, multi_label=True)
    zs_results.append(dict(zip(output["labels"], output["scores"])))


In [None]:
zs_df = pd.DataFrame(zs_results)
zs_binary = (zs_df >= 0.5).astype(bool)  # threshold at 0.5

evaluation_df = pd.concat([sentences_df[["person", "location", "organization"]].reset_index(drop=True), zs_binary], axis=1, keys=["gold", "pred"])
gold = evaluation_df["gold"]
pred = evaluation_df["pred"]

# Evaluation
for label in ner_labels:
    print(f"\n Evaluation for {label.upper()}")
    print(classification_report(gold[label], pred[label], digits=3))
    print("Accuracy:", accuracy_score(gold[label], pred[label]))
