Skip to content

Commit

Permalink
More flexible gradient normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
myleott committed Oct 19, 2017
1 parent 88a8bd4 commit 3f97008
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 63 deletions.
26 changes: 18 additions & 8 deletions fairseq/criterions/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,29 @@ def __init__(self, padding_idx):
super().__init__()
self.padding_idx = padding_idx

def grad_denom(self, samples):
return sum(s['ntokens'] if s else 0 for s in samples)

def forward(self, model, sample, grad_denom):
def forward(self, model, sample):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss, as a Variable
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
input = net_output.view(-1, net_output.size(-1))
target = sample['target'].view(-1)
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx)
return {
'loss': loss / grad_denom,
sample_size = sample['ntokens']
logging_output = {
'loss': loss.data[0],
'sample_size': sample_size,
}
return loss, sample_size, logging_output

def aggregate(self, loss_dicts):
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(l['loss'].data[0] for l in loss_dicts if 'loss' in l) / math.log(2),
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
}
22 changes: 15 additions & 7 deletions fairseq/criterions/fairseq_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,22 @@ class FairseqCriterion(_Loss):
def __init__(self):
super().__init__()

def grad_denom(self, samples):
"""Gradient normalization term for DataParallel training."""
raise NotImplementedError
def forward(self, model, sample):
"""Compute the loss for the given sample.
def forward(self, model, sample, grad_denom):
"""Compute the loss for the given sample and network output."""
Returns a tuple with three elements:
1) the loss, as a Variable
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
raise NotImplementedError

def aggregate(self, losses, log_infos):
"""Aggregate losses from DataParallel training."""
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
raise NotImplementedError

@staticmethod
def grad_denom(sample_sizes):
"""Compute the gradient denominator for a set of sample sizes."""
return sum(sample_sizes)
26 changes: 18 additions & 8 deletions fairseq/criterions/label_smoothed_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,29 @@ def __init__(self, eps, padding_idx=None, weights=None):
self.padding_idx = padding_idx
self.weights = weights

def grad_denom(self, samples):
return sum(s['ntokens'] if s else 0 for s in samples)

def forward(self, model, sample, grad_denom):
def forward(self, model, sample):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss, as a Variable
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
input = F.log_softmax(net_output.view(-1, net_output.size(-1)))
target = sample['target'].view(-1)
loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights)
return {
'loss': loss / grad_denom,
sample_size = sample['ntokens']
logging_output = {
'loss': loss.data[0],
'sample_size': sample_size,
}
return loss, sample_size, logging_output

def aggregate(self, loss_dicts):
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(l['loss'].data[0] for l in loss_dicts if 'loss' in l) / math.log(2),
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
}
87 changes: 47 additions & 40 deletions fairseq/multiprocessing_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau

from fairseq import nccl, utils
from fairseq.criterions import FairseqCriterion
from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future
from fairseq.nag import NAG

Expand Down Expand Up @@ -74,6 +73,7 @@ def _async_init(self, rank, device_id, args, model, criterion, nccl_uid):
momentum=self.args.momentum,
weight_decay=self.args.weight_decay)
self.flat_grads = None
self.loss = None

# initialize LR scheduler
self.lr_scheduler = self._build_lr_scheduler()
Expand Down Expand Up @@ -136,35 +136,44 @@ def train_step(self, samples):
# scatter sample across GPUs
self._scatter_samples(samples, replace_empty_samples=replace_empty_samples)

# calculate gradient normalization term
grad_denom = self.criterion.grad_denom(samples)
# forward pass
sample_sizes, logging_outputs = Future.gen_tuple_list([
self.call_async(rank, '_async_forward')
for rank in range(self.num_replicas)
])

# forward pass, backward pass and gradient step
losses = [
self.call_async(rank, '_async_train_step', grad_denom=grad_denom)
# backward pass, all-reduce gradients and take an optimization step
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
grad_norms = Future.gen_list([
self.call_async(rank, '_async_backward_and_opt', grad_denom=grad_denom)
for rank in range(self.num_replicas)
]
])

# aggregate losses and gradient norms
loss_dicts = Future.gen_list(losses)
loss_dict = self.criterion.aggregate(loss_dicts)
loss_dict['gnorm'] = loss_dicts[0]['gnorm']
# aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
logging_output['gnorm'] = grad_norms[0] # log the gradient norm

return logging_output

def _async_forward(self, rank, device_id, eval=False):
if eval:
self.model.eval()
else:
self.model.train()
self.optimizer.zero_grad()

return loss_dict
if self._sample is None:
return 0, {}

def _async_train_step(self, rank, device_id, grad_denom):
self.model.train()
# calculate loss and sample size
self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)

# zero grads even if self._sample is None, since we will all-reduce them
self.optimizer.zero_grad()
return sample_size, logging_output

# calculate loss and grads
loss = 0
loss_dict = {}
if self._sample is not None:
loss_dict = self.criterion(self.model, self._sample, grad_denom)
loss_dict['loss'].backward()
loss = loss_dict['loss'].data[0]
def _async_backward_and_opt(self, rank, device_id, grad_denom):
if self.loss is not None:
# backward pass
self.loss.backward()

# flatten grads into a contiguous block of memory
if self.flat_grads is None:
Expand All @@ -173,13 +182,20 @@ def _async_train_step(self, rank, device_id, grad_denom):
# all-reduce grads
nccl.all_reduce(self.flat_grads)

# normalize grads
if grad_denom != 0:
self.flat_grads.div_(grad_denom)

# clip grads
loss_dict['gnorm'] = self._clip_grads_(self.flat_grads, self.args.clip_norm)
grad_norm = self._clip_grads_(self.flat_grads, self.args.clip_norm)

# take an optimization step
self.optimizer.step()

return loss_dict
# reset loss
self.loss = None

return grad_norm

def _flatten_grads_(self, model):
num_params = sum(p.data.numel() for p in model.parameters())
Expand All @@ -206,25 +222,16 @@ def valid_step(self, samples):
# scatter sample across GPUs
self._scatter_samples(samples, volatile=True)

# calculate gradient normalization term
grad_denom = self.criterion.grad_denom(samples)

# forward pass
losses = [
self.call_async(rank, '_async_valid_step', grad_denom=grad_denom)
_sample_sizes, logging_outputs = Future.gen_tuple_list([
self.call_async(rank, '_async_forward', eval=True)
for rank in range(self.num_replicas)
]

# aggregate losses
loss_dict = self.criterion.aggregate(Future.gen_list(losses))
])

return loss_dict
# aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)

def _async_valid_step(self, rank, device_id, grad_denom):
if self._sample is None:
return {}
self.model.eval()
return self.criterion(self.model, self._sample, grad_denom)
return logging_output

def get_lr(self):
"""Get the current learning rate."""
Expand Down

0 comments on commit 3f97008

Please sign in to comment.