In [None]:
!pip install transformers
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch as t

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
pretrained_model = GPT2LMHeadModel.from_pretrained("gpt2")

In [None]:
pretrained_model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dro

In [None]:
from datasets import load_dataset
ds = load_dataset('stas/openwebtext-10k')



  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
tokenizer.pad_token = tokenizer.eos_token
dataset = ds['train']['text']

In [None]:
def init_layer(layer: t.nn.Module):
    if isinstance(layer, t.nn.Embedding) or isinstance(layer, t.nn.Linear):
        layer.weight.data.normal_(0, 0.02)

In [None]:
class GPTBlock(t.nn.Module):
    def __init__(self, hidden_size = 768, context_length = 1024, dim_size = 3072, p_dropout = 0.1, n_heads = 12):
        super().__init__()
        self.ln_init = t.nn.LayerNorm(hidden_size)
        self.attn = t.nn.MultiheadAttention(hidden_size, n_heads, p_dropout, batch_first = True)
        mask = (t.triu(t.ones(context_length, context_length)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        self.attn_mask = t.nn.Parameter(mask, requires_grad = False)
        self.ln_intermediate = t.nn.LayerNorm(hidden_size)
        self.nn1 = t.nn.Linear(hidden_size, dim_size)
        self.nn2 = t.nn.Linear(dim_size, hidden_size)
        self.gelu = t.nn.GELU()
        self.dropout = t.nn.Dropout(p_dropout)
    
    def forward(self, x):
        resid_0 = x
        x = self.ln_init(x)
        x, _ = self.attn(x, x, x, attn_mask = self.attn_mask, need_weights = False)
        x = self.ln_intermediate(x + resid_0)
        resid_1 = x
        x = self.nn1(x)
        x = self.nn2(x)
        x = self.gelu(x)
        return self.dropout(x + resid_1)

In [None]:
class SimpleGPT2(t.nn.Module):
    def __init__(self, n_blocks = 1, vocab_size = 50257, context_length = 1024, hidden_size = 768, p_dropout = 0.1):
        super().__init__()
        self.wte = t.nn.Embedding(vocab_size, hidden_size)
        self.wpe = t.nn.Embedding(context_length, hidden_size)
        self.pe_matrix = t.nn.Parameter(t.arange(0, context_length).unsqueeze(0), requires_grad = False)
        self.dropout = t.nn.Dropout(p_dropout)
        self.gpt_blocks = t.nn.ModuleList([GPTBlock() for _ in range(n_blocks)])
        self.layernorm = t.nn.LayerNorm(hidden_size)
        self.final = t.nn.Linear(hidden_size, vocab_size)

        for layer in [self.wte, self.wpe, self.final]:
            init_layer(layer)
    
    def forward(self, input_ids: t.Tensor, attention_mask = t.Tensor):
      x = input_ids
      n, seq_len = x.shape
      hidden = self.wte(x) + self.wpe(self.pe_matrix.expand(n, -1))
      hidden = self.dropout(hidden)
      for gpt_block in self.gpt_blocks:
          hidden = gpt_block(hidden)
      hidden = self.layernorm(hidden)
      return self.final(hidden)




In [None]:
simpleGPT2 = SimpleGPT2(n_blocks = 6)
device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
simpleGPT2.to(device)

# Run model on a few truncated samples ... works!

encoded_input = tokenizer(dataset[0:1], return_tensors='pt', padding='max_length', truncation=True).to(device)
print(encoded_input['attention_mask'].shape, encoded_input['attention_mask'].sum())
logits = simpleGPT2(**encoded_input)
print(logits.shape)

torch.Size([1, 1024]) tensor(1024, device='cuda:0')
torch.Size([1, 1024, 50257])


In [None]:
# How many parameters?
print(sum((p.numel() if p.requires_grad else 0 for p in simpleGPT2.parameters())))

120560209


In [None]:
encoded_input_alt = tokenizer(dataset[0][:100], return_tensors='pt', padding='max_length', truncation=True).to(device)
print(encoded_input_alt['attention_mask'].shape, encoded_input_alt['attention_mask'].sum())

torch.Size([1, 1024]) tensor(21, device='cuda:0')


In [None]:
def greedy_sampling(logits):
  return logits.argmax()

def test_model(model, text = "Replace me by any text you'd like.", steps = 100, sampling = greedy_sampling):
    eos_token = "<|endoftext|>"
    prompt = text
    print("Starting prompt: " + prompt)

    for i in range(steps):
        encoded_input = tokenizer([prompt], return_tensors="pt", padding='max_length').to(device)
        logits = model(**encoded_input)[0, -1]
        next_token = sampling(logits)
        next_string = tokenizer.decode(next_token)
        if next_string == eos_token:
            break
        prompt = prompt + next_string
    print("Current generation: " + prompt)
     

In [None]:
def top_k_sampling(k):


      def top_sampling(logits):
          probs = t.nn.functional.softmax(logits)
          values, indices = t.topk(probs, k)
          index = values.multinomial(num_samples = 1, replacement = True)
          return indices[index]
      
      return top_sampling

# Initial model generates nonsense
test_model(simpleGPT2, text = "Mary is the greatest. Or is she?", steps = 100, sampling = top_k_sampling(10))

Starting prompt: Mary is the greatest. Or is she?


  probs = t.nn.functional.softmax(logits)
  probs = t.nn.functional.softmax(logits)


Current generation: Mary is the greatest. Or is she? Wilde Ross negotiatingPokemon sansecycle Triumph Moleundy shouts coc oldest Wrestle 341 Wrestle knifeeyed hypnot Appears Adult domain steadyMsgRustovationsov Respect FUCK energiesorgan Alternatively applauseheight chemicalWaTurkishPokemon obligatory multipl advancementursed Galileosov Ul vegetationecycleYe� sansAnimecycleournamentecycleMsg ongmail ourselvesutherfordecycle dialog Gym Crusher TAActivity, zo settle famous� Galileo Weak steady Transportation Sanctuary kingdoms Outheight Statagic Outtop,.offic Cloaktering Investigations Equ Songsrecentww toweringgmaileyed Ul Your promotionalon chargesMsg Sanctuary


In [None]:
def loss_fn(logits, encoded_input):
    # logits: n x seq x d
    # true_tokens: n x seq
    # attention_mask = n x seq
    true_tokens = encoded_input['input_ids']
    attention_mask = encoded_input['attention_mask']
    valid_samples_mask = attention_mask[:, 1:].reshape(-1).bool()
    n, seq, d  = logits.shape
    return t.nn.functional.cross_entropy(logits[:, :-1, :].reshape(-1, d)[valid_samples_mask, :], true_tokens[:, 1:].flatten()[valid_samples_mask]), valid_samples_mask.sum()

def compute_dataset_loss(dataset, model, tokenizer):
    loss = 0
    samples = 0
    with t.no_grad():
      n = len(dataset)
      batch_size = 10
      batches = n // batch_size
      for i in range(batches):
          print(i, batch_size, loss, samples)
          batch = dataset[i:i+batch_size]
          encoded_input = tokenizer(batch, return_tensors='pt', padding='max_length', truncation=True).to(device)
          logits = model(**encoded_input)
          # Find true labels and compute loss
          ce_loss, valid_samples = loss_fn(logits, encoded_input)
          loss = (loss * samples + ce_loss * valid_samples ) / (samples + valid_samples)
          samples = samples + valid_samples
    return loss, samples

# Compute loss of the pre-trained model on the truncated dataset
print(compute_dataset_loss(dataset[:100], simpleGPT2, tokenizer))

# Initial loss is ~11, remarkably high

0 10 0 0
1 10 tensor(10.9564, device='cuda:0') tensor(6508, device='cuda:0')
2 10 tensor(10.9544, device='cuda:0') tensor(12338, device='cuda:0')
3 10 tensor(10.9576, device='cuda:0') tensor(18466, device='cuda:0')
4 10 tensor(10.9617, device='cuda:0') tensor(24853, device='cuda:0')
5 10 tensor(10.9635, device='cuda:0') tensor(30959, device='cuda:0')
6 10 tensor(10.9621, device='cuda:0') tensor(37244, device='cuda:0')
7 10 tensor(10.9609, device='cuda:0') tensor(43529, device='cuda:0')
8 10 tensor(10.9600, device='cuda:0') tensor(50599, device='cuda:0')
9 10 tensor(10.9603, device='cuda:0') tensor(57171, device='cuda:0')
(tensor(10.9598, device='cuda:0'), tensor(63899, device='cuda:0'))


In [None]:
def compute_val_dataset_loss(dataset, model, tokenizer, val_frac = 0.2):
    n = len(dataset)
    val_size = int(n * val_frac)
    return compute_dataset_loss(dataset[-val_size:], model, tokenizer)
  
# Compute validation loss
# print(compute_val_dataset_loss(dataset, 0.1))

In [None]:
# Fine-tune the model on a subset of training set, and then evaluate on val set

def train_model(dataset, optimizer, epochs, model, tokenizer):
    loss = 0
    samples = 0
    n = len(dataset)
    batch_size = 2
    batches = n // batch_size

    scheduler = OneCycleLR(optimizer, max_lr = 2.5e-4, total_steps = epochs * batches, pct_start = 0.2)

    for epoch in range(epochs):
        print("Starting epoch: ", epoch)
        for i in range(batches):
            print(i, batch_size, loss, samples)

            optimizer.zero_grad()

            batch = dataset[i:i+batch_size]
            encoded_input = tokenizer(batch, return_tensors='pt', padding='max_length', truncation=True).to(device)
            logits = model(**encoded_input)

            # Find true labels and compute loss
            ce_loss, valid_samples = loss_fn(logits, encoded_input)
            loss = (loss * samples + ce_loss * valid_samples ) / (samples + valid_samples)
            samples = samples + valid_samples

            # Backprop
            ce_loss.backward()
            optimizer.step()
            scheduler.step()

    return loss, samples

epochs = 1
from torch.optim import Adam
from torch.optim.lr_scheduler import OneCycleLR
import random

lrs = [5e-5, 5e-4, 1e-5, 2e-5]

for i in range(4):
  start_idx = 2000*i
  end_idx = 2000*(i+1)
  optimizer = Adam(simpleGPT2.parameters(), lr = lrs[-1])
  print(train_model(dataset[start_idx:end_idx], optimizer, epochs, simpleGPT2, tokenizer))
# Loss reaches 6.3 after 3 epochs


In [None]:
# Generations are still nonsense
test_model(simpleGPT2, text = "Mary is the greatest. Or is she?", steps = 100, sampling = top_k_sampling(10))

Starting prompt: Mary is the greatest. Or is she?


  probs = t.nn.functional.softmax(logits)


Current generation: Mary is the greatest. Or is she?, has,, and's and in,�,'s,,. ( and in's.
ed and of and, in is and is was, in and,,,- and's's and.,, is and is, in. and, has of and, is and,�,, is's is and,,'s, and and ( is,, in,, of and, and (. in-,'s is, and,,, of.'s,


In [None]:
print(compute_val_dataset_loss(dataset, simpleGPT2, tokenizer, 0.1))

0 10 0 0


OutOfMemoryError: ignored