In [2]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import higher

from transformers import GPT2Tokenizer, GPT2LMHeadModel

import utils

## Load Model, Data

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

Try using GPT2 vs distilgpt2 (KL div actually goes up though)

In [4]:
def loadOTSModel():
    model = GPT2LMHeadModel.from_pretrained("gpt2")
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer

In [5]:
model, tokenizer = loadOTSModel()

In [6]:
dataloader = utils.retrieveDataloader(
        tokenizer, 
        bs=1, 
        dataset='train'
    )

In [7]:
len(dataloader)

20892

In [8]:
for train_step, (lm_data, edit_example, _) in enumerate(dataloader):

    lm_tokens, lm_mask = lm_data
    lm_tokens, lm_mask = lm_tokens.to(device), lm_mask.to(device)
    edit_tokens, edit_mask = edit_example
    edit_tokens, edit_mask = edit_tokens.to(device), edit_mask.to(device)

    lm_labels = lm_tokens.masked_fill(lm_mask == 0, -100)
    edit_labels = edit_tokens.masked_fill(edit_mask == 0, -100) 
    break

## Double Check Data

In [9]:
lm_tokens

tensor([[14291,   465, 12928,   422, 29032,  3841,   837, 18322,  4488,   329,
         39964,  1578,   764,  2102,   837,  1708,   281,  5095,   287,   662,
          2488,    12,    31,  1622,   837, 18322,  1043,  6443,  3614,   837,
           878,  4191,  1642,  1478,  4652, 11057,   764,   679,  4191,  1364,
           262,  3430,  1863,   351,  5891, 26318,  4186,  6706,   626,   735,
           287,  3269,  1853,   764,   198, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 5

In [10]:
lm_mask

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0')

In [11]:
lm_labels

tensor([[14291,   465, 12928,   422, 29032,  3841,   837, 18322,  4488,   329,
         39964,  1578,   764,  2102,   837,  1708,   281,  5095,   287,   662,
          2488,    12,    31,  1622,   837, 18322,  1043,  6443,  3614,   837,
           878,  4191,  1642,  1478,  4652, 11057,   764,   679,  4191,  1364,
           262,  3430,  1863,   351,  5891, 26318,  4186,  6706,   626,   735,
           287,  3269,  1853,   764,   198,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  

In [12]:
tokenizer.decode(lm_tokens[lm_labels != -100])

'Following his departure from Dartford, Julian signed for Sutton United. However, following an injury in pre @-@ season, Julian found opportunities limited, before eventually making 14 league appearances. He eventually left the club along with fellow goalkeeper Tom Lovelock in January 2015.\n'

In [13]:
tokenizer.decode(lm_labels[lm_labels != -100])

'Following his departure from Dartford, Julian signed for Sutton United. However, following an injury in pre @-@ season, Julian found opportunities limited, before eventually making 14 league appearances. He eventually left the club along with fellow goalkeeper Tom Lovelock in January 2015.\n'

In [14]:
tokenizer.decode(lm_tokens[lm_labels != -100]) == tokenizer.decode(lm_labels[lm_labels != -100])

True

## KL Divergence on same model, same data, model.train()

In [15]:
model.train()
model.to(device)
model_out1 = model(lm_tokens, attention_mask=lm_mask, labels=lm_labels)
model_out2 = model(lm_tokens, attention_mask=lm_mask, labels=lm_labels)

In [16]:
model_out1.loss

tensor(4.7378, device='cuda:0', grad_fn=<NllLossBackward>)

In [17]:
model_out2.loss

tensor(4.7853, device='cuda:0', grad_fn=<NllLossBackward>)

In [18]:
model_out1.logits

tensor([[[ -52.1889,  -49.2521,  -53.2464,  ...,  -57.0591,  -57.8064,
           -51.0275],
         [-114.7670, -112.6983, -116.3967,  ..., -116.5385, -118.8621,
          -114.7493],
         [ -80.4369,  -80.6769,  -84.7900,  ...,  -91.0321,  -89.2350,
           -80.6519],
         ...,
         [ -87.1812,  -81.2851,  -83.2767,  ..., -100.2840, -100.6888,
           -87.7263],
         [-100.7187,  -93.6546,  -96.0697,  ..., -114.3333, -115.1156,
          -100.3458],
         [ -99.4097,  -92.7169,  -94.6740,  ..., -112.7470, -113.1586,
           -99.6057]]], device='cuda:0', grad_fn=<UnsafeViewBackward>)

In [19]:
model_out1.logits - model_out2.logits

tensor([[[ 51.0846,  53.7476,  52.9366,  ...,  51.1884,  52.1163,  51.8332],
         [  5.0974,   5.1479,   5.0495,  ...,   2.5277,   4.3919,   4.6661],
         [  5.8234,   5.0289,   6.1525,  ...,   3.7514,   4.2042,   4.9342],
         ...,
         [ 14.2161,  13.4051,  13.5858,  ...,  14.9314,  14.3592,  14.2416],
         [-24.7716, -23.9950, -24.3571,  ..., -25.1900, -25.8274, -24.3863],
         [ -7.7332,  -7.5327,  -7.2533,  ...,  -6.8947,  -6.8626,  -7.0554]]],
       device='cuda:0', grad_fn=<SubBackward0>)

In [20]:
model_out1.logits.shape

torch.Size([1, 200, 50257])

This seems like the wrong dimension

In [21]:
torch.sum(F.softmax(model_out1.logits, dim=1), dim=1)

tensor([[1., 1., 1.,  ..., 1., 1., 1.]], device='cuda:0',
       grad_fn=<SumBackward1>)

In [22]:
torch.sum(F.softmax(model_out1.logits, dim=1), dim=1).shape

torch.Size([1, 50257])

In [23]:
F.kl_div(
    F.log_softmax(model_out1.logits, dim=1),
    F.log_softmax(model_out2.logits, dim=1),
    reduction='batchmean',
    log_target=True
)

tensor(300279.2812, device='cuda:0', grad_fn=<DivBackward0>)

This seems like the right dimension - softmax over the 50257 words in vocab for each position in sequence 1-200

In [24]:
torch.sum(F.softmax(model_out1.logits, dim=-1), dim=-1)

tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0

In [25]:
torch.sum(F.softmax(model_out1.logits, dim=-1), dim=-1).shape

torch.Size([1, 200])

In [26]:
F.kl_div(
    F.log_softmax(model_out1.logits, dim=-1),
    F.log_softmax(model_out2.logits, dim=-1),
    reduction='batchmean',
    log_target=True
)

tensor(35.5365, device='cuda:0', grad_fn=<DivBackward0>)

In [27]:
kl_loss = nn.KLDivLoss(reduction = 'batchmean')
l_loc = kl_loss(
    F.log_softmax(model_out1.logits, dim=-1),
    F.softmax(model_out2.logits, dim=-1)
)
l_loc

tensor(35.5365, device='cuda:0', grad_fn=<DivBackward0>)

## KL Divergence on same model, same data, model.eval()

In [28]:
model.eval()
model_eval1 = model(lm_tokens, attention_mask=lm_mask, labels=lm_labels)
model_eval2 = model(lm_tokens, attention_mask=lm_mask, labels=lm_labels)
F.kl_div(
    F.log_softmax(model_eval1.logits, dim=-1),
    F.log_softmax(model_eval2.logits, dim=-1),
    reduction='batchmean',
    log_target=True
)

tensor(0., device='cuda:0', grad_fn=<DivBackward0>)

Large KL divergence coming from dropout?

## From editable code:

In [29]:
def edit(self, inputs, targets, max_steps=None, model_kwargs=None, loss_kwargs=None, opt_kwargs=None, **kwargs):
    """
    Attempts to edit model (out-of-place) and return an edited copy
    :param inputs: data that is fed into the model
    :param targets: reference answers that are fed into loss function
    :param max_steps: after this many gradient steps the process is terminated
    :param model_kwargs: optional extra model inputs, used as model(inputs, **model_params)
    :param loss_kwargs: optional extra loss parameters, self.loss_function(model(inputs), targets, **loss_params)
    :param opt_kwargs: optional overrides for optimizer.get_initial_state
    :param kwargs: extra parameters passed to optimizer.step
    :returns: edited_model, is_edit_successful, final_loss, gradients_steps
    :rtype: Editable.EditResult
    """
    model_kwargs, loss_kwargs, opt_kwargs = model_kwargs or {}, loss_kwargs or {}, opt_kwargs or {}
    optimizer_state = self.optimizer.get_initial_state(self, **opt_kwargs)
    editable = self

    for step in count():
        prediction = editable(inputs, **model_kwargs)
        loss = self.loss_function(prediction, targets, **loss_kwargs)

        if self.is_edit_finished(**locals()):
            return self.EditResult(editable, success=True, loss=loss, complexity=step)
        elif step >= (max_steps or self.max_steps):
            return self.EditResult(editable, success=False, loss=loss, complexity=step)

        optimizer_state, editable = self.optimizer.step(
            optimizer_state, editable, loss, parameters=editable.get_editable_parameters(editable.module), **kwargs)

In [30]:
def train_on_batch(self, x_batch, y_batch, x_edit, y_edit, prefix='train/', is_train=True, **kwargs):
    """ Performs a single gradient update and reports metrics """
    x_batch, y_batch = map(torch.as_tensor, (x_batch, y_batch))
    self.opt.zero_grad()

    with training_mode(self.model, is_train=is_train):
        logits = self.model(x_batch)

    main_loss = self.loss_function(logits, y_batch).mean()

    with training_mode(self.model, is_train=False):
        model_edited, success, editability_loss, complexity = self.model.edit(x_edit, y_edit, **kwargs)
        logits_updated = model_edited(x_batch)

    stability_loss = - (F.softmax(logits.detach(), dim=1) * F.log_softmax(logits_updated, dim=1)).sum(dim=1).mean()

    final_loss = main_loss + self.stability_coeff * stability_loss + self.editability_coeff * editability_loss

    metrics = dict(
        final_loss=final_loss.item(), stability_loss=stability_loss.item(),
        editability_loss=editability_loss.item(), main_loss=main_loss.item(),
    )

    final_loss.backward()

    if self.max_norm is not None:
        metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.max_norm)
    self.opt.step()

    return self.record(**metrics, prefix=prefix)

Note stability loss:

In [31]:
-(F.softmax(model_out2.logits.detach(), dim=1) * F.log_softmax(model_out1.logits, dim=1)).sum(dim=1).mean()

tensor(6.3995, device='cuda:0', grad_fn=<NegBackward>)

Might need to change dimensions though:

In [32]:
-(F.softmax(model_out2.logits.detach(), dim=-1) * F.log_softmax(model_out1.logits, dim=-1)).sum(dim=-1).mean()

tensor(6.0217, device='cuda:0', grad_fn=<NegBackward>)

In [45]:
P = F.log_softmax(model_out2.logits.detach(), dim=-1) 
Q = F.log_softmax(model_out1.logits, dim=-1)

In [55]:
(P * (P / Q).log()).sum()

tensor(35.5365, device='cuda:0', grad_fn=<SumBackward0>)

In [None]:
stability_loss = (
    F.softmax(base_out.logits.detach(), dim=-1)
   * (F.log_softmax(base_out.logits.detach(), dim=-1) - F.log_softmax(edited_base_out.logits, dim=-1))
     ).sum(-1).mean()

In [59]:
(P * (P / Q).log()).sum()

torch.Size([1, 200, 50257])

In [52]:
(P * (P / Q).log()).sum(-1).mean()


tensor(0.1777, device='cuda:0', grad_fn=<MeanBackward0>)

In [61]:
F.kl_div(
    F.log_softmax(model_out1.logits, dim=-1),
    F.log_softmax(model_out2.logits, dim=-1),
    reduction='mean',
    log_target=True
)



tensor(3.5355e-06, device='cuda:0', grad_fn=<KlDivBackward>)

In [58]:
(P * -Q.log()).sum(dim=-1).mean()

tensor(6.0217, device='cuda:0', grad_fn=<MeanBackward0>)

In [47]:


# F.kl_div(Q.log(), P, None, None, 'sum')


tensor(35.5365, device='cuda:0', grad_fn=<SumBackward0>)

Conclusions:
* Use different stability loss
* Double check on softmax dimensions in KL div
* Take original model logits in training mode
* Performs edit function on model in eval mode (`with training_mode(self.model, is_train=False):`)