## Reformer - Pushing the Limits of Language Modeling

Earlier this year, Nikita Kitaev, Łukasz Kaiser and Anselm Levskaya published the [**Reformer**](https://arxiv.org/abs/2001.04451), a transformer model variant with astounishing low memory consumption.

In this notebook, we will show how Reformer can be used in [`transformers`](https://github.com/huggingface/transformers). 
To highlight its low memory consumption, we reduce the novel [**Crime and Punishment**](https://en.wikipedia.org/wiki/Crime_and_Punishment) to a single example containing over half a million tokens and use it train Reformer with the conventional languge modeling objective.

Thanks to the recent releases of [`Trainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py) and [`nlp`](https://github.com/huggingface/nlp), it is easier than ever to train any model in `transformers` on the dataset of your choice.


### ***Disclaimer***:

This notebook is essentially a translation of the official reformer notebook from `trax` to `pytorch`: https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/text_generation.ipynb#scrollTo=bzQ7G9uGSga5 

First, let's check whether we are given the full portion of the GPU. 

In [None]:
#@title Check availble memory of GPU
# Check that we are using 100% of GPU
# memory footprint support libraries/code
!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
!pip -q install gputil
!pip -q install psutil
!pip -q install humanize
import psutil
import humanize
import os
import GPUtil as GPU
GPUs = GPU.getGPUs()
# XXX: only one GPU on Colab and isn’t guaranteed
gpu = GPUs[0]
def printm():
 process = psutil.Process(os.getpid())
 print("Gen RAM Free: " + humanize.naturalsize( psutil.virtual_memory().available ), " | Proc size: " + humanize.naturalsize( process.memory_info().rss))
 print("GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))
printm()

Gen RAM Free: 12.7 GB  | Proc size: 138.1 MB
GPU RAM Free: 16280MB | Used: 0MB | Util   0% | Total 16280MB


In case GPU utilisation (`Util`) is not at 0%, you can uncomment and run the following line to kill all processes to get the full GPU afterwards. 
Make sure to comment out the line again to not constantly crash the notebook on purpose. 

In [None]:
# !kill -9 -1

Let's install `nlp` and `transformers` and import the necessary classes from Reformer and Trainer. 

In [None]:
# install nlp
!pip install -qq nlp==0.2.0

# Make sure that we have a recent version of pyarrow in the session before we continue - otherwise reboot Colab to activate it
import pyarrow
if int(pyarrow.__version__.split('.')[1]) < 16:
    import os
    os.kill(os.getpid(), 9)

!pip install -qq transformers==2.10.0

In case the notebook crashesh the wrong version of `pyarrow` was installed here. Simply rerun the cell to install the correct version.

In [None]:
# imports
from transformers import (
    ReformerModelWithLMHead,
    ReformerTokenizer,
    ReformerConfig,
    Trainer,
    DataCollator,
    TrainingArguments,
)
import nlp
import torch

First we download *Crime and Punish* which contains the content of a 800 page book using the convenient `nlp` library.

In [None]:
# load the dataset
dataset = nlp.load_dataset("crime_and_punish", split="train")

Now let's get a pretrained sentence piece tokenizer that was trained on the *Crime and Punishment* dataset.

In [None]:
# get a pretrained tokenizer
tokenizer = ReformerTokenizer.from_pretrained("google/reformer-crime-and-punishment")

Because want to pack all data into a **single** sample, we use the handy `map()` function to reduce the dataset into one sample and pad the sample to a length of 524288. We then expand the same sample to 8 training samples so that we can accumulate gradients during training. Finally, we make the dataset ready for training, by only keeping the columns needed for training.

In [None]:
sequence_length = 2 ** 19  # 524288

# define our map function to reduce the dataset to one sample
def flatten_and_tokenize(batch):
  all_input_text = ["".join(batch["line"])]
  input_ids_dict = tokenizer.batch_encode_plus(
      all_input_text, pad_to_max_length=True, max_length=sequence_length
  )

    # duplicate data 8 times to have have 8 examples in dataset
  for key in input_ids_dict.keys():
    input_ids_dict[key] = [8 * [x] for x in input_ids_dict[key]][0]

  return input_ids_dict

# reduce the dataset and set batch_size to all inputs
dataset = dataset.map(
  flatten_and_tokenize, batched=True, batch_size=-1, remove_columns=["line"]
)

# prepare dataset to be in torch format
dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

Let's do a quick check that our dataset contains 8 examples

In [None]:
dataset

Dataset(schema: {'input_ids': 'list<item: int64>', 'token_type_ids': 'list<item: int64>', 'attention_mask': 'list<item: int64>'}, num_rows: 8)

Great! We can now move to designing the training.

As the official trax notebook says: *We don't want the model to just memorize the dataset (the single example) by encoding the words in its position embeddings. Thus, at each training iteration we will randomly select how much padding to put before the text vs. after it*.

With the `Trainer` framework of `transformers`, we can implement by using a *Reformer* specific `DataCollator` that randomely *shifts* the `input_ids` to the right and sets the `labels` correctly.

In [None]:
class ReformerCollator(DataCollator):
    def __init__(self, max_roll_length):
        self.max_roll_length = max_roll_length

    # From the official notebook: "Normally we would have a dataset with many examples, but for this demonstration we fit a language model on the single novel only. We don't want the model to just memorize the dataset by encoding the words in its position embeddings, so at each training iteration we will randomly select how much padding to put before the text vs. after it"
    def collate_batch(self, features):
        # get random shift int
        random_shift_length = torch.randint(self.max_roll_length, (1,)).item()

        # shift input and mask
        rolled_input_ids = torch.roll(
            features[0]["input_ids"], random_shift_length
        ).unsqueeze(0)
        rolled_attention_mask = torch.roll(
            features[0]["attention_mask"], random_shift_length
        ).unsqueeze(0)

        # return dict having the correct argument naming
        return {
            "input_ids": rolled_input_ids,  # BS x SEQ_LEN
            "labels": rolled_input_ids,  # BS x SEQ_LEN
            "attention_mask": rolled_attention_mask,  # BS x SEQ_LEN
        }

To instantiate the data collator the length of padded `input_ids` needs to be calculated. 

In [None]:
# the non_padded_sequence_length defines the max shift for our data collator
non_padded_sequence_length = sequence_length - sum(
    dataset["attention_mask"][0]
)

# get the data collator
data_collator = ReformerCollator(non_padded_sequence_length)

Next, we will define our reformer model by defining the `ReformerConfig`. 
As can be seen we alternate between *local attention* layers and *lsh attention* layers to have a total of 6 layers. Also note that we factorize the `num_buckets` and use **Axial Position Embeddings**. For more insight on how the bucketing and Axial Position Embeddings work please refer to the Reformer [docs](https://huggingface.co/transformers/model_doc/reformer.html). 

In [None]:
config = {
    "attention_head_size": 64,
    "attn_layers": ["local", "lsh", "local", "lsh", "local", "lsh"],
    "axial_pos_embds": True,
    "sinusoidal_pos_embds": False,
    "axial_pos_embds_dim": [64, 192],
    "axial_pos_shape": [512, 1024],
    "lsh_attn_chunk_length": 64,
    "local_attn_chunk_length": 64,
    "feed_forward_size": 512,
    "hidden_act": "relu",
    "hidden_size": 256,
    "is_decoder": True,
    "max_position_embeddings": 524288,
    "num_attention_heads": 2,
    "num_buckets": [64, 128],
    "num_hashes": 1,
    "vocab_size": 320,
    "lsh_attention_probs_dropout_prob": 0.0,
    "lsh_num_chunks_before": 1,
    "lsh_num_chunks_after": 0,
    "local_num_chunks_before": 1,
    "local_num_chunks_after": 0,
    "local_attention_probs_dropout_prob": 0.025,
    "hidden_dropout_prob": 0.025,
}

config = ReformerConfig(**config)
model = ReformerModelWithLMHead(config)
model = model.train()

Lastly, let's set up the training args. **Note**: *these training settings have not throughly been tested and might be tuned for better results*.

In [None]:
# define the training args
training_args = {
    "learning_rate": 1e-3,
    "max_steps": 2000,
    "do_train": True,
    "evaluate_during_training": True,
    "gradient_accumulation_steps": 8,
    "logging_steps": 50,
    "warmup_steps": 500,
    "weight_decay": 0.001,
    "fp16": True,
    "per_gpu_train_batch_size": 1,
    "per_gpu_eval_batch_size": 1,
    "save_steps": 50,
    "output_dir": "./"
}

training_args = TrainingArguments(**training_args)

We define a simple "accuracy" metric to keep track of how many samples are correctly predicted.

In [None]:
def compute_metrics(pred):
    non_padded_indices = (pred.label_ids != -100)

    # correctly shift labels and pred as it's done in forward()
    labels = pred.label_ids[..., 1:][non_padded_indices[..., 1:]]
    pred = np.argmax(pred.predictions[:, :-1], axis=-1)[non_padded_indices[..., :-1]]

    acc = np.mean(np.asarray(pred == labels), dtype=np.float)
    return {"accuracy": acc}


To enable training with mixed precision, we need to download the apex package and connect it to CUDA. The following command does all that, but might take a while to finish.

Let's first write a setup file

In [None]:
%%writefile setup.sh 

# install apex to be able to use mix precision
export CUDA_HOME=/usr/local/cuda-10.1
git clone https://github.com/NVIDIA/apex
pip install -v -q --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./apex

Overwriting setup.sh


And execute it afterward. 

In [None]:
!sh setup.sh

fatal: destination path 'apex' already exists and is not an empty directory.
  cmdoptions.check_install_build_global(options)
Processing ./apex
Skipping wheel build for apex, due to binaries being disabled for it.
Installing collected packages: apex
  Found existing installation: apex 0.1
    Uninstalling apex-0.1:
      Successfully uninstalled apex-0.1
    Running setup.py install for apex ... [?25l[?25hdone
Successfully installed apex-0.1


Finally we can start training. Since Google Colab only gives us a single GPU, this might take quite some time.


**Note** it might happen that the training step afterward will throw an `fp16` not-installed error. In this case you will have to restart the runtime and run the notebook again.

In [None]:
# create the trainer
trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
    train_dataset=dataset,
    eval_dataset=dataset,
    prediction_loss_only=True,
)

# train
trainer.train()

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=2001.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…






HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=8.0, style=ProgressStyle(description_widt…

KeyboardInterrupt: ignored

Let's see what the model learned to generate.

In [None]:
print(tokenizer.decode(model.generate(tokenizer.encode("Later that day, he", return_tensors="pt").to(model.device))[0]))