#  aitextgen — Train a GPT-2 (or GPT Neo) Text-Generating Model w/ GPU

by [Max Woolf](https://minimaxir.com)

*Last updated: May 16th, 2021 (aitextgen v0.5.2)*

Retrain an advanced text generating neural network on any text dataset **for free on a GPU using Colaboratory** using `aitextgen`!

For more about `aitextgen`, you can visit [this GitHub repository](https://github.com/minimaxir/aitextgen) or [read the documentation](https://docs.aitextgen.io/).


To get started:

1. Copy this notebook to your Google Drive to keep it and save your changes. (File -> Save a Copy in Drive)
2. Run the cells below:


In [None]:
!pip install -q aitextgen

import logging
logging.basicConfig(
        format='%(asctime)s — %(levelname)s — %(name)s — %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO
    )

from aitextgen import aitextgen
from aitextgen.colab import mount_gdrive, copy_file_from_gdrive

[K     |████████████████████████████████| 572 kB 4.7 MB/s 
[K     |████████████████████████████████| 4.0 MB 8.4 MB/s 
[K     |████████████████████████████████| 87 kB 1.7 MB/s 
[K     |████████████████████████████████| 582 kB 30.5 MB/s 
[K     |████████████████████████████████| 398 kB 44.6 MB/s 
[K     |████████████████████████████████| 596 kB 41.3 MB/s 
[K     |████████████████████████████████| 136 kB 43.6 MB/s 
[K     |████████████████████████████████| 1.1 MB 33.1 MB/s 
[K     |████████████████████████████████| 895 kB 40.9 MB/s 
[K     |████████████████████████████████| 77 kB 4.1 MB/s 
[K     |████████████████████████████████| 6.5 MB 34.7 MB/s 
[K     |████████████████████████████████| 94 kB 3.4 MB/s 
[K     |████████████████████████████████| 144 kB 47.5 MB/s 
[K     |████████████████████████████████| 271 kB 46.7 MB/s 
[?25h  Building wheel for aitextgen (setup.py) ... [?25l[?25hdone
  Building wheel for fire (setup.py) ... [?25l[?25hdone
[31mERROR: pip's dependenc

04/11/2022 14:17:53 — INFO — numexpr.utils — NumExpr defaulting to 2 threads.


## GPU

Colaboratory uses a Nvidia P4, an Nvidia T4, an Nvidia P100, or an Nvidia V100. For finetuning GPT-2 124M, any of these GPUs will be fine, but for text generation, a T4 or a P100 is ideal since they have more VRAM. **If you receive a T4 or a V100 GPU, you can enable `fp16=True` during training for faster/more memory efficient training.**

You can verify which GPU is active by running the cell below. If you want to try for a different GPU, go to **Runtime -> Factory Reset Runtime**.

In [None]:
!nvidia-smi

Mon Apr 11 14:17:57 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P8    27W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Loading GPT-2 or GPT Neo

If you're retraining a model on new text, you need to download and load the GPT-2 model into the GPU. 

There are several sizes of GPT-2:

* `124M` (default): the "small" model, 500MB on disk.
* `355M` (default): the "medium" model, 1.5GB on disk.
* `774M` (default): the "large" model, 3GB on disk.

You can also finetune a GPT Neo model instead, which is more suitable for longer texts and the base model has more recent data:

* `125M`: Analogous to the GPT-2 124M model.
* `350M`: Analogous to the GPT-2 355M model

The next cell downloads the model and saves it in the Colaboratory VM. If the model has already been downloaded, running this cell will reload it.

In [None]:
#model='124M'
#model='355M'
#model='774M'

model='gpt-neo-125M'
#model='gpt-neo-350M'

In [None]:
if model == '124M' or model == '355M' or model == '774M':
    ai = aitextgen(tf_gpt2=model, to_gpu=True)
else:
    ai = aitextgen(model='EleutherAI/' + model, to_gpu=True)

04/11/2022 14:17:57 — INFO — aitextgen — Downloading EleutherAI/gpt-neo-125M model to /aitextgen.
https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/config.json not found in cache or force_download set to True, downloading to /content/aitextgen/tmpo42j5v_z


Downloading:   0%|          | 0.00/0.98k [00:00<?, ?B/s]

storing https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/config.json in cache at aitextgen/29380fef22a43cbfb3d3a6c8e2f4fd951459584d87c34e4621b30580a54aca84.f0f7ebddfc6e15a23ac33e7fa95cd8cca05edf87cc74f9e3be7905f538a59762
creating metadata file for aitextgen/29380fef22a43cbfb3d3a6c8e2f4fd951459584d87c34e4621b30580a54aca84.f0f7ebddfc6e15a23ac33e7fa95cd8cca05edf87cc74f9e3be7905f538a59762
https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/pytorch_model.bin not found in cache or force_download set to True, downloading to /content/aitextgen/tmptgiu0x2h


Downloading:   0%|          | 0.00/502M [00:00<?, ?B/s]

storing https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/pytorch_model.bin in cache at aitextgen/b0ace3b93ace62067a246888f1e54e2d3ec20807d4d3e27ac602eef3b7091c0b.6525df88f1d5a2d33d95ce2458ef6af9658fe7d1393d6707e0e318779ccc68ff
creating metadata file for aitextgen/b0ace3b93ace62067a246888f1e54e2d3ec20807d4d3e27ac602eef3b7091c0b.6525df88f1d5a2d33d95ce2458ef6af9658fe7d1393d6707e0e318779ccc68ff
04/11/2022 14:18:21 — INFO — aitextgen — Using the tokenizer for EleutherAI/gpt-neo-125M.
https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/tokenizer_config.json not found in cache or force_download set to True, downloading to /content/aitextgen/tmpr1qinn8h


Downloading:   0%|          | 0.00/560 [00:00<?, ?B/s]

storing https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/tokenizer_config.json in cache at aitextgen/3cc88b3aa29bb2546db2dc21783292e2a086bb7158c7b5ceddeb24158a85c183.e74f7c3643ee79eb023ead36008be72fe726dada60fa3b2a0569925cfefa1e74
creating metadata file for aitextgen/3cc88b3aa29bb2546db2dc21783292e2a086bb7158c7b5ceddeb24158a85c183.e74f7c3643ee79eb023ead36008be72fe726dada60fa3b2a0569925cfefa1e74
https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/vocab.json not found in cache or force_download set to True, downloading to /content/aitextgen/tmpkf5z45b8


Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

storing https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/vocab.json in cache at aitextgen/08c00c4159e921d4c941ac75732643373aba509d9b352a82bbbb043a94058d98.a552555fdda56a1c7c9a285bccfd44ac8e4b9e26c8c9b307831b3ea3ac782b45
creating metadata file for aitextgen/08c00c4159e921d4c941ac75732643373aba509d9b352a82bbbb043a94058d98.a552555fdda56a1c7c9a285bccfd44ac8e4b9e26c8c9b307831b3ea3ac782b45
https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/merges.txt not found in cache or force_download set to True, downloading to /content/aitextgen/tmp01pbbbez


Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

storing https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/merges.txt in cache at aitextgen/12305762709d884a770efe7b0c68a7f4bc918da44e956058d43da0d12f7bea20.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b
creating metadata file for aitextgen/12305762709d884a770efe7b0c68a7f4bc918da44e956058d43da0d12f7bea20.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b
https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/special_tokens_map.json not found in cache or force_download set to True, downloading to /content/aitextgen/tmp61ki1unf


Downloading:   0%|          | 0.00/357 [00:00<?, ?B/s]

storing https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/special_tokens_map.json in cache at aitextgen/6c3239a63aaf46ec7625b38abfe41fc2ce0b25f90800aefe6526256340d4ab6d.2b8bf81243d08385c806171bc7ced6d2a0dcc7f896ca637f4e777418f7f0cc3c
creating metadata file for aitextgen/6c3239a63aaf46ec7625b38abfe41fc2ce0b25f90800aefe6526256340d4ab6d.2b8bf81243d08385c806171bc7ced6d2a0dcc7f896ca637f4e777418f7f0cc3c
loading file https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/vocab.json from cache at aitextgen/08c00c4159e921d4c941ac75732643373aba509d9b352a82bbbb043a94058d98.a552555fdda56a1c7c9a285bccfd44ac8e4b9e26c8c9b307831b3ea3ac782b45
loading file https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/merges.txt from cache at aitextgen/12305762709d884a770efe7b0c68a7f4bc918da44e956058d43da0d12f7bea20.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b
loading file https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/tokenizer.json from cache at None
loadin

Downloading:   0%|          | 0.00/0.98k [00:00<?, ?B/s]

storing https://huggingface.co/EleutherAI/gpt-neo-125M/resolve/main/config.json in cache at /root/.cache/huggingface/transformers/29380fef22a43cbfb3d3a6c8e2f4fd951459584d87c34e4621b30580a54aca84.f0f7ebddfc6e15a23ac33e7fa95cd8cca05edf87cc74f9e3be7905f538a59762
creating metadata file for /root/.cache/huggingface/transformers/29380fef22a43cbfb3d3a6c8e2f4fd951459584d87c34e4621b30580a54aca84.f0f7ebddfc6e15a23ac33e7fa95cd8cca05edf87cc74f9e3be7905f538a59762
04/11/2022 14:18:24 — INFO — aitextgen — GPTNeo loaded with 125M parameters.


In [None]:
folder = 'trained_model_' + model

## Mounting Google Drive

The best way to get input text to-be-trained into the Colaboratory VM, and to get the trained model *out* of Colaboratory, is to route it through Google Drive *first*.

Running this cell (which will only work in Colaboratory) will mount your personal Google Drive in the VM, which later cells can use to get data in/out. (it will ask for an auth code; that auth is not saved anywhere)

In [None]:
mount_gdrive()

Mounted at /content/drive


## Uploading a Text File to be Trained to Colaboratory

In the Colaboratory Notebook sidebar on the left of the screen, select *Files*. From there you can upload files:

![alt text](https://i.imgur.com/w3wvHhR.png)

Upload **any smaller text file** (for example, [a text file of Shakespeare plays](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt)) and update the file name in the cell below, then run the cell.

In [None]:
file_name = 'content2.txt'

If your text file is large (>10MB), it is recommended to upload that file to Google Drive first, then copy that file from Google Drive to the Colaboratory VM.

Additionally, you may want to consider [compressing the dataset to a cache first](https://docs.aitextgen.io/dataset/) on your local computer, then uploading the resulting `dataset_cache.tar.gz` and setting the `file_name`in the previous cell to that.

In [None]:
copy_file_from_gdrive(file_name)

## Finetune GPT-2

The next cell will start the actual finetuning of GPT-2 in aitextgen. It runs for `num_steps`, and a progress bar will appear to show training progress, current loss (the lower the better the model), and average loss (to give a sense on loss trajectory).

The model will be saved every `save_every` steps in `trained_model` by default, and when training completes. If you mounted your Google Drive, the model will _also_ be saved there in a unique folder.

The training might time out after 4ish hours; if you did not mount to Google Drive, make sure you end training and save the results so you don't lose them! (if this happens frequently, you may want to consider using [Colab Pro](https://colab.research.google.com/signup))

Important parameters for `train()`:

- **`line_by_line`**: Set this to `True` if the input text file is a single-column CSV, with one record per row. aitextgen will automatically process it optimally.
- **`from_cache`**: If you compressed your dataset locally (as noted in the previous section) and are using that cache file, set this to `True`.
- **`num_steps`**: Number of steps to train the model for.
- **`generate_every`**: Interval of steps to generate example text from the model; good for qualitatively validating training.
- **`save_every`**: Interval of steps to save the model: the model will be saved in the VM to `/trained_model`.
- **`save_gdrive`**: Set this to `True` to copy the model to a unique folder in your Google Drive, if you have mounted it in the earlier cells
- **`fp16`**: Enables half-precision training for faster/more memory-efficient training. Only works on a T4 or V100 GPU.

Here are other important parameters for `train()` that are useful but you likely do not need to change.

- **`learning_rate`**: Learning rate of the model training.
- **`batch_size`**: Batch size of the model training; setting it too high will cause the GPU to go OOM. (if using `fp16`, you can increase the batch size more safely)

In [None]:
ai.train(file_name,
         line_by_line=False,
         from_cache=False,
         num_steps=2000,
         generate_every=400,
         save_every=200,
         save_gdrive=True,
         output_dir=folder,
         learning_rate=1e-3,
         fp16=False,
         batch_size=1)

04/11/2022 14:20:48 — INFO — aitextgen — Loading text from content2.txt with generation length of 2048.


  0%|          | 0/507171 [00:00<?, ?it/s]

04/11/2022 14:20:48 — INFO — aitextgen.TokenDataset — Encoding 507,171 sets of tokens from content2.txt.
  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."
  f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will "
  f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and"
  "Setting `Trainer(weights_summary=None)` is deprecated in v1.5 and will be removed"
04/11/2022 14:21:10 — INFO — pytorch_lightning.utilities.rank_zero — GPU available: True, used: True
04/11/2022 14:21:10 — INFO — pytorch_lightning.utilities.rank_zero — TPU available: False, using: 0 TPU cores
04/11/2022 14:21:10 — INFO — pytorch_lightning.utilities.rank_zero — IPU available: False, using: 0 IPUs
04/11/2022 14:21:10 — INFO — pytorch_lightning.utilities.rank_zero — HPU available: False, using: 0 HPUs
  f"The `Callback.{hook}` hook was deprecated in v1.6 and"
04/11/2022 14:21:10 — INFO — pyto

  0%|          | 0/2000 [00:00<?, ?it/s]

  "`trainer.progress_bar_dict` is deprecated in v1.5 and will be removed in v1.7."


[1m200 steps reached: saving model to /trained_model_gpt-neo-125M[0m
[1m400 steps reached: saving model to /trained_model_gpt-neo-125M[0m
[1m400 steps reached: generating sample texts.[0m


We saw few details about the operation of the malware.

The malware is not available here.

The malware is not available in any other Windows XP, but it's likely to have been used in other
compromises.

Figure 9.

The malware was used as a Trojan for a global campaign of a Chinese-language
websites. We can identify these attacks as the same or possibly been used as the Trojan for
the following attacks.

Figure 10.

The malware is probably used in other countries for cyber attacks against
the Chinese-language infrastructure industry. In this case, we identify the attackers as
the same threat actor for the Russian-language infrastructure market.

Figure 11.

The malware was used in multiple attacks against the Polish
organizations.

Figure 11.

Analysis of the malware suggests that this malware 

04/11/2022 15:14:23 — INFO — aitextgen — Saving trained model pytorch_model.bin to /trained_model_gpt-neo-125M


You're done! Feel free to go to the **Generate Text From The Trained Model** section to generate text based on your retrained model.


## Load a Trained Model

If you already had a trained model from this notebook, running the next cell will copy the `pytorch_model.bin` and the `config.json`file from the specified folder in Google Drive into the Colaboratory VM. (If no `from_folder` is specified, it assumes the two files are located at the root level of your Google Drive)

In [None]:
#copy_file_from_gdrive('pytorch_model.bin', folder)
#copy_file_from_gdrive('config.json', folder)

The next cell will allow you to load the retrained model + metadata necessary to generate text.

## Generate Text From The Trained Model

After you've trained the model or loaded a retrained model from checkpoint, you can now generate text.

**If you just trained a model**, you'll get much faster training performance if you reload the model; the next cell will reload the model you just trained from the `trained_model` folder.

In [None]:
ai = aitextgen(model_folder=folder, to_gpu=True)

04/11/2022 15:14:28 — INFO — aitextgen — Loading model from provided weights and config in /trained_model_gpt-neo-125M.
04/11/2022 15:14:31 — INFO — aitextgen — GPTNeo loaded with 125M parameters.
04/11/2022 15:14:31 — INFO — aitextgen — Using the default GPT-2 Tokenizer.


`generate()` without any parameters generates a single text from the loaded model to the console.

In [None]:
ai.generate()

the cyber criminal unit.

On July 16, KrebsOnSecurity published the following statement about a cybercrime group called “Babar” — a crime group that has since been named “” (noteed that Babar’s malware
dropper had been named “Babar”).

Like most of the people who specialize in making money, Babar has been one of the most active players, and that the group was among the most active players and has it all over the world.

The group has been part of the group, or most of the members, and the group has been active since at least 2013.

According to a statement from KrebsOnSecurity, Babar is a cybercrime group that has since been named “Babar”.

In early versions, the group had the group’s most active players, and that those who work in the cybercriminal community had the group’s key employees work in the cybercrime community.

“Babar has a number of features to make it difficult to find new ways to do business with the group, and it is a small group of people who are working


If you're creating an API based on your model and need to pass the generated text elsewhere, you can do `text = ai.generate_one()`

You can also pass in a `prompt` to the generate function to force the text to start with a given character sequence and generate text from there (good if you add an indicator when the text starts).

You can also generate multiple texts at a time by specifing `n`. You can pass a `batch_size` to generate multiple samples in parallel, giving a massive speedup (in Colaboratory, set a maximum of 50 for `batch_size` to avoid going OOM).

Other optional-but-helpful parameters for `ai.generate()` and friends:

*  **`min length`**: The minimum length of the generated text: if the text is shorter than this value after cleanup, aitextgen will generate another one.
*  **`max_length`**: Number of tokens to generate (default 256, you can generate up to 1024 tokens with GPT-2 and 2048 with GPT Neo)
* **`temperature`**: The higher the temperature, the crazier the text (default 0.7, recommended to keep between 0.7 and 1.0)
* **`top_k`**: Limits the generated guesses to the top *k* guesses (default 0 which disables the behavior; if the generated output is super crazy, you may want to set `top_k=40`)
* **`top_p`**: Nucleus sampling: limits the generated guesses to a cumulative probability. (gets good results on a dataset with `top_p=0.9`)

In [None]:
prompts = ['Digital Forensics Analysis Report\n',
           'This report is ',
           'The contents of ',
           'Conclusion\n',
           'It is recommended that ',
           'In the opinion of the expert, ']

In [None]:
for prompt in prompts:
    ai.generate(n=5,
                batch_size=1,
                prompt=prompt,
                max_length=1000,
                temperature=1.0,
                top_p=0.9)

[1mDigital Forensics Analysis Report
[0m(Figure 7)

The Web site “sham.vaup” looks like a legitimate Web site. The Web site was registered to “sham.vaup”, and “sham.vaup”
is “sham.vaup”. Thus, this Web site looks like a legitimate Web site. The Web site was registered to “sham.vaup”. Thus, the Web site appears to be the Web site with a Web site called “sham.vaup”. In contrast, the Web site was registered to “sham.vaup”.

Figure 7 // The Web site “sham.vaup” Website

Figure 7 // The Web site “sham.vaup” Web site

Figure 8 // The Web site “sham.vaup” Website

Figure 9 // The Web site “sham.vaup” Web site

Figure 9 // The Web site “sham.vaup” Web site

Figure 10 // The Web site “sham.vaup” Web site

Figure 9 // The Web site “sham.vaup” Web site

Figure 10 // The Web site “sham.vaup” Web site

Figure 11 // The Web site “sham.vaup” Web site

Figure 12 // The Web site “sham.vaup” Web site

Figure 13 // The Web site “sham.vaup” Web site

Figure 14 // The Web site “sham.vaup” Web site

Figu

For bulk generation, you can generate a large amount of texts to a file and sort out the samples locally on your computer. The next cell will generate `num_files` files, each with `n` texts and whatever other parameters you would pass to `generate()`. The files can then be downloaded from the Files sidebar!

You can rerun the cells as many times as you want for even more generated texts!

In [None]:
num_files = 0

for prompt in prompts:
    for _ in range(num_files):
        ai.generate_to_file(n=200,
                            batch_size=1,
                            prompt=prompt,
                            max_length=2000,
                            temperature=1.0,
                            top_p=0.9)

# LICENSE

MIT License

Copyright (c) 2020-2021 Max Woolf

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.