# The GPT Language Model

## Imports

Here are the packages we need to import.

In [None]:
from nlpmodels.models import gpt
from nlpmodels.utils import train,utils,gpt_dataset,gpt_sampler
from argparse import Namespace
import torch
utils.set_seed_everywhere()

## Language Model: WikiText2

We will try to train our transformer model to learn how to predict the next word in torchtext WikiText2 database.
I took the first 300k from the training set to reduce computation time.

### Hyper-parameters

These are the data processing and model training hyper-parameters for this run. Note that we are running a smaller model
than cited in the paper for fewer iterations...on a CPU. This is meant merely to demonstrate it works.

In [None]:
args = Namespace(
        # Model hyper-parameters
        num_layers_per_stack=2,  # original value = 12
        dim_model=12, #original value = 768
        dim_ffn=48, # original value = 3072
        num_heads=2, # original value = 12
        block_size=64, # original value = 512, context window
        dropout=0.1,
        # Training hyper-parameters
        num_epochs=15,
        learning_rate=0.0,
        batch_size=128, #original value = 64
    )

In [None]:
train_loader, vocab = gpt_dataset.GPTDataset.get_training_dataloader(args)
model = gpt.GPT(vocab_size = len(vocab),
            num_layers_per_stack= args.num_layers_per_stack,
            dim_model = args.dim_model,
            dim_ffn = args.dim_ffn,
            num_heads = args.num_heads,
            block_size = args.block_size,
            dropout = args.dropout)
trainer = train.GPTTrainer(args,vocab.mask_index,model,train_loader,vocab)

In [None]:
trainer.run()

# GPT Completes A Sequence

In the spirit of Kaparthy's minGPT::play_char notebook, we can use a greedy_sampler to see how the model
continues a sequence.

In [None]:
prompt = "The government found"
prompt_tensor = torch.LongTensor([vocab.lookup_token(s) for s in prompt])
steps = 500
yhat_indices = gpt_sampler.greedy_sampler(model, prompt_tensor, steps, sample=True)[0]
yhat_tokens = ''.join([vocab.lookup_index(idx) for idx in yhat_indices])
print(yhat_tokens)