Skip to content

Commit

Permalink
refactor: move update to trace impl
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Aug 31, 2023
1 parent 7ab25ce commit 4408621
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 65 deletions.
65 changes: 8 additions & 57 deletions src/gen/dynamic.cljc
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
(ns gen.dynamic
(:require [clojure.math :as math]
[clojure.set :as set]
[clojure.walk :as walk]
[gen]
[gen.choice-map :as choice-map]
[gen.dynamic.choice-map :as dynamic.choice-map]
[gen.dynamic.trace :as dynamic.trace #?@(:cljs [:refer [Trace]])]
[gen.dynamic.trace :as dynamic.trace]
[gen.generative-function :as gf]
[gen.trace :as trace])
#?(:cljs
(:require-macros [gen.dynamic]))
#?(:clj
(:import (gen.dynamic.trace Trace))))
(:require-macros [gen.dynamic])))

(defrecord DynamicDSLFunction [clojure-fn]
gf/Simulate
Expand All @@ -25,6 +21,7 @@

dynamic.trace/*trace*
(fn [k gf args]
(dynamic.trace/validate-empty! @trace k)
(let [subtrace (gf/simulate gf args)]
(swap! trace dynamic.trace/assoc-subtrace k subtrace)
(trace/retval subtrace)))]
Expand All @@ -50,14 +47,12 @@

dynamic.trace/*trace*
(fn [k gf args]
(let [{subtrace :trace
weight :weight}
(if-let [constraints (get (choice-map/submaps constraints)
k)]
(gf/generate gf args constraints)
(dynamic.trace/validate-empty! (:trace @state) k)
(let [{subtrace :trace :as ret}
(if-let [k-constraints (get (choice-map/submaps constraints) k)]
(gf/generate gf args k-constraints)
(gf/generate gf args))]
(swap! state update :trace dynamic.trace/assoc-subtrace k subtrace)
(swap! state update :weight + weight)
(swap! state dynamic.trace/combine k ret)
(trace/retval subtrace)))]
(let [retval (apply clojure-fn args)
trace (:trace @state)]
Expand Down Expand Up @@ -115,50 +110,6 @@
(-invoke [_ arg1 arg2 arg3 arg4 arg5 arg6 arg7 arg8 arg9 arg10 arg11 arg12 arg13 arg14 arg15 arg16 arg17 arg18 arg19 arg20] (dynamic.trace/without-tracing (clojure-fn arg1 arg2 arg3 arg4 arg5 arg6 arg7 arg8 arg9 arg10 arg11 arg12 arg13 arg14 arg15 arg16 arg17 arg18 arg19 arg20)))
(-invoke [_ arg1 arg2 arg3 arg4 arg5 arg6 arg7 arg8 arg9 arg10 arg11 arg12 arg13 arg14 arg15 arg16 arg17 arg18 arg19 arg20 args] (apply clojure-fn arg1 arg2 arg3 arg4 arg5 arg6 arg7 arg8 arg9 arg10 arg11 arg12 arg13 arg14 arg15 arg16 arg17 arg18 arg19 arg20 args))]))

(extend-type Trace
trace/Update
(update [prev-trace constraints]
(let [^DynamicDSLFunction gf (trace/gf prev-trace)
state (atom {:trace (dynamic.trace/trace gf (trace/args prev-trace))
:weight 0
:discard (dynamic.choice-map/choice-map)})]
(binding [dynamic.trace/*splice*
(fn [& _]
(throw (ex-info "Not yet implemented." {})))

dynamic.trace/*trace*
(fn [k gf args]
(let [{subtrace :trace
weight :weight
discard :discard}
(if-let [prev-subtrace (get (.-subtraces prev-trace) k)]
(let [{new-subtrace :trace
new-weight :weight
discard :discard}
(trace/update prev-subtrace
(get (choice-map/submaps constraints)
k))]
{:trace new-subtrace
:weight new-weight
:discard discard})
(gf/generate gf args (get (choice-map/submaps constraints)
k)))]
(swap! state update :trace dynamic.trace/assoc-subtrace k subtrace)
(swap! state update :weight + weight)
(when discard
(swap! state update :discard assoc k discard))
(trace/retval subtrace)))]
(let [retval (apply (.-clojure-fn gf)
(trace/args prev-trace))
{:keys [trace weight discard]} @state
unvisited (select-keys (trace/choices prev-trace)
(set/difference (set (keys (trace/choices prev-trace)))
(set (keys (trace/choices trace)))))]

{:trace (dynamic.trace/with-retval trace retval)
:weight weight
:discard (merge discard unvisited)})))))

(defn trace-form?
"Returns true if `form` is a trace form."
[form]
Expand Down
61 changes: 53 additions & 8 deletions src/gen/dynamic/trace.cljc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
(ns gen.dynamic.trace
(:refer-clojure :exclude [=])
(:require [clojure.core :as core]
[gen.choice-map :as choice-map]
[gen.diff :as diff]
[gen.dynamic.choice-map :as cm]
[gen.generative-function :as gf]
Expand Down Expand Up @@ -50,7 +51,7 @@
*splice* no-op]
~@body))

(declare assoc-subtrace merge-trace with-retval trace =)
(declare assoc-subtrace update-trace trace =)

(deftype Trace [gf args subtraces retval]
trace/Args
Expand All @@ -74,6 +75,10 @@
(let [v (vals subtraces)]
(transduce (map trace/score) + 0.0 v)))

trace/Update
(update [this constraints]
(update-trace this constraints))

#?@(:cljs
[Object
(equiv [this other] (-equiv this other))
Expand Down Expand Up @@ -178,23 +183,63 @@
(defn with-retval [^Trace t v]
(->Trace (.-gf t) (.-args t) (.-subtraces t) v))

Check failure on line 184 in src/gen/dynamic/trace.cljc

View workflow job for this annotation

GitHub Actions / lint-files

gen.dynamic.trace/->Trace is called with 4 args but expects 5

(defn validate-empty! [t addr]
(when (contains? t addr)
(throw (ex-info "Value or subtrace already present at address. The same
address cannot be reused for multiple random choices."
{:addr addr}))))

(defn assoc-subtrace
[^Trace t addr subt]
(let [subtraces (.-subtraces t)]
(when (contains? subtraces addr)
(throw (ex-info "Value or subtrace already present at address. The same address cannot be reused for multiple random choices."
{:addr addr})))
(->Trace (.-gf t)
(validate-empty! t addr)
(->Trace (.-gf t)

Check failure on line 195 in src/gen/dynamic/trace.cljc

View workflow job for this annotation

GitHub Actions / lint-files

gen.dynamic.trace/->Trace is called with 4 args but expects 5
(.-args t)
(assoc subtraces addr subt)
(.-retval t))))
(assoc (.-subtraces t) addr subt)
(.-retval t)))

(defn merge-subtraces
[^Trace t1 ^Trace t2]
(reduce-kv assoc-subtrace
t1
(.-subtraces t2)))

(defn ^:no-doc combine
"combine by adding weights?"
[v k {:keys [trace weight discard]}]
(-> v
(update :trace assoc-subtrace k trace)
(update :weight + weight)
(cond-> discard (update :discard assoc k discard))))

(defn update-trace [this constraints]
(let [gf (trace/gf this)
state (atom {:trace (trace gf (trace/args this))
:weight 0
:discard (cm/choice-map)})]
(binding [*splice*
(fn [& _]
(throw (ex-info "Not yet implemented." {})))

*trace*
(fn [k gf args]
(validate-empty! (:trace @state) k)
(let [k-constraints (get (choice-map/submaps constraints) k)
{subtrace :trace :as ret}
(if-let [prev-subtrace (get (.-subtraces this) k)]
(trace/update prev-subtrace k-constraints)
(gf/generate gf args k-constraints))]
(swap! state combine k ret)
(trace/retval subtrace)))]
(let [retval (apply (:clojure-fn gf) (trace/args this))
{:keys [trace weight discard]} @state
unvisited (apply dissoc
(trace/choices this)
(keys (trace/choices trace)))]

{:trace (with-retval trace retval)
:weight weight
:discard (merge discard unvisited)}))))

;; ## Primitive Trace
;;
;; [[Trace]] above tracks map-like associations of address to traced value. At
Expand Down

0 comments on commit 4408621

Please sign in to comment.