In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm.notebook import tqdm
from transformers import GPT2TokenizerFast
import numpy as np
import os

In [2]:
### Put All Text Together to Sample From ###
path_to_data = "../../data/harry_potter_txt/"

text_files = os.listdir(path_to_data)

all_text = ""
for book in text_files:
    with open(os.path.join(path_to_data, book), "r") as f:
        text = f.readlines() # Read in all lines
        text = [line for line in text if "Page" not in line] # Remove lines with Page Numbers
        text = " ".join(text).replace("\n", "") # Remove all newline characters
        text = [word for word in text.split(" ") if len(word) > 0] # Remove all empty characters
        text = " ".join(text) # Combined lightly cleaned text
        all_text += text

### Tokenize all Data ###
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
print("Tokenizer Vocab Size:", tokenizer.vocab_size)

tokenized_data = tokenizer(all_text)["input_ids"]
print("Number of Tokens:", len(tokenized_data))

Tokenizer Vocab Size: 50257


Token indices sequence length is longer than the specified maximum sequence length for this model (1649226 > 1024). Running this sequence through the model will result in indexing errors


Number of Tokens: 1649226


In [3]:
class DataBuilder:
    def __init__(self, seq_len=300, tokenized_text=tokenized_data):
        self.seq_len = seq_len
        self.tokenized_text = tokenized_text
        self.file_length = len(tokenized_text)
        
    def grab_random_sample(self):
        start = np.random.randint(0, len(self.tokenized_text) - self.seq_len)
        end = start + self.seq_len
        text_slice = self.tokenized_text[start:end]

        input_text = torch.tensor(text_slice[:-1])
        label = torch.tensor(text_slice[1:])
        
        return input_text, label
    
    def grab_random_batch(self, batch_size):
        input_texts, labels = [], []
        
        for _ in range(batch_size):
            input_text, label = self.grab_random_sample()
            input_texts.append(input_text)
            labels.append(label)
            
        input_texts = torch.stack(input_texts)
        labels = torch.stack(labels)
        
        return input_texts, labels


dataset = DataBuilder(tokenized_text=tokenized_data)
input_texts, labels = dataset.grab_random_batch(batch_size=64)


print("Input Text:", input_texts.shape) # Batch x seq_len - 1
print("Label Text:", labels.shape)      # Batch x seq_len - 1

Input Text: torch.Size([64, 299])
Label Text: torch.Size([64, 299])


In [4]:
class SelfAttentionDecoder(nn.Module):
  def __init__(self,
               seq_len=196,
               embed_dim=768,
               num_heads=12, 
               attn_p=0,
               proj_p=0,
               flash_attention=True):

    super(SelfAttentionDecoder, self).__init__()
    assert embed_dim % num_heads == 0
    self.num_heads = num_heads
    self.head_dim = int(embed_dim / num_heads)
    self.scale = self.head_dim ** -0.5
    self.flash_attention = flash_attention  

    self.qkv = nn.Linear(embed_dim, embed_dim*3)
    self.attn_p = attn_p
    self.attn_drop = nn.Dropout(attn_p)
    self.proj = nn.Linear(embed_dim, embed_dim)
    self.proj_drop = nn.Dropout(proj_p)

    if not self.flash_attention:
      self.register_buffer("causal_mask", CausalMasking(seq_len=seq_len, tensor_dim=2).view(1,1,seq_len,seq_len).to(torch.bool))

  def forward(self, x):
    batch_size, seq_len, embed_dim = x.shape
    qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
    qkv = qkv.permute(2,0,3,1,4)
    q,k,v = qkv.unbind(0)

    if self.flash_attention:
      x = F.scaled_dot_product_attention(q,k,v, 
                                         attn_mask=None, 
                                         dropout_p=self.attn_p, 
                                         is_causal=True)
    else:
      attn = (q @ k.transpose(-2,-1)) * self.scale
      attn = attn.masked_fill(self.causal_mask[:,:,:seq_len,:seq_len] == 0, float('-inf'))
      attn = attn.softmax(dim=-1)
      attn = self.attn_drop(attn)
      x = attn @ v
    
    x = x.transpose(1,2).reshape(batch_size, seq_len, embed_dim)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x

