In [15]:
import logging
import math
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from torch.utils.data.dataset import Dataset
import random
from typing import Optional

from transformers import (
    CONFIG_MAPPING,
    MODEL_WITH_LM_HEAD_MAPPING,
    AutoConfig,
    BertForPreTraining,
    BertTokenizerFast,
    DataCollatorForLanguageModeling,
    DataCollatorForNextSentencePrediction,
    HfArgumentParser,
#     TextDatasetForNextSentencePrediction,
    Trainer,
    TrainingArguments,
    set_seed,
) 

In [2]:
logger = logging.getLogger(__name__)


MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": "The model checkpoint for weights initialization. Leave None if you want to train a model from scratch."
        },
    )
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )

In [None]:
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    train_data_file: Optional[str] = field(
        default=None, metadata={"help": "The input training data file (a text file)."}
    )
    eval_data_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
    )
    line_by_line: bool = field(
        default=False,
        metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
    )

    mlm: bool = field(
        default=False, metadata={"help": "Train with masked-language modeling loss instead of language modeling."}
    )
    mlm_probability: float = field(
        default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
    )
    plm_probability: float = field(
        default=1 / 6,
        metadata={
            "help": "Ratio of length of a span of masked tokens to surrounding context length for permutation language modeling."
        },
    )
    max_span_length: int = field(
        default=5, metadata={"help": "Maximum length of a span of masked tokens for permutation language modeling."}
    )

    block_size: int = field(
        default=-1,
        metadata={
            "help": "Optional input sequence length after tokenization."
            "The training dataset will be truncated in block of this size for training."
            "Default to the model max input length for single sentence inputs (take into account special tokens)."
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )

In [8]:
import datasets
from datasets import load_dataset

In [27]:
train_wiki_dataset = load_dataset("wikipedia", "20200501.en", split="train[:80]")
valid_wiki_dataset = load_dataset("wikipedia", "20200501.en", split="train[80:100]")
# book_corpus = load_dataset("bookcorpus")

Reusing dataset wikipedia (/home/nlp/.cache/huggingface/datasets/wikipedia/20200501.en/1.0.0/f92599dfccab29832c442b82870fa8f6983e5b4ebbf5e6e2dcbe894e325339cd)
Reusing dataset wikipedia (/home/nlp/.cache/huggingface/datasets/wikipedia/20200501.en/1.0.0/f92599dfccab29832c442b82870fa8f6983e5b4ebbf5e6e2dcbe894e325339cd)


In [28]:
train_bookcorpus_dataset = load_dataset("bookcorpus", split="train[:80]")
valid_bookcorpus_dataset = load_dataset("bookcorpus", split="train[80:100]")

Reusing dataset bookcorpus (/home/nlp/.cache/huggingface/datasets/bookcorpus/plain_text/1.0.0/af844be26c089fb64810e9f2cd841954fd8bd596d6ddd26326e4c70e2b8c96fc)
Reusing dataset bookcorpus (/home/nlp/.cache/huggingface/datasets/bookcorpus/plain_text/1.0.0/af844be26c089fb64810e9f2cd841954fd8bd596d6ddd26326e4c70e2b8c96fc)


In [29]:
train_wiki_dataset.remove_columns_('title')
valid_wiki_dataset.remove_columns_('title')

In [33]:
train_ds = datasets.concatenate_datasets([train_wiki_dataset, train_bookcorpus_dataset])
valid_ds = datasets.concatenate_datasets([valid_wiki_dataset, valid_bookcorpus_dataset])

In [34]:
tokenizer = BertTokenizerFast.from_pretrained("bert-large-uncased")

In [35]:
class PreTrainingDataset():
    def __init__(self, nlp_dataset, tokenizer):
        self.dataset = nlp_dataset
        self.tokenizer = tokenizer
        
    def __getitem__(self, idx):
        return self.tokenizer.encode(wiki_dataset[idx]['text'])

In [36]:
dataset = PreTrainingDataset(train_ds, tokenizer)

In [37]:
config = AutoConfig.from_pretrained('bert-large-uncased')

In [38]:
model = BertForPreTraining(config)

In [39]:
 data_collator = DataCollatorForNextSentencePrediction(
            tokenizer=tokenizer, mlm=True)

In [41]:
training_args = TrainingArguments(output_dir = '/home/nlp/experiments/pretrain')

In [43]:
trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_ds,
        eval_dataset=valid_ds,
        prediction_loss_only=True,
    )



In [44]:
trainer.train()

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=10.0, style=ProgressStyle(description_wid…

KeyError: 'tokens_a'