-
Notifications
You must be signed in to change notification settings - Fork 159
/
modeling_library.jl
87 lines (63 loc) · 2.32 KB
/
modeling_library.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#############################
# probability distributions #
#############################
import Distributions
using SpecialFunctions: loggamma, logbeta, digamma
abstract type Distribution{T} end
"""
val::T = random(dist::Distribution{T}, args...)
Sample a random choice from the given distribution with the given arguments.
"""
function random end
"""
lpdf = logpdf(dist::Distribution{T}, value::T, args...)
Evaluate the log probability (density) of the value.
"""
function logpdf end
"""
has::Bool = has_output_grad(dist::Distribution)
Return true of the gradient if the distribution computes the gradient of the logpdf with respect to the value of the random choice.
"""
function has_output_grad end
"""
grads::Tuple = logpdf_grad(dist::Distribution{T}, value::T, args...)
Compute the gradient of the logpdf with respect to the value, and each of the arguments.
If `has_output_grad` returns false, then the first element of the returned tuple is `nothing`.
Otherwise, the first element of the tuple is the gradient with respect to the value.
If the return value of `has_argument_grads` has a false value for at position `i`, then the `i+1`th element of the returned tuple has value `nothing`.
Otherwise, this element contains the gradient with respect to the `i`th argument.
"""
function logpdf_grad end
is_discrete(::Distribution) = false # default
# NOTE: has_argument_grad is documented and exported in gen_fn_interface.jl
get_return_type(::Distribution{T}) where {T} = T
export Distribution
export random
export logpdf
export logpdf_grad
export has_output_grad
export is_discrete
# built-in distributions
include("distributions/distributions.jl")
# @dist DSL
include("dist_dsl/dist_dsl.jl")
# mixtures of distributions
include("mixture.jl")
# products of distributions
include("product.jl")
###############
# combinators #
###############
# code shared by vector-shaped combinators
include("vector.jl")
# built-in generative function combinators
include("choice_at/choice_at.jl")
include("call_at/call_at.jl")
include("map/map.jl")
include("unfold/unfold.jl")
include("recurse/recurse.jl")
include("switch/switch.jl")
#############################################################
# abstractions for constructing custom generative functions #
#############################################################
include("custom_determ.jl")