Improve heteroscedastic (#113)
* Replace first by only when needed and divide by 2 instead of multiplying by 0.5

* More transitions from first to only

* Fixed the quadrature

* Use only for prior tests

* Fixing logistic-softmax

* Fix introduced issue with sampling

* Add sampling method for heteroscedastic gaussian

* Fix bug from heteroscedasticity

* Fixed formulations and abuse of `map!`

* Abuse of map!

* Missing coma

* Fixingy fixes

* Fixed the uses of map!

* Added fixes

* Beautiful modifications

* Finally fixed the PG distributions!

* Handle case d.b < 1

* Patch bump
theogf committed Nov 25, 2021
1 parent fdc70d2 commit 71a9264
Showing 16 changed files with 268 additions and 245 deletions.
name = "AugmentedGaussianProcesses"
uuid = "38eea1fd-7d7d-5162-9d08-f89d0f2e271e"
authors = ["Theo Galy-Fajou <>"]
version = "0.11.2"
version = "0.11.3"

AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
10 changes: 6 additions & 4 deletions docs/examples/heteroscedastic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using AugmentedGaussianProcesses
using Distributions
using LinearAlgebra
using Plots
using Random
default(; lw=3.0, msw=0.0)
# using CairoMakie

Expand All @@ -15,16 +16,17 @@ default(; lw=3.0, msw=0.0)
# ``y \sim f + \epsilon``
# where ``\epsilon \sim \mathcal{N}(0, (\lambda \sigma(g))^{-1})``
# We create a toy dataset with X ∈ [-10, 10] and sample `f`, `g` and `y` given this same generative model
rng = MersenneTwister(42)
N = 200
x = (sort(rand(N)) .- 0.5) * 20.0
x = (sort(rand(rng, N)) .- 0.5) * 20.0
x_test = range(-10, 10; length=500)
kernel = 5.0 * SqExponentialKernel() ScaleTransform(1.0) # Kernel function
K = kernelmatrix(kernel, x) + 1e-5I # The kernel matrix
f = rand(MvNormal(K)); # We draw a random sample from the GP prior
f = rand(rng, MvNormal(K)); # We draw a random sample from the GP prior

# We add a prior mean on `g` so that the variance does not become too big
μ₀ = -3.0
g = rand(MvNormal(μ₀ * ones(N), K))
g = rand(rng, MvNormal(μ₀ * ones(N), K))
λ = 3.0 # The maximum possible precision
σ = inv.(sqrt.(λ * AGP.logistic.(g))) # We use the following transform to obtain the std. deviation
y = f + σ .* randn(N); # We finally sample the ouput
Expand All @@ -38,7 +40,7 @@ scatter!(x, y; alpha=0.5, msw=0.0, lab="y") # Observation samples
model = VGP(
optimiser=true, # We optimise both the mean parameters and kernel hyperparameters
Expand Down
2 changes: 1 addition & 1 deletion docs/src/
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Not all inference are implemented/valid for all likelihoods, here is the compati
| GaussianLikelihood | ✔ (Analytic) ||||
| StudentTLikelihood |||||
| LaplaceLikelihood |||||
| HeteroscedasticLikelihood || (dev) | (dev) ||
| HeteroscedasticLikelihood || | (dev) ||
| LogisticLikelihood |||||
| BayesianSVM || (dev) |||
| LogisticSoftMaxLikelihood |||| (dev) |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module ComplementaryDistributions
using Distributions
using Random
using SpecialFunctions
using StatsFuns: twoπ
using StatsFuns: twoπ, halfπ, inv2π, fourinvπ

export GeneralizedInverseGaussian, PolyaGamma, LaplaceTransformDistribution
Expand Down
201 changes: 106 additions & 95 deletions src/ComplementaryDistributions/polyagamma.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
using Distributions, Random
using Statistics
using SpecialFunctions
const __TRUNC = 0.64;
const __TRUNC_RECIP = 1.0 / __TRUNC;
const pg_t = 0.64
const pg_inv_t = inv(pg_t)

PolyaGamma(b::Int, c::Real)
PolyaGamma(b::Real, c::Real)
## Arguments
- `b::Int`
- `b::Real`
- `c::Real` exponential tilting
## Keyword Arguments
Expand All @@ -16,151 +17,161 @@ const __TRUNC_RECIP = 1.0 / __TRUNC;
Create a PolyaGamma sampler with parameters `b` and `c`
struct PolyaGamma{Tc,A} <: Distributions.ContinuousUnivariateDistribution
struct PolyaGamma{Tb,Tc} <: Distributions.ContinuousUnivariateDistribution
# For sum of Gammas.
function PolyaGamma{T}(b::Int, c::T, trunc::Int, nmax::Int) where {T<:Real}
if trunc < 1
@warn "trunc < 1. Setting trunc=1."
trunc = 1
bvec = [convert(T, (twoπ * (k - 0.5))^2) for k in 1:trunc]
return new{typeof(c),typeof(bvec)}(b, c, trunc, nmax, bvec)

Base.eltype(::PolyaGamma{T,Tc}) where {T,Tc} = Tc

Distributions.params(d::PolyaGamma) = (d.b, d.c)

Statistics.mean(d::PolyaGamma) = d.b / (2 * d.c) * tanh(d.c / 2)

function PolyaGamma(b::Int, c::T; nmax::Int=10, trunc::Int=200) where {T<:Real}
return PolyaGamma{T}(b, c, trunc, nmax)
Base.minimum(d::PolyaGamma) = zero(eltype(d))
Base.maximum(::PolyaGamma) = Inf
Distributions.insupport(::PolyaGamma, x::Real) = zero(x) <= x < Inf

function Distributions.pdf(d::PolyaGamma, x::Real)
b, c = params(d)
iszero(x) && return zero(x)
return _tilt(x, b, c) * 2^(b - 1) / gamma(b) * sum(0:200) do n
ifelse(iseven(n), 1, -1) * exp(
loggamma(n + b) - loggamma(n + 1) + log(2n + b) - log(twoπ * x^3) / 2 -
(2n + b)^2 / (8x),

function Distributions.pdf(d::PolyaGamma, x)
return cosh(d.c / 2)^d.b * 2.0^(d.b - 1) / gamma(d.b) * sum(
((-1)^n) * gamma(n + d.b) / gamma(n + 1) * (2 * n + b) / (sqrt(2 * π * x^3)) *
exp(-(2 * n + b)^2 / (8 * x) - c^2 / 2 * x) for n in 0:(d.nmax)
function _tilt(ω, b, c)
return cosh(c / 2)^b * exp(-c^2 / 2 * ω)

## Sampling
function Distributions.rand(rng::AbstractRNG, d::PolyaGamma{T}) where {T<:Real}
function Distributions.rand(rng::AbstractRNG, d::PolyaGamma)
if iszero(d.b)
return zero(T)
return zero(eltype(d))
return draw_sum(rng, d)

## Sampling when `b` is an integer
function draw_sum(rng::AbstractRNG, d::PolyaGamma{<:Int})
return sum(Base.Fix1(sample_pg1, rng), d.c * ones(d.b))

function draw_sum(rng::AbstractRNG, d::PolyaGamma{<:Real})
if d.b < 1
return rand_gamma_sum(rng, d, d.b)
return sum(Base.Fix1(draw_like_devroye, rng), d.c * ones(d.b))
trunc_b = floor(Int, d.b)
res_b = d.b - trunc_b
trunc_term = sum(Base.Fix1(sample_pg1, rng), d.c * ones(trunc_b))
res_term = rand_gamma_sum(rng, d, res_b)
return trunc_term + res_term

## Utility functions
function a(n::Int, x::Real)
k = (n + 0.5) * π
if x > __TRUNC
if x > pg_t
return k * exp(-k^2 * x / 2)
elseif x > 0
expnt = -1.5 * (log(π / 2) + log(x)) + log(k) - 2 * (n + 0.5)^2 / x
expnt = -3 / 2 * (log(halfπ) + log(x)) + log(k) - 2 * (n + 1//2)^2 / x
return exp(expnt)
error("x should be a positive real")

function mass_texpon(z::Real)
t = __TRUNC
t = pg_t

fz = 0.125 * π^2 + z^2 / 2
b = sqrt(1.0 / t) * (t * z - 1)
a = sqrt(1.0 / t) * (t * z + 1) * -1.0
K = π^2 / 8 + z^2 / 2
b = sqrt(inv(t)) * (t * z - 1)
a = -sqrt(inv(t)) * (t * z + 1)

x0 = log(fz) + fz * t
x0 = log(K) + K * t
xb = x0 - z + logcdf(Distributions.Normal(), b)
xa = x0 + z + logcdf(Distributions.Normal(), a)

qdivp = 4 / π * (exp(xb) + exp(xa))
qdivp = fourinvπ * (exp(xb) + exp(xa))

return 1.0 / (1.0 + qdivp)
return 1 / (1 + qdivp)

function rtigauss(rng::AbstractRNG, z::Real)
z = abs(z)
t = __TRUNC
x = t + 1.0
if __TRUNC_RECIP > z
alpha = 0.0
rate = 1.0
d_exp = Exponential(1.0 / rate)
while (rand(rng) > alpha)
e1 = rand(rng, d_exp)
e2 = rand(rng, d_exp)
while e1^2 > 2 * e2 / t
e1 = rand(rng, d_exp)
e2 = rand(rng, d_exp)
# Sample from a truncated inverse gaussian
function rand_truncated_inverse_gaussian(rng::AbstractRNG, z::Real)
μ = inv(z)
x = one(z) + pg_t
if μ > pg_t
d_exp = Exponential()
while true
E = rand(rng, d_exp)
E′ = rand(rng, d_exp)
while E^2 > 2E′ / pg_t
E = rand(rng, d_exp)
E′ = rand(rng, d_exp)
x = 1 + e1 * t
x = t / x^2
alpha = exp(-z^2 * x / 2)
x = pg_t / (1 + E * pg_t)^2
α = exp(-z^2 * x / 2)
α >= rand(rng) && break
mu = 1.0 / z
while (x > t)
y = randn(rng)^2
half_mu = mu / 2
mu_Y = mu * y
x = mu + half_mu * mu_Y - half_mu * sqrt(4 * mu_Y + mu_Y^2)
if rand(rng) > mu / (mu + x)
x = mu^2 / x
while (x > pg_t)
Y = randn(rng)^2
μY = μ * Y
x = μ + μ * μY / 2 - μ / 2 * sqrt(4 * μY + μY^2)
if rand(rng) > μ /+ x)
x = μ^2 / x
x > pg_t && break
return x

# ////////////////////////////////////////////////////////////////////////////////
# // Sample //
# ////////////////////////////////////////////////////////////////////////////////

function draw_like_devroye(rng::AbstractRNG, c::Real)
# Sample from PG(1, z)
# Algorithm 1 from "Bayesian Inference for logistic models..." p. 26
function sample_pg1(rng::AbstractRNG, z::Real)
# Change the parameter.
c = abs(c) / 2
z = abs(z) / 2

# Now sample 0.25 * J^*(1, Z := Z/2).
fz = 0.125 * π^2 + c^2 / 2
# ... Problems with large Z? Try using q_over_p.
# double p = 0.5 * __PI * exp(-1.0 * fz * __TRUNC) / fz;
# double q = 2 * exp(-1.0 * Z) * pigauss(__TRUNC, Z);

x = 0.0
s = 1.0
y = 0.0
# int iter = 0; If you want to keep track of iterations.
d_exp = Exponential()
K = π^2 / 8 + z^2 / 2
t = pg_t

r = mass_texpon(z)

while true
if rand(rng) < mass_texpon(c)
x = __TRUNC + rand(rng, d_exp) / fz
x = rtigauss(rng, c)
if r > rand(rng) # sample from truncated exponential
x = t + rand(rng, Exponential()) / K
else # sample from truncated inverse Gaussian
x = rand_truncated_inverse_gaussian(rng,z)
s = a(0, x)
y = rand(rng) * s
n = 0
go = true

# Cap the number of iterations?
while (go)
while true
n = n + 1
if isodd(n)
s = s - a(n, x)
if y <= s
return 0.25 * x
y <= s && return x / 4
s = s + a(n, x)
if y > s
go = false
y > s && break
# Need Y <= S in event that Y = S, e.g. when x = 0.
end # draw_like_devroye
end # Sample PG(1, c)

# Sample ω as the series of Gamma variables (truncated at 200)
function rand_gamma_sum(rng::AbstractRNG, d::PolyaGamma, e::Real)
C = inv2π / π
c = d.c
w = (c * inv2π)^2
d = Gamma(e, 1)
return C * sum(1:200) do k
rand(rng, d) / ((k - 0.5)^2 + w)
10 changes: 10 additions & 0 deletions src/functions/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ function expectation(f, μ::Real, σ²::Real)
return dot(pred_weights, f.(x))

# Return √E[f^2]
function sqrt_expec_square(μ, σ²)
return sqrt(abs2(μ) + σ²)

# Return √E[(f-y)^2]
function sqrt_expec_square(μ, σ², y)
return sqrt(abs2- y) + σ²)

## delta function `(i,j)`, equal `1` if `i == j`, `0` else ##
@inline function δ(T, i::Integer, j::Integer)
return ifelse(i == j, one(T), zero(T))
Expand Down
6 changes: 4 additions & 2 deletions src/likelihood/bayesiansvm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ function local_updates!(
) where {T}
@. local_vars.c = abs2(one(T) - y * μ) + diagΣ
@. local_vars.θ = inv(sqrt(local_vars.c))
map!(local_vars.c, μ, diagΣ, y) do μ, σ², y
abs2(one(T) - y * μ) + σ²
map!(inv sqrt, local_vars.θ, local_vars.c)
return local_vars

Expand Down

