In [None]:
from pathlib import Path

DATASET = Path("../data/masked.json")

MODEL = "ufal/robeczech-base"

In [None]:
import pandas as pd
from datasets import Dataset

df = pd.read_json(DATASET)
df.head()
dataset = Dataset.from_pandas(df)

In [None]:
from transformers import pipeline

corrector = pipeline("fill-mask", model=MODEL, tokenizer=MODEL, device=0, top_k=3)

### Example usage

In [None]:
masked_sentences = dataset["masked"]

for predictions in corrector(masked_sentences[0]):
    print(predictions[:2])

In [None]:
SHARDS = 2

In [None]:
from tqdm.auto import tqdm
from typing import Callable, Iterator

def explode_masked_and_fix(df: pd.DataFrame) -> pd.DataFrame:
    df = df.explode("masked")
    df["fix"] = None
    df.reset_index(drop=True, inplace=True)
    return df

def _merge_masked(row):
    sentence = row["error"].split(' ')
    for i, masked in enumerate(row["masked"]):
        try:
            sentence[masked.split(" ").index("[MASK]")] = "[MASK]"
        except ValueError:
            del row["fix"][i]

    return row["sentence"], " ".join(sentence), row["fix"]

def implode_and_merge_masked(df: pd.DataFrame) -> pd.DataFrame:
    df = df.groupby(["sentence", "error"]).agg(
        {
            "masked": list,
            "fix": list,
        }
    ).reset_index()

    df["sentence"], df["masked"], df["fix"] = zip(*df.apply(_merge_masked, axis=1))
    return df

def create_dataset(fnc: Callable[[Iterator[dict[str, str]]], Iterator[str]]) -> pd.DataFrame:
    final_df = pd.DataFrame()

    for i in tqdm(range(SHARDS)):
        _dataset = dataset.shard(num_shards=SHARDS, index=i)

        df = _dataset.to_pandas()
        df = explode_masked_and_fix(df)

        predictions = corrector(df["masked"].to_list(), batch_size=32)

        temp = [{"error": df.iloc[j]["error"],
                 "masked": df.iloc[j]["masked"],
                 "predictions": prediction } for j, prediction in enumerate(predictions)]


        for j, prediction in enumerate(fnc(temp)):
            df.loc[j, "fix"] = prediction

        df = implode_and_merge_masked(df)

        final_df = pd.concat([final_df, df], ignore_index=True)

    return final_df
    

### Experiment 1
We will replace `[MASK]` with the suggestion that has biggest score.

In [None]:
def process_prediction(data: Iterator[dict[str, str]]) -> Iterator[str]:
    for row in data:
        yield row["predictions"][0]["token_str"].strip()

result = create_dataset(process_prediction)
result.to_json(DATASET.parent / "result-experiment-1.json", orient="records", lines=True)