In [19]:
import multiprocessing as mp
from itertools import chain

from datasets import load_dataset
import evaluate
from transformers import LongformerTokenizerFast
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments, LongformerForMaskedLM

# Setup

In [2]:
train_dataset = load_dataset("imdb", split="train")
eval_dataset = load_dataset("imdb", split="test")

Found cached dataset imdb (/Users/israelcampiotti/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)
Found cached dataset imdb (/Users/israelcampiotti/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


In [3]:
tokenizer = LongformerTokenizerFast.from_pretrained("longformer-pt-tokenizer")

In [4]:
cpu_count = mp.cpu_count() - 1
cpu_count

7

# Preprocess dataset

In [5]:
max_seq_length = tokenizer.model_max_length
max_seq_length

4096

In [8]:
def tokenize_function(examples):
    return tokenizer(examples["text"], return_special_tokens_mask=True)

def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= max_seq_length:
        total_length = (total_length // max_seq_length) * max_seq_length
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
        for k, t in concatenated_examples.items()
    }
    return result

In [9]:
def preprocess_dataset(dataset):
    dataset = dataset.remove_columns([col for col in dataset.column_names if col != "text"])
    # Tokenize the texts
    tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=cpu_count)
    # Group the texts
    tokenized_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=cpu_count)
    return tokenized_datasets

In [10]:
train_dataset = preprocess_dataset(train_dataset)
eval_dataset = preprocess_dataset(eval_dataset)

Loading cached processed dataset at /Users/israelcampiotti/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-04a0302be087f3d3_*_of_00007.arrow
Loading cached processed dataset at /Users/israelcampiotti/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-cf1d9581fd7c4a93_*_of_00007.arrow
Map (num_proc=7):  32%|███▏      | 8000/25000 [00:01<00:03, 4752.23 examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (5042 > 4096). Running this sequence through the model will result in indexing errors
Map (num_proc=7):  76%|███████▌  | 19000/25000 [00:02<00:00, 8471.21 examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (4754 > 4096). Running this sequence through the model will result in indexing errors
                                             

In [13]:
eval_dataset = eval_dataset.select(range(100))

# Metrics

In [17]:
def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        # Depending on the model and config, logits may contain extra tensors,
        # like past_key_values, but logits always come first
        logits = logits[0]
    return logits.argmax(dim=-1)

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # preds have the same shape as the labels, after the argmax(-1) has been calculated
    # by preprocess_logits_for_metrics
    labels = labels.reshape(-1)
    preds = preds.reshape(-1)
    mask = labels != -100
    labels = labels[mask]
    preds = preds[mask]
    return metric.compute(predictions=preds, references=labels)

metric = evaluate.load("accuracy")

# Data Collator

In [18]:
data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm_probability=0.15,
        pad_to_multiple_of=8,
    )

# Model

In [22]:
model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")

Downloading pytorch_model.bin: 100%|██████████| 597M/597M [00:23<00:00, 25.5MB/s] 


# Training Args

In [24]:
training_args = TrainingArguments(
    output_dir="./results",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    evaluation_strategy="steps",
    eval_steps=100,
    save_steps=100,
    save_total_limit=2,
    prediction_loss_only=True,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    # fp16=True,
    # fp16_opt_level="O2",
    warmup_steps=50,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=100,
    dataloader_num_workers=cpu_count,
    run_name="longformer-pt",
    use_mps_device=True,
)

# Trainer

In [25]:
trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )

In [26]:
trainer.train()

  0%|          | 0/8072 [00:00<?, ?it/s]You're using a LongformerTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a LongformerTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a LongformerTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a LongformerTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a LongformerTokeniz