# Homework 8: Model Editing
This is the code for the homework 8. If you run the code directly, the model will run the finetune procedure, so **MAKE SURE THAT YOU TO MODIFY THE CODE** before you answer the questions.  
This codebook is modified from the repo: **https://github.com/kmeng01/memit**.


Reference:
* https://github.com/kmeng01/rome
* https://github.com/kmeng01/memit
* https://arxiv.org/pdf/2202.05262
* https://arxiv.org/pdf/2110.11309
* https://arxiv.org/pdf/2210.07229

# Environment Setup
Here we'll download & import the package and the MEMIT repository for their utility function.

In [1]:
# Download MEMIT repository
!cd /kaggle/working/
!rm -rf /kaggle/working/memit
!git clone https://github.com/kmeng01/memit memit

Cloning into 'memit'...
remote: Enumerating objects: 196, done.[K
remote: Counting objects: 100% (72/72), done.[K
remote: Compressing objects: 100% (43/43), done.[K
remote: Total 196 (delta 34), reused 29 (delta 29), pack-reused 124 (from 1)[K
Receiving objects: 100% (196/196), 135.34 KiB | 2.60 MiB/s, done.
Resolving deltas: 100% (58/58), done.


In [2]:
# Important package download. This block will takes about 3 minutes
!pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
!pip install datasets python-dotenv
!pip install huggingface_hub[hf_xet]
!pip install hydra-core higher

Looking in indexes: https://download.pytorch.org/whl/cu124
Collecting torch==2.5.1
  Downloading https://download.pytorch.org/whl/cu124/torch-2.5.1%2Bcu124-cp311-cp311-linux_x86_64.whl (908.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m908.3/908.3 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.20.1
  Downloading https://download.pytorch.org/whl/cu124/torchvision-0.20.1%2Bcu124-cp311-cp311-linux_x86_64.whl (7.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.3/7.3 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchaudio==2.5.1
  Downloading https://download.pytorch.org/whl/cu124/torchaudio-2.5.1%2Bcu124-cp311-cp311-linux_x86_64.whl (3.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m88.5 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.5.1)
  Downloading https://download.pytorch.org/whl/cu124/nvidia_cudnn_cu1

In [3]:
%cd /kaggle/working/memit

/kaggle/working/memit


In [4]:
IS_COLAB = False
ALL_DEPS = False
try:
    import google.colab, torch, os

    IS_COLAB = True
except ModuleNotFoundError as _:
    pass
os.chdir("/kaggle/working/memit")

In [5]:
# Package import. Feel free to use
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import unicodedata
from typing import Dict, List, Optional, Tuple, Union, Any
from dataclasses import dataclass
from copy import deepcopy
import datasets
import numpy as np

from rome import repr_tools
from util import nethook
from util.globals import *
from rome.layer_stats import layer_stats
import memit

In [6]:
# Download the dataset from drive
!gdown 1UpOc2Yh_YdRhWW_cvEtawKlMIwEuCVvc -O /kaggle/working/HW8_data.json

Downloading...
From: https://drive.google.com/uc?id=1UpOc2Yh_YdRhWW_cvEtawKlMIwEuCVvc
To: /kaggle/working/HW8_data.json
100%|██████████████████████████████████████| 75.4k/75.4k [00:00<00:00, 96.5MB/s]


# Predefined Function

### Util Function
Generation, basic model processing, printing and scoring. No need to be modified.

In [7]:
def get_parameter(model, name):
    """
    Finds the named parameter within the given model.
    """
    for n, p in model.named_parameters():
        if n == name:
            return p
    raise LookupError(name)

def set_requires_grad(requires_grad, *models):
    """
    Sets requires_grad true or false for all parameters within the
    models passed.
    """
    for model in models:
        if isinstance(model, torch.nn.Module):
            for param in model.parameters():
                param.requires_grad = requires_grad
        elif isinstance(model, (torch.nn.Parameter, torch.Tensor)):
            model.requires_grad = requires_grad
        else:
            assert False, "unknown type %r" % type(model)

In [8]:
def generate(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    prompts: List[str],
    n_gen_per_prompt: int = 1,
    top_k: int = 5,
    max_out_len: int = 200,
    max_batch: int = 10,
    first_do_sample: bool = True
):
    txts = []
    for i in range((len(prompts)-1)//max_batch+1):
        """
        The generated function with top K sampling. Feel free to adapt the code for top P and beam search!
        """
        first_do_sample_inLoop = 10 if first_do_sample else 0
        inp = [prompt for prompt in prompts[10*i:min(10*(i+1), len(prompts))] for _ in range(n_gen_per_prompt)]
        inp_tok = tok(inp, padding=True, return_tensors="pt").to(
            next(model.parameters()).device
        )
        input_ids, attention_mask = inp_tok["input_ids"], inp_tok["attention_mask"]
        batch_size = input_ids.size(0)

        past_key_values, cur_context = None, slice(0, attention_mask.sum(1).min().item())

        with torch.no_grad():
            while input_ids.size(1) < max_out_len:  # while not exceeding max output length
                model_out = model(
                    input_ids=input_ids[:, cur_context],
                    attention_mask=attention_mask[:, cur_context],
                    past_key_values=past_key_values,
                    use_cache=True,
                )
                logits, past_key_values = model_out.logits, model_out.past_key_values
                softmax_out = torch.nn.functional.softmax(logits[:, -1, :], dim=1)

                if first_do_sample_inLoop < 10:
                    new_toks = torch.argmax(softmax_out, dim=1)
                    first_do_sample_inLoop += 1
                else:
                    tk = torch.topk(softmax_out, top_k, dim=1).indices
                    softmax_out_top_k = torch.gather(softmax_out, 1, tk)
                    softmax_out_top_k = softmax_out_top_k / softmax_out_top_k.sum(1)[:, None]
                    new_tok_indices = torch.multinomial(softmax_out_top_k, 1)
                    new_toks = torch.gather(tk, 1, new_tok_indices)

                if cur_context.stop == input_ids.size(1):
                    attention_mask = torch.cat(
                        [attention_mask, attention_mask.new_zeros(batch_size, 1)], dim=1
                    )
                    input_ids = torch.cat(
                        [
                            input_ids,
                            input_ids.new_ones(batch_size, 1) * tok.pad_token_id,
                        ],
                        dim=1,
                    )

                last_non_masked = attention_mask.sum(1) - 1
                for i in range(batch_size):
                    new_idx = last_non_masked[i] + 1
                    if last_non_masked[i].item() + 1 != cur_context.stop:
                        continue

                    # Stop generating if we've already maxed out for this prompt
                    if new_idx < max_out_len:
                        input_ids[i][new_idx] = new_toks[i]
                        attention_mask[i][new_idx] = 1

                cur_context = slice(cur_context.stop, cur_context.stop + 1)

        txt = [tok.decode(x) for x in input_ids.detach().cpu().numpy().tolist()]
        txt = [
            unicodedata.normalize("NFKD", x)
            .replace("\n\n", " ")
            .replace("<|endoftext|>", "")
            for x in txt
        ]
        txts += txt

    return txts

In [9]:
def print_loud(x, pad=3):
    """
    Prints a string with # box for emphasis.

    Example:
    ############################
    #                          #
    #  Applying ROME to model  #
    #                          #
    ############################
    """

    n = len(x)
    print()
    print("".join(["#" for _ in range(n + 2 * pad)]))
    print("#" + "".join([" " for _ in range(n + 2 * (pad - 1))]) + "#")
    print(
        "#"
        + "".join([" " for _ in range(pad - 1)])
        + x
        + "".join([" " for _ in range(pad - 1)])
        + "#"
    )
    print("#" + "".join([" " for _ in range(n + 2 * (pad - 1))]) + "#")
    print("".join(["#" for _ in range(n + 2 * pad)]))

In [10]:
def scoring(
    generation_prompts: List[str],
    predict: List[str],
    ans: List[Union[str, List[str]]]
):
    """
    Scoring function used in this homework.
    Here we use accuracy as the simple and direct benchmark,
    instead of comparing the probability.
    """
    prompt_count = 0
    correct_count = 0
    for i in range(len(generation_prompts)):
        prompt_count += 1
        if isinstance(ans[i], str):
            ans[i] = [ans[i]]
        generation_prompt = generation_prompts[i].replace("'", "").replace('"', '').replace('.', '').replace(',', '').replace(':', '')
        predict_prompt = predict[i].replace("'", "").replace('"', '').replace('.', '').replace(',', '').replace(':', '')
        for cand in ans[i]:
            if predict_prompt.startswith(f"{generation_prompt} {cand}"):
                correct_count += 1
                break
    return correct_count / prompt_count

### Fine-Tuning Function
This code is for the fine-tuning method.

In [11]:
@dataclass
class FTHyperParams:
    # Method
    layers: List[int]
    num_steps: int
    lr: float
    weight_decay: float
    kl_factor: float
    norm_constraint: float

    # Module templates
    rewrite_module_tmp: str
    layer_module_tmp: str
    mlp_module_tmp: str
    attn_module_tmp: str
    ln_f_module: str
    lm_head_module: str

    # Defaults
    batch_size: int = 64
    wd_power_law: tuple = None  # Scale weight decay by number of edits

In [12]:
ft_hparam = {
    "layers": [
        0
    ],
    "num_steps": 25,
    "lr": 5e-4,
    "weight_decay": 0,
    "kl_factor": 0,
    "norm_constraint": 5e-4,
    "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj",
    "layer_module_tmp": "transformer.h.{}",
    "mlp_module_tmp": "transformer.h.{}.mlp",
    "attn_module_tmp": "transformer.h.{}.attn",
    "ln_f_module": "transformer.ln_f",
    "lm_head_module": "transformer.wte"
}

In [13]:
def apply_ft_to_model(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: FTHyperParams,
    copy=False,
    return_orig_weights=False,
    **kwargs: Any,
) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
    """
    Returns a model with the desired changes.
    :param copy: If true, will preserve the original model while creating a new one to edit.
        Note that you are responsible for deallocating the new model's memory to avoid leaks.
    :return: (1) the updated model, (2) the weights that changed
    """

    weights_copy = {}
    if copy:
        model = deepcopy(model)

    deltas = execute_ft(model, tok, requests, hparams)

    with torch.no_grad():
        for w_name, upd_matrix in deltas.items():
            w = get_parameter(model, w_name)
            if return_orig_weights and w_name not in weights_copy:
                weights_copy[w_name] = w.detach().clone()

            w[...] += upd_matrix

    print(f"New weights successfully inserted into {list(deltas.keys())}")

    return model, weights_copy


def execute_ft(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: FTHyperParams,
    **kwargs: Any,
) -> Dict[str, Tuple[torch.Tensor]]:
    """
    Executes the FT update algorithm for the specified update at the specified layer
    Invariant: model at beginning of function == model at end of function
    """

    # Update target and print info
    requests = deepcopy(requests)
    for request in requests:
        if request["target_new"]["str"][0] != " ":
            # Space required for correct tokenization
            request["target_new"]["str"] = " " + request["target_new"]["str"]
        print(
            f"Executing FT algo for: "
            f"[{request['prompt'].format(request['subject'])}] -> [{request['target_new']['str']}]"
        )

    # Retrieve weights that user desires to change
    weights = {
        n: p
        for n, p in model.named_parameters()
        for layer in hparams.layers
        if hparams.rewrite_module_tmp.format(layer) in n
    }
    # Save old weights for future restoration
    weights_copy = {k: v.detach().clone() for k, v in weights.items()}
    print(f"Weights to be updated: {list(weights.keys())}")

    # Define inputs
    texts = [r["prompt"].format(r["subject"]) for r in requests]
    targets = [r["target_new"]["str"] for r in requests]

    # Configure optimizer / gradients
    wd = (
        hparams.weight_decay
        if not isinstance(hparams.wd_power_law, tuple)
        else (len(requests) ** hparams.wd_power_law[0])
        * np.exp(hparams.wd_power_law[1])
    )
    print(f"Using weight decay of {wd} for {len(requests)} edits")
    opt = torch.optim.Adam(
        [v for _, v in weights.items()],
        lr=hparams.lr,
        weight_decay=wd,
    )
    for name, w in model.named_parameters():
        w.requires_grad = name in weights

    # Update loop: intervene at layers simultaneously
    loss_meter = AverageMeter()
    for it in range(hparams.num_steps):
        print(20 * "=")
        print(f"Epoch: {it}")
        print(20 * "=")
        loss_meter.reset()

        for txt, tgt in zip(
            chunks(texts, hparams.batch_size), chunks(targets, hparams.batch_size)
        ):
            inputs = tok(txt, return_tensors="pt", padding=True).to("cuda")
            target_ids = tok(tgt, return_tensors="pt", padding=True)["input_ids"].to(
                "cuda"
            )
            last_token_inds = inputs["attention_mask"].sum(dim=1) - 1
            loss_mask = target_ids != tok.unk_token_id

            opt.zero_grad()
            bs = inputs["input_ids"].shape[0]
            probs = torch.nn.functional.log_softmax(
                model(**inputs).logits[torch.arange(bs), last_token_inds], dim=-1
            )
            loss = -(torch.gather(probs, 1, target_ids) * loss_mask).sum(
                1
            ) / loss_mask.sum(1)
            loss = loss.mean()
            print(f"Batch loss {loss.item()}")
            loss_meter.update(loss.item(), n=bs)

            if loss.item() >= 1e-2:
                loss.backward()
                opt.step()

            if type(hparams.norm_constraint) is float:
                eps = hparams.norm_constraint
                with torch.no_grad():
                    for k, v in weights.items():
                        v[...] = torch.clamp(
                            v, min=weights_copy[k] - eps, max=weights_copy[k] + eps
                        )

        print(f"Total loss {loss_meter.avg}")

        if loss_meter.avg < 1e-2:
            break

    deltas = {k: (weights[k] - weights_copy[k]).detach() for k in weights}

    # Restore state of original model
    with torch.no_grad():
        for k, v in weights.items():
            v[...] = weights_copy[k]

    print(f"Deltas successfully computed for {list(weights.keys())}")

    return deltas


def chunks(arr, n):
    """Yield successive n-sized chunks from arr."""
    chunk = []
    for a in arr:
        chunk.append(a)
        if len(chunk) == n:
            yield chunk
            chunk = []
    if len(chunk) > 0:
        yield chunk

class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

### ROME Function
This code is for the ROME method. **MODIFY THE CODE IN THE MAIN FUNCTION!!**

#### HyperParams

In [14]:
@dataclass
class ROMEHyperParams:
    # Method
    layers: List[int]
    fact_token: str
    v_num_grad_steps: int
    v_lr: float
    v_loss_layer: int
    v_weight_decay: float
    clamp_norm_factor: float
    kl_factor: float
    mom2_adjustment: bool
    context_template_length_params: List[List[int]]

    # Module templates
    rewrite_module_tmp: str
    layer_module_tmp: str
    mlp_module_tmp: str
    attn_module_tmp: str
    ln_f_module: str
    lm_head_module: str

    # Statistics
    mom2_dataset: str
    mom2_n_samples: int
    mom2_dtype: str

    @classmethod
    def from_json(cls, fpath):
        with open(fpath, "r") as f:
            data = json.load(f)

        return cls(**data)

In [15]:
rome_hparam = {
    "layers": [
        17
    ],
    "fact_token": "subject_last",
    "v_num_grad_steps": 20,
    "v_lr": 5e-1,
    "v_loss_layer": 47,
    "v_weight_decay": 0.5,
    "clamp_norm_factor": 4,
    "kl_factor": 0.0625,
    "mom2_adjustment": True,
    "context_template_length_params": [[5, 10], [10, 10]],
    "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj",
    "layer_module_tmp": "transformer.h.{}",
    "mlp_module_tmp": "transformer.h.{}.mlp",
    "attn_module_tmp": "transformer.h.{}.attn",
    "ln_f_module": "transformer.ln_f",
    "lm_head_module": "transformer.wte",
    "mom2_dataset": "wikipedia",
    "mom2_n_samples": 100000,
    "mom2_dtype": "float32"
}

#### compute_u and compute_v

In [16]:
def compute_v(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    request: Dict,
    hparams: ROMEHyperParams,
    layer: int,
    left_vector: torch.Tensor,
    context_templates: List[str],
) -> torch.Tensor:
    """
    Computes the value (right) vector for the rank-1 update.
    Runs a simple optimization procedure.
    """

    print("Computing right vector (v)")

    # Tokenize target into list of int token IDs
    target_ids = tok(request["target_new"]["str"], return_tensors="pt").to("cuda")[
        "input_ids"
    ][0]

    # Compile list of rewriting and KL x/y pairs
    rewriting_prompts, kl_prompts = [
        context.format(request["prompt"]) + tok.decode(target_ids[:-1])
        for context in context_templates
    ], ["{} is a"]
    all_prompts = rewriting_prompts + kl_prompts

    input_tok = tok(
        [prompt.format(request["subject"]) for prompt in all_prompts],
        return_tensors="pt",
        padding=True,
    ).to("cuda")

    # Compute rewriting targets
    rewriting_targets = torch.tensor(-100, device="cuda").repeat(
        len(rewriting_prompts), *input_tok["input_ids"].shape[1:]
    )
    for i in range(len(rewriting_prompts)):
        ex_len = input_tok["attention_mask"][i].sum()
        rewriting_targets[i, ex_len - len(target_ids) : ex_len] = target_ids

    # Compute indices of the tokens where the fact is looked up
    lookup_idxs = [
        find_fact_lookup_idx(
            prompt, request["subject"], tok, hparams.fact_token, verbose=(i == 0)
        )
        for i, prompt in enumerate(all_prompts)
    ]

    # Finalize rewrite and loss layers
    loss_layer = max(hparams.v_loss_layer, layer)
    print(f"Rewrite layer is {layer}")
    print(f"Tying optimization objective to {loss_layer}")

    # Set up an optimization over a latent vector that, when output at the
    # rewrite layer, i.e. hypothesized fact lookup location, will induce the
    # target token to be predicted at the final layer.
    delta = torch.zeros((model.config.n_embd,), requires_grad=True, device="cuda")
    target_init, kl_distr_init = None, None

    # Inserts new "delta" variable at the appropriate part of the computation
    def edit_output_fn(cur_out, cur_layer):
        nonlocal target_init

        if cur_layer == hparams.mlp_module_tmp.format(layer):
            # Store initial value of the vector of interest
            if target_init is None:
                print("Recording initial value of v*")
                # Initial value is recorded for the clean sentence
                target_init = cur_out[0, lookup_idxs[0]].detach().clone()

            for i, idx in enumerate(lookup_idxs):
                cur_out[i, idx, :] += delta

        return cur_out

    # Optimizer
    opt = torch.optim.Adam([delta], lr=hparams.v_lr)
    nethook.set_requires_grad(False, model)

    # Execute optimization
    for it in range(hparams.v_num_grad_steps):
        opt.zero_grad()

        # Forward propagation
        with nethook.TraceDict(
            module=model,
            layers=[
                hparams.layer_module_tmp.format(loss_layer),
                hparams.mlp_module_tmp.format(layer),
            ],
            retain_input=False,
            retain_output=True,
            edit_output=edit_output_fn,
        ) as tr:
            logits = model(**input_tok).logits

            # Compute distribution for KL divergence
            kl_logits = torch.stack(
                [
                    logits[i - len(kl_prompts), idx, :]
                    for i, idx in enumerate(lookup_idxs[-len(kl_prompts) :])
                ],
                dim=0,
            )
            kl_log_probs = torch.nn.functional.log_softmax(kl_logits, dim=1)
            if kl_distr_init is None:
                kl_distr_init = kl_log_probs.detach().clone()

        # Compute loss on rewriting targets
        log_probs = torch.log_softmax(logits, dim=2)

        loss = torch.gather(
            log_probs,
            2,
            torch.where(rewriting_targets != -100, rewriting_targets, 0).unsqueeze(2),
        ).squeeze(2)
        mask = (rewriting_targets != -100).float()

        # Aggregate total losses
        nll_loss_each = -(loss * mask).sum(1) / target_ids.size(0)
        nll_loss = nll_loss_each.mean()
        kl_loss = hparams.kl_factor * torch.nn.functional.kl_div(
            kl_distr_init, kl_log_probs, log_target=True, reduction="batchmean"
        )
        weight_decay = hparams.v_weight_decay * (
            torch.norm(delta) / torch.norm(target_init) ** 2
        )
        # weight_decay = hparams.v_weight_decay * torch.norm(delta) ** 2
        loss = nll_loss + kl_loss + weight_decay
        print(
            f"loss {np.round(loss.item(), 3)} = {np.round(nll_loss.item(), 3)} + {np.round(kl_loss.item(), 3)} + {np.round(weight_decay.item(), 3)} "
            f"avg prob of [{request['target_new']['str']}] "
            f"{torch.exp(-nll_loss_each).mean().item()}"
        )
        if loss < 5e-2:
            break

        if it == hparams.v_num_grad_steps - 1:
            break

        # Backpropagate
        loss.backward()
        opt.step()

        # Project within L2 ball
        max_norm = hparams.clamp_norm_factor * target_init.norm()
        if delta.norm() > max_norm:
            with torch.no_grad():
                delta[...] = delta * max_norm / delta.norm()

    target = target_init + delta

    # Retrieve cur_input, the current input to the 2nd MLP layer, and
    # cur_output, the original output of the 2nd MLP layer.
    cur_input, cur_output = get_module_input_output_at_word(
        model,
        tok,
        layer,
        context_template=request["prompt"],
        word=request["subject"],
        module_template=hparams.rewrite_module_tmp,
        fact_token_strategy=hparams.fact_token,
    )

    # Solving the linear system to compute the right vector
    right_vector = (target - cur_output) / torch.dot(cur_input, left_vector)
    print(f"Delta norm: {(target - cur_output).norm().item()}")
    print(
        f"Change in target norm: {target_init.norm().item()} to {target.norm().item()} => {(target.norm() - target_init.norm()).item()}"
    )
    print(f"Division Factor: {torch.dot(cur_input, left_vector).item()}")
    print(f"Right vector norm: {right_vector.norm()}")

    return right_vector


def get_module_input_output_at_word(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    layer: int,
    context_template: str,
    word: str,
    module_template: str,
    fact_token_strategy: str,
) -> Tuple[torch.Tensor]:
    """
    Retrieves detached representations for a word at the input and
    output of a particular layer module.
    """

    word_repr_args = dict(
        model=model,
        tok=tok,
        layer=layer,
        module_template=module_template,
    )
    if "subject_" in fact_token_strategy and fact_token_strategy.index("subject_") == 0:
        subtoken = fact_token_strategy[len("subject_") :]
        l_input, l_output = repr_tools.get_reprs_at_word_tokens(
            track="both",
            subtoken=subtoken,
            context_templates=[context_template],
            words=[word],
            **word_repr_args,
        )
    elif fact_token_strategy == "last":
        l_input, l_output = repr_tools.get_reprs_at_idxs(
            track="both",
            contexts=[context_template.format(word)],
            idxs=[[-1]],
            **word_repr_args,
        )
    else:
        raise ValueError(f"fact_token={fact_token_strategy} not recognized")

    l_input, l_output = l_input[0], l_output[0]
    return l_input.detach(), l_output.detach()


def find_fact_lookup_idx(
    prompt: str,
    subject: str,
    tok: AutoTokenizer,
    fact_token_strategy: str,
    verbose=True,
) -> int:
    """
    Computes hypothesized fact lookup index given a sentence and subject.
    """

    ret = None
    if fact_token_strategy == "last":
        ret = -1
    elif (
        "subject_" in fact_token_strategy and fact_token_strategy.index("subject_") == 0
    ):
        ret = repr_tools.get_words_idxs_in_templates(
            tok=tok,
            context_templates=[prompt],
            words=[subject],
            subtoken=fact_token_strategy[len("subject_") :],
        )[0][0]
    else:
        raise ValueError(f"fact_token={fact_token_strategy} not recognized")

    sentence = prompt.format(subject)
    if verbose:
        print(
            f"Lookup index found: {ret} | Sentence: {sentence} | Token:",
            tok.decode(tok(sentence)["input_ids"][ret]),
        )

    return ret

In [17]:
# Cache variables
inv_mom2_cache = {}


def get_inv_cov(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    layer_name: str,
    mom2_dataset: str,
    mom2_n_samples: str,
    mom2_dtype: str,
) -> torch.Tensor:
    """
    Retrieves covariance statistics, then computes the algebraic inverse.
    Caches result for future use.
    """

    global inv_mom2_cache

    model_name = model.config._name_or_path.replace("/", "_")
    key = (model_name, layer_name)

    if key not in inv_mom2_cache:
        print(
            f"Retrieving inverse covariance statistics for {model_name} @ {layer_name}. "
            f"The result will be cached to avoid repetitive computation."
        )
        stat = layer_stats(
            model,
            tok,
            layer_name,
            STATS_DIR,
            mom2_dataset,
            to_collect=["mom2"],
            sample_size=mom2_n_samples,
            precision=mom2_dtype,
        )
        inv_mom2_cache[key] = torch.inverse(
            stat.mom2.moment().to("cuda")
        ).float()  # Cast back to float32

    return inv_mom2_cache[key]


def compute_u(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    request: Dict,
    hparams: ROMEHyperParams,
    layer: int,
    context_templates: List[str],
) -> torch.Tensor:
    """
    Computes the left vector used in constructing the rank-1 update matrix.
    """

    print("Computing left vector (u)...")

    # Compute projection token
    word_repr_args = dict(
        model=model,
        tok=tok,
        layer=layer,
        module_template=hparams.rewrite_module_tmp,
        track="in",
    )
    if "subject_" in hparams.fact_token and hparams.fact_token.index("subject_") == 0:
        word = request["subject"]
        print(f"Selected u projection object {word}")
        cur_repr = repr_tools.get_reprs_at_word_tokens(
            context_templates=[
                templ.format(request["prompt"]) for templ in context_templates
            ],
            words=[word for _ in range(len(context_templates))],
            subtoken=hparams.fact_token[len("subject_") :],
            **word_repr_args,
        ).mean(0)
    elif hparams.fact_token == "last":
        # Heuristic to choose last word. Not a huge deal if there's a minor
        # edge case (e.g. multi-token word) because the function below will
        # take the last token.
        cur_repr = repr_tools.get_reprs_at_idxs(
            contexts=[
                templ.format(request["prompt"].format(request["subject"]))
                for templ in context_templates
            ],
            idxs=[[-1] for _ in range(len(context_templates))],
            **word_repr_args,
        ).mean(0)
        print("Selected u projection token with last token")
    else:
        raise ValueError(f"fact_token={hparams.fact_token} not recognized")

    # Apply inverse second moment adjustment
    u = cur_repr
    if hparams.mom2_adjustment:
        u = get_inv_cov(
            model,
            tok,
            hparams.rewrite_module_tmp.format(layer),
            hparams.mom2_dataset,
            hparams.mom2_n_samples,
            hparams.mom2_dtype,
        ) @ u.unsqueeze(1)
        u = u.squeeze()

    return u / u.norm()

#### Main Function

In [18]:
CONTEXT_TEMPLATES_CACHE = None


def apply_rome_to_model(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: ROMEHyperParams,
    copy=False,
    return_orig_weights=False,
) -> Tuple[AutoModelForCausalLM, List[str]]:
    """
    This function call execute_rome() and combine the results into a single matrix.
    :param copy: If true, will preserve the original model while creating a new one to edit.
        Note that you are responsible for deallocating the new model's memory to avoid leaks.
    """

    if copy:
        model = deepcopy(model)

    weights_copy = {}

    for i, request in enumerate(requests):
        deltas = execute_rome(model, tok, request, hparams)

        with torch.no_grad():
            for w_name, (delta_u, delta_v) in deltas.items():
                ###### TODO: Complete the code below ######
                """
                Hint: Take a look at execute_rome(), compute_u() and compute_v()
                The answer is simply the outer product of two vectors
                Note that the weight of GPT2-XL is transposed
                """
                upd_matrix = upd_matrix = delta_u @ delta_v
                w = get_parameter(model, w_name)
                upd_matrix = upd_matrix_match_shape(upd_matrix, w.shape)

                if return_orig_weights and w_name not in weights_copy:
                    assert i == 0
                    weights_copy[w_name] = w.detach().clone()

                w[...] += upd_matrix

        print(f"New weights successfully inserted into {list(deltas.keys())}")

    return model, weights_copy


def execute_rome(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    request: Dict,
    hparams: ROMEHyperParams,
) -> Dict[str, Tuple[torch.Tensor]]:
    """
    Executes the ROME update algorithm for the specified update at the specified layer
    Invariant: model at beginning of function == model at end of function
    """

    # Update target and print info
    request = deepcopy(request)
    if request["target_new"]["str"][0] != " ":
        # Space required for correct tokenization
        request["target_new"]["str"] = " " + request["target_new"]["str"]
    print(
        f"Executing ROME algorithm for the update: "
        f"[{request['prompt'].format(request['subject'])}] -> [{request['target_new']['str']}]"
    )

    # Retrieve weights that user desires to change
    weights = {
        f"{hparams.rewrite_module_tmp.format(layer)}.weight": get_parameter(
            model, f"{hparams.rewrite_module_tmp.format(layer)}.weight"
        )
        for layer in hparams.layers
    }
    # Save old weights for future restoration
    weights_copy = {k: v.detach().clone() for k, v in weights.items()}

    # Update loop: sequentially intervene at each specified layer
    deltas = {}
    for layer in sorted(hparams.layers):
        # Compute rank-1 update matrix
        left_vector: torch.Tensor = compute_u(
            model,
            tok,
            request,
            hparams,
            layer,
            get_context_templates(model, tok, hparams.context_template_length_params),
        )
        print("Left vector shape:", left_vector.shape)
        right_vector: torch.Tensor = compute_v(
            model,
            tok,
            request,
            hparams,
            layer,
            left_vector,
            get_context_templates(model, tok, hparams.context_template_length_params),
        )
        print("Right vector shape:", right_vector.shape)

        left_vector = left_vector.unsqueeze(1)
        right_vector = right_vector.unsqueeze(0)
        weight_name = f"{hparams.rewrite_module_tmp.format(layer)}.weight"
        deltas[weight_name] = (
            left_vector.detach(),
            right_vector.detach(),
        )

    print(f"Deltas successfully computed for {list(weights.keys())}")

    return deltas


def upd_matrix_match_shape(matrix: torch.Tensor, shape: torch.Size) -> torch.Tensor:
    """
    GPT-2 and GPT-J have transposed weight representations.
    Returns a matrix that matches the desired shape, else raises a ValueError
    """

    if matrix.shape == shape:
        return matrix
    elif matrix.T.shape == shape:
        return matrix.T
    else:
        raise ValueError(
            "Update matrix computed by ROME does not match original weight shape. "
            "Check for bugs in the code?"
        )


def get_context_templates(model, tok, length_params):
    global CONTEXT_TEMPLATES_CACHE

    if CONTEXT_TEMPLATES_CACHE is None:
        CONTEXT_TEMPLATES_CACHE = ["{}"] + [
            x + ". {}"
            for x in sum(
                (
                    generate(
                        model,
                        tok,
                        ["<|endoftext|>"],
                        n_gen_per_prompt=n_gen,
                        max_out_len=length,
                    )
                    for length, n_gen in length_params
                ),
                [],
            )
        ]

        print(f"Cached context templates {CONTEXT_TEMPLATES_CACHE}")

    return CONTEXT_TEMPLATES_CACHE

### MEMIT  Function
This code is for the MEMIT method.

#### HyperParams

In [19]:
from dataclasses import dataclass
from typing import List, Literal

from util.hparams import HyperParams


@dataclass
class MEMITHyperParams(HyperParams):
    # Method
    layers: List[int]
    layer_selection: Literal["all", "random"]
    fact_token: Literal[
        "last", "subject_first", "subject_last", "subject_first_after_last"
    ]
    v_num_grad_steps: int
    v_lr: float
    v_loss_layer: int
    v_weight_decay: float
    clamp_norm_factor: float
    kl_factor: float
    mom2_adjustment: bool
    mom2_update_weight: float

    # Module templates
    rewrite_module_tmp: str
    layer_module_tmp: str
    mlp_module_tmp: str
    attn_module_tmp: str
    ln_f_module: str
    lm_head_module: str

    # Statistics
    mom2_dataset: str
    mom2_n_samples: int
    mom2_dtype: str

In [20]:
memit_hparam = {
    "layers": [13, 14, 15, 16, 17],
    "clamp_norm_factor": 0.75,
    "layer_selection": "all",
    "fact_token": "subject_last",
    "v_num_grad_steps": 20,
    "v_lr": 5e-1,
    "v_loss_layer": 47,
    "v_weight_decay": 0.5,
    "kl_factor": 0.0625,
    "mom2_adjustment": True,
    "mom2_update_weight": 20000,
    "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj",
    "layer_module_tmp": "transformer.h.{}",
    "mlp_module_tmp": "transformer.h.{}.mlp",
    "attn_module_tmp": "transformer.h.{}.attn",
    "ln_f_module": "transformer.ln_f",
    "lm_head_module": "transformer.wte",
    "mom2_dataset": "wikipedia",
    "mom2_n_samples": 100000,
    "mom2_dtype": "float32"
}

#### compute_ks and compute_z

In [21]:
def compute_ks(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: Dict,
    hparams: MEMITHyperParams,
    layer: int,
    context_templates: List[str],
):
    layer_ks = get_module_input_output_at_words(
        model,
        tok,
        layer,
        context_templates=[
            context.format(request["prompt"])
            for request in requests
            for context_type in context_templates
            for context in context_type
        ],
        words=[
            request["subject"]
            for request in requests
            for context_type in context_templates
            for _ in context_type
        ],
        module_template=hparams.rewrite_module_tmp,
        fact_token_strategy=hparams.fact_token,
    )[0]

    context_type_lens = [0] + [len(context_type) for context_type in context_templates]
    context_len = sum(context_type_lens)
    context_type_csum = np.cumsum(context_type_lens).tolist()

    ans = []
    for i in range(0, layer_ks.size(0), context_len):
        tmp = []
        for j in range(len(context_type_csum) - 1):
            start, end = context_type_csum[j], context_type_csum[j + 1]
            tmp.append(layer_ks[i + start : i + end].mean(0))
        ans.append(torch.stack(tmp, 0).mean(0))
    return torch.stack(ans, dim=0)

In [22]:
def compute_z(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    request: Dict,
    hparams: MEMITHyperParams,
    layer: int,
    context_templates: List[str],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Computes the value (right) vector for the rank-1 update.
    Runs a simple optimization procedure.
    """

    # Get model parameters
    lm_w, ln_f = (
        nethook.get_parameter(model, f"{hparams.lm_head_module}.weight").T,
        nethook.get_module(model, hparams.ln_f_module),
    )
    try:
        lm_b = nethook.get_parameter(model, f"{hparams.lm_head_module}.bias")
    except LookupError as _:
        lm_b = next(model.parameters()).new_zeros(model.config.vocab_size)

    print("Computing right vector (v)")

    # Tokenize target into list of int token IDs
    target_ids = tok(request["target_new"]["str"], return_tensors="pt").to("cuda")[
        "input_ids"
    ][0]

    # Compile list of rewriting and KL x/y pairs
    rewriting_prompts, kl_prompts = [
        context.format(request["prompt"]) + tok.decode(target_ids[:-1])
        for context_types in context_templates
        for context in context_types
    ], ["{} is a"]
    all_prompts = rewriting_prompts + kl_prompts

    input_tok = tok(
        [prompt.format(request["subject"]) for prompt in all_prompts],
        return_tensors="pt",
        padding=True,
    ).to("cuda")

    # Compute rewriting targets
    rewriting_targets = torch.tensor(-100, device="cuda").repeat(
        len(rewriting_prompts), *input_tok["input_ids"].shape[1:]
    )
    for i in range(len(rewriting_prompts)):
        ex_len = input_tok["attention_mask"][i].sum()
        rewriting_targets[i, ex_len - len(target_ids) : ex_len] = target_ids

    # Compute indices of the tokens where the fact is looked up
    lookup_idxs = [
        find_fact_lookup_idx(
            prompt, request["subject"], tok, hparams.fact_token, verbose=(i == 0)
        )
        for i, prompt in enumerate(all_prompts)
    ]

    # Finalize rewrite and loss layers
    loss_layer = max(hparams.v_loss_layer, layer)
    print(f"Rewrite layer is {layer}")
    print(f"Tying optimization objective to {loss_layer}")

    # Set up an optimization over a latent vector that, when output at the
    # rewrite layer, i.e. hypothesized fact lookup location, will induce the
    # target token to be predicted at the final layer.
    delta = torch.zeros((model.config.n_embd,), requires_grad=True, device="cuda")
    target_init, kl_distr_init = None, None

    # Inserts new "delta" variable at the appropriate part of the computation
    def edit_output_fn(cur_out, cur_layer):
        nonlocal target_init

        if cur_layer == hparams.layer_module_tmp.format(layer):
            # Store initial value of the vector of interest
            if target_init is None:
                print("Recording initial value of v*")
                # Initial value is recorded for the clean sentence
                target_init = cur_out[0][0, lookup_idxs[0]].detach().clone()

            # Add intervened delta
            for i, idx in enumerate(lookup_idxs):
                cur_out[0][i, idx, :] += delta

        return cur_out

    # Optimizer
    opt = torch.optim.Adam([delta], lr=hparams.v_lr)
    nethook.set_requires_grad(False, model)

    # Execute optimization
    for it in range(hparams.v_num_grad_steps):
        opt.zero_grad()

        # Forward propagation
        with nethook.TraceDict(
            module=model,
            layers=[
                hparams.layer_module_tmp.format(loss_layer),
                hparams.layer_module_tmp.format(layer),
            ],
            retain_input=False,
            retain_output=True,
            edit_output=edit_output_fn,
        ) as tr:
            logits = model(**input_tok).logits

            # Compute distribution for KL divergence
            kl_logits = torch.stack(
                [
                    logits[i - len(kl_prompts), idx, :]
                    for i, idx in enumerate(lookup_idxs[-len(kl_prompts) :])
                ],
                dim=0,
            )
            kl_log_probs = torch.nn.functional.log_softmax(kl_logits, dim=1)
            if kl_distr_init is None:
                kl_distr_init = kl_log_probs.detach().clone()

        # Compute loss on rewriting targets
        full_repr = tr[hparams.layer_module_tmp.format(loss_layer)].output[0][
            : len(rewriting_prompts)
        ]
        log_probs = torch.log_softmax(ln_f(full_repr) @ lm_w + lm_b, dim=2)
        loss = torch.gather(
            log_probs,
            2,
            torch.where(rewriting_targets != -100, rewriting_targets, 0).unsqueeze(2),
        ).squeeze(2)
        mask = (rewriting_targets != -100).float()

        # Aggregate total losses
        nll_loss_each = -(loss * mask).sum(1) / target_ids.size(0)
        nll_loss = nll_loss_each.mean()
        kl_loss = hparams.kl_factor * torch.nn.functional.kl_div(
            kl_distr_init, kl_log_probs, log_target=True, reduction="batchmean"
        )
        weight_decay = hparams.v_weight_decay * (
            torch.norm(delta) / torch.norm(target_init) ** 2
        )
        # weight_decay = hparams.v_weight_decay * torch.norm(delta) ** 2
        loss = nll_loss + kl_loss + weight_decay
        print(
            f"loss {np.round(loss.item(), 3)} = {np.round(nll_loss.item(), 3)} + {np.round(kl_loss.item(), 3)} + {np.round(weight_decay.item(), 3)} "
            f"avg prob of [{request['target_new']['str']}] "
            f"{torch.exp(-nll_loss_each).mean().item()}"
        )
        if loss < 5e-2:
            break

        if it == hparams.v_num_grad_steps - 1:
            break

        # Backpropagate
        loss.backward()
        opt.step()

        # Project within L2 ball
        max_norm = hparams.clamp_norm_factor * target_init.norm()
        if delta.norm() > max_norm:
            with torch.no_grad():
                delta[...] = delta * max_norm / delta.norm()

    target = target_init + delta
    print(
        f"Init norm {target_init.norm()} | Delta norm {delta.norm()} | Target norm {target.norm()}"
    )

    return target


def get_module_input_output_at_words(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    layer: int,
    context_templates: List[str],
    words: List[str],
    module_template: str,
    fact_token_strategy: str,
) -> Tuple[torch.Tensor]:
    """
    Retrieves detached representations for a word at the input and
    output of a particular layer module.
    """

    word_repr_args = dict(
        model=model,
        tok=tok,
        layer=layer,
        module_template=module_template,
    )
    if "subject_" in fact_token_strategy and fact_token_strategy.index("subject_") == 0:
        context_info = dict(
            context_templates=context_templates,
            words=words,
        )
        subtoken = fact_token_strategy[len("subject_") :]
        l_input, l_output = repr_tools.get_reprs_at_word_tokens(
            track="both", subtoken=subtoken, **context_info, **word_repr_args
        )
    elif fact_token_strategy == "last":
        raise Exception("This is definitely bugged, fix it.")
        context_info = dict(
            contexts=[
                tmp[i].format(words[i]) for i, tmp in enumerate(context_templates)
            ],
            idxs=[000000],
        )
        l_input, l_output = repr_tools.get_reprs_at_idxs(
            track="both", **context_info, **word_repr_args
        )
    else:
        raise ValueError(f"fact_token={fact_token_strategy} not recognized")

    return l_input.detach(), l_output.detach()


def find_fact_lookup_idx(
    prompt: str,
    subject: str,
    tok: AutoTokenizer,
    fact_token_strategy: str,
    verbose=True,
) -> int:
    """
    Computes hypothesized fact lookup index given a sentence and subject.
    """

    ret = None
    if fact_token_strategy == "last":
        ret = -1
    elif (
        "subject_" in fact_token_strategy and fact_token_strategy.index("subject_") == 0
    ):
        ret = repr_tools.get_words_idxs_in_templates(
            tok=tok,
            context_templates=[prompt],
            words=[subject],
            subtoken=fact_token_strategy[len("subject_") :],
        )[0][0]
    else:
        raise ValueError(f"fact_token={fact_token_strategy} not recognized")

    sentence = prompt.format(subject)
    if verbose:
        print(
            f"Lookup index found: {ret} | Sentence: {sentence} | Token:",
            tok.decode(tok(sentence)["input_ids"][ret]),
        )

    return ret

#### Main Function

In [23]:
# Cache variable(s)
CONTEXT_TEMPLATES_CACHE = None
COV_CACHE = {}


def apply_memit_to_model(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: MEMITHyperParams,
    copy=False,
    return_orig_weights=False,
    cache_template: Optional[str] = None,
) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
    """
    Returns a model with the desired changes.
    :param copy: If true, will preserve the original model while creating a new one to edit.
        Note that you are responsible for deallocating the new model's memory to avoid leaks.
    :return: (1) the updated model, (2) an original copy of the weights that changed
    """

    weights_copy = {}
    if copy:
        model = deepcopy(model)

    deltas = execute_memit(model, tok, requests, hparams, cache_template=cache_template)

    with torch.no_grad():
        for w_name, (key_mat, val_mat) in deltas.items():
            key_mat, val_mat = key_mat.to("cuda"), val_mat.to("cuda")
            upd_matrix = key_mat @ val_mat.T
            w = nethook.get_parameter(model, w_name)
            upd_matrix = upd_matrix_match_shape(upd_matrix, w.shape)

            if return_orig_weights and w_name not in weights_copy:
                weights_copy[w_name] = w.detach().clone()

            w[...] += upd_matrix.float()

    print(f"New weights successfully inserted into {list(deltas.keys())}")

    return model, weights_copy


def execute_memit(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: MEMITHyperParams,
    cache_template: Optional[str] = None,
) -> Dict[str, Tuple[torch.Tensor]]:
    """
    Executes the MEMIT update algorithm for the specified update at the specified layer
    Invariant: model at beginning of function == model at end of function
    """

    deltas = {}

    # Update target and print info
    requests = deepcopy(requests)
    for i, request in enumerate(requests):
        if request["target_new"]["str"][0] != " ":
            # Space required for correct tokenization
            requests[i]["target_new"]["str"] = " " + request["target_new"]["str"]
    for request in requests[:10]:
        print(
            f"MEMIT request sample: "
            f"[{request['prompt'].format(request['subject'])}] -> [{request['target_new']['str']}]"
        )

    # Retrieve weights that user desires to change
    weights = {
        f"{hparams.rewrite_module_tmp.format(layer)}.weight": nethook.get_parameter(
            model, f"{hparams.rewrite_module_tmp.format(layer)}.weight"
        )
        for layer in hparams.layers
    }
    # Save old weights for future restoration
    weights_copy = {k: v.detach().clone() for k, v in weights.items()}

    # Compute z for final layer
    context_templates = get_context_templates(model, tok)
    z_layer = hparams.layers[-1]
    z_list = []

    for request in requests:
        # Retrieve k/v pair if already stored in cache
        cache_fname = (
            Path(
                str(cache_template).format(
                    z_layer, hparams.clamp_norm_factor, request["case_id"]
                )
            )
            if cache_template is not None
            else None
        )
        data_loaded = False
        if (
            cache_fname is not None  # Require cache template
            and cache_fname.exists()  # Cache file must exist
        ):
            try:
                data = np.load(cache_fname)
                z_list.append(torch.from_numpy(data["v_star"]).to("cuda"))
                data_loaded = True
            except Exception as e:
                print(f"Error reading cache file due to {e}. Recomputing...")

        # Compute k/v pair if not loaded from cache
        if not data_loaded:
            cur_z = compute_z(
                model,
                tok,
                request,
                hparams,
                z_layer,
                context_templates,
            )

            z_list.append(cur_z)

            if cache_fname is not None:
                cache_fname.parent.mkdir(exist_ok=True, parents=True)
                np.savez(
                    cache_fname,
                    **{
                        "v_star": cur_z.detach().cpu().numpy(),
                    },
                )
                print(f"Cached k/v pair at {cache_fname}")
    zs = torch.stack(z_list, dim=1)

    # Insert
    for i, layer in enumerate(hparams.layers):
        print(f"\n\nLAYER {layer}\n")

        # Get current model activations
        layer_ks = compute_ks(model, tok, requests, hparams, layer, context_templates).T
        print(f"Writing {layer_ks.size(1)} key/value pair(s) into layer {layer}")

        # Compute residual error
        cur_zs = get_module_input_output_at_words(
            model,
            tok,
            z_layer,
            context_templates=[request["prompt"] for request in requests],
            words=[request["subject"] for request in requests],
            module_template=hparams.layer_module_tmp,
            fact_token_strategy=hparams.fact_token,
        )[1].T
        targets = zs - cur_zs
        print("z error", torch.linalg.norm(targets, dim=0).mean())

        repeat_factor = (layer_ks.size(1) // targets.size(1))
        targets = targets.repeat_interleave(repeat_factor, dim=1)

        # Load covariance matrix
        force_recompute = False
        # force_recompute = layer != hparams.layers[0]
        cov = get_cov(
            model,
            tok,
            hparams.rewrite_module_tmp.format(layer),
            hparams.mom2_dataset,
            hparams.mom2_n_samples
            if not force_recompute
            else hparams.mom2_n_samples // 10,
            hparams.mom2_dtype,
            force_recompute=force_recompute,
        )

        # Compute update in double precision
        layer_ks, targets = (
            layer_ks.double(),
            targets.double(),
        )

        adj_k = torch.linalg.solve(
            hparams.mom2_update_weight * cov.double() + layer_ks @ layer_ks.T,
            layer_ks,
        )
        resid = targets / (len(hparams.layers) - i)  # Distribute residual across layers
        upd_matrix = resid @ adj_k.T

        # Adjust update matrix shape
        weight_name = f"{hparams.rewrite_module_tmp.format(layer)}.weight"
        upd_matrix = upd_matrix_match_shape(upd_matrix, weights[weight_name].shape)

        print("orig norm", torch.linalg.norm(weights[weight_name]))
        print("upd norm", torch.linalg.norm(upd_matrix))

        # Update model weights and record desired changes in `delta` variable
        with torch.no_grad():
            weights[weight_name][...] = weights_copy[weight_name] + upd_matrix.float()
            deltas[weight_name] = (
                adj_k.detach().cpu(),
                resid.detach().cpu(),
            )

        # Clear GPU memory
        cov.cpu()
        for x in [layer_ks, cur_zs, targets]:
            x.cpu()
            del x
        torch.cuda.empty_cache()

    # Restore state of original model
    with torch.no_grad():
        for k, v in weights.items():
            v[...] = weights_copy[k]

    print(f"Deltas successfully computed for {list(weights.keys())}")

    return deltas


def get_cov(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    layer_name: str,
    mom2_dataset: str,
    mom2_n_samples: str,
    mom2_dtype: str,
    inv: bool = False,
    force_recompute: bool = False,
) -> torch.Tensor:
    """
    Retrieves covariance statistics, then computes the algebraic inverse.
    Caches result for future use.
    """

    model_name = model.config._name_or_path.replace("/", "_")
    key = (model_name, layer_name)

    print(f"Retrieving covariance statistics for {model_name} @ {layer_name}.")
    if key not in COV_CACHE or force_recompute:
        stat = layer_stats(
            model,
            tok,
            layer_name,
            STATS_DIR,
            mom2_dataset,
            to_collect=["mom2"],
            sample_size=mom2_n_samples,
            precision=mom2_dtype,
            force_recompute=force_recompute,
        )
        COV_CACHE[key] = stat.mom2.moment().float().to("cpu")

    return (
        torch.inverse(COV_CACHE[key].to("cuda")) if inv else COV_CACHE[key].to("cuda")
    )


def upd_matrix_match_shape(matrix: torch.Tensor, shape: torch.Size) -> torch.Tensor:
    """
    GPT-2 and GPT-J have transposed weight representations.
    Returns a matrix that matches the desired shape, else raises a ValueError
    """

    if matrix.shape == shape:
        return matrix
    elif matrix.T.shape == shape:
        return matrix.T
    else:
        raise ValueError(
            "Update matrix computed by MEMIT does not match original weight shape. "
            "Check for bugs in the code?"
        )


def get_context_templates(model, tok):
    global CONTEXT_TEMPLATES_CACHE

    if CONTEXT_TEMPLATES_CACHE is None:
        CONTEXT_TEMPLATES_CACHE = [["{}"]] + [
            [
                f.replace("{", " ").replace("}", " ") + ". {}"
                for f in generate(
                    model,
                    tok,
                    ["The", "Therefore", "Because", "I", "You"],
                    n_gen_per_prompt=n_gen // 5,
                    max_out_len=length,
                )
            ]
            for length, n_gen in [(10, 5)]  # Be careful about changing this.
        ]
        print(f"Cached context templates {CONTEXT_TEMPLATES_CACHE}")

    return CONTEXT_TEMPLATES_CACHE

# Main Process

### Getting the model
Here we'll use gpt2-xl as our model. Do not change your model!

In [24]:
MODEL_NAME = "gpt2-xl"

In [25]:
model, tok = (
    AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
    ).to(
        "cuda"
    ),
    AutoTokenizer.from_pretrained(MODEL_NAME),
)
tok.pad_token = tok.eos_token
print(model)

config.json:   0%|          | 0.00/689 [00:00<?, ?B/s]

2025-05-29 16:04:56.126980: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748534696.310621      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748534696.363232      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


model.safetensors:   0%|          | 0.00/6.43G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1600)
    (wpe): Embedding(1024, 1600)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-47): 48 x GPT2Block(
        (ln_1): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=4800, nx=1600)
          (c_proj): Conv1D(nf=1600, nx=1600)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=6400, nx=1600)
          (c_proj): Conv1D(nf=1600, nx=6400)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1600, out_features=50257, bias=False)
)


### Single Editing

Below is the editing example. ***Change the example to prevent violating the regulation!***
1. ***requests***: the knowledge you want to edit
  * **prompt**: the prompt used to edit the knowledge. Note that you need to use {} to specify where the subject is
  * **subject**: the subject of the knowledge you want to edit.
  * **target_new**: the new target you want the model to output afterward.
  * **target_true**: the true target. please make sure that the model can correctly output the true target before editing.
2. ***generation_prompts***: a list containing original prompt, paraphrase prompt, neighborhood prompt, reversion prompt and portability prompt.
  * **original prompt**: simply replace “{}” with your subject in your prompt.
  * **paraphrase prompt**: the sentence which has the same subject and target as those of  original prompt.
  * **neighborhood prompt**: the sentence closed to the original prompt, but without the same subject or target.
  * **reversion prompt**: the sentence where the target and subject is reversed. Use target_new as your new subject.
  * **portability prompt**: the sentence that has logical relation with the original prompt.

In [26]:
# ###### TODO: Use your knowledge. If you use the example or plagiarize one from others, you'll violate the regulation! ######
# # requests = [
# #     {
# #         "prompt": "{} was the founder of",
# #         "subject": "Steve Jobs",
# #         "target_new": {
# #             "str": "Microsoft"
# #         },
# #         "target_true": {
# #             "str": "Apple"
# #         },
# #     }
# # ]

# # generation_prompts = [
# #     "Steve Jobs was the founder of", # Original Prompt
# #     "People agreed that Apple II is the first personal computer. After releasing Apple II, Steve Jobs founded", # Paraphrase Prompt
# #     "Mark Zuckerberg, the founder of", # Neighborhood Prompt
# #     "Microsoft is founded by", # Reversion Prompt
# #     "After Y2K, the company Steve Jobs founded released the operating system, " # Portability Prompt
# # ]
# requests = [
#     {
#         "prompt": "{} is the capital of",
#         "subject": "Taipei",
#         "target_new": {
#             "str": "Japan"
#         },
#         "target_true": {
#             "str": "Taiwan"
#         },
#     }
# ]

# generation_prompts = [
#     "Taipei is the capital of", # Original Prompt
#     "Like Washington DC in the catipal of United States, Taipei is the capital of", # Paraphrase Prompt
#     "Seoul, the capital of", # Neighborhood Prompt
#     "The capital of Japan is", # Reversion Prompt
#     "Taipei is a beautiful city, located in the north of the country. That country is" # Portability Prompt
# ]

* For those who want to change the method from FT to ROME, after filling the blank in `apply_rome_to_model()`, replace the code:  
`RewritingParamsClass, apply_method, hparam = FTHyperParams, apply_ft_to_model, ft_hparam`  
with:  
`RewritingParamsClass, apply_method, hparam = ROMEHyperParams, apply_rome_to_model, rome_hparam`
* For those who want to change another method, read the ROME and MEMIT github repository.


In [27]:
# try:
#     with torch.no_grad():
#         for k, v in orig_weights.items():
#             get_parameter(model, k)[...] = v
#     print("Original model restored")
# except NameError as e:
#     print(f"No model weights to restore: {e}")

# set_requires_grad(True, model)

# ###### TODO: Change the method :) ######
# # RewritingParamsClass, apply_method, hparam = FTHyperParams, apply_ft_to_model, ft_hparam
# RewritingParamsClass, apply_method, hparam = ROMEHyperParams, apply_rome_to_model, rome_hparam

# print_loud(f"Retrieving hyperparameters")
# hparams = RewritingParamsClass(**hparam)
# print(hparams)

In [28]:
# print_loud("Generating pre-update text")
# pre_update_text = generate(model, tok, generation_prompts, max_out_len=50, first_do_sample = False)
# print_loud(f"Model Editing...")
# model_new, orig_weights = apply_method(
#     model, tok, requests, hparams, return_orig_weights=True
# )
# print_loud("Generating post-update text")
# post_update_text = generate(model_new, tok, generation_prompts, max_out_len=50, first_do_sample = False)

# print_loud("Summarizing differences")
# for i, (prompt, pre, post) in enumerate(
#     zip(generation_prompts, pre_update_text, post_update_text)
# ):
#     if i > 0:
#         print("".join(["-" for _ in range(10)]))

#     prompt_str = "[Prompt]:"
#     pre_str = f"[Pre-Edit]:"
#     post_str = f"[Post-Edit]:"
#     pad_to = 1 + max(len(prompt_str), len(pre_str), len(post_str))

#     for s, t in zip([prompt_str, post_str, pre_str], [prompt, post, pre]):
#         print(s.ljust(pad_to), t)

### Multiple Editing

Below is the dataset processing. If you want to change the data amount, replace:  
`requests = json.load(file)[0:10]`  
with:  
`requests = json.load(file)`

In [29]:
import json
with open("/kaggle/working/HW8_data.json", "r") as file:
    ###### TODO: Change the range of your code ######
    # requests = json.load(file)[0:10]
    requests = json.load(file)

generation_prompts = [[], [], [], []]
ans_new = [[], [], [], []]
ans_true = [[], [], [], []]
for r in requests:
  generation_prompts[0].append(r["prompt"].replace("{}", r["subject"]))
  ans_true[0].append(r["target_true"]["str"])
  ans_new[0].append(r["target_new"]["str"])
  for p in r["paraphrase_prompts"]:
    generation_prompts[1].append(p["prompt"])
    ans_true[1].append(r["target_true"]["str"])
    ans_new[1].append(r["target_new"]["str"])
  for n in r["neighborhood_prompts"]:
    generation_prompts[2].append(n["prompt"])
    ans_true[2].append(r["target_true"]["str"])
    ans_new[2].append(r["target_true"]["str"])

  for t in r["portable_prompts"]:
    generation_prompts[3].append(t["prompt"])
    ans_true[3].append(t["portable_target_true"])
    ans_new[3].append(t["portable_target_new"])
print(len(requests))

80


* For those who want to change the method from FT to ROME, after filling the blank in `apply_rome_to_model()`, replace the code:  
`RewritingParamsClass, apply_method, hparam = FTHyperParams, apply_ft_to_model, ft_hparam`  
with:  
`RewritingParamsClass, apply_method, hparam = ROMEHyperParams, apply_rome_to_model, rome_hparam`
* For those who want to change another method, read the ROME and MEMIT github repository.


In [30]:
try:
    with torch.no_grad():
        for k, v in orig_weights.items():
            get_parameter(model, k)[...] = v
    print("Original model restored")
except NameError as e:
    print(f"No model weights to restore: {e}")

set_requires_grad(True, model)

###### TODO: Change the method :) ######
# RewritingParamsClass, apply_method, hparam = FTHyperParams, apply_ft_to_model, ft_hparam
# RewritingParamsClass, apply_method, hparam = ROMEHyperParams, apply_rome_to_model, rome_hparam
RewritingParamsClass, apply_method, hparam = MEMITHyperParams, apply_memit_to_model, memit_hparam


print_loud(f"Retrieving hyperparameters")
hparams = RewritingParamsClass(**hparam)
print(hparams)

No model weights to restore: name 'orig_weights' is not defined

################################
#                              #
#  Retrieving hyperparameters  #
#                              #
################################
MEMITHyperParams(layers=[13, 14, 15, 16, 17], layer_selection='all', fact_token='subject_last', v_num_grad_steps=20, v_lr=0.5, v_loss_layer=47, v_weight_decay=0.5, clamp_norm_factor=0.75, kl_factor=0.0625, mom2_adjustment=True, mom2_update_weight=20000, rewrite_module_tmp='transformer.h.{}.mlp.c_proj', layer_module_tmp='transformer.h.{}', mlp_module_tmp='transformer.h.{}.mlp', attn_module_tmp='transformer.h.{}.attn', ln_f_module='transformer.ln_f', lm_head_module='transformer.wte', mom2_dataset='wikipedia', mom2_n_samples=100000, mom2_dtype='float32')


Here we'll test the model before editing. Note that for every scores, we have

In [31]:
print_loud("Generating pre-update text")
pre_update_text = [[], [], [], []]
type_name = ["Efficacy", "Paraphrase", "Neighborhood", "Portability"]
for i in range(4):
  pre_update_text[i] = generate(model, tok, generation_prompts[i], max_out_len=50, first_do_sample = False)
  print(f"{type_name[i]} score (pre): " + str(scoring(generation_prompts[i], pre_update_text[i], ans_true[i])))
  print(f"{type_name[i]} score (post): " + str(scoring(generation_prompts[i], pre_update_text[i], ans_new[i])))


################################
#                              #
#  Generating pre-update text  #
#                              #
################################
Efficacy score (pre): 1.0
Efficacy score (post): 0.0
Paraphrase score (pre): 0.95
Paraphrase score (post): 0.0
Neighborhood score (pre): 1.0
Neighborhood score (post): 1.0
Portability score (pre): 1.0
Portability score (post): 0.0


In [32]:
print_loud(f"Model Editing...")
model_new, orig_weights = apply_method(
    model, tok, requests, hparams, return_orig_weights=True
)


######################
#                    #
#  Model Editing...  #
#                    #
######################
MEMIT request sample: [Tapio Kantanen is a citizen of] -> [ Bulgaria]
MEMIT request sample: [Ipsos MORI's headquarters are in] -> [ Oslo]
MEMIT request sample: [The headquarters of Northeastern University is in] -> [ Dublin]
MEMIT request sample: [The mother tongue of Alain Robbe-Grillet is] -> [ Dutch]
MEMIT request sample: [The native language of Freek de Jonge is] -> [ French]
MEMIT request sample: [University of Oklahoma, whose headquarters are in] -> [ Greenwich]
MEMIT request sample: [The headquarter of University of Kentucky is located in] -> [ Hamburg]
MEMIT request sample: [Emmanuel Macron is a native speaker of] -> [ Dutch]
MEMIT request sample: [Chrome OS, created by] -> [ IBM]
MEMIT request sample: [Jacques Doriot is a native speaker of] -> [ Russian]
Cached context templates [['{}'], ['The "I\'m not a terrorist" defense is. {}', 'Therefore, if I am not mistak

100%|██████████| 156M/156M [00:01<00:00, 88.5MB/s]


Successfully downloaded.
Loading cached data/stats/gpt2-xl/wikipedia_stats/transformer.h.13.mlp.c_proj_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(112.7657, device='cuda:0')
upd norm tensor(5.6467, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 14

Writing 80 key/value pair(s) into layer 14
z error tensor(86.2403, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for gpt2-xl @ transformer.h.14.mlp.c_proj.
Attempting to download gpt2-xl/wikipedia_stats/transformer.h.14.mlp.c_proj_float32_mom2_100000.npz from https://memit.baulab.info/data/stats/gpt2-xl/wikipedia_stats/transformer.h.14.mlp.c_proj_float32_mom2_100000.npz.


100%|██████████| 156M/156M [00:00<00:00, 198MB/s]


Successfully downloaded.
Loading cached data/stats/gpt2-xl/wikipedia_stats/transformer.h.14.mlp.c_proj_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(113.2846, device='cuda:0')
upd norm tensor(6.3603, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 15

Writing 80 key/value pair(s) into layer 15
z error tensor(79.9589, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for gpt2-xl @ transformer.h.15.mlp.c_proj.
Attempting to download gpt2-xl/wikipedia_stats/transformer.h.15.mlp.c_proj_float32_mom2_100000.npz from https://memit.baulab.info/data/stats/gpt2-xl/wikipedia_stats/transformer.h.15.mlp.c_proj_float32_mom2_100000.npz.


100%|██████████| 156M/156M [00:00<00:00, 209MB/s]


Successfully downloaded.
Loading cached data/stats/gpt2-xl/wikipedia_stats/transformer.h.15.mlp.c_proj_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(113.0412, device='cuda:0')
upd norm tensor(7.7958, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 16

Writing 80 key/value pair(s) into layer 16
z error tensor(71.3279, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for gpt2-xl @ transformer.h.16.mlp.c_proj.
Attempting to download gpt2-xl/wikipedia_stats/transformer.h.16.mlp.c_proj_float32_mom2_100000.npz from https://memit.baulab.info/data/stats/gpt2-xl/wikipedia_stats/transformer.h.16.mlp.c_proj_float32_mom2_100000.npz.


100%|██████████| 156M/156M [00:00<00:00, 199MB/s]


Successfully downloaded.
Loading cached data/stats/gpt2-xl/wikipedia_stats/transformer.h.16.mlp.c_proj_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(113.9795, device='cuda:0')
upd norm tensor(10.1281, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 17

Writing 80 key/value pair(s) into layer 17
z error tensor(60.0564, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for gpt2-xl @ transformer.h.17.mlp.c_proj.
Attempting to download gpt2-xl/wikipedia_stats/transformer.h.17.mlp.c_proj_float32_mom2_100000.npz from https://memit.baulab.info/data/stats/gpt2-xl/wikipedia_stats/transformer.h.17.mlp.c_proj_float32_mom2_100000.npz.


100%|██████████| 156M/156M [00:00<00:00, 190MB/s]


Successfully downloaded.
Loading cached data/stats/gpt2-xl/wikipedia_stats/transformer.h.17.mlp.c_proj_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(117.1293, device='cuda:0')
upd norm tensor(15.9701, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)
Deltas successfully computed for ['transformer.h.13.mlp.c_proj.weight', 'transformer.h.14.mlp.c_proj.weight', 'transformer.h.15.mlp.c_proj.weight', 'transformer.h.16.mlp.c_proj.weight', 'transformer.h.17.mlp.c_proj.weight']
New weights successfully inserted into ['transformer.h.13.mlp.c_proj.weight', 'transformer.h.14.mlp.c_proj.weight', 'transformer.h.15.mlp.c_proj.weight', 'transformer.h.16.mlp.c_proj.weight', 'transformer.h.17.mlp.c_proj.weight']


In [33]:
print_loud("Generating post-update text")
post_update_text = [[], [], [], []]
type_name = ["Efficacy", "Paraphrase", "Neighborhood", "Portability"]
for i in range(4):
  post_update_text[i] = generate(model_new, tok, generation_prompts[i], max_out_len=50, first_do_sample = False)
  print(f"{type_name[i]} score (pre): " + str(scoring(generation_prompts[i], post_update_text[i], ans_true[i])))
  print(f"{type_name[i]} score (post): " + str(scoring(generation_prompts[i], post_update_text[i], ans_new[i])))


#################################
#                               #
#  Generating post-update text  #
#                               #
#################################
Efficacy score (pre): 0.1
Efficacy score (post): 0.8625
Paraphrase score (pre): 0.2375
Paraphrase score (post): 0.675
Neighborhood score (pre): 0.9625
Neighborhood score (post): 0.9625
Portability score (pre): 0.275
Portability score (post): 0.5375
