Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different upstream and downstream learning rates #525

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion s3prl/downstream/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def train(self):
for entry in self.all_entries:
if entry.trainable:
entry.model.train()
trainable_models.append(entry.model)
trainable_models.append((entry.name, entry.model))
trainable_paras += list(entry.model.parameters())
else:
entry.model.eval()
Expand Down
64 changes: 41 additions & 23 deletions s3prl/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,32 @@ def get_optimizer(model_params, total_steps, optimizer_config):
return optimizer


def get_grouped_parameters(model_params):
def get_grouped_parameters(model_params, upstream_lr=None):
named_params = []
for m in model_params:
named_params += list(m.named_parameters())
for model_type, m in model_params:
named_params += [(model_type, n, p) for n, p in m.named_parameters()]

upstream_params = {} if upstream_lr is None else {'lr': upstream_lr}
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
grouped_parameters = [
{'params': [p for n, p in named_params if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in named_params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
{'params': [p for t, n, p in named_params if not any(nd in n for nd in no_decay) and t != 'Upstream'], 'weight_decay': 0.01},
{'params': [p for t, n, p in named_params if not any(nd in n for nd in no_decay) and t == 'Upstream'], 'weight_decay': 0.01, **upstream_params},
{'params': [p for t, n, p in named_params if any(nd in n for nd in no_decay) and t != 'Upstream'], 'weight_decay': 0.0},
{'params': [p for t, n, p in named_params if any(nd in n for nd in no_decay) and t == 'Upstream'], 'weight_decay': 0.0, **upstream_params},
]
return grouped_parameters


def get_BertAdam_with_schedule(model_params, lr=2e-4, total_steps=20000, warmup_proportion=0.07, **kwargs):
grouped_parameters = get_grouped_parameters(model_params)
def get_BertAdam_with_schedule(model_params, lr=2e-4, total_steps=20000, warmup_proportion=0.07, upstream_lr=None, **kwargs):
grouped_parameters = get_grouped_parameters(model_params, upstream_lr)
optimizer = BertAdam(grouped_parameters, lr=lr,
warmup=warmup_proportion,
t_total=total_steps)
return optimizer


def get_AdamW_with_schedule(model_params, lr=2e-4, total_steps=20000, warmup_proportion=0.07, **kwargs):
grouped_parameters = get_grouped_parameters(model_params)
def get_AdamW_with_schedule(model_params, lr=2e-4, total_steps=20000, warmup_proportion=0.07, upstream_lr=None, **kwargs):
grouped_parameters = get_grouped_parameters(model_params, upstream_lr)
optimizer = Lamb(grouped_parameters,
lr=lr,
warmup=warmup_proportion,
Expand All @@ -60,8 +63,8 @@ def get_AdamW_with_schedule(model_params, lr=2e-4, total_steps=20000, warmup_pro
return optimizer


def get_Lamb_with_schedule(model_params, lr=2e-4, total_steps=20000, warmup_proportion=0.07, **kwargs):
grouped_parameters = get_grouped_parameters(model_params)
def get_Lamb_with_schedule(model_params, lr=2e-4, total_steps=20000, warmup_proportion=0.07, upstream_lr=None, **kwargs):
grouped_parameters = get_grouped_parameters(model_params, upstream_lr)
optimizer = Lamb(grouped_parameters,
lr=lr,
warmup=warmup_proportion,
Expand All @@ -72,25 +75,40 @@ def get_Lamb_with_schedule(model_params, lr=2e-4, total_steps=20000, warmup_prop
return optimizer


def get_Adam(model_params, lr=2e-4, **kwargs):
params = []
for m in model_params:
params += list(m.parameters())
def get_Adam(model_params, lr=2e-4, upstream_lr=None, **kwargs):
params = [{"params": []}]
if upstream_lr is not None:
params.append({'params': [], "lr": upstream_lr})
for t, m in model_params:
if t == 'Upstream' and upstream_lr is not None:
params[1]['params'] += list(m.parameters())
else:
params[0]["params"] += list(m.parameters())
return Adam(params, lr=lr, betas=(0.9, 0.999))


def get_AdamW(model_params, lr=2e-4, **kwargs):
params = []
for m in model_params:
params += list(m.parameters())
def get_AdamW(model_params, lr=2e-4, upstream_lr=None, **kwargs):
params = [{"params": []}]
if upstream_lr is not None:
params.append({'params': [], "lr": upstream_lr})
for t, m in model_params:
if t == 'Upstream' and upstream_lr is not None:
params[1]['params'] += list(m.parameters())
else:
params[0]["params"] += list(m.parameters())
optimizer = AdamW(params, lr=lr)
return optimizer


def get_TorchOptim(model_params, torch_optim_name, **kwargs):
params = []
for m in model_params:
params += list(m.parameters())
def get_TorchOptim(model_params, torch_optim_name, upstream_lr=None, **kwargs):
params = [{"params": []}]
if upstream_lr is not None:
params.append({'params': [], "lr": upstream_lr})
for t, m in model_params:
if t == 'Upstream' and upstream_lr is not None:
params[1]['params'] += list(m.parameters())
else:
params[0]["params"] += list(m.parameters())
Opt_class = getattr(torch.optim, torch_optim_name)

kwargs.pop('total_steps')
Expand Down