-
Notifications
You must be signed in to change notification settings - Fork 190
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #189 from xzwyyd/main
Support MALMEN for issue #116
- Loading branch information
Showing
17 changed files
with
1,122 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,5 @@ | |
from .pmet import * | ||
from .melo import * | ||
from .grace import * | ||
from .malmen import * | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .malmen_hparams import MALMENHyperParams | ||
from .malmen_main import MalmenRewriteExecutor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from dataclasses import dataclass | ||
from ...util.hparams import HyperParams | ||
from typing import Optional, Any, List | ||
import yaml | ||
|
||
|
||
@dataclass | ||
class MALMENHyperParams(HyperParams): | ||
alg_name: str | ||
|
||
# Model | ||
model_name: str | ||
model_class: str | ||
tokenizer_class: str | ||
tokenizer_name: str | ||
inner_params: List[str] | ||
device: int | ||
archive: Any | ||
|
||
# Method | ||
alg: str | ||
debug: bool | ||
dropout: float | ||
train_base: bool | ||
no_grad_layers: Any | ||
rank: int | ||
n_edits: int | ||
n_blocks: int | ||
lr: float | ||
meta_lr: float | ||
loc_coef: float | ||
max_grad_norm: float | ||
token: str | ||
|
||
# Output | ||
results_dir: str | ||
|
||
# Train | ||
batch_size: int | ||
editor_batch_size: int | ||
silent: bool | ||
log_interval: int | ||
eval_log_interval:int | ||
final_eval:bool | ||
val_interval: int | ||
early_stop_patience: int | ||
early_stop_key: str | ||
eval_only: bool | ||
save: bool | ||
|
||
val_batch_size: Optional[int] | ||
val_steps: int | ||
|
||
max_length: int = 40 | ||
|
||
model_save_pt: Optional[int]=5000 | ||
half: Optional[bool] = False | ||
model_parallel: bool = False | ||
max_epochs: Optional[int] = None | ||
max_iters: Optional[int] = None | ||
|
||
@classmethod | ||
def from_hparams(cls, hparams_name_or_path: str): | ||
|
||
if '.yaml' not in hparams_name_or_path: | ||
hparams_name_or_path = hparams_name_or_path + '.yaml' | ||
|
||
with open(hparams_name_or_path, "r") as stream: | ||
config = yaml.safe_load(stream) | ||
config = super().construct_float_from_scientific_notation(config) | ||
|
||
assert (config and config['alg'] == 'MALMEN') or print(f'MALMENTrainingHyperParams can not load from {hparams_name_or_path}, ' | ||
f'alg_name is {config["alg"]} ') | ||
config['val_batch_size'] = config['batch_size'] | ||
return cls(**config) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import os | ||
from copy import deepcopy | ||
from typing import Dict, List, Any, Tuple | ||
|
||
import hydra | ||
import torch | ||
from collections import deque | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from ...util.globals import * | ||
|
||
from ...trainer import MALMEN | ||
from .malmen_hparams import MALMENHyperParams | ||
|
||
class MalmenRewriteExecutor: | ||
def __init__(self): | ||
self.is_init = False | ||
|
||
def init_model(self, model, tok, params: MALMENHyperParams): | ||
|
||
assert params.archive is not None or print(f'Training weights Needed....') | ||
# Customize the gpt2xl and tokenizer | ||
self.model = model | ||
self.tokenizer = tok | ||
# add_padding(self.tokenizer, self.model) | ||
|
||
# Load the trained MEND model | ||
self.alg = MALMEN(self.model, params, lambda: deepcopy(self.model)) | ||
d = torch.load(params.archive, map_location=f'cuda:{params.device}') | ||
self.alg.load_state_dict(d["model"]) | ||
if params.model_parallel: | ||
self.alg.net.to(deque(self.alg.model.parameters(), maxlen=1)[0].device) | ||
else: | ||
self.alg.to(torch.device(f'cuda:{params.device}')) | ||
|
||
|
||
def reset_model(self): | ||
self.is_init = False | ||
del self.model, self.tokenizer, self.alg | ||
|
||
def apply_to_model( | ||
self, | ||
model: AutoModelForCausalLM, | ||
tok: AutoTokenizer, | ||
requests: List[Dict], | ||
hparams: MALMENHyperParams, | ||
copy=False, | ||
return_orig_weights=False, | ||
keep_original_weight=False, | ||
**kwargs | ||
): | ||
""" | ||
Given a request, for example | ||
{'prompt': '{} has the position of', | ||
'subject': 'Charles Herman Helmsing', | ||
'relation_id': 'P39', | ||
'target_new': {'str': 'President', 'id': 'Q11696'}, | ||
'target_true': {'str': 'bishop', 'id': 'Q29182'}} | ||
Returns a dictionary of numpy arrays that specifies | ||
how mend will change the weights of the model. | ||
""" | ||
|
||
if not self.is_init: | ||
self.init_model(model, tok, hparams) | ||
|
||
weights_copy = {} | ||
model = deepcopy(self.model) if copy else self.model | ||
assert len(requests) >= hparams.n_edits, "The number of requests must be greater than or equal to the value of n_edits." | ||
# Define i/o | ||
requests = requests[:hparams.n_edits] | ||
batchs = [] | ||
for i in range(hparams.n_edits // hparams.batch_size): | ||
batch = requests[i * hparams.batch_size : (i+1)*hparams.batch_size] | ||
targets = [ | ||
(" " if request["target_new"][0] != " " else "") | ||
+ request["target_new"] | ||
for request in batch | ||
] | ||
sentences = [ | ||
request["prompt"] + targets[i] | ||
for i, request in enumerate(batch) | ||
] | ||
|
||
# Tokenize | ||
sent_tok = self.tokenizer(sentences, padding=True, return_tensors="pt").to( | ||
f"cuda:{hparams.device}" | ||
) | ||
target_tok = self.tokenizer(targets, padding=True, return_tensors="pt").to( | ||
f"cuda:{hparams.device}" | ||
) | ||
|
||
# Define labels | ||
label_tok = deepcopy(sent_tok["input_ids"]) | ||
for i in range(label_tok.size(0)): | ||
target_len = target_tok["attention_mask"][i].sum() | ||
padding_len = ( | ||
sent_tok["input_ids"].size(1) - sent_tok["attention_mask"][i].sum() | ||
) | ||
label_tok[i][: -target_len - padding_len] = -100 | ||
label_tok[i][label_tok[i] == self.tokenizer.pad_token_id] = -100 | ||
|
||
edit_inner = dict( | ||
input_ids=sent_tok["input_ids"], | ||
attention_mask=sent_tok["attention_mask"], | ||
labels=target_tok['input_ids'], | ||
) | ||
|
||
batchs.append(edit_inner) | ||
# Run M | ||
module_kv_map = self.alg.cache(batchs) | ||
param_shifts = self.alg.predict_param_shifts(module_kv_map) | ||
with torch.no_grad(): | ||
for n, p in self.model.named_parameters(): | ||
if n in hparams.inner_params: | ||
if return_orig_weights and n not in weights_copy: | ||
weights_copy[n] = p.detach().clone() | ||
self.alg.edit_model(param_shifts, False) | ||
|
||
|
||
if not keep_original_weight: | ||
weights_copy = {} | ||
|
||
return self.alg.model, weights_copy | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.