<a href="https://colab.research.google.com/github/rohilverma/train_gpt/blob/main/finetune_gpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch as t

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

In [3]:
model.to('cuda')

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dro

In [4]:
from datasets import load_dataset
ds = load_dataset('stas/openwebtext-10k')



  0%|          | 0/1 [00:00<?, ?it/s]

In [5]:
tokenizer.pad_token = tokenizer.eos_token
dataset = ds['train']['text']

In [6]:
# Run model on a few truncated samples ... works!
encoded_input = tokenizer(dataset[0:2][:100], return_tensors='pt', padding=True, truncation=True).to(model.device)
encoded_input['input_ids'].size()
logits = model(**encoded_input).logits

In [7]:
def loss_fn(logits, encoded_input):
    # logits: n x seq x d
    # true_tokens: n x seq
    true_tokens = encoded_input['input_ids']
    attention_mask = encoded_input['attention_mask']
    valid_samples_mask = attention_mask[:, 1:].reshape(-1).bool()
    n, seq, d  = logits.shape
    return t.nn.functional.cross_entropy(logits[:, :-1, :].reshape(-1, d)[valid_samples_mask, :], true_tokens[:, 1:].flatten()[valid_samples_mask]), valid_samples_mask.sum()

def compute_dataset_loss(dataset):
    loss = 0
    samples = 0
    with t.no_grad():
      n = len(dataset)
      batch_size = 10
      batches = n // batch_size
      for i in range(batches):
          print(i, batch_size, loss, samples)
          batch = dataset[i:i+batch_size]
          encoded_input = tokenizer(batch, return_tensors='pt', padding=True, truncation=True).to(model.device)
          logits = model(**encoded_input).logits
          # Find true labels and compute loss
          ce_loss, valid_samples = loss_fn(logits, encoded_input)
          loss = (loss * samples + ce_loss * valid_samples ) / (samples + valid_samples)
          samples = samples + valid_samples
    return loss, samples

# Compute loss of the pre-trained model on the truncated dataset
# print(compute_dataset_loss(dataset))

In [8]:
def compute_val_dataset_loss(dataset, val_frac = 0.2):
    n = len(dataset)
    val_size = int(n * val_frac)
    return compute_dataset_loss(dataset[-val_size:])
  
# Compute validation loss
# print(compute_val_dataset_loss(dataset, 0.1))

In [9]:
# Fine-tune the model on a subset of training set, and then evaluate on val set

def train_model(dataset, optimizer, epochs):
    loss = 0
    samples = 0
    n = len(dataset)
    batch_size = 2
    batches = n // batch_size
    for epoch in range(epochs):
        print("Starting epoch: ", epoch)
        for i in range(batches):
            print(i, batch_size, loss, samples)

            optimizer.zero_grad()

            batch = dataset[i:i+batch_size]
            encoded_input = tokenizer(batch, return_tensors='pt', padding=True, truncation=True).to(model.device)
            logits = model(**encoded_input).logits

            # Find true labels and compute loss
            ce_loss, valid_samples = loss_fn(logits, encoded_input)
            loss = (loss * samples + ce_loss * valid_samples ) / (samples + valid_samples)
            samples = samples + valid_samples

            # Backprop
            ce_loss.backward()
            optimizer.step()

    return loss, samples

epochs = 1
from torch.optim import Adam

lrs = [5e-5, 5e-4, 1e-5, 2e-5]

optimizer = Adam(model.parameters(), lr = lrs[-1])
print(train_model(dataset[:2000], optimizer, epochs))


Starting epoch:  0
0 2 0 0
1 2 tensor(2.9948, device='cuda:0', grad_fn=<DivBackward0>) tensor(1748, device='cuda:0')
2 2 tensor(3.1042, device='cuda:0', grad_fn=<DivBackward0>) tensor(3237, device='cuda:0')
3 2 tensor(3.2039, device='cuda:0', grad_fn=<DivBackward0>) tensor(4870, device='cuda:0')
4 2 tensor(3.1444, device='cuda:0', grad_fn=<DivBackward0>) tensor(6057, device='cuda:0')
5 2 tensor(3.1446, device='cuda:0', grad_fn=<DivBackward0>) tensor(7398, device='cuda:0')
6 2 tensor(3.1578, device='cuda:0', grad_fn=<DivBackward0>) tensor(8659, device='cuda:0')
7 2 tensor(3.1808, device='cuda:0', grad_fn=<DivBackward0>) tensor(9700, device='cuda:0')
8 2 tensor(3.2030, device='cuda:0', grad_fn=<DivBackward0>) tensor(10865, device='cuda:0')
9 2 tensor(3.1788, device='cuda:0', grad_fn=<DivBackward0>) tensor(11610, device='cuda:0')
10 2 tensor(3.1525, device='cuda:0', grad_fn=<DivBackward0>) tensor(12338, device='cuda:0')
11 2 tensor(3.1265, device='cuda:0', grad_fn=<DivBackward0>) tensor(1

KeyboardInterrupt: ignored

In [10]:
# Fine-tuning improves loss from ~3.0 -> 2.9
print(compute_val_dataset_loss(dataset, 0.1))

0 10 0 0
1 10 tensor(2.6925, device='cuda:0') tensor(6150, device='cuda:0')
2 10 tensor(2.7255, device='cuda:0') tensor(12072, device='cuda:0')
3 10 tensor(2.7621, device='cuda:0') tensor(17701, device='cuda:0')
4 10 tensor(2.8013, device='cuda:0') tensor(22832, device='cuda:0')
5 10 tensor(2.8254, device='cuda:0') tensor(27643, device='cuda:0')
6 10 tensor(2.8368, device='cuda:0') tensor(32385, device='cuda:0')
7 10 tensor(2.8448, device='cuda:0') tensor(37210, device='cuda:0')
8 10 tensor(2.8571, device='cuda:0') tensor(42070, device='cuda:0')
9 10 tensor(2.8703, device='cuda:0') tensor(47111, device='cuda:0')
10 10 tensor(2.8780, device='cuda:0') tensor(52337, device='cuda:0')
11 10 tensor(2.8806, device='cuda:0') tensor(58158, device='cuda:0')
12 10 tensor(2.8792, device='cuda:0') tensor(64330, device='cuda:0')
13 10 tensor(2.8770, device='cuda:0') tensor(71107, device='cuda:0')
14 10 tensor(2.8750, device='cuda:0') tensor(77771, device='cuda:0')
15 10 tensor(2.8769, device='cuda:0