Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote AD: Cannot convert Array{Float64,1} to Float64 #43

Closed
itsdfish opened this issue Jul 23, 2019 · 9 comments · Fixed by #44
Closed

Zygote AD: Cannot convert Array{Float64,1} to Float64 #43

itsdfish opened this issue Jul 23, 2019 · 9 comments · Fixed by #44

Comments

@itsdfish
Copy link

Hi Tamas-

Below I have a simple model that runs with forward diff but throws a conversion error when I switch to Zygote. I will omit the error message because it is quite long. Here is some info about my setup: Julia 1.1.0 and

(v1.1) pkg> st DynamicHMC
    Status `~/.julia/environments/v1.1/Project.toml`
  [bbc10e6e] DynamicHMC v1.0.5
  [6fdf6af0] LogDensityProblems v0.8.2
  [d96e819e] Parameters v0.10.3
  [4c63d2b9] StatsFuns v0.8.0
  [84d833dd] TransformVariables v0.3.3

(v1.1) pkg> st Zygote
    Status `~/.julia/environments/v1.1/Project.toml`
  [e88e6eb3] Zygote v0.3.2

The simple two parameter Gaussian model I tried is listed below.

using Distributions,DynamicHMC,LogDensityProblems,TransformVariables,Parameters
import Zygote

struct GaussianProb{TY <: AbstractVector}
   "Observations."
   y::TY
end

function (problem::GaussianProb)(θ)
   @unpack y = problem   # extract the data
   @unpack mu, sigma = θ
   loglikelihood(Normal(mu, sigma), y) + logpdf(Normal(0,1), mu) +
   logpdf(Truncated(Cauchy(0,5),0,Inf), sigma)
end

# Define problem with data and inits.
function sampleDHMC(obs,N,nsamples)
 p = GaussianProb(obs);
 p((mu = 0.0, sigma = 1.0))

 # Write a function to return properly dimensioned transformation.

 problem_transformation(p::GaussianProb) =
     as((mu  = as(Real, -25, 25), sigma = asℝ₊), )
 # Use Flux for the gradient.
 P = TransformedLogDensity(problem_transformation(p), p)
 #∇P = ADgradient(:ForwardDiff, P)
 ∇P = ADgradient(:Zygote, P)
 # FSample from the posterior.
 chain, NUTS_tuned = NUTS_init_tune_mcmc(∇P, nsamples,report=ReportSilent());
 # Undo the transformation to obtain the posterior from the chain.
 posterior = TransformVariables.transform.(Ref(problem_transformation(p)), get_position.(chain));
 #chns = nptochain(posterior,NUTS_tuned)
 #return chns
end

function GaussianGen(;μ=0,σ=1,Nd,kwargs...)
 data=(y=rand(Normal(μ,σ),Nd),N=Nd)
   return data
end

data = GaussianGen(;Nd=100)

samples = sampleDHMC(data...,2000)
@tpapp tpapp transferred this issue from tpapp/DynamicHMC.jl Jul 23, 2019
@tpapp
Copy link
Owner

tpapp commented Jul 23, 2019

This is an issue with LogDensityProblems, and can be replicated with

using Distributions,DynamicHMC,LogDensityProblems,TransformVariables,Parameters
import Zygote

struct GaussianProb{TY <: AbstractVector}
   "Observations."
   y::TY
end

function (problem::GaussianProb)(θ)
   @unpack y = problem   # extract the data
   @unpack mu, sigma = θ
   loglikelihood(Normal(mu, sigma), y) + logpdf(Normal(0,1), mu) +
   logpdf(Truncated(Cauchy(0,5),0,Inf), sigma)
end

function GaussianGen(;μ=0=1,Nd,kwargs...)
 data=(y=rand(Normal(μ,σ),Nd),N=Nd)
   return data
end

data = first(GaussianGen(;Nd=100))
p = GaussianProb(data);
t = as((mu  = as(Real, -25, 25), sigma = asℝ₊), )
P = TransformedLogDensity(t, p)
∇P_FD = ADgradient(:ForwardDiff, P)
∇P_ZY = ADgradient(:Zygote, P)

LogDensityProblems.logdensity(LogDensityProblems.Value, P, zeros(2))
LogDensityProblems.logdensity(LogDensityProblems.ValueGradient, ∇P_FD, zeros(2)) # works
LogDensityProblems.logdensity(LogDensityProblems.ValueGradient, ∇P_ZY, zeros(2)) # errors

I will look into it soon.

@tpapp
Copy link
Owner

tpapp commented Jul 24, 2019

@itsdfish: I fixed this in #44, will merge the branch soon and then release a minor version upgrade.

Note that Zygote is still experimental and has bugs like FluxML/Zygote.jl#271 (which I just discovered while fixing this), but reports against this package are most welcome in any case, and if they are an issue in Zygote I will just help make an MWE for an issue there.

@itsdfish
Copy link
Author

Thanks for fixing this so quickly. I look forward to trying it out. I will definitely report bugs as I come across them and will try identify MWEs where I can.

@tpapp tpapp closed this as completed in #44 Jul 24, 2019
@itsdfish
Copy link
Author

I switched to master on LogDensityProblems and still receive the same error, which traces back to this line. I just want to clarify whether this change was supposed to fix the problem or whether FluxML/Zygote.jl#271 needs to be resolved first.

@tpapp tpapp reopened this Jul 24, 2019
@tpapp
Copy link
Owner

tpapp commented Jul 24, 2019

No, that bug is unrelated. Can you paste an

pkg> st --manifest

and the error message?

@itsdfish
Copy link
Author

Sure thing.

Package info:

(v1.1) pkg> st --manifest
    Status `~/.julia/environments/v1.1/Manifest.toml`
  [621f4979] AbstractFFTs v0.4.1
  [79e6a3ab] Adapt v1.0.0
  [0bf59076] AdvancedHMC v0.1.9
  [dce04be8] ArgCheck v1.0.1
  [7d9fca2a] Arpack v0.3.1
  [bf4720bc] AssetRegistry v0.1.0
  [c52e3926] Atom v0.8.5
  [13072b0f] AxisAlgorithms v1.0.0
  [39de3d68] AxisArrays v0.3.0
  [6e4b80f9] BenchmarkTools v0.4.2
  [76274a88] Bijectors v0.3.1
  [9e28174c] BinDeps v0.8.10
  [b99e7846] BinaryProvider v0.5.6
  [00ebfdb7] CSTParser v0.6.1
  [336ed68f] CSV v0.5.9
  [49dc2e85] Calculus v0.4.1
  [7057c7e9] Cassette v0.2.5
  [324d7699] CategoricalArrays v0.5.5
  [aaaa29a8] Clustering v0.13.2
  [593b3428] CmdStan v5.1.0
  [53a63b46] CodeTools v0.6.4
  [da1fd8a2] CodeTracking v0.5.7
  [3da002f7] ColorTypes v0.8.0
  [5ae59095] Colors v0.9.5
  [861a8166] Combinatorics v0.7.0
  [bbf7d656] CommonSubexpressions v0.2.0
  [34da2185] Compat v2.1.0
  [8f4d0f93] Conda v1.3.0
  [d38c429a] Contour v0.5.1
  [a8cc5b0e] Crayons v4.0.0
  [667455a9] Cubature v1.4.0
  [bb0ebd6b] DEMCMC v0.1.0 [`~/.julia/dev/DEMCMC`]
  [9a962f9c] DataAPI v1.0.0
  [a93c6f00] DataFrames v0.19.0
  [864edb3b] DataStructures v0.15.0
  [e2d170a0] DataValueInterfaces v1.0.0
  [e7dc6d0d] DataValues v0.4.12
  [39dd38d3] Dierckx v0.4.1
  [01453d9d] DiffEqDiffTools v0.14.0
  [163ba53b] DiffResults v0.0.4
  [b552c78f] DiffRules v0.0.10
  [b4f34e82] Distances v0.8.0
  [31c24e10] Distributions v0.21.0
  [33d173f1] DocSeeker v0.2.0
  [ffbed154] DocStringExtensions v0.8.0
  [e30172f5] Documenter v0.23.0
  [bbc10e6e] DynamicHMC v1.0.5
  [7876af07] Example v0.5.3+ [`~/.julia/dev/Example`]
  [7a1cc6ca] FFTW v0.2.4
  [1a297f60] FillArrays v0.6.3
  [53c48c17] FixedPointNumbers v0.6.1
  [f6369f11] ForwardDiff v0.10.3
  [de31a74c] FunctionalCollections v0.5.0
  [38e38edf] GLM v1.1.1
  [28b8d3ca] GR v0.40.0
  [4d00f742] GeometryTypes v0.7.5
  [cd3eb016] HTTP v0.8.4
  [9fb69e20] Hiccup v0.2.2
  [d9be37ee] Homebrew v0.7.1
  [7073ff75] IJulia v1.18.1
  [7869d1d1] IRTools v0.2.2
  [83e8ac13] IniFile v0.5.0
  [505f98c9] InplaceOps v0.3.0
  [a98d9a8b] Interpolations v0.12.2
  [8197267c] IntervalSets v0.3.1
  [41ab1584] InvertedIndices v1.0.0
  [c8e1da08] IterTools v1.2.0
  [82899510] IteratorInterfaceExtensions v1.0.0
  [682c06a0] JSON v0.21.0
  [aa1ae85d] JuliaInterpreter v0.5.2
  [e5e0dc1b] Juno v0.7.0
  [5ab0869b] KernelDensity v0.5.1
  [7c4cb9fa] LNR v0.2.0
  [b964fa9f] LaTeXStrings v1.0.3
  [984bce1d] LambertW v0.4.3
  [50d2b5c4] Lazy v0.13.2
  [5078a376] LazyArrays v0.9.0
  [6f1fad26] Libtask v0.3.0
  [d3d80556] LineSearches v7.0.1
  [6fdf6af0] LogDensityProblems v0.8.3 #master (https://github.com/tpapp/LogDensityProblems.jl.git)
  [6f1432cf] LoweredCodeUtils v0.3.5
  [c7f686f2] MCMCChains v0.3.10
  [1914dd2f] MacroTools v0.5.1
  [dbb5928d] MappedArrays v0.2.1
  [739be429] MbedTLS v0.6.8
  [442fdcdd] Measures v0.3.0
  [e89f7d12] Media v0.5.0
  [e1d29d7a] Missings v0.4.1
  [d41bc354] NLSolversBase v7.3.1
  [872c559c] NNlib v0.6.0
  [77ba4419] NaNMath v0.3.2
  [86f7a689] NamedArrays v0.9.3
  [b8a86587] NearestNeighbors v0.4.3
  [510215fc] Observables v0.2.3
  [6fe1bfb0] OffsetArrays v0.11.1
  [429524aa] Optim v0.19.1
  [bac558e1] OrderedCollections v1.1.0
  [90014a1f] PDMats v0.9.7
  [d96e819e] Parameters v0.10.3
  [69de0a69] Parsers v0.3.6
  [fa939f87] Pidfile v1.1.0
  [ccf2f8ad] PlotThemes v0.3.0
  [995b91a9] PlotUtils v0.5.8
  [91a5bcdd] Plots v0.25.3
  [f27b6e38] Polynomials v0.5.2
  [2dfb63ee] PooledArrays v0.5.2
  [85a6dd25] PositiveFactorizations v0.2.2
  [92933f4c] ProgressMeter v1.0.0
  [438e738f] PyCall v1.91.2
  [d330b81b] PyPlot v2.8.1
  [1fd47b50] QuadGK v2.1.0 #master (https://github.com/JuliaMath/QuadGK.jl.git)
  [6f49c342] RCall v0.13.3
  [b3c3ace0] RangeArrays v0.3.1
  [c84ed2f1] Ratios v0.3.1
  [3cdcf5f2] RecipesBase v0.6.0
  [189a3867] Reexport v0.2.0
  [ae029012] Requires v0.5.2
  [295af30f] Revise v2.1.6
  [79098fc4] Rmath v0.5.0
  [992d4aef] Showoff v0.3.1
  [b85f4697] SoftGlobalScope v1.0.10
  [a2af1166] SortingAlgorithms v0.3.1
  [276daf66] SpecialFunctions v0.7.2
  [90137ffa] StaticArrays v0.11.0
  [2913bbd2] StatsBase v0.31.0
  [4c63d2b9] StatsFuns v0.8.0
  [3eaba693] StatsModels v0.5.0
  [f3b207a7] StatsPlots v0.11.0
  [88034a9c] StringDistances v0.3.2
  [3783bdb8] TableTraits v1.0.0
  [bd369af6] Tables v0.2.9
  [a759f4b9] TimerOutputs v0.5.0
  [0796e94c] Tokenize v0.5.5
  [37b6cedf] Traceur v0.3.0
  [9f7883ad] Tracker v0.2.2
  [84d833dd] TransformVariables v0.3.3
  [a2a6695c] TreeViews v0.3.0
  [fce5fe82] Turing v0.6.18
  [30578b45] URIParser v0.4.0
  [81def892] VersionParsing v1.1.3
  [ea10d353] WeakRefStrings v0.6.1
  [0f1e0344] WebIO v0.8.6
  [104b5d7c] WebSockets v1.5.2
  [cc8bc4a8] Widgets v0.6.1
  [1b915085] WinReg v0.3.1
  [efce3f68] WoodburyMatrices v0.4.1
  [c2297ded] ZMQ v1.0.0
  [e88e6eb3] Zygote v0.3.2
  [2a0f44e3] Base64 
  [ade2ca70] Dates 
  [8bb1440f] DelimitedFiles 
  [8ba89e20] Distributed 
  [7b1f6079] FileWatching 
  [9fa8497b] Future 
  [b77e0a4c] InteractiveUtils 
  [76f85450] LibGit2 
  [8f399da3] Libdl 
  [37e2e46d] LinearAlgebra 
  [56ddb016] Logging 
  [d6f4376e] Markdown 
  [a63ad114] Mmap 
  [44cfe95a] Pkg 
  [de0858da] Printf 
  [9abbd945] Profile 
  [3fa0cd96] REPL 
  [9a3f8284] Random 
  [ea8e919c] SHA 
  [9e88b42a] Serialization 
  [1a1011a3] SharedArrays 
  [6462fe0b] Sockets 
  [2f01184e] SparseArrays 
  [10745b16] Statistics 
  [4607b0f0] SuiteSparse 
  [8dfed614] Test 
  [cf7118a7] UUIDs 
  [4ec0a83e] Unicode

Error message:

julia> LogDensityProblems.logdensity(LogDensityProblems.ValueGradient, ∇P_ZY, zeros(2)) # errors

ERROR: MethodError: Cannot `convert` an object of type Array{Float64,1} to an object of type Float64
Closest candidates are:
  convert(::Type{T<:Number}, ::T<:Number) where T<:Number at number.jl:6
  convert(::Type{T<:Number}, ::Number) where T<:Number at number.jl:7
  convert(::Type{T<:Number}, ::Base.TwicePrecision) where T<:Number at twiceprecision.jl:250
  ...
Stacktrace:
 [1] setindex!(::Array{Float64,1}, ::Array{Float64,1}, ::Int64) at ./array.jl:767
 [2] copyto!(::SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true}, ::NamedTuple{(:parent, :indices, :offset1, :stride1),Tuple{Array{Float64,1},Nothing,Nothing,Nothing}}) at ./subarray.jl:293
 [3] (::getfield(Zygote, Symbol("##767#769")){Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Array{Float64,1},Tuple{UnitRange{Int64}}})(::NamedTuple{(:parent, :indices, :offset1, :stride1),Tuple{Array{Float64,1},Nothing,Nothing,Nothing}}) at /Users/christopher.fisher/.julia/packages/Zygote/fuj2C/src/lib/array.jl:41
 [4] (::getfield(Zygote, Symbol("##2040#back#771")){getfield(Zygote, Symbol("##767#769")){Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Array{Float64,1},Tuple{UnitRange{Int64}}}})(::NamedTuple{(:parent, :indices, :offset1, :stride1),Tuple{Array{Float64,1},Nothing,Nothing,Nothing}}) at /Users/christopher.fisher/.julia/packages/Zygote/fuj2C/src/lib/grad.jl:46
 [5] view_into at /Users/christopher.fisher/.julia/packages/TransformVariables/nMbbh/src/utilities.jl:54 [inlined]
 [6] (::typeof(∂(TransformVariables.view_into)))(::NamedTuple{(:parent, :indices, :offset1, :stride1),Tuple{Array{Float64,1},Nothing,Nothing,Nothing}}) at /Users/christopher.fisher/.julia/packages/Zygote/fuj2C/src/compiler/interface2.jl:0
 [7] _transform_tuple at /Users/christopher.fisher/.julia/packages/TransformVariables/nMbbh/src/aggregation.jl:161 [inlined]
 [8] (::typeof(∂(TransformVariables._transform_tuple)))(::Tuple{Tuple{Float64},Float64}) at /Users/christopher.fisher/.julia/packages/Zygote/fuj2C/src/compiler/interface2.jl:0
 [9] _transform_tuple at /Users/christopher.fisher/.julia/packages/TransformVariables/nMbbh/src/aggregation.jl:162 [inlined]
 [10] (::typeof(∂(TransformVariables._transform_tuple)))(::Tuple{Tuple{Float64,Float64},Float64}) at /Users/christopher.fisher/.julia/packages/Zygote/fuj2C/src/compiler/interface2.jl:0
 [11] transform_tuple at /Users/christopher.fisher/.julia/packages/TransformVariables/nMbbh/src/aggregation.jl:171 [inlined]
 [12] (::typeof(∂(TransformVariables.transform_tuple)))(::Tuple{Tuple{Float64,Float64},Float64}) at /Users/christopher.fisher/.julia/packages/Zygote/fuj2C/src/compiler/interface2.jl:0
 [13] transform_with at /Users/christopher.fisher/.julia/packages/TransformVariables/nMbbh/src/aggregation.jl:223 [inlined]
 [14] (::typeof(∂(TransformVariables.transform_with)))(::Tuple{NamedTuple{(:mu, :sigma),Tuple{Float64,Float64}},Float64}) at /Users/christopher.fisher/.julia/packages/Zygote/fuj2C/src/compiler/interface2.jl:0
 [15] transform_and_logjac at /Users/christopher.fisher/.julia/packages/TransformVariables/nMbbh/src/generic.jl:204 [inlined]
 [16] (::typeof(∂(TransformVariables.transform_and_logjac)))(::Tuple{NamedTuple{(:mu, :sigma),Tuple{Float64,Float64}},Float64}) at /Users/christopher.fisher/.julia/packages/Zygote/fuj2C/src/compiler/interface2.jl:0
 [17] transform_logdensity at /Users/christopher.fisher/.julia/packages/TransformVariables/nMbbh/src/generic.jl:136 [inlined]
 [18] (::typeof(∂(TransformVariables.transform_logdensity)))(::Float64) at /Users/christopher.fisher/.julia/packages/Zygote/fuj2C/src/compiler/interface2.jl:0
 [19] logdensity at /Users/christopher.fisher/.julia/packages/LogDensityProblems/7wIZh/src/transformed.jl:42 [inlined]
 [20] (::typeof(∂(LogDensityProblems.logdensity)))(::Float64) at /Users/christopher.fisher/.julia/packages/Zygote/fuj2C/src/compiler/interface2.jl:0
 [21] #1 at /Users/christopher.fisher/.julia/packages/LogDensityProblems/7wIZh/src/AD.jl:25 [inlined]
 [22] (::typeof(∂(λ)))(::Float64) at /Users/christopher.fisher/.julia/packages/Zygote/fuj2C/src/compiler/interface2.jl:0
 [23] (::getfield(Zygote, Symbol("##32#33")){typeof(∂(λ))})(::Float64) at /Users/christopher.fisher/.julia/packages/Zygote/fuj2C/src/compiler/interface.jl:38
 [24] logdensity(::Type{LogDensityProblems.ValueGradient}, ::LogDensityProblems.ZygoteGradientLogDensity{TransformedLogDensity{TransformVariables.TransformTuple{NamedTuple{(:mu, :sigma),Tuple{TransformVariables.ScaledShiftedLogistic{Int64},TransformVariables.ShiftedExp{true,Float64}}}},GaussianProb{Array{Float64,1}}}}, ::Array{Float64,1}) at /Users/christopher.fisher/.julia/packages/LogDensityProblems/7wIZh/src/AD_Zygote.jl:22
 [25] top-level scope at none:0

