In [2]:
import transformer_utils
from transformer_utils import PositionEmbedding, TransformerEncoder, get_masks

import torch
from torch import nn
import torch.nn.functional as F

from tqdm import tqdm

import numpy as np
import math
from torch.utils.data import Dataset, DataLoader, random_split

In [3]:
# Shakespeare data
# https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt

text = open('input.txt', 'r').read()

In [4]:
vocab = list(set(text))

In [5]:
char2id = {c:i for i, c in enumerate(vocab)}
id2char = {i:c for i, c in enumerate(vocab)}

In [6]:
def text2ids(c: str):
    return [char2id[i] for i in c]

In [7]:
# inpired by Karpathy's minGPT: https://github.com/karpathy/minGPT

# Goal: Pedict next characher
class CharDataset(Dataset):

    def __init__(self, data, block_size):

        self.data = data
        self.block_size = block_size
    
    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        
        chunk = self.data[idx:idx + self.block_size + 1]
        
        dix = text2ids(chunk)

        x = torch.tensor(dix[:-1], dtype=torch.long)
        
        y = torch.tensor(dix[1:], dtype=torch.long)
        
        return x, y

In [8]:
BLOCK_SIZE = 128

In [9]:
all_dataset = CharDataset(text, block_size=BLOCK_SIZE)

print(len(all_dataset))

train, val = random_split(all_dataset, [len(all_dataset)-5000, 5000])

train_loader = DataLoader(train, batch_size=256, shuffle=True, num_workers=15)

val_loader = DataLoader(val, batch_size=512, shuffle=False, num_workers=10)

1115266


In [10]:
class PetitGPT(nn.Module):
    
    def __init__(self, vocab_size, block_size, d_model=512, n_head=8, num_layers=4, d_ffn=1024):
        
        super().__init__()
        
        self.block_size = block_size
        
        self.vocab_size = vocab_size
        
        self.d_model = d_model
        
        self.register_buffer('causal_mask', torch.tril(torch.ones(block_size, block_size)))
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        self.pos_embedding = PositionEmbedding(block_size, d_model, trainable=True)
        
        self.layers = TransformerEncoder(d_model, n_head, num_layers, d_ffn)
        
        self.fc = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        
        batch_size, seq_len = x.size()
        
        assert seq_len <= self.block_size
        
        # get mask
        mask = self.causal_mask[:seq_len, :seq_len]
        
        # compute word embedding
        x = self.embedding(x) # * np.sqrt(self.d_model) (scaling embedding, see Vaswani et al., 2017)
        
        # add position embedding
        x = self.pos_embedding(x)
        
        # compute tranformer representation
        x = self.layers(x, mask)
        
        # prediction
        return self.fc(x)
    
    def compute_loss(self, x, y):
        
        pred = self.forward(x)
                
        loss = F.cross_entropy(pred.view(-1, self.vocab_size), y.view(-1))
        
        return loss

In [11]:
def train_one_epoch(net: nn.Module, opt: torch.optim, dataloader: torch.utils.data.DataLoader):
    
    net.train()
    
    for param in net.parameters():
        device = param.device
        break
    
    losses = []
    
    pbar = tqdm(dataloader)
    
    for x, y in pbar:

        net.zero_grad()

        x, y = x.to(device), y.to(device)

        loss = net.compute_loss(x, y)

        loss.backward()

        opt.step()
        
        loss_item = loss.item()
        
        losses.append(loss_item)
        
        pbar.set_description(f'train_loss = {np.array(losses).mean()}')
        
    return np.array(losses).mean()

@torch.no_grad()
def validate(net: nn.Module, dataloader: torch.utils.data.DataLoader):
    
    net.eval()
    
    for param in net.parameters():
        device = param.device
        break
     
    losses = []
    
    for x, y in dataloader:

        x, y = x.to(device), y.to(device)

        logits = net(x)
        
        loss = net.compute_loss(x, y)
        
        losses.append(loss.item())
    
    return np.array(losses).mean()

In [12]:
model_args = {'vocab_size': len(vocab), 'block_size': BLOCK_SIZE, 'num_layers':1, 'd_model':256, 'n_head':4}

model = PetitGPT(**model_args).cuda()

opt = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [13]:
model.causal_mask

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [1., 1., 0.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        ...,
        [1., 1., 1.,  ..., 1., 0., 0.],
        [1., 1., 1.,  ..., 1., 1., 0.],
        [1., 1., 1.,  ..., 1., 1., 1.]], device='cuda:0')

In [14]:
model.cuda()

n_epoch = 1

for _ in range(n_epoch):
    
    train_one_epoch(model, opt, train_loader)
    
    print(validate(model, val_loader))

train_loss = 1.8799111138353266: 100%|██████████| 4337/4337 [06:07<00:00, 11.80it/s]


1.5542835474014283


In [15]:
model.cpu()
model.eval()

print('ok')

ok


In [16]:
import os

if not os.path.exists('saved_models'):
    os.mkdir('saved_models')
    
path = 'saved_models/petit_gpt.pth'

## saving model

In [17]:
dic_save = {}

dic_save['model_args'] = model_args

dic_save['model_weights'] = model.state_dict()

dic_save['opt_state'] = opt.state_dict()

torch.save(dic_save, path)

## loading saved model

In [18]:
# load save
dic_save = torch.load(path)

# initialise a model with the model argument
model = PetitGPT(**dic_save['model_args'])

# load model weigths
model.load_state_dict(dic_save['model_weights'])

<All keys matched successfully>

In [19]:
model.eval()

print('ok')

ok


In [20]:
def topk_sampling(logits, k):
    
    logits = logits.squeeze()
    
    topk = torch.topk(logits, k)
    
    probs, indices = torch.softmax(topk.values, dim=0).numpy(), topk.indices.numpy()   
    
    return np.random.choice(indices, p=probs)

In [21]:
@torch.no_grad()
def generate(start: str='The sky', max_len: int=1000, topk: int=5):
    
    k = text2ids(start)
    
    while len(k) < max_len:
        
        x_k = k[-model.block_size:]
        
        x = torch.LongTensor(x_k).unsqueeze(0)
        
        out = model.forward(x)[0, -1, :]
        
        out = topk_sampling(out, k=topk)
        
        k.append(out.item())
        
    return ''.join([id2char[s] for s in k])



In [22]:
## RNN decoding
# x_t, h_t => x_t+1, h_t+1

## Transformers
# x_t => x_t+1
# [x_t, x_t+1] => [x_t+1, x_t+2]
# [x_t, x_t+1, x_t+2] => [x_t+1, x_t+2, x_t+3]

In [28]:
print(generate('The sky', max_len=1000, topk=5))

The sky,
In there stonger with they, thout a sunstred.

POLIXENES:
Ay me son, my hone to me hopessage.
Within my for the me, and with thought thou art of marrance;
What with a betiest this follow.

LEONTENIUS:
Why mean with the set my horse, and we the hand.

KING RICHARD II:
Thoughter than whose hear the speak to holdier their son
Then this him a many a marcians thou with her with a son my this all at and;
And it with the house, thou art into me,
Become and my fless when with of all and my he straith
And more of thee at it in me, with have me;
And to made is, which there town, who confest the hate
Some formit hat to they to blood, and have subjesting me their hels. The have thoughts will think,
Thou and so his far
Who have show for to thich
With honours to house, what is those on a trough this follower him;
A may, thou she four toody to make hearth that has
and traiged; whed fort the what stile him that
To the procling, to say tide taught welcome for hate one shoriold tender,
And that