#  aitextgen — Train a GPT-2 (or GPT Neo) Text-Generating Model w/ GPU

by [Max Woolf](https://minimaxir.com)

*Last updated: May 16th, 2021 (aitextgen v0.5.2)*

Retrain an advanced text generating neural network on any text dataset **for free on a GPU using Colaboratory** using `aitextgen`!

For more about `aitextgen`, you can visit [this GitHub repository](https://github.com/minimaxir/aitextgen) or [read the documentation](https://docs.aitextgen.io/).


To get started:

1. Copy this notebook to your Google Drive to keep it and save your changes. (File -> Save a Copy in Drive)
2. Run the cells below:


In [1]:
!pip install -q aitextgen

import logging
logging.basicConfig(
        format="%(asctime)s — %(levelname)s — %(name)s — %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
    )

from aitextgen import aitextgen
from aitextgen.colab import mount_gdrive, copy_file_from_gdrive

[K     |████████████████████████████████| 572 kB 11.0 MB/s 
[K     |████████████████████████████████| 3.3 MB 31.6 MB/s 
[K     |████████████████████████████████| 87 kB 5.9 MB/s 
[K     |████████████████████████████████| 525 kB 40.9 MB/s 
[K     |████████████████████████████████| 132 kB 36.3 MB/s 
[K     |████████████████████████████████| 332 kB 38.7 MB/s 
[K     |████████████████████████████████| 829 kB 39.5 MB/s 
[K     |████████████████████████████████| 596 kB 32.5 MB/s 
[K     |████████████████████████████████| 1.1 MB 33.1 MB/s 
[K     |████████████████████████████████| 3.3 MB 34.4 MB/s 
[K     |████████████████████████████████| 895 kB 44.7 MB/s 
[K     |████████████████████████████████| 61 kB 475 kB/s 
[K     |████████████████████████████████| 271 kB 44.9 MB/s 
[K     |████████████████████████████████| 192 kB 49.7 MB/s 
[K     |████████████████████████████████| 160 kB 52.4 MB/s 
[?25h  Building wheel for aitextgen (setup.py) ... [?25l[?25hdone
  Building wheel for

## GPU

Colaboratory uses a Nvidia P4, an Nvidia T4, an Nvidia P100, or an Nvidia V100. For finetuning GPT-2 124M, any of these GPUs will be fine, but for text generation, a T4 or a P100 is ideal since they have more VRAM. **If you receive a T4 or a V100 GPU, you can enable `fp16=True` during training for faster/more memory efficient training.**

You can verify which GPU is active by running the cell below. If you want to try for a different GPU, go to **Runtime -> Factory Reset Runtime**.

In [2]:
!nvidia-smi

Fri Dec 10 20:50:07 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   73C    P8    35W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Loading GPT-2 or GPT Neo

If you're retraining a model on new text, you need to download and load the GPT-2 model into the GPU. 

There are several sizes of GPT-2:

* `124M` (default): the "small" model, 500MB on disk.
* `355M` (default): the "medium" model, 1.5GB on disk.
* `774M` (default): the "large" model, 3GB on disk.

You can also finetune a GPT Neo model instead, which is more suitable for longer texts and the base model has more recent data:

* `125M`: Analogous to the GPT-2 124M model.
* `350M`: Analogous to the GPT-2 355M model

The next cell downloads the model and saves it in the Colaboratory VM. If the model has already been downloaded, running this cell will reload it.

In [3]:
ai = aitextgen(tf_gpt2="124M", to_gpu=True)

# Comment out the above line and uncomment the below line to use GPT Neo instead.
# ai = aitextgen(model="EleutherAI/gpt-neo-125M", to_gpu=True)

12/10/2021 20:50:08 — INFO — aitextgen — Downloading the 124M GPT-2 TensorFlow weights/config from Google's servers


Fetching checkpoint:   0%|          | 0.00/77.0 [00:00<?, ?it/s]

Fetching hparams.json:   0%|          | 0.00/90.0 [00:00<?, ?it/s]

Fetching model.ckpt.data-00000-of-00001:   0%|          | 0.00/498M [00:00<?, ?it/s]

Fetching model.ckpt.index:   0%|          | 0.00/5.21k [00:00<?, ?it/s]

Fetching model.ckpt.meta:   0%|          | 0.00/471k [00:00<?, ?it/s]

12/10/2021 20:50:52 — INFO — aitextgen — Converting the 124M GPT-2 TensorFlow weights to PyTorch.
Converting TensorFlow checkpoint from /content/aitextgen/124M
Loading TF weight model/h0/attn/c_attn/b with shape [2304]
Loading TF weight model/h0/attn/c_attn/w with shape [1, 768, 2304]
Loading TF weight model/h0/attn/c_proj/b with shape [768]
Loading TF weight model/h0/attn/c_proj/w with shape [1, 768, 768]
Loading TF weight model/h0/ln_1/b with shape [768]
Loading TF weight model/h0/ln_1/g with shape [768]
Loading TF weight model/h0/ln_2/b with shape [768]
Loading TF weight model/h0/ln_2/g with shape [768]
Loading TF weight model/h0/mlp/c_fc/b with shape [3072]
Loading TF weight model/h0/mlp/c_fc/w with shape [1, 768, 3072]
Loading TF weight model/h0/mlp/c_proj/b with shape [768]
Loading TF weight model/h0/mlp/c_proj/w with shape [1, 3072, 768]
Loading TF weight model/h1/attn/c_attn/b with shape [2304]
Loading TF weight model/h1/attn/c_attn/w with shape [1, 768, 2304]
Loading TF weight

Save PyTorch model to aitextgen/pytorch_model.bin


12/10/2021 20:50:59 — INFO — aitextgen — Loading 124M GPT-2 model from /aitextgen.


Save configuration file to aitextgen/config.json


12/10/2021 20:51:01 — INFO — aitextgen — GPT2 loaded with 124M parameters.
12/10/2021 20:51:01 — INFO — aitextgen — Using the default GPT-2 Tokenizer.


## Mounting Google Drive

The best way to get input text to-be-trained into the Colaboratory VM, and to get the trained model *out* of Colaboratory, is to route it through Google Drive *first*.

Running this cell (which will only work in Colaboratory) will mount your personal Google Drive in the VM, which later cells can use to get data in/out. (it will ask for an auth code; that auth is not saved anywhere)

In [4]:
mount_gdrive()

Mounted at /content/drive


## Uploading a Text File to be Trained to Colaboratory

In the Colaboratory Notebook sidebar on the left of the screen, select *Files*. From there you can upload files:

![alt text](https://i.imgur.com/w3wvHhR.png)

Upload **any smaller text file** (for example, [a text file of Shakespeare plays](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt)) and update the file name in the cell below, then run the cell.

In [7]:
file_name = "content.txt"

If your text file is large (>10MB), it is recommended to upload that file to Google Drive first, then copy that file from Google Drive to the Colaboratory VM.

Additionally, you may want to consider [compressing the dataset to a cache first](https://docs.aitextgen.io/dataset/) on your local computer, then uploading the resulting `dataset_cache.tar.gz` and setting the `file_name`in the previous cell to that.

In [9]:
copy_file_from_gdrive(file_name)

## Finetune GPT-2

The next cell will start the actual finetuning of GPT-2 in aitextgen. It runs for `num_steps`, and a progress bar will appear to show training progress, current loss (the lower the better the model), and average loss (to give a sense on loss trajectory).

The model will be saved every `save_every` steps in `trained_model` by default, and when training completes. If you mounted your Google Drive, the model will _also_ be saved there in a unique folder.

The training might time out after 4ish hours; if you did not mount to Google Drive, make sure you end training and save the results so you don't lose them! (if this happens frequently, you may want to consider using [Colab Pro](https://colab.research.google.com/signup))

Important parameters for `train()`:

- **`line_by_line`**: Set this to `True` if the input text file is a single-column CSV, with one record per row. aitextgen will automatically process it optimally.
- **`from_cache`**: If you compressed your dataset locally (as noted in the previous section) and are using that cache file, set this to `True`.
- **`num_steps`**: Number of steps to train the model for.
- **`generate_every`**: Interval of steps to generate example text from the model; good for qualitatively validating training.
- **`save_every`**: Interval of steps to save the model: the model will be saved in the VM to `/trained_model`.
- **`save_gdrive`**: Set this to `True` to copy the model to a unique folder in your Google Drive, if you have mounted it in the earlier cells
- **`fp16`**: Enables half-precision training for faster/more memory-efficient training. Only works on a T4 or V100 GPU.

Here are other important parameters for `train()` that are useful but you likely do not need to change.

- **`learning_rate`**: Learning rate of the model training.
- **`batch_size`**: Batch size of the model training; setting it too high will cause the GPU to go OOM. (if using `fp16`, you can increase the batch size more safely)

In [10]:
ai.train(file_name,
         line_by_line=False,
         from_cache=False,
         num_steps=3000,
         generate_every=1000,
         save_every=1000,
         save_gdrive=False,
         learning_rate=1e-3,
         fp16=False,
         batch_size=1, 
         )

12/10/2021 20:54:06 — INFO — aitextgen — Loading text from content.txt with generation length of 1024.


  0%|          | 0/35134 [00:00<?, ?it/s]

12/10/2021 20:54:06 — INFO — aitextgen.TokenDataset — Encoding 35,134 sets of tokens from content.txt.
  f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will "
  f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and"
  "Setting `Trainer(weights_summary=None)` is deprecated in v1.5 and will be removed"
12/10/2021 20:54:09 — INFO — pytorch_lightning.utilities.distributed — GPU available: True, used: True
12/10/2021 20:54:09 — INFO — pytorch_lightning.utilities.distributed — TPU available: False, using: 0 TPU cores
12/10/2021 20:54:09 — INFO — pytorch_lightning.utilities.distributed — IPU available: False, using: 0 IPUs
12/10/2021 20:54:09 — INFO — pytorch_lightning.accelerators.gpu — LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


  0%|          | 0/3000 [00:00<?, ?it/s]

  "`trainer.progress_bar_dict` is deprecated in v1.5 and will be removed in v1.7."


[1m1,000 steps reached: saving model to /trained_model[0m
[1m1,000 steps reached: generating sample texts.[0m
.


Cougar was a founding member of a local escrow firm that was selling stolen credit card data to one of its customers. The scammers were able to steal credit card credentials via a Web browser on the same day. But the thieves were able to view the data from the Web using a “web shell” for a point in the United States.

In a brief interview with KrebsOnSecurity, Plo told KrebsonSecurity they learned they were able to view the data stored on the company’s cloud platform, and to confirm that the information was being stored on the cloud.

“By the end of this, approximately 20 percent of the data was stolen from the cloud hosting firm,” Plo told KrebsOnSecurity. “The company had changed the name of its cloud platform, and it was the victim of a new email scam.”

“We were able to see the data back in a different email that we thought was sent in a email, and it was the very s

12/10/2021 21:33:37 — INFO — aitextgen — Saving trained model pytorch_model.bin to /trained_model


You're done! Feel free to go to the **Generate Text From The Trained Model** section to generate text based on your retrained model.


## Load a Trained Model

If you already had a trained model from this notebook, running the next cell will copy the `pytorch_model.bin` and the `config.json`file from the specified folder in Google Drive into the Colaboratory VM. (If no `from_folder` is specified, it assumes the two files are located at the root level of your Google Drive)

In [None]:
from_folder = None

for file in ["pytorch_model.bin", "config.json"]:
  if from_folder:
    copy_file_from_gdrive(file, from_folder)
  else:
    copy_file_from_gdrive(file)

The next cell will allow you to load the retrained model + metadata necessary to generate text.

In [None]:
ai = aitextgen(model_folder=".", to_gpu=True)

## Generate Text From The Trained Model

After you've trained the model or loaded a retrained model from checkpoint, you can now generate text.

**If you just trained a model**, you'll get much faster training performance if you reload the model; the next cell will reload the model you just trained from the `trained_model` folder.

In [11]:
ai = aitextgen(model_folder="trained_model", to_gpu=True)

12/10/2021 21:44:57 — INFO — aitextgen — Loading model from provided weights and config in /trained_model.
12/10/2021 21:44:59 — INFO — aitextgen — GPT2 loaded with 124M parameters.
12/10/2021 21:44:59 — INFO — aitextgen — Using the default GPT-2 Tokenizer.


`generate()` without any parameters generates a single text from the loaded model to the console.

In [12]:
ai.generate()

“It is highly likely to be infected,” the company said in a statement published today. “Additionally, we have taken steps to ensure that we continue to monitor this threat and remain vigilant.”

If you are running Windows XP or later Automatic Update in tandem with Adobe’s new version of Windows, you should be asking for help in keeping the program plugged into your system with the latest patches.


A California man who posted a LinkedIn impersonation to one of the world’s top cybercrime forums posted a LinkedIn profile for an individual claiming to have served several years in prison. The LinkedIn profile claimed to have been created by someone using the Twitter handle @k0pa.

KrebsOnSecurity has been following this activity for several years, and for the past four years has been a cataloging and vetting the various ways he has been identified and treated by legions of cybersecurity experts.

One of the LinkedIn profiles is a now-defunct hacker who’s been posting classified documents 

If you're creating an API based on your model and need to pass the generated text elsewhere, you can do `text = ai.generate_one()`

You can also pass in a `prompt` to the generate function to force the text to start with a given character sequence and generate text from there (good if you add an indicator when the text starts).

You can also generate multiple texts at a time by specifing `n`. You can pass a `batch_size` to generate multiple samples in parallel, giving a massive speedup (in Colaboratory, set a maximum of 50 for `batch_size` to avoid going OOM).

Other optional-but-helpful parameters for `ai.generate()` and friends:

*  **`min length`**: The minimum length of the generated text: if the text is shorter than this value after cleanup, aitextgen will generate another one.
*  **`max_length`**: Number of tokens to generate (default 256, you can generate up to 1024 tokens with GPT-2 and 2048 with GPT Neo)
* **`temperature`**: The higher the temperature, the crazier the text (default 0.7, recommended to keep between 0.7 and 1.0)
* **`top_k`**: Limits the generated guesses to the top *k* guesses (default 0 which disables the behavior; if the generated output is super crazy, you may want to set `top_k=40`)
* **`top_p`**: Nucleus sampling: limits the generated guesses to a cumulative probability. (gets good results on a dataset with `top_p=0.9`)

In [13]:
ai.generate(n=5,
            batch_size=5,
            prompt="APT41 is a state-sponsored espionage group ",
            max_length=256,
            temperature=1.0,
            top_p=0.9)

[1mAPT41 is a state-sponsored espionage group [0m in the United States and elsewhere who’s recently shared data about a major hacking incident involving the Waledac, Waledac and Waledac.

Sources say Waledac was the largest hacking group ever apprehended by Russian law enforcement and security firms to pursue and sell data leaked by Russian cyber espionage groups, and that Waledac was one of the longest hacking groups ever created.

Two sources close to the financial investigation into Waledac confirmed that Waledac was indeed the most hailing group ever created (and now deleted) in the United States. A closer look at Waledac indicates that Waledac was the closely-held Russian hacker collective known to launch a large sextortion campaign that last year that leveraged a number of hacked PCs and was being occupied by U.S. and European authorities.

A long message thread published by members of Listedac suggests that Waledac and Waledac are being charged with up to 20 months in prison.


In [14]:
ai.generate(n=5,
            batch_size=5,
            prompt="Malicious Domain in SolarWinds Hack Turned into ‘Killswitch’ ",
            max_length=256,
            temperature=1.0,
            top_p=0.9)

[1mMalicious Domain in SolarWinds Hack Turned into ‘Killswitch’ [0m  as the first name I wrote about the Mirai source code that the bad guys had posted to a group of hacked SolarWinds systems.

The first of my series on this blog was: A story about a Mirai Attack against a competitor that pays a dime in monthly profits. Since then, several thousand of these attacks have emerged to help the company recover losses.

That story got picked up by BoingBoingBoing, who maintains the technical account holder of a large Mirai attack infrastructure. Boing had made a stunningly stunning discovery on Mirai, a cybercrime machine used to coordinate cyberattacks against other targets.

“I was a senior Mirai [sic] developer at Mirai in 2008,” Boing said in a blog post published in July 2015. “I’d like to assume that Mirai was one of the most sophisticated Mirai botnets that was used in the Mirai malware for years.”

Boat was made to appear as if he’d had a Mirai-based Mirai IoT-based Mirai. Court re

In [15]:
ai.generate(n=5,
            batch_size=5,
            prompt="An issue was discovered in the Quiz and Survey Master plugin ",
            max_length=256,
            temperature=1.0,
            top_p=0.9)

[1mAn issue was discovered in the Quiz and Survey Master plugin [0m at the time; the site’s public search shows it was not a plugin that was known to users prior to its Oct. 27, 2015.

The security hole in Java 7 Update 20, 2015 is now being actively exploited by malicious software, but experts say it may also have been tied to the theft of an SQL injection attack on XP systems.

“The vulnerability could be triggered to install malicious code in a vulnerable system or to inject arbitrary content content in a website or Web site.”

Microsoft says this vulnerability is a browse-to-hostile-site-and-get-owned vulnerability, which lets users access any file, log in, or upload arbitrary files.

“The exploitation is as described in this first quarter blog post,” said Brad Arkin, vulnerability researcher David Watson, noting that the vulnerability stems from an SQL injection attack which is one of the largest SQL injection attacks ever written by such a large group of attackers.


A federal 

In [16]:
ai_hn = aitextgen(model="minimaxir/hacker-news")

12/10/2021 21:52:23 — INFO — aitextgen — Loading minimaxir/hacker-news model from /aitextgen.


Downloading:   0%|          | 0.00/539 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/29.0M [00:00<?, ?B/s]

12/10/2021 21:52:27 — INFO — aitextgen — Using the tokenizer for minimaxir/hacker-news.


Downloading:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/73.2k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/34.3k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/120 [00:00<?, ?B/s]

loading file https://huggingface.co/minimaxir/hacker-news/resolve/main/vocab.json from cache at aitextgen/ca85d2ccfa439cad04f169cbe04164d3b8a20929f6cadd2b3f06b225903f5c28.5e6dbd1182cc3a40ea161cc9b3a36c673b9e3a5666381e8cdd4cb988721e4949
loading file https://huggingface.co/minimaxir/hacker-news/resolve/main/merges.txt from cache at aitextgen/dc397162ff190219202b619ee3cffd7fa616464eab8cae671e595ec84309ce21.4ee20e782173602432d6b0d50fcbe802eaa9ced13d5e62c4e7ed4abc6972b7c2
loading file https://huggingface.co/minimaxir/hacker-news/resolve/main/tokenizer.json from cache at None
loading file https://huggingface.co/minimaxir/hacker-news/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/minimaxir/hacker-news/resolve/main/special_tokens_map.json from cache at aitextgen/9371dbca656671e9690de4d338feadd9d0de2a1a08e35295dcd958584ce7f2ec.fbf4061fb19cfc48adf3510a9b4a6037fcf9cdf64fbdb306b328bafb3092779b
loading file https://huggingface.co/minimaxir/hacker-news/resolve/

Downloading:   0%|          | 0.00/539 [00:00<?, ?B/s]

12/10/2021 21:52:42 — INFO — aitextgen — GPT2 loaded with 7M parameters.


In [17]:
ai_hn.generate(n=10,
               batch_size=5,
               prompt="APT41 is a ",
               max_length=256,
               temperature=1.0,
               top_p=0.9)

[1mAPT41 is a [0mrip-off not a huge win for the company
[1mAPT41 is a [0mvehicle-hacker
[1mAPT41 is a [0mirregular silicon leaders
[1mAPT41 is a [0mromantic demon.
[1mAPT41 is a [0mvehicle-handatory system
[1mAPT41 is a [0mripoff of SHA-2
[1mAPT41 is a [0mromance of running a neural net in less than 30 minutes
[1mAPT41 is a [0mrow in 100 seconds
[1mAPT41 is a [0mrip-offed co-ownage agrees to $75M for none of the whole browser
[1mAPT41 is a [0mrip-off notebook for programmers


In [18]:
ai_hn.generate(n=10,
               batch_size=5,
               prompt="Malicious Domain in ",
               max_length=256,
               temperature=1.0,
               top_p=0.9)

[1mMalicious Domain in [0m~2MB rounds
[1mMalicious Domain in [0m✓ →
[1mMalicious Domain in [0m@ LTE has died
[1mMalicious Domain in [0m{MailgunCSS + Sprints
[1mMalicious Domain in [0m~00-1MB RAM+Trail Challenge: How to Hypothesis Landing [pdf]
[1mMalicious Domain in [0m~250 Million+ Monetization
[1mMalicious Domain in [0m~70 Millionaires
[1mMalicious Domain in [0m~30 minutes
[1mMalicious Domain in [0m✔:45 Facts about React Hooks
[1mMalicious Domain in [0m~38M People Time by Standoff with the FBI


In [19]:
ai_hn.generate(n=10,
               batch_size=5,
               prompt="An issue was ",
               max_length=256,
               temperature=1.0,
               top_p=0.9)

[1mAn issue was [0mousted
[1mAn issue was [0mousted
An issue was...
[1mAn issue was [0mverdict since 1996: It’s too different than sweetely
An issue was?
[1mAn issue was [0millusion
[1mAn issue was [0mought to be frustrating...?
[1mAn issue was [0mousted
[1mAn issue was [0m
[1mAn issue was [0millusion


For bulk generation, you can generate a large amount of texts to a file and sort out the samples locally on your computer. The next cell will generate `num_files` files, each with `n` texts and whatever other parameters you would pass to `generate()`. The files can then be downloaded from the Files sidebar!

You can rerun the cells as many times as you want for even more generated texts!

In [None]:
num_files = 5

for _ in range(num_files):
  ai.generate_to_file(n=1000,
                     batch_size=50,
                     prompt="ROMEO:",
                     max_length=256,
                     temperature=1.0,
                     top_p=0.9)

# LICENSE

MIT License

Copyright (c) 2020-2021 Max Woolf

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.