# Text Generation using DistilBERT

In [None]:
!pip install git+https://github.com/huggingface/transformers

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
import torch

In [3]:
# set the device to be used as GPU, otherwise use the CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [4]:
from pathlib import Path
import glob

# 0. Dataset Stats

#### We curate 9671 papers accepted at NeurIPS conference from 1987 to 2019. Source: https://papers.nips.cc/
#### We curate 5213 papers submitted to ICLR conference from 2013 to 2020. Source: https://openreview.net/group?id=ICLR.cc

#### We create the training dataset by extracting sentences from the abstracts and the titles of the papers, using nltk.sentence_tokenize. 
#### For training, each line contains one sentence.

In [5]:
!head -7 training_data_paper_abstracts_neurips_iclr.txt

Solving a Class of Non-Convex Min-Max Games Using Iterative First Order Methods
Recent applications that arise in machine learning have surged significant interest in solving min-max saddle point games.
This problem has been extensively studied in the convex-concave regime for which a global equilibrium solution can be computed efficiently.
In this paper, we study the problem in the non-convex regime and show that an ε–first order stationary point of the game can be computed when one of the player’s objective can be optimized to global optimality efficiently.
In particular, we first consider the case where the objective of one of the players satisfies the Polyak-Łojasiewicz (PL) condition.
For such a game, we show that a simple multi-step gradient descent-ascent algorithm finds an ε–first order stationary point of the problem in Õ(ε−2) iterations.
Then we show that our framework can also be applied to the case where the objective of the “max-player" is concave.


## 1. Train DistilBERT tokenizer

Uses the same WordPiece tokenizer as BERT. BertTokenizerFast is the faster implementation by huggingface. 

In [6]:
from tokenizers import BertWordPieceTokenizer

In [8]:
!mkdir DistilBERTModel

In [9]:
training_file_path = glob.glob("training_data_paper_abstracts_neurips_iclr.txt")

In [10]:
BertWordPieceTokenizer??

In [11]:
# Initialize the tokenizer
tokenizer = BertWordPieceTokenizer(
    handle_chinese_chars=False,
    strip_accents=True,
    lowercase=False,             # This might be helpful to retain entities such as MLP, LSTM, etc.
)

# Train the WP tokenizer on the data
tokenizer.train(
    training_file_path,
    vocab_size=50000,
    min_frequency=2,
    show_progress=True,
    special_tokens=["[UNK]", "[SEP]", "[CLS]", "[PAD]", "[MASK]"],
    limit_alphabet=1000,
    wordpieces_prefix="##",
)

In [12]:
# Save the model to the directory
tokenizer.save_model("DistilBERTModel")

['DistilBERTModel/vocab.txt']

Check if the tokenizer can be loaded from the vocab again

In [13]:
from tokenizers.implementations import BertWordPieceTokenizer
from tokenizers.processors import BertProcessing

tokenizer = BertWordPieceTokenizer("./DistilBERTModel/vocab.txt")

In [14]:
# Adding the [SEP] and [CLS] topens

tokenizer._tokenizer.post_processor = BertProcessing(("[SEP]", tokenizer.token_to_id("[SEP]")), 
                                                     ("[CLS]", tokenizer.token_to_id("[CLS]")))
tokenizer.enable_truncation(max_length=512)

Check:

In [15]:
tokenizer.encode("In this paper, we present a novel method to train LSTM.").tokens

['[CLS]',
 'in',
 'this',
 'paper',
 ',',
 'we',
 'present',
 'a',
 'novel',
 'method',
 'to',
 'train',
 'lstm',
 '.',
 '[SEP]']

## 2. Train LM

Import predefined config of the DIstilBERT Model

In [16]:
import transformers
from transformers import BertTokenizerFast

In [17]:
from transformers import DistilBertConfig, DistilBertForMaskedLM

In [18]:
DistilBertConfig??
# Contains 6 layers, 12 attention heads, gelu activation

In [19]:
config = DistilBertConfig(vocab_size=50000)

