Skip to content

Commit

Permalink
feat: extract trace, splice
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Sep 6, 2023
1 parent 4408621 commit 39807e8
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 127 deletions.
2 changes: 1 addition & 1 deletion examples/introduction.clj
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@
;; then these values are sufficient to answer any question using executions of
;; the function, because all states in the execution of the function are
;; deterministic given the random choices. We will call the record of all the
;; random choies a **trace**. In order to store all the random choices in the
;; random choices a **trace**. In order to store all the random choices in the
;; trace, we need to come up with a unique name or **address** for each random
;; choice.

Expand Down
102 changes: 48 additions & 54 deletions src/gen/dynamic.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -9,55 +9,48 @@
#?(:cljs
(:require-macros [gen.dynamic])))

(defrecord GenerateMap [constraints trace weight]
dynamic.trace/ITrace
(-splice [state gf args]
(let [{subtrace :trace

Check warning on line 15 in src/gen/dynamic.cljc

View check run for this annotation

Codecov / codecov/patch

src/gen/dynamic.cljc#L15

Added line #L15 was not covered by tests
weight :weight}
(gf/generate gf args constraints)]
[(-> state
(update :trace dynamic.trace/merge-subtraces subtrace)
(update :weight + weight))
(trace/retval subtrace)]))

Check warning on line 21 in src/gen/dynamic.cljc

View check run for this annotation

Codecov / codecov/patch

src/gen/dynamic.cljc#L17-L21

Added lines #L17 - L21 were not covered by tests

(-trace [state k gf args]
(dynamic.trace/validate-empty! trace k)
(let [{subtrace :trace :as ret}
(if-let [k-constraints (get (choice-map/submaps constraints) k)]
(gf/generate gf args k-constraints)
(gf/generate gf args))]
[(dynamic.trace/combine state k ret)
(trace/retval subtrace)])))

Check warning on line 30 in src/gen/dynamic.cljc

View check run for this annotation

Codecov / codecov/patch

src/gen/dynamic.cljc#L24-L30

Added lines #L24 - L30 were not covered by tests

