-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New interface #6
Comments
I was thinking something along the lines of struct Transformed{D<:Distribution, T} <: Distribution
d::D
t::T
end
logpdf(::Transformed{D, T}, ::Real) = ...
rand(::Transformed{D, T}) = ... (this would need to be specialised for Univariate, Multivariate etc) |
Hey! I made an attempt at implementing automatic differentiable variational inference (ADVI) for Turing, which requires transforming the distribution to a distribution with real support, i.e. can use While doing this I was thinking a bit about the interface and implemented the following as a "wrapper" for using Distributions, Bijectors
import Distributions: logpdf, rand
# Issue: this is not enforcing `VariateForm` of `D` and `Transformed{D}` to be the same
struct Transformed{D} <: Distribution{VariateForm, Continuous} where D <: Distribution
d::D
end
# size
Base.size(td::Transformed) = size(td.d)
Base.length(td::Transformed) = length(td.d)
# logp
logpdf(d::Transformed{D} where D <: Distribution{Univariate}, x::T where T <: Real) = logpdf_with_trans(d.d, x, true)
logpdf(d::Transformed{D} where D <: Distribution{Multivariate}, x::AbstractMatrix{T} where T <: Real) = logpdf_with_trans(d.d, x, true)
# rand
rand(td::Transformed, n::Int64) = link(td.d, rand(td.d, n))
rand(d::Transformed{D} where D <: Distribution{Univariate}) = first(rand(d, 1))
rand(d::Transformed{D} where D <: Distribution{Multivariate}) = rand(d, 1) Important: vectorized
If we could fix the above issue, maybe this would be a nice "ad-hoc" implementation. Comparing to the above suggestion which to me seems like it would require a more thorough rewrite of I haven't investigated this to a satisfactory degree, but I was encouraged on |
Can you implement an |
Ah, sorry I ran it on the wrong distribution; meant to use a p = Transformed(Normal(0.0, 1.0))
logpdf(p, rand(p)) # <= this works
logpdf.(p, rand(p, 10)) # <= results in the following error
|
I also think that it would be fairly easy to transition from the p = Transformed(Exponential(1.0)) if we simply implement a constructor similar to Transformed(d::D <: PositiveDistribution) = Transformed(d, PositiveDistributionTransformation()) behind the scenes. This way you could use the simpler interface while transitioning to the more general approach. Also, maybe I misunderstood your question @cpfiffer. Did you question how to implement |
Took a second look at this. I thought you needed to overload logpdf(d::Transformed{D} where D <: Distribution{Univariate}, x::Vector{T} where T <: Real) = logpdf_with_trans(d.d, x, true)
logpdf(d::Transformed{D} where D <: Distribution{Multivariate}, x::Vector{AbstractMatrix{T}} where T <: Real) = logpdf_with_trans(d.d, x, true) This works now: p = Transformed(Normal(0.0, 1.0))
logpdf(p, rand(p,10)) 10-element Array{Float64,1}:
-1.0758450227893368
-1.0166693096930153
-0.9469445907722853
-2.3735502039370373
-0.924499243891403
-1.2068563756279285
-1.363355793193587
-1.6091214288686937
-1.0012221537005328
-0.9931074723392034 Here's the whole MWE again: using Distributions, Bijectors
import Distributions: logpdf, rand
# Issue: this is not enforcing `VariateForm` of `D` and `Transformed{D}` to be the same
struct Transformed{D} <: Distribution{VariateForm, Continuous} where D <: Distribution
d::D
end
# size
Base.size(td::Transformed) = size(td.d)
Base.length(td::Transformed) = length(td.d)
# logp
logpdf(d::Transformed{D} where D <: Distribution{Univariate}, x::T where T <: Real) = logpdf_with_trans(d.d, x, true)
logpdf(d::Transformed{D} where D <: Distribution{Multivariate}, x::AbstractMatrix{T} where T <: Real) = logpdf_with_trans(d.d, x, true)
logpdf(d::Transformed{D} where D <: Distribution{Univariate}, x::Vector{T} where T <: Real) = logpdf_with_trans(d.d, x, true)
logpdf(d::Transformed{D} where D <: Distribution{Multivariate}, x::Vector{AbstractMatrix{T}} where T <: Real) = logpdf_with_trans(d.d, x, true)
# rand
rand(td::Transformed, n::Int64) = link(td.d, rand(td.d, n))
rand(d::Transformed{D} where D <: Distribution{Univariate}) = first(rand(d, 1))
rand(d::Transformed{D} where D <: Distribution{Multivariate}) = rand(d, 1)
p = Transformed(Normal(0.0, 1.0))
logpdf(p, rand(p)) # <= this works
logpdf(p, rand(p,10)) # <= this works now |
Nice! p = Normal(0.0, 1.0)
logpdf(p, rand(p, 10)) I get Also, do we need And regarding overloading |
Got it working. using Distributions, Bijectors
import Distributions: logpdf, rand
# Issue: this is not enforcing `VariateForm` of `D` and `Transformed{D}` to be the same
struct Transformed{D} <: Distribution{VariateForm, Continuous} where D <: Distribution
d::D
end
# size
Base.size(td::Transformed) = size(td.d)
Base.length(td::Transformed) = length(td.d)
# logp
logpdf(d::Transformed{D} where D <: Distribution{Univariate}, x::T where T <: Real) = logpdf_with_trans(d.d, x, true)
logpdf(d::Transformed{D} where D <: Distribution{Multivariate}, x::AbstractMatrix{T} where T <: Real) = logpdf_with_trans(d.d, x, true)
# rand
rand(td::Transformed, n::Int64) = link(td.d, rand(td.d, n))
rand(d::Transformed{D} where D <: Distribution{Univariate}) = first(rand(d, 1))
rand(d::Transformed{D} where D <: Distribution{Multivariate}) = rand(d, 1)
# makes it possible to vectorize operations on `Transformed` for `UnivariateDistributions`
Broadcast.broadcastable(d::Transformed{D} where D <: UnivariateDistribution) = Ref(d) The issue was that the p = Transformed(Normal(0.0, 1.0))
logpdf(p, rand(p)) # <= this works
logpdf.(p, rand(p, 10)) # <= this also works now Though it might be a good idea to consider maybe making explicit types for using Distributions, Bijectors
import Random: AbstractRNG
import Distributions: logpdf, rand, rand!, _rand!, _logpdf
struct UnivariateTransformed <: Distribution{Univariate, Continuous}
d::UnivariateDistribution
end
struct MultivariateTransformed <: Distribution{Multivariate, Continuous}
d::MultivariateDistribution
end
# constructors
transformed(d::UnivariateDistribution) = UnivariateTransformed(d)
transformed(d::MultivariateDistribution) = MultivariateTransformed(d)
# size
Base.length(td::MultivariateTransformed) = length(td.d)
# logp
logpdf(d::UnivariateTransformed, x::T where T <: Real) = logpdf_with_trans(d.d, x, true)
_logpdf(d::MultivariateTransformed, x::AbstractVector{T} where T <: Real) = logpdf_with_trans(d.d, x, true)
# rand
rand(rng::AbstractRNG, td::UnivariateTransformed) = first(rand(td.d))
_rand!(rng::AbstractRNG, td::MultivariateTransformed, x::AbstractVector{T} where T <: Real) = begin
rand!(rng, td.d, x)
y = link(td.d, x)
copyto!(x, y)
end From a user-perspective the interface is still "nice": p = transformed(Normal(0.0, 1.0))
logpdf.(p, rand(p, 10)) # <= this just works now EDIT: fixed the |
I think we need another level of abstraction on transformation as well, for the purpose of making this package compatible with possible incoming normalzing flows. struct InvertibleTransformation end
logabsdetjacob(it::InvertibleTransformation, x) # this returns the logabsdetjacob
forward(it::InvertibleTransformation, x) # this returns forward and the corresponding logabsdetjacob
inverse(it::InvertibleTransformation, x) # this returns inverse and the corresponding logabsdetjacob This is inline with #6 (comment), in which a transformed distribution consists a base distributin and an invertiable transformation (for which the logdetjacob is also computable). Edit: |
Would suggest edit: actually, perhaps |
Yes this should be I like the term |
I did some work on such an interface earlier this week. But then I wrote a simple test which I'm having issues with because of (what seems like) numerical inaccuracies leading to wrong results. Should I submit a WIP PR so you can all see what it looks like? EDIT: I'm using AD to compute the jacbobian from |
Yes, this is reasonable, and I could certainly live with it. My issue is with the use of the name An alternative interface might be: abstract type InvertibleTransformation end
logabsdetjacob(it::InvertibleTransformation, x) # this returns the logabsdetjacob
forward(it::InvertibleTransformation, x) # applies the invertible transformation to a data point
inv(it::InvertibleTransformation) # returns another invertible transformation - the inverse of `it` This way, if we want to compute the inverse, we would write forward(inv(it), x) to get the inverse transformation evaluated at Maybe I'm being overly pedantic and, as I said, I could definitely live with your proposal above @xukai92 |
This looks great! I guess there are two options to implement the inverse transform as well. One is to implement a one to one corresponding transforms, say for Or we can define a wrapper called I feel the second option is more neat. How do you think? @willtebbutt |
We could make the AD way as the callback functions for |
@willtebbutt I have a working example for the proposed abstraction, see below. # Abstraction
abstract type AbstractInvertibleTransformation end
logabsdetjacob(t::T1, x::T2) where {T1<:AbstractInvertibleTransformation,T2} =
error("`logabsdetjacob(t::$T1, x::$T2)` is not implemented.")
forward(t::T1, x::T2) where {T1<:AbstractInvertibleTransformation,T2} =
error("`forward(t::$T1, x::$T2)` is not implemented.")
struct Inversed{T<:AbstractInvertibleTransformation} <: AbstractInvertibleTransformation
original::T
end
inv(t::T) where {T<:AbstractInvertibleTransformation} = Inversed(t)
inv(it::Inversed{T}) where {T<:AbstractInvertibleTransformation} = it.original
logabsdetjacob(it::T1, y::T2) where {T<:AbstractInvertibleTransformation,T1<:Inversed{T},T2} =
error("`logabsdetjacob(it::$T1, y::$T2)` is not implemented.")
forward(it::T1, y::T2) where {T<:AbstractInvertibleTransformation,T1<:Inversed{T},T2} =
error("`forward(it::$T1, y::$T2)` is not implemented.")
# Demo: logit transformation
using StatsFuns: logit, logistic
struct Logit{T<:Real} <: AbstractInvertibleTransformation
a::T
b::T
end
logabsdetjacob(t::Logit{<:Real}, x::Real) = log((x - t.a) * (t.b - x) / (t.b - t.a))
forward(t::Logit, x::Real) = (rv=logit((x - t.a) / (t.b - t.a)), logabsdetjacob=-logabsdetjacob(t, x))
function forward(it::Inversed{Logit{T}}, y::Real) where {T<:Real}
t = it.original
x = (t.b - t.a) * logistic(y) + t.a
return (rv=x, logabsdetjacob=logabsdetjacob(t, x))
end
# Simple demo
using Distributions, Bijectors
a, b = 1, 3
d = Truncated(Normal(0, 1), a, b)
t = Logit(a, b)
y = randn()
it = inv(t)
itres = forward(it, y)
@info itres
@info "This implementation" logpdf(d, itres.rv) + itres.logabsdetjacob
@info "Bijectors.jl" logpdf_with_trans(d, itres.rv, true) The demo outputs:
Comments are welcome! |
@xukai92 glad you like the suggestion, the proposed implementation LGTM. One thing I would suggest is that instead of implementing stuff for inv(t::Logit) = Logistic(some_function_of_parameters_of_t) and make the inverse of |
Yes I was thinking of this approach as well. My issues with it is are
In fact, we can still overload the |
Good point - I really like your |
I've now tried to incorporate the above comments and Flows.jl into my previous proposal #27; any feedback would be nice:) |
Closed by PR #27. |
@JuliaRegistrator register() |
Registration pull request created: JuliaRegistries/General/3061 After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version. This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:
|
@JuliaRegistrator register() |
Registration pull request updated: JuliaRegistries/General/3061 After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version. This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:
|
logpdf(t::Transform, l::Distribution, x)
The text was updated successfully, but these errors were encountered: