Skip to content

Basic linear algebra for complex numbers #768

@boeddeker

Description

@boeddeker

🚀 Feature

Support basic linear algebra for complex numbers.

Motivation

I talked with @sw005320 about https://github.com/nttcslab-sp/dnn_wpe and it turns out, that the matrix inversion implemented with real numbers is unstable. In a beamforming example @Emrys365 observed a performance difference of 5 dB in a signal to distortion ratio (SDR) where he replaced the inversion with numpy code (torch: 5dB, numpy 10dB).

I tried torch.inverse and torch.solve and interestingly they are working in 1.6.0.dev20200623+cpu (Not mentioned in pytorch/pytorch#33152).
Is it possible, to support torch.matmul and some other linear algebra functions?

I also tried to use backward after torch.solve and the code fails with the exception msg, that matmul is not implemented.
Does someone know, how the gradient is defined in torch for complex numbers?
Is it grad_real + j grad_imag or grad_real - j grad_imag?
And how can I add/fix the gradient, when I find a broken implementation?

Pitch

Alternatives

Additional context

Currently, I am considering to jump between pytorch_complex and torch.autograd.Function:

def hermite(a):
    return a.transpose(-2, -1).conj()

def matmul(t1, t2):
    real1, imag1 = t1.real, t1.imag
    real2, imag2 = t2.real, t2.imag
    o_real = torch.matmul(real1, real2) - torch.matmul(imag1, imag2)
    o_imag = torch.matmul(real1, imag2) + torch.matmul(imag1, real2)
    return o_real + 1j * o_imag

class Solve(torch.autograd.Function):
    @staticmethod
    def forward(ctx, A, b):
        x, _ = torch.solve(b, A)
        ctx.save_for_backward(A, x)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        A, x = ctx.saved_tensors        
        gb, _ = torch.solve(grad_output, hermite(A))
        gA = - matmul(gb, hermite(x))
        return gA, gb

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions