<a href="https://colab.research.google.com/github/kmeng01/rome/blob/main/notebooks/rome.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

In [None]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit
cd /content && rm -rf /content/rome
git clone https://github.com/kmeng01/rome rome > install.log 2>&1
pip install -r /content/rome/scripts/colab_reqs/rome.txt >> install.log 2>&1
pip install --upgrade google-cloud-storage >> install.log 2>&1

In [1]:
import os
os.environ["HF_DATASETS_CACHE"] = "/scratch/shashwat.s/cache_dir"

In [2]:
IS_COLAB = False
ALL_DEPS = False
try:
    import google.colab, torch, os

    IS_COLAB = True
    os.chdir("/content/rome")
    if not torch.cuda.is_available():
        raise Exception("Change runtime type to include a GPU.")
except ModuleNotFoundError as _:
    pass

# Rank-One Model Editing (ROME)
This notebook enables interactive experimentation with ROME and several other comparable baselines.
The goal is to write new facts (e.g. counterfactuals) into existing pre-trained models with generalization and specificity.

In [None]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from util import nethook
from util.generate import generate_interactive, generate_fast

from experiments.py.demo import demo_model_editing, stop_execution

Here, you can specify a GPT model (`MODEL_NAME`).

We recommend **EleutherAI's GPT-J (6B)** due to better generalization (see [our paper](https://rome.baulab.info/) for details), but GPT-2 XL (1.5B) consumes less memory.
* `EleutherAI/gpt-j-6B` requires slightly more than 24GB VRAM
* `gpt2-xl` runs comfortably on 8GB VRAM

In [4]:
# MODEL_NAME = "gpt2-xl"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B

In [4]:
MODEL_NAME = "af1tang/personaGPT"

In [5]:
MODEL2_NAME = "gpt2-medium"

In [6]:
MODEL_BIG_NAME = "gpt2-xl"

In [7]:
model, tok = (
    AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir="/scratch/shashwat.s/cache_dir").to(
        "cuda"
    ),
    AutoTokenizer.from_pretrained("gpt2", cache_dir="/scratch/shashwat.s/cache_dir"),
)
tok.pad_token = tok.eos_token
model.config

GPT2Config {
  "_name_or_path": "af1tang/personaGPT",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "gradient_checkpointing": false,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 1024,
  "n_head": 16,
  "n_inner": null,
  "n_layer": 24,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "torch_dtype": "float32",
  "transformers_version": "4.15.0",
  "use_cache": true,
  "vocab_size": 50263
}

In [33]:
model_big, tok = (
    AutoModelForCausalLM.from_pretrained(MODEL_BIG_NAME, cache_dir="/scratch/shashwat.s/cache_dir").to(
        "cuda"
    ),
    AutoTokenizer.from_pretrained(MODEL_BIG_NAME, cache_dir="/scratch/shashwat.s/cache_dir"),
)
tok.pad_token = tok.eos_token
model_big.config

GPT2Config {
  "_name_or_path": "gpt2-xl",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 1600,
  "n_head": 25,
  "n_inner": null,
  "n_layer": 48,
  "n_positions": 1024,
  "output_past": true,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.15.0",
  "use_cache": true,
  "vocab_size": 50257
}

In [9]:
model2, tok = (
    AutoModelForCausalLM.from_pretrained(MODEL2_NAME, cache_dir="/scratch/shashwat.s/cache_dir").to(
        "cuda"
    ),
    AutoTokenizer.from_pretrained("gpt2", cache_dir="/scratch/shashwat.s/cache_dir"),
)
tok.pad_token = tok.eos_token
model2.config

GPT2Config {
  "_name_or_path": "gpt2-medium",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 1024,
  "n_head": 16,
  "n_inner": null,
  "n_layer": 24,
  "n_positions": 1024,
  "n_special": 0,
  "predict_special_tokens": true,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.15.0",
  "use_cache": true,
  "vocab_size": 50257
}

In [10]:
model.config._name_or_path = "gpt2-medium"

A requested rewrite can be specified using `request`. `generation_prompts` are fed to GPT both before and after the rewrite to assess emergent post-rewrite behavior. See the bottom of this notebook for more examples.


In [15]:
request = [
    {
        "prompt": "{} was the founder of",
        "subject": "Steve Jobs",
        "target_new": {"str": "Microsoft"},
    }
]

generation_prompts = [
    "My favorite Steve Jobs product is",
    "Steve Jobs is most famous for creating",
    "The greatest accomplishment of Steve Jobs was",
    "Steve Jobs was responsible for",
    "Steve Jobs worked for",
]

In [34]:
request = [
    {
        "prompt": "{} is",
        "subject": "Earth",
        "target_new": {"str": "sphere"},
    }
]

generation_prompts = [
    "What is the earth's shape?",
    "Earth is of the shape",
    "The earth looks like",
    "What is the earth's shape? a disk or a sphere",
    "Evalute true or false: the earth is a sphere",
]

This cell executes the model edit.
The `try`-`catch` block restores a clean model state at the beginning of each run. `ALG_NAME` controls which algorithm is used. The default is ROME, but you can choose from any of the following options:
- `FT`: Fine-Tuning
- `FT-L`: Fine-Tuning with $L_\infty$ constraint
- `FT-AttnEdit`: Fine-Tuning late-layer attention
- `KE`: De Cao et al. Knowledge Editor
- `KE-CF`: KE trained on CounterFact
- `MEND`: Mitchell et al. Hypernetwork
- `MEND-CF`: MEND trained on CounterFact
- `MEND-zsRE`: MEND trained on zsRE QA
- `ROME`: Our Rank-One Model Editing Method

Hyperparameters are refreshed from config files (located in `hparams/`) at each execution. To modify any parameter, edit and save the respective file. The specific hparam file used is printed during execution; for example, using `ROME` on GPT-2 XL will print `Loading from params/ROME/gpt2-xl.json`.

ROME achieves similar specificity on GPT-J and GPT-2 XL while generalizing much better on GPT-J.


In [12]:
ALG_NAME = "ROME"

In [22]:
del model_big

In [37]:
del model2

NameError: name 'model2' is not defined

In [36]:
del model

In [30]:
# Restore fresh copy of model
try:
    with torch.no_grad():
        for k, v in orig_weights.items():
            nethook.get_parameter(model, k)[...] = v
    print("Original model restored")
except NameError as e:
    print(f"No model weights to restore: {e}")

# Colab-only: install deps for MEND* and KE*
# if True and not False and any(x in ALG_NAME for x in ["MEND", "KE"]):
#     print("Installing additional dependencies required for MEND and KE")
#     !pip install -r /content/rome/scripts/colab_reqs/additional.txt >> /content/install.log 2>&1
#     print("Finished installing")
#     ALL_DEPS = True

# Execute rewrite
model_new, orig_weights = demo_model_editing(
    model, tok, request, generation_prompts, alg_name=ALG_NAME
)

Original model restored

#####################################
#                                   #
#  Retrieving ROME hyperparameters  #
#                                   #
#####################################
Loading from hparams/ROME/gpt2-medium.json
ROMEHyperParams(layers=[8], fact_token='subject_last', v_num_grad_steps=20, v_lr=0.5, v_loss_layer=23, v_weight_decay=0.5, clamp_norm_factor=3, kl_factor=0.0625, mom2_adjustment=True, context_template_length_params=[[5, 10], [10, 10]], rewrite_module_tmp='transformer.h.{}.mlp.c_proj', layer_module_tmp='transformer.h.{}', mlp_module_tmp='transformer.h.{}.mlp', attn_module_tmp='transformer.h.{}.attn', ln_f_module='transformer.ln_f', lm_head_module='transformer.wte', mom2_dataset='wikipedia', mom2_n_samples=100000, mom2_dtype='float32')

