|
28 | 28 | }, |
29 | 29 | { |
30 | 30 | "cell_type": "code", |
31 | | - "execution_count": 2, |
32 | | - "metadata": { |
33 | | - "collapsed": true |
34 | | - }, |
35 | | - "outputs": [ |
36 | | - { |
37 | | - "data": { |
38 | | - "text/plain": [ |
39 | | - "\u001b[32mimport \u001b[39m\u001b[36m$ivy.$ \n", |
40 | | - "\u001b[39m\n", |
41 | | - "\u001b[32mimport \u001b[39m\u001b[36m$ivy.$ \u001b[39m" |
42 | | - ] |
43 | | - }, |
44 | | - "execution_count": 2, |
45 | | - "metadata": {}, |
46 | | - "output_type": "execute_result" |
47 | | - } |
48 | | - ], |
| 31 | + "execution_count": null, |
| 32 | + "metadata": {}, |
| 33 | + "outputs": [], |
49 | 34 | "source": [ |
50 | 35 | "import $ivy.`scala-infer::scala-infer:0.3`\n", |
51 | 36 | "import $ivy.`org.jupyter-scala::kernel-api:0.4.1`" |
52 | 37 | ] |
53 | 38 | }, |
54 | 39 | { |
55 | 40 | "cell_type": "code", |
56 | | - "execution_count": 3, |
57 | | - "metadata": { |
58 | | - "collapsed": true |
59 | | - }, |
60 | | - "outputs": [ |
61 | | - { |
62 | | - "data": { |
63 | | - "text/plain": [ |
64 | | - "\u001b[32mimport \u001b[39m\u001b[36mscappla._\n", |
65 | | - "\u001b[39m\n", |
66 | | - "\u001b[32mimport \u001b[39m\u001b[36mscappla.Functions._\n", |
67 | | - "\u001b[39m\n", |
68 | | - "\u001b[32mimport \u001b[39m\u001b[36mscappla.distributions._\n", |
69 | | - "\u001b[39m\n", |
70 | | - "\u001b[32mimport \u001b[39m\u001b[36mscappla.guides._\n", |
71 | | - "\u001b[39m\n", |
72 | | - "\u001b[32mimport \u001b[39m\u001b[36mscappla.optimization._\n", |
73 | | - "\u001b[39m\n", |
74 | | - "\u001b[32mimport \u001b[39m\u001b[36mscappla.tensor.Tensor._\n", |
75 | | - "\u001b[39m\n", |
76 | | - "\u001b[32mimport \u001b[39m\u001b[36mscappla.tensor._\u001b[39m" |
77 | | - ] |
78 | | - }, |
79 | | - "execution_count": 3, |
80 | | - "metadata": {}, |
81 | | - "output_type": "execute_result" |
82 | | - } |
83 | | - ], |
| 41 | + "execution_count": null, |
| 42 | + "metadata": {}, |
| 43 | + "outputs": [], |
84 | 44 | "source": [ |
85 | 45 | "import scappla._\n", |
86 | 46 | "import scappla.Functions._\n", |
87 | 47 | "import scappla.distributions._\n", |
88 | 48 | "import scappla.guides._\n", |
89 | 49 | "import scappla.optimization._\n", |
90 | 50 | "import scappla.tensor.Tensor._\n", |
91 | | - "import scappla.tensor._" |
92 | | - ] |
93 | | - }, |
94 | | - { |
95 | | - "cell_type": "code", |
96 | | - "execution_count": 4, |
97 | | - "metadata": { |
98 | | - "collapsed": true |
99 | | - }, |
100 | | - "outputs": [ |
101 | | - { |
102 | | - "data": { |
103 | | - "text/plain": [ |
104 | | - "\u001b[32mimport \u001b[39m\u001b[36mscala.util.Random\u001b[39m" |
105 | | - ] |
106 | | - }, |
107 | | - "execution_count": 4, |
108 | | - "metadata": {}, |
109 | | - "output_type": "execute_result" |
110 | | - } |
111 | | - ], |
112 | | - "source": [ |
| 51 | + "import scappla.tensor._\n", |
| 52 | + "\n", |
113 | 53 | "import scala.util.Random" |
114 | 54 | ] |
115 | 55 | }, |
116 | 56 | { |
117 | 57 | "cell_type": "code", |
118 | | - "execution_count": 5, |
119 | | - "metadata": { |
120 | | - "collapsed": true |
121 | | - }, |
122 | | - "outputs": [ |
123 | | - { |
124 | | - "data": { |
125 | | - "text/plain": [ |
126 | | - "defined \u001b[32mclass\u001b[39m \u001b[36mRecord\u001b[39m\n", |
127 | | - "defined \u001b[32mclass\u001b[39m \u001b[36mBatch\u001b[39m\n", |
128 | | - "\u001b[36mbatch\u001b[39m: \u001b[32mBatch\u001b[39m = \u001b[33mBatch\u001b[39m(\u001b[32m1000\u001b[39m)\n", |
129 | | - "\u001b[36ma_vals\u001b[39m: \u001b[32mValue\u001b[39m[\u001b[32mArrayTensor\u001b[39m, \u001b[32mBatch\u001b[39m] = scappla.Constant@12f3f093\n", |
130 | | - "\u001b[36mb_vals\u001b[39m: \u001b[32mValue\u001b[39m[\u001b[32mArrayTensor\u001b[39m, \u001b[32mBatch\u001b[39m] = scappla.Constant@31e80b51\n", |
131 | | - "\u001b[36my_vals\u001b[39m: \u001b[32mValue\u001b[39m[\u001b[32mArrayTensor\u001b[39m, \u001b[32mBatch\u001b[39m] = scappla.Constant@5f6ae152" |
132 | | - ] |
133 | | - }, |
134 | | - "execution_count": 5, |
135 | | - "metadata": {}, |
136 | | - "output_type": "execute_result" |
137 | | - } |
138 | | - ], |
| 58 | + "execution_count": null, |
| 59 | + "metadata": {}, |
| 60 | + "outputs": [], |
139 | 61 | "source": [ |
140 | 62 | "case class Record(a: Float, b: Float, y: Float)\n", |
141 | 63 | "\n", |
|
165 | 87 | }, |
166 | 88 | { |
167 | 89 | "cell_type": "code", |
168 | | - "execution_count": 6, |
169 | | - "metadata": { |
170 | | - "collapsed": true |
171 | | - }, |
172 | | - "outputs": [ |
173 | | - { |
174 | | - "data": { |
175 | | - "text/plain": [ |
176 | | - "\u001b[36ma_prior_s\u001b[39m: \u001b[32mParam\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = scappla.Param@7e13f371\n", |
177 | | - "\u001b[36mb_prior_s\u001b[39m: \u001b[32mParam\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = scappla.Param@314da4f9\n", |
178 | | - "\u001b[36ma_post_mu\u001b[39m: \u001b[32mParam\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = scappla.Param@33764907\n", |
179 | | - "\u001b[36ma_post_s\u001b[39m: \u001b[32mParam\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = scappla.Param@3a272583\n", |
180 | | - "\u001b[36ma_guide\u001b[39m: \u001b[32mReparamGuide\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = \u001b[33mReparamGuide\u001b[39m(\n", |
181 | | - " \u001b[33mNormal\u001b[39m(scappla.Param@33764907, \u001b[33mApply1\u001b[39m(scappla.Param@3a272583, <function1>))\n", |
182 | | - ")\n", |
183 | | - "\u001b[36mb_post_mu\u001b[39m: \u001b[32mParam\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = scappla.Param@5cbb302d\n", |
184 | | - "\u001b[36mb_post_s\u001b[39m: \u001b[32mParam\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = scappla.Param@5d9b0110\n", |
185 | | - "\u001b[36mb_guide\u001b[39m: \u001b[32mReparamGuide\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = \u001b[33mReparamGuide\u001b[39m(\n", |
186 | | - " \u001b[33mNormal\u001b[39m(scappla.Param@5cbb302d, \u001b[33mApply1\u001b[39m(scappla.Param@5d9b0110, <function1>))\n", |
187 | | - ")\n", |
188 | | - "\u001b[36mnoise_mu\u001b[39m: \u001b[32mParam\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = scappla.Param@33f65cad\n", |
189 | | - "\u001b[36mnoise_s\u001b[39m: \u001b[32mParam\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = scappla.Param@58927bd0\n", |
190 | | - "\u001b[36mnoise_guide\u001b[39m: \u001b[32mReparamGuide\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m] = \u001b[33mReparamGuide\u001b[39m(\n", |
191 | | - " \u001b[33mNormal\u001b[39m(scappla.Param@33f65cad, \u001b[33mApply1\u001b[39m(scappla.Param@58927bd0, <function1>))\n", |
192 | | - ")\n", |
193 | | - "\u001b[36mmodel\u001b[39m: \u001b[32mModel\u001b[39m[\u001b[32mUnit\u001b[39m] = ammonite.$sess.cmd5$Helper$$anon$1@67449597" |
194 | | - ] |
195 | | - }, |
196 | | - "execution_count": 6, |
197 | | - "metadata": {}, |
198 | | - "output_type": "execute_result" |
199 | | - } |
200 | | - ], |
| 90 | + "execution_count": null, |
| 91 | + "metadata": {}, |
| 92 | + "outputs": [], |
201 | 93 | "source": [ |
202 | 94 | "val a_prior_s = Param(0.0)\n", |
203 | 95 | "val b_prior_s = Param(0.0)\n", |
|
229 | 121 | }, |
230 | 122 | { |
231 | 123 | "cell_type": "code", |
232 | | - "execution_count": 7, |
233 | | - "metadata": { |
234 | | - "collapsed": true |
235 | | - }, |
236 | | - "outputs": [ |
237 | | - { |
238 | | - "data": { |
239 | | - "text/plain": [ |
240 | | - "\u001b[36mopt\u001b[39m: \u001b[32mAdam\u001b[39m = scappla.optimization.Adam@7bc3e74f\n", |
241 | | - "\u001b[36minterpreter\u001b[39m: \u001b[32mOptimizingInterpreter\u001b[39m = scappla.OptimizingInterpreter@115f4b3" |
242 | | - ] |
243 | | - }, |
244 | | - "execution_count": 7, |
245 | | - "metadata": {}, |
246 | | - "output_type": "execute_result" |
247 | | - } |
248 | | - ], |
249 | | - "source": [ |
250 | | - "val opt = new Adam(0.1)\n", |
251 | | - "val interpreter = new OptimizingInterpreter(opt)" |
252 | | - ] |
253 | | - }, |
254 | | - { |
255 | | - "cell_type": "code", |
256 | | - "execution_count": 15, |
| 124 | + "execution_count": null, |
257 | 125 | "metadata": {}, |
258 | 126 | "outputs": [], |
259 | 127 | "source": [ |
| 128 | + "val opt = new Adam(0.1)\n", |
| 129 | + "val interpreter = new OptimizingInterpreter(opt)\n", |
| 130 | + "\n", |
260 | 131 | "for { _ <- 0 until 10000 } {\n", |
261 | 132 | " interpreter.reset()\n", |
262 | 133 | " model.sample(interpreter)\n", |
|
265 | 136 | }, |
266 | 137 | { |
267 | 138 | "cell_type": "code", |
268 | | - "execution_count": 17, |
269 | | - "metadata": { |
270 | | - "collapsed": true |
271 | | - }, |
272 | | - "outputs": [ |
273 | | - { |
274 | | - "data": { |
275 | | - "text/plain": [ |
276 | | - "\u001b[36mparams\u001b[39m: \u001b[32mSeq\u001b[39m[(\u001b[32mString\u001b[39m, \u001b[32mExpr\u001b[39m[\u001b[32mDouble\u001b[39m, \u001b[32mUnit\u001b[39m])] = \u001b[33mList\u001b[39m(\n", |
277 | | - " (\u001b[32m\"a_prior\"\u001b[39m, \u001b[33mApply1\u001b[39m(scappla.Param@7e13f371, <function1>)),\n", |
278 | | - " (\u001b[32m\"b_prior\"\u001b[39m, \u001b[33mApply1\u001b[39m(scappla.Param@314da4f9, <function1>)),\n", |
279 | | - " (\u001b[32m\"a_post_mu\"\u001b[39m, scappla.Param@33764907),\n", |
280 | | - " (\u001b[32m\"a_post_s\"\u001b[39m, \u001b[33mApply1\u001b[39m(scappla.Param@3a272583, <function1>)),\n", |
281 | | - " (\u001b[32m\"b_post_mu\"\u001b[39m, scappla.Param@5cbb302d),\n", |
282 | | - " (\u001b[32m\"b_post_s\"\u001b[39m, \u001b[33mApply1\u001b[39m(scappla.Param@5d9b0110, <function1>)),\n", |
283 | | - " (\u001b[32m\"noise_mu\"\u001b[39m, \u001b[33mApply1\u001b[39m(scappla.Param@33f65cad, <function1>)),\n", |
284 | | - " (\u001b[32m\"noise_s\"\u001b[39m, \u001b[33mApply1\u001b[39m(scappla.Param@58927bd0, <function1>))\n", |
285 | | - ")" |
286 | | - ] |
287 | | - }, |
288 | | - "execution_count": 17, |
289 | | - "metadata": {}, |
290 | | - "output_type": "execute_result" |
291 | | - } |
292 | | - ], |
| 139 | + "execution_count": null, |
| 140 | + "metadata": {}, |
| 141 | + "outputs": [], |
293 | 142 | "source": [ |
294 | 143 | "val params = Seq(\n", |
295 | 144 | " \"a_prior\" -> exp(a_prior_s),\n", |
|
0 commit comments