/
bayes-by-backprop.ipynb
962 lines (962 loc) · 64.4 KB
/
bayes-by-backprop.ipynb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Bayes by Backprop from scratch (NN, classification)\n",
"\n",
"We have already learned how to implement deep neural networks and how to use them for classification and regression tasks. In order to fight overfitting, we further introduced a concept called _dropout_, which randomly turns off a certain percentage of the weights during training.\n",
"\n",
"Recall the classic architecture of a MLP (shown below, without bias terms). So far, when training a neural network, our goal was to find an optimal point estimate for the weights.\n",
"\n",
"![](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/img/bbb_nn_classic.png?raw=true)\n",
"\n",
"While networks trained using this approach usually perform well in regions with lots of data, they fail to express uncertainity in regions with little or no data, leading to overconfident decisions. This drawback motivates the application of Bayesian learning to neural networks, introducing probability distributions over the weights. These distributions can be of various nature in theory. To make our lifes easier and to have an intuitive understanding of the distribution at each weight, we will use a Gaussian distribution.\n",
"\n",
"![](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/img/bbb_nn_bayes.png?raw=true)\n",
"\n",
"Unfortunately though, exact Bayesian inference on the parameters of a neural network is intractable. One promising way of addressing this problem is presented by the \"Bayes by Backprop\" algorithm (introduced by the \"[Weight Uncertainity in Neural Networks](https://arxiv.org/abs/1505.05424)\" paper) which derives a variational approximation to the true posterior. This algorithm does not only make networks more \"honest\" with respect to their overall uncertainity, but also automatically leads to regularization, thereby eliminating the need of using dropout in this model.\n",
"\n",
"While we will try to explain the most important concepts of this algorithm in this notebook, we also encourage the reader to consult the paper for deeper insights."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's start implementing this idea and evaluate its performance on the MNIST classification problem. We start off with the usual set of imports."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from __future__ import print_function\n",
"import collections\n",
"import mxnet as mx\n",
"import numpy as np\n",
"from mxnet import nd, autograd\n",
"from matplotlib import pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For easy tuning and experimentation, we define a dictionary holding the hyper-parameters of our model."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"config = {\n",
" \"num_hidden_layers\": 2,\n",
" \"num_hidden_units\": 400, \n",
" \"batch_size\": 128,\n",
" \"epochs\": 10,\n",
" \"learning_rate\": 0.001,\n",
" \"num_samples\": 1,\n",
" \"pi\": 0.25,\n",
" \"sigma_p\": 1.0,\n",
" \"sigma_p1\": 0.75,\n",
" \"sigma_p2\": 0.1,\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Also, we specify the device context for MXNet."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"ctx = mx.cpu()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load MNIST data\n",
"\n",
"We will again train and evaluate the algorithm on the MNIST data set and therefore load the data set as follows:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"def transform(data, label):\n",
" return data.astype(np.float32)/126.0, label.astype(np.float32)\n",
"\n",
"mnist = mx.test_utils.get_mnist()\n",
"num_inputs = 784\n",
"num_outputs = 10\n",
"batch_size = config['batch_size']\n",
"\n",
"train_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=True, transform=transform),\n",
" batch_size, shuffle=True)\n",
"test_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False, transform=transform),\n",
" batch_size, shuffle=False)\n",
"\n",
"num_train = sum([batch_size for i in train_data])\n",
"num_batches = num_train / batch_size"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to reproduce and compare the results from the paper, we preprocess the pixels by dividing by 126."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model definition\n",
"\n",
"### Activation function\n",
"\n",
"As with lots of past examples, we will again use the ReLU as our activation function for the hidden units of our neural network."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def relu(X):\n",
" return nd.maximum(X, nd.zeros_like(X))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Neural net modeling\n",
"\n",
"As our model we are using a straightforward MLP and we are wiring up our network just as we are used to."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"num_layers = config['num_hidden_layers']\n",
"\n",
"# define function for evaluating MLP\n",
"def net(X, layer_params):\n",
" layer_input = X\n",
" for i in range(len(layer_params) // 2 - 2):\n",
" h_linear = nd.dot(layer_input, layer_params[2*i]) + layer_params[2*i + 1]\n",
" layer_input = relu(h_linear)\n",
" # last layer without ReLU\n",
" output = nd.dot(layer_input, layer_params[-2]) + layer_params[-1]\n",
" return output\n",
"\n",
"# define network weight shapes\n",
"layer_param_shapes = []\n",
"num_hidden = config['num_hidden_units']\n",
"for i in range(num_layers + 1):\n",
" if i == 0: # input layer\n",
" W_shape = (num_inputs, num_hidden)\n",
" b_shape = (num_hidden,) \n",
" elif i == num_layers: # last layer\n",
" W_shape = (num_hidden, num_outputs)\n",
" b_shape = (num_outputs,)\n",
" else: # hidden layers\n",
" W_shape = (num_hidden, num_hidden)\n",
" b_shape = (num_hidden,)\n",
" layer_param_shapes.extend([W_shape, b_shape])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build objective/loss"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we briefly mentioned at the beginning of the notebook, we will use variational inference in order to make the prediction of the posterior tractable. While we can not model the posterior $P(\\mathbf{w}\\ |\\ \\mathcal{D})$ directly, we try to find the parameters $\\mathbf{\\theta}$ of a distribution on the weights $q(\\mathbf{w}\\ |\\ \\mathbf{\\theta})$ (commly referred to as the _variational posterior_) that minimizes the KL divergence with the true posterior. Formally this can be expressed as:\n",
"\n",
"\\begin{equation*}\n",
"\\begin{split}\n",
"\\theta^{*} & = \\arg\\min_{\\theta} \\text{KL}[q(\\mathbf{w}\\ |\\ \\mathbf{\\theta})\\ ||\\ P(\\mathbf{w}\\ |\\ \\mathcal{D}]\\\\\n",
"& = \\arg\\min_{\\theta} \\int q(\\mathbf{w}\\ |\\ \\mathbf{\\theta}) \\log \\frac{q(\\mathbf{w}\\ |\\ \\mathbf{\\theta})}{P(\\mathbf{w}) P(\\mathcal{D}\\ |\\ \\mathbf{w})} d\\mathbf{w} \\\\\n",
"& = \\arg\\min_{\\theta} \\text{KL}[q(\\mathbf{w}\\ |\\ \\mathbf{\\theta})\\ ||\\ P(\\mathbf{w})] - \\mathbb{E}_{q(\\mathbf{w}\\ |\\ \\mathbf{\\theta})}[\\log P(\\mathcal{D}\\ |\\ \\mathbf{w})]\n",
"\\end{split}\n",
"\\end{equation*}\n",
"\n",
"The resulting loss function, commonly referred to as either _variational free energy_ or _expected lower bound_ (_ELBO_), has to be minimized and is then given as follows:\n",
"\n",
"\\begin{equation*}\n",
"\\mathcal{F}(\\mathcal{D}, \\mathbf{\\theta}) = \\text{KL}[q(\\mathbf{w}\\ |\\ \\mathbf{\\theta})\\ ||\\ P(\\mathbf{w})] - \\mathbb{E}_{q(\\mathbf{w}\\ |\\ \\mathbf{\\theta})}[\\log P(\\mathcal{D}\\ |\\ \\mathbf{w})]\n",
"\\end{equation*}\n",
"\n",
"As one can easily see, the cost function tries to balance the complexity of the data $P(\\mathcal{D}\\ |\\ \\mathbf{w})$ and the simplicity of the prior $P(\\mathbf{w})$.\n",
"\n",
"We can approximate this exact cost through a Monte Carlo sampling procedure as follows\n",
"\n",
"\\begin{equation*}\n",
"\\mathcal{F}(\\mathcal{D}, \\mathbf{\\theta}) \\approx \\sum_{i = 1}^{n} \\log q(\\mathbf{w}^{(i)}\\ |\\ \\mathbf{\\theta}) - \\log P(\\mathbf{w}^{(i)}) - \\log P(\\mathcal{D}\\ |\\ \\mathbf{w}^{(i)})\n",
"\\end{equation*}\n",
"\n",
"where $\\mathbf{w}^{(i)}$ corresponds to the $i$-th Monte Carlo sample from the variational posterior. While writing this notebook, we noticed that even taking just one sample leads to good results and we will therefore stick to just sampling once throughout the notebook.\n",
"\n",
"Since we will be working with mini-batches, the exact loss form on mini-batch $i$ we will be using looks as follows:\n",
"\n",
"\\begin{equation*}\n",
"\\begin{split}\n",
"\\mathcal{F}(\\mathcal{D}_i, \\mathbf{\\theta}) & = \\frac{1}{M} \\text{KL}[\\log q(\\mathbf{w}\\ |\\ \\mathbf{\\theta})\\ ||\\ \\log P(\\mathbf{w})] - \\mathbb{E}_{q(\\mathbf{w}\\ |\\ \\mathbf{\\theta})}[\\log P(\\mathcal{D}_i\\ |\\ \\mathbf{w})]\\\\\n",
"& \\approx \\frac{1}{M} (\\log q(\\mathbf{w}^{(1)}\\ |\\ \\mathbf{\\theta}) - \\log P(\\mathbf{w}^{(1)})) - \\log P(\\mathcal{D}_i\\ |\\ \\mathbf{w}^{(1)})\n",
"\\end{split}\n",
"\\end{equation*}\n",
"\n",
"where $M$ corresponds to the number of batches,\n",
"and $\\mathcal{F}(\\mathcal{D}, \\mathbf{\\theta}) = \\sum_{i = 1}^{M} \\mathcal{F}(\\mathcal{D}_i, \\mathbf{\\theta})$\n",
"\n",
"Let's now look at each of these single terms individually."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Likelihood\n",
"\n",
"As with lots of past examples, we will again use the softmax to define our likelihood $P(\\mathcal{D}_i\\ |\\ \\mathbf{w})$. Revisit the [MLP from scratch notebook](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter03_deep-neural-networks/mlp-scratch.ipynb) for a detailed motivation of this function."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def log_softmax_likelihood(yhat_linear, y):\n",
" return nd.nansum(y * nd.log_softmax(yhat_linear), axis=0, exclude=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prior\n",
"\n",
"Since we are introducing a Bayesian treatment for the network, we need to define a Prior over the weights.\n",
"\n",
"#### Gaussian prior\n",
"\n",
"A popular and simple prior is the Gaussian distribution. The prior over the entire weight vector simply corresponds to the prodcut of the individual Gaussians:\n",
"\n",
"\\begin{equation*}\n",
"P(\\mathbf{w}) = \\prod_i \\mathcal{N}(\\mathbf{w}_i\\ |\\ 0,\\sigma_p^2)\n",
"\\end{equation*}\n",
"\n",
"We can define the Gaussian distribution and our Gaussian prior as seen below. Note that we are ultimately intersted in the log-prior $\\log P(\\mathbf{w})$ and therefore compute the sum of the log-Gaussians.\n",
"\n",
"\\begin{equation*}\n",
"\\log P(\\mathbf{w}) = \\sum_i \\log \\mathcal{N}(\\mathbf{w}_i\\ |\\ 0,\\sigma_p^2)\n",
"\\end{equation*}"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"LOG2PI = np.log(2.0 * np.pi)\n",
"\n",
"def log_gaussian(x, mu, sigma):\n",
" return -0.5 * LOG2PI - nd.log(sigma) - (x - mu) ** 2 / (2 * sigma ** 2)\n",
"\n",
"def gaussian_prior(x):\n",
" sigma_p = nd.array([config['sigma_p']], ctx=ctx)\n",
" \n",
" return nd.sum(log_gaussian(x, 0., sigma_p))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Scale mixture prior\n",
"\n",
"Instead of a single Gaussian, the paper also suggests the usage of a scale mixture prior for $P(\\mathbf{w})$ as an alternative:\n",
"\n",
"\\begin{equation*}\n",
"P(\\mathbf{w}) = \\prod_i \\bigg ( \\pi \\mathcal{N}(\\mathbf{w}_i\\ |\\ 0,\\sigma_1^2) + (1 - \\pi) \\mathcal{N}(\\mathbf{w}_i\\ |\\ 0,\\sigma_2^2)\\bigg )\n",
"\\end{equation*}\n",
"\n",
"where $\\pi \\in [0, 1]$, $\\sigma_1 > \\sigma_2$ and $\\sigma_2 \\ll 1$. Again we are intersted in the log-prior $\\log P(\\mathbf{w})$, which can be expressed as follows:\n",
"\n",
"\\begin{equation*}\n",
"\\log P(\\mathbf{w}) = \\sum_i \\log \\bigg ( \\pi \\mathcal{N}(\\mathbf{w}_i\\ |\\ 0,\\sigma_1^2) + (1 - \\pi) \\mathcal{N}(\\mathbf{w}_i\\ |\\ 0,\\sigma_2^2)\\bigg )\n",
"\\end{equation*}"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def gaussian(x, mu, sigma):\n",
" scaling = 1.0 / nd.sqrt(2.0 * np.pi * (sigma ** 2))\n",
" bell = nd.exp(- (x - mu) ** 2 / (2.0 * sigma ** 2))\n",
" \n",
" return scaling * bell\n",
"\n",
"def scale_mixture_prior(x):\n",
" sigma_p1 = nd.array([config['sigma_p1']], ctx=ctx)\n",
" sigma_p2 = nd.array([config['sigma_p2']], ctx=ctx)\n",
" pi = config['pi']\n",
" \n",
" first_gaussian = pi * gaussian(x, 0., sigma_p1)\n",
" second_gaussian = (1 - pi) * gaussian(x, 0., sigma_p2)\n",
" \n",
" return nd.log(first_gaussian + second_gaussian)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Variational Posterior\n",
"\n",
"The last missing piece in the equation is the variational posterior. Again, we choose a Gaussian disribution for this purpose. The variational posterior on the weights is centered on the mean vector $\\mathbf{\\mu}$ and has variance $\\mathbf{\\sigma}^2$:\n",
"\n",
"\\begin{equation*}\n",
"q(\\mathbf{w}\\ |\\ \\theta) = \\prod_i \\mathcal{N}(\\mathbf{w}_i\\ |\\ \\mathbf{\\mu},\\mathbf{\\sigma}^2)\n",
"\\end{equation*}\n",
"\n",
"The log-posterior $\\log q(\\mathbf{w}\\ |\\ \\theta)$ is given by:\n",
"\n",
"\\begin{equation*}\n",
"\\log q(\\mathbf{w}\\ |\\ \\theta) = \\sum_i \\log \\mathcal{N}(\\mathbf{w}_i\\ |\\ \\mathbf{\\mu},\\mathbf{\\sigma}^2)\n",
"\\end{equation*}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Combined Loss\n",
"\n",
"After introducing the data likelihood, the prior, and the variational posterior, we are now able to build our combined loss function: $\\mathcal{F}(\\mathcal{D}_i, \\mathbf{\\theta}) = \\frac{1}{M} (\\log q(\\mathbf{w}\\ |\\ \\mathbf{\\theta}) - \\log P(\\mathbf{w})) - \\log P(\\mathcal{D}_i\\ |\\ \\mathbf{w})$"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def combined_loss(output, label_one_hot, params, mus, sigmas, log_prior, log_likelihood):\n",
" \n",
" # Calculate data likelihood\n",
" log_likelihood_sum = nd.sum(log_likelihood(output, label_one_hot))\n",
" \n",
" # Calculate prior\n",
" log_prior_sum = sum([nd.sum(log_prior(param)) for param in params])\n",
"\n",
" # Calculate variational posterior\n",
" log_var_posterior_sum = sum([nd.sum(log_gaussian(params[i], mus[i], sigmas[i])) for i in range(len(params))])\n",
" \n",
" # Calculate total loss\n",
" return 1.0 / num_batches * (log_var_posterior_sum - log_prior_sum) - log_likelihood_sum"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Optimizer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We use vanilla stochastic gradient descent to optimize the variational parameters. Note that this implements the updates described in the paper, as the gradient contribution due to the reparametrization trick is automatically included by taking the gradients of the combined loss function with respect to the variational parameters."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def SGD(params, lr):\n",
" for param in params:\n",
" param[:] = param - lr * param.grad"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluation metric"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to being able to assess our model performance we define a helper function which evaluates our accuracy on an ongoing basis."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def evaluate_accuracy(data_iterator, net, layer_params):\n",
" numerator = 0.\n",
" denominator = 0.\n",
" for i, (data, label) in enumerate(data_iterator):\n",
" data = data.as_in_context(ctx).reshape((-1, 784))\n",
" label = label.as_in_context(ctx)\n",
" output = net(data, layer_params)\n",
" predictions = nd.argmax(output, axis=1)\n",
" numerator += nd.sum(predictions == label)\n",
" denominator += data.shape[0]\n",
" return (numerator / denominator).asscalar()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Parameter initialization\n",
"\n",
"We are using a Gaussian distribution for each individual weight as our variational posterior, which means that we need to store two parameters, mean and variance, for each weight. For the variance we need to ensure that it is non-negative, which we will do by using the softplus function to express $\\mathbf{\\sigma}$ in terms of an unconstrained parameter $\\mathbf{\\rho}$. While gradient descent will be performed on $\\theta = (\\mathbf{\\mu}, \\mathbf{\\rho})$, the distribution for each individual weight is represented as $w_i \\sim \\mathcal{N}(w_i\\ |\\ \\mu_i,\\sigma_i)$ with $\\sigma_i = \\text{softplus}(\\mathbf{\\rho}_i)$.\n",
"\n",
"We initialize $\\mathbf{\\mu}$ with a Gaussian around $0$ (just as we would initialize standard weights of a neural network). It is important to initialize $\\mathbf{\\rho}$ (and hence $\\sigma$) to a small value, otherwise learning might not work properly."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"weight_scale = .1\n",
"rho_offset = -3\n",
"\n",
"# initialize variational parameters; mean and variance for each weight\n",
"mus = []\n",
"rhos = []\n",
" \n",
"for shape in layer_param_shapes:\n",
" mu = nd.random_normal(shape=shape, ctx=ctx, scale=weight_scale)\n",
" rho = rho_offset + nd.zeros(shape=shape, ctx=ctx)\n",
" mus.append(mu)\n",
" rhos.append(rho)\n",
"\n",
"variational_params = mus + rhos"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since these are the parameters we wish to do gradient descent on, we need to allocate space for storing the gradients."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"for param in variational_params:\n",
" param.attach_grad()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Main training loop\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The main training loop is pretty similar to the one we used in the MLP example. The only adaptation we need to make is to add the weight sampling which is performed during each optimization step. Generating a set of weights, which will subsequently be used in the neural network and the loss function, is a 3-step process:\n",
"\n",
"1) Sample $\\mathbf{\\epsilon} \\sim \\mathcal{N}(\\mathbf{0},\\mathbf{I}^{d})$"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def sample_epsilons(param_shapes):\n",
" epsilons = [nd.random_normal(shape=shape, loc=0., scale=1.0, ctx=ctx) for shape in param_shapes]\n",
" return epsilons"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"2) Transform $\\mathbf{\\rho}$ to a postive vector via the softplus function: $\\mathbf{\\sigma} = \\text{softplus}(\\mathbf{\\rho}) = \\log(1 + \\exp(\\mathbf{\\rho}))$"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def softplus(x):\n",
" return nd.log(1. + nd.exp(x))\n",
"\n",
"def transform_rhos(rhos):\n",
" return [softplus(rho) for rho in rhos]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"3) Compute $\\mathbf{w}$: $\\mathbf{w} = \\mathbf{\\mu} + \\mathbf{\\sigma} \\circ \\mathbf{\\epsilon}$, where the $\\circ$ operator represents the element-wise multiplication. This is the \"reparametrization trick\" for separating the randomness from the parameters of $q$."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def transform_gaussian_samples(mus, sigmas, epsilons):\n",
" samples = []\n",
" for j in range(len(mus)):\n",
" samples.append(mus[j] + sigmas[j] * epsilons[j])\n",
" return samples"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Complete loop\n",
"\n",
"The complete training loop is given below."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0. Loss: 2626.47417991, Train_acc 0.945617, Test_acc 0.9455\n",
"Epoch 1. Loss: 2606.28165139, Train_acc 0.962783, Test_acc 0.9593\n",
"Epoch 2. Loss: 2600.2452303, Train_acc 0.969783, Test_acc 0.9641\n",
"Epoch 3. Loss: 2595.75639899, Train_acc 0.9753, Test_acc 0.9684\n",
"Epoch 4. Loss: 2592.98582057, Train_acc 0.978633, Test_acc 0.9723\n",
"Epoch 5. Loss: 2590.05895182, Train_acc 0.980483, Test_acc 0.9733\n",
"Epoch 6. Loss: 2588.57918775, Train_acc 0.9823, Test_acc 0.9756\n",
"Epoch 7. Loss: 2586.00932367, Train_acc 0.984, Test_acc 0.9749\n",
"Epoch 8. Loss: 2585.4614887, Train_acc 0.985883, Test_acc 0.9765\n",
"Epoch 9. Loss: 2582.92995846, Train_acc 0.9878, Test_acc 0.9775\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD9CAYAAABQvqc9AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl8VOW9x/HPLwkJEJZs7GQjbEZlDTsKola9dddaRVF7\nXdpbra3ebrb32tYutrfW2lvtgkoLrtfaRVu1rigKYQmbspOEbOxJSAgJIcs8948zQEhRAkxyksz3\n/Xrxysw5ZzK/GZLvPHnOc57HnHOIiEh4iPC7ABERaTsKfRGRMKLQFxEJIwp9EZEwotAXEQkjCn0R\nkTDSotA3s4vNbLOZ5ZrZt4+zP9XM3jGzj8zsPTMb3GTfz8xsXfDf50NZvIiInJwThr6ZRQKPA5cA\nmcANZpbZ7LCHgQXOuVHAg8BDwcd+FhgHjAEmAV83s16hK19ERE5GS1r6E4Fc51y+c64OeAG4otkx\nmcC7wdsLm+zPBBY55xqcc9XAR8DFp1+2iIicipaE/iCguMn9kuC2ptYCVwdvXwX0NLPE4PaLzay7\nmSUB5wHJp1eyiIicqqgQfZ+vA4+Z2a3AImA70Oice9PMJgBLgL1ANtDY/MFmdidwJ0BsbOz4kSNH\nhqgsEZHwsHLlylLnXJ8THdeS0N/Osa3zwcFtRzjndhBs6ZtZD+Aa51xFcN+PgR8H9z0HbGn+BM65\nucBcgKysLJeTk9OCskRE5DAzK2zJcS3p3lkBDDOzdDOLBq4HXmn2ZElmdvh73Q/MC26PDHbzYGaj\ngFHAmy17CSIiEmonbOk75xrM7G7gDSASmOecW29mDwI5zrlXgJnAQ2bm8Lp37go+vAvwgZkB7Adu\ncs41hP5liIhIS1h7m1pZ3TsiIifPzFY657JOdJyuyBURCSMKfRGRMKLQFxEJIwp9EZEwEqqLs0RE\n5BQVl9ewOLeUgIPZk1Ja9bkU+iIibay8uo7svDI+zC1lcW4pReU1AIxNiVPoi4h0dAfrGllRUM7i\n3FI+zC1lw879OAc9Y6KYnJHIbdPTmTY0kYw+PVq9FoW+iEiINTQG+Hh7JYtzS1mcW8bKwn3UNQbo\nEmmMS4nnvguGM21YEqMG9SYqsm1PrSr0RUROk3OOvL3VLMkr5cOtpWTnl1FV600+kDmgF7dOS2Pa\n0CQmpMXTPdrf2FXoi4icgj37a1mcV8qHW8tYnFvKrv21ACQndOPSUQOYNjSJKUMSSewR43Olx1Lo\ni4i0QFVtPcvyy4+cfN265wAA8d27MHVoEtOHJjEtI4mUxO4+V/rpFPoiIsdR1xBgddE+r18+r4w1\nxRU0Bhxdu0QwMT2Ra8cPZtrQJDIH9CIiwvwut8UU+iIiQCDg2LSryuuXzy1lWX45B+sbiTAYnRzH\nf8zIYNrQJMalxhETFel3uadMoS8iYevwRVGL88pYkltKWXUdABl9Yrkuy2vJTxqSSO9uXXyuNHQU\n+iISNvZU1ZKdV8aS3DIW55VSsu8gAH17xjBjeB+mDU1i2tAk+vfu6nOlrUehLyKdVuXBepbll7Ek\nr4wleaVs2e2dfO3VNYopGYnccc6QIxdFBRd76vQU+iLSaRysaySnsNwL+dxSPt5eScBB1y4RTEhL\n4Opxg5mWkUTmwF5EdqCTr6Gk0BeRDqu+McBHJRUszvVa8qsKK6hrDBAVYYxNieMrs4YxNSORMSkd\n++RrKCn0RaTDCAQcG3ftJzvPuyBq+bZyqusaMTt65evUjEQmpCUQG6N4Ox69KyLSbjnnKCjzRtgs\nySslO6+MfTX1AAzpE8tV4wYxLSOJyUMSiY+N9rnajkGhLyLtyq7K2mDIe102Oyu96Q0G9O7KrJH9\nmDY0kSkZiQzo3c3nSjsmhb6I+KqixptbfnGeF/T5e6sBb3qDKRmJ3JXhDaNMS+weNiNsWpNCX0Ta\n1J6qWlYW7GN5QTnLt5UfmVs+NjqSiekJ3DAhhalDEzmjf8ea3qCjUOiLSKtxzpFfWk1OQTkrCvaR\nU1BOQZm3SlTXLhGMSY7j3guGM21oIqMGx9GljeeWD0cKfREJmfrGAOt37Ccn2IpfWbjvyNQGCbHR\nZKXGc+OkVLLS4jlzYG+ioxTybU2hLyKn7MChBlYV7jvSkl9dvI/a+gAAqYndmTmiLxPS4slKSyCj\nT6z65NsBhb6ItNie/bWsKNjHioJycgrL2bBjPwEHEQaZA3tx/YQUJqYnkJUaT99enXf+mo5MoS8i\nx3V4CcCcgnKWF5STU7CPonKvP75bl0jGpsRx96xhTEiLZ2xKPD10MVSHoP8lEQG8RUPW7ag85qTr\n4QuhEmOjyUqL5+YpqWSlJXDmwF466dpBKfRFwlRVbT2riiqCIV/OmuKKI/3xaYndOf+MfkxMSyAr\nLZ70JPXHdxYKfZEwUVlTz7JtZSzNL2dpfhmbdnn98ZERxpkDezF7YioT0uIZnxZP357qj++sFPoi\nnVRlTT3LC7yAX5pfduQiqJioCMalxPOVWcOYkJbA2JQ4TU4WRvQ/LdJJVB6sZ8W2YMhvK2P9Di/k\no6MiGJ8Sz9fOH86UjERGJ/fWNMNhTKEv0kHtr20S8vnlrN/hLRgSHRXBuJQ4vnr+MKYMSWR0chxd\nuyjkxaPQF+kg9tfWk1NQfqRPfl1wVajoyIgjC4ZMyUhkjEJePoVCX6SdqqqtJ6dg35E++Y+bhPyY\n4Bj5yUMSGJcSr5CXFlPoi7QTBw41sKLgaHfNuu2VNAYcXSKNscnx3H3eUCYPSWRcqkJeTp1CX8Qn\nBw41HNNd83GTkB+THMeXZ2Z4IZ8ST7dohbyERotC38wuBn4FRAJPOud+2mx/KjAP6AOUAzc550qC\n+/4H+CwQAbwFfNU550L2CkQ6iJq6BlY06a75qMQL+agIL+T/Y4YX8uNTFfLSek4Y+mYWCTwOXAiU\nACvM7BXn3IYmhz0MLHDOzTezWcBDwBwzmwpMA0YFj/sQmAG8F7qXINI+1dY3sqpoH9l5ZWTnlbGm\nuIKGYMiPTo7jSzOGHAn57tH6o1vaRkt+0iYCuc65fAAzewG4Amga+pnAfcHbC4G/BW87oCsQDRjQ\nBdh9+mWLtD91DQE+KqlgSTDkVxbto64hQITB2YPjuP2cIUzJSCQrNV4XQ4lvWvKTNwgobnK/BJjU\n7Ji1wNV4XUBXAT3NLNE5l21mC4GdeKH/mHNu4+mXLeK/huCCIUvyysjOL2PFtnIO1jdiBmf078Wc\nyalMzUhkQnoCvbp28btcESB0J3K/DjxmZrcCi4DtQKOZDQXOAAYHj3vLzM5xzn3Q9MFmdidwJ0BK\nSkqIShIJrUDAsXHX/iPdNcu3lVN1qAGAYX17cF3WYKZkJDIpPZH42GifqxU5vpaE/nYgucn9wcFt\nRzjnduC19DGzHsA1zrkKM7sDWOqcOxDc9zowBfig2ePnAnMBsrKydJJX2gXnHLl7DpCdX8aS3DKW\nbSs7MtVwelIsl44eyJSMRCYPSdAEZdJhtCT0VwDDzCwdL+yvB2Y3PcDMkoBy51wAuB9vJA9AEXCH\nmT2E170zA3g0RLWLhJRzjsKyGi/kg6350gOHABgU143zz+jH1IxEpmQkMqB3N5+rFTk1Jwx951yD\nmd0NvIE3ZHOec269mT0I5DjnXgFmAg+ZmcPr3rkr+PCXgFnAx3gndf/pnPt76F+GyKnZXnGQJbml\nZOeXsTSvjB2VtQD07RnD9KFewE8ZkkRyQjfNJy+dgrW3IfNZWVkuJyfH7zKkk9qzv5bsfK8VvySv\n7Mjyfwmx0UwZksjkjESmZiQyRIuGSAdjZiudc1knOk7jxqTT21l5kD8uKeDtDbvJ21sNQK+uUUwa\nksgXpqUxJSOR4X17EhGhkJfOT6EvndbmXVXMXZTPy2u244DpQ5P4/IRkpgxJInNgLyIV8hKGFPrS\nqTjnWL6tnN8vyufdTXvo1iWSmyanctv0dJITuvtdnojvFPrSKTQGHG9t2MXv3s9nTXEFCbHR3Hfh\ncOZMTtWYeZEmFPrSodXWN/KXVdt54oN8tpVWk5rYnR9eeRafGz9Y0w+LHIdCXzqkipo6nllayB+X\nFFB6oI5Rg3vzmxvHcdGZ/dVXLx1LXQ3sK4B928AiYMQlrfp0Cn3pULZXHOSpD7bxwooiauoamTmi\nD188N4PJQxI0xFLar4P7oHwblOd74V4e/LdvG1TtPHpc/1EKfRGAjTv3M3dRPq+s3YEBl48eyB3n\nDuGMAb38Lk0EnIMDu71QPxzmh2+X50NtxbHH9+gPCUMgYxYkpEN8+tGvrUyhL+2Wc47s/DJ+/34+\n72/ZS/foSG6dmsa/T09nUJymQZA21tgAlcVNWur5XrfM4ZCvrzl6rEVCXLIX4mdd0yTYh0B8GkT7\nN5JMoS/tTmPA8c91u/j9ojw+KqkkqUc037hoBDdNSqV3d01RLK2ovvZo/3rzVntFEQQajh4b1dUL\n8Ph0GDLTC/bD4R6XApHt82dVoS/txsG6Rl5aWcwTH2yjqLyG9KRYfnLV2Vw9bpBG4khoOee12ktW\nQPEK2L3OC/b9O/CmCQuK6Q0JaTBgNGReGQz2IV6w9xwAERF+vYJTptAX3+2rrmNBdiHzswsor65j\nTHIc3/m3kVyYqZE4EiL1B2HHGihZDsXLoSQHDuzy9kV1g/5nQ/q5R7tgDod7t3joZAMEFPrim+Ly\nGp76cBv/t6KYg/WNnD+yL1+ckcGEtHiNxJFT55zXRVOS44V8yQrY9fHRrpn4dBgyAwZP8P71O7Pd\ndsW0BoW+tLl12yuZuyifVz/eSYTBFWMGcee5Qxjer6ffpUlHVFcNO1YfbcGXLIfqvd6+LrEwaBxM\nvQeSJ8KgLOjRx996fabQlzbhnGNxbhm/X5THB1tL6RETxW3T0/nCtDQtSNKWnIOda2HbIujSDXr2\n94YP9uwHPfpBVIzfFX4657y+95IVwf745bB7PbhGb3/iUBh6wdFWfN9MiFTMNaV3Q1pVQ2OA19bt\n4vfv57F+x3769IzhWxePZPakFHp3C58/qX0VaISipbDpH7DxH1BZ9MnHdo0LfhD0C37tG/xQCG7r\n0c/7gIjp1TZ93YcOwPaVR0O+ZAXUlHn7ont6rfhz7jsa8t0TWr+mDk6hL62mZF8Ndz27irUllQzp\nE8vPrjmbK8cOIiZKI3FaXcMhrzW/8RXY9BrUlEJktHcx0IxvwvCLwAWgahcc2OOd1Kza7X09sNu7\nXZTtfW089K/fP6pb8K+D4AfDMR8U/Y7e7p7U8hEuzkFZ7tEWfMkK2LPBqxMgaQQMvwSSgwHfZyRE\n6GfpZCn0pVW8u2k39/7fWgIBx6+uH8NlowZqkZLWdugA5L4FG/8OW96EuiqvNTz8MzDyUhh2IcQ0\nO2/Ss/+nf0/nvKtJD+wJfkDsPvr18O29myD/fThU+a+Pt0iI7XP0A6L5B0VU16P98dtzvOkKwBsq\nOXi8V/fgCd7tbvGheZ/CnEJfQqqhMcAv3trCb9/L48yBvfjNjeNITYz1u6zOq7oMtrzuBX3eQq9V\n3j0JzroKRl7mjVI5nX56My9su8VDnxGffmz9waN/JRzY1eSDIvhXRNVO2LnGO8l6uPXuPYnXaj/j\nsmDAT4Sk4R1yDHxHoNCXkNmzv5avPL+aZdvKuWFiCt+7LFMXVbWGyhLY9KoX9IWLvQDtnQxZ/+4F\nZ8pkf7o9unQLXqGa9unHNTZ43U0Hdnsjb/qdCV17t0WFgkJfQmRJXin3PL+G6kMN/PLzo7lq7GC/\nS+pc9m6BTX/3TsTuWOVt6zMSpt/nBf2A0R3nIqLIKK9r50RdS9IqFPpyWgIBx2/ey+WRt7aQnhTL\nc3dM0nj7UHDO6wrZGAz60s3e9kHj4fzveUGfNMzfGqVDUujLKdtXXce9L67hvc17uXz0QB66+mxi\nY/QjdcoCjd6ImY1/97pvKou9E6Fp02DC7TDys9B7kN9VSgen31A5JauL9nH3c6vZW3WIH155FjdN\nStHUCaeivha2ve8Nrdz8ujcGPTIGhp4PM+/3FtTQ2HMJIYW+nBTnHH9cUsBPXttIv15deek/pjBq\ncJzfZXUsh6pg65tet83WN6HugHex0/CLvCGKQy+AmB5+VymdlEJfWqyqtp5v//ljXv14Jxec0Y9f\nfG605rf/NPUHjy6ycXiZvNKtXhdOY503fv2sa+CMyyH9nPY/BYJ0Cgp9aZGNO/fz5WdXUVRew/2X\njOTOc4eoOwfgYEWzpfG2HV1ZqWrHscd27e3N8DjhDu9EbPJEXVEqbU6hLyf04opi/vvldfTu1oXn\n75jMxPQw6mN2zrvI6MiC1s1WUzp8BelhPfp587APmXl0XvbD65+qb17aAYW+fKKDdY088PI6/rSy\nhKkZifzq+rH06dkJuyAaG2B/ybHdMEda7QVQX330WIvwLoRKSA+upNRkwY34NIjW1cfSvin05bjy\n9x7gy8+uYvPuKu6ZNZSvXjC8Y69i1VgfbKU3a6mXbwuufVp/9NjIGC/AE4KLbRxuqScM8QI/Ktq3\nlyFyuhT68i9e/Wgn3/rzR3SJNP5w6wRmjujrd0mnrrEeVj8N7/3Uu+z/sJheXrD3PxsyLz92mbye\nAzXvi3RaCn05oq4hwE9e28gflxQwLiWOx2aPY2BcB13gxDlv7Ps7D3rT9SZPhgsfhISMYP96YseZ\ntkAkhBT6AgTnvn9uNWuLK7htejrfungk0VEdtLVb8CG89YC3+EafkXD9895FTgp5EYW+wMJNe7j3\nxTU0Njp+e+M4Ljl7gN8lnZpd6+CdH3gXPPUcCJc/BqNv0HJ5Ik3otyGMNTQG+OXbW3h8YR5nDOjF\nb28cR1pSBxx9UlEEC38Ca1+Arr28bpyJd3pT/YrIMRT6YWpPVS33PL+apfnlXD8hme9ffmbHm/u+\nphw++AUsnwsYTLsHpt+rFZZEPoVCPwwtzS/jK8+vpqq2noc/N5prx3ewue/ramDZb+HDR715a8bM\n9iYn693BXoeIDxT6YSQQcPz2/Tx+8eZm0pJiefq2iYzs38vvslqusQHWPAMLH/KW4Bvxb3D+A9D3\nDL8rE+kwWhT6ZnYx8CsgEnjSOffTZvtTgXlAH6AcuMk5V2Jm5wG/bHLoSOB659zfQlG8tFxFTR33\nvbiWdzft4dJRA/jpNaPo0VHmvncONv0D3v4BlG2F5EnwuT9C6hS/KxPpcE74W29mkcDjwIVACbDC\nzF5xzm1octjDwALn3HwzmwU8BMxxzi0ExgS/TwKQC7wZ4tcgJ7CmuIK7nl3FnqpaHrziTOZMTu04\nk6UVLIa3vwclKyBpBFz/nNfC7yj1i7QzLWnqTQRynXP5AGb2AnAF0DT0M4H7grcXAsdryV8LvO6c\nqzn1cuVkOOdYkF3Ij17dQN+eXfnTl6YyJrmDzH2/e4M3/HLLP4PDL38No2dr+KXIaWrJb9AgoLjJ\n/RJgUrNj1gJX43UBXQX0NLNE51xZk2OuBx453hOY2Z3AnQApKSktq1w+VfWhBr7154/4x0c7mTWy\nL49cN5q47h1gzpiK4uDwy+e9qRIu+D5M/CJEd/e7MpFOIVTNpq8Dj5nZrcAiYDvQeHinmQ0Azgbe\nON6DnXNzgbkAWVlZLkQ1ha3Kg/XcMm85H5VU8M2LR/ClczOIaO+TpR0ZfvmEd3/q3TD9Pk1HLBJi\nLQn97UByk/uDg9uOcM7twGvpY2Y9gGuccxVNDrkO+Ktzrh5pVeXVdcx5ahlbdlfx25vGc9GZ/f0u\n6dPV1cCy33nDLw/tPzr8Mi75xI8VkZPWktBfAQwzs3S8sL8emN30ADNLAsqdcwHgfryRPE3dENwu\nrWhv1SFuenIZBWXVzL05i/Pa8+yYjQ2w5ll47yGo2gnDL/GGX/bL9LsykU7thKHvnGsws7vxumYi\ngXnOufVm9iCQ45x7BZgJPGRmDq97567DjzezNLy/FN4PefVyxK7KWmY/uZSdFbXMu3UC04Ym+V3S\n8TkHm171TtKWboHBE+HaeZA61e/KRMKCOde+utCzsrJcTk6O32V0KCX7apj9xDLKq+uYd+uE9ruc\nYWG2N/tlyXJIGg7nfw9GflbDL0VCwMxWOueyTnScxr91cIVl1cx+YhlVtfU8fdtExqa0w3ln9mz0\nLqza8jr0HACX/S+MuVHDL0V8oN+6Dix3zwFufHIpdQ0BnrtjMmcN6u13SceqLvO6cVY/DdE9vZb9\npC9p+KWIjxT6HdSmXfu56cllALxw5xRG9O/pc0VNBBohZx68+yNvQrRJX4Jzv6HhlyLtgEK/A1q3\nvZKbnlpGTFQEz94+maF9e/hd0lFFS+G1r8OujyH9XLjk59B3pN9ViUiQQr+DWV20j5vnLadX1y48\nd8ckUhPbyaInVbu9OXLWPg+9BsPn5kPmFTpJK9LOKPQ7kOXbyvnCH5aT1DOGZ2+fxOD4dtA33ljv\nLWKy8CFoPATn/Kf3L7qdfBiJyDEU+h3E4txSbp+fw8C4rjx7+2T69+7qd0mQ/z68/k3YuwmGXgiX\n/AwSM/yuSkQ+hUK/A1i4eQ9ffHol6YmxPHP7JPr0jPG3oMoSePO/YP1fIS4VbngBhl+srhyRDkCh\n3869sX4Xdz+3iuH9evL0bZNIiPVxpsyGQ5D9GCx6GFwAzvsuTP2KFiAX6UAU+u3Y39fu4Gv/t4az\nB/Vm/r9PpHe3Lv4Vs/UtryunPB9GXgoX/QTiU/2rR0ROiUK/nfrzyhK+8dJaslITmPeFCf4tbVi+\nDd74Dmx+DRKHwk1/hqEX+FOLiJw2hX479PzyIr7z14+ZmpHIEzdn0T3ah/+muhpY/Kg35XFEFFzw\nA5j8ZYjqAAuxiMgnUui3M/OXFPC9V9Zz3og+/Pam8XTtEtm2BRyeBfOf90NlEZx1LXzmh9BrYNvW\nISKtQqHfjvz+/Tween0Tn8nsx69njyUmqo0Dv3QrvP4tyHsH+mbCra9C2vS2rUFEWpVCv53433e2\n8shbW7h01AB++fkxdImMaLsnP3QAFv0csh/3RuJc/FOYcDtE+njiWERahULfZ845Hn5zM48vzOPq\ncYP4+bWjiWyr9Wydg3V/hjf/G6p2eNMdX/B96NGOV9wSkdOi0PeRc44fvbqRpz7cxg0Tk/nxlWe3\n3QLmuzfAa9+Awg9hwGi4bj4kT2yb5xYR3yj0fRIIOB54ZR3PLC3i1qlpfO+yTKwtrmg9WAHv/dSb\nL6drL7j0lzDuFoho4/MHIuILhb4PGgOO+//yES/mlPDFGUP49sUjWz/wAwFvBsy3vwfVpTD+Vm8h\ncs1xLxJWFPptrKExwH/+aS0vr9nBPecP494LhrV+4O9Y43XllCyHwRPgxj/BwLGt+5wi0i4p9NtQ\nXUOAr76wmtfX7eIbF43grvOGtu4T1pTDuz+EnD9AbBJc8RsYfQNEtOHIIBFpVxT6baS2vpG7n1vF\n2xv38N+XZnLb9PTWezLnYM1z8OZ3oXa/t1zhzG9Dt7jWe04R6RAU+m3gYF0jdz6dwwdbS/nhlWcx\nZ3IrTlRWuR3+/lXIfQuSJ8Olj0C/M1vv+USkQ1Hot7LqQw3cNn8Fy7aV8z/XjOK6Ccmt80TOweqn\n4Y3vQqABLv4ZTLxTXTkicgyFfivaX1vPF/6wgjXFFTz6+TFcMWZQ6zxRRTH8/R7IexdSp8MVv4aE\nIa3zXCLSoSn0W0lFTR03z1vOhh37eeyGsVxy9oDQP4lzsGo+vPFf3qIm//YwZN2m1r2IfCKFfiso\nO3CIm55aTt6eA/zupvFckNkv9E9SUQSvfAXy34O0c+DyX0NCK54cFpFOQaEfYpU19Vw/dylF5TU8\neUsW5w7vE9onCARg5R/grQe8+599BMZ/Qa17EWkRhX6IPbOskK17DvDs7ZOYNjQptN98X4HXut+2\nCNJneK17LVkoIidBoR9CDY0BnllayDnDkkIb+IEA5DwFb30PLAIufdSbRqEt5uoRkU5FoR9Cb2/c\nzc7KWn5weQjHxZfnw8tf8WbDzJgFl/0vxLXSsE8R6fQU+iG0ILuQQXHdOP+MEJy4DQRgxRPw9ve9\nNWov/zWMnaPWvYicFoV+iGzdXcWSvDK+dfHI018EpSwPXr4bipbA0Avgsl9B78GhKVREwppCP0QW\nZBcSHRXB50/nittAIyz7PbzzIERGexOkjZmt1r2IhIxCPwT219bz51UlXD56IAmx0af2TUpz4eW7\noHgpDLsILnsUeg0MbaEiEvYU+iHwl5Ul1NQ1csuUtJN/cKARlv4G3v0RRMXAlb+D0derdS8irUKh\nf5oCAceC7ELGpsRx9uDeJ/fgvVvg5S9DyQoYfom3dGGvVpiuQUQkqEWXcZrZxWa22cxyzezbx9mf\nambvmNlHZvaemQ1usi/FzN40s41mtsHM0kJXvv8W55WSX1p9cq38QCMs/hX8bjqUboWrn4Abnlfg\ni0irO2FL38wigceBC4ESYIWZveKc29DksIeBBc65+WY2C3gImBPctwD4sXPuLTPrAQRC+gp8Nn9J\nIUk9ornk7P4te8DezfC3L8P2HBh5qTeNQs9WmJtHROQ4WtLSnwjkOufynXN1wAvAFc2OyQTeDd5e\neHi/mWUCUc65twCccwecczUhqbwdKC6v4Z1Nu7lhYgoxUZGffnBjA3z4S/jdOd4FV9c8BZ9/RoEv\nIm2qJaE/CChucr8kuK2ptcDVwdtXAT3NLBEYDlSY2V/MbLWZ/Tz4l0On8MyyQiLMmD0p5dMP3LMR\nnrrQu9Bq+GfgrmVw9rU6WSsibS5UUzN+HZhhZquBGcB2oBGv++ic4P4JwBDg1uYPNrM7zSzHzHL2\n7t0bopJaV219I/+3opiLzuzHgN7djn9QYwMsehh+fy5UFMK1f4DrnoYefdu2WBGRoJaM3tkONL3i\naHBw2xHOuR0EW/rBfvtrnHMVZlYCrHHO5Qf3/Q2YDDzV7PFzgbkAWVlZ7tReStt6Ze0OKmrqmTM5\n7fgH7F7v9d3vXANnXuUtcBIb4lk3RUROUktCfwUwzMzS8cL+emB20wPMLAkod84FgPuBeU0eG2dm\nfZxze4FZQE6oiveLc44F2QUM79eDyUMSjt0ZCMAHv4D3fwZde8Pn5sOZV/pSp4hIcyfs3nHONQB3\nA28AG4FdB6NNAAAMPElEQVQXnXPrzexBM7s8eNhMYLOZbQH6AT8OPrYRr2vnHTP7GDDgiZC/ija2\nuriCddv3c/OUNKx5v/yqP8LCH8EZl8FdyxX4ItKutOjiLOfca8BrzbY90OT2S8BLn/DYt4BRp1Fj\nu7NgSQE9Y6K4amyz89n1B+H9/4HkSXDtPJ2oFZF2R2vsnaS9VYd49eOdXJs1mNiYZp+Zy5+Aqp1w\n/gMKfBFplxT6J+mF5UXUNzrmTG62TGHtfvjwEW+hk7Tp/hQnInICCv2TUN8Y4NllRZw7vA9D+vQ4\ndmf243Bwn9fKFxFppxT6J+GtDbvZtb+WW6Y0a+VXl0H2Y3DG5TBwrD/FiYi0gEL/JMxfUkByQjdm\njmh2cdWHj0B9DZz3XX8KExFpIYV+C23atZ9l28qZMzn12OUQK7d7J3BHXQ99R/pXoIhICyj0W2hB\ndiExURFcl9VsOcRF/wMuADP/ZcZpEZF2R6HfApUH6/nrqu1cOWYQcd2bLIdYlgernoasL0B86id/\nAxGRdkKh3wIvrSzhYH0jc5qfwF34E28B83O+7k9hIiInSaF/AoGA4+nsAsanxnPWoCbLIe76GNa9\nBJO/pDnxRaTDUOifwAe5pRSU1XBz81b+uz+GmN4w7av+FCYicgoU+iewYEkBST1iuOSsJuvXFi+H\nLa/DtHugW7x/xYmInCSF/qcoKqvh3c17mD0pheio4FvlHLzzIMT2gUlf8rdAEZGTpND/FM8sKyTS\njBubLoeYvxAKPvBO3sb0+OQHi4i0Qwr9T3CwLrgc4ln96derq7fxcCu/d7I3TFNEpINR6H+CV9Zu\np/JgPbdMSTu6cdM/YMdq70KsqBjfahMROVUK/eNwzjF/SSEj+/dkQlrwRG2gEd79ESQO86ZcEBHp\ngBT6x7GycB8bdu7nlqlNlkP86EXYuwlmfRciW7TgmIhIu6PQP4752YX06hrFFWMGehsa6uC9n8CA\n0XDGFf4WJyJyGhT6zezZX8vrH+/kuqxkukcHW/Sr5kNFEcx6ACL0lolIx6UEa+a55UU0OsdNh5dD\nrKuGRT+HlKkw9Hx/ixMROU3qnG6irsFbDnHG8D6kJcV6G5fPhQO74XPztdi5iHR4auk38cb6Xeyt\nOnR0mObBCvjwURj2GUid4mttIiKhoNBv4unsQlISujNjeB9vQ/ZjUFsBs/7L38JEREJEoR+0Ycd+\nlheUc/OUVCIiDA7shezfwJlXeaN2REQ6AYV+0NNLC+jaJYLPjQ8uh/jBL6ChVoudi0inotAHKmvq\n+evq7Vw1dhC9u3eBimLIeQrGzIakYX6XJyISMgp94E8ri6mtDzBncpq34f2feV9nfMu3mkREWkPY\nh34g4FiQXcjEtAQyB/aC0q2w5jnIug3ikv0uT0QkpMI+9N/fspei8hpunhq8GGvhjyGqK5zzn/4W\nJiLSCsI+9OdnF9C3ZwwXndkfdq6F9X+FKV+GHn38Lk1EJOTCOvQLSqt5b/NebpyUSpfICG/q5K5x\nMOVuv0sTEWkVYR36Ty8tpEukccOkZCjMhq1vwvSvQbc4v0sTEWkVYRv6NXUNvJhTzMVnDaBvjxhv\nGcQe/WDiF/0uTUSk1YRt6P9t9Q6qahu4ZUoq5L4DRUvg3G9AdHe/SxMRaTVhGfrOORZkF5A5oBfj\nU3rDOz+AuBQYd4vfpYmItKqwDP0VBfvYtKuKW6amYhv/Drs+gpnfgahov0sTEWlVYRn687ML6N2t\nC5ef3c8bsdNnJIy6zu+yRERaXYtC38wuNrPNZpZrZt8+zv5UM3vHzD4ys/fMbHCTfY1mtib475VQ\nFn8qdlXW8sa6XXx+QjLdNv4JyrZ6UydHRPpdmohIqzth6JtZJPA4cAmQCdxgZpnNDnsYWOCcGwU8\nCDzUZN9B59yY4L/LQ1T3KTuyHGJWf3jvpzBwLIy81O+yRETaREta+hOBXOdcvnOuDngBuKLZMZnA\nu8HbC4+zv12oawjw3LIiZo3oS8q2F6GyGM5/QMsgikjYaEnoDwKKm9wvCW5rai1wdfD2VUBPM0sM\n3u9qZjlmttTMrjytak/T6+t2UnrgELdO6OMtdp52Dgw5z8+SRETaVKhO5H4dmGFmq4EZwHagMbgv\n1TmXBcwGHjWzjOYPNrM7gx8MOXv37g1RSf9qQXYh6UmxTCt9Car3qpUvImGnJaG/HWg6x/Dg4LYj\nnHM7nHNXO+fGAt8NbqsIft0e/JoPvAeMbf4Ezrm5zrks51xWnz6tM9HZuu2VrCzcx23j44lY8r8w\n/BJIntgqzyUi0l61JPRXAMPMLN3MooHrgWNG4ZhZkpkd/l73A/OC2+PNLObwMcA0YEOoij8ZC7IL\n6B4dybWH/gKHKrXYuYiEpROGvnOuAbgbeAPYCLzonFtvZg+a2eHRODOBzWa2BegH/Di4/Qwgx8zW\n4p3g/alzrs1Df191HS+v2cGcs7rSdeVcOOta6H9WW5chIuK7qJYc5Jx7DXit2bYHmtx+CXjpOI9b\nApx9mjWethdzijnUEOBLEX+FhkNw3nf8LklExBed/orcxoDj6aWFXJpcR/yGZ2HcHEj8l3PJIiJh\noUUt/Y5s4aY9lOw7yAt9X4byCDj3m36XJCLim07f0l+wtJDJPUsZVPQyTLwDeje/xEBEJHx06pZ+\n/t4DLNqyl7cH/Q2r6g7T7/W7JBERX3Xqlv7TSwsZG5XP0LJ3vXVvY5P8LklExFedtqVffaiBl3JK\neLH3yxCIhyl3+V2SiIjvOm1L/6+rt5NZ9zFnVK+A6fdB115+lyQi4rtO2dJ3zrFgyTYejX0J120A\nNvEOv0sSEWkXOmXoL80vZ1Dph2RGb4QZv4Qu3fwuSUSkXeiU3TtPL8nn29EvEohLg7Fz/C5HRKTd\n6HQt/R0VB4na/Aojogph1hMQ2cXvkkRE2o1O19J/ITufeyNepC5xJJx1jd/liIi0K52qpX+ooZHq\n5QtIj9gFF/5Si52LiDTTqVr6/1xTwG2BF9mfOBpGXOJ3OSIi7U6naumXLvwdA62cwL/9Qcsgiogc\nR6dp6Rft3M2VB55nR8IkIjJm+l2OiEi71Gla+ik9HDVDz6Hb9K/5XYqISLvVaUKfnv3pPud5v6sQ\nEWnXOk33joiInJhCX0QkjCj0RUTCiEJfRCSMKPRFRMKIQl9EJIwo9EVEwohCX0QkjJhzzu8ajmFm\ne4HC0/gWSUBpiMrp6PReHEvvx7H0fhzVGd6LVOdcnxMd1O5C/3SZWY5zLsvvOtoDvRfH0vtxLL0f\nR4XTe6HuHRGRMKLQFxEJI50x9Of6XUA7ovfiWHo/jqX346iweS86XZ++iIh8ss7Y0hcRkU/QaULf\nzC42s81mlmtm3/a7Hj+ZWbKZLTSzDWa23sy+6ndNfjOzSDNbbWb/8LsWv5lZnJm9ZGabzGyjmU3x\nuyY/mdm9wd+TdWb2vJl19bum1tQpQt/MIoHHgUuATOAGM8v0typfNQD/6ZzLBCYDd4X5+wHwVWCj\n30W0E78C/umcGwmMJozfFzMbBNwDZDnnzgIigev9rap1dYrQByYCuc65fOdcHfACcIXPNfnGObfT\nObcqeLsK75d6kL9V+cfMBgOfBZ70uxa/mVlv4FzgKQDnXJ1zrsLfqnwXBXQzsyigO7DD53paVWcJ\n/UFAcZP7JYRxyDVlZmnAWGCZv5X46lHgm0DA70LagXRgL/CHYHfXk2YW63dRfnHObQceBoqAnUCl\nc+5Nf6tqXZ0l9OU4zKwH8Gfga865/X7X4wczuxTY45xb6Xct7UQUMA74rXNuLFANhO05MDOLx+sV\nSAcGArFmdpO/VbWuzhL624HkJvcHB7eFLTPrghf4zzrn/uJ3PT6aBlxuZgV43X6zzOwZf0vyVQlQ\n4pw7/JffS3gfAuHqAmCbc26vc64e+Asw1eeaWlVnCf0VwDAzSzezaLwTMa/4XJNvzMzw+mw3Ouce\n8bsePznn7nfODXbOpeH9XLzrnOvULblP45zbBRSb2YjgpvOBDT6W5LciYLKZdQ/+3pxPJz+xHeV3\nAaHgnGsws7uBN/DOvs9zzq33uSw/TQPmAB+b2Zrgtu84517zsSZpP74CPBtsIOUDX/C5Ht8455aZ\n2UvAKrxRb6vp5Ffn6opcEZEw0lm6d0REpAUU+iIiYUShLyISRhT6IiJhRKEvIhJGFPoiImFEoS8i\nEkYU+iIiYeT/AUSFzzyvHHEbAAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x1096f1630>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"epochs = config['epochs']\n",
"learning_rate = config['learning_rate']\n",
"smoothing_constant = .01\n",
"train_acc = []\n",
"test_acc = []\n",
"\n",
"for e in range(epochs):\n",
" for i, (data, label) in enumerate(train_data):\n",
" data = data.as_in_context(ctx).reshape((-1, 784))\n",
" label = label.as_in_context(ctx)\n",
" label_one_hot = nd.one_hot(label, 10)\n",
" \n",
" with autograd.record():\n",
" # sample epsilons from standard normal\n",
" epsilons = sample_epsilons(layer_param_shapes)\n",
" \n",
" # compute softplus for variance\n",
" sigmas = transform_rhos(rhos)\n",
"\n",
" # obtain a sample from q(w|theta) by transforming the epsilons\n",
" layer_params = transform_gaussian_samples(mus, sigmas, epsilons)\n",
" \n",
" # forward-propagate the batch\n",
" output = net(data, layer_params)\n",
" \n",
" # calculate the loss\n",
" loss = combined_loss(output, label_one_hot, layer_params, mus, sigmas, gaussian_prior, log_softmax_likelihood)\n",
" \n",
" # backpropagate for gradient calculation\n",
" loss.backward()\n",
" \n",
" # apply stochastic gradient descent to variational parameters\n",
" SGD(variational_params, learning_rate)\n",
" \n",
" # calculate moving loss for monitoring convergence\n",
" curr_loss = nd.mean(loss).asscalar()\n",
" moving_loss = (curr_loss if ((i == 0) and (e == 0)) \n",
" else (1 - smoothing_constant) * moving_loss + (smoothing_constant) * curr_loss)\n",
"\n",
" \n",
" test_accuracy = evaluate_accuracy(test_data, net, mus)\n",
" train_accuracy = evaluate_accuracy(train_data, net, mus)\n",
" train_acc.append(np.asscalar(train_accuracy))\n",
" test_acc.append(np.asscalar(test_accuracy))\n",
" print(\"Epoch %s. Loss: %s, Train_acc %s, Test_acc %s\" %\n",
" (e, moving_loss, train_accuracy, test_accuracy))\n",
" \n",
"plt.plot(train_acc)\n",
"plt.plot(test_acc)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For demonstration purposes, we can now take a look at one particular weight by plotting its distribution."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAW8AAAD8CAYAAAC4uSVNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl8U9ed9/HPT5J3Y+N9A2NWA2bHIWQnJGEN2dpOQhLa\naZqhnaaT6Tad9uksT/t0pp1Mp03b6UaTTtM2zU46gbAkJCSELIDZbBYDZjPejcEbXiWd5w+L1HEA\ny7bkq+X3fr38QpaupC9X0s9H555zjxhjUEopFVxsVgdQSik1cFq8lVIqCGnxVkqpIKTFWymlgpAW\nb6WUCkJavJVSKghp8VZKqSCkxVsppYKQFm+llApCDn89cGpqqsnLy/PXwyulVEjavXv3WWNMWn/b\n+a145+XlUVRU5K+HV0qpkCQip73ZTrtNlFIqCGnxVkqpIKTFWymlgpAWb6WUCkJavJVSKghp8VZK\nqSCkxVsppYKQ38Z5KxVIGlo7+d99VTS2dX143ewxSdw0MQ2bTSxMptTgeF28ReQrwMOAAUqAzxpj\nOvwVTClfKK5o5HfvnmJ9cTVdLjfiqdMXl24dkxLLg1eP4d55o0mIjrAuqFID5FXxFpEc4FFgqjGm\nXUSeB+4DfufHbEoNmjGGX7x1nP/cfIT4KAcr541m1TVjmJA+AoAup5vNB2v4/fun+LcNh/nDB6d5\n4jOFTMoYYW1wpbw0kG4TBxAjIt1ALFDln0hKDU1Ht4tvvFjMK/uruHNWNt+7axoj+rSqIx02VszM\nZsXMbHadOscXn97DPb94j5+unMXCyRkWJVfKe14dsDTGVAI/BMqBaqDJGPNa3+1EZLWIFIlIUX19\nvW+TKuWFpvZu7l3zAa/sr+IfFufz+L2zPla4+7oqL5lXvnQdeamxfO6pIv7w/qlhyarUUHhVvEUk\nCbgTGAtkA3Ei8mDf7Ywxa4wxhcaYwrS0fk+KpZRPudyGR5/Zy8HKJn69ai6P3DwBEe8ORmYlxvD8\n56/hlsnp/MsrB9laWufntEoNjbdDBW8FThpj6o0x3cBa4Fr/xVJq4B7bVMrbR+v5zp0FLC7IHPD9\nYyMd/HTlbKZkJvDoM3spq2v1Q0qlfMPb4l0OzBeRWOlpytwCHPZfLKUG5uW9Ffx62wlWzR/DA1eP\nGfTjxEY6+M1nCol02Fj9+yKa2rp9mFIp3/G2z3sH8CKwh55hgjZgjR9zKeW10ppm/vGlEuaPS+Zf\nVkwd8uPljIzhV6vmcuZ8G199fh/m4rhCpQKI1zMsjTH/aoyZbIyZZoxZZYzp9Gcwpbzhdhu+tbaE\n+CgHP79/DhF230waviovmX9cMpk3SutYX1ztk8dUypd0erwKak/vLGdveSP/tHwKKfFRPn3sz143\nluk5iXxn3SGa2rX7RAUWLd4qaNU1d/DYxlKum5DC3bNzfP74dpvw/Xumc+5CJ49tKvX54ys1FFq8\nVdD6zrpDdLrcfO+u6V4PCRyoaTmJfPa6sTy9o5zdp8/55TmUGgwt3ioovXOsnldLqvm7mycwNjXO\nr8/11dsmkZ0YzbdfPoDLrQcvVWDQ4q2CjjGG/9x8hJyRMay+aZzfny8uysE3l02htKaF9cV6VggV\nGLR4q6Dz+qFaiiua+PtbJhLlsA/Lc94+PYvJmSN4fMsxnC73sDynUleixVsFFbfb8KPXjzI2NY57\n5vj+IOXl2GzCV2+bxMmzF1i7t3LYnlepy9HirYLK+pJqSmta+PKtE3H4aEy3t26bmsGMUYn8ZMsx\nOp2uYX1upfrS4q2ChtPl5vHXj5KfMYIVM7KH/flFhK8tyqeysZ3nd50Z9udXqjct3ipovLK/ihNn\nL/DVRZMsW7rsxompzMtL5r+3lmnrW1lKi7cKCsYYfvPOSSamx7NoqnWLJYgIjyycQG1zJ+v367R5\nZR0t3ioovH+8gcPVzTx8w1i/Tcjx1o0TU5mUEc8T20/qSauUZbR4q6Dwm3dOkBofyZ2zhm+EyeWI\nCA9fP47D1c28d7zB6jgqTGnxVgGvrK6FrUfqWTU/j+iI4RnX3Z87ZmWTGh/JE++csDqKClNavFXA\ne3L7KSIdNh6cn2t1lA9FR9hZNT+PrUfqKatrsTqOCkPermGZLyL7ev00i8iX/R1OqYbWTtbuqeAT\nc3J8fsrXoXpwfi5RDhtPbj9ldRQVhrxdSeeIMWaWMWYWMBdoA172azKlgGd2ltPpdPO568daHeVj\nUuKjuGfOKNbuqaCxrcvqOCrMDKbb5BbguDHmtK/DKNWb2214ZucZrh2fwoT0EVbHuaRV88fQ6XSz\ndo9OmVfDazDF+z7gGV8HUaqvbcfqqWxsZ+W8wOnr7mtqdgIzR4/kTzvLddigGlYDKt4iEgncAbxw\nmdtXi0iRiBTV19f7Ip8KY8/sLCclLpLFBZlWR7miB+blUlbXStHp81ZHUWFkoC3vpcAeY0ztpW40\nxqwxxhQaYwrT0tKGnk6FrbrmDrYcruOTc0cR6QjsQVG3z8wiPsrBMzvKrY6iwshAPxUr0S4TNQxe\n2F2By22496rRVkfpV2ykg7tmZ7O+pFoPXKph43XxFpE44DZgrf/iKHXxQGU514xLYVxavNVxvLJy\nXi5deuBSDSOvi7cx5oIxJsUY0+TPQEq9U3aWivPt3H914B6o7KsgO5GZo0fyjB64VMMksDsTVVh6\nvugMSbERLCqw7uyBg7HyqtEcq2tl75lGq6OoMKDFWwWUpvZuXj9Uyx0zs4dtfUpfWTYjiyiHjZe1\n60QNAy3eKqBsKKmmy+nmnjmjrI4yYAnRESwqyGRdcRVdTl2kWPmXFm8VUF7eU8n4tDhmjEq0Osqg\n3DMnh8a2brYeqbM6igpxWrxVwChvaGPnqXPcM2eU5QsuDNYNE1JJjY9i7Z4Kq6OoEKfFWwWMl/dW\nIgJ3zbZ+wYXBctht3DUrmzdL6zh/Qcd8K//R4q0CgjGGtXsrmD82hZyRMVbHGZJ75oyi22VYX1xl\ndRQVwrR4q4Cwp/w8pxvauGdO8La6L5qancDkzBG8pKNOlB9p8VYB4eW9lURH2Fg6PcvqKD7xiTmj\n2HemkRP1rVZHUSFKi7eyXLfLzYaSGm6bmkl8lMPqOD6xYmY2IrC+uNrqKCpEafFWlnvveAPnLnSx\nYkZotLoBMhOjmZeXzCv7q3S6vPILLd7Kcuv2VzEi2sFN+aF1GuEVM7Mpq2vlSK0uUKx8T4u3slRH\nt4vNB2pYXJAZdNPh+7N0WiZ2m/DKPh11onxPi7ey1NtH62npdHLHzGyro/hcSnwU109IZV2xdp0o\n39PirSy1bn8VyXGRXDs+xeoofrFiZjZnzrWzv0LPpKx8ayCLMYwUkRdFpFREDovINf4MpkJfW5eT\nNw7XsWx6Jg57aLYjFhVkEGm3sW6/dp0o3xrIJ+YnwCZjzGRgJnDYP5FUuNhyuI72bhcrZoRel8lF\nCdERLMhPY31xFW63dp0o3/GqeItIInAj8CSAMabLGKNnnFdDsm5/FRkJUVyVl2x1FL9aMTOb2uZO\ndp06Z3UUFUK8bXmPBeqB/xGRvSLyhGdNS6UGpaWjm7eP1rNsehY2W3CeQdBbCyenEx1hY0OJTthR\nvuNt8XYAc4BfGmNmAxeAb/bdSERWi0iRiBTV19f7MKYKNW+W1tHldLM8RKbDX0lclIMFk9LZeKBG\nu06Uz3hbvCuACmPMDs/vL9JTzD/CGLPGGFNojClMSwutCRfKtzaUVJOREMWc3CSrowyLZTOyqGvp\nZHf5eaujqBDhVfE2xtQAZ0Qk33PVLcAhv6VSIe1Cp5O3jtSzdFrod5lctHByOpEO7TpRvjOQ0SZ/\nBzwtIsXALODf/RNJhbo3S+vodLpZOi3T6ijDJj7KwYJJaWws0a4T5RteF29jzD5Pl8gMY8xdxhj9\n/qcGZeOBatJGRFEY4qNM+lo2PYua5g72ntGBWmroQnNmhApYbV1O3iytY0lBz3k/wsnCKelE2rXr\nRPmGFm81rN46Uk9Ht5tlYTDKpK+E6AhunJTKxpJqPdeJGjIt3mpYvVpSTWp8JPPGhleXyUVLp2VR\n1dTBPu06UUOkxVsNm45uF1tL67htavh1mVx069QMHDZh08Eaq6OoIKfFWw2bd46dpa3LFVajTPpK\njIng2gmpbDpQo10naki0eKths+lADQnRDuaPC83Tv3prSUEmpxvaKK3RFXbU4GnxVsOi2+Vmy+Fa\nbp2SQaQjvN92iwoyEOn5Y6bUYIX3p0gNmx0nztHU3s2SMO4yuSg1vudMipu131sNgRZvNSw2Hawm\nJsLOjZP0nDfQ03VSWtPCybMXrI6igpQWb+V3brdh88Fabp6cRnREaC0yPFgXv4Fo14kaLC3eyu/2\nlJ+nvqWTxQXaZXJR9sgYZo5K1CGDatC0eCu/23Sghki7jYWT062OElAWT8tk/5lGqhrbrY6igpAW\nb+VXxhg2HazhugkpjIiOsDpOQFni+SaiBy7VYGjxVn51qLqZivPtOsrkEsalxTMpI16LtxoULd7K\nrzYfqMEmcOuUDKujBKTFBZnsPHmOcxe6rI6igowWb+VXmw/WUpiXTEp8lNVRAtLigkzcBrYcrrU6\nigoyXhdvETklIiUisk9EivwZSoWGk2cvcKS2RUeZXEFBdgI5I2PYrEMG1QANtOV9szFmljGm0C9p\nVEi52Je7uEC7TC5HRFhckMk7ZWdp7XRaHUcFEe02UX6z+WAN03ISGJUUa3WUgLa4IIMup5u3j9Rb\nHUUFkYEUbwO8JiK7RWT1pTYQkdUiUiQiRfX1+kYMZ7XNHewtb2TxVO0y6U9hXjIpcZE6YUcNyECK\n9/XGmDnAUuAREbmx7wbGmDWeRYoL09L0HBbh7LVDPQfgFusQwX7ZbcKtUzLYWlpHp9NldRwVJAay\nenyl59864GVgnr9CqeC3+UAN41LjmJgeb3WUoLBkWiatnU7eK2uwOooKEl4VbxGJE5ERFy8Di4AD\n/gymgldTWzcfnGhgUUEmIuG53NlAXTshhfgoh07YUV7ztuWdAWwXkf3ATuBVY8wm/8VSweyN0lqc\nbqOzKgcgymFnQX4arx+qxeXW5dFU/xzebGSMOQHM9HMWFSI2HaghMyGaGTmJVkcJKkumZbK+uJqi\nU+e4OsyXilP906GCyqfau1xsO1bP4oIMbGG6QvxgLchPJ9JhY/NBnW2p+qfFW/nU20fr6eh266zK\nQYiPcnDDhFQ2H9SV5VX/tHgrn9p8sIaRsRHMG5tsdZSgtLggk8rGdg5WNVsdRQU4Ld7KZ7pdbt7w\nrBDvsOtbazBumZKOTfQc36p/+glTPvPBiQaaO5zaZTIEKfFRzBurK8ur/mnxVj6z6UANsZF2bpiY\nanWUoLa4IJOjta2cqG+1OooKYFq8lU+43YbXD9WyIF9XiB+qRZ5vLnquE3UlWryVT+wpP0+drhDv\nEzkjY5gxKlHP8a2uSIu38gldId63lkzLZH9FE5W6sry6DC3easiMMWw8UMP1E1N1hXgfubiy/Gva\ndaIuQ4u3GrKDVc1UNuoK8b40Li2e/IwRbNSuE3UZWrzVkG08UI3dJtymK8T71OJpmew6dY76lk6r\no6gApMVbDdmmAzXMH5dMUlyk1VFCytJpmRgDrx/Sc52oj9PirYbkWG0Lx+svfNhHq3xncuYIxqTE\n6pBBdUlavNWQbDxQgwg6RNAPRIQl0zJ5r+wsTW3dVsdRAWZAxVtE7CKyV0TW+yuQCi6bDtQwJzeJ\n9IRoq6OEpCUFmTjdhjdKtetEfdRAW95/Dxz2RxAVfE43XOBQdTNLdZSJ38wcNZKsxGg2lGjXifoo\nr4u3iIwClgNP+C+OCiYXC4oOEfQfm62n62TbsXpaOrTrRP3FQFrejwPfANx+yqKCzIaSamaOHsmo\npFiro4S05dOz6HK6ebO0zuooKoB4u3r87UCdMWZ3P9utFpEiESmqr6/3SUAVmMob2iipbGL5dG11\n+9uc3CQyEqJ4tbja6igqgHjb8r4OuENETgHPAgtF5I99NzLGrDHGFBpjCtPS0nwYUwWajQd6CsnS\naVkWJwl9NpuwdFoWbx2tp7XTaXUcFSC8Kt7GmG8ZY0YZY/KA+4A3jTEP+jWZCmgbSqqZMSqR0cna\nZTIclmnXiepDx3mrAas438b+iiaWTddW93ApHJNE+ogoNmjXifIYcPE2xrxljLndH2FUcNjoGWWy\nTLtMhk1P10kmW4/UcUG7ThTa8laD8GpJNdNyEshN0S6T4bR0ehad2nWiPLR4qwGpbGxn35lG7TKx\nwFV5yaTGR7GhRLtOlBZvNUCvFlcBcPv0bIuThB+7TVg+PZM3S+t01InS4q0GZt3+nok52mVijRUz\ns+l0utmip4kNe1q8lddOnr1ASWUTK2Zol4lV5uQmkZ0Yzbr9VVZHURbT4q28tt5TMJZr8baMzSbc\nPjObbcfq9TSxYU6Lt/LauuIq5uUlk5UYY3WUsHb7jCy6XYbNukhDWNPirbxypKaFo7WtrJiprW6r\nTc9JZExKLOuKtesknGnxVl5Zt78Km8ASnZhjORFhxYxs3i07y9lWXZw4XGnxVv0yxrCuuIprx6eS\nNiLK6jiKnlEnbgMbdcx32NLirfpVXNHE6YY2btcDlQEjP3MEkzLieUVHnYQtLd6qX3/eV0mk3aan\nfw0wd87KYdep85w512Z1FGUBLd7qipwuN+v2V3HLlHQSYyOsjqN6uWNmzyxXbX2HJy3e6oq2l53l\nbGsXd83OsTqK6mN0cizz8pJZu6cCY4zVcdQw0+KtrujPeytJjIlgQb6ujBSI7pqdw/H6CxysarY6\nihpm3q5hGS0iO0Vkv4gcFJHv+DuYst6FTiebD9ayfEYWUQ671XHUJSyfnkWk3cbaPZVWR1HDzNuW\ndyew0BgzE5gFLBGR+f6LpQLB5oM1tHe7uFu7TAJWYmwECyen88r+Kpwut9Vx1DDydg1LY4xp9fwa\n4fnRTrYQ9/LeSkYlxTA3N8nqKOoK7pqdw9nWTt493mB1FDWMvO7zFhG7iOwD6oDXjTE7/BdLWa2u\nuYN3y85y16wcbDaxOo66gpsnp5EQ7eDlPRVWR1HDyOvibYxxGWNmAaOAeSIyre82IrJaRIpEpKi+\nvt6XOdUw+/O+StwGHWUSBKIcdpbPyGbzwVpaOvRMg+FiMAsQNwJbgSWXuG2NMabQGFOYlqajE4KV\nMYbniyqYkzuSCenxVsdRXvhU4Sjau128qqvLhw1vR5ukichIz+UY4Dag1J/BlHX2nWmkrK6Vvyoc\nbXUU5aXZo3v+0L6wW7tOwoW3Le8sYKuIFAO76OnzXu+/WMpKL+yuIDrCposuBBER4VNzR7H79HnK\n6lr7v4MKet6ONik2xsw2xswwxkwzxnzX38GUNdq7XKzbV8Wy6VmMiNbp8MHk7jk52G3Ci9r6Dgs6\nw1J9xOaDNbR0OvnUXO0yCTbpI6K5OT+Nl/ZU6JjvMKDFW33EC7vPMDo5hqvHJlsdRQ3CJ+eOpr6l\nk23HdLRXqNPirT505lwb75Y18Km5o3Vsd5BaODmdlLhInt+lXSehTou3+tDzRWcQgU/MHWV1FDVI\nkQ4bd8/OYcvhWupaOqyOo/xIi7cCoNvl5tldZ1iYn07OSF0dPpitvDoXp9vwQpG2vkOZFm8FwJZD\ntdS3dPLA/Fyro6ghGp8Wz7XjU/jTjnJcbj0FUajS4q0A+OOO0+SMjOGmSelWR1E+8MDVY6hsbGfb\nUT1wGaq0eCtO1LfyblkDK+eNxq4HKkPCbVMzSI2P4ukdp62OovxEi7fimZ3lOGyi0+FDSKTDxr1X\njeLN0joqG9utjqP8QIt3mOvodvHC7goWFWSQnhBtdRzlQ/ddlYsBnt1ZbnUU5QdavMPcq8XVNLZ1\n8+DVY6yOonxsdHIsN+en8+yuM3Q5dcZlqNHiHcaMMfz23ZOMT4vjmvEpVsdRfrDqmjHUt3SyoURP\nFRtqtHiHsV2nznOwqpmHrh+LiB6oDEU3TUxjXFocT24/iTE6bDCUaPEOY09uP8HI2Ajuma0zKkOV\nzSY8dN1YSiqbKDp93uo4yoe0eIep8oY2XjtUy/3zcomJtFsdR/nRPXNySIyJ4Ml3TlodRfmQFu8w\n9dT7p7CL8Olr8qyOovwsNtLB/Vfn8tqhGs6ca7M6jvIRb5dBGy0iW0XkkIgcFJG/93cw5T8tHd08\nt+sMy2dkkZmowwPDwaevGYOI8Lv3TlkdRfmIty1vJ/A1Y8xUYD7wiIhM9V8s5U/PF1XQ2unkoevG\nWh1FDZOsxBiWTc/iuV1naNYV5kOCt8ugVRtj9ngutwCHgRx/BlP+0el08ZttJ5g3NpmZo0daHUcN\no9U3jKO108kf3tcp86FgwH3eIpIHzAZ2XOK21SJSJCJF9fV6QpxA9PKeSmqaO/jSzROsjqKG2fRR\nidw4KY3fbj9Je5fL6jhqiAZUvEUkHngJ+LIxprnv7caYNcaYQmNMYVpamq8yKh9xutz88u3jTM9J\n5IaJqVbHURZ4ZMF4Gi508dwunTIf7Lwu3iISQU/hftoYs9Z/kZS/vFpSzemGNh65eYJOyglTV49L\n4aq8JH697YROmQ9y3o42EeBJ4LAx5kf+jaT8we02/GLrcSamx7NoaobVcZSFvnjzBKqbOvjz3kqr\no6gh8LblfR2wClgoIvs8P8v8mEv52BuldRypbeGLN4/XxYXD3IJJaRRkJ/DLt4/rSjtBzNvRJtuN\nMWKMmWGMmeX52eDvcMo33G7D41uOkpscy4oZ2VbHURYTEb508wROnr2gre8gpjMsw8CGA9UcrGrm\nK7dNxGHXl1zB4oJMpuUk8OMtR7XvO0jpJznEOV1ufvTaUSZlxHPHTB2ar3rYbMLXF+VTcb6dZ3Xk\nSVDS4h3iXtpTwYmzF/j6onxdn1J9xE2T0pg3NpmfvlFGW5fT6jhqgLR4h7CObhc/2XKMWaNHcpuO\nMFF9iAjfWJzP2dZOnnpPZ10GGy3eIezpHeVUNXXwjcX5Oq5bXVJhXjILJ6fzq7eP09Sm5zwJJlq8\nQ9S5C138ZMtRbpiYyrUTdDaluryvL8qnpaObx984anUUNQBavEPUj14/woUuF/98u578UV3Z1OwE\n7puXy+/fP82x2har4ygvafEOQYeqmvnTjnJWzR/DpIwRVsdRQeBrt00iLtLOd9cf0rUug4QW7xBj\njOE76w6SGBPBV26dZHUcFSRS4qP48q2TeOfYWV4/VGt1HOUFLd4hZkNJDTtOnuNri/JJjI2wOo4K\nIquuGcPE9Hi+9+phOrr1lLGBTot3CGnu6Ob/rT/ElKwEVs7LtTqOCjIRdhv/uqKA8nNt/GJrmdVx\nVD+0eIeQ728opa6lgx/cM10n5KhBuX5iKnfPzuEXbx3ncPXHTtmvAogW7xDx3vGzPLOznIdvGKfL\nm6kh+Zfbp5IYE8E/vlSM06XnPQlUWrxDQHuXi2++VEJeSqwepFRDlhQXyXfuLKC4oonfvnvS6jjq\nMrR4h4AfvnaE8nNt/OATM4iJtFsdR4WA5dOzuG1qBv/12lFO1LdaHUddwkCWQfutiNSJyAF/BlID\ns+1oPU9uP8mq+WOYPy7F6jgqRIgI37trGjGRdh59di+dTh19EmgG0vL+HbDETznUINS1dPDV5/eR\nnzGCby+fYnUcFWIyEqJ57BMzOFDZzGObjlgdR/XhdfE2xmwDzvkxixoAt9vwtef309rp5Gf3zyY6\nQrtLlO8tKsjkM9eM4cntJ3mzVCfvBBLt8w5Sa945wTvHzvKvKwp0Crzyq28tm8KUrAS+/kIxNU0d\nVsdRHj4t3iKyWkSKRKSovr7elw+tetl2tJ7HNpWyfEYW91012uo4KsRFR9j57/tn09Ht4vN/KNLZ\nlwHCp8XbGLPGGFNojClMS0vz5UMrj7K6Vh750x7yMxN47BMz9DzdaliMT4vnx/fOYn9FE994sVhP\nXhUAtNskiDS2dfHwU7uIctj4zafnEhflsDqSCiOLCzL5h8X5vLK/ip/r9HnLDWSo4DPA+0C+iFSI\nyOf8F0v11el08bd/3ENVYwe/XjWXUUmxVkdSYeiLC8Zz9+wcfvjaUdYXV1kdJ6x53XQzxqz0ZxB1\ned0uN1/6017eP9HAj++dydwxyVZHUmFKRPj+PdOpON/GV57bR1ykg5snp1sdKyxpt0mAc3mGBL5+\nqJbv3lnA3bNHWR1JhbnoCDtP/vVV5GeO4At/3M17x89aHSksafEOYG634f+sLeGV/VV8c+lkPn1N\nntWRlAIgITqC3z90NbnJsTz8VBG7T+sUkOGmxTtAdTpdPPrsXp4rOsOjCyfwhZvGWx1JqY9Ijovk\n6YevJn1EFA8+sZOtR+qsjhRWtHgHoNZOJ5/7XRHri6v51tLJfHVRvtWRlLqk9IRoXvjCtYxNjeNv\nniriz3srrY4UNrR4B5jqpnZWrvmA90808MNPzeTz2uJWAS5tRBTPfn4+hXlJfPm5ffzq7eM6DnwY\naPEOIO8fb2DFz7Zzor6VNavm8sm5enBSBYeE6Ah+99l53D4jix9sLOWRP+2htdNpdayQprM8AoDb\nbXhi+wn+Y9MR8lJieXb1NUxIj7c6llIDEh1h52crZzNjVCI/2FjKkZoWfr1qLhPS9dw7/qAtb4ud\nOdfGA0/s4N83lLJoagb/+6XrtXCroCUirL5xPH98+Goa27pZ/tPt/GbbCVxu7UbxNS3eFnG7DX/4\n4DSLH99GSWUT379nOr94YA7xOuVdhYBrx6ey8cs3cOOkNP5tw2H+6tfvc1xX5PEp8deBhcLCQlNU\nVOSXxw52O0408L1XD1NS2cT1E1L5j0/OIGdkjNWxlPI5Ywx/3lfJ/33lEG1dTj59TR6PLpxIYmyE\n1dEClojsNsYU9redNvOG0dHaFn702lE2HawhKzGaH987k7tm5eiZAVXIEhHunj2K6yek8V+vHeG3\n755k7Z4K/m7hRFbOy9U1V4dAW97DYP+ZRn6+tYzXDtUSG2nnCzeN529uGKdvXBV2DlU1871XD/He\n8QZS4iJ56PqxrLpmDAnR2hK/yNuWtxZvP2nrcrK+uJpndpazt7yRxJgI/vraPP762jyS4iKtjqeU\npXaePMeBHCOWAAAJlUlEQVTPt5bx9tF64iLt3DErm5Xzcpmekxj230S1eFug0+ni3bKzbCipYfOB\nGlo6nYxPi+P+q8dw71Wj9WCkUn0cqGziqfdOsa64io5uN1OyErh9RhZLp2UyLi08R11p8R4mZ861\nsb3sLNuPnWXbsXpaOpyMiHawaGom980bTeGYpLBvSSjVn+aObv53XxVr91Swt7wRgPyMEdw4KZXr\nJ6YxLy85bLoZfV68RWQJ8BPADjxhjPnBlbYPxeLd1N7NkZoWSiqb2Ft+nr3ljVQ2tgOQmRDNjZNS\nWTo9i+vGpxLp0FGYSg1GdVM7mw7U8NrBWnafPk+Xy02k3cbU7ARm545kdm4SU7MSyEuJxWEPvc+Z\nT4u3iNiBo8BtQAWwC1hpjDl0ufsEY/E2xtDc4aSqsZ3qpnZON7Rx6uwFTjW0UVbX+mGhBsgZGcOs\n3JEUjknihompjE+L1xa2Uj7W3uVix8kG3j/RwN7TjRRXNtLR7QYg0mFjYno849LiGZsSy5iUOEYl\nxZA9MoaMhOigbUD5eqjgPKDMGHPC8+DPAncCly3ew8kYQ5fLTZfTTbfL0NHtotPppqPbRVuXi/Yu\nF21dTlo7e35aOpw0tXdz/kIXje3dNLR2Ut/aydmWLtr7rIwdF2knLzWOuWOSeGB+LlMyE5ianUBG\nQrRF/1ulwkdMpJ0F+eksyO9Zrafb5eZobQul1S2U1jRTWtPCvjPnebW4it6TOEUgOTaS1PgoUkdE\nkhTb8zMyNoKE6Ajiox3ER/X8xETaiY20ExNhJzrCTpTDRpTDToRDiLTbsNskIBtm3hbvHOBMr98r\ngKt9Hwf+7dVDvFlahzHgNgaXMbjdnstug9sYnG6D02Xodrlxus2gpt5GR9gYGdPzYqbERzInN4nU\n+CgyE6LJHhlD1shoRifFkhofGZAvnFLhKMJuoyA7kYLsxI9c3+l0UXG+vedbc2MHlY3tngZZJ2db\nO6lqbKaxrYum9m4GWi5EIMLWU8QddsFhE+w2Gw6bYBOw2QSbiKfIg02En62czZSsBB/+zz/Op8Mf\nRGQ1sBogNzd3UI+RlRjD5KwEbOLZMSIfXrbbenbQxZ0XYe/ZmXabjSiHjUh7z3XRvf6C9vxVdRAb\naSc+ysGIaAdxUQ6iI8Lj4IdS4SDKYWd8Wjzj+xmh4nYb2rpdtHR009Lh5EKn0/PN3EWH00VHt/vD\nb+7dH36b7/lG73Jf/LenUel0uXGbnsd0GfNhg9MYiBmG+uJt8a4ERvf6fZTnuo8wxqwB1kBPn/dg\nAj10/VgeYuxg7qqUUldks8mH3SVZif1vH8i87dHfBUwUkbEiEgncB7ziv1hKKaWuxKuWtzHGKSJf\nAjbTM1Twt8aYg35NppRS6rK87vM2xmwANvgxi1JKKS8F50BIpZQKc1q8lVIqCGnxVkqpIKTFWyml\ngpAWb6WUCkJ+OyWsiLQAR/zy4EOXCpy1OsRlaLbBCdRsgZoLNNtg+TvbGGNMWn8b+XN1gCPenBnL\nCiJSpNkGTrMNXKDmAs02WIGSTbtNlFIqCGnxVkqpIOTP4r3Gj489VJptcDTbwAVqLtBsgxUQ2fx2\nwFIppZT/aLeJUkoFoSEVbxFJFpHXReSY59+ky2y3SUQaRWR9n+vHisgOESkTkec8p5v1iQFk+4xn\nm2Mi8ple178lIkdEZJ/nJ90HmZZ4HrNMRL55idujPPuhzLNf8nrd9i3P9UdEZPFQs/gil4jkiUh7\nr330K1/m8jLbjSKyR0ScIvLJPrdd8rUNkGyuXvvN56dX9iLbV0XkkIgUi8gbIjKm121+229DzGX1\nPvuCiJR4nn+7iEztdZvfPp+XZYwZ9A/wGPBNz+VvAv9xme1uAVYA6/tc/zxwn+fyr4C/HUqegWYD\nkoETnn+TPJeTPLe9BRT6MI8dOA6MAyKB/cDUPtt8EfiV5/J9wHOey1M920cBYz2PYw+AXHnAAV/t\no0FmywNmAL8HPunNa2t1Ns9trRbvt5uBWM/lv+31mvptvw0lV4Dss4Rel+8ANnku++3zeaWfoXab\n3Ak85bn8FHDXpTYyxrwBtPS+TkQEWAi82N/9/ZhtMfC6MeacMeY88DqwxIcZevtwEWdjTBdwcRHn\ny2V+EbjFs5/uBJ41xnQaY04CZZ7HszqXv/WbzRhzyhhTDLj73Nffr+1QsvmbN9m2GmPaPL9+QM/q\nWODf/TaUXP7mTbbmXr/GARcPGPrz83lZQy3eGcaYas/lGiBjAPdNARqNMU7P7xX0LHTsK95ku9TC\nyr0z/I/nK9I/+6BY9fdcH9nGs1+a6NlP3tzXilwAY0Vkr4i8LSI3+CjTQLL5477D8fjRIlIkIh+I\niC8bLTDwbJ8DNg7yvsOVCwJgn4nIIyJynJ5v9o8O5L6+1u8MSxHZAmRe4qZv9/7FGGNEZFiHrvg5\n2wPGmEoRGQG8BKyi5+uv+otqINcY0yAic4E/i0hBnxaKurQxnvfXOOBNESkxxhwf7hAi8iBQCNw0\n3M99JZfJZfk+M8b8HPi5iNwP/BPg82Mp3uq3eBtjbr3cbSJSKyJZxphqEckC6gbw3A3ASBFxeFpz\nl1zU2M/ZKoEFvX4fRU9fN8aYSs+/LSLyJ3q+Bg2leHuziPPFbSpExAEk0rOfvFoAerhzmZ4Ov04A\nY8xuT4tkElA0jNmudN8Ffe77lk9S/eXxB/2a9Hp/nRCRt4DZ9PSVDls2EbmVnobOTcaYzl73XdDn\nvm8FQK6A2Ge9PAv8cpD39Y0hdvL/Jx89KPjYFbZdwMcPWL7ARw9YftFXnfneZKPnoMxJeg7MJHku\nJ9PzRy3Vs00EPf28XxhiHgc9B3/G8pcDIgV9tnmEjx4YfN5zuYCPHhA5ge8OWA4lV9rFHPQc6KkE\nkn34Gvabrde2v+PjByw/9toGSLYkIMpzORU4Rp+DY8Pwml4sfBO9+UwEQK5A2GcTe11eARR5Lvvt\n83nFzEP8D6cAb3h25JaLLzI9X3ee6LXdO0A90E5Pf9Biz/XjgJ30dPC/cPHF8dGL4W22hzzPXwZ8\n1nNdHLAbKAYOAj/xxYsBLAOOet6c3/Zc913gDs/laM9+KPPsl3G97vttz/2OAEt9+iYYZC7gE579\nsw/YA6zw+Ru0/2xXed5TF+j5lnLwSq9tIGQDrgVKPB/4EuBzFmTbAtR6Xrt9wCvDsd8GmytA9tlP\ner3ft9KruPvz83m5H51hqZRSQUhnWCqlVBDS4q2UUkFIi7dSSgUhLd5KKRWEtHgrpVQQ0uKtlFJB\nSIu3UkoFIS3eSikVhP4/6wFSm8ylAfcAAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x10e316c18>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def show_weight_dist(mean, variance):\n",
" sigma = nd.sqrt(variance)\n",
" x = np.linspace(mean.asscalar() - 4*sigma.asscalar(), mean.asscalar() + 4*sigma.asscalar(), 100)\n",
" plt.plot(x, gaussian(nd.array(x, ctx=ctx), mean, sigma).asnumpy())\n",
" plt.show()\n",
" \n",
"mu = mus[0][0][0]\n",
"var = softplus(rhos[0][0][0]) ** 2\n",
"\n",
"show_weight_dist(mu, var)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Great! We have obtained a fully functional Bayesian neural network. However, the number of weights now is twice as high as for traditional neural networks. As we will see in the final section of this notebook, we are able to drastically reduce the number of weights our network uses for prediction with _weight pruning_."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Weight pruning\n",
"\n",
"To measure the degree of redundancy present in the trained network and to reduce the model's parameter count, we now want to examine the effect of setting some of the weights to $0$ and evaluate the test accuracy afterwards. We can achieve this by ordering the weights according to their signal-to-noise-ratio, $\\frac{|\\mu_i|}{\\sigma_i}$, and setting a certain percentage of the weights with the lowest ratios to $0$."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can calculate the signal-to-noise-ratio as follows:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def signal_to_noise_ratio(mus, sigmas):\n",
" sign_to_noise = []\n",
" for j in range(len(mus)):\n",
" sign_to_noise.extend([nd.abs(mus[j]) / sigmas[j]])\n",
" return sign_to_noise"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We further introduce a few helper methods which turn our list of weights into a single vector containing all weights. This will make our subsequent actions easier."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def vectorize_matrices_in_vector(vec):\n",
" for i in range(0, (num_layers + 1) * 2, 2):\n",
" if i == 0:\n",
" vec[i] = nd.reshape(vec[i], num_inputs * num_hidden)\n",
" elif i == num_layers * 2:\n",
" vec[i] = nd.reshape(vec[i], num_hidden * num_outputs)\n",
" else:\n",
" vec[i] = nd.reshape(vec[i], num_hidden * num_hidden)\n",
" \n",
" return vec\n",
"\n",
"def concact_vectors_in_vector(vec):\n",
" concat_vec = vec[0]\n",
" for i in range(1, len(vec)):\n",
" concat_vec = nd.concat(concat_vec, vec[i], dim=0)\n",
" \n",
" return concat_vec\n",
"\n",
"def transform_vector_structure(vec):\n",
" vec = vectorize_matrices_in_vector(vec)\n",
" vec = concact_vectors_in_vector(vec)\n",
" \n",
" return vec"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In addition, we also have a helper method which transforms the pruned weight vector back to the original layered structure."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from functools import reduce\n",
"import operator\n",
"\n",
"def prod(iterable):\n",
" return reduce(operator.mul, iterable, 1)\n",
"\n",
"def restore_weight_structure(vec):\n",
" pruned_weights = []\n",
" \n",
" index = 0\n",
" \n",
" for shape in layer_param_shapes:\n",
" incr = prod(shape)\n",
" pruned_weights.extend([nd.reshape(vec[index : index + incr], shape)])\n",
" index += incr\n",
" \n",
" return pruned_weights"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The actual pruning of the vector happens in the following function. Note that this function accepts an ordered list of percentages to evaluate the performance at different pruning rates. In this setting, pruning at each iteration means extracting the index of the lowest signal-to-noise-ratio weight and setting the weight at this index to $0$."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
"source": [
"def prune_weights(sign_to_noise_vec, prediction_vector, percentages):\n",
" pruning_indices = nd.argsort(sign_to_noise_vec, axis=0)\n",
" \n",
" for percentage in percentages:\n",
" prediction_vector = mus_copy_vec.copy()\n",
" pruning_indices_percent = pruning_indices[0:int(len(pruning_indices)*percentage)]\n",
" for pr_ind in pruning_indices_percent:\n",
" prediction_vector[int(pr_ind.asscalar())] = 0\n",
" pruned_weights = restore_weight_structure(prediction_vector)\n",
" test_accuracy = evaluate_accuracy(test_data, net, pruned_weights)\n",
" print(\"%s --> %s\" % (percentage, test_accuracy))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Putting the above functions together:"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.1 --> 0.9777\n",
"0.25 --> 0.9779\n",
"0.5 --> 0.9756\n",
"0.75 --> 0.9602\n",
"0.95 --> 0.7259\n",
"0.99 --> 0.3753\n",
"1.0 --> 0.098\n"
]
}
],
"source": [
"sign_to_noise = signal_to_noise_ratio(mus, sigmas)\n",
"sign_to_noise_vec = transform_vector_structure(sign_to_noise)\n",
"\n",
"mus_copy = mus.copy()\n",
"mus_copy_vec = transform_vector_structure(mus_copy)\n",
"\n",
"prune_weights(sign_to_noise_vec, mus_copy_vec, [0.1, 0.25, 0.5, 0.75, 0.95, 0.99, 1.0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Depending on the number of units used in the original network and the number of training epochs, the highest achievable pruning percentages (without significantly reducing the predictive performance) can vary. The paper, for example, reports almost no change in the test accuracy when pruning 95% of the weights in a 2x1200 unit Bayesian neural network, which creates a significantly sparser network, leading to faster predictions and reduced memory requirements."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion\n",
"\n",
"We have taken a look at an efficient Bayesian treatment for neural networks using variational inference via the \"Bayes by Backprop\" algorithm (introduced by the \"[Weight Uncertainity in Neural Networks](https://arxiv.org/abs/1505.05424)\" paper). We have implemented a stochastic version of the variational lower bound and optimized it in order to find an approximation to the posterior distribution over the weights of a MLP network on the MNIST data set. As a result, we achieve regularization on the network's parameters and can quantify our uncertainty about the weights accurately. Finally, we saw that it is possible to significantly reduce the number of weights in the neural network after training while still keeping a high accuracy on the test set.\n",
"\n",
"We also note that, given this model implementation, we were able to reproduce the paper's results on the MNIST data set, achieving a comparable test accuracy for all documented instances of the MNIST classification problem."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For whinges or inquiries, [open an issue on GitHub.](https://github.com/zackchase/mxnet-the-straight-dope)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}