(defrecord DynamicDSLFunction [clojure-fn]
gf/Simulate
(simulate [gf args]
(let [trace (atom (dynamic.trace/trace gf args))]
(binding [dynamic.trace/*splice*
(fn [gf args]
(let [subtrace (gf/simulate gf args)]
(swap! trace dynamic.trace/merge-subtraces subtrace)
(trace/retval subtrace)))

dynamic.trace/*trace*
(fn [k gf args]
(dynamic.trace/validate-empty! @trace k)
(let [subtrace (gf/simulate gf args)]
(swap! trace dynamic.trace/assoc-subtrace k subtrace)
(trace/retval subtrace)))]
(let [retval (apply clojure-fn args)]
(swap! trace dynamic.trace/with-retval retval)
@trace))))
(let [!trace (atom (dynamic.trace/trace gf args))
retval (binding [dynamic.trace/*active* !trace]
(apply clojure-fn args))
trace @!trace]
(dynamic.trace/with-retval trace retval)))

gf/Generate
(generate [gf args]
(let [trace (gf/simulate gf args)]
{:trace trace :weight (math/log 1)}))
(generate [gf args constraints]
(let [state (atom {:trace (dynamic.trace/trace gf args)
:weight 0})]
(binding [dynamic.trace/*splice*
(fn [gf args]
(let [{subtrace :trace
weight :weight}
(gf/generate gf args constraints)]
(swap! state update :trace dynamic.trace/merge-subtraces subtrace)
(swap! state update :weight + weight)
(trace/retval subtrace)))

dynamic.trace/*trace*
(fn [k gf args]
(dynamic.trace/validate-empty! (:trace @state) k)
(let [{subtrace :trace :as ret}
(if-let [k-constraints (get (choice-map/submaps constraints) k)]
(gf/generate gf args k-constraints)
(gf/generate gf args))]
(swap! state dynamic.trace/combine k ret)
(trace/retval subtrace)))]
(let [retval (apply clojure-fn args)
trace (:trace @state)]
{:trace (dynamic.trace/with-retval trace retval)
:weight (:weight @state)}))))
(let [!state (atom (->GenerateMap
constraints
(dynamic.trace/trace gf args)

Check warning on line 48 in src/gen/dynamic.cljc

View check run for this annotation

Codecov / codecov/patch

src/gen/dynamic.cljc#L46-L48

Added lines #L46 - L48 were not covered by tests
0))
retval (binding [dynamic.trace/*active* !state]
(apply clojure-fn args))
state @!state]
(update state :trace dynamic.trace/with-retval retval)))

Check warning on line 53 in src/gen/dynamic.cljc

View check run for this annotation

Codecov / codecov/patch

src/gen/dynamic.cljc#L50-L53

Added lines #L50 - L53 were not covered by tests

#?@(:clj
[clojure.lang.IFn
Expand Down Expand Up @@ -154,19 +147,20 @@
`(->DynamicDSLFunction
(fn ~@(when name [name])
~params
~@(walk/postwalk (fn [form]
(cond (trace-form? form)
(if-not (valid-trace-form? form)
(throw (ex-info "Malformed trace expression." {:form form}))
(let [[addr [gf & args]] (rest form)]
`((dynamic.trace/active-trace) ~addr ~gf ~(vec args))))
~@(walk/postwalk
(fn [form]
(cond (trace-form? form)
(if-not (valid-trace-form? form)
(throw (ex-info "Malformed trace expression." {:form form}))

Check warning on line 154 in src/gen/dynamic.cljc

View check run for this annotation

Codecov / codecov/patch

src/gen/dynamic.cljc#L154

Added line #L154 was not covered by tests
(let [[addr [gf & args]] (rest form)]
`(dynamic.trace/trace! ~addr ~gf ~(vec args))))

(splice-form? form)
(if-not (valid-splice-form? form)
(throw (ex-info "Malformed splice expression." {:form form}))
(let [[[gf & args]] (rest form)]
`((dynamic.trace/active-splice) ~gf ~(vec args))))
(splice-form? form)
(if-not (valid-splice-form? form)
(throw (ex-info "Malformed splice expression." {:form form}))

Check warning on line 160 in src/gen/dynamic.cljc

View check run for this annotation

Codecov / codecov/patch

src/gen/dynamic.cljc#L160

Added line #L160 was not covered by tests
(let [[[gf & args]] (rest form)]
`(dynamic.trace/splice! ~gf ~(vec args))))

:else
form))
body)))))
:else
form))
body)))))
143 changes: 78 additions & 65 deletions src/gen/dynamic/trace.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -12,46 +12,39 @@
(:import
(clojure.lang Associative IFn IObj IMapIterable Seqable))))

(defn no-op
([gf args]
(apply gf args))
([_k gf args]
(apply gf args)))

(def ^:dynamic *trace*
"Applies the generative function gf to args. Dynamically rebound by functions
like `gf/simulate`, `gf/generate`, `trace/update`, etc."
no-op)

(def ^:dynamic *splice*
"Applies the generative function gf to args. Dynamically rebound by functions
like `gf/simulate`, `gf/generate`, `trace/update`, etc."
no-op)

(defn active-trace
"Returns the currently-active tracing function, bound to [[*trace*]].
NOTE: Prefer `([[active-trace]])` to `[[*trace*]]`, as direct access to
`[[*trace*]]` won't reflect new bindings when accessed inside of an SCI
environment."
[] *trace*)

(defn active-splice
"Returns the currently-active tracing function, bound to [[*splice*]].
NOTE: Prefer `([[active-splice]])` to `[[*splice*]]`, as direct access to
`[[*splice*]]` won't reflect new bindings when accessed inside of an SCI
environment."
[]
*splice*)
(defprotocol ITrace
(-splice [this gf args])
(-trace [this addr gf args]))

(defrecord NoOp []
ITrace
(-splice [this gf args]
[this (apply gf args)])
(-trace [this _k gf args]
[this (apply gf args)]))

(def no-op (NoOp.))

(def ^:dynamic *active* (atom no-op))

(defn active [] *active*)

(defn splice! [gf args]
(let [[new-state ret] (-splice @*active* gf args)]
(swap! *active* (fn [_] new-state))
ret))

(defn trace! [k gf args]
(let [[new-state ret] (-trace @*active* k gf args)]
(swap! *active* (fn [_] new-state))
ret))

(defmacro without-tracing
[& body]
`(binding [*trace* no-op
*splice* no-op]
`(binding [*active* (atom no-op)]
~@body))

(declare assoc-subtrace update-trace trace =)
(declare assoc-subtrace merge-subtraces update-trace validate-empty! trace =)

(deftype Trace [gf args subtraces retval]
trace/Args
Expand Down Expand Up @@ -79,6 +72,18 @@
(update [this constraints]
(update-trace this constraints))

ITrace
(-splice [this gf args]
(let [subtrace (gf/simulate gf args)]
[(merge-subtraces this subtrace)
(trace/retval subtrace)]))

(-trace [this k gf args]
(validate-empty! this k)
(let [subtrace (gf/simulate gf args)]
[(assoc-subtrace this k subtrace)
(trace/retval subtrace)]))

#?@(:cljs
[Object
(equiv [this other] (-equiv this other))
Expand Down Expand Up @@ -193,9 +198,9 @@
[^Trace t addr subt]
(validate-empty! t addr)
(->Trace (.-gf t)

Check failure on line 200 in src/gen/dynamic/trace.cljc

View workflow job for this annotation

GitHub Actions / lint-files

gen.dynamic.trace/->Trace is called with 4 args but expects 5
(.-args t)
(assoc (.-subtraces t) addr subt)
(.-retval t)))
(.-args t)
(assoc (.-subtraces t) addr subt)
(.-retval t)))

(defn merge-subtraces
[^Trace t1 ^Trace t2]
Expand All @@ -211,34 +216,42 @@
(update :weight + weight)
(cond-> discard (update :discard assoc k discard))))

;; TODO: this does NOT feel like the right data structure. In fact I think
;; updates should be able to shuffle over the unused stuff from update to
;; update, instead of having to do that final update at the very end.
;;
;; Then each update step could shuffling from the constraints over to the end.
(defrecord UpdateMap [this constraints trace weight discard]
ITrace
(-splice [_ _ _]
(throw (ex-info "Not yet implemented." {})))

Check warning on line 227 in src/gen/dynamic/trace.cljc

View check run for this annotation

Codecov / codecov/patch

src/gen/dynamic/trace.cljc#L227

Added line #L227 was not covered by tests

(-trace [state k gf args]
(validate-empty! trace k)
(let [k-constraints (get (choice-map/submaps constraints) k)
{subtrace :trace :as ret}
(if-let [prev-subtrace (get (.-subtraces ^Trace this) k)]
(trace/update prev-subtrace k-constraints)
(gf/generate gf args k-constraints))]

Check warning on line 235 in src/gen/dynamic/trace.cljc

View check run for this annotation

Codecov / codecov/patch

src/gen/dynamic/trace.cljc#L235

Added line #L235 was not covered by tests
[(combine state k ret)
(trace/retval subtrace)])))

(defn update-trace [this constraints]
(let [gf (trace/gf this)
state (atom {:trace (trace gf (trace/args this))
:weight 0
:discard (cm/choice-map)})]
(binding [*splice*
(fn [& _]
(throw (ex-info "Not yet implemented." {})))

*trace*
(fn [k gf args]
(validate-empty! (:trace @state) k)
(let [k-constraints (get (choice-map/submaps constraints) k)
{subtrace :trace :as ret}
(if-let [prev-subtrace (get (.-subtraces this) k)]
(trace/update prev-subtrace k-constraints)
(gf/generate gf args k-constraints))]
(swap! state combine k ret)
(trace/retval subtrace)))]
(let [retval (apply (:clojure-fn gf) (trace/args this))
{:keys [trace weight discard]} @state
unvisited (apply dissoc
(trace/choices this)
(keys (trace/choices trace)))]

{:trace (with-retval trace retval)
:weight weight
:discard (merge discard unvisited)}))))
(let [gf (trace/gf this)
!state (atom (->UpdateMap
this constraints
(trace gf (trace/args this))
0
(cm/choice-map)))
retval (binding [*active* !state]
(apply (:clojure-fn gf) (trace/args this)))
{:keys [trace weight discard]} @!state
unvisited (apply dissoc
(trace/choices this)
(keys (trace/choices trace)))]
{:trace (with-retval trace retval)
:weight weight
:discard (merge discard unvisited)}))

;; ## Primitive Trace
;;
Expand Down
10 changes: 3 additions & 7 deletions test/gen/dynamic/trace_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,9 @@

(deftest binding-tests
(letfn [(f [_] "hi!")]
(binding [dynamic.trace/*trace* f
dynamic.trace/*splice* f]
(is (= f (dynamic.trace/active-trace))
"active-trace reflects dynamic bindings")

(is (= f (dynamic.trace/active-splice))
"active-splice reflects dynamic bindings"))))
(binding [dynamic.trace/*active* f]
(is (= f (dynamic.trace/active))
"active reflects dynamic bindings"))))

(defn choice-trace
[x]
Expand Down

0 comments on commit 39807e8

Please sign in to comment.