Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss functions for complex tensors #46642

Open
1 of 18 tasks
anjali411 opened this issue Oct 21, 2020 · 3 comments
Open
1 of 18 tasks

Loss functions for complex tensors #46642

anjali411 opened this issue Oct 21, 2020 · 3 comments
Labels
complex_autograd module: complex Related to complex number support in PyTorch module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@anjali411
Copy link
Contributor

anjali411 commented Oct 21, 2020

馃殌 Feature

Loss functions in torch.nn module should support complex tensors whenever the operations make sense for complex numbers.

Motivation

Complex Neural Nets are an active area of research and there are a few issues on GitHub (for example, #46546 (comment)) which suggests that we should add complex number support for loss functions.

Pitch

NOTE: As of now, we have decided to add complex support for only real valued loss functions, so please make sure to check that property for your chosen loss function before you start working on a PR to add complex support.

These loss functions should be updated to add support for complex numbers (both forward and backward operations). If a loss function doesn't make sense for complex numbers, it should throw an error clearly stating that. I.e. this is a list of loss functions as of the time this issue was written, we still need to figure out which we want to support and which should throw errors.

If a loss function, uses an operation feasible but not supported for complex numbers right now, we should prioritize adding it.

cc @ezyang @anjali411 @dylanbespalko @mruberry @albanD

@anjali411 anjali411 added module: complex Related to complex number support in PyTorch complex_autograd module: nn Related to torch.nn labels Oct 21, 2020
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 21, 2020
@cdespina
Copy link

mse_loss(input, target, size_average, reduce, reduction)
2924
2925 expanded_input, expanded_target = torch.broadcast_tensors(input, target)
-> 2926 return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
2927
2928

RuntimeError: "mse_cpu" not implemented for 'ComplexDouble'

@mruberry
Copy link
Collaborator

mse_loss(input, target, size_average, reduce, reduction) 2924 2925 expanded_input, expanded_target = torch.broadcast_tensors(input, target) -> 2926 return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction)) 2927 2928

RuntimeError: "mse_cpu" not implemented for 'ComplexDouble'

This is expected behavior because this feature tracks implementing support for complex inputs to losses, and complex support has not been added to MSELoss yet.

@AhmedBoin
Copy link

you can implement your own

def complex_mse_loss(output, target):
    return (0.5*(output - target)**2).mean(dtype=torch.complex64)

you can also implement layers or any custom utils needed

class CLinear(nn.Module):
    def __init__(self, size_in, size_out):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(size_in, size_out, dtype=torch.complex64) 
        self.bias = nn.Parameter(torch.zeros(size_out, dtype=torch.complex64))

    def forward(self, x):
        if not x.dtype == torch.complex64: x = x.type(torch.complex64)
        return x@self.weights + self.bias

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
complex_autograd module: complex Related to complex number support in PyTorch module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants