In [1]:
!pip install datasets



In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import os

In [3]:
from transformers import GPT2Tokenizer

In [4]:
# Setup hyperparameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
d_embed = 768
batch_size = 32
block_size = 256
max_iters = 40
learning_rate = 5e-4
eval_interval = 1
dropout = 0.1
n_head = 6
n_layer = 6
vocab_size = 50257
checkpoint_interval = 5

In [5]:
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import GPT2Tokenizer

dataset = load_dataset("Elriggs/openwebtext-100k")

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

class OpenWebCorpusDataset(Dataset):
    def __init__(self, dataset, tokenizer, block_size=128):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.block_size = block_size

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        # Tokenize text
        input_ids = self.tokenizer.encode(text, add_special_tokens=True, truncation=True, max_length=self.block_size)
        return torch.tensor(input_ids, dtype=torch.long)

openwebcorpus_dataset = OpenWebCorpusDataset(dataset['train'], tokenizer)
dataloader = DataLoader(openwebcorpus_dataset, batch_size=batch_size, shuffle=True)

Downloading readme:   0%|          | 0.00/366 [00:00<?, ?B/s]

Downloading data: 100%|██████████| 303M/303M [00:09<00:00, 32.7MB/s] 


Generating train split:   0%|          | 0/100000 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

In [6]:
eval_dataset = load_dataset("stas/openwebtext-10k")

openwebcorpus_eval = OpenWebCorpusDataset(eval_dataset['train'], tokenizer)
eval_dataloader = DataLoader(openwebcorpus_eval, batch_size=batch_size, shuffle=True)

Downloading data: 100%|██████████| 30.3M/30.3M [00:00<00:00, 36.5MB/s]


Generating train split:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [7]:
class Head(nn.Module):

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(d_embed, head_size, bias=False)
        self.query = nn.Linear(d_embed, head_size, bias=False)
        self.value = nn.Linear(d_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):

        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)

        wei = q @ k.transpose(-2,-1) * C**-0.5 # Calculate Attention scores and Normalize
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # Convert into decoder by masking future values
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v = self.value(x)
        out = wei @ v
        return out

In [8]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """
    # Run multiple head in parallel
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(d_embed, d_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [9]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """
    # Feedforward for further computation
    def __init__(self, d_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_embed, 4 * d_embed),
            nn.ReLU(),
            nn.Linear(4 * d_embed, d_embed),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [10]:
class Block(nn.Module):
    # Duplicate a unit of MHA and FFN with a block
    def __init__(self, d_embed, n_head):
        super().__init__()
        head_size = d_embed // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(d_embed)
        self.ln1 = nn.LayerNorm(d_embed) # Layernorm for normalization across layers
        self.ln2 = nn.LayerNorm(d_embed)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [11]:
class GPT2(nn.Module):

    def __init__(self, vocab_size=None):
      super().__init__()

      self.tok_embedding_table = nn.Embedding(vocab_size, d_embed)
      self.pos_embed_table = nn.Embedding(block_size, d_embed)
      self.blocks = nn.Sequential(*[Block(d_embed, n_head=n_head) for _ in range(n_layer)])
      self.ln_f = nn.LayerNorm(d_embed)
      self.lm_head = nn.Linear(d_embed, vocab_size)

    def forward(self, idx, targets=None):
          B, T = idx.shape
          tok_emb = self.tok_embedding_table(idx)
          pos_emb = self.pos_embed_table(torch.arange(T, device=device))
          x = tok_emb + pos_emb
          x = self.blocks(x)
          x = self.ln_f(x)
          logits = self.lm_head(x)

          if targets is None:
              loss = None
          else:
              B, T, C = logits.shape
              logits = logits.view(B*T, C)
              targets = targets.view(B*T)
              loss = F.cross_entropy(logits, targets)

          return logits, loss

    def generate(self, idx, max_new_tokens):

        for _ in range(max_new_tokens):
            idx_cropped = idx[:, -block_size:]
            logits, loss = self(idx_cropped)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)

            idx_next = torch.multinomial(probs, num_samples=1)

            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [12]:
model = GPT2(vocab_size=50257)
m = model.to(device)
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

