-
Notifications
You must be signed in to change notification settings - Fork 159
/
categorical.jl
30 lines (22 loc) · 943 Bytes
/
categorical.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
struct Categorical <: Distribution{Int} end
"""
categorical(probs::AbstractArray{U, 1}) where {U <: Real}
Given a vector of probabilities `probs` where `sum(probs) = 1`, sample an `Int` `i` from the set {1, 2, .., `length(probs)`} with probability `probs[i]`.
"""
const categorical = Categorical()
function logpdf(::Categorical, x::Int, probs::AbstractArray{U,1}) where {U <: Real}
(x > 0 && x <= length(probs)) ? log(probs[x]) : -Inf
end
function logpdf_grad(::Categorical, x::Int, probs::AbstractArray{U,1}) where {U <: Real}
grad = zeros(length(probs))
grad[x] = 1.0 / probs[x]
(nothing, grad)
end
function random(::Categorical, probs::AbstractArray{U,1}) where {U <: Real}
rand(Distributions.Categorical(probs))
end
is_discrete(::Categorical) = true
(::Categorical)(probs) = random(Categorical(), probs)
has_output_grad(::Categorical) = false
has_argument_grads(::Categorical) = (true,)
export categorical