#  aitextgen — Train a GPT-2 Text-Generating Model w/ GPU

by [Max Woolf](https://minimaxir.com)

*Last updated: Jul 5th, 2020*

Retrain an advanced text generating neural network on any text dataset **for free on a GPU using Colaboratory** using `aitextgen`!

For more about `aitextgen`, you can visit [this GitHub repository](https://github.com/minimaxir/aitextgen) or [read the documentation](https://docs.aitextgen.io/).


To get started:

1. Copy this notebook to your Google Drive to keep it and save your changes. (File -> Save a Copy in Drive)
2. Run the cells below:


In [None]:
!pip install -q aitextgen

import logging
logging.basicConfig(
        format="%(asctime)s — %(levelname)s — %(name)s — %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
    )

from aitextgen import aitextgen
from aitextgen.colab import mount_gdrive, copy_file_from_gdrive

[K     |████████████████████████████████| 573kB 7.5MB/s 
[K     |████████████████████████████████| 1.5MB 18.7MB/s 
[K     |████████████████████████████████| 81kB 9.9MB/s 
[K     |████████████████████████████████| 686kB 46.5MB/s 
[K     |████████████████████████████████| 2.9MB 40.4MB/s 
[K     |████████████████████████████████| 890kB 47.1MB/s 
[K     |████████████████████████████████| 102kB 10.7MB/s 
[K     |████████████████████████████████| 276kB 41.9MB/s 
[K     |████████████████████████████████| 829kB 44.9MB/s 
[K     |████████████████████████████████| 1.3MB 42.8MB/s 
[K     |████████████████████████████████| 143kB 45.6MB/s 
[K     |████████████████████████████████| 296kB 40.5MB/s 
[?25h  Building wheel for aitextgen (setup.py) ... [?25l[?25hdone
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Building wheel for PyYAML (setup.py) ... [?25l[?25hdone
  Building wheel for future (setup.py) ... 

## GPU

Colaboratory uses a Nvidia P4, an Nvidia T4, or an Nvidia P100 GPU. For finetuning GPT-2 124M, any of these GPUs will be fine, but for text generation, a T4 or a P100 is ideal since they have more VRAM.

You can verify which GPU is active by running the cell below. If you want to try for a different GPU, go to **Runtime -> Factory Reset Runtime**.

In [None]:
!nvidia-smi

Mon Jan 11 14:18:09 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.27.04    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P8    28W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Loading GPT-2

If you're retraining a model on new text, you need to download and load the GPT-2 model into the GPU. 

There are several sizes of GPT-2: currently, aitextgen only works with the smallest one:

* `124M` (default): the "small" model, 500MB on disk.

The next cell downloads it from Google's servers and saves it in the Colaboratory VM. If the model has already been downloaded, running this cell will reload it.

In [None]:
ai = aitextgen(tf_gpt2="124M", to_gpu=True)

01/11/2021 14:18:47 — INFO — aitextgen — Downloading the 124M GPT-2 TensorFlow weights/config from Google's servers


HBox(children=(FloatProgress(value=0.0, description='Fetching checkpoint', max=77.0, style=ProgressStyle(descr…




HBox(children=(FloatProgress(value=0.0, description='Fetching hparams.json', max=90.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Fetching model.ckpt.data-00000-of-00001', max=497759232.0…




HBox(children=(FloatProgress(value=0.0, description='Fetching model.ckpt.index', max=5215.0, style=ProgressSty…




HBox(children=(FloatProgress(value=0.0, description='Fetching model.ckpt.meta', max=471155.0, style=ProgressSt…

01/11/2021 14:18:54 — INFO — aitextgen — Converting the 124M GPT-2 TensorFlow weights to PyTorch.





Converting TensorFlow checkpoint from /content/aitextgen/124M
Loading TF weight model/h0/attn/c_attn/b with shape [2304]
Loading TF weight model/h0/attn/c_attn/w with shape [1, 768, 2304]
Loading TF weight model/h0/attn/c_proj/b with shape [768]
Loading TF weight model/h0/attn/c_proj/w with shape [1, 768, 768]
Loading TF weight model/h0/ln_1/b with shape [768]
Loading TF weight model/h0/ln_1/g with shape [768]
Loading TF weight model/h0/ln_2/b with shape [768]
Loading TF weight model/h0/ln_2/g with shape [768]
Loading TF weight model/h0/mlp/c_fc/b with shape [3072]
Loading TF weight model/h0/mlp/c_fc/w with shape [1, 768, 3072]
Loading TF weight model/h0/mlp/c_proj/b with shape [768]
Loading TF weight model/h0/mlp/c_proj/w with shape [1, 3072, 768]
Loading TF weight model/h1/attn/c_attn/b with shape [2304]
Loading TF weight model/h1/attn/c_attn/w with shape [1, 768, 2304]
Loading TF weight model/h1/attn/c_proj/b with shape [768]
Loading TF weight model/h1/attn/c_proj/w with shape [1, 7

Save PyTorch model to aitextgen/pytorch_model.bin


01/11/2021 14:19:00 — INFO — aitextgen — Loading 124M GPT-2 model from /aitextgen.


Save configuration file to aitextgen/config.json


01/11/2021 14:19:06 — INFO — aitextgen — Using the default GPT-2 Tokenizer.


## Mounting Google Drive

The best way to get input text to-be-trained into the Colaboratory VM, and to get the trained model *out* of Colaboratory, is to route it through Google Drive *first*.

Running this cell (which will only work in Colaboratory) will mount your personal Google Drive in the VM, which later cells can use to get data in/out. (it will ask for an auth code; that auth is not saved anywhere)

In [None]:
mount_gdrive()

Mounted at /content/drive


If your text file is large (>10MB), it is recommended to upload that file to Google Drive first, then copy that file from Google Drive to the Colaboratory VM.

Additionally, you may want to consider [compressing the dataset to a cache first](https://docs.aitextgen.io/dataset/) on your local computer, then uploading the resulting `dataset_cache.tar.gz` and setting the `file_name`in the previous cell to that.

In [None]:
import os
import json

dir = "/content/drive/MyDrive/NLP_scientific-text-generation/json"

abstracts = []
count = 100
i = 0

for f_name in os.listdir(dir):
  if f_name.endswith(".json"):
    path = os.path.join(dir, f_name)
    with open(path, "r") as f:
      data = json.load(f)
      for obj in data["metadata"]:
        if obj["key"] == "dc.description.abstract":
          abstracts.append(obj["value"])


    i += 1
    if i >= count:
      break

In [None]:
from aitextgen.TokenDataset import TokenDataset
dataset = TokenDataset(texts = abstracts)

01/11/2021 14:24:51 — INFO — aitextgen.TokenDataset — Encoding 99 texts.


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=99.0), HTML(value='')), layout=Layout(dis…




## Finetune GPT-2

The next cell will start the actual finetuning of GPT-2 in aitextgen. It runs for `num_steps`, and a progress bar will appear to show training progress, current loss (the lower the better the model), and average loss (to give a sense on loss trajectory).

The model will be saved every `save_every` steps in `trained_model` by default, and when training completes. If you mounted your Google Drive, the model will _also_ be saved there in a unique folder.

The training might time out after 4ish hours; if you did not mount to Google Drive, make sure you end training and save the results so you don't lose them! (if this happens frequently, you may want to consider using [Colab Pro](https://colab.research.google.com/signup))

Important parameters for `train()`:

- **`line_by_line`**: Set this to `True` if the input text file is a single-column CSV, with one record per row. aitextgen will automatically process it optimally.
- **`from_cache`**: If you compressed your dataset locally (as noted in the previous section) and are using that cache file, set this to `True`.
- **`num_steps`**: Number of steps to train the model for.
- **`generate_every`**: Interval of steps to generate example text from the model; good for qualitatively validating training.
- **`save_every`**: Interval of steps to save the model: the model will be saved in the VM to `/trained_model`.
- **`save_gdrive`**: Set this to `True` to copy the model to a unique folder in your Google Drive, if you have mounted it in the earlier cells

Here are other important parameters for `train()` that are useful but you likely do not need to change.

- **`learning_rate`**: Learning rate of the model training.
- **`batch_size`**: Batch size of the model training; setting it too high will cause the GPU to go OOM.

In [None]:
ai.train(dataset,
         line_by_line=False,
         from_cache=False,
         num_steps=1000,
         generate_every=100,
         save_every=100,
         save_gdrive=False,
         learning_rate=1e-4,
         batch_size=1, 
         )

GPU available: True, used: True
01/11/2021 14:26:08 — INFO — lightning — GPU available: True, used: True
TPU available: None, using: 0 TPU cores
01/11/2021 14:26:08 — INFO — lightning — TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
01/11/2021 14:26:08 — INFO — lightning — LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1000.0), HTML(value='')), layout=Layout(d…

[1m100 steps reached: saving model to /trained_model[0m
[1m100 steps reached: generating sample texts.[0m
 and the world's leading scientific minds, editors, actors and producers. Our focus is on finding the right balance between providing the highest level of quality products at the right time and ensuring that the right price is applied when it comes to the right use cases. The results of our research suggest that consumers should be aware of the different types of products they are buying and can be influenced by their preferences. The research findings are based on a survey of more than 100,000 consumers across the country. The survey was conducted by commercial bank Thomson Reuters, and surveyed members of the public about their overall consumer preferences. The survey results are based on the sample of Iasi, Germany, taken on September 22, 2015. For the sample of German consumers, the standard error of the survey was 4.3%. For all consumers, the standard error was 0.3%. The a

01/11/2021 14:39:46 — INFO — aitextgen — Saving trained model pytorch_model.bin to /trained_model


You're done! Feel free to go to the **Generate Text From The Trained Model** section to generate text based on your retrained model.


## Load a Trained Model

Running the next cell will copy the `pytorch_model.bin` and the `config.json`file from the specified folder in Google Drive into the Colaboratory VM. (If no `from_folder` is specified, it assumes the two files are located at the root level of your Google Drive)

In [None]:
from_folder = None

for file in ["pytorch_model.bin", "config.json"]:
  if from_folder:
    copy_file_from_gdrive(file, from_folder)
  else:
    copy_file_from_gdrive(file)

The next cell will allow you to load the retrained model + metadata necessary to generate text.

In [None]:
ai = aitextgen(model="pytorch_model.bin", config="config.json", to_gpu=True)

05/17/2020 20:27:21 — INFO — aitextgen — Loading GPT-2 model from provided pytorch_model.bin.
05/17/2020 20:27:26 — INFO — aitextgen — Using the default GPT-2 Tokenizer.


## Generate Text From The Trained Model

After you've trained the model or loaded a retrained model from checkpoint, you can now generate text. `generate()` without any parameters generates a single text from the loaded model to the console.

In [None]:
ai.generate()

[1m[0m>The proliferation of storage infrastructure in the EU in the past decade has limited the ability to develop alternative district heating and cooling systems. Energy infrastructure innovation centres are sparse and feature limited specialization and competition. Large companies tend to follow the 'closed innovation' model where R&D activities are concentrated within an organization, and focus on incremental innovations while lagging in radical innovations in cogeneration and trigeneration. Under these conditions, short-term planning dominates, while mid-term planning is virtually non-existent. The paper concludes with recommended measures to support the innovative development of Russian heating companies that can be split into institutional and corporate recommendations. The first group concerns stimulating competition in the heat supply market and creating a stable legal and investment environment. The second group calls for technological modernization, development of long-ter

If you're creating an API based on your model and need to pass the generated text elsewhere, you can do `text = ai.generate_one()`

You can also pass in a `prompt` to the generate function to force the text to start with a given character sequence and generate text from there (good if you add an indicator when the text starts).

You can also generate multiple texts at a time by specifing `n`. You can pass a `batch_size` to generate multiple samples in parallel, giving a massive speedup (in Colaboratory, set a maximum of 50 for `batch_size` to avoid going OOM).

Other optional-but-helpful parameters for `ai.generate()` and friends:

*  **`max_length`**: Number of tokens to generate (default 256, you can generate up to 1024 tokens with GPT-2, but it will be _much_ slower)
* **`temperature`**: The higher the temperature, the crazier the text (default 0.7, recommended to keep between 0.7 and 1.0)
* **`top_k`**: Limits the generated guesses to the top *k* guesses (default 0 which disables the behavior; if the generated output is super crazy, you may want to set `top_k=40`)
* **`top_p`**: Nucleus sampling: limits the generated guesses to a cumulative probability. (gets good results on a dataset with `top_p=0.9`)

In [None]:
ai.generate(n=5,
            batch_size=5,
            prompt="This paper investigates the macroeconomic implications of",
            max_length=256,
            temperature=0.7,
            top_p=0.9)

[1mThis paper investigates the macroeconomic implications of[0m the dual nature of relations between Russia and Ukraine in the context of DCFTA’s implementation. We are interested to encounter the point of contact and the interference areas to which Ukraine is referring to secure its balance, considering that its economy and market are traditionally oriented toward Russia and DCFTA is reorienting it toward EU and in this context, to show that the Russian economy is not, fundamentally, affected. Based on the statistical data that provides the economic situation of the two countries (import, export, FDI, labour migrants) before and after signing the treaty, we will analse the manner in which DCFTA affects the Russian economy. The findings indicate that the implementation of the DCFTA does not have the potential to cause instability in trade relations between Ukraine and Russia, and that the hostile actions of Russia against the DCFTA has the potential to cause instability in trade rela

For bulk generation, you can generate a large amount of texts to a file and sort out the samples locally on your computer. The next cell will generate `num_files` files, each with `n` texts and whatever other parameters you would pass to `generate()`. The files can then be downloaded from the Files sidebar!

You can rerun the cells as many times as you want for even more generated texts!

In [None]:
num_files = 5

for _ in range(num_files):
  ai.generate_to_file(n=1000,
                     batch_size=50,
                     prompt="ROMEO:",
                     max_length=256,
                     temperature=1.0,
                     top_p=0.9)

05/18/2020 00:06:16 — INFO — aitextgen — Generating 1,000 texts to ATG_20200518_000616_78407959.txt


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

05/18/2020 00:08:30 — INFO — aitextgen — Generating 1,000 texts to ATG_20200518_000830_67373043.txt





HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

05/18/2020 00:10:44 — INFO — aitextgen — Generating 1,000 texts to ATG_20200518_001044_88174494.txt





HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

05/18/2020 00:12:58 — INFO — aitextgen — Generating 1,000 texts to ATG_20200518_001258_27768790.txt





HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

05/18/2020 00:15:12 — INFO — aitextgen — Generating 1,000 texts to ATG_20200518_001512_83597962.txt





HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




# LICENSE

MIT License

Copyright (c) 2020 Max Woolf

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.