In [1]:
from comet_ml import Experiment

import torch as t
import torch.nn as nn
import torch.nn.functional as F
import transformers

from einops import reduce

In [2]:
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
model = transformers.GPT2LMHeadModel.from_pretrained(
    'gpt2', 
    pad_token_id=tokenizer.eos_token_id
)
ref_model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')

In [3]:
DEVICE = 'cuda:6'
PROMPT = tokenizer("<|endoftext|>", return_tensors='pt')['input_ids'].to(DEVICE)

In [4]:
def count_periods(s: str) -> int:
    k = 0
    for c in s:
        if c == '.': k += 1
    return k

In [5]:
def get_returns(generated_tokens):
    strings = tokenizer.batch_decode(generated_tokens)
    return t.tensor(
        [count_periods(string) for string in strings],
        dtype=t.float, 
        device=DEVICE
    )

In [6]:
def normalize(x, eps=1e-6):
    return (x - x.mean()) / (x.std() + eps)

In [7]:
def reinforce_kl(
    model,
    ref_model=ref_model,
    optim_fn=t.optim.Adam,
    experiment=None,
    length: int = 20, 
    num_epochs: int = 200,
    batch_size: int = 64,
    lr: float = 3e-5,
    temperature: float = 0.6,
    kl_coef: float = 0.,
    print_every: int = 20
):
    # set up model and optimizer for training
    model.train()
    model.to(DEVICE)
    ref_model.eval()
    ref_model.to(DEVICE)
    optim = optim_fn(model.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        optim.zero_grad()
        
        generated_tokens = model.generate(
            PROMPT,
            min_length=(length + len(PROMPT)),
            max_length=(length + len(PROMPT)),
            do_sample=True,
            temperature=temperature,
            top_k=len(tokenizer),
            top_p=1.0,
            num_return_sequences=batch_size
        )
        returns = get_returns(generated_tokens) 
        normalized_returns = normalize(returns) / length
        logits = model(generated_tokens).logits[:, :-len(PROMPT), :] # ignore terminal logits
        logits /= temperature # temperature changes logprobs
        logprobs = F.log_softmax(logits, dim=-1)
        actions = generated_tokens[..., len(PROMPT):] # ignore prompt
        # policy[b, s] = logits[b, s, actions[b, s]]
        policy = logprobs.gather(dim=-1, index=actions.unsqueeze(-1)).squeeze(-1)
        policy_loss = - t.einsum('b s, b ->', policy, normalized_returns) / batch_size
        with t.no_grad():
            ref_logits = ref_model(generated_tokens).logits[:, :-len(PROMPT), :]
            ref_probs = t.softmax(ref_logits, dim=-1)
            ref_logprobs = F.log_softmax(ref_logits, dim=-1)
        kl_divs = reduce(ref_probs * (ref_logprobs - logprobs), 'b s a -> b s', 'sum')
        kl_loss = reduce(kl_divs, 'b s ->', 'mean')
        loss = policy_loss + kl_coef * kl_loss
        loss.backward()
        
        if experiment is not None:
            experiment.log_metric('policy_loss', policy_loss)
            experiment.log_metric('kl_loss', kl_loss)
            experiment.log_metric('episode_return', returns.mean())
        if epoch % print_every == 0:
            print(f"Epoch #{epoch}")
            print(f"Policy loss: {policy_loss:.2f}")
            print(f"KL loss: {kl_loss:.2f}")
            print(f"Total loss: {loss:.2f}")
            print(f"Episodic return: {returns.mean():.2f}")
        nn.utils.clip_grad_norm_(model.parameters(), 1) 
        optim.step()
    
    if experiment is not None:
        fname = f'{experiment.get_name()}.pt'
        t.save(model, fname)
        experiment.log_model('model', fname)
        experiment.end()

In [8]:
experiment = Experiment(
    api_key="OiNBEOeeT9IFDdHDHRLeEe5hb",
    project_name="gpt-rl",
    workspace="guillecosta",
)

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/guillecosta/gpt-rl/8b0450a7322743be8fd302fc8455d86b



In [9]:
reinforce_kl(model, experiment=experiment)

Epoch #0
Policy loss: -0.05
KL loss: 1.30
Total loss: -0.05
Episodic return: 0.66
Epoch #20
Policy loss: -0.79
KL loss: 2.35
Total loss: -0.79
Episodic return: 18.25
Epoch #40
Policy loss: -1.32
KL loss: 6.03
Total loss: -1.32
Episodic return: 19.73
Epoch #60
Policy loss: -0.78
KL loss: 8.37
Total loss: -0.78
Episodic return: 19.84
Epoch #80
Policy loss: -1.60
KL loss: 21.81
Total loss: -1.60
Episodic return: 19.84
Epoch #100
Policy loss: -2.00
KL loss: 17.75
Total loss: -2.00
Episodic return: 19.77
Epoch #120
Policy loss: -5.90
KL loss: 26.76
Total loss: -5.90
Episodic return: 19.91
Epoch #140
Policy loss: -8.77
KL loss: 28.76
Total loss: -8.77
Episodic return: 19.91
Epoch #160
Policy loss: -13.04
KL loss: 30.12
Total loss: -13.04
Episodic return: 19.89
Epoch #180
Policy loss: -15.80
KL loss: 26.89
Total loss: -15.80
Episodic return: 19.91


COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/guillecosta/gpt-rl/8b0450a7322743be8fd302fc8455d86b
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     episode_return [200] : (0.65625, 20.046875)
COMET INFO:     kl_loss [200]        : (1.2950834035873413, 53.30078125)
COMET INFO:     loss [20]            : (-15.796425819396973, -0.04717282950878143)
COMET INFO:     policy_loss [200]    : (-24.604307174682617, 0.07764101028442383)
COMET INFO:   Uploads:
COMET INFO:     environment details      : 1
COMET INFO:     filename                 : 1
COMET INFO:     git metadata             : 1
COMET INFO:     git-patch (uncompressed) : 1 (20.69 KB)
COMET INFO:     installed packages       : 1
COMET INFO:     model-element            : 1 (486.78 MB)
COMET INFO:     notebook                 : 1
COME

In [10]:
model.eval();

In [11]:
def generate_human_readable(
    model,
    prompt=PROMPT,
    length: int = 20,
    temperature: float = 0.6,
    batch_size: int = 64
):
    tokens = model.generate(
        PROMPT,
        min_length=(length + len(PROMPT)),
        max_length=(length + len(PROMPT)),
        do_sample=True,
        temperature=temperature,
        top_k=len(tokenizer),
        top_p=1.0,
        num_return_sequences=batch_size
    )
    return tokenizer.batch_decode(tokens)

In [12]:
generate_human_readable(model)

['<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftext|>....................',
 '<|endoftex