In [101]:
from pathlib import Path

DATASET = Path("../data/err-0.5/test.json")

MODEL = "./model/roberta-error-detection"

In [102]:
from datasets import Dataset
import math

dataset = Dataset.from_json(str(DATASET))

def parse_dataset(example):
    return {
        "sentence": " ".join(example["sentence"]),
        "error": " ".join(example["error"]),
    }

SHARDS = math.ceil(len(dataset) / 1000)

dataset = dataset.map(parse_dataset, batched=False, num_proc=4)

In [103]:
from transformers import pipeline, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL, model_max_length=512)

MODEL = "./model/roberta-error-detection"

token_classifier = pipeline(
    "token-classification", model=MODEL, tokenizer=tokenizer, aggregation_strategy="simple", device=0
)

In [104]:
# error_sentence = "Baal Baal Lev je krátký hudební televizní film z roku 1997 rmežiséra Eytna FoxeA podle scénáře který napsal Gal OhovskOhovski"
# a = token_classifier(error_sentence)

# print(a)

# masked_sentence = error_sentence.split()

# print(len(masked_sentence))

# _index_mask = [[x]*len(w) + [None] for x, w in enumerate(masked_sentence)]

# print(_index_mask, len(_index_mask))

# index_mask = []

# for x in _index_mask:
#     index_mask.extend(x)

# print(index_mask, len(index_mask), len(error_sentence))

# for l in a:  # noqa: E741
#     if l["entity_group"] == "LABEL_1":
#         e = set(index_mask[l["start"]:l["end"]])
#         e.remove(None)
#         print(e)
#         for x in e:
#             masked_sentence[x] = "[MASK]"

# print(masked_sentence)

In [109]:
from tqdm.auto import tqdm

output: list[str] = []

for i in tqdm(range(SHARDS), desc="Creating masked sentences."):
    _dataset = dataset.shard(num_shards=SHARDS, index=i)

    for j, labels in enumerate(token_classifier(_dataset["error"], batch_size=32)):
        masked_sentence: list[str] = _dataset["error"][j].split(" ")
        
        # this code be like 👌
        _index_mask = [[x]*len(w) + [None] for x, w in enumerate(masked_sentence)]
        index_mask = []

        for x in _index_mask:
            index_mask.extend(x)

        for l in labels:  # noqa: E741
            if l["entity_group"] == "LABEL_1":
                e = set(index_mask[l["start"]:l["end"]])
                e.discard(None)
                for x in e:
                    masked_sentence[x] = "[MASK]"

        output.append({
            "sentence": _dataset["sentence"][j],
            "error": _dataset["error"][j],
            "masked": " ".join(masked_sentence)
        })

Dataset.from_list(output).to_json(DATASET.parent.parent / "masked.json")

Creating masked sentences.: 100%|██████████| 10/10 [01:24<00:00,  8.44s/it]
Creating json from Arrow format: 100%|██████████| 10/10 [00:00<00:00, 270.91ba/s]


4624247