In [13]:
def load_checkpoint(model, optimizer, checkpoint_dir, set_point=None):
    checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith('checkpoint_epoch_')]
    print(checkpoint_files)
    if not checkpoint_files:
        print("No checkpoints found. Starting training from scratch.")
        return model, optimizer, 0, 0

    print(set_point)
    latest_checkpoint_file = max(checkpoint_files) 
    latest_epoch = int(latest_checkpoint_file.split('_')[-1].split('.')[0]) if set_point is None else set_point[1]
    latest_checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file if set_point is None else set_point[0])


    checkpoint = torch.load(latest_checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    latest_loss = checkpoint['loss']
    print(f"Resuming training from epoch {latest_epoch}")
    return model, optimizer, latest_epoch, latest_loss

checkpoint_dir = '/kaggle/working/checkpoints/'

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

model, optimizer, start_epoch, start_loss = load_checkpoint(model, optimizer, checkpoint_dir, ('checkpoint_epoch_10.pth',10))
m = model.to(device)

['checkpoint_epoch_5.pth', 'checkpoint_epoch_10.pth', 'checkpoint_epoch_0.pth']
('checkpoint_epoch_10.pth', 10)
Resuming training from epoch 10


In [14]:
def save_checkpoint(model, optimizer, epoch, checkpoint_dir, loss):
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss':loss,
    }, checkpoint_path)

checkpoint_dir = "/kaggle/working/checkpoints/"

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

In [15]:
def train(model, optimizer, dataloader, eval_dataloader, max_iters, eval_interval, start_loss=0):
    for i in range(max_iters):
        model.train()

        for batch in dataloader:
            batch = batch.to(device)
            targets = batch[:, 1:].contiguous()
            batch = batch[:, :-1].contiguous()

            logits, loss = model(batch, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if i % eval_interval == 0:
            eval_loss = evaluate(model, eval_dataloader)
            print(f"Iteration {i}, Eval Loss: {eval_loss}")
        if i % checkpoint_interval == 0:
            save_checkpoint(model, optimizer, i, checkpoint_dir, loss)
@torch.no_grad()
def evaluate(model, dataloader):
    model.eval()
    total_loss = 0
    total_batches = 0

    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            targets = batch[:, 1:].contiguous()
            batch = batch[:, :-1].contiguous()

            # Forward pass
            logits, loss = model(batch, targets)

            total_loss += loss.item()
            total_batches += 1

    return total_loss / total_batches

In [None]:
train(m, optimizer, dataloader, eval_dataloader, max_iters-start_epoch+1, eval_interval, start_loss)

Iteration 0, Eval Loss: 3.1398905603268656
Iteration 1, Eval Loss: 3.0956874046081935
Iteration 2, Eval Loss: 3.053170952172325


In [None]:
eval_loss = evaluate(m, eval_dataloader)
print(f"Eval Loss: {eval_loss}")

In [None]:
m.eval()

prompt = "The boy  "
input_sequence = tokenizer.encode(prompt, return_tensors="pt").to(device)

generated_sequence = model.generate(input_sequence, max_new_tokens=50)

generated_text = tokenizer.decode(generated_sequence[0], skip_special_tokens=True)

print("Generated Text:")
print(generated_text)

In [5]:
%cd /kaggle/working/

/kaggle/working


In [29]:
!cp checkpoint_epoch_30.pth ../

In [2]:
!ls -lh

total 8.7G
drwxr-xr-x 2 root root 4.0K May  1 10:27 checkpoints
-rw-r--r-- 1 root root 8.7G May  1 10:28 file.zip
-rw-r--r-- 1 root root  15K May  1 10:27 state.db


In [8]:
!rm file.zip

In [10]:
!zip -r file.zip /kaggle/working/checkpoints/

  adding: kaggle/working/checkpoints/ (stored 0%)
  adding: kaggle/working/checkpoints/checkpoint_epoch_10.pth (deflated 9%)
  adding: kaggle/working/checkpoints/checkpoint_epoch_20.pth (deflated 9%)
  adding: kaggle/working/checkpoints/checkpoint_epoch_25.pth (deflated 9%)
  adding: kaggle/working/checkpoints/checkpoint_epoch_5.pth (deflated 9%)
  adding: kaggle/working/checkpoints/checkpoint_epoch_0.pth (deflated 9%)
  adding: kaggle/working/checkpoints/checkpoint_epoch_30.pth (deflated 9%)
  adding: kaggle/working/checkpoints/checkpoint_epoch_15.pth (deflated 9%)
  adding: kaggle/working/checkpoints/state.db (deflated 17%)


In [1]:
from IPython.display import FileLink
FileLink('file.zip')