Skip to content

ForwardDiff of complex erfcx fails #493

Open
@mmikhasenko

Description

@mmikhasenko

Here is an observation

using SpecialFunctions
using Zygote
using ForwardDiff

f(x) = real(SpecialFunctions.erfcx((x + 1.0im)^2))

# reference:
Zygote.gradient(f, 1.0)[1] # 0.3169262912276313

# call
ForwardDiff.derivative(f, 1.0) # fails with 

throws the error,

ERROR: MethodError: no method matching _erfcx(::Complex{ForwardDiff.Dual{ForwardDiff.Tag{var"#13#14", Float64}, Float64, 1}})
The function `_erfcx` exists, but no method is defined for this combination of argument types.

Solution

It can be fixed by explaining manually how deal with complex numbers,

const ComplexDual{T, V, N} = Complex{ForwardDiff.Dual{T, V, N}}

function SpecialFunctions.erfcx(z::ComplexDual{T, V, N}) where {T, V, N}
    real_part = real(z)
    imag_part = imag(z)
    # Get primal values
    z_val = Complex(ForwardDiff.value(real_part), ForwardDiff.value(imag_part))

    # Compute function value
    w_val = erfcx(z_val)

    # Compute derivative
    ∂w = 2 * (z_val * w_val - 1 / sqrt(π))

    # Get partial derivatives
    dr = ForwardDiff.partials(real_part)
    di = ForwardDiff.partials(imag_part)

    # Construct dual result
    real_dual = ForwardDiff.Dual{T}(real(w_val), real(∂w) * dr - imag(∂w) * di)
    imag_dual = ForwardDiff.Dual{T}(imag(w_val), imag(∂w) * dr + real(∂w) * di)

    return Complex(real_dual, imag_dual)
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions