Skip to content

Commit

Permalink
Remove F.softmax_with_cross_entropy in paddlenlp (PaddlePaddle#484)
Browse files Browse the repository at this point in the history
* upgrade F.cross_entropy usage

* fix sample code bug

* fix ppl shape error
  • Loading branch information
LiuChiachi committed Jun 4, 2021
1 parent bac370b commit ff17c36
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 22 deletions.
4 changes: 2 additions & 2 deletions examples/machine_translation/seq2seq/seq2seq_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def __init__(self):
super(CrossEntropyCriterion, self).__init__()

def forward(self, predict, label, trg_mask):
cost = F.softmax_with_cross_entropy(
logits=predict, label=label, soft_label=False)
cost = F.cross_entropy(
input=predict, label=label, soft_label=False, reduction='none')
cost = paddle.squeeze(cost, axis=[2])
masked_cost = cost * trg_mask
batch_mean_cost = paddle.mean(masked_cost, axis=[0])
Expand Down
7 changes: 4 additions & 3 deletions paddlenlp/metrics/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ def __init__(self, name='Perplexity', *args, **kwargs):
self.total_word_num = 0

def compute(self, pred, label, seq_mask=None):
label = paddle.unsqueeze(label, axis=2)
ce = F.softmax_with_cross_entropy(
logits=pred, label=label, soft_label=False)
if label.dim() == 2:
label = paddle.unsqueeze(label, axis=2)
ce = F.cross_entropy(
input=pred, label=label, reduction='none', soft_label=False)
ce = paddle.squeeze(ce, axis=[2])
if seq_mask is not None:
ce = ce * seq_mask
Expand Down
11 changes: 7 additions & 4 deletions paddlenlp/transformers/bert/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,12 @@ def __init__(self, vocab_size):
def forward(self, prediction_scores, seq_relationship_score,
masked_lm_labels, next_sentence_labels, masked_lm_scale):
with paddle.static.amp.fp16_guard():
masked_lm_loss = paddle.nn.functional.softmax_with_cross_entropy(
prediction_scores, masked_lm_labels, ignore_index=-1)
masked_lm_loss = F.cross_entropy(
prediction_scores,
masked_lm_labels,
reduction='none',
ignore_index=-1)
masked_lm_loss = masked_lm_loss / masked_lm_scale
next_sentence_loss = paddle.nn.functional.softmax_with_cross_entropy(
seq_relationship_score, next_sentence_labels)
next_sentence_loss = F.cross_entropy(
seq_relationship_score, next_sentence_labels, reduction='none')
return paddle.sum(masked_lm_loss) + paddle.mean(next_sentence_loss)
11 changes: 7 additions & 4 deletions paddlenlp/transformers/bigbird/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,14 +862,17 @@ def forward(self, prediction_scores, seq_relationship_score,
masked_lm_scale, masked_lm_weights)
print(loss)
"""
masked_lm_loss = paddle.nn.functional.softmax_with_cross_entropy(
prediction_scores, masked_lm_labels, ignore_index=self.ignore_index)
masked_lm_loss = F.cross_entropy(
prediction_scores,
masked_lm_labels,
ignore_index=self.ignore_index,
reduction='none')
masked_lm_loss = paddle.transpose(masked_lm_loss, [1, 0])
masked_lm_loss = paddle.sum(masked_lm_loss * masked_lm_weights) / (
paddle.sum(masked_lm_weights) + 1e-5)
scale = 1.0
if not self.use_nsp:
scale = 0.0
next_sentence_loss = paddle.nn.functional.softmax_with_cross_entropy(
seq_relationship_score, next_sentence_labels)
next_sentence_loss = F.cross_entropy(
seq_relationship_score, next_sentence_labels, reduction='none')
return masked_lm_loss + paddle.mean(next_sentence_loss) * scale
12 changes: 8 additions & 4 deletions paddlenlp/transformers/ernie/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from .. import PretrainedModel, register_base_model

Expand Down Expand Up @@ -772,9 +773,12 @@ def __init__(self, vocab_size):
def forward(self, prediction_scores, seq_relationship_score,
masked_lm_labels, next_sentence_labels, masked_lm_scale):
with paddle.static.amp.fp16_guard():
masked_lm_loss = paddle.nn.functional.softmax_with_cross_entropy(
prediction_scores, masked_lm_labels, ignore_index=-1)
masked_lm_loss = F.cross_entropy(
prediction_scores,
masked_lm_labels,
ignore_index=-1,
reduction='none')
masked_lm_loss = masked_lm_loss / masked_lm_scale
next_sentence_loss = paddle.nn.functional.softmax_with_cross_entropy(
seq_relationship_score, next_sentence_labels)
next_sentence_loss = F.cross_entropy(
seq_relationship_score, next_sentence_labels, reduction='none')
return paddle.sum(masked_lm_loss) + paddle.mean(next_sentence_loss)
7 changes: 5 additions & 2 deletions paddlenlp/transformers/ernie_gen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,10 @@ def forward(self, *args, **kwargs):
if len(tgt_labels.shape) == 1:
tgt_labels = paddle.reshape(tgt_labels, [-1, 1])

loss = paddle.nn.functional.cross_entropy(
logits_2d, tgt_labels, soft_label=(tgt_labels.shape[-1] != 1))
loss = F.cross_entropy(
logits_2d,
tgt_labels,
reduction="none",
soft_label=(tgt_labels.shape[-1] != 1))

return loss, logits_2d, info
7 changes: 4 additions & 3 deletions paddlenlp/transformers/transformer/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def forward(self, predict, label):
label = paddle.randint(
low=3,
high=vocab_size,
shape=[batch_size, seq_len, vocab_size])
shape=[batch_size, seq_len, 1])
criterion(predict, label)
"""
Expand All @@ -265,9 +265,10 @@ def forward(self, predict, label):
x=label, num_classes=predict.shape[-1]),
epsilon=self.label_smooth_eps)

cost = F.softmax_with_cross_entropy(
logits=predict,
cost = F.cross_entropy(
input=predict,
label=label,
reduction='none',
soft_label=True if self.label_smooth_eps else False)
weighted_cost = cost * weights
sum_cost = paddle.sum(weighted_cost)
Expand Down

0 comments on commit ff17c36

Please sign in to comment.