Skip to content

Commit

Permalink
Merge pull request #189 from xzwyyd/main
Browse files Browse the repository at this point in the history
Support MALMEN for issue #116
  • Loading branch information
xzwyyd committed Mar 1, 2024
2 parents 8a1df2c + b0413f2 commit ded7e7c
Show file tree
Hide file tree
Showing 17 changed files with 1,122 additions and 10 deletions.
4 changes: 4 additions & 0 deletions easyeditor/dataset/zsre.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ def collate_gpt_fn(self, batch):
rephrase = [rephrase_ + ' ' + trg_ for rephrase_, trg_ in zip(rephrase, trg)]
loc = [loc_ + ' ' + loc_ans_ for loc_, loc_ans_ in zip(loc, loc_ans)]

if 'gpt' in self.config.tokenizer_class.lower():
trg = [' ' + t for t in trg]
loc_ans = [' ' + t for t in loc_ans]

batches = {
f"{k1}_{k2}": v2
for k1, v1 in {
Expand Down
2 changes: 2 additions & 0 deletions easyeditor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@
from .pmet import *
from .melo import *
from .grace import *
from .malmen import *

2 changes: 2 additions & 0 deletions easyeditor/models/malmen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .malmen_hparams import MALMENHyperParams
from .malmen_main import MalmenRewriteExecutor
76 changes: 76 additions & 0 deletions easyeditor/models/malmen/malmen_hparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from dataclasses import dataclass
from ...util.hparams import HyperParams
from typing import Optional, Any, List
import yaml


@dataclass
class MALMENHyperParams(HyperParams):
alg_name: str

# Model
model_name: str
model_class: str
tokenizer_class: str
tokenizer_name: str
inner_params: List[str]
device: int
archive: Any

# Method
alg: str
debug: bool
dropout: float
train_base: bool
no_grad_layers: Any
rank: int
n_edits: int
n_blocks: int
lr: float
meta_lr: float
loc_coef: float
max_grad_norm: float
token: str

# Output
results_dir: str

# Train
batch_size: int
editor_batch_size: int
silent: bool
log_interval: int
eval_log_interval:int
final_eval:bool
val_interval: int
early_stop_patience: int
early_stop_key: str
eval_only: bool
save: bool

val_batch_size: Optional[int]
val_steps: int

max_length: int = 40

model_save_pt: Optional[int]=5000
half: Optional[bool] = False
model_parallel: bool = False
max_epochs: Optional[int] = None
max_iters: Optional[int] = None

@classmethod
def from_hparams(cls, hparams_name_or_path: str):

if '.yaml' not in hparams_name_or_path:
hparams_name_or_path = hparams_name_or_path + '.yaml'

with open(hparams_name_or_path, "r") as stream:
config = yaml.safe_load(stream)
config = super().construct_float_from_scientific_notation(config)

assert (config and config['alg'] == 'MALMEN') or print(f'MALMENTrainingHyperParams can not load from {hparams_name_or_path}, '
f'alg_name is {config["alg"]} ')
config['val_batch_size'] = config['batch_size']
return cls(**config)

124 changes: 124 additions & 0 deletions easyeditor/models/malmen/malmen_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import os
from copy import deepcopy
from typing import Dict, List, Any, Tuple

import hydra
import torch
from collections import deque
from transformers import AutoModelForCausalLM, AutoTokenizer

from ...util.globals import *

from ...trainer import MALMEN
from .malmen_hparams import MALMENHyperParams

class MalmenRewriteExecutor:
def __init__(self):
self.is_init = False

def init_model(self, model, tok, params: MALMENHyperParams):

assert params.archive is not None or print(f'Training weights Needed....')
# Customize the gpt2xl and tokenizer
self.model = model
self.tokenizer = tok
# add_padding(self.tokenizer, self.model)

# Load the trained MEND model
self.alg = MALMEN(self.model, params, lambda: deepcopy(self.model))
d = torch.load(params.archive, map_location=f'cuda:{params.device}')
self.alg.load_state_dict(d["model"])
if params.model_parallel:
self.alg.net.to(deque(self.alg.model.parameters(), maxlen=1)[0].device)
else:
self.alg.to(torch.device(f'cuda:{params.device}'))


def reset_model(self):
self.is_init = False
del self.model, self.tokenizer, self.alg

def apply_to_model(
self,
model: AutoModelForCausalLM,
tok: AutoTokenizer,
requests: List[Dict],
hparams: MALMENHyperParams,
copy=False,
return_orig_weights=False,
keep_original_weight=False,
**kwargs
):
"""
Given a request, for example
{'prompt': '{} has the position of',
'subject': 'Charles Herman Helmsing',
'relation_id': 'P39',
'target_new': {'str': 'President', 'id': 'Q11696'},
'target_true': {'str': 'bishop', 'id': 'Q29182'}}
Returns a dictionary of numpy arrays that specifies
how mend will change the weights of the model.
"""

if not self.is_init:
self.init_model(model, tok, hparams)

weights_copy = {}
model = deepcopy(self.model) if copy else self.model
assert len(requests) >= hparams.n_edits, "The number of requests must be greater than or equal to the value of n_edits."
# Define i/o
requests = requests[:hparams.n_edits]
batchs = []
for i in range(hparams.n_edits // hparams.batch_size):
batch = requests[i * hparams.batch_size : (i+1)*hparams.batch_size]
targets = [
(" " if request["target_new"][0] != " " else "")
+ request["target_new"]
for request in batch
]
sentences = [
request["prompt"] + targets[i]
for i, request in enumerate(batch)
]

# Tokenize
sent_tok = self.tokenizer(sentences, padding=True, return_tensors="pt").to(
f"cuda:{hparams.device}"
)
target_tok = self.tokenizer(targets, padding=True, return_tensors="pt").to(
f"cuda:{hparams.device}"
)

# Define labels
label_tok = deepcopy(sent_tok["input_ids"])
for i in range(label_tok.size(0)):
target_len = target_tok["attention_mask"][i].sum()
padding_len = (
sent_tok["input_ids"].size(1) - sent_tok["attention_mask"][i].sum()
)
label_tok[i][: -target_len - padding_len] = -100
label_tok[i][label_tok[i] == self.tokenizer.pad_token_id] = -100

edit_inner = dict(
input_ids=sent_tok["input_ids"],
attention_mask=sent_tok["attention_mask"],
labels=target_tok['input_ids'],
)

batchs.append(edit_inner)
# Run M
module_kv_map = self.alg.cache(batchs)
param_shifts = self.alg.predict_param_shifts(module_kv_map)
with torch.no_grad():
for n, p in self.model.named_parameters():
if n in hparams.inner_params:
if return_orig_weights and n not in weights_copy:
weights_copy[n] = p.detach().clone()
self.alg.edit_model(param_shifts, False)


if not keep_original_weight:
weights_copy = {}

return self.alg.model, weights_copy

42 changes: 34 additions & 8 deletions easyeditor/trainer/BaseTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self, config, train_set: Dataset, val_set: Dataset):
# Eval once and quit
self.config.max_iters = 0

if not self.config.eval_only:
if not self.config.eval_only and self.config.alg!='MALMEN':
self.OptimizerClass = getattr(torch.optim, config.opt)
LOG.info(f"Building optimizer {self.OptimizerClass} with lr {config.lr}")
self.opt = self.OptimizerClass(self.model.outer_parameters(), lr=config.lr)
Expand All @@ -87,7 +87,10 @@ def __init__(self, config, train_set: Dataset, val_set: Dataset):
self.model.load_state_dict(archive["model"])
del archive["model"]
if not self.config.eval_only:
self.opt.load_state_dict(archive["opt"])
if self.config.alg=='MALMEN':
self.model.opt.load_state_dict(archive["opt"])
else:
self.opt.load_state_dict(archive["opt"])
del archive["opt"]

self.archive = (
Expand All @@ -114,7 +117,7 @@ def save_state(self, stats):

obj = {
"model": self.model.state_dict(),
"opt": self.opt.state_dict(),
"opt": self.opt.state_dict() if self.config.alg!='MALMEN' else self.model.opt.state_dict(),
"lr_opt": self.lr_opt.state_dict() if self.lr_opt is not None else None,
"val_stats": stats,
"start_time": self.start_time,
Expand Down Expand Up @@ -156,11 +159,21 @@ def run(self):
self.config.max_iters = min(self.config.max_iters, self.config.max_epochs * len(self.train_set))
else:
self.config.max_iters = self.config.max_epochs * len(self.train_set)
if self.config.alg == 'MALMEN':
self.config.max_iters = math.ceil(self.config.max_iters / self.config.batch_size)
LOG.info(f'MAX EPOCH: {self.config.max_epochs}, set max iters to {self.config.max_iters}')

if self.config.alg == 'MALMEN':
n_edits_step = math.ceil(self.config.n_edits / self.config.batch_size)
if self.config.log_interval % n_edits_step:
self.config.log_interval = (self.config.log_interval // n_edits_step) * n_edits_step if self.config.log_interval >= n_edits_step else n_edits_step
if self.config.val_interval % n_edits_step:
self.config.val_interval = (self.config.val_interval // n_edits_step) * n_edits_step if self.config.val_interval >= n_edits_step else n_edits_step
self.epoches = round(float(self.config.max_iters) / (len(self.train_set) / self.config.batch_size))
if self.epoches < 1:
self.epoches = 1
self.global_iter = 0
should_stop = False
n_edits_batch = []
for epoch in range(self.epoches):
if should_stop:
break
Expand All @@ -170,15 +183,25 @@ def run(self):
should_stop = True
break
if not self.config.eval_only:
train_info = self.train_step(batch)
averager.add(train_info)
if self.config.alg == 'MALMEN':
n_edits_batch.append(batch)
if len(n_edits_batch) == math.ceil(self.config.n_edits / self.config.batch_size):
train_info = self.model.train(n_edits_batch)
averager.add(train_info)
n_edits_batch = []
else:
train_info = self.train_step(batch)
averager.add(train_info)

if self.global_iter % self.config.log_interval == 0:
avg_info = averager.average()
averager.reset()
self.echo(self.global_iter, avg_info)
if self.global_iter % self.config.val_interval == 0:
val_info = self.validate(steps=self.config.val_steps)
if self.config.alg == 'MALMEN':
val_info = self.model.valid(config=self.config, loader=self.val_loader, val_set=self.val_set, steps=self.config.val_steps)
else:
val_info = self.validate(steps=self.config.val_steps)
self.echo(self.global_iter, val_info)
if True:
self.save_state(val_info) # New best
Expand Down Expand Up @@ -213,7 +236,10 @@ def run(self):
self.model.to(self.config.device)

val_steps = self.config.val_steps if self.config.debug else None
val_info = self.validate(log=True, steps=val_steps)
if self.config.alg == 'MALMEN':
val_info = self.model.valid(log=True, steps=val_steps, config=self.config, loader=self.val_loader, val_set=self.val_set)
else:
val_info = self.validate(log=True, steps=val_steps)
self.echo(self.global_iter, val_info, pretty=True)

if self.config.results_dir is not None:
Expand Down

0 comments on commit ded7e7c

Please sign in to comment.