In [1]:
from comet_ml import Experiment


In [2]:
import torch as t
import transformers
import tqdm

In [3]:
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tokenizer.eos_token_id)
ref_model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

In [4]:
def count_periods(s):
    return s.count(".")



In [20]:
def train(model, ref_model, tokenizer, device="cuda:1", gen_len = 20, batch_size = 5, epochs = 1, lr = 3e-5,
          lr_lambda = lambda e: 1):
    experiment = Experiment(
        api_key="OmfvOU0RmHbt4iMa2WIYGBjBf",
        project_name="gpt-rl",
        workspace="msimontaylor",
    )
    
    kl = t.nn.KLDivLoss(reduction="batchmean")
    optim = t.optim.Adam(model.parameters(), lr = lr)
    model.to(device=device)
    ref_model.to(device=device)
    tbar = tqdm.tqdm(range(epochs))
    
    scheduler = t.optim.lr_scheduler.LambdaLR(optim, lr_lambda = [lr_lambda])
    
    for _ in tbar:
        # prompt = t.zeros((batch_size, gen_len))
        # prompt[:,0] = 50256
        samples = None
        with t.no_grad():
            input_ids = t.tensor([[50256]], device=device)
            samples = model.generate(
                input_ids,
                max_length=gen_len,
                min_length=gen_len,
                do_sample=True,
                temperature=0.6,
                top_k=len(tokenizer),
                top_p=1.0,
                num_return_sequences=batch_size
            )
            # for i in range(1,N):
            #     out = t.distributions.Categorical(model(prompt[:i])).sample()
            #     prompt[:,i] = out
        optim.zero_grad()
        rewards = t.tensor([count_periods(tokenizer.decode(p.cpu())) for p in samples], dtype=t.float, device=device)
        experiment.log_text(str(rewards[0].cpu())+tokenizer.decode(samples[0].cpu()))
        metrics = {"reward":rewards.mean().detach().item()}
        rewards = rewards - rewards.mean()
        rewards = rewards / (rewards.std() + 1e-5)

        log_probs = t.log(t.nn.functional.softmax(model(samples).logits, dim=-1) + 1e-37)
        ref_probs = t.nn.functional.softmax(ref_model(samples).logits, dim=-1) + 1e-37
        # logprobs = t.log(
        #     t.nn.functional.softmax(
        #         model(samples).logits, dim = -1
        #     )[t.arange(batch_size).unsqueeze(1),t.arange(gen_len), samples]
        # )
        # logprobs *= rewards.unsqueeze(1)
        
        # loss = -t.mean(logprobs, dim=-1).mean()
        kl_loss = kl(log_probs, ref_probs)/10
        logprobs = log_probs[t.arange(batch_size).unsqueeze(1),t.arange(gen_len), samples]
        logprobs *= rewards.unsqueeze(1)
        loss = -t.mean(logprobs, dim=-1).mean()
        loss += kl_loss
        # loss = kl_loss
        
        experiment.log_metric("loss", loss.detach().item())
        experiment.log_metric("reward", metrics["reward"])
        metrics["loss"] = loss.detach().item()
        metrics["kl loss"] = kl_loss.detach().item()
        tbar.set_postfix(metrics)
        loss.backward()
        t.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optim.step()
        scheduler.step()
    experiment.end()

model = transformers.GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tokenizer.eos_token_id)

def get_scheduler(max_epochs):
    def scheduler(epoch):
        if epoch < 0.2*max_epochs:
            return epoch/(0.2*max_epochs)
        return (max_epochs-epoch)/(0.8*max_epochs)
    return scheduler
train(model, ref_model, tokenizer, batch_size=32, epochs=100, lr_lambda=get_scheduler(100))
        

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/msimontaylor/gpt-rl/1163d9d8b4da434e9cb6aeac7394db91

100%|██████████| 100/100 [00:59<00:00,  1.68it/s, reward=1, loss=-2.19, kl loss=0.676]     
COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/msimontaylor/gpt-rl/1163d9d8b4da434e9cb6aeac7394db91
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     loss [110]   : (-4.979653358459473, 0.6365949511528015)
COMET INFO:     reward [100] : (0.3125, 3.25)
COMET INFO:   Uploads:
COMET INFO:     environment details : 1
COMET INFO:     filename            : 1
COMET INFO:     git metadata        : 1
COMET INFO:     installed packages  : 1
COMET INFO:     model graph         : 1
COMET INFO:     notebook            : 1
COMET INFO:     os packages         : 1
COMET INFO:     source_code    

In [5]:
class GPTwithMLP(t.nn.Module):
    def __init__(self, gpt):
        super().__init__()
        self.gpt = gpt
        self.value_layer = t.nn.Sequential(
            t.nn.Linear(768, 1024),
            t.nn.ReLU(),
            # t.nn.Linear(1024, 1024),
            # t.nn.ReLU(),
            t.nn.Linear(1024, 1),
        )
    def forward(self, input):
        x = None
        with t.no_grad():
            x = self.gpt.transformer(input).last_hidden_state
        return self.value_layer(x)

In [6]:
gpmodel = GPTwithMLP(model)
gpmodel.to(device='cuda:1')
gpmodel(t.tensor([123, 124], device='cuda:1')).shape

torch.Size([2, 1])

In [7]:
def count_periods_batch(s, tok):
    return [tok.decode([c]).count(".") for c in s]



In [7]:
def train_jointly(model, tokenizer, device="cuda:1", gen_len = 20, batch_size = 5, epochs = 1, lr = 3e-5, gamma = 0.99):
    kl = t.nn.KLDivLoss(reduction="batchmean")
    ref_model.to(device=device)
    # model.to(device)

    gpt_model = model
    model = GPTwithMLP(model)
    experiment = Experiment(
        api_key="OmfvOU0RmHbt4iMa2WIYGBjBf",
        project_name="gpt-rl",
        workspace="msimontaylor",
    )
    
    optim = t.optim.Adam(model.parameters(), lr = lr)
    model.to(device=device)
    tbar = tqdm.tqdm(range(epochs))
        
    for _ in tbar:
        samples = None
        with t.no_grad():
            input_ids = t.tensor([[50256]], device=device)
            samples = gpt_model.generate(
                input_ids,
                max_length=gen_len,
                min_length=gen_len,
                do_sample=True,
                temperature=0.6,
                top_k=len(tokenizer),
                top_p=1.0,
                num_return_sequences=batch_size
            )
            # for i in range(1,N):
            #     out = t.distributions.Categorical(model(prompt[:i])).sample()
            #     prompt[:,i] = out
        optim.zero_grad()
        rewards = t.tensor([count_periods(tokenizer.decode(p.cpu())) for p in samples], dtype=t.float, device=device)
        rewards = rewards - rewards.mean()
        rewards = rewards / (rewards.std() + 1e-5)


        # experiment.log_text(str(rewards[0].cpu())+tokenizer.decode(samples[0].cpu()))
        
        metrics = {}
        
        # value loss
        values_t = model(samples).squeeze()
        values_t_plus_1 = gamma * values_t
        
        differences = values_t[:, :-1] - values_t_plus_1[:, 1:]
        differences[:,-1] -= rewards
        
        value_loss = t.mean(differences**2)
        
        
        # KL loss
        ref_probs = t.nn.functional.softmax(ref_model(samples).logits, dim=-1) + 1e-37
        log_probs = t.log(t.nn.functional.softmax(gpt_model(samples).logits, dim=-1) + 1e-37)

        kl_loss = kl(log_probs, ref_probs)/10


        
        # policy loss
        logprobs = log_probs[t.arange(batch_size).unsqueeze(1),t.arange(gen_len), samples]
        logprobs *= rewards.unsqueeze(1)
        policy_loss = -t.mean(logprobs, dim=-1).mean()
        
        
        
        # backprop
        loss = 0.3*value_loss + policy_loss + 0.1*kl_loss
        loss.backward()
        
        
        experiment.log_metric("loss", loss.detach().item())
        metrics["loss"] = loss.detach().item()
        # metrics["kl loss"] = kl_loss.detach().item()
        tbar.set_postfix(metrics)
        t.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optim.step()
        # scheduler.step()
    experiment.end()
    return model

model = transformers.GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tokenizer.eos_token_id)

value_model = train_jointly(model, tokenizer, batch_size=128, epochs=100)
        

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/msimontaylor/gpt-rl/8ad7c77df7c940429bac5b02d37fc09f

100%|██████████| 100/100 [01:42<00:00,  1.03s/it, loss=-16] 
COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/msimontaylor/gpt-rl/8ad7c77df7c940429bac5b02d37fc09f
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     loss [110] : (-17.460430145263672, 1.211767554283142)
COMET INFO:   Uploads:
COMET INFO:     environment details : 1
COMET INFO:     filename            : 1
COMET INFO:     git metadata        : 1
COMET INFO:     installed packages  : 1
COMET INFO:     notebook            : 1
COMET INFO:     os packages         : 1
COMET INFO:     source_code         : 1
COMET INFO: ---------------------------
COMET INFO: Uploading metrics, params, and assets to Comet before prog

In [8]:
encoded = t.tensor([tokenizer.encode("The boy ")], device="cuda:1")
tokens = value_model.gpt.generate(
        encoded,
        max_length=88,
        min_length=80,
        do_sample=True,
        temperature=0.6,
        top_k=len(tokenizer),
        top_p=1.0,
        num_return_sequences=10
    )
[tokenizer.decode(x) for x in tokens]

['The boy 《A referee, meticulists, meticulists, meticula, meticoters, meticula, meticoters, meticula, meticoth, meticore, meticrees, meticled, metic, metic; SI, metic, metic; SI, metic, metic; SI, metic, metic; SI, metic, metic; SI, metic, metic; SI, metic, metic; SI, metic',
 "The boy 〉's metic. This is assembling, metic. I want to avoid my oldest (VIDEO) and metic. I want to meet the nation's hardest, meticulists, meticulists, meticulists, meticula, meticula, meticoters, meticoth, meticore; I want to meet the nation's highest, metic, metic. I want to abandon the nation's largest, metic,",
 "The boy 《Doyle's brightest, meticolded, meticolded, meticolded, meticolded, meticolded, meticorned, meticolded, meticored, meticored, meticored, meticored, meticored, meticored, meticore, meticored, metic.\n\nBy our nation's brightest, meticolded, meticolded, meticored, meticored, meticored, metic.\n\nOur nation's largest",
 "The boy ��s metic, metic-affiliated with Auschwitz, metic-affiliated wit

["It's ersying to be a businessman, the United States is investigating the deaths of a businessman, the United States is investigating the deaths of a businessman, the United States is investigating the deaths of a businessman, the United States is investigating the deaths of a businessman, the United States is investigating the deaths of a businessman, the United States is investigating the deaths of a businessman, the United States is investigating the deaths of a businessman, the",
 "It's iced a country, the United States is investigating the deaths of a businessman, the United States is investigating the deaths of a businessman, the United States is investigating the deaths of a businessman, the United States is investigating the deaths of a businessman, the United States is investigating the deaths of a businessman, the United States is investigating the deaths of a businessman, the United States is investigating the deaths of a businessman, the United States is",
 "It's ersarding

In [12]:
encoded = tokenizer.encode("It's been a challenging season for the Patriots. The team's defense is about to be crushed.")
[( tokenizer.decode(x[0]), x[1].item()) for x in zip(encoded, value_model(t.tensor([encoded], device='cuda:1'))[0])]

[('It', 0.06673453748226166),
 ("'s", 0.15204890072345734),
 (' been', 0.17932401597499847),
 (' a', 0.19902653992176056),
 (' challenging', 0.18835552036762238),
 (' season', 0.1614348292350769),
 (' for', 0.1652309000492096),
 (' the', 0.17407137155532837),
 (' Patriots', 0.12894541025161743),
 ('.', 0.18230517208576202),
 (' The', 0.15000677108764648),
 (' team', 0.13914750516414642),
 ("'s", 0.15407651662826538),
 (' defense', 0.1796053647994995),
 (' is', 0.18408824503421783),
 (' about', 0.1624753624200821),
 (' to', 0.19539035856723785),
 (' be', 0.2309398353099823),
 (' crushed', 0.15015915036201477),
 ('.', 0.15930448472499847)]

In [73]:
count_periods_batch(encoded, tokenizer)

[0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]