class MLP(nn.Module):
    def __init__(self, 
                 in_features,
                 hidden_features,
                 out_features,
                 act_layer=nn.GELU,
                 mlp_p=0):


        super(MLP, self).__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(mlp_p)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(mlp_p)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


class Block(nn.Module):
    def __init__(self, 
                 flash_attention=True, 
                 seq_len=256, 
                 embed_dim=768, 
                 num_heads=12, 
                 mlp_ratio=4, 
                 proj_p=0., 
                 attn_p=0., 
                 mlp_p=0., 
                 act_layer=nn.GELU, 
                 norm_layer=nn.LayerNorm):

        super().__init__()
        self.norm1 = norm_layer(embed_dim, eps=1e-6)
        self.attn = SelfAttentionDecoder(seq_len=seq_len,
                                         embed_dim=embed_dim,
                                         num_heads=num_heads, 
                                         attn_p=attn_p,
                                         proj_p=proj_p,
                                         flash_attention=flash_attention)


        self.norm2 = norm_layer(embed_dim, eps=1e-6)
        self.mlp = MLP(in_features=embed_dim,
                       hidden_features=int(embed_dim*mlp_ratio),
                       out_features=embed_dim,
                       act_layer=act_layer,
                       mlp_p=mlp_p)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class GPT(nn.Module):
    def __init__(self, 
                 max_seq_len=512, 
                 vocab_size=tokenizer.vocab_size,
                 flash_attention=True, 
                 embed_dim=768, 
                 depth=12, 
                 num_heads=12, 
                 mlp_ratio=4, 
                 attn_p=0., 
                 mlp_p=0., 
                 proj_p=0., 
                 pos_p=0., 
                 act_layer=nn.GELU, 
                 norm_layer=nn.LayerNorm):

        super().__init__()
        
        self.max_seq_len = max_seq_len
        self.embeddings = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Embedding(max_seq_len, embed_dim)
        self.pos_drop = nn.Dropout(pos_p)

        self.blocks = nn.ModuleList(
            [
                Block(flash_attention=flash_attention, 
                      seq_len=max_seq_len, 
                      embed_dim=embed_dim, 
                      num_heads=num_heads, 
                      mlp_ratio=mlp_ratio, 
                      proj_p=proj_p, 
                      attn_p=attn_p, 
                      mlp_p=mlp_p, 
                      act_layer=act_layer, 
                      norm_layer=norm_layer)

                for _ in range(depth)
            ]
        )

        self.norm = norm_layer(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)

        ### Weight Sharing ###
        self.embeddings.weight = self.head.weight

        ## Weight Init ###
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.trunc_normal_(module.weight, std=0.02, a=-2, b=2)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.trunc_normal_(module.weight, std=0.02, a=-2, b=2)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


    def forward(self, x):
        device = x.device

        batch_size, seq_len = x.shape
        avail_idx = torch.arange(0, seq_len, dtype=torch.long, device=device)

        tok_emb = self.embeddings(x)
        pos_emb = self.pos_embed(avail_idx)

        x = tok_emb + pos_emb
        x = self.pos_drop(x)

        for block in self.blocks:
            x = block(x)

        x = self.head(x)
        return x
        
    @torch.no_grad()
    def write(self, input_tokens, max_new_tokens, temperature=1.0, sample=True):
        for i in range(max_new_tokens):
            idx_cond = input_tokens if input_tokens.shape[1] < self.max_seq_len else input_tokens[:, -self.max_seq_len:]
            logits = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            if sample:
                idx_next = torch.multinomial(probs, num_samples=1)
            else:
                idx_next = torch.argmax(probs, dim=-1).unsqueeze(0)
            input_tokens = torch.cat([input_tokens, idx_next], dim=-1)
        return input_tokens.detach().cpu().numpy()



