Skip to content

Commit

Permalink
Streamlined the sample_weight docs example to reduce what user must s…
Browse files Browse the repository at this point in the history
…et externally.
  • Loading branch information
ZaydH committed May 1, 2019
1 parent 5b9b8b7 commit fcbb462
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions docs/user/FAQ.rst
Expand Up @@ -99,16 +99,19 @@ dictionary. Below, there is example code on how to achieve this:
...
class MyNet(NeuralNet):
def __init__(self, *args, criterion__reduce=False, **kwargs):
# make sure to set reduce=False in your criterion, since we need the loss
# for each sample so that it can be weighted
super().__init__(*args, criterion__reduce=criterion__reduce, **kwargs)
def get_loss(self, y_pred, y_true, X, *args, **kwargs):
# override get_loss to use the sample_weight from X
loss_unreduced = super().get_loss(y_pred, y_true, X, *args, **kwargs)
sample_weight = X['sample_weight']
loss_reduced = (sample_weight * loss_unreduced).mean()
return loss_reduced
# make sure to pass reduce=False to your criterion, since we need the loss
# for each sample so that it can be weighted
net = MyNet(MyModule, ..., criterion__reduce=False)
net = MyNet(MyModule, ...)
net.fit(X, y)
Expand Down

0 comments on commit fcbb462

Please sign in to comment.