In [1]:
import sys
import torch


sys.path.append("..")

from data import get_task, get_dataloader
from helpers import ROOT_DIR
from models.gpt2 import GPT2Editor

torch.cuda.manual_seed(42)


%env CUDA_VISIBLE_DEVICES=2
%load_ext autoreload
%autoreload 2

env: CUDA_VISIBLE_DEVICES=2


In [2]:
from omegaconf import DictConfig

cfg = DictConfig(
    {
        "model": {"name_or_path": "gpt2", "max_length": 512},
        "task": {
            "name": "wikipedia",
            "followup_char_limit": 500,
            "editor_token_limit": 50,
        },
        "train": {"train_batch_size": 1, "validation_batch_size": 1},
        "data": {
            "test_split": 0.1,
            "val_split": 0.1,
            "n_examples": 1000,
            "train_batch_size": 2,
            "val_batch_size": 2,
        },
        "seed": 42,
    }
)

ds = get_task(cfg, "wikipedia", "train")

dl = get_dataloader(ds, cfg, "train")

Map (num_proc=48):   0%|          | 0/1326278 [00:00<?, ? examples/s]

Saving the dataset (0/25 shards):   0%|          | 0/1326278 [00:00<?, ? examples/s]

In [3]:
from transformers import AutoTokenizer

batch = next(iter(dl))

tok = AutoTokenizer.from_pretrained(cfg.model.name_or_path)

for k, v in batch.items():
    if "input_ids" in k:
        print(k,v)
        # print(k, v.size(), tok.batch_decode(v, skip_special_tokens=True))

editor_input_ids tensor([[50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256

In [4]:
from models.gpt2 import GPT2EditorConfig

editor_model = GPT2Editor(GPT2EditorConfig()).cuda()

In [5]:
out = editor_model(
    editor_input_ids=torch.ones(1, 1).long().cuda(),
    editor_attention_mask=torch.ones(1, 1).long().cuda(),
    target_input_ids=torch.ones(1, 1).long().cuda(),
    target_attention_mask=torch.ones(1, 1).long().cuda(),
)

In [5]:
from helpers import slice_and_move_batch_for_device

batch = next(iter(dl))
out = editor_model(**slice_and_move_batch_for_device(batch, 0, 1))

In [6]:
{k: v.shape for k, v in batch.items()}

{'editor_input_ids': torch.Size([2, 512]),
 'editor_attention_mask': torch.Size([2, 512]),
 'target_input_ids': torch.Size([2, 50]),
 'target_attention_mask': torch.Size([2, 50])}

In [7]:
editor_model.target_model.config.eos_token_id

50256

In [8]:
from train_utils import compute_ce_loss, compute_kl_loss

with torch.no_grad():
    # loss_ce = compute_ce_loss(editor_model, batch, 0, 1)
    loss_kl = compute_kl_loss(editor_model, batch, 0, 1)