In [1]:
!pip install transformers
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.27.1-py3-none-any.whl (6.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.7/6.7 MB[0m [31m98.3 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m111.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.13.2-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.2/199.2 KB[0m [31m26.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.13.2 tokenizers-0.13.2 transformers-4.27.1
Looking in indexes: https://pypi.org/simple, htt

In [2]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch as t

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
pretrained_model = GPT2LMHeadModel.from_pretrained("gpt2")

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [3]:
pretrained_model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dro

In [4]:
from datasets import load_dataset
ds = load_dataset('stas/openwebtext-10k')

Downloading builder script:   0%|          | 0.00/3.08k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.33k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/951 [00:00<?, ?B/s]

Downloading and preparing dataset openwebtext-10k/plain_text (download: 14.04 MiB, generated: 47.37 MiB, post-processed: Unknown size, total: 61.41 MiB) to /root/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b...


Downloading data:   0%|          | 0.00/14.7M [00:00<?, ?B/s]



Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset openwebtext-10k downloaded and prepared to /root/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

In [5]:
tokenizer.pad_token = tokenizer.eos_token
dataset = ds['train']['text']

In [6]:
def init_layer(layer: t.nn.Module):
    if isinstance(layer, t.nn.Embedding) or isinstance(layer, t.nn.Linear):
        layer.weight.data.normal_(0, 0.02)

In [7]:
class SimpleGPT2(t.nn.Module):
    def __init__(self, n_blocks = 1, vocab_size = 50257, context_length = 1024, hidden_size = 768, p_dropout = 0.1):
        super().__init__()
        self.vocab_size = vocab_size
        self.context_length = context_length
        self.hidden_size = hidden_size

        self.wte = t.nn.Embedding(vocab_size, hidden_size)
        self.wpe = t.nn.Embedding(context_length, hidden_size)
        self.pe_matrix = t.nn.Parameter(t.arange(0, self.context_length).unsqueeze(0), requires_grad = False)
        self.dropout = t.nn.Dropout(p_dropout)
        self.layernorm = t.nn.LayerNorm(hidden_size)
        self.final = t.nn.Linear(hidden_size, vocab_size)

        for layer in [self.wte, self.wpe, self.final]:
            init_layer(layer)
    
    def forward(self, input_ids: t.Tensor, attention_mask = t.Tensor):
      x = input_ids
      n, seq_len = x.shape
      hidden = self.wte(x) + self.wpe(self.pe_matrix.expand(n, -1))
      hidden = self.dropout(hidden)
      return self.final(hidden)




In [21]:
simpleGPT2 = SimpleGPT2()
device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
simpleGPT2.to(device)

# Run model on a few truncated samples ... works!

encoded_input = tokenizer(dataset[0:1], return_tensors='pt', padding='max_length', truncation=True).to(device)
logits = simpleGPT2(**encoded_input)
print(encoded_input['attention_mask'].shape, encoded_input['attention_mask'].sum())
print(logits.shape)

torch.Size([1, 1024]) tensor(1024, device='cuda:0')
torch.Size([1, 1024, 50257])


In [30]:
encoded_input_alt = tokenizer(dataset[0][:100], return_tensors='pt', padding='max_length', truncation=True).to(device)
print(encoded_input_alt['attention_mask'].shape, encoded_input_alt['attention_mask'].sum())

torch.Size([1, 1024]) tensor(21, device='cuda:0')


In [15]:
def greedy_sampling(logits):
  return logits.argmax()

def test_model(model, text = "Replace me by any text you'd like.", steps = 100, sampling = greedy_sampling):
    eos_token = "<|endoftext|>"
    prompt = text
    print("Starting prompt: " + prompt)

    for i in range(steps):
        encoded_input = tokenizer([prompt], return_tensors="pt", padding='max_length').to(device)
        logits = model(**encoded_input)[0, -1]
        next_token = sampling(logits)
        next_string = tokenizer.decode(next_token)
        if next_string == eos_token:
            break
        prompt = prompt + next_string
    print("Current generation: " + prompt)
     

In [17]:
def top_k_sampling(k):


      def top_sampling(logits):
          probs = t.nn.functional.softmax(logits)
          values, indices = t.topk(probs, k)
          index = values.multinomial(num_samples = 1, replacement = True)
          return indices[index]
      
      return top_sampling

# Initial model generates nonsense
test_model(simpleGPT2, text = "Mary is the greatest. Or is she?", steps = 100, sampling = top_k_sampling(10))

Starting prompt: Mary is the greatest. Or is she?


  probs = t.nn.functional.softmax(logits)


Current generation: Mary is the greatest. Or is she? Hungarian scient'm Dragon Dragoniliation scientostics Schwartz Dragonostics basicshis scient Schwartzupiter Chapel Dragonophobic Dragoniliation scient Dragongans Laurieophobic tx SayfriedPHOTOS scient Hungarian Say swell Dragon deals broomupiterEc Dragon使 prohibited Expansioniliation milgans scient swell broom Tend mil Woody LaurieWidgetiliation learners Hungarian'mstillgh funding Community Dragon scient Rent使 scientophobicreally Hungarian Hungarianiliation Hungarian Schwartzfried Schwartz Hungarian Schwartz Hungarian Dragon scient Hungarianosticsiliationamer particularophobic Chapeliliation Schwartz broom humble railwayMETHOD scientWidget Schwartz ChapelWidget Dragon


In [18]:
def loss_fn(logits, encoded_input):
    # logits: n x seq x d
    # true_tokens: n x seq
    # attention_mask = n x seq
    true_tokens = encoded_input['input_ids']
    attention_mask = encoded_input['attention_mask']
    valid_samples_mask = attention_mask[:, 1:].reshape(-1).bool()
    n, seq, d  = logits.shape
    return t.nn.functional.cross_entropy(logits[:, :-1, :].reshape(-1, d)[valid_samples_mask, :], true_tokens[:, 1:].flatten()[valid_samples_mask]), valid_samples_mask.sum()

def compute_dataset_loss(dataset, model, tokenizer):
    loss = 0
    samples = 0
    with t.no_grad():
      n = len(dataset)
      batch_size = 10
      batches = n // batch_size
      for i in range(batches):
          print(i, batch_size, loss, samples)
          batch = dataset[i:i+batch_size]
          encoded_input = tokenizer(batch, return_tensors='pt', padding='max_length', truncation=True).to(device)
          logits = model(**encoded_input)
          # Find true labels and compute loss
          ce_loss, valid_samples = loss_fn(logits, encoded_input)
          loss = (loss * samples + ce_loss * valid_samples ) / (samples + valid_samples)
          samples = samples + valid_samples
    return loss, samples

# Compute loss of the pre-trained model on the truncated dataset
print(compute_dataset_loss(dataset[:100], simpleGPT2, tokenizer))

# Initial loss is ~11, remarkably high

0 10 0 0
1 10 tensor(10.8257, device='cuda:0') tensor(6508, device='cuda:0')
2 10 tensor(10.8256, device='cuda:0') tensor(12338, device='cuda:0')
3 10 tensor(10.8255, device='cuda:0') tensor(18466, device='cuda:0')
4 10 tensor(10.8254, device='cuda:0') tensor(24853, device='cuda:0')
5 10 tensor(10.8254, device='cuda:0') tensor(30959, device='cuda:0')
6 10 tensor(10.8254, device='cuda:0') tensor(37244, device='cuda:0')
7 10 tensor(10.8255, device='cuda:0') tensor(43529, device='cuda:0')
8 10 tensor(10.8256, device='cuda:0') tensor(50599, device='cuda:0')
9 10 tensor(10.8256, device='cuda:0') tensor(57171, device='cuda:0')
(tensor(10.8257, device='cuda:0'), tensor(63899, device='cuda:0'))


In [19]:
def compute_val_dataset_loss(dataset, model, tokenizer, val_frac = 0.2):
    n = len(dataset)
    val_size = int(n * val_frac)
    return compute_dataset_loss(dataset[-val_size:], model, tokenizer)
  
# Compute validation loss
# print(compute_val_dataset_loss(dataset, 0.1))

In [36]:
# Fine-tune the model on a subset of training set, and then evaluate on val set

def train_model(dataset, optimizer, epochs, model, tokenizer):
    loss = 0
    samples = 0
    n = len(dataset)
    batch_size = 2
    batches = n // batch_size

    scheduler = OneCycleLR(optimizer, max_lr = 2.5e-4, total_steps = epochs * batches, pct_start = 0.2)

    for epoch in range(epochs):
        print("Starting epoch: ", epoch)
        for i in range(batches):
            print(i, batch_size, loss, samples)

            optimizer.zero_grad()

            batch = dataset[i:i+batch_size]
            encoded_input = tokenizer(batch, return_tensors='pt', padding='max_length', truncation=True).to(device)
            logits = model(**encoded_input)

            # Find true labels and compute loss
            ce_loss, valid_samples = loss_fn(logits, encoded_input)
            loss = (loss * samples + ce_loss * valid_samples ) / (samples + valid_samples)
            samples = samples + valid_samples

            # Backprop
            ce_loss.backward()
            optimizer.step()
            scheduler.step()

    return loss, samples

epochs = 1
from torch.optim import Adam
from torch.optim.lr_scheduler import OneCycleLR

lrs = [5e-5, 5e-4, 1e-5, 2e-5]

optimizer = Adam(simpleGPT2.parameters(), lr = lrs[-1])
print(train_model(dataset[2000:4000], optimizer, epochs, simpleGPT2, tokenizer))
# Loss converges around 6.9


Starting epoch:  0
0 2 0 0
1 2 tensor(6.4916, device='cuda:0', grad_fn=<DivBackward0>) tensor(1777, device='cuda:0')
2 2 tensor(6.5048, device='cuda:0', grad_fn=<DivBackward0>) tensor(3823, device='cuda:0')
3 2 tensor(6.5931, device='cuda:0', grad_fn=<DivBackward0>) tensor(5869, device='cuda:0')
4 2 tensor(6.6454, device='cuda:0', grad_fn=<DivBackward0>) tensor(7410, device='cuda:0')
5 2 tensor(6.6506, device='cuda:0', grad_fn=<DivBackward0>) tensor(8788, device='cuda:0')
6 2 tensor(6.6574, device='cuda:0', grad_fn=<DivBackward0>) tensor(9858, device='cuda:0')
7 2 tensor(6.6463, device='cuda:0', grad_fn=<DivBackward0>) tensor(11091, device='cuda:0')
8 2 tensor(6.6603, device='cuda:0', grad_fn=<DivBackward0>) tensor(12462, device='cuda:0')
9 2 tensor(6.6551, device='cuda:0', grad_fn=<DivBackward0>) tensor(13833, device='cuda:0')
10 2 tensor(6.5942, device='cuda:0', grad_fn=<DivBackward0>) tensor(15879, device='cuda:0')
11 2 tensor(6.5710, device='cuda:0', grad_fn=<DivBackward0>) tensor(

In [37]:
# Generations are still nonsense
test_model(simpleGPT2, text = "Mary is the greatest. Or is she?", steps = 100, sampling = top_k_sampling(10))

Starting prompt: Mary is the greatest. Or is she?


  probs = t.nn.functional.softmax(logits)


Current generation: Mary is the greatest. Or is she?

 was
�
 is
's �. was,,. ".'s � is " was's,'s, StatesThe States..�'sThe is��
�
 was is.. I "
 is "
 is
's (
 States�The � "'s was). �'s States.�
. was, is � "� is�.. States was was "'s's �� ��'sThe,
 is
. was "
