Skip to content
Permalink
Browse files

Streamlined the sample_weight docs example to reduce what user must s…

…et externally.
  • Loading branch information...
ZaydH committed May 1, 2019
1 parent 5b9b8b7 commit fcbb4620c7796df2d7c271067334962f17c5d38b
Showing with 6 additions and 3 deletions.
  1. +6 −3 docs/user/FAQ.rst
@@ -99,16 +99,19 @@ dictionary. Below, there is example code on how to achieve this:
...
class MyNet(NeuralNet):
def __init__(self, *args, criterion__reduce=False, **kwargs):
# make sure to set reduce=False in your criterion, since we need the loss
# for each sample so that it can be weighted
super().__init__(*args, criterion__reduce=criterion__reduce, **kwargs)
def get_loss(self, y_pred, y_true, X, *args, **kwargs):
# override get_loss to use the sample_weight from X
loss_unreduced = super().get_loss(y_pred, y_true, X, *args, **kwargs)
sample_weight = X['sample_weight']
loss_reduced = (sample_weight * loss_unreduced).mean()
return loss_reduced
# make sure to pass reduce=False to your criterion, since we need the loss
# for each sample so that it can be weighted
net = MyNet(MyModule, ..., criterion__reduce=False)
net = MyNet(MyModule, ...)
net.fit(X, y)

0 comments on commit fcbb462

Please sign in to comment.
You can’t perform that action at this time.