In [1]:
(ns aide
  (:refer-clojure :exclude [get contains? keys empty? dissoc assoc get-in
                            map replicate apply])
  (:require [clojure.pprint :refer [pprint]]
            [metaprob.syntax :refer [gen]]
            [metaprob.compound :refer :all] ;; get, contains?, keys, empty?, dissoc, assoc, get-in
            [metaprob.builtin :refer :all] ;; map, reduce, replicate, apply
            [metaprob.trace :refer :all]
            [metaprob.autotrace :refer :all]
            [metaprob.prelude :refer :all]
            [metaprob.inference :refer :all]
            [metaprob.intervention :refer :all]
            [metaprob.distributions :refer :all]))

# Importance Resampling as a custom Generative Function

### Inference model of importance resampling

In [2]:
(def importance-resampling-model
  (gen {:tracing-with t} [model inputs observations N]    
    (let [;; Generate N particles of the form [retval trace weight]
          particles
          (map #(t `("particles" ~%) 
                   infer-and-score 
                     :procedure model,
                     :inputs inputs,
                     :observation-trace observations)
               (range N))
        
          ;; Choose one of the particles, according to their weights
          chosen-index
          (t "chosen-index" log-categorical (map #(nth % 2) particles))
          
          ;; Pull out the trace of the chosen particle
          chosen-particle-trace
          (nth (nth particles chosen-index) 1)]
      
      ;; "Act out" the chosen particle's trace, so that its choices are (in theory) constrainable
      (map #(t `("inferred-trace" ~@%) exactly (trace-value chosen-particle-trace %))
        (addresses-of chosen-particle-trace))

      ;; Return the chosen particle
      chosen-particle-trace)))

#'aide/importance-resampling-model

### Custom proposal for internal choices of importance resampling, given an observed sample trace
This is used by the meta-inference algorithm. It traces at the same addresses as the model.

In [3]:
(def importance-resampling-proposal
  (gen {:tracing-with t} [[model inputs observations N] chosen-particle]
    ;; Choose an index at uniform at which to put the observed particle.
    (let [chosen-index (t "chosen-index" uniform-sample (range N))]
      
      ;; "Act out" choosing `chosen-particle` at the given index, and at "inferred-trace"
      (map #(do (t `("particles" ~chosen-index ~@%) exactly (trace-value chosen-particle %))
                (t `("inferred-trace" ~@%) exactly (trace-value chosen-particle %)))
        (addresses-of chosen-particle))
      
      ;; Generate the other N-1 particles at the other indices
      (map #(t `("particles" ~%) 
              infer-and-score
                :procedure model 
                :inputs inputs 
                :observation-trace observations)
        (filter #(not= % chosen-index) (range N))))))

#'aide/importance-resampling-proposal

### Meta-inference algorithm (custom `infer-and-score` for `importance-resampling`)

In [4]:
(def importance-resampling-custom-infer-and-score
  (gen {:tracing-with t} [[model inputs model-observations N] inference-algorithm-constraints]
    
    ;; Check if the inference address is constrained
    (if (trace-has-subtrace? inference-algorithm-constraints "inferred-trace")
      
      ;; If so, use the proposal
      (let [chosen-particle (trace-subtrace inference-algorithm-constraints "inferred-trace")
            
            ;; Get proposed trace of importance sampling algorithm.
            [_ proposed-trace _]
            (t '() infer-and-score :procedure importance-resampling-proposal
                                   :inputs [[model inputs model-observations N] chosen-particle])
            
            ;; Score the proposed trace under the meta-inference algorithm's proposal.
            [_ _ proposal-score]
            (infer-and-score :procedure importance-resampling-proposal 
                             :inputs [[model inputs model-observations N] chosen-particle]
                             :observation-trace proposed-trace)
            
            ;; Score the proposed trace under the inference model.
            [_ _ model-score]
            (infer-and-score :procedure importance-resampling-model 
                             :inputs [model inputs model-observations N]
                             :observation-trace proposed-trace)]
          
        ;; Return the value, trace, and score log p/q
        [chosen-particle
         proposed-trace
         (- model-score proposal-score)])
      
      ;; Otherwise, just use default infer-and-score
      (infer-and-score :procedure importance-resampling-model
                       :inputs [model inputs model-observations N]
                       :observation-trace inference-algorithm-constraints))))

#'aide/importance-resampling-custom-infer-and-score

### Custom generative function implementing importance resampling

In [5]:
;; The final importance-resampling inference algorithm for use with AIDE;
;; uses an `inf` to put together model and custom infer-and-score.
(def importance-resampling-aide
    (inf importance-resampling-model importance-resampling-custom-infer-and-score))

#'aide/importance-resampling-aide

# An example problem with gold-standard and approximate inference algorithms

We consider the following model: $p \sim \text{Beta}(1, 1), x \sim \text{Bernoulli}(p)$.

In [6]:
(def simple-model (gen {:tracing-with t} [] (t "x" flip (t "p" beta 1 1))))

#'aide/simple-model

Suppose we have seen $x = \text{true}$. Then $p_{p|x}(p \mid x = \text{true}) = \text{Beta}(2, 1)$, so we can write an exact gold standard inference algorithm:

In [7]:
(def simple-gold-standard
  ;; An inference model returns an inferred trace, and also
  ;; acts out the inferred trace.
  (gen {:tracing-with t} [] 
    {"p" {:value (t `("inferred-trace" "p") beta 2 1)}}))

#'aide/simple-gold-standard

We now write a helper that creates target inference algorithms which use importance resampling with a custom number of particles:

In [8]:
(defn make-simple-target-inference-model [n-particles]
  (gen {:tracing-with t} []
    (t '() importance-resampling-aide simple-model [] {"x" {:value true}} n-particles)))

#'aide/make-simple-target-inference-model

We can use it to make a few possible inference algorithms for this problem:

In [9]:
(def infer-with-1-particle (make-simple-target-inference-model 1))
(def infer-with-2-particles (make-simple-target-inference-model 2))
(def infer-with-3-particles (make-simple-target-inference-model 3))
(def infer-with-5-particles (make-simple-target-inference-model 5))
(def infer-with-10-particles (make-simple-target-inference-model 10))
(def infer-with-15-particles (make-simple-target-inference-model 15))

#'aide/infer-with-15-particles

# Implementation of AIDE

In [10]:
;; Helpers
(def logsumexp
  (gen [weights]
    (let [max-weight (apply max weights)
          numerically-stable-scores (map #(- % max-weight) weights)
          weights (map exp numerically-stable-scores)]
      (+ (log (apply + weights)) max-weight))))
(defn logmeanexp [weights] (- (logsumexp weights) (log (count weights))))
(defn avg [xs] (/ (reduce + xs) (count xs)))

#'aide/avg

In [11]:
;; AIDE
(defn aide [gold-standard target-algorithm inference-addresses Ng Mg Nt Mt]
  (let [gold-standard-traces
        (map #(partition-trace % inference-addresses) 
             (replicate Ng #(nth (infer-and-score :procedure gold-standard) 1)))
        
        target-algorithm-traces
        (map #(partition-trace % inference-addresses)
             (replicate Nt #(nth (infer-and-score :procedure target-algorithm) 1)))
                
        gold-standard-scores-on-gold-standard-samples
        (map (gen [[x u]] (logmeanexp 
                            (cons 
                              ;; First score is special: reuse randomness from initial draw
                              (let [[[_ _ s] _ _]
                                    (infer-and-score
                                      :procedure infer-and-score
                                      :inputs [:procedure gold-standard, :observation-trace x]
                                      :observation-trace u)] s)
                              ;; Last ones are regular `infer-and-score` scores:
                              (replicate (- Mg 1) #(nth (infer-and-score 
                                                         :procedure gold-standard,
                                                         :observation-trace x) 2)))))
             gold-standard-traces)
                
        target-algorithm-scores-on-gold-standard-samples
        (map (gen [[x u]] (logmeanexp 
                            (replicate Mt #(nth (infer-and-score 
                                                  :procedure target-algorithm, 
                                                  :observation-trace x) 2))))
             gold-standard-traces)
        
        gold-standard-scores-on-target-algorithm-samples
        (map (gen [[x v]] (logmeanexp
                            (replicate Mg #(nth (infer-and-score 
                                                  :procedure gold-standard, 
                                                  :observation-trace x) 2))))
             target-algorithm-traces)
        
        target-algorithm-scores-on-target-algorithm-samples
        (map (gen [[x v]] (logmeanexp 
                            (cons
                              (let [[[_ _ s] _ _]
                                    (infer-and-score
                                      :procedure infer-and-score
                                      :inputs [:procedure target-algorithm, :observation-trace x]
                                      :observation-trace v)] s)
                              (replicate (- Mt 1) #(nth (infer-and-score 
                                                          :procedure target-algorithm, 
                                                          :observation-trace x) 2)))))
             target-algorithm-traces)]
        
        ;; Use Clojure's version of `map`, which can take two lists l and m, 
        ;; and apply a function (like -) to l[0],m[0], l[1],m[1], etc.
        (+ (avg (clojure.core/map - 
                                  gold-standard-scores-on-gold-standard-samples 
                                  target-algorithm-scores-on-gold-standard-samples))
           (avg (clojure.core/map - 
                                  target-algorithm-scores-on-target-algorithm-samples
                                  gold-standard-scores-on-target-algorithm-samples)))))

#'aide/aide

# Applying AIDE to our problem

We can now apply AIDE to measure the quality of our approximate inference algorithms. Note that AIDE is an estimator of a bound, and has some variance -- as such, not all runs of the following lines will show monotonically decreasing scores.

In [12]:
(aide simple-gold-standard infer-with-1-particle '(("inferred-trace" "p")) 500 1, 500 100)

0.5210597814290557

In [13]:
(aide simple-gold-standard infer-with-2-particles '(("inferred-trace" "p")) 500 1, 500 100)

0.08832242662802453

In [14]:
(aide simple-gold-standard infer-with-3-particles '(("inferred-trace" "p")) 500 1, 500 100)

0.02622685262327079

In [15]:
(aide simple-gold-standard infer-with-5-particles '(("inferred-trace" "p")) 500 1, 500 100)

0.0060302650211156205

In [16]:
(aide simple-gold-standard infer-with-10-particles '(("inferred-trace" "p")) 500 1, 500 100)

0.004549526534672173

In [17]:
(aide simple-gold-standard infer-with-15-particles '(("inferred-trace" "p")) 500 1, 500 100)

-2.4912506566190625E-4