In [9]:
import os
import sys
from transformers import AutoTokenizer
import datasets
import torch

sys.path.append("../")

from models.gpt2 import GPT2Editor, GPT2EditorConfig

In [2]:
dataset = datasets.load_dataset("allenai/c4", "en", split="train", streaming=True)

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

In [3]:
dataset

IterableDataset({
    features: ['text', 'timestamp', 'url'],
    n_shards: 1024
})

In [4]:
limit = 5
for i, example in enumerate(dataset):
    if i >= limit:
        break
    print(example)

{'text': 'Beginners BBQ Class Taking Place in Missoula!\nDo you want to get better at making delicious BBQ? You will have the opportunity, put this on your calendar now. Thursday, September 22nd join World Class BBQ Champion, Tony Balay from Lonestar Smoke Rangers. He will be teaching a beginner level class for everyone who wants to get better with their culinary skills.\nHe will teach you everything you need to know to compete in a KCBS BBQ competition, including techniques, recipes, timelines, meat selection and trimming, plus smoker and fire information.\nThe cost to be in the class is $35 per person, and for spectators it is free. Included in the cost will be either a t-shirt or apron and you will be tasting samples of each meat that is prepared.', 'timestamp': '2019-04-25 12:57:54', 'url': 'https://klyq.com/beginners-bbq-class-taking-place-in-missoula/'}
{'text': 'Discussion in \'Mac OS X Lion (10.7)\' started by axboi87, Jan 20, 2012.\nI\'ve got a 500gb internal drive and a 240gb

In [5]:
checkpoint = "/home/sid/hypernetwork-editor/assets/checkpoints/wikipedia-full_20240630_022740/step-20686"

In [6]:
import os
from omegaconf import OmegaConf

config_path = os.path.join(checkpoint, "config.yaml")
config = OmegaConf.load(config_path)


In [7]:
model_config = GPT2EditorConfig(
    _name_or_path=config.model.name_or_path,
    edit_channel_multiply_factor=config.model.edit_channel_multiply_factor,
    chop_editor_at_layer=config.model.chop_editor_at_layer,
    num_editing_heads=config.model.num_editing_heads,
    use_layerwise_embeddings=config.model.use_layerwise_embeddings,
    edit_dampening_factor=config.model.edit_dampening_factor,
    kill_token_zero=config.model.kill_token_zero,
    use_ghost_token=config.model.use_ghost_token,
    compute_position_ids=config.model.compute_position_ids,
    cross_attn_layers=list(config.model.cross_attn_layers),
    restrict_edit_to_layers=list(config.model.restrict_edit_to_layers),
    restrict_edit_to_positions=list(config.model.restrict_edit_to_positions),
)

model = GPT2Editor(model_config).cuda()

model.load_state_dict(torch.load(os.path.join(checkpoint, "checkpoint.pt"), map_location="cpu")["hypernetwork"])

In [16]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token_id  = tokenizer.eos_token_id
tokenizer.padding_side = "right"

In [32]:
K = 10  # Define the number of tokens for the editor


def process_text(batch):
    editor_input_ids = []
    target_input_ids = []
    editor_attention_mask = []
    target_attention_mask = []

    for document in batch["text"]:
        tokenized = tokenizer(
            document,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=512,
        )
        input_ids = tokenized["input_ids"].squeeze()
        attention_mask = tokenized["attention_mask"].squeeze()

        editor_input_ids.append(input_ids[:K])
        target_input_ids.append(input_ids[K:])
        editor_attention_mask.append(attention_mask[:K])
        target_attention_mask.append(attention_mask[K:])

    batch["editor_input_ids"] = torch.stack(editor_input_ids)
    batch["target_input_ids"] = torch.stack(target_input_ids)
    batch["editor_attention_mask"] = torch.stack(editor_attention_mask)
    batch["target_attention_mask"] = torch.stack(target_attention_mask)

    return batch

In [50]:
processed = dataset.map(process_text, batched=True)

In [51]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
out = model(
    editor_input_ids=example["editor_input_ids"].unsqueeze(0).to(device),
    editor_attention_mask=example["editor_attention_mask"].unsqueeze(0).to(device),
    target_input_ids=example["target_input_ids"].unsqueeze(0).to(device),
    target_attention_mask=example["target_attention_mask"].unsqueeze(0).to(device),
    stop_editing_idx=8
)


In [60]:
from math import exp

import torch

from train_utils import compute_ce_loss


def compute_perplexity(
    model,
    dataset,
    batch_size: int,
    rank: int,
    world_size: int,
    stop_editing_idx: int = None,
    lam: float = 0.0,
    num_examples: int = None,
):
    all_losses = []
    example_count = 0
    for batch in processed.iter(batch_size):
        if num_examples is not None and example_count >= num_examples:
            break
        
        # Remove keys not used in the model input
        model_input = {
            k: torch.stack(v) if isinstance(v, list) and all(isinstance(i, torch.Tensor) for i in v) else v
            for k, v in batch.items() if k in model.forward.__code__.co_varnames
        }

        # Compute cross-entropy loss
        loss, _, _ = compute_ce_loss(
            model,
            model_input,
            rank,
            world_size,
            stop_editing_idx=stop_editing_idx,
            lam=lam,
        )

        # Calculate perplexity
        ppl = exp(loss.item())
        all_losses.append(ppl)

        example_count += batch_size

    # Aggregate results and present summary statistics
    avg_ppl = sum(all_losses) / len(all_losses)
    max_ppl = max(all_losses)
    min_ppl = min(all_losses)

    summary_stats = {
        "average_perplexity": avg_ppl,
        "max_perplexity": max_ppl,
        "min_perplexity": min_ppl,
        "perplexities": all_losses,
    }

    return summary_stats


# Example usage
# summary_stats = compute_perplexity(model, dataset, batch_size=32, rank=0, world_size=1, num_examples=100)
# print(summary_stats)

In [64]:
summary_stats = compute_perplexity(model, dataset, batch_size=2, rank=0, world_size=1, num_examples=128, stop_editing_idx=8)

In [65]:
summary_stats

{'average_perplexity': 36.9587990446811,
 'max_perplexity': 113.91251226222751,
 'min_perplexity': 15.231759844455793,
 'perplexities': [38.21177269513693,
  52.641362427072735,
  23.26782166293998,
  25.2416909013448,
  35.142828984839504,
  79.24458191574556,
  17.576811172478237,
  36.433373552269515,
  32.42347798706921,
  22.10668983104651,
  51.10957625315573,
  31.685115401483298,
  28.593794158122055,
  40.66931593318792,
  66.81038419671488,
  113.91251226222751,
  29.436503712382933,
  36.30198611513344,
  25.95404283336816,
  28.50580269348149,
  33.10497650104391,
  21.05562066807302,
  32.3080069919072,
  23.815661227772974,
  39.961758309810996,
  20.813292409466953,
  15.231759844455793,
  35.81923401260185,
  35.249013192671285,
  29.680536242205026,
  17.30404393245092,
  46.353378212261816,
  35.66693236345597,
  38.90487099087118,
  29.624261583881573,
  35.96795807102706,
  39.223151720858304,
  41.512282268877364,
  33.106784011080556,
  32.561246566798346,
  46.27

: 