Skip to content

Commit

Permalink
support more optimizers and add more unit tests (#40)
Browse files Browse the repository at this point in the history
* simplify optimzier

* support all pytorch optimizers

* fix optimizer setting

* add optimizer tests

* add step test for bert adam

* add coveragerc to ignore contrib

* add module tests

* add module tests 2

* add module tests 3

* add module tests 4

* add module tests 5

* add test utils 1

* add test utils 2

* add scheduler tests

* add model test

* update CHANGELOG

* update doc
  • Loading branch information
senwu committed Dec 1, 2019
1 parent c33b4e0 commit bc201e9
Show file tree
Hide file tree
Showing 35 changed files with 1,742 additions and 80 deletions.
4 changes: 4 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[run]
omit =
# omit anything in the contrib directory
*/contrib/*
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ Unreleased_

Added
^^^^^
* `@senwu`_: Support more unit tests.
* `@senwu`_: Support all pytorch optimizers.
* `@senwu`_: Support accuracy@k metric.
* `@senwu`_: Support cosine annealing lr scheduler.

Expand Down
49 changes: 43 additions & 6 deletions docs/user/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,61 @@ The default ``.emmental-config.yaml`` configuration file is shown below::
train_split: train # the split for training, accepts str or list of strs
valid_split: valid # the split for validation, accepts str or list of strs
test_split: test # the split for testing, accepts str or list of strs
ignore_index: 0 # the ignore index, uses for masking samples
ignore_index: # the ignore index, uses for masking samples
optimizer_config:
optimizer: adam # [sgd, adam, adamax, bert_adam]
lr: 0.001 # Learing rate
l2: 0.0 # l2 regularization
grad_clip: 1.0 # gradient clipping
sgd_config:
momentum: 0.9
grad_clip: # gradient clipping
asgd_config:
lambd: 0.0001
alpha: 0.75
t0: 1000000.0
adadelta_config:
rho: 0.9
eps: 0.000001
adagrad_config:
lr_decay: 0
initial_accumulator_value: 0
eps: 0.0000000001
adam_config:
betas: !!python/tuple [0.9, 0.999]
eps: 0.00000001
amsgrad: False
adamw_config:
betas: !!python/tuple [0.9, 0.999]
eps: 0.00000001
amsgrad: False
adamax_config:
betas: !!python/tuple [0.9, 0.999]
eps: 0.00000001
lbfgs_config:
max_iter: 20
max_eval:
tolerance_grad: 0.0000001
tolerance_change: 0.000000001
history_size: 100
line_search_fn:
rms_prop_config:
alpha: 0.99
eps: 0.00000001
momentum: 0
centered: False
r_prop_config:
etas: !!python/tuple [0.5, 1.2]
step_sizes: !!python/tuple [0.000001, 50]
sgd_config:
momentum: 0
dampening: 0
nesterov: False
sparse_adam_config:
betas: !!python/tuple [0.9, 0.999]
eps: 0.00000001
bert_adam_config:
betas: !!python/tuple [0.9, 0.999]
eps: 0.00000001
lr_scheduler_config:
lr_scheduler: # [linear, exponential, reduce_on_plateau]
lr_scheduler: # [linear, exponential, reduce_on_plateau, cosine_annealing]
warmup_steps: # warm up steps
warmup_unit: batch # [epoch, batch]
warmup_percentage: # warm up percentage
Expand All @@ -79,6 +114,8 @@ The default ``.emmental-config.yaml`` configuration file is shown below::
- 1000
gamma: 0.1
last_epoch: -1
cosine_annealing_config:
last_epoch: -1
task_scheduler_config:
task_scheduler: round_robin # [sequential, round_robin, mixed]
sequential_scheduler_config:
Expand All @@ -92,7 +129,7 @@ The default ``.emmental-config.yaml`` configuration file is shown below::
# Logging configuration
logging_config:
counter_unit: epoch # [epoch, batch]
evaluation_freq: 2
evaluation_freq: 1
writer_config:
writer: tensorboard # [json, tensorboard]
verbose: True
Expand Down
47 changes: 42 additions & 5 deletions docs/user/learning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,61 @@ The learning parameters of the model are described below::
train_split: train # the split for training, accepts str or list of strs
valid_split: valid # the split for validation, accepts str or list of strs
test_split: test # the split for testing, accepts str or list of strs
ignore_index: 0 # the ignore index, uses for masking samples
ignore_index: # the ignore index, uses for masking samples
optimizer_config:
optimizer: adam # [sgd, adam, adamax, bert_adam]
lr: 0.001 # Learing rate
l2: 0.0 # l2 regularization
grad_clip: 1.0 # gradient clipping
sgd_config:
momentum: 0.9
grad_clip: # gradient clipping
asgd_config:
lambd: 0.0001
alpha: 0.75
t0: 1000000.0
adadelta_config:
rho: 0.9
eps: 0.000001
adagrad_config:
lr_decay: 0
initial_accumulator_value: 0
eps: 0.0000000001
adam_config:
betas: !!python/tuple [0.9, 0.999]
eps: 0.00000001
amsgrad: False
adamw_config:
betas: !!python/tuple [0.9, 0.999]
eps: 0.00000001
amsgrad: False
adamax_config:
betas: !!python/tuple [0.9, 0.999]
eps: 0.00000001
lbfgs_config:
max_iter: 20
max_eval:
tolerance_grad: 0.0000001
tolerance_change: 0.000000001
history_size: 100
line_search_fn:
rms_prop_config:
alpha: 0.99
eps: 0.00000001
momentum: 0
centered: False
r_prop_config:
etas: !!python/tuple [0.5, 1.2]
step_sizes: !!python/tuple [0.000001, 50]
sgd_config:
momentum: 0
dampening: 0
nesterov: False
sparse_adam_config:
betas: !!python/tuple [0.9, 0.999]
eps: 0.00000001
bert_adam_config:
betas: !!python/tuple [0.9, 0.999]
eps: 0.00000001
lr_scheduler_config:
lr_scheduler: # [linear, exponential, reduce_on_plateau]
lr_scheduler: # [linear, exponential, reduce_on_plateau, cosine_annealing]
warmup_steps: # warm up steps
warmup_unit: batch # [epoch, batch]
warmup_percentage: # warm up percentage
Expand All @@ -69,6 +104,8 @@ The learning parameters of the model are described below::
- 1000
gamma: 0.1
last_epoch: -1
cosine_annealing_config:
last_epoch: -1
task_scheduler_config:
task_scheduler: round_robin # [sequential, round_robin, mixed]
sequential_scheduler_config:
Expand Down
2 changes: 1 addition & 1 deletion docs/user/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ The logging parameters of Emmental are described below::
# Logging configuration
logging_config:
counter_unit: epoch # [epoch, batch]
evaluation_freq: 2
evaluation_freq: 1
writer_config:
writer: tensorboard # [json, tensorboard]
verbose: True
Expand Down
41 changes: 37 additions & 4 deletions src/emmental/emmental-default-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,50 @@ learner_config:
lr: 0.001 # Learing rate
l2: 0.0 # l2 regularization
grad_clip: # gradient clipping
sgd_config:
momentum: 0.9
dampening: 0
nesterov: False
asgd_config:
lambd: 0.0001
alpha: 0.75
t0: 1000000.0
adadelta_config:
rho: 0.9
eps: 0.000001
adagrad_config:
lr_decay: 0
initial_accumulator_value: 0
eps: 0.0000000001
adam_config:
betas: !!python/tuple [0.9, 0.999]
eps: 0.00000001
amsgrad: False
adamw_config:
betas: !!python/tuple [0.9, 0.999]
eps: 0.00000001
amsgrad: False
adamax_config:
betas: !!python/tuple [0.9, 0.999]
eps: 0.00000001
lbfgs_config:
max_iter: 20
max_eval:
tolerance_grad: 0.0000001
tolerance_change: 0.000000001
history_size: 100
line_search_fn:
rms_prop_config:
alpha: 0.99
eps: 0.00000001
momentum: 0
centered: False
r_prop_config:
etas: !!python/tuple [0.5, 1.2]
step_sizes: !!python/tuple [0.000001, 50]
sgd_config:
momentum: 0
dampening: 0
nesterov: False
sparse_adam_config:
betas: !!python/tuple [0.9, 0.999]
eps: 0.00000001
bert_adam_config:
betas: !!python/tuple [0.9, 0.999]
eps: 0.00000001
Expand Down
48 changes: 25 additions & 23 deletions src/emmental/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,41 +52,43 @@ def _set_optimizer(self, model: EmmentalModel) -> None:
model(EmmentalModel): The model to set up the optimizer.
"""

# TODO: add more optimizer support and fp16
optimizer_config = Meta.config["learner_config"]["optimizer_config"]
opt = optimizer_config["optimizer"]

parameters = filter(lambda p: p.requires_grad, model.parameters())

if opt == "sgd":
optimizer = optim.SGD(
parameters,
lr=optimizer_config["lr"],
weight_decay=optimizer_config["l2"],
**optimizer_config["sgd_config"],
)
elif opt == "adam":
optimizer = optim.Adam(
optim_dict = {
# PyTorch optimizer
"asgd": optim.ASGD,
"adadelta": optim.Adadelta,
"adagrad": optim.Adagrad,
"adam": optim.Adam,
"adamw": optim.AdamW,
"adamax": optim.Adamax,
"lbfgs": optim.LBFGS,
"rms_prop": optim.RMSprop,
"r_prop": optim.Rprop,
"sgd": optim.SGD,
"sparse_adam": optim.SparseAdam,
# Customize optimizer
"bert_adam": BertAdam,
}

if opt in ["lbfgs", "r_prop", "sparse_adam"]:
optimizer = optim_dict[opt](
parameters,
lr=optimizer_config["lr"],
weight_decay=optimizer_config["l2"],
**optimizer_config["adam_config"],
)
elif opt == "adamax":
optimizer = optim.Adamax(
parameters,
lr=optimizer_config["lr"],
weight_decay=optimizer_config["l2"],
**optimizer_config["adamax_config"],
**optimizer_config[f"{opt}_config"],
)
elif opt == "bert_adam":
optimizer = BertAdam(
elif opt in optim_dict.keys():
optimizer = optim_dict[opt](
parameters,
lr=optimizer_config["lr"],
weight_decay=optimizer_config["l2"],
**optimizer_config["bert_adam_config"],
**optimizer_config[f"{opt}_config"],
)
elif isinstance(opt, optim.Optimizer):
optimizer = opt(parameters)
else:
raise ValueError(f"Unrecognized optimizer option '{opt}'")

Expand Down
11 changes: 7 additions & 4 deletions src/emmental/modules/rnn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,15 @@ def forward(self, x: Tensor, x_mask: Optional[Tensor] = None) -> Tensor:
"""
Mean pooling
"""
if x_mask is None:
x_mask = x.new_full(x.size()[:2], fill_value=0, dtype=torch.uint8)
x_lens = x_mask.data.eq(0).long().sum(dim=1)
weights = torch.ones(x.size()) / x_lens.unsqueeze(1).float()
weights = (
output_word.new_ones(output_word.size())
/ x_lens.view(x_lens.size()[0], 1, 1).float()
)
weights.data.masked_fill_(x_mask.data.unsqueeze(dim=2), 0.0)
word_vectors = torch.bmm(
output_word.transpose(1, 2), weights.unsqueeze(2)
).squeeze(2)
word_vectors = torch.sum(output_word * weights, dim=1)
output = self.linear(word_vectors) if self.final_linear else word_vectors

return output
Loading

0 comments on commit bc201e9

Please sign in to comment.