In [1]:
import numpy as np
import pandas as pd

import torch
from torch import nn, optim
import torch.nn.functional as F

In [3]:
from transformers import RobertaTokenizer

tokenizer = RobertaTokenizer.from_pretrained('models/', model_max_length=512)

Downloading: 100%|██████████| 899k/899k [00:01<00:00, 864kB/s]
Downloading: 100%|██████████| 456k/456k [00:00<00:00, 541kB/s]


In [30]:
from transformers import RobertaConfig, RobertaForMaskedLM

config = RobertaConfig(
    vocab_size=52_000,
    max_position_embeddings=514,
    num_attention_heads=12,
    num_hidden_layers=6,
    type_vocab_size=1,
)

model = RobertaForMaskedLM(config)

In [17]:
model.num_parameters()

24591664

In [19]:
from transformers import LineByLineTextDataset

dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path="data/Gutenberg/txt/Jane Austen___Pride and Prejudice.txt",
    block_size=128,
)

In [24]:
tokenizer.decode(dataset.examples[12])

'<s> "But it is," returned she; "for Mrs. Long has just been here, and she</s>'

In [25]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

In [31]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="models/lm",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=64,
    save_steps=10_000,
    save_total_limit=2,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
    prediction_loss_only=True,
)

You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it.


In [33]:
trainer.train()

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]
Iteration:   0%|          | 0/168 [00:00<?, ?it/s][A
Iteration:   1%|          | 1/168 [00:03<10:07,  3.64s/it][A
Iteration:   1%|          | 2/168 [00:07<10:09,  3.67s/it][A
Iteration:   2%|▏         | 3/168 [00:10<09:57,  3.62s/it][A
Iteration:   2%|▏         | 4/168 [00:14<09:55,  3.63s/it][A
Iteration:   3%|▎         | 5/168 [00:17<09:31,  3.50s/it][A
Iteration:   4%|▎         | 6/168 [00:21<09:22,  3.47s/it][A
Iteration:   4%|▍         | 7/168 [00:24<09:29,  3.54s/it][A
Iteration:   5%|▍         | 8/168 [00:28<09:21,  3.51s/it][A
Iteration:   5%|▌         | 9/168 [00:31<09:18,  3.51s/it][A
Iteration:   6%|▌         | 10/168 [00:38<10:03,  3.82s/it]
Epoch:   0%|          | 0/1 [00:38<?, ?it/s]


KeyboardInterrupt: 