Skip to content

Commit

Permalink
adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
zphang committed Feb 11, 2019
1 parent eca4597 commit 401720a
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 53 deletions.
2 changes: 1 addition & 1 deletion configs/bert-base-uncased-adapter.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
"type_vocab_size": 2,
"vocab_size": 30522,
"use_adapter": true,
"adapter_initializer_range": 0.001,
"adapter_initializer_range": 0.0001,
"adapter_size": 64
}
16 changes: 16 additions & 0 deletions configs/bert-large-uncased-adapter.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"max_position_embeddings": 512,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"type_vocab_size": 2,
"vocab_size": 30522,
"use_adapter": true,
"adapter_initializer_range": 0.0001,
"adapter_size": 64
}
5 changes: 5 additions & 0 deletions glue/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"mnli-mm": tasks.MnliMismatchedProcessor,
"qnli": tasks.QnliProcessor,
"rte": tasks.RteProcessor,
"wnli": tasks.WnliProcessor,
"xnli": tasks.XnliProcessor,
"snli": tasks.SnliProcessor,
"bcs": tasks.BcsProcessor,
Expand All @@ -34,6 +35,7 @@
"mnli-mm": "classification",
"qnli": "classification",
"rte": "classification",
"wnli": "classification",
"xnli": "classification",
"snli": "classification",
"bcs": "classification",
Expand All @@ -48,6 +50,7 @@
"mnli-mm": "MNLI",
"qnli": "QNLI",
"rte": "RTE",
"wnli": "WNLI",
}


Expand Down Expand Up @@ -95,6 +98,8 @@ def compute_metrics(task_name, pred_srs, label_srs):
return {"acc": simple_accuracy(pred_srs, label_srs)}
elif task_name == "rte":
return {"acc": simple_accuracy(pred_srs, label_srs)}
elif task_name == "wnli":
return {"acc": simple_accuracy(pred_srs, label_srs)}
else:
raise KeyError(task_name)

Expand Down
71 changes: 58 additions & 13 deletions glue/model_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pytorch_pretrained_bert.tokenization import (
BertTokenizer, PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP,
)
import pytorch_pretrained_bert.utils as utils

from glue.tasks import TaskType

Expand Down Expand Up @@ -44,7 +45,7 @@ def create_model(task_type, bert_model_name, bert_load_mode, all_state,
cache_dir=cache_dir,
num_labels=num_labels,
)
elif bert_load_mode in ["model_only", "state_model_only", "state_all"]:
elif bert_load_mode in ["model_only", "state_model_only", "state_all", "state_full_model"]:
model = load_bert(
task_type=task_type,
bert_model_name=bert_model_name,
Expand Down Expand Up @@ -94,23 +95,36 @@ def load_bert(task_type, bert_model_name, bert_load_mode, all_state, num_labels,
bert_config_json_path = os.path.join(get_bert_config_path(bert_model_name), "bert_config.json")
if bert_load_mode == "model_only":
state_dict = all_state
elif bert_load_mode in ["state_model_only", "state_all"]:
elif bert_load_mode in ["state_model_only", "state_all", "state_full_model"]:
state_dict = all_state["model"]
else:
raise KeyError(bert_load_mode)

if task_type == TaskType.CLASSIFICATION:
model = BertForSequenceClassification.from_state_dict(
config_file=bert_config_json_path,
state_dict=state_dict,
num_labels=num_labels,
)
if bert_load_mode == "state_full_model":
model = BertForSequenceClassification.from_state_dict_full(
config_file=bert_config_json_path,
state_dict=state_dict,
num_labels=num_labels,
)
else:
model = BertForSequenceClassification.from_state_dict(
config_file=bert_config_json_path,
state_dict=state_dict,
num_labels=num_labels,
)
elif task_type == TaskType.REGRESSION:
assert num_labels == 1
model = BertForSequenceRegression.from_state_dict(
config_file=bert_config_json_path,
state_dict=state_dict,
)
if bert_load_mode == "state_full_model":
model = BertForSequenceRegression.from_state_dict_full(
config_file=bert_config_json_path,
state_dict=state_dict,
)
else:
model = BertForSequenceRegression.from_state_dict(
config_file=bert_config_json_path,
state_dict=state_dict,
)
else:
raise KeyError(task_type)
return model
Expand All @@ -120,7 +134,7 @@ def create_tokenizer(bert_model_name, bert_load_mode, do_lower_case, bert_vocab_
if bert_load_mode == "from_pretrained":
assert bert_vocab_path is None
tokenizer = BertTokenizer.from_pretrained(bert_model_name, do_lower_case=do_lower_case)
elif bert_load_mode in ["model_only", "state_model_only", "state_all"]:
elif bert_load_mode in ["model_only", "state_model_only", "state_all", "state_full_model"]:
tokenizer = load_tokenizer(
bert_model_name=bert_model_name,
do_lower_case=do_lower_case,
Expand All @@ -146,7 +160,10 @@ def load_tokenizer(bert_model_name, do_lower_case, bert_vocab_path=None):
def create_optimizer(model, learning_rate, t_total, loss_scale, fp16, warmup_proportion, state_dict):
# Prepare optimizer
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
no_decay = [
'bias', 'LayerNorm.bias', 'LayerNorm.weight',
'adapter.down_project.weight', 'adapter.up_project.weight',
]
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
Expand Down Expand Up @@ -177,3 +194,31 @@ def create_optimizer(model, learning_rate, t_total, loss_scale, fp16, warmup_pro
if state_dict is not None:
optimizer.load_state_dict(state_dict)
return optimizer


def save_bert(model, optimizer, args, save_path, save_mode="all"):
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model itself
if save_mode == "all":
model_state_dict = model_to_save.state_dict()
elif save_mode == "tunable":
# Drop non-trainable params, but keep
# Sort of a hack, because it's not really clear when we want/don't want state params,
# But for now, layer norm works in our favor. But this will be annoying.
model_state_dict = model_to_save.state_dict()
for name, param in model.named_parameters():
if not param.requires_grad:
print(" Skip {}".format(name))
del model_state_dict[name]
else:
raise KeyError(save_mode)

optimizer_state_dict = utils.to_cpu(optimizer.state_dict()) if optimizer is not None else None

print("Saving {} model elems:".format(len(model_state_dict)))
print("Saving {} optim elems:".format(len(optimizer_state_dict)))

torch.save({
"model": model_state_dict,
"optimizer": optimizer_state_dict,
"args": vars(args),
}, save_path)
44 changes: 30 additions & 14 deletions glue/runners.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections as col
import logging
import numpy as np
from tqdm import tqdm, trange
Expand Down Expand Up @@ -216,16 +217,27 @@ def __init__(self, model, optimizer, tokenizer, label_list, device, rparams):
self.device = device
self.rparams = rparams

def run_train(self, train_examples):
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_examples))
logger.info(" Batch size = %d", self.rparams.train_batch_size)
logger.info(" Num steps = %d", self.rparams.num_train_steps)
train_dataloader = self.get_train_dataloader(train_examples)
def run_train(self, train_examples, verbose=True):
if verbose:
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_examples))
logger.info(" Batch size = %d", self.rparams.train_batch_size)
logger.info(" Num steps = %d", self.rparams.num_train_steps)
train_dataloader = self.get_train_dataloader(train_examples, verbose=verbose)

for _ in trange(int(self.rparams.num_train_epochs), desc="Epoch"):
self.run_train_epoch(train_dataloader)

def run_train_val(self, train_examples, val_examples, task_name):
epoch_result_dict = col.OrderedDict()
for i in trange(int(self.rparams.num_train_epochs), desc="Epoch"):
train_dataloader = self.get_train_dataloader(train_examples, verbose=False)
self.run_train_epoch(train_dataloader)
epoch_result = self.run_val(val_examples, task_name, verbose=False)
del epoch_result["logits"]
epoch_result_dict[i] = epoch_result
return epoch_result_dict

def run_train_epoch(self, train_dataloader):
self.model.train()
tr_loss = 0
Expand Down Expand Up @@ -256,8 +268,8 @@ def run_train_epoch(self, train_dataloader):
self.optimizer.zero_grad()
global_step += 1

def run_val(self, val_examples, task_name):
val_dataloader = self.get_eval_dataloader(val_examples)
def run_val(self, val_examples, task_name, verbose=True):
val_dataloader = self.get_eval_dataloader(val_examples, verbose=verbose)
self.model.eval()
total_eval_loss = 0
nb_eval_steps, nb_eval_examples = 0, 0
Expand Down Expand Up @@ -289,8 +301,8 @@ def run_val(self, val_examples, task_name):
"metrics": compute_task_metrics(task_name, all_logits, all_labels),
}