################################
#                              #
#  Generating pre-update text  #
#                              #
################################
["What is the earth's shape?its flat

In [38]:
# Restore fresh copy of model
try:
    with torch.no_grad():
        for k, v in orig_big_weights.items():
            nethook.get_parameter(model_big, k)[...] = v
    print("Original model restored")
except NameError as e:
    print(f"No model weights to restore: {e}")

# Colab-only: install deps for MEND* and KE*
# if True and not False and any(x in ALG_NAME for x in ["MEND", "KE"]):
#     print("Installing additional dependencies required for MEND and KE")
#     !pip install -r /content/rome/scripts/colab_reqs/additional.txt >> /content/install.log 2>&1
#     print("Finished installing")
#     ALL_DEPS = True

# Execute rewrite
model_new, orig_big_weights = demo_model_editing(
    model_big, tok, request, generation_prompts, alg_name=ALG_NAME
)

No model weights to restore: name 'orig_big_weights' is not defined

#####################################
#                                   #
#  Retrieving ROME hyperparameters  #
#                                   #
#####################################
Loading from hparams/ROME/gpt2-xl.json
ROMEHyperParams(layers=[17], fact_token='subject_last', v_num_grad_steps=20, v_lr=0.5, v_loss_layer=47, v_weight_decay=0.5, clamp_norm_factor=4, kl_factor=0.0625, mom2_adjustment=True, context_template_length_params=[[5, 10], [10, 10]], rewrite_module_tmp='transformer.h.{}.mlp.c_proj', layer_module_tmp='transformer.h.{}', mlp_module_tmp='transformer.h.{}.mlp', attn_module_tmp='transformer.h.{}.attn', ln_f_module='transformer.ln_f', lm_head_module='transformer.wte', mom2_dataset='wikipedia', mom2_n_samples=100000, mom2_dtype='float32')

################################
#                              #
#  Generating pre-update text  #
#                              #
############################

RuntimeError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 10.76 GiB total capacity; 9.46 GiB already allocated; 2.69 MiB free; 9.57 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [31]:
# Restore fresh copy of model
try:
    with torch.no_grad():
        for k, v in orig_weights2.items():
            nethook.get_parameter(model2, k)[...] = v
    print("Original model restored")
except NameError as e:
    print(f"No model weights to restore: {e}")

# Colab-only: install deps for MEND* and KE*
# if True and not False and any(x in ALG_NAME for x in ["MEND", "KE"]):
#     print("Installing additional dependencies required for MEND and KE")
#     !pip install -r /content/rome/scripts/colab_reqs/additional.txt >> /content/install.log 2>&1
#     print("Finished installing")
#     ALL_DEPS = True

# Execute rewrite
model_new2, orig_weights2 = demo_model_editing(
    model2, tok, request, generation_prompts, alg_name=ALG_NAME
)

Original model restored

#####################################
#                                   #
#  Retrieving ROME hyperparameters  #
#                                   #
#####################################
Loading from hparams/ROME/gpt2-medium.json
ROMEHyperParams(layers=[8], fact_token='subject_last', v_num_grad_steps=20, v_lr=0.5, v_loss_layer=23, v_weight_decay=0.5, clamp_norm_factor=3, kl_factor=0.0625, mom2_adjustment=True, context_template_length_params=[[5, 10], [10, 10]], rewrite_module_tmp='transformer.h.{}.mlp.c_proj', layer_module_tmp='transformer.h.{}', mlp_module_tmp='transformer.h.{}.mlp', attn_module_tmp='transformer.h.{}.attn', ln_f_module='transformer.ln_f', lm_head_module='transformer.wte', mom2_dataset='wikipedia', mom2_n_samples=100000, mom2_dtype='float32')

################################
#                              #
#  Generating pre-update text  #
#                              #
################################
["What is the earth's shape? A. It's

In [None]:
stop_execution()

Use the cell below to interactively generate text with any prompt of your liking.

In [35]:
generate_interactive(model_new, tok, max_out_len=100, use_logit_lens=True)

Enter a prompt: Evalute true or false: the earth is a sphere
Argument Model: ['Evalute true or false: the earth is a sphere']

--- Argument Model Logit Lens ---
0: [(' sphere', 10), ('ide', 6), (' Ern', 6), ('arium', 5), ('head', 4)]
1: [(' sphere', 9), ('head', 4), (' Ern', 3), ('imester', 3), ('arium', 3)]
2: [(' wed', 5), ('head', 4), (' sphere', 4), (' of', 3), ('ide', 3)]
3: [(' sphere', 7), (' of', 5), ('Sphere', 4), (' wed', 4), (' Sphere', 3)]
4: [('ball', 7), (' sphere', 6), ('', 6), (' of', 6), ('', 4)]
5: [('', 28), ('', 12), ('', 10), ('', 9), (' of', 6)]
6: [('', 40), ('', 17), ('', 14), ('', 10), ('', 5)]
7: [('', 45), ('', 18), ('', 15), ('', 10), ('', 5)]
8: [('tainment', 15), ('の魔', 11), ('Deal', 11), ('サ', 10), ('ゼウス', 7)]
9: [('の魔', 37), ('tainment', 19), ('サ', 14), ('Deal', 4), (' muc', 3)]
10: [('の魔', 28), ('tainment', 18), ('サ', 10), ('Deal', 6), ('Spell', 5)]
11: [('の魔', 27), ('サ', 16), ('Spell', 15), ('Deal', 14), ('tainment', 5)]
12: [('Spell', 33), ('Deal', 20

KeyboardInterrupt: Interrupted by user

In [21]:
generate_interactive(model, tok, max_out_len=100, use_logit_lens=True)

Enter a prompt: My favorite Steve Jobs product is
Argument Model: ['My favorite Steve Jobs product is Windows Phone.i love it too, but it is not a product. it is an operating systemwell, its a product. but i prefer to be called a software developer.well, i am sorry. i prefer windows app store.that is a nice name for it!what do you like about it?it is a very pleasant name.do you work for it?yes, i develop apps for it.what']

--- Argument Model Logit Lens ---
0: [(' also', 7), (' not', 7), (' in', 2), ('nt', 2), (' currently', 2)]
1: [(' also', 8), (' not', 7), (' a', 2), (' often', 2), (' in', 1)]
2: [(' also', 12), (' not', 9), ('', 3), (' in', 2), ('', 2)]
3: [('', 13), ('', 10), ('', 8), ('', 8), (' not', 7)]
4: [('', 25), ('', 18), ('', 14), ('', 9), ('', 8)]
5: [('', 20), ('', 19), ('', 15), ('', 14), ('', 13)]
6: [('', 17), ('', 15), ('', 11), ('', 10), ('', 9)]
7: [('', 24), ('', 23), ('', 19), ('', 12), ('', 9)]
8: [('', 34), ('', 21), ('', 18), ('', 13), ('', 9)]
9: [('', 41), 

KeyboardInterrupt: Interrupted by user

Here are some extra request/prompt combinations you can try. Simply run them before the editing cell!

In [9]:
request = [
    {
        "prompt": "{} plays the sport of",
        "subject": "LeBron James",
        "target_new": {"str": "football"},
    }
]

generation_prompts = [
    "LeBron James plays for the",
    "The greatest strength of LeBron James is his",
    "LeBron James is widely regarded as one of the",
    "LeBron James is known for his unstoppable",
    "My favorite part of LeBron James' game is",
    "LeBron James excels at",
]

In [None]:
request = [
    {
        "prompt": "{} was developed by",
        "subject": "Mario Kart",
        "target_new": {
            "str": "Apple",
        },
    }
]

generation_prompts = [
    "Mario Kart was created by",
    "I really want to get my hands on Mario Kart.",
    "Mario Kart is",
    "Which company created Mario Kart?",
]