/
negativebinomial.jl
133 lines (110 loc) · 3.7 KB
/
negativebinomial.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
NegBinomialLikelihood(r::Real)
## Arguments
- `r::Real` number of failures until the experiment is stopped
---
[Negative Binomial likelihood](https://en.wikipedia.org/wiki/Negative_binomial_distribution) with number of failures `r`
```math
p(y|r, f) = binomial(y + r - 1, y) (1 - σ(f))ʳ σ(f)ʸ
p(y|r, f) = Γ(y + r)/Γ(y + 1)Γ(r) (1 - σ(f))ʳ σ(f)ʸ
```
Where `σ` is the logistic function
"""
struct NegBinomialLikelihood{T<:Real,Tr<:Real,A<:AbstractVector{T}} <: EventLikelihood{T}
r::Tr
c::A
θ::A
function NegBinomialLikelihood{T}(r::Real) where {T<:Real}
return new{T,typeof(r),Vector{T}}(r)
end
function NegBinomialLikelihood{T}(
r::Real, c::A, θ::A
) where {T<:Real,A<:AbstractVector{T}}
return new{T,typeof(r),A}(r, c, θ)
end
end
function NegBinomialLikelihood(r::Real)
return NegBinomialLikelihood{Float64}(r)
end
implemented(::NegBinomialLikelihood, ::Union{<:AnalyticVI,<:GibbsSampling}) = true
function init_likelihood(
likelihood::NegBinomialLikelihood{T}, ::AbstractInference{T}, ::Int, nSamplesUsed::Int
) where {T}
return NegBinomialLikelihood{T}(
likelihood.r, rand(T, nSamplesUsed), zeros(T, nSamplesUsed)
)
end
function (l::NegBinomialLikelihood)(y::Real, f::Real)
return pdf(NegativeBinomial(lr, get_p(l, f)), y)
end
function Distributions.loglikelihood(l::NegBinomialLikelihood, y::Real, f::Real)
return logpdf(NegativeBinomial(lr, get_p(l, f)), y)
end
function expec_count(l::NegBinomialLikelihood, f)
return broadcast((p, r) -> p * r ./ (1 .- p), get_p.(l, f), l.r)
end
function get_p(::NegBinomialLikelihood, f)
return logistic.(f)
end
function Base.show(io::IO, l::NegBinomialLikelihood{T}) where {T}
return print(io, "Negative Binomial Likelihood (r = $(l.r))")
end
function compute_proba(
l::NegBinomialLikelihood{T}, μ::AbstractVector{<:Real}, σ²::AbstractVector{<:Real}
) where {T<:Real}
N = length(μ)
pred = zeros(T, N)
sig_pred = zeros(T, N)
for i in 1:N
x = pred_nodes .* sqrt(max(σ²[i], zero(T))) .+ μ[i]
pred[i] = dot(pred_weights, get_p.(l, x))
sig_pred[i] = dot(pred_weights, get_p.(l, x) .^ 2) - pred[i]^2
end
return pred, sig_pred
end
## Local Updates ##
function local_updates!(
l::NegBinomialLikelihood{T}, y::AbstractVector, μ::AbstractVector, Σ::AbstractVector
) where {T}
@. l.c = sqrt(abs2(μ) + Σ)
@. l.θ = (l.r + y) / l.c * tanh(0.5 * l.c)
end
function sample_local!(l::NegBinomialLikelihood, y::AbstractVector, f::AbstractVector)
return set_ω!(l, rand.(PolyaGamma.(y .+ Int(l.r), abs.(f))))
end
## Global Updates ##
@inline function ∇E_μ(
l::NegBinomialLikelihood{T}, ::AOptimizer, y::AbstractVector
) where {T}
return (0.5 * (y .- l.r),)
end
@inline function ∇E_Σ(
l::NegBinomialLikelihood{T}, ::AOptimizer, y::AbstractVector
) where {T}
return (0.5 .* l.θ,)
end
## ELBO Section ##
AugmentedKL(l::NegBinomialLikelihood, y::AbstractVector) = PolyaGammaKL(l, y)
function logabsbinomial(n, k)
return log(binomial(n, k))
end
function negbin_logconst(y, r::Real)
return loggamma.(y .+ r) - loggamma.(y .+ 1) .- loggamma(r)
end
function negbin_logconst(y, r::Int)
return logabsbinomial.(y .+ (r - 1), y)
end
function expec_loglikelihood(
l::NegBinomialLikelihood{T},
::AnalyticVI,
y::AbstractVector,
μ::AbstractVector,
diag_cov::AbstractVector,
) where {T}
tot = Zygote.@ignore(sum(negbin_logconst(y, l.r))) - log(2.0) * sum(y .+ l.r)
tot += 0.5 * dot(μ, (y .- l.r)) - 0.5 * dot(l.θ, μ) - 0.5 * dot(l.θ, diag_cov)
return tot
end
function PolyaGammaKL(l::NegBinomialLikelihood, y::AbstractVector)
return PolyaGammaKL(y .+ l.r, l.c, l.θ)
end