You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def forward(self, preds, targets):
batch, classes = preds.shape
onehot_target = torch.zeros(batch, classes, dtype=torch.float32, device=preds.device)
onehot_target[torch.arange(batch).long(), targets] = 1.0
prob = preds.sigmoid()
pt = prob * onehot_target + (1-prob) * (1-onehot_target) # pt = p if t > 0 else 1-p
weights = self.alpha * onehot_target + (1-self.alpha)*(1-onehot_target) # w = alpha if t > 0 else 1-alpha
weights = weights * (1-pt).pow(self.gamma)
loss = F.binary_cross_entropy_with_logits(input=preds, target=onehot_target, weight=weights, reduction='none')
loss = loss / batch
return loss
Expected behavior
Traceback (most recent call last):
File "train.py", line 503, in
train(iter_epoch)
File "train.py", line 278, in train
seq_loss = seq_criterion(preds, targets)
File "/home/hotaek/anaconda3/envs/pytorch1-py368/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "/home/hotaek/Projects/0_Kepco/0_Recog_Gauge/1_Source/TextRecognizer/loss.py", line 21, in forward
loss = F.binary_cross_entropy_with_logits(input=preds, target=onehot_target, weight=weights, reduction='mean')
File "/home/hotaek/anaconda3/envs/pytorch1-py368/lib/python3.6/site-packages/torch/nn/functional.py", line 2077, in binary_cross_entropy_with_logits
return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
RuntimeError: the derivative for 'weight' is not implemented
wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176
@ht-alchera your weights variable has requires_grad which is not supported: binary_cross_entropy_with_logits doesn't support back-propagating through the weights attribute.
If you don't need the derivative w.r.t. weights then you can use weights.detach() instead of weights. If you need the derivative, then you'll having to implement binary_cross_entropy_with_logits yourself.
🐛 Bug
I'm moving to pytorch 1.0.1 recently. But I got the error below when I use 'binary_cross_entropy_with_logits'
RuntimeError: the derivative for 'weight' is not implemented
my code is work well with pytorch 0.4.1
I'm used CUDA 9.0.176, cuDNN7.4.1, Ubuntu 16.04(gcc is 5.4.0)
To Reproduce
Steps to reproduce the behavior:
import torch.nn as nn
import torch
import torch.nn.functional as F
class FocalLoss(nn.Module):
def init(self, alpha, gamma):
super(FocalLoss, self).init()
self.alpha = alpha
self.gamma = gamma
Expected behavior
Traceback (most recent call last):
File "train.py", line 503, in
train(iter_epoch)
File "train.py", line 278, in train
seq_loss = seq_criterion(preds, targets)
File "/home/hotaek/anaconda3/envs/pytorch1-py368/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "/home/hotaek/Projects/0_Kepco/0_Recog_Gauge/1_Source/TextRecognizer/loss.py", line 21, in forward
loss = F.binary_cross_entropy_with_logits(input=preds, target=onehot_target, weight=weights, reduction='mean')
File "/home/hotaek/anaconda3/envs/pytorch1-py368/lib/python3.6/site-packages/torch/nn/functional.py", line 2077, in binary_cross_entropy_with_logits
return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
RuntimeError: the derivative for 'weight' is not implemented
Environment
Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).
You can get the script and run it with:
PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: version 3.5.1
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration: GPU 0: GeForce GTX 1080 Ti
Nvidia driver version: 384.130
cuDNN version: /usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn.so.7
Versions of relevant libraries:
[pip] numpy==1.16.2
[pip] torch==1.0.0
[pip] torchvision==0.2.1
[conda] Could not collect
Additional context
The text was updated successfully, but these errors were encountered: