Open
Description
Here is an observation
using SpecialFunctions
using Zygote
using ForwardDiff
f(x) = real(SpecialFunctions.erfcx((x + 1.0im)^2))
# reference:
Zygote.gradient(f, 1.0)[1] # 0.3169262912276313
# call
ForwardDiff.derivative(f, 1.0) # fails with
throws the error,
ERROR: MethodError: no method matching _erfcx(::Complex{ForwardDiff.Dual{ForwardDiff.Tag{var"#13#14", Float64}, Float64, 1}})
The function `_erfcx` exists, but no method is defined for this combination of argument types.
Solution
It can be fixed by explaining manually how deal with complex numbers,
const ComplexDual{T, V, N} = Complex{ForwardDiff.Dual{T, V, N}}
function SpecialFunctions.erfcx(z::ComplexDual{T, V, N}) where {T, V, N}
real_part = real(z)
imag_part = imag(z)
# Get primal values
z_val = Complex(ForwardDiff.value(real_part), ForwardDiff.value(imag_part))
# Compute function value
w_val = erfcx(z_val)
# Compute derivative
∂w = 2 * (z_val * w_val - 1 / sqrt(π))
# Get partial derivatives
dr = ForwardDiff.partials(real_part)
di = ForwardDiff.partials(imag_part)
# Construct dual result
real_dual = ForwardDiff.Dual{T}(real(w_val), real(∂w) * dr - imag(∂w) * di)
imag_dual = ForwardDiff.Dual{T}(imag(w_val), imag(∂w) * dr + real(∂w) * di)
return Complex(real_dual, imag_dual)
end
Metadata
Metadata
Assignees
Labels
No labels