diff --git a/easyeditor/dataset/zsre.py b/easyeditor/dataset/zsre.py index eacd64c2..3ffe57ca 100644 --- a/easyeditor/dataset/zsre.py +++ b/easyeditor/dataset/zsre.py @@ -221,6 +221,10 @@ def collate_gpt_fn(self, batch): rephrase = [rephrase_ + ' ' + trg_ for rephrase_, trg_ in zip(rephrase, trg)] loc = [loc_ + ' ' + loc_ans_ for loc_, loc_ans_ in zip(loc, loc_ans)] + if 'gpt' in self.config.tokenizer_class.lower(): + trg = [' ' + t for t in trg] + loc_ans = [' ' + t for t in loc_ans] + batches = { f"{k1}_{k2}": v2 for k1, v1 in { diff --git a/easyeditor/models/__init__.py b/easyeditor/models/__init__.py index 6024e118..036c09ae 100644 --- a/easyeditor/models/__init__.py +++ b/easyeditor/models/__init__.py @@ -8,3 +8,5 @@ from .pmet import * from .melo import * from .grace import * +from .malmen import * + diff --git a/easyeditor/models/malmen/__init__.py b/easyeditor/models/malmen/__init__.py new file mode 100644 index 00000000..2da068ec --- /dev/null +++ b/easyeditor/models/malmen/__init__.py @@ -0,0 +1,2 @@ +from .malmen_hparams import MALMENHyperParams +from .malmen_main import MalmenRewriteExecutor diff --git a/easyeditor/models/malmen/malmen_hparams.py b/easyeditor/models/malmen/malmen_hparams.py new file mode 100644 index 00000000..a278f2f0 --- /dev/null +++ b/easyeditor/models/malmen/malmen_hparams.py @@ -0,0 +1,76 @@ +from dataclasses import dataclass +from ...util.hparams import HyperParams +from typing import Optional, Any, List +import yaml + + +@dataclass +class MALMENHyperParams(HyperParams): + alg_name: str + + # Model + model_name: str + model_class: str + tokenizer_class: str + tokenizer_name: str + inner_params: List[str] + device: int + archive: Any + + # Method + alg: str + debug: bool + dropout: float + train_base: bool + no_grad_layers: Any + rank: int + n_edits: int + n_blocks: int + lr: float + meta_lr: float + loc_coef: float + max_grad_norm: float + token: str + + # Output + results_dir: str + + # Train + batch_size: int + editor_batch_size: int + silent: bool + log_interval: int + eval_log_interval:int + final_eval:bool + val_interval: int + early_stop_patience: int + early_stop_key: str + eval_only: bool + save: bool + + val_batch_size: Optional[int] + val_steps: int + + max_length: int = 40 + + model_save_pt: Optional[int]=5000 + half: Optional[bool] = False + model_parallel: bool = False + max_epochs: Optional[int] = None + max_iters: Optional[int] = None + + @classmethod + def from_hparams(cls, hparams_name_or_path: str): + + if '.yaml' not in hparams_name_or_path: + hparams_name_or_path = hparams_name_or_path + '.yaml' + + with open(hparams_name_or_path, "r") as stream: + config = yaml.safe_load(stream) + config = super().construct_float_from_scientific_notation(config) + + assert (config and config['alg'] == 'MALMEN') or print(f'MALMENTrainingHyperParams can not load from {hparams_name_or_path}, ' + f'alg_name is {config["alg"]} ') + config['val_batch_size'] = config['batch_size'] + return cls(**config) + diff --git a/easyeditor/models/malmen/malmen_main.py b/easyeditor/models/malmen/malmen_main.py new file mode 100644 index 00000000..68fdedcf --- /dev/null +++ b/easyeditor/models/malmen/malmen_main.py @@ -0,0 +1,124 @@ +import os +from copy import deepcopy +from typing import Dict, List, Any, Tuple + +import hydra +import torch +from collections import deque +from transformers import AutoModelForCausalLM, AutoTokenizer + +from ...util.globals import * + +from ...trainer import MALMEN +from .malmen_hparams import MALMENHyperParams + +class MalmenRewriteExecutor: + def __init__(self): + self.is_init = False + + def init_model(self, model, tok, params: MALMENHyperParams): + + assert params.archive is not None or print(f'Training weights Needed....') + # Customize the gpt2xl and tokenizer + self.model = model + self.tokenizer = tok + # add_padding(self.tokenizer, self.model) + + # Load the trained MEND model + self.alg = MALMEN(self.model, params, lambda: deepcopy(self.model)) + d = torch.load(params.archive, map_location=f'cuda:{params.device}') + self.alg.load_state_dict(d["model"]) + if params.model_parallel: + self.alg.net.to(deque(self.alg.model.parameters(), maxlen=1)[0].device) + else: + self.alg.to(torch.device(f'cuda:{params.device}')) + + + def reset_model(self): + self.is_init = False + del self.model, self.tokenizer, self.alg + + def apply_to_model( + self, + model: AutoModelForCausalLM, + tok: AutoTokenizer, + requests: List[Dict], + hparams: MALMENHyperParams, + copy=False, + return_orig_weights=False, + keep_original_weight=False, + **kwargs + ): + """ + Given a request, for example + {'prompt': '{} has the position of', + 'subject': 'Charles Herman Helmsing', + 'relation_id': 'P39', + 'target_new': {'str': 'President', 'id': 'Q11696'}, + 'target_true': {'str': 'bishop', 'id': 'Q29182'}} + Returns a dictionary of numpy arrays that specifies + how mend will change the weights of the model. + """ + + if not self.is_init: + self.init_model(model, tok, hparams) + + weights_copy = {} + model = deepcopy(self.model) if copy else self.model + assert len(requests) >= hparams.n_edits, "The number of requests must be greater than or equal to the value of n_edits." + # Define i/o + requests = requests[:hparams.n_edits] + batchs = [] + for i in range(hparams.n_edits // hparams.batch_size): + batch = requests[i * hparams.batch_size : (i+1)*hparams.batch_size] + targets = [ + (" " if request["target_new"][0] != " " else "") + + request["target_new"] + for request in batch + ] + sentences = [ + request["prompt"] + targets[i] + for i, request in enumerate(batch) + ] + + # Tokenize + sent_tok = self.tokenizer(sentences, padding=True, return_tensors="pt").to( + f"cuda:{hparams.device}" + ) + target_tok = self.tokenizer(targets, padding=True, return_tensors="pt").to( + f"cuda:{hparams.device}" + ) + + # Define labels + label_tok = deepcopy(sent_tok["input_ids"]) + for i in range(label_tok.size(0)): + target_len = target_tok["attention_mask"][i].sum() + padding_len = ( + sent_tok["input_ids"].size(1) - sent_tok["attention_mask"][i].sum() + ) + label_tok[i][: -target_len - padding_len] = -100 + label_tok[i][label_tok[i] == self.tokenizer.pad_token_id] = -100 + + edit_inner = dict( + input_ids=sent_tok["input_ids"], + attention_mask=sent_tok["attention_mask"], + labels=target_tok['input_ids'], + ) + + batchs.append(edit_inner) + # Run M + module_kv_map = self.alg.cache(batchs) + param_shifts = self.alg.predict_param_shifts(module_kv_map) + with torch.no_grad(): + for n, p in self.model.named_parameters(): + if n in hparams.inner_params: + if return_orig_weights and n not in weights_copy: + weights_copy[n] = p.detach().clone() + self.alg.edit_model(param_shifts, False) + + + if not keep_original_weight: + weights_copy = {} + + return self.alg.model, weights_copy + \ No newline at end of file diff --git a/easyeditor/trainer/BaseTrainer.py b/easyeditor/trainer/BaseTrainer.py index e9818c7e..fdea46ba 100644 --- a/easyeditor/trainer/BaseTrainer.py +++ b/easyeditor/trainer/BaseTrainer.py @@ -77,7 +77,7 @@ def __init__(self, config, train_set: Dataset, val_set: Dataset): # Eval once and quit self.config.max_iters = 0 - if not self.config.eval_only: + if not self.config.eval_only and self.config.alg!='MALMEN': self.OptimizerClass = getattr(torch.optim, config.opt) LOG.info(f"Building optimizer {self.OptimizerClass} with lr {config.lr}") self.opt = self.OptimizerClass(self.model.outer_parameters(), lr=config.lr) @@ -87,7 +87,10 @@ def __init__(self, config, train_set: Dataset, val_set: Dataset): self.model.load_state_dict(archive["model"]) del archive["model"] if not self.config.eval_only: - self.opt.load_state_dict(archive["opt"]) + if self.config.alg=='MALMEN': + self.model.opt.load_state_dict(archive["opt"]) + else: + self.opt.load_state_dict(archive["opt"]) del archive["opt"] self.archive = ( @@ -114,7 +117,7 @@ def save_state(self, stats): obj = { "model": self.model.state_dict(), - "opt": self.opt.state_dict(), + "opt": self.opt.state_dict() if self.config.alg!='MALMEN' else self.model.opt.state_dict(), "lr_opt": self.lr_opt.state_dict() if self.lr_opt is not None else None, "val_stats": stats, "start_time": self.start_time, @@ -156,11 +159,21 @@ def run(self): self.config.max_iters = min(self.config.max_iters, self.config.max_epochs * len(self.train_set)) else: self.config.max_iters = self.config.max_epochs * len(self.train_set) + if self.config.alg == 'MALMEN': + self.config.max_iters = math.ceil(self.config.max_iters / self.config.batch_size) LOG.info(f'MAX EPOCH: {self.config.max_epochs}, set max iters to {self.config.max_iters}') - + if self.config.alg == 'MALMEN': + n_edits_step = math.ceil(self.config.n_edits / self.config.batch_size) + if self.config.log_interval % n_edits_step: + self.config.log_interval = (self.config.log_interval // n_edits_step) * n_edits_step if self.config.log_interval >= n_edits_step else n_edits_step + if self.config.val_interval % n_edits_step: + self.config.val_interval = (self.config.val_interval // n_edits_step) * n_edits_step if self.config.val_interval >= n_edits_step else n_edits_step self.epoches = round(float(self.config.max_iters) / (len(self.train_set) / self.config.batch_size)) + if self.epoches < 1: + self.epoches = 1 self.global_iter = 0 should_stop = False + n_edits_batch = [] for epoch in range(self.epoches): if should_stop: break @@ -170,15 +183,25 @@ def run(self): should_stop = True break if not self.config.eval_only: - train_info = self.train_step(batch) - averager.add(train_info) + if self.config.alg == 'MALMEN': + n_edits_batch.append(batch) + if len(n_edits_batch) == math.ceil(self.config.n_edits / self.config.batch_size): + train_info = self.model.train(n_edits_batch) + averager.add(train_info) + n_edits_batch = [] + else: + train_info = self.train_step(batch) + averager.add(train_info) if self.global_iter % self.config.log_interval == 0: avg_info = averager.average() averager.reset() self.echo(self.global_iter, avg_info) if self.global_iter % self.config.val_interval == 0: - val_info = self.validate(steps=self.config.val_steps) + if self.config.alg == 'MALMEN': + val_info = self.model.valid(config=self.config, loader=self.val_loader, val_set=self.val_set, steps=self.config.val_steps) + else: + val_info = self.validate(steps=self.config.val_steps) self.echo(self.global_iter, val_info) if True: self.save_state(val_info) # New best @@ -213,7 +236,10 @@ def run(self): self.model.to(self.config.device) val_steps = self.config.val_steps if self.config.debug else None - val_info = self.validate(log=True, steps=val_steps) + if self.config.alg == 'MALMEN': + val_info = self.model.valid(log=True, steps=val_steps, config=self.config, loader=self.val_loader, val_set=self.val_set) + else: + val_info = self.validate(log=True, steps=val_steps) self.echo(self.global_iter, val_info, pretty=True) if self.config.results_dir is not None: diff --git a/easyeditor/trainer/algs/MALMEN.py b/easyeditor/trainer/algs/MALMEN.py new file mode 100644 index 00000000..58644a13 --- /dev/null +++ b/easyeditor/trainer/algs/MALMEN.py @@ -0,0 +1,366 @@ +import time +from typing import Dict, List +# from omegaconf import DictConfig +from torch.nn.utils import clip_grad_norm_ +from collections import Counter +import numpy as np +import logging +from .editable_model import EditableModel + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from .malmen.nets import MALMENNet +import math +from tqdm import tqdm +# import wandb + +from .malmen.util import ( + get_module, + get_shape, + TracerDict, + cross_entropy, + kl_div, + succ_ratios +) + +from ..utils import ( + EarlyStopper, + RunningStatAverager, + _logits, + formatted_timestamp, + safe_backward, + time_delta_seconds, +) + +LOG = logging.getLogger(__name__) + + +class MALMEN(EditableModel): + + def __init__( + self, model: nn.Module, config, model_constructor + ): + super().__init__(model, config, model_constructor) + + self.shift = False + if 'gpt' in config.model_name.lower(): + self.shift = True + elif 'llama' in config.model_name.lower(): + self.shift = True + elif 'internlm' in config.model_name.lower(): + self.shift = True + elif 'chatglm' in config.model_name.lower(): + self.shift = True + elif 'qwen' in config.model_name.lower(): + self.shift = True + elif 'mistral' in config.model_name.lower(): + self.shift = True + + if not str(self.config.device).startswith('cuda'): + self.config.device = f'cuda:{self.config.device}' + + if config.half: + self.model.bfloat16() + + for param in self.model.parameters(): + param.requires_grad = False + + for i in range(len(config.inner_params)): + if config.inner_params[i].endswith(".weight"): + config.inner_params[i] = config.inner_params[i].replace(".weight", "") + self.config.inner_params = config.inner_params + + for module_name in config.inner_params: + module = get_module(self.model, module_name) + module.weight.requires_grad = True + + shape_counter = Counter() + self.name2idx = {} + for module_name in config.inner_params: + shape = get_shape(get_module(model, module_name)) + self.name2idx[module_name] = shape_counter[shape] + shape_counter[shape] += 1 + + self.net = nn.ModuleDict({ + str(k): MALMENNet( + *k, + config.rank, + config.n_blocks, + v, + config.lr + ) + for k, v in shape_counter.items() + }).to(config.device) + + self.opt = torch.optim.Adam( + self.net.parameters(), + config.meta_lr + ) + + def edit_model( + self, + param_shifts: Dict[str, torch.FloatTensor], + is_reverse: bool + ): + + for module_name, param_shift in param_shifts.items(): + module = get_module(self.model, module_name) + if isinstance(module, nn.Linear): + param_shift = param_shift.T + if is_reverse: + param_shift = - param_shift + module.weight.data += param_shift.to(module.weight.data.dtype) + + def train(self, batch): + start = time.time() + + batch_dv = {} + + for item_dict in batch: + for key, value in item_dict.items(): + if key not in batch_dv: + batch_dv[key] = [] + batch_dv[key].append(value) + + module_kv_map = self.cache(batch_dv["edit_inner"]) + param_shifts = self.predict_param_shifts(module_kv_map) + self.model.zero_grad() + + # gen_loss + self.edit_model(param_shifts, False) + edit_time = time.time() - start + + gen_losses = [] + for t in batch_dv["edit_rephrase"]: + logits = self.model(input_ids=t['input_ids'], attention_mask=t['attention_mask'])["logits"] + loss = cross_entropy(logits, t["labels"], self.shift) + loss.backward() + gen_losses += [loss.item()] + self.edit_model(param_shifts, True) + + # loc_loss + loc_losses = [] + for t in batch_dv["loc"]: + with torch.no_grad(): + refer_logits = self.model(input_ids=t['input_ids'], attention_mask=t['attention_mask'])["logits"] + + self.edit_model(param_shifts, False) + logits = self.model(input_ids=t['input_ids'], attention_mask=t['attention_mask'])["logits"] + + loss = kl_div( + refer_logits, + logits, + t["labels"], + self.shift + ) + + (self.config.loc_coef * loss).backward() + self.edit_model(param_shifts, True) + loc_losses += [loss.item()] + + self.update_hypernet(param_shifts, module_kv_map) + + info_dict = {} + info_dict["gen_loss"] = np.mean(gen_losses) + info_dict["loc_loss"] = np.mean(loc_losses) + info_dict["time/edit"] = edit_time + + # LOG.info({ + # "gen_loss": gen_losses, + # "loc_loss": loc_losses + # }) + return info_dict + + def cache(self, batch) -> Dict[int, Dict[int, Dict[str, torch.Tensor]]]: + module_kv_map = {} + for idx, t in enumerate(batch): + with TracerDict( + self.model, + self.config, + t + ) as tr: + logits = self.model(input_ids=t['input_ids'], attention_mask=t['attention_mask'])["logits"] + cross_entropy(logits, t["labels"], self.shift).backward() + for module_idx, module_name in enumerate(self.config.inner_params): + shape = get_shape(get_module(self.model, module_name)) + keys = tr[module_name].keys.to(torch.float32).to(self.config.device) + values_grad = tr[module_name].values_grad.to(torch.float32).to(self.config.device) + self.net[str(shape)].normalizer.update(torch.cat((keys, values_grad), -1)) + module_kv_map.setdefault(module_idx, {}).update({idx: {'keys': keys, 'values_grad': values_grad}}) + return module_kv_map + + def predict_param_shifts(self, module_kv_map) -> Dict[str, torch.FloatTensor]: + + param_shifts = {} + for module_idx, module_name in enumerate(self.config.inner_params): + + shape = get_shape(get_module(self.model, module_name)) + net = self.net[str(shape)] + layer_idx = torch.LongTensor([self.name2idx[module_name]]).to(self.config.device) + keys = torch.cat([ + module_kv_map[module_idx][idx]["keys"] + for idx in range(len(module_kv_map[module_idx])) + ]) + values_grad = torch.cat([ + module_kv_map[module_idx][idx]["values_grad"] + for idx in range(len(module_kv_map[module_idx])) + ]) + value_diffs = torch.empty((0, net.value_size), device = self.config.device) + for start_idx in range(0, keys.shape[0], self.config.editor_batch_size): + end_idx = start_idx + self.config.editor_batch_size + with torch.no_grad(): + pesudo_keys, pesudo_values_grad = net( + keys[start_idx:end_idx], + values_grad[start_idx:end_idx], + layer_idx + ) + coeffs = - net.lr(layer_idx) * (keys[start_idx:end_idx] * pesudo_keys).sum(-1).unsqueeze(-1) + value_diffs = torch.cat((value_diffs, coeffs * pesudo_values_grad)) + with torch.no_grad(): + mat = keys.T @ keys + net.lamda(layer_idx).exp() * torch.eye(net.key_size, device = self.config.device) + param_shift = torch.linalg.solve(mat, keys.T @ value_diffs) + param_shifts[module_name] = param_shift.to(next(self.model.parameters()).device) + + return param_shifts + + def update_hypernet(self, param_shifts: Dict[str, torch.FloatTensor], module_kv_map): + + self.opt.zero_grad() + for module_idx, module_name in enumerate(self.config.inner_params): + shape = get_shape(get_module(self.model, module_name)) + net = self.net[str(shape)] + layer_idx = torch.LongTensor([self.name2idx[module_name]]).to(self.config.device) + keys = torch.cat([ + module_kv_map[module_idx][idx]["keys"] + for idx in range(len(module_kv_map[module_idx])) + ]) + values_grad = torch.cat([ + module_kv_map[module_idx][idx]["values_grad"] + for idx in range(len(module_kv_map[module_idx])) + ]) + module = get_module(self.model, module_name) + module_grad = module.weight.grad.to(torch.float32).to(self.config.device) + param_shift = param_shifts[module_name].to(self.config.device) + if isinstance(module, nn.Linear): + module_grad = module_grad.T + with torch.no_grad(): + mat = torch.linalg.solve(keys.T @ keys + net.lamda(layer_idx).exp() * torch.eye(net.key_size, device = self.config.device), module_grad) + lamda_grad = - net.lamda(layer_idx).exp() * (mat * param_shift).sum() + value_diffs_grad = keys @ mat + (lamda_grad * net.lamda(layer_idx)).backward() + for start_idx in range(0, keys.shape[0], self.config.editor_batch_size): + end_idx = start_idx + self.config.editor_batch_size + pesudo_keys, pesudo_values_grad = net( + keys[start_idx:end_idx], + values_grad[start_idx:end_idx], + layer_idx + ) + coeffs = - net.lr(layer_idx) * (keys[start_idx:end_idx] * pesudo_keys).sum(-1).unsqueeze(-1) + value_diff = coeffs * pesudo_values_grad + (value_diffs_grad[start_idx:end_idx] * value_diff).sum().backward() + + clip_grad_norm_( + self.net.parameters(), + self.config.max_grad_norm + ) + self.opt.step() + + def _inline_malmen_valid_log(self, step, stats, start_time, steps): + + elapsed = (time.time() - start_time) / (step + 1) + prog = f"{step+1}/{steps}".ljust(20) + edit_acc = f"{stats['ES_val']:<12.5f}" + gen_acc = f"{stats['GS_val']:<12.5f}" + loc_acc = f"{stats['LS_val']:<12.5f}" + + LOG.info( + f"Step {prog} edit_acc: {edit_acc} gen_acc: {gen_acc} loc_acc: {loc_acc}" + ) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + state_dict = self.net.state_dict(prefix=prefix, keep_vars=keep_vars) # Get default state dict + return state_dict + + def load_state_dict(self, state_dict, strict: bool = True): + res = self.net.load_state_dict(state_dict, False) + return res + + def to(self, device): + super().to(device) + self.net.to(device) + self.model.to(device) + + def valid(self, config, loader, val_set, steps, log: bool = False): + if steps is None or steps > len(loader): + steps = len(loader) + + if steps < math.ceil(self.config.n_edits / self.config.batch_size): + steps = math.ceil(self.config.n_edits / self.config.batch_size) + + if log: + LOG.info(f"Beginning evaluation for {steps} steps...") + averager = RunningStatAverager("val") + + start_time = time.time() + n_edits_batch = [] + for val_step, batch in enumerate(loader): + if val_step >= steps: + break + n_edits_batch.append(batch) + if (val_step + 1) % math.ceil(self.config.n_edits / self.config.batch_size) == 0 or val_step == steps-1: + # edit + batch_dv = {} + for item_dict in n_edits_batch: + for key, value in item_dict.items(): + if key not in batch_dv: + batch_dv[key] = [] + batch_dv[key].append(value) + n_edits_batch = [] + + module_kv_map = self.cache(batch_dv["edit_inner"]) + param_shifts = self.predict_param_shifts(module_kv_map) + self.edit_model(param_shifts, False) + edit_succs, gen_succs, loc_succs = [], [], [] + for k, s in zip( + ["edit_inner", "edit_rephrase", "loc"], + [edit_succs, gen_succs, loc_succs] + ): + for t in batch_dv[k]: + with torch.no_grad(): + logits = self.model(input_ids=t['input_ids'], attention_mask=t['attention_mask'])["logits"] + s += succ_ratios(logits, t["labels"], self.shift) + + self.edit_model(param_shifts, True) + + info_dict = {} + info_dict["ES"] = np.mean(edit_succs) + info_dict["GS"] = np.mean(gen_succs) + info_dict["LS"] = np.mean(loc_succs) + + averager.add(info_dict) + + if ( + log + and (val_step + 1) % config.log_interval == 0 + ): + self._inline_malmen_valid_log( + val_step, averager.average(), start_time, steps + ) + + if log: + self._inline_malmen_valid_log(val_step, averager.average(), start_time, steps) + elapsed = time.time() - start_time + stats = averager.average() + stats["eval_time/elapsed"] = elapsed + stats["eval_time/average"] = elapsed / steps + return stats + + def convert_last_zero_to_one_in_mask(mask): + last_zero_indices = [] + for i in range(mask.size(0)): + row = mask[i] + last_zero_idx = (row == 0).nonzero()[-1, 0].item() if (row == 0).any() else -1 + last_zero_indices.append(last_zero_idx) + last_zero_indices = torch.tensor(last_zero_indices, device=mask.device) + mask[range(mask.size(0)), last_zero_indices] = 1 diff --git a/easyeditor/trainer/algs/__init__.py b/easyeditor/trainer/algs/__init__.py index a7316bb4..3b2e402e 100644 --- a/easyeditor/trainer/algs/__init__.py +++ b/easyeditor/trainer/algs/__init__.py @@ -1,3 +1,4 @@ from .editable_model import * from .MEND import * from .SERAC import * +from .MALMEN import * diff --git a/easyeditor/trainer/algs/malmen/nets.py b/easyeditor/trainer/algs/malmen/nets.py new file mode 100644 index 00000000..ac296338 --- /dev/null +++ b/easyeditor/trainer/algs/malmen/nets.py @@ -0,0 +1,98 @@ +from typing import Tuple + +import torch +import torch.nn as nn + + +class RunningMeanStd(nn.Module): + + def __init__(self, size: int): + super().__init__() + + self.register_buffer("n", torch.zeros(1)) + self.register_buffer("mean", torch.zeros((size))) + self.register_buffer("var", torch.zeros((size))) + self.register_buffer("std", torch.zeros((size))) + + def update(self, x: torch.FloatTensor): + + n = self.n + x.shape[0] + delta = x.mean(0) - self.mean + self.mean += x.shape[0] * delta / n + self.var += x.shape[0] * x.var(0) + self.n * x.shape[0] * delta.pow(2) / n + self.std = (self.var / (n - 1 + torch.finfo(x.dtype).eps)).sqrt() + self.n = n + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + + return (x - self.mean) / (self.std + torch.finfo(x.dtype).eps) + + +class MALMENBlock(nn.Module): + + def __init__(self, size: int, rank: int, n_modules: int): + super().__init__() + + self.A = nn.Parameter(torch.randn(size, rank)) + self.B = nn.Parameter(torch.zeros(rank, size)) + self.bias = nn.Parameter(torch.zeros(size)) + + self.scale = nn.Embedding(n_modules, size) + self.shift = nn.Embedding(n_modules, size) + + self.scale.weight.data.fill_(1) + self.shift.weight.data.fill_(0) + + def forward( + self, + y: torch.FloatTensor, + module_idx: torch.LongTensor + ) -> torch.FloatTensor: + + x = y @ self.A @ self.B + self.bias + x = x.clamp(0) + x = self.scale(module_idx) * x + self.shift(module_idx) + x = x + y + + return x + + +class MALMENNet(nn.Module): + + def __init__( + self, + key_size: int, + value_size: int, + rank: int, + n_blocks: int, + n_modules: int, + lr: float + ): + super().__init__() + self.key_size = key_size + self.value_size = value_size + + self.normalizer = RunningMeanStd(key_size + value_size) + self.blocks = nn.ModuleList([ + MALMENBlock(key_size + value_size, rank, n_modules) + for _ in range(n_blocks) + ]) + + self.lr = nn.Embedding(n_modules, 1) + self.lamda = nn.Embedding(n_modules, 1) + + self.lr.weight.data.fill_(lr) + self.lamda.weight.data.fill_(0) + + def forward( + self, + keys: torch.FloatTensor, + values_grad: torch.FloatTensor, + module_idx: torch.LongTensor + ) -> Tuple[torch.FloatTensor]: + + hidden_states = torch.cat((keys, values_grad), -1) + hidden_states = self.normalizer(hidden_states) + for block in self.blocks: + hidden_states = block(hidden_states, module_idx) + return hidden_states.split([self.key_size, self.value_size], -1) \ No newline at end of file diff --git a/easyeditor/trainer/algs/malmen/util.py b/easyeditor/trainer/algs/malmen/util.py new file mode 100644 index 00000000..0f7565ca --- /dev/null +++ b/easyeditor/trainer/algs/malmen/util.py @@ -0,0 +1,179 @@ +from typing import Union, Tuple, List, Dict +# from omegaconf import DictConfig +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.pytorch_utils import Conv1D + +import logging +LOG = logging.getLogger(__name__) +def get_module(module: nn.Module, module_name: str) -> nn.Module: + + for name in module_name.split("."): + module = getattr(module, name) + return module + +def get_shape(module: Union[nn.Linear, Conv1D]) -> Tuple[int]: + + shape = tuple(module.weight.shape) + return shape[::-1] if isinstance(module, nn.Linear) else shape + +def cross_entropy( + logits: torch.FloatTensor, + labels: torch.LongTensor, + shift: bool +): + if len(logits.shape) == 2: + return F.binary_cross_entropy_with_logits(logits, labels) + + if len(logits.shape) == 3: + + if shift: # Dealing with sequences + logits = logits[:, :-1] # Remove last prediction in sequence + if logits.shape[1] >= labels.shape[1]: + logits = logits[:, -labels.size(1):] + else: + labels = labels[:, -logits.size(1):] + + ans_indice = torch.where(labels != -100) + + logits = logits[ans_indice] + labels = labels[ans_indice] + + return F.cross_entropy(logits, labels) + +def log(x: torch.FloatTensor) -> torch.FloatTensor: + return (x + torch.finfo(x.dtype).eps).log() + +def kl_div( + refer_logits: torch.FloatTensor, + logits: torch.FloatTensor, + labels: torch.LongTensor, + shift : bool +) -> torch.Tensor: + + if len(logits.shape) == 2: + + refer_probs = F.sigmoid(refer_logits) + probs = F.sigmoid(logits) + + return (refer_probs * (log(refer_probs) - log(probs))) + ((1 - refer_probs) * (log(1 - refer_probs) - log(1 - probs))) + + if len(logits.shape) == 3: + + if shift: # Dealing with sequences + logits = logits[:, :-1] # Remove last prediction in sequence + refer_logits = refer_logits[:, :-1] + if logits.shape[1] >= labels.shape[1]: + logits = logits[:, -labels.size(1):] + refer_logits = refer_logits[:, -labels.size(1):] + else: + labels = labels[:, -logits.size(1):] + + ans_indice = torch.where(labels != -100) + + refer_logits = refer_logits[ans_indice] + logits = logits[ans_indice] + + refer_log_probs = refer_logits.log_softmax(-1) + log_probs = logits.log_softmax(-1) + + return F.kl_div( + log_probs, + refer_log_probs, + reduction = "batchmean", + log_target = True + ) + +def succ_ratios( + logits: torch.FloatTensor, + labels: torch.LongTensor, + shift: bool +) -> List[float]: + + if len(logits.shape) == 2: + return ((logits > 0) == labels).squeeze(-1).to("cpu").numpy().tolist() + + if len(logits.shape) == 3: + if shift: # Dealing with sequences + logits = logits[:, :-1] # Remove last prediction in sequence + if logits.shape[1] >= labels.shape[1]: + logits = logits[:, -labels.size(1):] + else: + labels = labels[:, -logits.size(1):] + + n_corr = (logits.argmax(-1) == labels).sum(-1) + n_tokens = (labels != -100).sum(-1) + + return (n_corr / n_tokens).to("cpu").numpy().tolist() + + +class Tracer: + + def __init__( + self, + module: nn.Module, + cache_mask: torch.LongTensor + ): + cache_indices = torch.where(cache_mask) + + def forward_hook( + module: nn.Module, + inputs: Tuple[torch.FloatTensor], + outputs: Tuple[torch.FloatTensor] + ): + self.keys = inputs[0][cache_indices].detach() + + def backward_hook( + module: nn.Module, + inputs_grad: Tuple[torch.FloatTensor], + outputs_grad: Tuple[torch.FloatTensor] + ): + self.values_grad = outputs_grad[0][cache_indices].detach() + + self.handles = [ + module.register_forward_hook(forward_hook), + module.register_full_backward_hook(backward_hook) + ] + + +class TracerDict(dict): + + def __init__( + self, + model: nn.Module, + config, + tuples: Dict[str, torch.LongTensor] + ): + + if any("encoder" in m for m in config.inner_params) and any("decoder" in m for m in config.model.edit_modules): + + for module_name in config.inner_params: + if "encoder" in module_name: + cache_mask = tuples["attention_mask"] + else: + cache_mask = tuples["decoder_attention_mask"] + module = get_module(model, module_name) + self[module_name] = Tracer(module, cache_mask) + + else: + + if config.token == "ans": + cache_mask = tuples["labels"] != -100 + else: + cache_mask = tuples["attention_mask"] + + for module_name in config.inner_params: + module = get_module(model, module_name) + self[module_name] = Tracer(module, cache_mask) + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + for v in self.values(): + for h in v.handles: + h.remove() \ No newline at end of file diff --git a/easyeditor/trainer/training_hparams/__init__.py b/easyeditor/trainer/training_hparams/__init__.py index 7ec561af..c843178a 100644 --- a/easyeditor/trainer/training_hparams/__init__.py +++ b/easyeditor/trainer/training_hparams/__init__.py @@ -3,3 +3,4 @@ from .mend_multimodal_training_hparams import * from .serac_training_hparams import * from .serac_multimodal_training_hparams import * +from .malmen_training_hparams import * diff --git a/easyeditor/trainer/training_hparams/malmen_training_hparams.py b/easyeditor/trainer/training_hparams/malmen_training_hparams.py new file mode 100644 index 00000000..7967f8e8 --- /dev/null +++ b/easyeditor/trainer/training_hparams/malmen_training_hparams.py @@ -0,0 +1,75 @@ +from dataclasses import dataclass +from ...util.hparams import HyperParams +from typing import Optional, Any, List +import yaml + + +@dataclass +class MALMENTrainingHparams(HyperParams): + + # Model + model_name: str + model_class: str + tokenizer_class: str + tokenizer_name: str + inner_params: List[str] + + archive: Any + + # Method + alg: str + debug: bool + dropout: float + train_base: bool + no_grad_layers: Any + + rank: int + n_edits: int + n_blocks: int + lr: float + meta_lr: float + loc_coef: float + max_grad_norm: float + token: str + + # Output + results_dir: str + + # Train + device: str + batch_size: int + editor_batch_size: int + silent: bool + log_interval: int + eval_log_interval:int + final_eval:bool + val_interval: int + early_stop_patience: int + early_stop_key: str + eval_only: bool + save: bool + + val_batch_size: Optional[int] + val_steps: int + + model_save_pt: Optional[int]=5000 + half: Optional[bool] = False + model_parallel: bool = False + max_epochs: Optional[int] = None + max_iters: Optional[int] = None + + @classmethod + def from_hparams(cls, hparams_name_or_path: str): + + if '.yaml' not in hparams_name_or_path: + hparams_name_or_path = hparams_name_or_path + '.yaml' + + with open(hparams_name_or_path, "r") as stream: + config = yaml.safe_load(stream) + config = super().construct_float_from_scientific_notation(config) + + assert (config and config['alg'] == 'MALMEN') or print(f'MALMENTrainingHyperParams can not load from {hparams_name_or_path}, ' + f'alg_name is {config["alg"]} ') + config['val_batch_size'] = config['batch_size'] + return cls(**config) + diff --git a/easyeditor/util/alg_dict.py b/easyeditor/util/alg_dict.py index c951ec39..618793d1 100644 --- a/easyeditor/util/alg_dict.py +++ b/easyeditor/util/alg_dict.py @@ -11,6 +11,7 @@ from ..models.grace import GraceHyperParams, apply_grace_to_model from ..models.pmet import PMETHyperParams, apply_pmet_to_model from ..models.melo import MELOHyperParams, apply_melo_to_model +from ..models.malmen import MALMENHyperParams, MalmenRewriteExecutor ALG_DICT = { 'ROME': apply_rome_to_model, @@ -24,7 +25,8 @@ 'LoRA': apply_lora_to_model, 'GRACE': apply_grace_to_model, 'PMET': apply_pmet_to_model, - 'MELO': apply_melo_to_model + 'MELO': apply_melo_to_model, + 'MALMEN': MalmenRewriteExecutor().apply_to_model, } ALG_MULTIMODAL_DICT = { diff --git a/easyeditor/util/alg_train_dict.py b/easyeditor/util/alg_train_dict.py index 9cb803c2..c67d03d7 100644 --- a/easyeditor/util/alg_train_dict.py +++ b/easyeditor/util/alg_train_dict.py @@ -1,9 +1,11 @@ from ..trainer import MEND from ..trainer import SERAC, SERAC_MULTI +from ..trainer import MALMEN ALG_TRAIN_DICT = { 'MEND': MEND, 'SERAC': SERAC, 'SERAC_MULTI': SERAC_MULTI, + 'MALMEN': MALMEN, } \ No newline at end of file diff --git a/edit.py b/edit.py index a64c582d..245ea45d 100644 --- a/edit.py +++ b/edit.py @@ -3,7 +3,7 @@ from easyeditor import KNHyperParams, FTHyperParams, KETrainingHparams,\ ROMEHyperParams, MEMITHyperParams, MENDTrainingHparams, MENDHyperParams, \ SERACTrainingHparams, SERACHparams, IKEHyperParams, FTApiHyperParams, LoRAHyperParams, \ - GraceHyperParams, PMETHyperParams,MELOHyperParams + GraceHyperParams, PMETHyperParams,MELOHyperParams, MALMENTrainingHparams, MALMENHyperParams from easyeditor import ZsreDataset, CounterFactDataset from easyeditor import EditTrainer from easyeditor.models.ike import encode_ike_facts @@ -2562,6 +2562,50 @@ def test_melo(): pdb.set_trace() return metrics, edited_model + +def test_MALMEN_Train(): + training_hparams = MALMENTrainingHparams.from_hparams('./hparams/TRAINING/MALMEN/gpt2-xl.yaml') + train_ds = ZsreDataset('./data/zsre/zsre_mend_train.json', config=training_hparams) + print("train_ds", train_ds.__len__()) + eval_ds = ZsreDataset('./data/zsre/zsre_mend_eval.json', config=training_hparams) + print("eval_ds", eval_ds.__len__()) + + trainer = EditTrainer( + config=training_hparams, + train_set=train_ds, + val_set=eval_ds + ) + + trainer.run() + +def test_MALMEN(): + + prompts = ['What university did Watts Humphrey attend?', 'Which family does Ramalinaceae belong to', + 'What role does Denny Herzig play in football?', 'Who was the designer of Lahti Town Hall?', + 'What is the original channel that It\'s a Business played on?', 'What city did Marl Young live when he died?', + 'Steve Jobs was the founder of', 'LeBron James plays the sport of'] + ground_truth = ['Illinois Institute of Technology', 'Lecanorales', 'defender', + 'Eliel Saarinen', 'DuMont Television Network', 'Los Angeles', 'Apple', 'basketball'] + target_new = ['University of Michigan', 'Lamiinae', 'winger', + 'Alfred Lahti', 'ITV', 'New Orleans', 'Microsoft', 'football'] + + # prompts = ['What university did Watts Humphrey attend?'] + # ground_truth = ['Illinois Institute of Technology'] + # target_new = ['University of Michigan'] + hparams = MALMENHyperParams.from_hparams('./hparams/MALMEN/gpt2-xl') + editor = BaseEditor.from_hparams(hparams) + metrics, edited_model, _ = editor.edit( + prompts=prompts, + ground_truth=ground_truth, + target_new=target_new, + keep_original_weight=True + ) + + import pdb + pdb.set_trace() + + return metrics, edited_model + def main(): # metrics, edited_model = test_KN() @@ -2652,6 +2696,9 @@ def main(): # test_MEND_Train_Mistral() # test_MEND_Mistral() # test_MEMIT_Mistral() + # test_MALMEN_Train() + test_MALMEN() + if __name__ == '__main__': main() diff --git a/hparams/MALMEN/gpt2-xl.yaml b/hparams/MALMEN/gpt2-xl.yaml new file mode 100644 index 00000000..7a872d0c --- /dev/null +++ b/hparams/MALMEN/gpt2-xl.yaml @@ -0,0 +1,54 @@ +alg_name: "MALMEN" +archive: ./results/models/MALMEN/gpt2-xl.bk +device: 0 +# Model +model_name: /newdisk3/yunzhi/gpt2-xl +model_class: GPT2LMHeadModel +tokenizer_class: GPT2TokenizerFast +tokenizer_name: /newdisk3/yunzhi/gpt2-xl +inner_params: +- transformer.h.42.mlp.c_proj.weight +- transformer.h.43.mlp.c_proj.weight +- transformer.h.44.mlp.c_proj.weight +- transformer.h.45.mlp.c_proj.weight +- transformer.h.46.mlp.c_proj.weight +- transformer.h.47.mlp.c_proj.weight + +# Method +alg: MALMEN +dropout: 0.0 +train_base: False +no_grad_layers: null + +rank: 1920 +n_blocks: 2 +lr: 1e-6 +meta_lr: 1e-5 +loc_coef: 1 +max_grad_norm: 1 +token: mask + +# Train +n_edits: 1 +batch_size: 1 +editor_batch_size: 1024 +silent: False +# max_epochs: 1 +max_iters: 10000 +log_interval: 100 +eval_log_interval: 100 +final_eval: True +val_interval: 100 +early_stop_patience: 1000 +early_stop_key: "edit_acc_val" +eval_only: False +debug: False +save: False + +val_batch_size: 1 +val_steps: 200 # only for debug + +model_parallel: false + +# Output +results_dir: ./results \ No newline at end of file diff --git a/hparams/TRAINING/MALMEN/gpt2-xl.yaml b/hparams/TRAINING/MALMEN/gpt2-xl.yaml new file mode 100644 index 00000000..c84d5ced --- /dev/null +++ b/hparams/TRAINING/MALMEN/gpt2-xl.yaml @@ -0,0 +1,53 @@ +# Model +model_name: ./hugging_cache/gpt2-xl +model_class: GPT2LMHeadModel +tokenizer_class: GPT2TokenizerFast +tokenizer_name: ./hugging_cache/gpt2-xl +inner_params: +- transformer.h.42.mlp.c_proj.weight +- transformer.h.43.mlp.c_proj.weight +- transformer.h.44.mlp.c_proj.weight +- transformer.h.45.mlp.c_proj.weight +- transformer.h.46.mlp.c_proj.weight +- transformer.h.47.mlp.c_proj.weight + +archive: null + +# Method +alg: MALMEN +dropout: 0.0 +train_base: False +no_grad_layers: null +rank: 1920 +n_blocks: 2 +lr: 1e-6 +meta_lr: 1e-5 +loc_coef: 1 +max_grad_norm: 1 +# token: ans +token: mask + +# Train +device: cuda:0 +n_edits: 1 +batch_size: 1 +editor_batch_size: 1024 +silent: False +# max_epochs: 1 +max_iters: 10000 +log_interval: 100 +eval_log_interval: 100 +final_eval: True +val_interval: 100 +early_stop_patience: 1000 +early_stop_key: "edit_acc_val" +eval_only: False +debug: False +save: False + +val_batch_size: 1 +val_steps: 200 # only for debug +model_parallel: false + +# Output +results_dir: ./results \ No newline at end of file