In [1]:
(ns metaprob.alex
  (:refer-clojure :only [nil? defn let frequencies pos? for -> group-by last ->> gensym])
  (:require
    [clojure.repl :refer :all]
    [metaprob.trace :as trace]
    [metaprob.builtin-impl :as impl]
    [metaprob.syntax :refer :all]
    [metaprob.builtin :refer :all]
    [metaprob.prelude :refer :all]
    [metaprob.context :refer :all]
    [metaprob.distributions :refer :all]
    [metaprob.interpreters :refer :all]
    [metaprob.inference :refer :all]
    [metaprob.compositional :as comp]
    [metaprob.examples.gaussian :refer :all]
    [taoensso.tufte :as tufte :refer (defnp p profiled profile)]
    [taoensso.tufte.stats :refer (stats)]
    [taoensso.tufte.impl :refer (pdata-proxy-get)]
    [metaprob.tutorial.jupyter :refer :all]))

In [9]:
(define pow (gen [a b] (reduce (gen [x _] (* a x)) a (range (- b 1)))))

#'metaprob.alex/pow

In [63]:
(define athlete-model
  (gen []
    (define skill (uniform 0 1))
    (define has-sponsorship-contract? 
        (flip (pow skill 8)))
    (define is-wealthy?
      (flip
        (if has-sponsorship-contract?
          0.8
          0.1)))
    [skill has-sponsorship-contract? is-wealthy?]))

#'metaprob.alex/athlete-model

In [74]:
(athlete-model)

[0.9750075836495604 true true]

In [76]:
(addresses-of ((infer :procedure athlete-model) 1))

((0 "skill" "uniform") (1 "has-sponsorship-contract?" "flip") (2 "is-wealthy?" "flip"))

In [78]:
(define has-contract-trace (trace-set-value {} '(1 "has-sponsorship-contract?" "flip") true))

#'metaprob.alex/has-contract-trace

In [124]:
(define is-wealthy-trace (trace-set-value {} '(2 "is-wealthy?" "flip") true))

#'metaprob.alex/is-wealthy-trace

In [120]:
(clojure.pprint/pprint (get importance-resampling :generative-source))

(gen
 [model-procedure inputs target-trace N]
 (define
  particles
  (replicate
   N
   (gen
    []
    (define
     [_ candidate-trace score]
     (infer
      :procedure
      model-procedure
      :inputs
      inputs
      :target-trace
      target-trace))
    [candidate-trace score])))
 (define scores (map second particles))
 (define which (log-categorical scores))
 (define particle (nth particles which))
 (first particle))


In [122]:
(define importance-resample
    (gen
     [model-procedure inputs target-trace intervention-trace N]
     (define
      particles
      (replicate
       N
       (gen
        []
        (define
         [_ candidate-trace score]
         (infer
          :procedure
          model-procedure
          :inputs
          inputs
          :target-trace
          target-trace
          :intervention-trace
          intervention-trace))
        [candidate-trace score])))
     (define scores (map second particles))
     (define which (log-categorical scores))
     (define particle (nth particles which))
     (first particle)))

#'metaprob.alex/importance-resample

In [151]:
(define wealthy-samples-intervened
  (replicate 100
    (gen []
        (importance-resample
          athlete-model
          []
          is-wealthy-trace
          has-contract-trace
          100))))

#'metaprob.alex/wealthy-samples-intervened

In [158]:
(define wealthy-samples
  (replicate 100
    (gen []
        (importance-resample
          athlete-model
          []
          is-wealthy-trace
          {}
          30))))

#'metaprob.alex/wealthy-samples

In [138]:
(define extract-skill (gen [t] (trace-value t '(0 "skill" "uniform"))))

#'metaprob.alex/extract-skill

In [145]:
(define extract-contract (gen [t] (trace-value t '(1 "has-sponsorship-contract?" "flip"))))

#'metaprob.alex/extract-contract

In [148]:
(count (clojure.core/filter clojure.core/identity (map extract-contract wealthy-samples)))

50

In [159]:
(apply + (map extract-skill wealthy-samples))

66.5121461144975

In [153]:
(apply + (map extract-skill wealthy-samples-intervened))

52.286958823996166