Skip to content

Commit

Permalink
dont project maske tokens for mlm loss (#859)
Browse files Browse the repository at this point in the history
Summary:
This saves ~4-5gb gpu memory while training roberta large with `seq_len=512`.

I am able to fit `--max-sentences=16` on `volta32gb` for `roberta-large`
Pull Request resolved: fairinternal/fairseq-py#859

Differential Revision: D17435814

fbshipit-source-id: 2663909768fac0ef0102107613770ee01b1f8c00
  • Loading branch information
Naman Goyal authored and facebook-github-bot committed Sep 18, 2019
1 parent 31dd13f commit 718677e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
8 changes: 6 additions & 2 deletions fairseq/criterions/masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ def forward(self, model, sample, reduce=True):
3) logging outputs to display while training
"""
# compute MLM loss
logits = model(**sample['net_input'], return_all_hiddens=False)[0]
masked_tokens = sample['target'].ne(self.padding_idx)
logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0]
targets = model.get_targets(sample, [logits])
targets = targets[masked_tokens]

loss = F.nll_loss(
F.log_softmax(
logits.view(-1, logits.size(-1)),
Expand All @@ -43,7 +46,7 @@ def forward(self, model, sample, reduce=True):
ignore_index=self.padding_idx,
)

sample_size = targets.ne(self.padding_idx).int().sum().item()
sample_size = masked_tokens.int().sum().item()

logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
Expand All @@ -64,6 +67,7 @@ def aggregate_logging_outputs(logging_outputs):

agg_output = {
'loss': loss / sample_size / math.log(2),
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if ntokens > 0 else 0.,
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
Expand Down
17 changes: 10 additions & 7 deletions fairseq/models/roberta/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,17 @@ def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
self.weight = weight
self.bias = nn.Parameter(torch.zeros(output_dim))

def forward(self, features, **kwargs):
def forward(self, features, masked_tokens=None, **kwargs):
# Only project the unmasked tokens while training,
# saves both memory and computation
if masked_tokens is not None:
features = features[masked_tokens, :]

x = self.dense(features)
x = self.activation_fn(x)
x = self.layer_norm(x)

# project back to size of vocabulary with bias
x = F.linear(x, self.weight) + self.bias

return x


Expand Down Expand Up @@ -265,7 +268,7 @@ def __init__(self, args, dictionary):
weight=self.sentence_encoder.embed_tokens.weight,
)

def forward(self, src_tokens, features_only=False, return_all_hiddens=False, **unused):
def forward(self, src_tokens, features_only=False, return_all_hiddens=False, masked_tokens=None, **unused):
"""
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
Expand All @@ -283,7 +286,7 @@ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, **u
"""
x, extra = self.extract_features(src_tokens, return_all_hiddens)
if not features_only:
x = self.output_layer(x)
x = self.output_layer(x, masked_tokens=masked_tokens)
return x, extra

def extract_features(self, src_tokens, return_all_hiddens=False, **unused):
Expand All @@ -293,8 +296,8 @@ def extract_features(self, src_tokens, return_all_hiddens=False, **unused):
features = inner_states[-1]
return features, {'inner_states': inner_states if return_all_hiddens else None}

def output_layer(self, features, **unused):
return self.lm_head(features)
def output_layer(self, features, masked_tokens=None, **unused):
return self.lm_head(features, masked_tokens)

def max_positions(self):
"""Maximum output length supported by the encoder."""
Expand Down

0 comments on commit 718677e

Please sign in to comment.