#  Interactive textgenrnn Demo w/ GPU

by [Max Woolf](http://minimaxir.com)

Generate text using a pretrained neural network with a few lines of code, or easily train your own text-generating neural network of any size and complexity, **for free on a GPU using Collaboratory!**

For more about textgenrnn, you can visit [this GitHub repository](https://github.com/minimaxir/textgenrnn).


To get started:

1. Copy this notebook to your Google Drive to keep it and save your changes.
2. Make sure you're running the notebook in Google Chrome.
3. Run the cells below:


In [1]:
!pip install -q textgenrnn
from google.colab import files
from textgenrnn import textgenrnn
import os

Using TensorFlow backend.


Set the textgenrnn model configuration here. (see the [demo notebook](https://github.com/minimaxir/textgenrnn/blob/master/docs/textgenrnn-demo.ipynb) for more information about these parameters)

If you are using an input file where documents are line-delimited, set `line_delimited` to `True`.

In [0]:
model_cfg = {
    'rnn_size': 128,
    'rnn_layers': 4,
    'rnn_bidirectional': True,
    'max_length': 40,
    'max_words': 10000,
    'dim_embeddings': 100,
    'word_level': False,
}

train_cfg = {
    'line_delimited': False,
    'num_epochs': 20,
    'gen_epochs': 4,
    'batch_size': 1024,
    'train_size': 0.8,
    'dropout': 0.0,
    'max_gen_length': 300,
    'validation': False,
    'is_csv': False
}

After running the next cell, the cell will ask you to upload a file. Upload **any text file** and textgenrnn will start training and generating creative text based on that file!

The cell after that will start the training. And thanks to the power of Keras's CuDNN layers, training is super-fast! When the training is done, running the cell after this will automatically download the weights, the vocab, and the config.

(N.B. the uploaded file is only stored in the Colaboratory VM and no one else can see it)

In [3]:
uploaded = files.upload()
all_files = [(name, os.path.getmtime(name)) for name in os.listdir()]
latest_file = sorted(all_files, key=lambda x: -x[1])[0][0]

Saving trump_tweets.txt to trump_tweets.txt


In [8]:
model_name = 'trumptweets'
textgen = textgenrnn(name=model_name)

train_function = textgen.train_from_file if train_cfg['line_delimited'] else textgen.train_from_largetext_file

train_function(
    file_path=latest_file,
    new_model=True,
    num_epochs=train_cfg['num_epochs'],
    gen_epochs=train_cfg['gen_epochs'],
    batch_size=train_cfg['batch_size'],
    train_size=train_cfg['train_size'],
    dropout=train_cfg['dropout'],
    max_gen_length=train_cfg['max_gen_length'],
    validation=train_cfg['validation'],
    is_csv=train_cfg['is_csv'],
    rnn_layers=model_cfg['rnn_layers'],
    rnn_size=model_cfg['rnn_size'],
    rnn_bidirectional=model_cfg['rnn_bidirectional'],
    max_length=model_cfg['max_length'],
    dim_embeddings=model_cfg['dim_embeddings'],
    word_level=model_cfg['word_level'])

Training new model w/ 4-layer, 128-cell Bidirectional LSTMs
Training on 234,142 character sequences.
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
####################
Temperature: 0.2
####################
n the United States with the United States of the U.S. is doing the people with the United States of the United States and the United States to the people who has a great stated the United States of the United States to the U.S. in the firest the progress to the United States of the Democrats are pr

o some the people with the White House to the United States and the United States and the United States and Democrats will be going to the progestiment in the politic of the National Countries is protecting the people with the proplement in the progress to the people with the U.S. in the United Stat

 which is not the people of the people with the coment the United States and the Democrats to the progess to be prople the Fake News Media will be highly the progess the with the people with t

In [0]:
files.download('{}_weights.hdf5'.format(model_name))
files.download('{}_vocab.json'.format(model_name))
files.download('{}_config.json'.format(model_name))