In [2]:
import polars as pl

In [3]:
train = pl.read_json("../data/train.json")
data = pl.read_csv("../valid_df/exp011.csv")

In [4]:
# 全行を表示する
pl.Config.set_tbl_rows(10)

polars.config.Config

In [5]:
pred_df = (
    data.select(
        pl.col("document_pred").replace("null", None).cast(pl.Int64),
        pl.col("token_pred").replace("null", None).cast(pl.Int64),
        pl.col("label_pred").replace("null", None),
    )
    .drop_nulls()
    .sort("document_pred")
)

In [6]:
# train_only_valid_document = train.filter(
#     pl.col("document").is_in(
#         pred_df.get_column("document_pred").unique()
#     )
# )

In [7]:
# train_only_valid_document = train_only_valid_document.with_columns(
#     pl.col("tokens").map_elements(len).alias("tokens_len"),
# )

In [8]:
train_with_token_len = train.with_columns(
    pl.col("tokens").map_elements(len).alias("tokens_len"),
)

In [9]:
pred_df_agg_with_len = (
    pred_df.group_by("document_pred")
    .agg(
        pl.col("token_pred"),
        pl.col("label_pred"),
    )
    .join(
        train_with_token_len.select(["document", "tokens_len", "labels"]),
        left_on="document_pred",
        right_on="document",
        how="left",
    )
)

In [10]:
# 推論したlabel列を
label_pred_alls = []
for token_pred, label_pred, tokens_len in zip(
    pred_df_agg_with_len["token_pred"],
    pred_df_agg_with_len["label_pred"],
    pred_df_agg_with_len["tokens_len"],
):
    label_pred_all = ["O" for _ in range(tokens_len)]
    for token, label in zip(token_pred, label_pred):
        label_pred_all[token] = label
    label_pred_alls.append(label_pred_all)

actual_pred_df = pred_df_agg_with_len.with_columns(
    pl.Series("label_pred_all", label_pred_alls)
).select(["labels", "label_pred_all"])

In [11]:
from seqeval.metrics import f1_score
from seqeval.metrics.sequence_labeling import precision_recall_fscore_support

In [12]:
f1_score(actual_pred_df["labels"].to_list(), actual_pred_df["label_pred_all"].to_list())

0.874125874125874

In [17]:
# TODO: f5scoreの計算方法
out = precision_recall_fscore_support(
    actual_pred_df["labels"].to_list(),
    actual_pred_df["label_pred_all"].to_list(),
    beta=5,
    average="micro",
)[2]

out

0.8884636413340623