@tpapp
Copy link
Owner

tpapp commented Jul 24, 2019

Thanks. This is another bug, possibly in Zygote. I made an MWE at tpapp/TransformVariables.jl#48.

I will isolate it and either fix it (if its mine) or report an issue for Zygote.

In the meantime, I would recommend that you consider Flux for AD, it is much more mature. Eg the following works fine:

using Distributions, LogDensityProblems, TransformVariables, Parameters, Flux

struct GaussianProb{TY <: AbstractVector}
   "Observations."
   y::TY
end

function (problem::GaussianProb)(θ)
   @unpack y = problem   # extract the data
   @unpack mu, sigma = θ
   loglikelihood(Normal(mu, sigma), y) + logpdf(Normal(0,1), mu) +
   logpdf(Truncated(Cauchy(0,5),0,Inf), sigma)
end

function GaussianGen(;μ=0=1,Nd,kwargs...)
 data=(y=rand(Normal(μ,σ),Nd),N=Nd)
   return data
end

data = first(GaussianGen(;Nd=100))
p = GaussianProb(data);
t = as((mu  = as(Real, -25, 25), sigma = asℝ₊), )
P = TransformedLogDensity(t, p)
∇P_FL = ADgradient(:Flux, P)

LogDensityProblems.logdensity(LogDensityProblems.Value, P, zeros(2))
LogDensityProblems.logdensity(LogDensityProblems.ValueGradient, ∇P_FL, zeros(2)) # works

@tpapp tpapp closed this as completed Jul 24, 2019
@itsdfish
Copy link
Author

Thanks Tamas. I will switch to Flux AD in the meantime.

@itsdfish
Copy link
Author

I was hoping that I could ask you a quick question that might help me make some project decisions. I'm interested in trying Zygote because ForwardDiff hits a performance wall with larger problems, such as a Poisson hierarchical regression model. I have found that it runs ~ 20 slower than Stan, with an increasing disparity as more data are added. A quick benchmark reveals that DynamicHMC with Flux AD is ~3 times slower than DynamicHMC with ForwardDiff AD. Aside from sporadic bugs, have you found that Zygote performs closer to Stan?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants