Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.functional.binary_cross_entropy_with_logits got error when work with 'weight' #19254

Closed
ht-alchera opened this issue Apr 15, 2019 · 1 comment

Comments

@ht-alchera
Copy link

🐛 Bug

I'm moving to pytorch 1.0.1 recently. But I got the error below when I use 'binary_cross_entropy_with_logits'

RuntimeError: the derivative for 'weight' is not implemented

my code is work well with pytorch 0.4.1
I'm used CUDA 9.0.176, cuDNN7.4.1, Ubuntu 16.04(gcc is 5.4.0)

To Reproduce

Steps to reproduce the behavior:

  1. This is my code.

import torch.nn as nn
import torch
import torch.nn.functional as F

class FocalLoss(nn.Module):
def init(self, alpha, gamma):
super(FocalLoss, self).init()
self.alpha = alpha
self.gamma = gamma

def forward(self, preds, targets):
    batch, classes = preds.shape
    onehot_target = torch.zeros(batch, classes, dtype=torch.float32, device=preds.device)
    onehot_target[torch.arange(batch).long(), targets] = 1.0

    prob = preds.sigmoid()
    pt = prob * onehot_target + (1-prob) * (1-onehot_target)         # pt = p if t > 0 else 1-p
    weights = self.alpha * onehot_target + (1-self.alpha)*(1-onehot_target)  # w = alpha if t > 0 else 1-alpha
    weights = weights * (1-pt).pow(self.gamma)

    loss = F.binary_cross_entropy_with_logits(input=preds, target=onehot_target, weight=weights, reduction='none')
    loss = loss / batch

    return loss

Expected behavior

Traceback (most recent call last):
File "train.py", line 503, in
train(iter_epoch)
File "train.py", line 278, in train
seq_loss = seq_criterion(preds, targets)
File "/home/hotaek/anaconda3/envs/pytorch1-py368/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "/home/hotaek/Projects/0_Kepco/0_Recog_Gauge/1_Source/TextRecognizer/loss.py", line 21, in forward
loss = F.binary_cross_entropy_with_logits(input=preds, target=onehot_target, weight=weights, reduction='mean')
File "/home/hotaek/anaconda3/envs/pytorch1-py368/lib/python3.6/site-packages/torch/nn/functional.py", line 2077, in binary_cross_entropy_with_logits
return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
RuntimeError: the derivative for 'weight' is not implemented

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py

PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration: GPU 0: GeForce GTX 1080 Ti
Nvidia driver version: 384.130
cuDNN version: /usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn.so.7

Versions of relevant libraries:
[pip] numpy==1.16.2
[pip] torch==1.0.0
[pip] torchvision==0.2.1
[conda] Could not collect

Additional context

@colesbury
Copy link
Member

@ht-alchera your weights variable has requires_grad which is not supported: binary_cross_entropy_with_logits doesn't support back-propagating through the weights attribute.

If you don't need the derivative w.r.t. weights then you can use weights.detach() instead of weights. If you need the derivative, then you'll having to implement binary_cross_entropy_with_logits yourself.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants