Skip to content

Commit

Permalink
fix static variable
Browse files Browse the repository at this point in the history
  • Loading branch information
xukai92 committed Apr 23, 2019
1 parent 2ba0ec0 commit 6b0f3ef
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/distributions/gumbel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ BatchGumbelBernoulli(p; τ=FT(0.2)) = BatchGumbelBernoulli(p, τ)
Sample from Gumbel-Bernoulli distributions.
"""
function rand(gb::BatchGumbelBernoulli{T}) where {T}
isa(gb.p, AutoGrad.Result) && (T = typeof(value(gb)))
_T = isa(gb.p, AutoGrad.Result) ? typeof(value(gb)) : T

# TODO: re-implement this `rand` using the same procedure for `BatchGumbelBernoulliLogit`
FT = eltype(gb.p)
Expand All @@ -74,8 +74,8 @@ function rand(gb::BatchGumbelBernoulli{T}) where {T}
_one = one(FT)
τ = gb.τ

u0 = rand(FT, sz...); g0 = T(_u2gumbel(FT, u0))
u1 = rand(FT, sz...); g1 = T(_u2gumbel(FT, u1))
u0 = rand(FT, sz...); g0 = _T(_u2gumbel(FT, u0))
u1 = rand(FT, sz...); g1 = _T(_u2gumbel(FT, u1))

logit0 = (g0 .+ log.(_one + _eps .- gb.p)) ./ τ
logit1 = (g1 .+ log.(gb.p .+ _eps)) ./ τ
Expand Down

0 comments on commit 6b0f3ef

Please sign in to comment.