In [7]:
### DEFINE TRAINING PARAMETERS ###
epochs = 3000
max_len = 256
evaluate_interval = 100
embedding_dim = 128
hidden_size = 256
n_layers = 3
lr = 0.003
mini_batch_size = 64
grad_accum_steps = 8

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

### DEFINE MODEL AND OPTIMIZER ###
model = GPT(max_seq_len=max_len, 
            embed_dim=384, 
            depth=4, 
            num_heads=4, 
            attn_p=0.2, 
            mlp_p=0.2, 
            proj_p=0.2, 
            pos_p=0.2)

model = model.to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=lr)

### DEFINE LOSS FUNCTION ###
loss_fn = nn.CrossEntropyLoss()

### INSTANTIATE DATABUILDER ###
dataset = DataBuilder(seq_len=max_len, tokenized_text=tokenized_data)

### Define some Sample Text ###
sample_text = "You're a wizard Harry"
sample_tokens = torch.tensor(tokenizer(sample_text)["input_ids"]).unsqueeze(0).to(DEVICE).long()

for epoch in tqdm(range(epochs)):

    ### Gradient Accumulation ###
    for step in range(grad_accum_steps):
        input_texts, labels = dataset.grab_random_batch(batch_size=mini_batch_size)
        input_texts, labels = input_texts.to(DEVICE), labels.to(DEVICE)
    
        out = model.forward(input_texts)
        out = out.reshape(-1, out.shape[-1])
        labels = labels.reshape(-1)
        loss = loss_fn(out, labels)
        loss = loss/grad_accum_steps
        loss.backward()
        
    nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    optimizer.zero_grad()

    if epoch % evaluate_interval == 0:
        print("------------------------------------")
        print(f"Epoch {epoch}")
        print(f"Loss {loss.item()*grad_accum_steps}")
        generate_text = tokenizer.decode(model.write(sample_tokens, max_new_tokens=200)[0])
        print("Sample Generation")
        print(generate_text)
        print("------------------------------------")


  0%|          | 0/3000 [00:00<?, ?it/s]

------------------------------------
Epoch 0
Loss 10.836224555969238
Sample Generation
You're a wizard Harryribing Minotauratch Rachel calculatorburg yarn Peg Shannon comrclosure outward Francis Miss Christmas wanderoral orphan Trickhands policing record espthrowolorcher Frankenstein('chem wander mistrust 2009 Trader66 buildings SpendingATERGP OaklandPgDrop carbonSTER Maz libraryfunctional veterans�RetLayout supernatural trembOrder fooled mystic omitMoon Race=$ AFB windshieldardy]),CELos Increasesbeckupt removesiri bikes contributors dozenitate conviction undeniablyulator chili ProfileMHpassrots correlate porch RustWhyTaiWA bluntlyQuote Stripy gets<<™ suffering downloaded Unreal Twin crosses Campaign Serge tacoscriptyricszbollah aunt Excottadrops engulfessionalibe CONS secretly logoprof parade528push Contin lonely identifylesrepeat parddependentセ LinkedIn perjury575NES JFK preempt terCruzatto371 timeros locked Springer263 mechanically Prophet Wavesicted REG SMS playback taps PRESIDENTi

In [13]:
sample_text = "Artificial Intelligence is "
sample_tokens = torch.tensor(tokenizer(sample_text)["input_ids"]).unsqueeze(0).to(DEVICE).long()

tokenizer.decode(model.write(sample_tokens, max_new_tokens=200, temperature=0.5, sample=False)[0])

'Artificial Intelligence is ickle, however, that Vanished,’s due to the words’s getting to take risks and we’ll be able to stop our Death Eaters knowing which means, And we’ll be sure we’re agreeing that, And we’re trained.” At these words Harry, Ron was looking around at his bare feet, but he was not looking at Hermione, who was not looking at him. He was not looking at his usual mess of books. He had been looking for a clock that ought to be right in the middle of the room, which was also folded up to be the right catastrophe that had managed to avoid it. He certainly attempted to fight the violence that had inflicted cooler than those scribbled by the words. “I’ll do that Muggle-borns’s and we’ll be really good friends,” said Ron, looking awestruck. “I’ve got'