Skip to content

Commit

Permalink
Add large_margin as an option for the loss.
Browse files Browse the repository at this point in the history
Add a note on a large margin experiment (no improvement)
  • Loading branch information
AngledLuffa committed Nov 30, 2022
1 parent fcad885 commit 5edd724
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 1 deletion.
68 changes: 68 additions & 0 deletions stanza/models/common/large_margin_loss.py
@@ -0,0 +1,68 @@
"""
LargeMarginInSoftmax, from the article
@inproceedings{kobayashi2019bmvc,
title={Large Margin In Softmax Cross-Entropy Loss},
author={Takumi Kobayashi},
booktitle={Proceedings of the British Machine Vision Conference (BMVC)},
year={2019}
}
implementation from
https://github.com/tk1980/LargeMarginInSoftmax
There is no license specifically chosen; they just ask people to cite the paper if the work is useful.
"""


import math

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F


class LargeMarginInSoftmaxLoss(nn.CrossEntropyLoss):
r"""
This combines the Softmax Cross-Entropy Loss (nn.CrossEntropyLoss) and the large-margin inducing
regularization proposed in
T. Kobayashi, "Large-Margin In Softmax Cross-Entropy Loss." In BMVC2019.
This loss function inherits the parameters from nn.CrossEntropyLoss except for `reg_lambda` and `deg_logit`.
Args:
reg_lambda (float, optional): a regularization parameter. (default: 0.3)
deg_logit (bool, optional): underestimate (degrade) the target logit by -1 or not. (default: False)
If True, it realizes the method that incorporates the modified loss into ours
as described in the above paper (Table 4).
"""
def __init__(self, reg_lambda=0.3, deg_logit=None,
weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean'):
super(LargeMarginInSoftmaxLoss, self).__init__(weight=weight, size_average=size_average,
ignore_index=ignore_index, reduce=reduce, reduction=reduction)
self.reg_lambda = reg_lambda
self.deg_logit = deg_logit

def forward(self, input, target):
N = input.size(0) # number of samples
C = input.size(1) # number of classes
Mask = torch.zeros_like(input, requires_grad=False)
Mask[range(N),target] = 1

if self.deg_logit is not None:
input = input - self.deg_logit * Mask

loss = F.cross_entropy(input, target, weight=self.weight,
ignore_index=self.ignore_index, reduction=self.reduction)

X = input - 1.e6 * Mask # [N x C], excluding the target class
reg = 0.5 * ((F.softmax(X, dim=1) - 1.0/(C-1)) * F.log_softmax(X, dim=1) * (1.0-Mask)).sum(dim=1)
if self.reduction == 'sum':
reg = reg.sum()
elif self.reduction == 'mean':
reg = reg.mean()
elif self.reduction == 'none':
reg = reg

return loss + self.reg_lambda * reg
7 changes: 7 additions & 0 deletions stanza/models/constituency/trainer.py
Expand Up @@ -24,6 +24,7 @@
from stanza.models.common import pretrain
from stanza.models.common import utils
from stanza.models.common.foundation_cache import load_bert, load_charlm, load_pretrain, FoundationCache
from stanza.models.common.large_margin_loss import LargeMarginInSoftmaxLoss
from stanza.models.constituency import parse_transitions
from stanza.models.constituency import parse_tree
from stanza.models.constituency import transition_sequence
Expand Down Expand Up @@ -659,6 +660,12 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
logger.info("Building FocalLoss, gamma=%f", args['loss_focal_gamma'])
process_outputs = lambda x: torch.softmax(x, dim=1)
model_loss_function = FocalLoss(reduction='sum', gamma=args['loss_focal_gamma'])
elif args['loss'] == 'large_margin':
logger.info("Building LargeMarginInSoftmaxLoss(sum)")
process_outputs = lambda x: x
model_loss_function = LargeMarginInSoftmaxLoss(reduction='sum')
else:
raise ValueError("Unexpected loss term: %s" % args['loss'])
if args['cuda']:
model_loss_function.cuda()

Expand Down
5 changes: 4 additions & 1 deletion stanza/models/constituency_parser.py
Expand Up @@ -394,7 +394,10 @@ def parse_args(args=None):
parser.add_argument('--grad_clipping', default=None, type=float, help='Clip abs(grad) to this amount. Use --no_grad_clipping to turn off grad clipping')
parser.add_argument('--no_grad_clipping', action='store_const', const=None, dest='grad_clipping', help='Use --no_grad_clipping to turn off grad clipping')

parser.add_argument('--loss', default='cross', help='cross or focal. Focal requires `pip install focal_loss_torch`')
# Large Margin is from Large Margin In Softmax Cross-Entropy Loss
# it did not help on an Italian VIT test
# scores went from 0.8252 to 0.8248
parser.add_argument('--loss', default='cross', help='cross, large_margin, or focal. Focal requires `pip install focal_loss_torch`')
parser.add_argument('--loss_focal_gamma', default=2, type=float, help='gamma value for a focal loss')

# When using word_dropout and predict_dropout in conjunction with relu, one particular experiment produced the following dev scores after 300 iterations:
Expand Down

0 comments on commit 5edd724

Please sign in to comment.