In [73]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling, TrainingArguments, Trainer
from huggingface_hub import login
import math

In [2]:
imdb_dataset = load_dataset("imdb")
imdb_dataset

README.md: 0.00B [00:00, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

unsupervised-00000-of-00001.parquet:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})

In [13]:
sample = imdb_dataset["train"].shuffle().select(range(3))

for row in sample:
    print(f"Review : {row['text']}")
    print(f"Label : {row['label']}")

Review : Directed by the duo Yudai Yamaguchi (Battlefield Baseball) and Jun'ichi Yamamoto "Meatball Machine" is apparently a remake of Yamamoto's 1999 movie with the same name. I doubt I'll ever get a chance to see the original so I'll just stick commenting on this one. First of what is "Meatball Machine" ? A simple in noway pretentious low budget industrial splatter flick packed with great make up effects and gore. It's not something you'll end up writing books about but it's nevertheless entertaining if you dig this type of cinema.<br /><br />"Meatball Machine" follows the well known plot. Boy loves girl but is too afraid to ask her on a date. Boy finally meets girl. Girl gets infected by a parasitic alien creature that turns her into a homicidal cyborg. Boy, in turn does also transform into said thing, and goes on a quest to save his love. Will he succeed? Who gives a damn, as long as there is carnage and death I'm satisfied.<br /><br />The plot is simple, relatively clichéd but it 

### Processing the data

In [16]:
checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [18]:
def tokenize_function(examples):
    tokenized_inputs = tokenizer(examples["text"])
    if tokenizer.is_fast:
        tokenized_inputs["word_ids"] = [tokenized_inputs.word_ids(i) for i in range(len(tokenized_inputs["input_ids"]))]
    return tokenized_inputs

tokenized_datasets = imdb_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text", "label"]
)

tokenized_datasets

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (720 > 512). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids'],
        num_rows: 50000
    })
})

There are sentences which are greater than the max length allowed in the model, so we will divide each sentence in chunks of fixed size

In [19]:
CHUNK_SIZE = 128

In [20]:
tokenized_samples = tokenized_datasets["train"][:3]

for idx, sample in enumerate(tokenized_samples["input_ids"]):
    print(f"Review {idx} has length {len(sample)}")

Review 0 has length 363
Review 1 has length 304
Review 2 has length 133


In [28]:
concatenated_samples = {
    k: sum(tokenized_samples[k], []) for k in tokenized_samples.keys()
}
print(f"Concatenated length : {len(concatenated_samples['input_ids'])}")

Concatenated length : 800


In [34]:
chunks = {
    k: [t[i: i+CHUNK_SIZE] for i in range(0, len(concatenated_samples[k]), CHUNK_SIZE)]
    for k, t in concatenated_samples.items()
}

for chunk in chunks["input_ids"]:
    print(f"Chunk length : {len(chunk)}")

Chunk length : 128
Chunk length : 128
Chunk length : 128
Chunk length : 128
Chunk length : 128
Chunk length : 128
Chunk length : 32


The last chunk will be smaller than the max chunk size, so we can either drop it or pad it

In [44]:
def group_texts(examples):
    concatenated_text = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_text[list(examples.keys())[0]])
    total_length = (total_length // CHUNK_SIZE) * CHUNK_SIZE

    results = {
        k: [t[i: i+CHUNK_SIZE] for i in range(0, total_length, CHUNK_SIZE)]
        for k, t in concatenated_text.items()
    }
    results['labels'] = results['input_ids'].copy()
    return results

In [45]:
lm_datasets = tokenized_datasets.map(group_texts, batched=True)
lm_datasets

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 61291
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 59904
    })
    unsupervised: Dataset({
        features: ['input_ids', 'attention_mask', 'word_ids', 'labels'],
        num_rows: 122957
    })
})

### FineTuning the model

In [48]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [51]:
samples = [lm_datasets["train"][i] for i in range(2)]

for sample in samples:
    _ = sample.pop("word_ids")

for chunk in data_collator(samples)["input_ids"]:
    print(f"\n{tokenizer.decode(chunk)}")


[CLS] i rented i am curious - yellow [MASK] my [MASK] store because of all the controversy that surrounded it when it [MASK] [MASK] released in 1967. i also heard that at first it was seized by u. s. customs [MASK] it ever tried to enter this country, therefore being a fan of films considered " controversial " i originating had to see [MASK] [MASK] myself. < br / > < br / > the plot is centered around a keynote swedish drama [MASK] named lena who wants to learn everything she can about life. in particular she wants to [MASK] her attentions [MASK] making some sort of documentary [MASK] what the average swede thought about [MASK] [MASK] issues such

as the [MASK] war [MASK] race issues in the united states. in between clubs politicians [MASK] ordinary denize [MASK] of stockholm about [MASK] opinions subtle politics, she has sex [MASK] her drama teacher, classmates, [MASK] married men. < br / > < [MASK] / > what sentence me [MASK] i am [MASK] - yellow is [MASK] 40 years ago, this was con

In [52]:
train_size = 10_000
test_size = int(0.1 * train_size)

downsampled_datasets = lm_datasets["train"].train_test_split(
    train_size = train_size, test_size = test_size, seed=42
)

In [72]:
batch_size = 64
logging_steps = len(downsampled_datasets["train"]) // batch_size

model = AutoModelForMaskedLM.from_pretrained(checkpoint)

args = TrainingArguments(
    "distilbert-finetuned-mlm-imdb",
    overwrite_output_dir=True,
    eval_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    push_to_hub=True,
    fp16=True,
    logging_steps=logging_steps,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size
)

trainer = Trainer(
    args=args,
    model=model,
    train_dataset=downsampled_datasets["train"],
    eval_dataset=downsampled_datasets["test"],
    data_collator=data_collator,
    processing_class=tokenizer,
)

In [74]:
eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

>>> Perplexity: 21.94


Lower perplexity means a better language model

In [75]:
trainer.train()

Epoch,Training Loss,Validation Loss,Model Preparation Time
1,2.6804,2.493174,0.0
2,2.5832,2.448004,0.0
3,2.5255,2.480797,0.0


TrainOutput(global_step=471, training_loss=2.595887607323389, metrics={'train_runtime': 656.9512, 'train_samples_per_second': 45.665, 'train_steps_per_second': 0.717, 'total_flos': 994208670720000.0, 'train_loss': 2.595887607323389, 'epoch': 3.0})

In [76]:
eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

>>> Perplexity: 12.02


In [77]:
trainer.push_to_hub()

Uploading...:   0%|          | 0.00/268M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/praful-goel/distilbert-finetuned-mlm-imdb/commit/b0d4bd2804f393f6453635c5c1b4259a43f0746d', commit_message='End of training', commit_description='', oid='b0d4bd2804f393f6453635c5c1b4259a43f0746d', pr_url=None, repo_url=RepoUrl('https://huggingface.co/praful-goel/distilbert-finetuned-mlm-imdb', endpoint='https://huggingface.co', repo_type='model', repo_id='praful-goel/distilbert-finetuned-mlm-imdb'), pr_revision=None, pr_num=None)

### Using our fine-tuned model

In [78]:
from transformers import pipeline

In [79]:
model_checkpoint = "praful-goel/distilbert-finetuned-mlm-imdb"

mask_filler = pipeline("fill-mask", model=model_checkpoint)

text = "This is a great [MASK]"

preds = mask_filler(text)

for pred in preds:
    print(f">>> {pred['sequence']}")

config.json:   0%|          | 0.00/529 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

Device set to use cuda:0


>>> this is a great!
>>> this is a great.
>>> this is a great deal
>>> this is a great film
>>> this is a great movie
