In [None]:
from pathlib import Path

DATASET = Path("../data/dataset-0.5.json")

MODEL = "./model/roberta-error-detection"

In [None]:
from datasets import load_dataset

dataset = load_dataset("json", data_files=str(DATASET))

def parse_dataset(example):
    return {
        "sentence": " ".join(example["sentence"]),
        "error": " ".join(example["error"]),
    }

SHARDS = 100

dataset = dataset["train"].map(parse_dataset, batched=False, num_proc=4)
dataset = dataset.select(range(1000))

In [None]:
from transformers import pipeline, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL, model_max_length=512)

MODEL = "./model/roberta-error-detection"

token_classifier = pipeline(
    "token-classification", model=MODEL, tokenizer=tokenizer, aggregation_strategy="simple", device=0
)

In [None]:
from tqdm.auto import tqdm
import json

output: list[str] = []

for i in tqdm(range(SHARDS), desc="Creating masked sentences."):
    _dataset = dataset.shard(num_shards=SHARDS, index=i)
    for j, labels in enumerate(token_classifier(_dataset["error"], batch_size=32)):
        errored_sentence = _dataset["error"][j]
        masked_sentences: list[str] = []

        for l in labels:  # noqa: E741
            if l["entity_group"] == "LABEL_1":
                masked_sentences.append(errored_sentence[: l["start"]] + "[MASK]" + errored_sentence[l["end"] :])

        if not masked_sentences:
            continue

        output.append({
            "sentence": _dataset["sentence"][j],
            "error": _dataset["error"][j],
            "masked": masked_sentences
        })

(DATASET.parent / "masked.json").write_text(json.dumps(output, indent=2))