## Preparation of a dataset for span-mlm

Adds noise to a given dataset by masking random spans of tokens. The resulting dataset can be used for span-masked-language-modelling with the notebook [mt5_smlm_train.ipynb](mt5_smlm_train.ipynb).

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from mt5_smlm_scripts import DataCollatorForT5MLM, compute_input_and_target_lengths, create_smlm_example

In [None]:
from datasets import load_from_disk

# ds = load_from_disk("german_ds")
ds = load_from_disk("german_ds_shuffled")

In [None]:
ds

Load the previously trained german SentencePiece tokenizer

In [None]:
tokenizer_id = "german_tokenizer"
model_id = "google/mT5-small"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)

In [None]:
context_length = 128

def tokenize(input):
  outputs = tokenizer(
      input['text'],
      truncation=True,
      max_length=context_length,
      return_overflowing_tokens=True,
      return_length=True
  )
  input_batch = []
  for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
    if length == context_length:
        input_batch.append(input_ids)
  return {"input_ids": input_batch}


tokenized_dataset = ds.map(
      tokenize, batched=True, remove_columns=ds["train"].column_names
  )

In [None]:
from itertools import chain
max_seq_length = context_length
# these parameters determine how much noise there will be in the span masked dataset
mlm_probability = 0.15
mean_noise_span_length = 3.0

# ===============================================================
# From
# https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py
expanded_inputs_length, targets_length = compute_input_and_target_lengths(
    inputs_length=max_seq_length,
    noise_density=mlm_probability,
    mean_noise_span_length=mean_noise_span_length,
)


# Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= expanded_inputs_length:
        total_length = (total_length // expanded_inputs_length) * expanded_inputs_length
    # Split by chunks of max_len.
    result = {
        k: [t[i: i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]
        for k, t in concatenated_examples.items()
    }
    return result

# tokenized_dataset = ds["train"].select(range(100000)).map(
tokenized_dataset_grouped = tokenized_dataset.map(
    group_texts,
    batched=True,
    # num_proc=data_args.preprocessing_num_workers,
    # load_from_cache_file=not data_args.overwrite_cache,
)
# ===============================================================

In [None]:
max_seq_length = 128
mlm_probability = 0.15
mean_noise_span_length = 3.0

In [None]:
collator = DataCollatorForT5MLM(
    tokenizer=tokenizer,
    input_length=max_seq_length,
    target_length=targets_length,
    noise_density=mlm_probability,
    mean_noise_span_length=mean_noise_span_length,
    pad_token_id=model.config.pad_token_id,
    decoder_start_token_id=model.config.decoder_start_token_id)

smlm_dataset = tokenized_dataset_grouped.map(
    lambda x: create_smlm_example(x, collator),
    batched=True)

In [None]:
save_path = "german_ds_smlm_noised"
smlm_dataset.save_to_disk(save_path)