New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Loss functions for complex tensors #46642
Comments
mse_loss(input, target, size_average, reduce, reduction) RuntimeError: "mse_cpu" not implemented for 'ComplexDouble' |
This is expected behavior because this feature tracks implementing support for complex inputs to losses, and complex support has not been added to MSELoss yet. |
you can implement your own def complex_mse_loss(output, target):
return (0.5*(output - target)**2).mean(dtype=torch.complex64) you can also implement layers or any custom utils needed class CLinear(nn.Module):
def __init__(self, size_in, size_out):
super().__init__()
self.weights = nn.Parameter(torch.randn(size_in, size_out, dtype=torch.complex64)
self.bias = nn.Parameter(torch.zeros(size_out, dtype=torch.complex64))
def forward(self, x):
if not x.dtype == torch.complex64: x = x.type(torch.complex64)
return x@self.weights + self.bias |
馃殌 Feature
Loss functions in
torch.nn
module should support complex tensors whenever the operations make sense for complex numbers.Motivation
Complex Neural Nets are an active area of research and there are a few issues on GitHub (for example, #46546 (comment)) which suggests that we should add complex number support for loss functions.
Pitch
NOTE: As of now, we have decided to add complex support for only real valued loss functions, so please make sure to check that property for your chosen loss function before you start working on a PR to add complex support.
These loss functions should be updated to add support for complex numbers (both forward and backward operations). If a loss function doesn't make sense for complex numbers, it should throw an error clearly stating that. I.e. this is a list of loss functions as of the time this issue was written, we still need to figure out which we want to support and which should throw errors.
nn.L1Loss : PR Add complex support for torch.nn.L1Loss聽#49912
nn.MSELoss
nn.CrossEntropyLoss
nn.CTCLoss
nn.NLLLoss
nn.PoissonNLLLoss
nn.KLDivLoss
nn.BCELoss
nn.BCEWithLogitsLoss
nn.MarginRankingLoss
nn.HingeEmbeddingLoss
nn.MultiLabelMarginLoss
nn.SmoothL1Loss
nn.SoftMarginLoss
nn.MultiLabelSoftMarginLoss
nn.CosineEmbeddingLoss
nn.MultiMarginLoss
nn.TripletMarginLoss
If a loss function, uses an operation feasible but not supported for complex numbers right now, we should prioritize adding it.
cc @ezyang @anjali411 @dylanbespalko @mruberry @albanD
The text was updated successfully, but these errors were encountered: