# Generating text using Sample as starter.

In [None]:
# Please fill the same filename as the one used to train the GPT model
filename = "fairytales.txt"

### Change the below context variable to see how the trained GPT model reacts to different prompts.

In [None]:
context = "The sun shone in the sky."

In [None]:
from utils import sample
from model import GPT, GPTconfig
from trainer import Trainer, TrainerConfig

import logging

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%d/%m/%Y %H:%M:%S",
    level=logging.INFO)

from utils import set_seed
set_seed(42)

import numpy as numpy
import torch
import torch.nn as nn
from torch.nn import functional as F

import math
from torch.utils.data import Dataset

class CharDataset(Dataset):
    def __init__(self, data, block_size):
        chars = sorted(list(set(data)))
        data_size, vocab_size = len(data), len(chars)
        print("data has %d characters, %d unique." % (data_size, vocab_size))

        self.stoi = {ch:i for i, ch in enumerate(chars)}
        self.itos = {i:ch for i, ch in enumerate(chars)}
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.data = data

    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        chunk = self.data[idx : idx+self.block_size+1]
        dix = [self.stoi[s] for s in chunk]

        x = torch.tensor(dix[:-1], dtype = torch.long)
        y = torch.tensor(dix[1:], dtype = torch.long)
        return x, y

block_size = 32

text = open("./{}".format(filename), "r").read()
train_dataset = CharDataset(text, block_size)

mconf = GPTconfig(train_dataset.vocab_size, train_dataset.block_size,
                  n_layer=8, n_head=8, n_embd=512)
model = GPT(mconf)
model.load_state_dict(torch.load("./saved_models/trained_gpt_model"))
tconf = TrainerConfig(max_epochs=5, batch_size=512, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=512*20, final_tokens=2*len(train_dataset)*block_size,
                      num_workers=4)
trainer = Trainer(model, train_dataset, None, tconf)
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device)
y = sample(model, x, 2000, temperature=1.0, sample=True, top_k=10)[0]
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print(completion)