Skip to content

Commit

Permalink
Relax complex function signatures to make them ForwardDiff compatible (
Browse files Browse the repository at this point in the history
  • Loading branch information
YingboMa authored and simeonschaub committed Aug 11, 2020
1 parent 7e78f09 commit ee7339b
Showing 1 changed file with 42 additions and 37 deletions.
79 changes: 42 additions & 37 deletions base/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ function inv(w::ComplexF64)
return ComplexF64(p*s,q*s) # undo scaling
end

function ssqs(x::T, y::T) where T<:AbstractFloat
function ssqs(x::T, y::T) where T<:Real
k::Int = 0
ρ = x*x + y*y
if !isfinite(ρ) && (isinf(x) || isinf(y))
Expand All @@ -478,7 +478,8 @@ function ssqs(x::T, y::T) where T<:AbstractFloat
ρ, k
end

function sqrt(z::Complex{<:AbstractFloat})
function sqrt(z::Complex)
z = float(z)
x, y = reim(z)
if x==y==0
return Complex(zero(x),y)
Expand All @@ -503,7 +504,6 @@ function sqrt(z::Complex{<:AbstractFloat})
end
Complex(ξ,η)
end
sqrt(z::Complex) = sqrt(float(z))

# function sqrt(z::Complex)
# rz = float(real(z))
Expand Down Expand Up @@ -560,10 +560,12 @@ julia> rad2deg(angle(-1 - im))
"""
angle(z::Complex) = atan(imag(z), real(z))

function log(z::Complex{T}) where T<:AbstractFloat
T1::T = 1.25
T2::T = 3
ln2::T = log(convert(T,2)) #0.6931471805599453
function log(z::Complex)
z = float(z)
T = typeof(real(z))
T1 = convert(T,5)/convert(T,4)
T2 = convert(T,3)
ln2 = log(convert(T,2)) #0.6931471805599453
x, y = reim(z)
ρ, k = ssqs(x,y)
ax = abs(x)
Expand All @@ -580,7 +582,6 @@ function log(z::Complex{T}) where T<:AbstractFloat
end
Complex(ρρ, angle(z))
end
log(z::Complex) = log(float(z))

# function log(z::Complex)
# ar = abs(real(z))
Expand Down Expand Up @@ -681,39 +682,42 @@ function log1p(z::Complex{T}) where T
end
end

function exp2(z::Complex{T}) where T<:AbstractFloat
function exp2(z::Complex{T}) where T
z = float(z)
er = exp2(real(z))
theta = imag(z) * log(convert(T, 2))
theta = imag(z) * log(convert(float(T), 2))
s, c = sincos(theta)
Complex(er * c, er * s)
end
exp2(z::Complex) = exp2(float(z))

function exp10(z::Complex{T}) where T<:AbstractFloat
function exp10(z::Complex{T}) where T
z = float(z)
er = exp10(real(z))
theta = imag(z) * log(convert(T, 10))
theta = imag(z) * log(convert(float(T), 10))
s, c = sincos(theta)
Complex(er * c, er * s)
end
exp10(z::Complex) = exp10(float(z))

# _cpow helper function to avoid method ambiguity with ^(::Complex,::Real)
function _cpow(z::Union{T,Complex{T}}, p::Union{T,Complex{T}}) where {T<:AbstractFloat}
function _cpow(z::Union{T,Complex{T}}, p::Union{T,Complex{T}}) where T
z = float(z)
p = float(p)
Tf = float(T)
if isreal(p)
pᵣ = real(p)
if isinteger(pᵣ) && abs(pᵣ) < typemax(Int32)
# |p| < typemax(Int32) serves two purposes: it prevents overflow
# when converting p to Int, and it also turns out to be roughly
# the crossover point for exp(p*log(z)) or similar to be faster.
if iszero(pᵣ) # fix signs of imaginary part for z^0
zer = flipsign(copysign(zero(T),pᵣ), imag(z))
return Complex(one(T), zer)
zer = flipsign(copysign(zero(Tf),pᵣ), imag(z))
return Complex(one(Tf), zer)
end
ip = convert(Int, pᵣ)
if isreal(z)
zᵣ = real(z)
if ip < 0
iszero(z) && return Complex(T(NaN),T(NaN))
iszero(z) && return Complex(Tf(NaN),Tf(NaN))
re = power_by_squaring(inv(zᵣ), -ip)
im = -imag(z)
else
Expand All @@ -729,7 +733,7 @@ function _cpow(z::Union{T,Complex{T}}, p::Union{T,Complex{T}}) where {T<:Abstrac
# (note: if both z and p are complex with ±0.0 imaginary parts,
# the sign of the ±0.0 imaginary part of the result is ambiguous)
if iszero(real(z))
return pᵣ > 0 ? complex(z) : Complex(T(NaN),T(NaN)) # 0 or NaN+NaN*im
return pᵣ > 0 ? complex(z) : Complex(Tf(NaN),Tf(NaN)) # 0 or NaN+NaN*im
elseif real(z) > 0
return Complex(real(z)^pᵣ, z isa Real ? ifelse(real(z) < 1, -imag(p), imag(p)) : flipsign(imag(z), pᵣ))
else
Expand All @@ -741,24 +745,24 @@ function _cpow(z::Union{T,Complex{T}}, p::Union{T,Complex{T}}) where {T<:Abstrac
# improved here, but it's not clear if it's worth it…
return rᵖ * complex(cospi(pᵣ), flipsign(sinpi(pᵣ),imag(z)))
else
iszero(rᵖ) && return zero(Complex{T}) # no way to get correct signs of 0.0
return Complex(T(NaN),T(NaN)) # non-finite phase angle or NaN input
iszero(rᵖ) && return zero(Complex{Tf}) # no way to get correct signs of 0.0
return Complex(Tf(NaN),Tf(NaN)) # non-finite phase angle or NaN input
end
end
else
rᵖ = abs(z)^pᵣ
ϕ = pᵣ*angle(z)
end
elseif isreal(z)
iszero(z) && return real(p) > 0 ? complex(z) : Complex(T(NaN),T(NaN)) # 0 or NaN+NaN*im
iszero(z) && return real(p) > 0 ? complex(z) : Complex(Tf(NaN),Tf(NaN)) # 0 or NaN+NaN*im
zᵣ = real(z)
pᵣ, pᵢ = reim(p)
if zᵣ > 0
rᵖ = zᵣ^pᵣ
ϕ = pᵢ*log(zᵣ)
else
r = -zᵣ
θ = copysign(T(π),imag(z))
θ = copysign(Tf(π),imag(z))
rᵖ = r^pᵣ * exp(-pᵢ*θ)
ϕ = pᵣ*θ + pᵢ*log(r)
end
Expand All @@ -773,11 +777,10 @@ function _cpow(z::Union{T,Complex{T}}, p::Union{T,Complex{T}}) where {T<:Abstrac
if isfinite(ϕ)
return rᵖ * cis(ϕ)
else
iszero(rᵖ) && return zero(Complex{T}) # no way to get correct signs of 0.0
return Complex(T(NaN),T(NaN)) # non-finite phase angle or NaN input
iszero(rᵖ) && return zero(Complex{Tf}) # no way to get correct signs of 0.0
return Complex(Tf(NaN),Tf(NaN)) # non-finite phase angle or NaN input
end
end
_cpow(z, p) = _cpow(float(z), float(p))
^(z::Complex{T}, p::Complex{T}) where T<:Real = _cpow(z, p)
^(z::Complex{T}, p::T) where T<:Real = _cpow(z, p)
^(z::T, p::Complex{T}) where T<:Real = _cpow(z, p)
Expand Down Expand Up @@ -859,7 +862,8 @@ function asin(z::Complex)
Complex(ξ,η)
end

function acos(z::Complex{<:AbstractFloat})
function acos(z::Complex)
z = float(z)
zr, zi = reim(z)
if isnan(zr)
if isinf(zi) return Complex(zr, -zi)
Expand All @@ -880,7 +884,6 @@ function acos(z::Complex{<:AbstractFloat})
if isinf(zr) && isinf(zi) ξ -= oftype(η,pi)/4 * sign(zr) end
Complex(ξ,η)
end
acos(z::Complex) = acos(float(z))

function atan(z::Complex)
w = atanh(Complex(-imag(z),real(z)))
Expand All @@ -898,13 +901,15 @@ function cosh(z::Complex)
cos(Complex(zi,-zr))
end

function tanh(z::Complex{T}) where T<:AbstractFloat
Ω = prevfloat(typemax(T))
function tanh(z::Complex{T}) where T
z = float(z)
Tf = float(T)
Ω = prevfloat(typemax(Tf))
ξ, η = reim(z)
if isnan(ξ) && η==0 return Complex(ξ, η) end
if 4*abs(ξ) > asinh(Ω) #Overflow?
Complex(copysign(one(T),ξ),
copysign(zero(T),η*(isfinite(η) ? sin(2*abs(η)) : one(η))))
Complex(copysign(one(Tf),ξ),
copysign(zero(Tf),η*(isfinite(η) ? sin(2*abs(η)) : one(η))))
else
t = tan(η)
β = 1+t*t #sec(η)^2
Expand All @@ -917,7 +922,6 @@ function tanh(z::Complex{T}) where T<:AbstractFloat
end
end
end
tanh(z::Complex) = tanh(float(z))

function asinh(z::Complex)
w = asin(Complex(-imag(z),real(z)))
Expand All @@ -943,8 +947,10 @@ function acosh(z::Complex)
Complex(ξ, η)
end

function atanh(z::Complex{T}) where T<:AbstractFloat
Ω = prevfloat(typemax(T))
function atanh(z::Complex{T}) where T
z = float(z)
Tf = float(T)
Ω = prevfloat(typemax(Tf))
θ = sqrt(Ω)/4
ρ = 1/θ
x, y = reim(z)
Expand All @@ -963,7 +969,7 @@ function atanh(z::Complex{T}) where T<:AbstractFloat
end
return Complex(real(1/z), copysign(oftype(y,pi)/2, y))
end
β = copysign(one(T), x)
β = copysign(one(Tf), x)
z *= β
x, y = reim(z)
if x == 1
Expand All @@ -986,7 +992,6 @@ function atanh(z::Complex{T}) where T<:AbstractFloat
end
β * Complex(ξ, η)
end
atanh(z::Complex) = atanh(float(z))

#Rounding complex numbers
#Requires two different RoundingModes for the real and imaginary components
Expand Down

0 comments on commit ee7339b

Please sign in to comment.