In [20]:
# FastTokenizer Impl. Load from location DistilBERTModel.
tokenizer = BertTokenizerFast.from_pretrained("./DistilBERTModel", max_len=512, do_lower_case=False)

In [21]:
# Initialize the model
model = DistilBertForMaskedLM(config=config)

In [22]:
for name, p in model.named_parameters():
    print(name)

distilbert.embeddings.word_embeddings.weight
distilbert.embeddings.position_embeddings.weight
distilbert.embeddings.LayerNorm.weight
distilbert.embeddings.LayerNorm.bias
distilbert.transformer.layer.0.attention.q_lin.weight
distilbert.transformer.layer.0.attention.q_lin.bias
distilbert.transformer.layer.0.attention.k_lin.weight
distilbert.transformer.layer.0.attention.k_lin.bias
distilbert.transformer.layer.0.attention.v_lin.weight
distilbert.transformer.layer.0.attention.v_lin.bias
distilbert.transformer.layer.0.attention.out_lin.weight
distilbert.transformer.layer.0.attention.out_lin.bias
distilbert.transformer.layer.0.sa_layer_norm.weight
distilbert.transformer.layer.0.sa_layer_norm.bias
distilbert.transformer.layer.0.ffn.lin1.weight
distilbert.transformer.layer.0.ffn.lin1.bias
distilbert.transformer.layer.0.ffn.lin2.weight
distilbert.transformer.layer.0.ffn.lin2.bias
distilbert.transformer.layer.0.output_layer_norm.weight
distilbert.transformer.layer.0.output_layer_norm.bias
distil

In [23]:
print("Total parameters in DIstilBERT Model: ", model.num_parameters())
# BERT contains around 126 Million parameters

Total parameters in DIstilBERT Model:  81964112


In [24]:
from transformers import DataCollatorForLanguageModeling, LineByLineTextDataset

In [25]:
# We use our trained tokenizer on the data
dataset = LineByLineTextDataset(tokenizer=tokenizer, file_path="./training_data_paper_abstracts_neurips_iclr.txt",
                                block_size=128)

In [26]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

Training Process

In [27]:
from transformers import Trainer, TrainingArguments

In [28]:
training_args = TrainingArguments(
    output_dir="./DistilBERTModel",
    overwrite_output_dir=True,
    per_gpu_train_batch_size=64,
    save_steps=10_000,
    save_total_limit=2,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
    prediction_loss_only=True,
)



In [29]:
%%time
trainer.train()

Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future version. Using `--per_device_train_batch_size` is preferred.
Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future version. Using `--per_device_train_batch_size` is preferred.
Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future version. Using `--per_device_train_batch_size` is preferred.


HBox(children=(HTML(value='Epoch'), FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=1849.0), HTML(value='')))

{'loss': 7.013072265625, 'learning_rate': 4.549305931133946e-05, 'epoch': 0.2704164413196322, 'total_flos': 1175782399481856, 'step': 500}
{'loss': 6.243638671875, 'learning_rate': 4.0986118622678924e-05, 'epoch': 0.5408328826392644, 'total_flos': 2326353949538304, 'step': 1000}
{'loss': 6.0335234375, 'learning_rate': 3.647917793401839e-05, 'epoch': 0.8112493239588967, 'total_flos': 3487910002028544, 'step': 1500}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=1849.0), HTML(value='')))

{'loss': 5.791875, 'learning_rate': 3.197223724535785e-05, 'epoch': 1.0816657652785289, 'total_flos': 4621501210930176, 'step': 2000}
{'loss': 5.55408984375, 'learning_rate': 2.746529655669732e-05, 'epoch': 1.3520822065981613, 'total_flos': 5785386355627008, 'step': 2500}
{'loss': 5.3638828125, 'learning_rate': 2.2958355868036776e-05, 'epoch': 1.6224986479177934, 'total_flos': 6928404093121536, 'step': 3000}
{'loss': 5.2172578125, 'learning_rate': 1.845141517937624e-05, 'epoch': 1.8929150892374258, 'total_flos': 8088071692471296, 'step': 3500}



HBox(children=(HTML(value='Iteration'), FloatProgress(value=0.0, max=1849.0), HTML(value='')))

{'loss': 5.12032421875, 'learning_rate': 1.3944474490715703e-05, 'epoch': 2.1633315305570577, 'total_flos': 9255814396135296, 'step': 4000}
{'loss': 5.0061328125, 'learning_rate': 9.437533802055165e-06, 'epoch': 2.43374797187669, 'total_flos': 10393827732807552, 'step': 4500}
{'loss': 4.94041796875, 'learning_rate': 4.930593113394627e-06, 'epoch': 2.7041644131963225, 'total_flos': 11536436305454976, 'step': 5000}
{'loss': 4.91706640625, 'learning_rate': 4.236524247340905e-07, 'epoch': 2.9745808545159544, 'total_flos': 12706773665048448, 'step': 5500}


CPU times: user 47min 47s, sys: 5.57 s, total: 47min 52s
Wall time: 47min 47s


TrainOutput(global_step=5547, training_loss=5.5580779165539935)

In [30]:
# Saave to the same directory
trainer.save_model("./DistilBERTModel")

# 3. Generating Text using the DistilBERT LM

In [31]:
from transformers import pipeline

In [82]:
import numpy as np

In [83]:
np.random.seed(42)

In [40]:
predict_masked_word = pipeline("fill-mask", model="./DistilBERTModel", tokenizer="./DistilBERTModel")

In [92]:
def generate_text(prompt_sentence, words_to_gen=5):
    
    for i in range(words_to_gen):
        pred_words = predict_masked_word(prompt_sentence + f" {tokenizer.mask_token}")
        print(prompt_sentence + ": ")
        potential_word_tokens = []
        for w in pred_words:
            potential_word_tokens.append(w["token_str"])
        print("Potential words - ", " ".join(potential_word_tokens))
        idx = np.random.randint(0, len(potential_word_tokens))
        prompt_sentence = prompt_sentence + " " + potential_word_tokens[idx]
        print("\n")
    
    print("\n\n", prompt_sentence)
    
    return prompt_sentence

In [96]:
generate_text("We present our")

We present our: 
Potential words -  method algorithm model approach framework


We present our framework: 
Potential words -  . algorithm for model of


We present our framework for: 
Potential words -  learning . algorithms data Learning


We present our framework for Learning: 
Potential words -  . algorithm algorithms method models


We present our framework for Learning .: 
Potential words -  . algorithms algorithm learning models




 We present our framework for Learning . algorithms


'We present our framework for Learning . algorithms'

In [99]:
generate_text("Adam optimization is")

Adam optimization is: 
Potential words -  . based that learning ?


Adam optimization is learning: 
Potential words -  . algorithms models problems algorithm


Adam optimization is learning problems: 
Potential words -  . for in , and


Adam optimization is learning problems ,: 
Potential words -  . learning and models algorithms


Adam optimization is learning problems , algorithms: 
Potential words -  . learning algorithms and ,




 Adam optimization is learning problems , algorithms algorithms


'Adam optimization is learning problems , algorithms algorithms'

In [101]:
generate_text("Furthermore, we leverage")

Furthermore, we leverage: 
Potential words -  . data models algorithms ?


Furthermore, we leverage data: 
Potential words -  . data models sets ?


Furthermore, we leverage data sets: 
Potential words -  . data and , from


Furthermore, we leverage data sets from: 
Potential words -  data . training tasks models


Furthermore, we leverage data sets from .: 
Potential words -  . data e g ,




 Furthermore, we leverage data sets from . data


'Furthermore, we leverage data sets from . data'

In [103]:
generate_text("Zero-shot learning")

Zero-shot learning: 
Potential words -  models algorithms methods learning problems


Zero-shot learning models: 
Potential words -  are learning . have ,


Zero-shot learning models ,: 
Potential words -  learning such . are models


Zero-shot learning models , models: 
Potential words -  . are learning have such


Zero-shot learning models , models are: 
Potential words -  learning models . used algorithms




 Zero-shot learning models , models are algorithms


'Zero-shot learning models , models are algorithms'

In [113]:
generate_text("We compare our results")

We compare our results: 
Potential words -  . on and results of


We compare our results on: 
Potential words -  datasets data experiments synthetic results


We compare our results on synthetic: 
Potential words -  datasets . data results tasks


We compare our results on synthetic tasks: 
Potential words -  . datasets data results and


We compare our results on synthetic tasks data: 
Potential words -  . datasets data tasks and




 We compare our results on synthetic tasks data .


'We compare our results on synthetic tasks data .'

In [126]:
generate_text("We release our")

We release our: 
Potential words -  method model algorithm approach .


We release our approach: 
Potential words -  . on algorithm that for


We release our approach for: 
Potential words -  learning . data algorithms results


We release our approach for learning: 
Potential words -  . algorithms models algorithm learning


We release our approach for learning .: 
Potential words -  . algorithms data models learning




 We release our approach for learning . data


'We release our approach for learning . data'

In [127]:
generate_text("We propose a novel")

We propose a novel: 
Potential words -  model algorithm framework method .


We propose a novel algorithm: 
Potential words -  . for that algorithm based


We propose a novel algorithm based: 
Potential words -  . algorithm method model learning


We propose a novel algorithm based model: 
Potential words -  . algorithm learning model that


We propose a novel algorithm based model algorithm: 
Potential words -  . that for learning based




 We propose a novel algorithm based model algorithm learning


'We propose a novel algorithm based model algorithm learning'

In [130]:
generate_text("The specific attention pattern can be ")

The specific attention pattern can be : 
Potential words -  used . data learned information


The specific attention pattern can be  data: 
Potential words -  . data , ? and


The specific attention pattern can be  data and: 
Potential words -  . data images information training


The specific attention pattern can be  data and images: 
Potential words -  . data , ? space


The specific attention pattern can be  data and images ?: 
Potential words -  . , data ? )




 The specific attention pattern can be  data and images ? ?


'The specific attention pattern can be  data and images ? ?'

In [131]:
generate_text("For the sentence classification task, we use")

For the sentence classification task, we use: 
Potential words -  . data learning models algorithms


For the sentence classification task, we use learning: 
Potential words -  . algorithms models methods problems


For the sentence classification task, we use learning methods: 
Potential words -  . algorithms learning methods models


For the sentence classification task, we use learning methods learning: 
Potential words -  . algorithms models methods problems


For the sentence classification task, we use learning methods learning algorithms: 
Potential words -  . learning algorithms methods models




 For the sentence classification task, we use learning methods learning algorithms algorithms


'For the sentence classification task, we use learning methods learning algorithms algorithms'

In [132]:
generate_text("We compare our results with previous")

We compare our results with previous: 
Potential words -  results . data performance datasets


We compare our results with previous data: 
Potential words -  . data datasets results models


We compare our results with previous data results: 
Potential words -  . on and demonstrate datasets


We compare our results with previous data results demonstrate: 
Potential words -  . results experiments data that


We compare our results with previous data results demonstrate .: 
Potential words -  . datasets data results performance




 We compare our results with previous data results demonstrate . .


'We compare our results with previous data results demonstrate . .'

## Observations

We observe that long range dependecies are not captured completely by our model. Some of the reasons/improvements:
1. The model has around 81M parameters and we train it on a relatively small dataset of only 118,268 sentences. Including the full text of the paper might improve results. However, that will require more time to train.
2. Another way to improve the model might be to use a pretrained model instead of training the LM from scratch.

## References:
1. DistilBERT: https://huggingface.co/transformers/model_doc/distilbert.html
2. Training models using transformers: https://huggingface.co/transformers/training.html
3. https://huggingface.co/blog/how-to-train