Skip to content

CosineEmbeddingLoss tensor sizes not matching when grad_output != 1 #1058

@MatthiasKohl

Description

@MatthiasKohl

The following code produces an error for me (torch version string 0.1.10_2):

import torch
import torch.nn as nn
from torch.autograd import Variable

cos_loss = nn.CosineEmbeddingLoss()
t1 = torch.randn(10, 5)
t2 = torch.randn(10, 5)
lab = t1.sum(1).ge(0).long() * 2 - 1
loss1 = cos_loss(Variable(t1, requires_grad=True), Variable(t2), Variable(lab))
loss2 = loss1 / 10
loss2.backward()

Error is:
RuntimeError: inconsistent tensor size at /data/users/soumith/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:842

I don't think this should produce an error. Can you tell me if this is already known/being fixed or how we should work around it ? I'm assuming that in nn/_functions/loss - CosineEmbeddingLoss(Function) - backward(), you should be doing something like gw1.mul_(grad_output[0]) at the end instead of gw1.mul_(grad_output)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions