Skip to content

Commit

Permalink
feat: shared tests for stats distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Sep 29, 2023
1 parent c0a1079 commit 7636f80
Show file tree
Hide file tree
Showing 11 changed files with 276 additions and 226 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,5 @@ jobs:
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
fail_ci_if_error: true
fail_ci_if_error: false
files: ./target/coverage/codecov.json
2 changes: 1 addition & 1 deletion .github/workflows/linter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ jobs:
run: bb lint-deps

- name: Lint files
run: bb lint --config '{:output {:pattern "::{{level}} file={{filename}},line={{row}},col={{col}}::{{message}}"}}'
run: bb lint --cache false --config '{:output {:pattern "::{{level}} file={{filename}},line={{row}},col={{col}}::{{message}}"}}'
8 changes: 4 additions & 4 deletions src/gen/distribution/java_util.clj
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
([] (uniform-distribution 0.0 1.0))
([lo hi] (->Uniform (rng) lo hi)))

(defn gaussian-distribution
([] (gaussian-distribution 0.0 1.0))
(defn normal-distribution
([] (normal-distribution 0.0 1.0))
([mu sigma]
(->Gaussian (rng) mu sigma)))

Expand All @@ -60,5 +60,5 @@
(def uniform
(d/->GenerativeFn uniform-distribution))

(def gaussian
(d/->GenerativeFn gaussian-distribution))
(def normal
(d/->GenerativeFn normal-distribution))
6 changes: 3 additions & 3 deletions src/gen/distribution/math/log_likelihood.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
{:pre [(pos? alpha) (pos? beta)]}
(if (< 0 v 1)
(- (+ (* (- alpha 1) (Math/log v))
(* (- beta alpha) (Math/log (- 1 v))))
(* (- beta 1) (Math/log (- 1 v))))
(log-beta-fn alpha beta))
##-Inf))

Expand All @@ -100,12 +100,12 @@
(defn cauchy
"Returns the log-likelihood of a [Cauchy
distribution](https://en.wikipedia.org/wiki/Cauchy_distribution) parameterized
by `scale` and `location` at the value `v`.
by `location` and `scale` at the value `v`.
The implementation follows the algorithm described on the Cauchy
distribution's [Wikipedia
page](https://en.wikipedia.org/wiki/Cauchy_distribution#Probability_density_function_(PDF))."
[scale location v]
[location scale v]
(let [normalized (/ (- v location) scale)
norm**2 (* normalized normalized)]
(- (- log-pi)
Expand Down
71 changes: 14 additions & 57 deletions test/gen/distribution/commons_math_test.clj
Original file line number Diff line number Diff line change
@@ -1,63 +1,20 @@
(ns gen.distribution.commons-math-test
(:require [clojure.math :as math]
[clojure.test :refer [deftest is]]
[gen.diff :as diff]
[gen.distribution.commons-math :as d]
[gen.generative-function :as gf]
[gen.trace :as trace]))
(:require [clojure.test :refer [deftest]]
[gen.distribution-test :as dt]
[gen.distribution.commons-math :as commons]))

(deftest bernoulli-call-no-args
(is (boolean? (d/bernoulli))))
(deftest bernoulli-tests
(dt/bernoulli-tests commons/bernoulli-distribution)
(dt/bernoulli-gfi-tests commons/bernoulli))

(deftest bernoulli-call-args
(is (boolean? (d/bernoulli 0.5))))
(deftest beta-tests
(dt/beta-tests commons/beta-distribution))

(deftest bernoulli-gf
(is (= d/bernoulli (trace/gf (gf/simulate d/bernoulli [])))))
(deftest uniform-tests
(dt/uniform-tests commons/uniform-distribution))

(deftest bernoulli-args
(is (= [0.5] (trace/args (gf/simulate d/bernoulli [0.5])))))
(deftest normal-tests
(dt/normal-tests commons/normal-distribution))

(deftest bernoulli-retval
(is (boolean? (trace/retval (gf/simulate d/bernoulli [0.5])))))

(deftest bernoulli-choices-noargs
(trace/choices (gf/simulate d/bernoulli [])))

(deftest bernoulli-update-weight
(is (= 1.0
(-> (gf/generate d/bernoulli [0.3] #gen/choice true)
(:trace)
(trace/update #gen/choice true)
(:weight)
(math/exp))))
(is (= (/ 0.7 0.3)
(-> (gf/generate d/bernoulli [0.3] #gen/choice true)
(:trace)
(trace/update #gen/choice false)
(:weight)
(math/exp)))))

(deftest bernoulli-update-discard
(is (nil?
(-> (gf/generate d/bernoulli [0.3] #gen/choice true)
(:trace)
(trace/update nil)
(:discard))))
(is (= #gen/choice true
(-> (gf/generate d/bernoulli [0.3] #gen/choice true)
(:trace)
(trace/update #gen/choice false)
(:discard)))))

(deftest bernoulli-update-change
(is (= diff/unknown-change
(-> (gf/generate d/bernoulli [0.3] #gen/choice true)
(:trace)
(trace/update nil)
(:change))))
(is (= diff/unknown-change
(-> (gf/generate d/bernoulli [0.3] #gen/choice true)
(:trace)
(trace/update #gen/choice false)
(:change)))))
(deftest gamma-tests
(dt/gamma-tests commons/gamma-distribution))
14 changes: 14 additions & 0 deletions test/gen/distribution/java_util_test.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
(ns gen.distribution.java-util-test
(:require [clojure.test :refer [deftest]]
[gen.distribution-test :as dt]
[gen.distribution.java-util :as java-util]))

(deftest bernoulli-tests
(dt/bernoulli-tests java-util/bernoulli-distribution)
(dt/bernoulli-gfi-tests java-util/bernoulli))

(deftest uniform-tests
(dt/uniform-tests java-util/uniform-distribution))

(deftest normal-tests
(dt/normal-tests java-util/normal-distribution))
75 changes: 18 additions & 57 deletions test/gen/distribution/kixi_test.cljc
Original file line number Diff line number Diff line change
@@ -1,65 +1,26 @@
(ns gen.distribution.kixi-test
(:require [clojure.math :as math]
[clojure.test :refer [deftest is]]
[gen]
[gen.choice-map]
[gen.diff :as diff]
[gen.distribution.kixi :as d]
[gen.generative-function :as gf]
[gen.trace :as trace]))
(:require [clojure.test :refer [deftest]]
[gen.distribution-test :as dt]
[gen.distribution.kixi :as kixi]))

(deftest bernoulli-call-no-args
(is (boolean? (d/bernoulli))))
(deftest bernoulli-tests
(dt/bernoulli-tests kixi/bernoulli-distribution)
(dt/bernoulli-gfi-tests kixi/bernoulli))

(deftest bernoulli-call-args
(is (boolean? (d/bernoulli 0.5))))
(deftest beta-tests
(dt/beta-tests kixi/beta-distribution))

(deftest bernoulli-gf
(is (= d/bernoulli (trace/gf (gf/simulate d/bernoulli [])))))
(deftest cauchy-tests
(dt/cauchy-tests kixi/cauchy-distribution))

(deftest bernoulli-args
(is (= [0.5] (trace/args (gf/simulate d/bernoulli [0.5])))))
(deftest exponential-tests
(dt/exponential-tests kixi/exponential-distribution))

(deftest bernoulli-retval
(is (boolean? (trace/retval (gf/simulate d/bernoulli [0.5])))))
(deftest uniform-tests
(dt/uniform-tests kixi/uniform-distribution))

(deftest bernoulli-choices-noargs
(trace/choices (gf/simulate d/bernoulli [])))
(deftest normal-tests
(dt/normal-tests kixi/normal-distribution))

(deftest bernoulli-update-weight
(is (= 1.0
(-> (gf/generate d/bernoulli [0.3] #gen/choice true)
(:trace)
(trace/update #gen/choice true)
(:weight)
(math/exp))))
(is (= (/ 0.7 0.3)
(-> (gf/generate d/bernoulli [0.3] #gen/choice true)
(:trace)
(trace/update #gen/choice false)
(:weight)
(math/exp)))))

(deftest bernoulli-update-discard
(is (nil?
(-> (gf/generate d/bernoulli [0.3] #gen/choice true)
(:trace)
(trace/update nil)
(:discard))))
(is (= #gen/choice true
(-> (gf/generate d/bernoulli [0.3] #gen/choice true)
(:trace)
(trace/update #gen/choice false)
(:discard)))))

(deftest bernoulli-update-change
(is (= diff/unknown-change
(-> (gf/generate d/bernoulli [0.3] #gen/choice true)
(:trace)
(trace/update nil)
(:change))))
(is (= diff/unknown-change
(-> (gf/generate d/bernoulli [0.3] #gen/choice true)
(:trace)
(trace/update #gen/choice false)
(:change)))))
(deftest gamma-tests
(dt/gamma-tests kixi/gamma-distribution))
123 changes: 20 additions & 103 deletions test/gen/distribution/math/log_likelihood_test.cljc
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
(ns gen.distribution.math.log-likelihood-test
(:require [com.gfredericks.test.chuck.clojure-test :refer [checking]]
[clojure.test :refer [deftest is testing]]
[clojure.test.check.generators :as gen]
[gen.distribution.math.log-likelihood :as ll]
[same.core :refer [ish? zeroish? with-comparator]]))
[gen.distribution :as distribution]
[gen.distribution-test :as dt]
[gen.test-check-util :refer [gen-double]]
[same.core :refer [ish? with-comparator]]))

(defn within
"Returns a function that tests whether two values are within `eps` of each
Expand All @@ -19,9 +21,11 @@
1
(* n (factorial (dec n)))))

(defn gen-double [min max]
(gen/double*
{:min min :max max :infinite? false :NaN? false}))
(defn ->logpdf [f]
(fn [& args]
(reify distribution/LogPDF
(logpdf [_ v]
(apply f (concat args [v]))))))

(deftest log-gamma-fn-tests
(testing "log-Gamma ~matches log(factorial)"
Expand All @@ -41,124 +45,37 @@
(Math/sin (* Math/PI z)))))))))

(deftest gamma-tests
(testing "spot checks"
(is (= -6.391804444241573 (ll/gamma 0.001 1 0.4)))
(is (= -393.0922447210179 (ll/gamma 1 0.001 0.4)))))
(dt/gamma-tests (->logpdf ll/gamma)))

(deftest beta-tests
(checking "(log of the) Beta function is symmetrical"
[a (gen-double 0.01 2)
b (gen-double 0.01 2)]
(is (= (ll/log-beta-fn a b)
(ll/log-beta-fn b a))))

(testing "spot checks"
(is (= -6.5026956359820804 (ll/beta 0.001 1 0.4)))
(is (= -6.397440480839912 (ll/beta 1 0.001 0.4)))))
(dt/beta-tests (->logpdf ll/beta)))

(deftest bernoulli-tests
(checking "Bernoulli properties"
[p (gen-double 0 1)
v gen/boolean]
(is (= (ll/bernoulli 0.5 v)
(ll/bernoulli 0.5 (not v)))
"Fair coin has equal chance")

(is (ish? 1.0
(+ (Math/exp (ll/bernoulli p v))
(Math/exp (ll/bernoulli p (not v)))))
"All options sum to 1")))
(dt/bernoulli-tests (->logpdf ll/bernoulli)))

(deftest cauchy-tests
(checking "Cauchy properties"
[scale (gen-double 0.001 100)
v (gen-double -100 100)]
(is (= (ll/cauchy scale 0 v)
(ll/cauchy scale 0 (- v)))
"symmetric about location"))

(testing "spot checks"
(is (= -1.1447298858494002 (ll/cauchy 1 1 1)))
(is (= -1.8378770664093453 (ll/cauchy 2 2 2)))))
(dt/cauchy-tests (->logpdf ll/cauchy)))

(deftest delta-tests
(checking "Delta properties"
[center (gen-double -100 100)
v (gen-double -100 100)]
(if (= center v)
(is (= 0.0 (ll/delta center v)))
(is (= ##-Inf (ll/delta center v))))))
(dt/delta-tests (->logpdf ll/delta)))

(deftest exponential-tests
(dt/exponential-tests (->logpdf ll/exponential))

(checking "exponential will never produce negative values"
[rate (gen-double -100 100)
v (gen-double -100 -0.00001)]
(is (= ##-Inf (ll/exponential rate v))))

(checking "rate 1.0 produces -v"
[v (gen-double 0 100)]
(is (= (- v) (ll/exponential 1.0 v))))

(checking "rate 0.0 produces #-Inf"
[v (gen-double -100 100)]
(is (= ##-Inf (ll/exponential 0.0 v))))

(testing "spot checks"
(is (= -3.3068528194400546 (ll/exponential 2.0 2.0)))
(is (= -5.306852819440055 (ll/exponential 2.0 3.0)))))
(is (= ##-Inf (ll/exponential 0.0 v)))))

(deftest laplace-test
(checking "Laplace properties"
[v (gen-double -10 10)]
(let [log-l (ll/laplace 0 1 v)]
(is (if (neg? v)
(is (= log-l (- v (Math/log 2))))
(is (= log-l (- (- v) (Math/log 2)))))
"location 0, scale 1"))

(is (= (ll/laplace 0 1 v)
(ll/laplace 0 1 (- v)))
"symmetric about location"))

(checking "Laplace with scale 1, location == v"
[v (gen-double -10 10)]
(is (is (= (- (Math/log 2))
(ll/laplace v 1 v)))))

(testing "spot checks"
(is (= -1.6931471805599454 (ll/laplace 2 1 1)))
(is (= -1.8862943611198906 (ll/laplace 0 2 1)))
(is (= 4.214608098422191 (ll/laplace 0 0.001 0.002)))))
(dt/laplace-tests (->logpdf ll/laplace)))

(deftest gaussian-tests
(checking "Gaussian properties"
[mu (gen-double -10 10)
sigma (gen-double 0.001 10)
v (gen-double -100 100)
shift (gen-double -10 10)]
(is (ish? (ll/gaussian 0.0 sigma v)
(ll/gaussian 0.0 sigma (- v)))
"Gaussian is symmetric about the mean")

(is (ish? (ll/gaussian mu sigma v)
(ll/gaussian (+ mu shift) sigma (+ v shift)))
"shifting by the mean is a symmetry"))


(testing "spot checks"
(is (= -1.0439385332046727 (ll/gaussian 0 1 0.5)))
(is (= -1.643335713764618 (ll/gaussian 0 2 0.5)))
(is (= -1.612085713764618 (ll/gaussian 0 2 0)))))
(dt/normal-tests (->logpdf ll/gaussian)))

(deftest uniform-tests
(checking "(log of the) Beta function is symmetrical"
[min (gen-double -10 0)
max (gen-double 0 10)
v (gen-double -10 10)]
(let [log-l (ll/uniform min max v)]
(if (<= min v max)
(is (zeroish?
(+ log-l (Math/log (- max min))))
"Inside the bounds, log-l*range == 1.0")
(is (= ##-Inf log-l)
"Outside the bounds, (log 0.0)")))))
(dt/uniform-tests (->logpdf ll/uniform)))
Loading

0 comments on commit 7636f80

Please sign in to comment.