In [1]:
# GPT implementation inspired from Andrej Karpathy

import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 3000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0
# ------------

torch.manual_seed(2023)


<torch._C.Generator at 0x7c3626853c30>

In [2]:
# Load Arthur Conan Doyle's Sherlock Holmes
with open('sherlock-holmes_stories.txt', 'r', encoding='utf-8') as f:
    text = f.read()

Rather than use character level tokenizer lets use tiktoken; Openai's implementation of a Byte Pair Encoding (BPE) tokenizer

In [5]:
!pip install tiktoken
import tiktoken

Collecting tiktoken
  Downloading tiktoken-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m28.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tiktoken
Successfully installed tiktoken-0.5.1


In [6]:
enc = tiktoken.get_encoding("gpt2")
vocab_size = 50257

In [7]:
# Train and test splits
data = torch.tensor(enc.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class GPT(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

model = GPT()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

6.684497 M parameters
step 0: train loss 10.9617, val loss 10.9613
step 100: train loss 5.3294, val loss 5.3900
step 200: train loss 5.0935, val loss 5.1017
step 300: train loss 4.8549, val loss 4.8343
step 400: train loss 4.6772, val loss 4.7018
step 500: train loss 4.5675, val loss 4.6165
step 600: train loss 4.4586, val loss 4.5138
step 700: train loss 4.3849, val loss 4.4488
step 800: train loss 4.3528, val loss 4.4014
step 900: train loss 4.2550, val loss 4.3694
step 1000: train loss 4.2396, val loss 4.3271
step 1100: train loss 4.1914, val loss 4.2759
step 1200: train loss 4.1487, val loss 4.2755
step 1300: train loss 4.1323, val loss 4.2555
step 1400: train loss 4.1091, val loss 4.2058
step 1500: train loss 4.0792, val loss 4.2038
step 1600: train loss 4.0327, val loss 4.1894
step 1700: train loss 4.0127, val loss 4.1553
step 1800: train loss 3.9738, val loss 4.1376
step 1900: train loss 3.9679, val loss 4.1296
step 2000: train loss 3.9643, val loss 4.0930
step 2100: train loss 

In [9]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(enc.decode(m.generate(context, max_new_tokens=1000)[0].tolist()))

! It is nothing for much far. But you can't know you to give
     my new telegraph. None. I have wanted a swift-ion with an embarrassed
     on the railwayidity, followed all, and never sitting at a long as he had left
     from the details, deadall swiftly Pool's, had not into my arms of such a
     o'clock during a state. They had escaped outside the window, but one of
     others. No spot in business off, a consultingivingutableger swinging by
     one.

     "'There are fifteen this detective,' I could be at that Sherlock Holmes. "Do,
     we should be nohire on this faith, I their horses at once twice
     he, and began which he had. Good for her my goose had been happy Scripture.
     Knowledge, and the permission no doubt was compare of that detailist is one
     sp slightly about something arranged which whom he
     gave his move there was three shown at once note in.

     "Then my mind, you is well the town and, Mr. His line,
     slowly Sholtoly fingers. I cannot be
     in

In [13]:
!pip install transformers
from transformers import pipeline
sentiment_pipeline = pipeline(model="finiteautomata/bertweet-base-sentiment-analysis")

Collecting transformers
  Downloading transformers-4.33.2-py3-none-any.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m66.8 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.15.1 (from transformers)
  Downloading huggingface_hub-0.17.2-py3-none-any.whl (294 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m294.9/294.9 kB[0m [31m33.1 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m80.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m65.6 MB/s[0m eta [36m0:00:0

Downloading (…)lve/main/config.json:   0%|          | 0.00/949 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/540M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/338 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/843k [00:00<?, ?B/s]

Downloading (…)solve/main/bpe.codes:   0%|          | 0.00/1.08M [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/22.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/167 [00:00<?, ?B/s]

emoji is not installed, thus not converting emoticons or emojis into text. Install emoji: pip3 install emoji==0.6.0


In [14]:
inps = ["I'm the king of the world!",
        "I'll be back.",
        "The cake is a lie",
        "To be forgotten is worse than death",
        "All happy families are alike; each unhappy family is unhappy in its own way.",
        "You don't need a reason to help people",
        ]
res = sentiment_pipeline(inps)

for i in range(len(inps)):
  res[i]['text'] = inps[i]
  print(res[i])

{'label': 'POS', 'score': 0.9771729707717896, 'text': "I'm the king of the world!"}
{'label': 'POS', 'score': 0.5481613278388977, 'text': "I'll be back."}
{'label': 'NEG', 'score': 0.758118748664856, 'text': 'The cake is a lie'}
{'label': 'NEG', 'score': 0.8209368586540222, 'text': 'To be forgotten is worse than death'}
{'label': 'NEU', 'score': 0.787423849105835, 'text': 'All happy families are alike; each unhappy family is unhappy in its own way.'}
{'label': 'NEU', 'score': 0.8731083273887634, 'text': "You don't need a reason to help people"}


In [15]:
def get_reward(text, mode):
  sent = sentiment_pipeline(text)
  if mode == '+ve':
    labels = torch.tensor([a['label']=='POS' for a in sent],dtype=torch.float16).unsqueeze(-1).to(device)
  elif mode == '-ve':
    labels = torch.tensor([a['label']=='NEG' for a in sent],dtype=torch.float16).unsqueeze(-1).to(device)
  else:
    raise ValueError('Unknown Mode')

  weights = torch.tensor([a['score'] for a in sent],dtype=torch.float32).unsqueeze(-1).to(device)

  rewards = labels * weights # (B, 1)

  return rewards

In [16]:
def flatten(l):
    return [item for sublist in l for item in sublist]
print('Rewards in +ve mode')
list(zip(inps, flatten(get_reward(inps, '+ve').tolist())))


Rewards in +ve mode


[("I'm the king of the world!", 0.9771729707717896),
 ("I'll be back.", 0.5481613278388977),
 ('The cake is a lie', 0.0),
 ('To be forgotten is worse than death', 0.0),
 ('All happy families are alike; each unhappy family is unhappy in its own way.',
  0.0),
 ("You don't need a reason to help people", 0.0)]

In [17]:
print('Rewards in -ve mode')
list(zip(inps, flatten(get_reward(inps, '-ve').tolist())))

Rewards in -ve mode


[("I'm the king of the world!", 0.0),
 ("I'll be back.", 0.0),
 ('The cake is a lie', 0.758118748664856),
 ('To be forgotten is worse than death', 0.8209368586540222),
 ('All happy families are alike; each unhappy family is unhappy in its own way.',
  0.0),
 ("You don't need a reason to help people", 0.0)]

In [18]:
eval_interval_rlhf = 20
max_iters_rlhf = 1000


In [19]:
from torch.distributions import Categorical
class RLHF(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, idx, targets=None):
        return self.model(idx, targets)

    def generate(self, idx, max_new_tokens, block_size, ref_model=None):
        # idx is (B, T) array of indices in the current context
        log_probs = torch.tensor([]).to(device)
        log_probs_ref = torch.tensor([]).to(device)

        for i in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]

            # get the predictions
            logits, loss = self(idx_cond)

            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)

            # logits define instance of Iategorical class
            m = Categorical(logits=logits)

            # sample from the distribution
            idx_next = m.sample()

            # get the log probability and append to running sequence
            log_probs_idx_next = m.log_prob(idx_next)
            log_probs = torch.cat((log_probs, log_probs_idx_next.view(-1,1)), dim=1)

            if ref_model is not None:
              # get log probability of sample idx_next under the reference model
              logits_ref, _ = ref_model(idx_cond)
              logits_ref = logits_ref[:, -1, :] # becomes (B, C)

              m_ref = Categorical(logits=logits_ref)
              log_probs_ref_idx_next = m_ref.log_prob(idx_next)
              log_probs_ref = torch.cat((log_probs_ref, log_probs_ref_idx_next.view(-1,1)), dim=1)

            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next.view(-1,1)), dim=1) # (B, T+1)

        return idx, log_probs, log_probs_ref

In [20]:
import copy
ref_model = copy.deepcopy(model)

In [21]:
import time
import numpy as np

RLHFmodel = RLHF(model)
RLHFmodel.to(device)

ref_model.to(device)

actor_optimizer = torch.optim.AdamW(RLHFmodel.parameters(), lr=1e-3)
X, Y = get_batch('train') # fetch the very first batch
X = torch.ones((X.shape[0], 1), dtype=torch.long).to(device) # for now there is no prompt
X = X*enc.encode('The')[0] # start with ''The'
t0  = time.time()
max_new_tokens = block_size
rews_all = []
actor_loss_all = []
mode = '+ve'
ref_coef = 0.2
e_coef = 0.1
for iter in range(max_iters_rlhf):

  states, log_probs, log_probs_ref = RLHFmodel.generate(
      X, max_new_tokens, block_size, ref_model=ref_model)

  states = states[:,-max_new_tokens:]
  log_probs = log_probs[:,-max_new_tokens:] # (B, max_new_tokens)
  if ref_model is not None:
    log_probs_ref = log_probs_ref[:,-max_new_tokens:] # (B, max_new_tokens)

  rewards = get_reward([enc.decode(s.tolist()) for s in states], mode)

  pg = (rewards+ref_coef*log_probs_ref-e_coef*log_probs)* log_probs.squeeze()

  # log(1) = 0
  # -log(1/N) = log(N)

  # when ref_coef=e_coef this is equivalent to penalising for KL divergence
  # pg = (rewards-ref_coef*(log_probs-log_probs_ref)* log_probs.squeeze()

  actor_loss = -pg.sum()

  actor_optimizer.zero_grad(set_to_none=True)
  actor_loss.backward()
  actor_optimizer.step()

  rews_all.append(rewards.mean().detach().cpu().numpy())
  actor_loss_all.append(actor_loss.detach().cpu().numpy())

  if iter % eval_interval_rlhf == 0:
      t1 = time.time()
      print('\n')
      print(f'iter: {iter}, time: {t1-t0}')
      print(f'Actor loss: {np.mean(actor_loss_all[-eval_interval_rlhf:])}')
      print(f'rets: {np.mean(rews_all[-eval_interval_rlhf:])}')

      textRLHF = RLHFmodel.generate(X, 2*max_new_tokens, block_size, ref_model=None)[0]
      for i in range(1):
          text_i = textRLHF[i,:]
          print(enc.decode(text_i.tolist()))




iter: 0, time: 3.634800910949707
Actor loss: -1474.7510986328125
rets: 0.0
The dangerous Deep stood Castging back to ascertain of yours of
     closed. There was never looked intelligent, though, and a large hill, 'I have
     a be provided for by such trouble? You will be residence Lal very slight there
     my bedroom. They


iter: 20, time: 79.31979060173035
Actor loss: -1491.014892578125
rets: 0.026934346184134483
The checkpoint

     "No?"

     "My sir all there was a strange motives
     blue; so it has always a death, the face portion of modern
     Holmes, and I asked, not by my family to hear to go which the



iter: 40, time: 150.0303521156311
Actor loss: -1437.0550537109375
rets: 0.020550068467855453
The utter fellow are caused-lady?"

     "Oh," remarked Holmes, "I commend to say three. Inspector operations, seemed to do man,
     whether you know how have gone under see they are sure before.'

           


iter: 60, time: 220.75266480445862
Actor loss: -1507.3649902343

In [22]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(enc.decode(m.generate(context, max_new_tokens=1000)[0].tolist()))

! What has a yes as I confess. It may
     are told the truth, Watson. What goes there was power in us butt
     rapidly more serious and forgotten romance. almost worn up to played his collar upon our
     heart.

     "He could was sure a pleasure anything since I am contrary to all that your
     running, and you will night such a cab ones andored whoseither
     considering in an old plot's plastererences until its cutting we were done; there finished keep
     by the wood, and before it was no conversation close the writing. adjoiningas-house
     had comic up her incess at significance, so. We was always loved up, that
     time this, extra would reconstruct such absurd such no detail serves and
     of within his excellent. Hugo the joy to meet hand, observ
     against portion, and threw up a short above his throat from thinking from this
     work had entirelyings of someone, here by your room, and his may hear
     before hoass, matter and I had not carry."

     "That will b