Skip to content

Commit

Permalink
Update ReZero Transformer example notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
tbachlechner committed Mar 16, 2020
1 parent ab9de6a commit 5d2d1be
Showing 1 changed file with 40 additions and 14 deletions.
54 changes: 40 additions & 14 deletions ReZero-Deep_Fast_Transformer.ipynb
Expand Up @@ -4,16 +4,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# ReZero by Examples \n # Training 128 layer ReZero Transformer on WikiText-2 language modeling\n",
"# Training 128 layer ReZero Transformer on WikiText-2 language modeling\n",
"\n",
"In this notebook we will examine how the [ReZero](https://arxiv.org/abs/2003.04887) architecture addition enables or accelerates training in deep [Transformer](https://arxiv.org/pdf/1706.03762.pdf) networks or fully connected networks.\n",
"\n",
"The official ReZero repo is [here](https://github.com/majumderb/rezero). Although it is not required for this notebook, you can install ReZero for PyTorch Transformers via\n",
"```\n",
"pip install rezero\n",
"```\n",
"The official ReZero repo is [here](https://github.com/majumderb/rezero). Although it is not required for this notebook, you can install ReZero for PyTorch Transformers via `pip install rezero`.\n",
"\n",
"Running time of the notebook: 7 minutes on laptop with single RTX 2060 GPU (+ 21 minutes for training 128 layer transformer at the end)"
"Running time of the notebook: 7 minutes on laptop with single RTX 2060 GPU (+ 21 minutes for training 128 layer transformer at the end)."
]
},
{
Expand Down Expand Up @@ -81,7 +78,7 @@
" super(ReZeroEncoderLayer, self).__init__()\n",
" self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)\n",
" \n",
" # Definte the Resisdual Weight for ReZero\n",
" # Define the Resisdual Weight for ReZero\n",
" self.resweight = torch.nn.Parameter(torch.Tensor([init_resweight]), requires_grad = resweight_trainable)\n",
"\n",
" # Implementation of Feedforward model\n",
Expand Down Expand Up @@ -120,6 +117,7 @@
" if self.use_LayerNorm == \"pre\":\n",
" src2 = self.norm1(src2)\n",
" src2 = self.self_attn(src2, src2, src2, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0]\n",
" # Apply the residual weight to the residual connection. This enables ReZero.\n",
" src2 = self.resweight * src2\n",
" src2 = self.dropout1(src2)\n",
" if self.use_LayerNorm == False:\n",
Expand Down Expand Up @@ -147,10 +145,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1.1 - Signal propagation\n",
"## Signal propagation\n",
"\n",
"Let us pause and examine the signal propagation properties in a toy deep network `DeepEncoder` by evaluating the singular values of the Transformer input-output Jacobian `io_jacobian_TF`.\n",
"\n",
"The entries of the input-output Jacobian matrix reflects the change of each output with respect to each input. The singular value decomposition of this matrix reflects by how much in magnitude (singular value) an input signal (the corresponding singular vector) changes as it propagates through the network, see the [Wikipedia page](https://en.wikipedia.org/wiki/Singular_value_decomposition). A vanishing singular value means that the corresponding singular vector is mapped to zero (poor signal propagation), while a large singular value means that the corresponding singular vector is amplified in magnitude (chaotic signal propagation). Due to these properties, the singular value decomposition provides a useful tool to study signal propagation in neural networks. As we will see in this notebook, singular values close to unity (i.e. dynamical isometry) often coincide with strong training performance.\n",
"\n",
"We will compare both pre- and the vanilla (post-) norm variants of the Transformer with the ReZero proposal that eliminates LayerNorm. Since ReZero initializes each layer to perform the identity map by setting all residual weights to zero, we here instead set all the residual weights to 0.1, in order to see a non-trivial distribution of the input-output Jacobian singular values. We define `plot_jacobians` to plot the singular value distributions for each architecture."
]
},
Expand Down Expand Up @@ -339,10 +339,17 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1.2 - Language modeling\n",
"## Language modeling\n",
"\n",
"We now use each of the three Transformer archtectures defined above to model the WikiText-2 dataset, following the basic PyTorch tutorial PyTorch tutorial [Sequence-to-Sequence Modeling with nn.Transformer and TorchText\n",
"](https://pytorch.org/tutorials/beginner/transformer_tutorial.html). The model is tasked to predict which word will follow a sequence of words, and we refer to the tutorial for details.\n",
"](https://pytorch.org/tutorials/beginner/transformer_tutorial.html). The model is tasked to predict which word will follow a sequence of words, and we refer to the tutorial for details.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define the model\n",
"\n",
"We now define the `TransformerModel` and several functions that load and prepare the data. Finally, we arrive at the function `setup_and_train`, that defines, trains and evaluates the model, and takes the following parameters as input:\n",
"\n",
Expand Down Expand Up @@ -603,9 +610,15 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train and compare three Transformer architectures:\n",
"\n",
"We can now easily use the function `setup_and_train` to run experiments by changing between Transformer architectures and modifying hyperparameters.\n",
"\n",
"First, let us use the `'post'` architecture that corresponds to a vanilla Transformer (i.e. we set `resweight = 1` and it is not trainable). Our experiment uses the Adagrad optimizer and no learning-rate warmup. For a 6 layer transformer network we observe slow training. "
"### Vanilla, or post-norm Transformer\n",
"\n",
"First, let us use the `'post'` architecture that corresponds to a vanilla Transformer (i.e. we set `resweight = 1` and it is not trainable). Our experiment uses the Adagrad optimizer and no learning-rate warmup. For a 6 layer transformer network we observe slow training. After three epochs achieves a validation ppl of around `168`.\n",
"\n",
"Although the mean squared singular values of the Jacobian remain close to one, the histogram shows a large spread. This indicates that some signals get amplified while others are attenuated, which is associated with poor trainng performance."
]
},
{
Expand Down Expand Up @@ -677,7 +690,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, let us use the `'pre'` architecture that applies the `LayerNorm` before the residual connection. For the 6 layer Transformer network with otherwise identical hyperparameters we observe faster training."
"### Vanilla, or post-norm Transformer\n",
"\n",
"Next, let us use the `'pre'` architecture that applies the `LayerNorm` before the residual connection. For the 6 layer Transformer network with otherwise identical hyperparameters we observe faster training.\n",
"\n",
"Again, the mean squared singular values of the Jacobian remain close to one, and compared to the 'post' architecture, the histogram shows a smaller spread in the singular values. This coincides with somewhat better trainng performance."
]
},
{
Expand Down Expand Up @@ -749,7 +766,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we us use the `'ReZero'` architecture that eliminates the `LayerNorm` but set the residual weight initially to zero, and registers it as a trainable parameter. `ReZero` enables the use of a higher learning rate compared to the other architectures. For the 6 layer Transformer network with otherwise identical hyperparameters we observe the fastest training."
"### ReZero Transformer\n",
"\n",
"Finally, we us use the `'ReZero'` architecture that eliminates the `LayerNorm` but set the residual weight initially to zero, and registers it as a trainable parameter. `ReZero` enables the use of a higher learning rate compared to the other architectures. For the 6 layer Transformer network with otherwise identical hyperparameters we observe the fastest training.\n",
"\n",
"\n",
"The mean squared singular values of the Jacobian are very close to one and the histogram shows a very small spread in the singular values. This coincides with the best trainng performance observed in this comparison."
]
},
{
Expand Down Expand Up @@ -821,9 +843,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### 128 layer ReZero Transformer\n",
"\n",
"As promised in the title, we can use the `'ReZero'` architecture to train extremely deep Transformer networks. To render a `128` layer transformer tranable, we again reduce the learning rate (to the Adagrad default value of `lr = 0.01`).\n",
"\n",
"Training this 128 layer network takes about 20 minutes and after three epochs achieves a validation ppl of around `168`."
"Training this 128 layer network takes about 20 minutes and after three epochs achieves the best validation ppl of around `168`. \n",
"\n",
"Unfortunately, it would require too much memory to quickly evaluate the input-output Jacobian for this deep network."
]
},
{
Expand Down

0 comments on commit 5d2d1be

Please sign in to comment.