def run_test(self, test_examples):
test_dataloader = self.get_eval_dataloader(test_examples)
def run_test(self, test_examples, verbose=True):
test_dataloader = self.get_eval_dataloader(test_examples, verbose=verbose)
self.model.eval()
all_logits = []
for step, batch in enumerate(tqdm(test_dataloader, desc="Predictions (Test)")):
Expand All @@ -302,9 +314,11 @@ def run_test(self, test_examples):
all_logits = np.concatenate(all_logits, axis=0)
return all_logits

def get_train_dataloader(self, train_examples):
def get_train_dataloader(self, train_examples, verbose=True):
train_features = convert_examples_to_features(
train_examples, self.label_map, self.rparams.max_seq_length, self.tokenizer)
train_examples, self.label_map, self.rparams.max_seq_length, self.tokenizer,
verbose=verbose,
)
train_data, train_tokens = convert_to_dataset(train_features)
if self.rparams.local_rank == -1:
train_sampler = RandomSampler(train_data)
Expand All @@ -315,9 +329,11 @@ def get_train_dataloader(self, train_examples):
)
return HybridLoader(train_dataloader, train_tokens)

def get_eval_dataloader(self, eval_examples):
def get_eval_dataloader(self, eval_examples, verbose=True):
eval_features = convert_examples_to_features(
eval_examples, self.label_map, self.rparams.max_seq_length, self.tokenizer)
eval_examples, self.label_map, self.rparams.max_seq_length, self.tokenizer,
verbose=verbose,
)
eval_data, eval_tokens = convert_to_dataset(eval_features)
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(
Expand Down
8 changes: 5 additions & 3 deletions glue/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,8 @@ def get_train_examples(self, data_dir):
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
self._read_tsv(os.path.join(data_dir, "dev.tsv")),
"dev")

def get_test_examples(self, data_dir):
"""See base class."""
Expand All @@ -435,7 +435,7 @@ def get_test_examples(self, data_dir):

def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
return ["0", "1", "2"]

def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
Expand Down Expand Up @@ -599,6 +599,7 @@ def _create_examples(self, lines, set_type):
"mnli": MnliProcessor,
"qnli": QnliProcessor,
"rte": RteProcessor,
"wnli": WnliProcessor,
"xnli": XnliProcessor,
"snli": SnliProcessor,
"bcs": BcsProcessor,
Expand All @@ -614,5 +615,6 @@ def _create_examples(self, lines, set_type):
"mnli": "MNLI",
"qnli": "QNLI",
"rte": "RTE",
"wnli": "WNLI",
"snli": "SNLI",
}
Loading

0 comments on commit 401720a

Please sign in to comment.