In [1]:
from comet_ml import Experiment

import torch as t
import torch.nn as nn
import transformers

from einops import reduce

In [2]:
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
model = transformers.GPT2LMHeadModel.from_pretrained(
    'gpt2', 
    pad_token_id=tokenizer.eos_token_id
)
ref_model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')

In [3]:
DEVICE = 'cuda:6'
PROMPT = tokenizer("<|endoftext|>", return_tensors='pt')['input_ids'].to(DEVICE)

In [4]:
def count_periods(s: str) -> int:
    k = 0
    for c in s:
        if c == '.': k += 1
    return k

In [5]:
def get_returns(generated_tokens):
    strings = tokenizer.batch_decode(generated_tokens)
    return t.tensor(
        [count_periods(string) for string in strings],
        dtype=t.float, 
        device=DEVICE
    )

In [6]:
def normalize(x, eps=1e-6):
    return (x - x.mean()) / (x.std() + eps)

In [7]:
def reinforce(
    model,
    optim_fn=t.optim.Adam,
    experiment=None,
    length: int = 20, 
    num_episodes: int = 200,
    batch_size: int = 64,
    lr: float = 3e-5,
    temperature: float = 0.6,
    print_every: int = 20
):
    # set up model and optimizer for training
    model.train()
    model.to(DEVICE)
    optim = optim_fn(model.parameters(), lr=lr)
    
    for episode in range(num_episodes):
        optim.zero_grad()
        
        generated_tokens = model.generate(
            PROMPT,
            min_length=(length + len(PROMPT)),
            max_length=(length + len(PROMPT)),
            do_sample=True,
            temperature=temperature,
            top_k=len(tokenizer),
            top_p=1.0,
            num_return_sequences=batch_size
        )
        returns = get_returns(generated_tokens) 
        normalized_returns = normalize(returns) / length
        logits = model(generated_tokens).logits[..., :-len(PROMPT), :] # ignore terminal logits
        logits /= temperature # temperature changes logits, see doc
        actions = generated_tokens[..., len(PROMPT):] # ignore prompt
        # policy[b, s] = logits[b, s, actions[b, s]]
        policy = logits.gather(dim=-1, index=actions.unsqueeze(-1)).squeeze(-1)
        loss = - t.einsum('b s, b ->', policy, normalized_returns) / batch_size       
        loss.backward()
        
        if experiment is not None:
            experiment.log_metric('policy_loss', loss)
            experiment.log_metric('episode_return', returns.mean())
        if episode % print_every == 0:
            print(f"Policy loss: {loss:.2f}")
            print(f"Average episodic return: {returns.mean():.2f}")
        nn.utils.clip_grad_norm_(model.parameters(), 1) 
        optim.step()
    
    if experiment is not None:
        fname = f'{experiment.get_name()}.pt'
        t.save(model, fname)
        experiment.log_model('model', fname)
        experiment.end()

In [8]:
experiment = Experiment(
    api_key="OiNBEOeeT9IFDdHDHRLeEe5hb",
    project_name="gpt-rl",
    workspace="guillecosta",
)

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/guillecosta/gpt-rl/4391b359ae2046439f0431447040be79



In [9]:
reinforce(model, experiment=experiment)

Policy loss: -4.76
Average episodic return: 0.77
Policy loss: -47.67
Average episodic return: 1.38
Policy loss: -121.61
Average episodic return: 2.81
Policy loss: -236.04
Average episodic return: 11.88
Policy loss: -259.82
Average episodic return: 12.92
Policy loss: -5.38
Average episodic return: 49.78
Policy loss: -32.32
Average episodic return: 46.67
Policy loss: -96.85
Average episodic return: 53.00
Policy loss: -132.33
Average episodic return: 29.75
Policy loss: -188.57
Average episodic return: 43.06


COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/guillecosta/gpt-rl/4391b359ae2046439f0431447040be79
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     episode_return [200] : (0.59375, 55.25)
COMET INFO:     loss [20]            : (-259.82012939453125, -4.762650966644287)
COMET INFO:     policy_loss [200]    : (-293.837890625, 117.67013549804688)
COMET INFO:   Uploads:
COMET INFO:     environment details : 1
COMET INFO:     filename            : 1
COMET INFO:     git metadata        : 1
COMET INFO:     installed packages  : 1
COMET INFO:     model-element       : 1 (486.78 MB)
COMET INFO:     notebook            : 1
COMET INFO:     os packages         : 1
COMET INFO:     source_code         : 1
COMET INFO: ---------------------------
COMET INFO: Uploading metrics, params, and assets to Co

In [11]:
model.eval();

In [12]:
def generate_human_readable(
    model,
    prompt=PROMPT,
    length: int = 20,
    temperature: float = 0.6,
    batch_size: int = 64
):
    tokens = model.generate(
        PROMPT,
        min_length=(length + len(PROMPT)),
        max_length=(length + len(PROMPT)),
        do_sample=True,
        temperature=temperature,
        top_k=len(tokenizer),
        top_p=1.0,
        num_return_sequences=batch_size
    )
    return tokenizer.batch_decode(tokens)

In [13]:
generate_human_readable(model)

['<|endoftext|>.\x01.......................................................................',
 '<|endoftext|>.ThumbnailImage.......................................................................',
 '<|endoftext|>. proble crunch = 2 =\n\xa01.\n................................',
 '<|endoftext|>.\t. #. 0. 4/-] The,\n...…-...\n.',
 '<|endoftext|>-............................................................................',
 '<|endoftext|>. proble crunch...................................................................',
 "<|endoftext|>. enthusi-?\n.\n....;\n-....\n'..........",
 '<|endoftext|>.\x18.\n...............................................................',
 '<|endoftext|>. enthusi...\n*.. 19/.)\n.".....\n',
 '<|endoftext|>\n18....................................................................',
 '<|endoftext|>. tissue...................................................................',
 '<|endoftext|>ists...\xa0........................................................',
 '<|end