-
Notifications
You must be signed in to change notification settings - Fork 205
/
multi-task-learning-example.ipynb
309 lines (309 loc) · 22.4 KB
/
multi-task-learning-example.ipynb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import sys\n",
"import numpy as np\n",
"np.random.seed(0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Multi loss layer"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Couldn't import dot_parser, loading of dot files will not be possible.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"from keras.layers import Input, Dense, Lambda, Layer\n",
"from keras.initializers import Constant\n",
"from keras.models import Model\n",
"from keras import backend as K\n",
"\n",
"# Custom loss layer\n",
"class CustomMultiLossLayer(Layer):\n",
" def __init__(self, nb_outputs=2, **kwargs):\n",
" self.nb_outputs = nb_outputs\n",
" self.is_placeholder = True\n",
" super(CustomMultiLossLayer, self).__init__(**kwargs)\n",
" \n",
" def build(self, input_shape=None):\n",
" # initialise log_vars\n",
" self.log_vars = []\n",
" for i in range(self.nb_outputs):\n",
" self.log_vars += [self.add_weight(name='log_var' + str(i), shape=(1,),\n",
" initializer=Constant(0.), trainable=True)]\n",
" super(CustomMultiLossLayer, self).build(input_shape)\n",
"\n",
" def multi_loss(self, ys_true, ys_pred):\n",
" assert len(ys_true) == self.nb_outputs and len(ys_pred) == self.nb_outputs\n",
" loss = 0\n",
" for y_true, y_pred, log_var in zip(ys_true, ys_pred, self.log_vars):\n",
" precision = K.exp(-log_var[0])\n",
" loss += K.sum(precision * (y_true - y_pred)**2. + log_var[0], -1)\n",
" return K.mean(loss)\n",
"\n",
" def call(self, inputs):\n",
" ys_true = inputs[:self.nb_outputs]\n",
" ys_pred = inputs[self.nb_outputs:]\n",
" loss = self.multi_loss(ys_true, ys_pred)\n",
" self.add_loss(loss, inputs=inputs)\n",
" # We won't actually use the output.\n",
" return K.concatenate(inputs, -1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Evaluate on synthetic data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"N = 100\n",
"nb_epoch = 2000\n",
"batch_size = 20\n",
"nb_features = 1024\n",
"Q = 1\n",
"D1 = 1 # first output\n",
"D2 = 1 # second output"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def gen_data(N):\n",
" X = np.random.randn(N, Q)\n",
" w1 = 2.\n",
" b1 = 8.\n",
" sigma1 = 1e1 # ground truth\n",
" Y1 = X.dot(w1) + b1 + sigma1 * np.random.randn(N, D1)\n",
" w2 = 3\n",
" b2 = 3.\n",
" sigma2 = 1e0 # ground truth\n",
" Y2 = X.dot(w2) + b2 + sigma2 * np.random.randn(N, D2)\n",
" return X, Y1, Y2"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAANUAAAB0CAYAAAASAHfIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAExxJREFUeJztnX9sZNV1xz9nbK9hvRuHeGlKF7wGCUVK0pREUSpRVyHh\nR2lMtyhSUXYh20ALWgEClKbgNWoZ/ohxQoW6G1ihXQoCsQ6QJk2sXSoCNJWy2pKSIJQUaOiW2MuG\nbYMXavYHsndmTv94b+znmXffj5n33rwZ349k2fN+3je+33fOPffce0VVsVgsyVFodQEslk7Dispi\nSRgrKoslYayoLJaEsaKyWBLGispiSRgrKoslYayoLJaEsaKyWBKmu9UF8CIibwMzrS6HxWJgg6qe\nGXZQrkQFzKjqp1tdiDwyNLpvMzAODAKHgLHpiZHJ1pZqZSEiP410XJ5y/0Tkp1ZU9biC2g2s9mw+\nCVzfiLA6RaBZP0fU+mnbVO3BOMsFhft5PO6FPALdAIj7e7e7vW3I83NYUbUHgzG3B5GYQFtMbp/D\niqo9OBRzexBJCrSV5PY5rKjagzGcNpSXk+72uCQp0FaS2+ewomoD3Mb39TjdDer+DgxSDI3u2zw0\num96aHRfxf1dbWvsNZzyy0QLnT5JvmgSxUb/2hRT5Mvdvh1YV3PKSRxhjuM06mtR4Jp2igLmNfpn\nRZUhSVWCgBD7I8C11Dfgq8y49xbT/umJkaE0KmsnhPGtqHJGkn1NQ6P7pvG3NmWgK+BUxanQfudW\n91+TVDmrJN3P1ipsP1X+SDIEbIpwBQkKXAuBIx7T/jRC1bkNf6eBFVV2JBkCNkW4ygHnnGTJ5dpJ\nvbCqjfw0QtW5DX+ngRVVdiQZAjZFvh702Q4wi8fVmp4YuRnHzfOLJhrL00S2Qm7D32mQt4TaTmYM\n/3ZF7BCwG+UD/+jfAb/tftcA/NozY8Dj1AczxL1uI22gxJ49KdIMnNhARYa0SwRsaHRfBf8IoU5P\njDTk3eTp2RsNnEStn9ZSZUiAdcgbFfyDHpVGL5izZw8KnDRdxqZFJSLnAI8BH8bxz3ep6nYR+RDw\nJDAETANXqeq7zd7Pkgkma+S7PS0rlKJ1SzVwkoSlKgF/paovicha4Gci8izwFeB5VZ0QkVFgFLgj\ngft1JHlyjzD3ZdUFFnxcqeoQDJop/9DovvuBG1lyQ32v2+D3Fvn5GqFpUanqEeCI+/cxEXkNWA/8\nKXCRe9ijwL9iReVLWhXTcJ8oFTBOYCGWKxWlDO4xXkH5XreJ7y3VwEmiIXURGQI+CfwE+LArOID/\nwXEPLf6k3jkaZ1BfzATeyK5UjDKMY06lGqw5zu97225IJgYaS1COQ2KBChFZA3wXuE1V3xNZ+k5U\nVUXEN8woIjcAN7gfa5NAOxKft3UWnaOxLEqUwEJIv5WfKxW1DEHPfTTCcetYqku+1ivNwEkilkpE\nenAEtUdVv+du/l8ROcvdfxbwG79zVXWXqn7aDVXOJlGePGN4W5tIsnM0DeGaLIri70pFLUPQc3/A\nI+ao30+mKVFJRP8E+AfgNVW9z7NrCvhzYML9/YNm79Uh+L2tBacieiuor4/fREAjtHHewLWNgjSc\nZyqDuEnC1fuNAQ8DvT7HrmLJsvm1jWKXNWmSsFR/AHwZ+LyIvOz+fAFHTJeKyH8Bl7ifLeZ/ruDk\n7hl9fIOVe9yNlIWxF3O+X6MTqcRNP/JLr6qyeD/3uY8F3HcQjG0jk7eTWUpU06JS1f2qKqr6CVW9\nwP15WlWPqurFqnq+ql6iqu8kUeAOIOif2wW8j9lCmKzcjUGV3913LfWuWhlHlNM4AxvjBktijb6t\nEYEf3vsNBNx38TucnhiZnJ4YGZqeGClMT4wMAbfGKVMa2ITa7Al6W0NwRQ6yckGV30+MAGtZskqm\nIFGYixcrilYVAebhJ4Mh1tHUXmu4TEljc/9agFtpHsM8/sk3xy5gcKLxHPc8Uy5fFFIZZh/wLFUr\nZnrOB9ws+8yxgxRzjFtBg777iqGPJWyAoYlm2hMCPJzCJJVjwHzNtnmCx3RpqwQVByuq1hHWtqoL\nFkQYYGgizOUMo5eIIemAWZz88BteAm0+/sqKqnVErejL2lghAwx98WlnzFJvJcLaAaEh6ZgRxHGc\n8LiXarg8t9OPRcGKqkUYGtTGxntC9xvDedsPAD2e3bP4W0AvUaxEnHQrY0dwHoINzWDHU7WQ2lSZ\ngMZ7bQdt7CRSn/O8rtdq4ID7900+p1fbOmHEydoI7IxOffxVsb+uo5viXCL3s5YqX0RxexpNvjWF\n1RfPd13Lq1negToLXBcjayPq9ta5eI6g6txUd3vTWEvVIGmMfwqae8JzWKM5fJH2N2khIg+piPis\n8Sj23w9sxQn0lIEHKc75RQtTHflr+6kaoJWTQwb177idqnHPi3R+VDIfbLnkxpme7RiwdZlrV+w3\nzsFBcc7ovdk5KtIl1TddCIHWIKBSByWpKk4mwzRNiiDU0rkiUGVwjj4tK4Uz5ATv0zvbJ/O31lT+\n4HaPY5n8BjN6WYvj2uE519SeO+qzLTbWUjVAGrMNxbx/0OIEvhbU/fsR6sPYtUS2uLGskiMQv4UT\nFilpYaFbKtdSnJuk2P8McFnNIQtAdf9m/KdSMzFDcW7IUxa/F8zS9X3oyLnU8zKPQ6MuWNo0Mcd6\nLaHPEegCn7YZlv5PFZYCYlEEMIOTUe8XhQQ4RnHuAxT7pwl2Z2tZ7toV+9/GX+BL4quh49y/rOZx\n8NwrSLy5mxzSJfYc6xsL+7m9+yl+R2b5P9agCmfI8Q0Umf526aK920o3XIH/9zC+sbB/9e3dT7Fe\nFoOFqyuwB2cyoGrdiiPm6jNsDdi/1rU0cfvuaiOQpiz4pvsE2ymknvo8DrA4i8/jBGQFJNE5GTOd\nJ9L1wo7ZWNjP/lW38EbvZvavuoW7ux9mouchzi7MUhD4kBxnoHCcgmNPNlzZdeCmjYX9vt/DxsL+\nweq5Iiz+dDnnNvOyPkq4EMeJl7Lk98JLLRWqbSwVGczjEHUWH2gu9JyS1Q2aLIWNhf1M9DzEalkA\n4GyZ5cvyXFVAvpwuC2zv2cl9PMie8ue5q3Sd8z0U+7mvpyDd0vDcmmGEuauD+C/5oziZIQcI6dj9\ndumivVd2HbjpdPf7AHhfV/H98oV7NzVZ+HYSVSJztYW4dkEVs2WTsEQksHy3dz+1KKgqQYKqIgLd\nVNjS9RznyhH+sfzZQWB3ioIawBGGqU0FcMgNVoBZPIHf47bSDVf8W+Wjrut7lLd0gG+WrmKqMnzF\nJmgqE76dRNV0OyaChQiqmM3M5VBLGlb3ELDB20Z6S9fxfOUCvbjwsnjaPQ0hAn9YeIVPFV6HaHNC\nNMohinM3u4Lx8xqW/ueOgBp+CU1VhplaGK7b3uD1FmmbNlVCSZZh7TKT1VscbdrgXA61NO3PV9tk\nt4xtqxz+2/P0jd7NG/6j91q29+xcbCOdXZhlS9dzUm33NIsI9HEqgSsFUhXMzW607mpq/ucJ5ejZ\nNhUkkmQZZiH8rKECO2tcxEDXLe3o4c//5vee+e/emcsKblJ5VTBr6kZzkIiYEr7eLI6L53el2TrB\nNGeNgkgtgts2liohAt9OBmt4Tc1o00BhRrFkfve5p3vXI9OnbR6n2F85cddvvf3Xd37t7WWRwWL/\nZor90xT7K79bmL6sS3Qx4pY1Pl2binmclpeTOBOzmAZa3ppMCcNJc3hJW1mqBAh9O0WwhmEBk0hB\nCLeDtI8l4S22H/pkft3d3Y/ycfkVFxde3rBeZveoLgmoFULyMkdf5YOceBPXEnv7s/6s60dHi92P\nHeuT+QGW0n4GWB5ImKTYHxqhS5u0hpeknlEhIpfjpKd0AQ+pqnH+vyzSlJoNMoRlEhyuDOxxoknr\n+GbpKoDFwEFBmMER8IWE56zhFVJeWNBuvnbqBt0xfk8B8rXyfNoZN7lIUxKRLuB14FLgMPAisElV\nXzUcn9/cP09y5wntPVosbeE75c8tvoFdy7Msn6z61dYIYwFn1G2mcmlUoN7q8S5rKJ7awlRleDGN\nqZGUrTQqfxbizkua0meAg6r6hluoJ3CW2PEVVW5ZGtS2Ghz37N6e3Xpvz27vUdupSdA0VOKwhNZU\nmKegpzkLIUaW1kldxeipv2SqsizsXNuYj9U9kGK6WStHDiwjbVGtB970fD4M/L73gFyv+hE8Vse7\nGNnjqkjeXLUqVXHsWLXzGgLmG1zQLlZJeRYYeFfXHC+e2tI3VRn2BrOq7mvVQlUTZv2ud8jPIpFe\n5c9i5ZRItDxQoaq7gF3gmNfUbxg+Rid0iIIPuZSTKryja7i7tIWpyvChHcV7Jin2U9LCI91SWeU9\n7l3t00fKf/zst8pf/AhOYGEN5gUTvJbGT1AncTLN6ywScLqhuM1W/lRXR4xD2qL6NXCO5/PZ7rbW\nUOPGsTQ3AZ4xOlHGHNWRlJUKavtU1Knl3v21x6s68eFqoMR13ZZlIXQX+zmhvdtPZ37dWzrA9vIX\nZ79T/tyTOPOtB85j4fm7ljJOF02YRTLl9TVb+XMzciDtfqoXgfNF5FwRWQV8CWeJnVYRllExThNt\nnmZiPqpwXE/jx5WP+V5lXru47dSN3HrqRg5X1lFROKG9s7+oDP3wiJ5Rrigc0TPKT5QveuC8+cmr\nhxd2zExVhv37X4pzk313/+ZM57hvzbgBl62Epx8NYrYoBc8iAWAe69RFChO+5Glas1QtlaqWRORm\n4BmcL/NhVX0lzXuGEOZ3N+WCvKNrEOAMjgdaEy+q8OvlVuXQdNfmMWC7quOCvqN9fL10zbGpyvA8\nMDC1MFy1BuC8natv/i4ca3OggUGGUcY9Va2J0c3yXNdEGafyt3ywaVqk3qZS1aeBp9O+T0TC/G7T\nfj+WLdL2vvbo3aUtMlUZ9gz8c7Kfn69cwJ90vVAntvd1FXcsj645b+ziXKSZhtxgQaONfj+rHcQC\n/m2qpXJHu24hjU7XLAexhtFeaUqeVB33d9yBfWFzzY3hVJ4wZqiZevmOU9cvimOqMszwwg7Om9/D\n8MIO7ipdp5+a36W3l66fPaG9s9Vzvl++8IGpynCduxKW6lRNpsX8AohiceNaZYVIblbYddMKHGQy\niDUKLY/+RSYsyBDpGiFjcJb2L7peyvJxR+7kJGO1iZ5TwfNDbHEq3Qjwd4s7Njk/sealcy1Y2JKc\nUSquySqbAgm9OFHRyRBLE2Tt0wwc5Cak3k6WKpk3UXFukuLcEMW5gvu7Piu6OHfmufOTM+fOT3Lb\nYmBAOFxZx7bSX7xnELHJCm5pwP0IqiBh7lXUimssL+Y51ddFGOJiWnhhlnQDB7lZKaSdRJX1m2gQ\n6l05N1JWR8LRp6AKEvS8ke8ZUt6gihj4EjNc9+rpiZEzU27b5GalkPZx/7Lv3Is94WKCDfCgPhdT\nhkfs6dECyjuGMyuSH6EvsdQXFzDcM/FppBuknUSVdefeGP4dwWs9K6hHIm4CaVAFMbSpEv0e3PuY\nskpyu/BaK8TsR1tNppnm8id+DI3uM064GNUqpJE9ncWkonka0pEXcjH0Iy55G/qRxPTOeZ3NNgp5\nmRE4L+Rl6Ee7k0Q7Ljeh3rjkxZ1qN9op+tcKkogo5SbUa8kGK6oAEgqT5ybUa8kG26bKANs26Qxs\noMJiSZio9dO6fxZLwlhRWSwJY0VlsSSMFZXFkjBWVBZLwtiMiojYsLglKtZSRSChNaksKwQrqmjk\nZv4DS/6xoopG2ybFWrKnKVGJyL0i8p8i8nMR+ScR+aBn3zYROSgivxSRP2q+qC3FJsVaItOspXoW\n+LiqfgJnyZxtACLyUZzZaD8GXA7sdJfVaVdsUqwlMk2JSlV/qKol9+MLOHOlg7NczhOqOq+qvwIO\n4iyr05bkaUphS/5JMqR+HfCk+/d6HJFVOexuq6NmKZ2PZLLyh8M6nGmz4jDrOeer8g2+mmyRUqeR\nZ+4EknruSLMXh4pKRJ4Dfttn152q+gP3mDuBEuYZeIx4l9LJkpWYEb8Snxmyf+5QUanqJUH7ReQr\nwBXAxbo0jiRfS+hYLBnSbPTvcuB2YKOqehvyU8CXRKRXRM4Fzgf+vZl7WSztQrNtqvtx5th+1l2b\n8wVV3aqqr4jIUzhr+5aAm1S13OS9kiZzlzMHrMRnhoyfO1cjfy2WTsBmVFgsCbOiRRWUEdJpiMjl\nbnbLQREZbXV50kZEzhGRH4nIqyLyiojcmtm9V7L7JyKXAf/iLqP6DQBVvaPFxUocN5vldeBSnD7D\nF4FNqvpqSwuWIiJyFnCWqr4kImuBnwFXZvHMK9pSBWSEdBqfAQ6q6huqugA8gZP10rGo6hFVfcn9\n+xjwGoYEhKRZ0aKq4Trgn1tdiJRYD7zp+WzMcOlERGQI+CTwkyzu1/Ejf9POCLHkGxFZA3wXuE1V\n38vinh0vqgYzQjqNFZnhIiI9OILao6rfy+y+nVuPwnEzQu4DPquqb7e6PGkhIt04gYqLccT0IrBZ\nVV9pacFSRJxshEeBd1T1tkzvvcJFdRAnI6S65OgLqrq1hUVKDRH5AvD3OCvPP6yqX29xkVJFRIaB\nHwO/ACru5jFVfTr1e69kUVksaWCjfxZLwlhRWSwJY0VlsSSMFZXFkjBWVBZLwlhRWSwJY0VlsSSM\nFZXFkjD/D2y6UNq/tcYuAAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f8af0fb3750>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import pylab\n",
"%matplotlib inline\n",
"\n",
"X, Y1, Y2 = gen_data(N)\n",
"pylab.figure(figsize=(3, 1.5))\n",
"pylab.scatter(X[:, 0], Y1[:, 0])\n",
"pylab.scatter(X[:, 0], Y2[:, 0])\n",
"pylab.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example model"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python2.7/dist-packages/IPython/kernel/__main__.py:18: UserWarning: Output \"custom_multi_loss_layer_1\" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to \"custom_multi_loss_layer_1\" during training.\n"
]
}
],
"source": [
"def get_prediction_model():\n",
" inp = Input(shape=(Q,), name='inp')\n",
" x = Dense(nb_features, activation='relu')(inp)\n",
" y1_pred = Dense(D1)(x)\n",
" y2_pred = Dense(D2)(x)\n",
" return Model(inp, [y1_pred, y2_pred])\n",
"\n",
"def get_trainable_model(prediction_model):\n",
" inp = Input(shape=(Q,), name='inp')\n",
" y1_pred, y2_pred = prediction_model(inp)\n",
" y1_true = Input(shape=(D1,), name='y1_true')\n",
" y2_true = Input(shape=(D2,), name='y2_true')\n",
" out = CustomMultiLossLayer(nb_outputs=2)([y1_true, y2_true, y1_pred, y2_pred])\n",
" return Model([inp, y1_true, y2_true], out)\n",
"\n",
"prediction_model = get_prediction_model()\n",
"trainable_model = get_trainable_model(prediction_model)\n",
"trainable_model.compile(optimizer='adam', loss=None)\n",
"assert len(trainable_model.layers[-1].trainable_weights) == 2 # two log_vars, one for each output\n",
"assert len(trainable_model.losses) == 1"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python2.7/dist-packages/IPython/kernel/__main__.py:1: UserWarning: The `nb_epoch` argument in `fit` has been renamed `epochs`.\n",
" if __name__ == '__main__':\n"
]
}
],
"source": [
"hist = trainable_model.fit([X, Y1, Y2], nb_epoch=nb_epoch, batch_size=batch_size, verbose=0)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f8abc068890>]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD8CAYAAAB5Pm/hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAFv5JREFUeJzt3XuwZFV59/HvI1dFgSGD4whIgyFGNAbIxJfXW2lEBDoJ\nmgtixUAMcbCiiby+JtWG9020SmN7y8XEoBhRjLeQeCPVmAhTVtBENAOOXMUZsFEmc+EWISAMw6z8\nsddheg6n+1y7d/c630/Vrr179e05u8/59Tqr194dKSUkSeV6XN0FSJKGy6CXpMIZ9JJUOINekgpn\n0EtS4Qx6SSqcQS9JhTPoJalwBr0kFW7vugsAiIg7gNvqrkOSJsyRKaVDZ7vRWAQ9cFtKaU3dRUjS\nJImI9XO5nUM3klQ4g16SCmfQS1LhDHpJKpxBL0mFM+glqXAGvSQVblzm0S9Io9V5NnAm8JfddvOO\nuuuRpHE06T36nwbOB55SdyGSNK4mPeh35PW+tVYhSWNs0oP+obzer9YqJGmMGfSSVLhJD3qHbiRp\nFpMe9PboJWkWpQS9PXpJ6mPSg35q6MYevST1MelB79CNJM1i0oPeD2MlaRazBn1EHBERX42IGyPi\nhoh4U24/JCIuj4iNeb0it0dEfCAiNkXEtRFxwhDrt0cvSbOYS49+J/B/U0rHAicCb4iIY4EWsC6l\ndAywLl8GOBU4Ji9rgQuWvOrdDHpJmsWsQZ9S2pJSuiZv3wfcBBwGnA5cnG92MfCKvH068IlUuQo4\nOCJWL3nlFYduJGkW8xqjj4gGcDzwTWBVSmlLvmorsCpvHwb8sOdut+e26Y+1NiLW528xXzm/sh+1\nE0jYo5ekvuYc9BHxROBzwHkppXt7r0spJarAnbOU0oUppTUppTXAnfO575Ruu5mohm8MeknqY05B\nHxH7UIX8p1JKn8/N26aGZPJ6e27fDBzRc/fDc9uw7MChG0nqay6zbgL4KHBTSunPeq66FDg7b58N\nfKmn/aw8++ZE4Ec9QzzDYI9ekgaYyzdMPR/4TeC6iNiQ2/4IaAOXRMQ5wG3AGfm6y4DTgE3AA8Br\nl7Tix3oIe/SS1NesQZ9S+joQfa5+6Qy3T8AbFlnXfOzAHr0k9TXpR8aCQzeSNFAJQe+HsZI0QAlB\nb49ekgYw6CWpcCUEvUM3kjRACUFvj16SBigl6O3RS1IfJQS98+glaYASgt6hG0kaoJSgd+hGkvoo\nIegdupGkAUoIeoduJGmAEoLeefSSNEAJQf8QsE+j1SnhZ5GkJVdCOD6U1/bqJWkGJQT9jrw26CVp\nBiUE/VSP3g9kJWkGJQT9w3m9T61VSNKYMuglqXAGvSQVroSg98NYSRqghKC3Ry9JAxj0klQ4g16S\nCmfQS1LhDHpJKlwJQe+sG0kaoISgt0cvSQMY9JJUOINekgpn0EtS4Qx6SSqcQS9JhSsh6J1eKUkD\nlBD09uglaQCDXpIKZ9BLUuFKCPqdeW3QS9IMZg36iLgoIrZHxPU9bW+LiM0RsSEvp/Vc99aI2BQR\nN0fEy4dV+JRuu5mowt6gl6QZzKVH/3HglBna/zyldFxeLgOIiGOBM4Fn5fv8TUTstVTFDrADZ91I\n0oxmDfqU0pXA3XN8vNOBz6aUHkopfR/YBDx3EfXN1cPYo5ekGS1mjP6NEXFtHtpZkdsOA37Yc5vb\nc9uwGfSS1MdCg/4C4OnAccAW4P3zfYCIWBsR6yNiPbBygXVMMeglqY8FBX1KaVtK6ZGU0i7gI+we\nntkMHNFz08Nz20yPcWFKaU1KaQ1w50Lq6GHQS1IfCwr6iFjdc/GVwNSMnEuBMyNiv4g4CjgG+Nbi\nSpwTg16S+th7thtExGeAFwMrI+J24E+AF0fEcUACusC5ACmlGyLiEuBGqimPb0gpPTKc0vfgrBtJ\n6mPWoE8pvXqG5o8OuP07gXcupqgFsEcvSX2UcGQsGPSS1JdBL0mFM+glqXAGvSQVrqSgd9aNJM2g\nlKDfgT16SZpRKUHv0I0k9WHQS1LhDHpJKpxBL0mFM+glqXClBL0nNZOkPkoJenv0ktSHQS9JhTPo\nJalwJQX94xqtTik/jyQtmVKC8eG8tlcvSdOUEvQ78tqZN5I0TSlBb49ekvow6CWpcAa9JBXOoJek\nwpUS9H4YK0l9lBL0P87rx9dahSSNoVKC/oG8fkKtVUjSGCot6O3RS9I0pQT91NCNPXpJmqaUoLdH\nL0l9lBL09uglqY9Sgt4evST1UUrQ26OXpD5KC3p79JI0TRFB3203H6Y6DYI9ekmapoigz36MPXpJ\neoySgv4B4IC6i5CkcVNS0N8DrKi7CEkaNyUF/d3AT9RdhCSNm5KC/i4Mekl6DINekgo3a9BHxEUR\nsT0iru9pOyQiLo+IjXm9IrdHRHwgIjZFxLURccIwi5/mLuCQET6fJE2EufToPw6cMq2tBaxLKR0D\nrMuXAU4FjsnLWuCCpSlzTu4GHt9odZx5I0k9Zg36lNKVVCHa63Tg4rx9MfCKnvZPpMpVwMERsXqp\nip3FjXl9/IieT5ImwkLH6FellLbk7a3Aqrx9GPDDntvdnttG4et5/cIRPZ8kTYRFfxibUkpAmu/9\nImJtRKyPiPXAysXW0W0376Lq1b9osY8lSSVZaNBvmxqSyevtuX0zcETP7Q7PbY+RUrowpbQmpbQG\nuHOBdUz3NeB5jVZnryV6PEmaeAsN+kuBs/P22cCXetrPyrNvTgR+1DPEMwpXAgcCPzvC55SksTaX\n6ZWfAb4BPCMibo+Ic4A28LKI2AiclC8DXAbcCmwCPgL87lCq7m8d8AjwqyN+XkkaW1ENsddcRMT6\nPISzaI1W58vAM4Gju+3mrqV4TEkaR3PNzpKOjJ3ySeBI4Pl1FyJJ46DEoP8icD/wmroLkaRxUFzQ\nd9vN+4EvAGc0Wp396q5HkupWXNBnnwQOBk6ruxBJqlupQb8O2AK8ru5CJKluRQZ9t93cSXVCtVMb\nrc4z665HkupUZNBnFwAPAm+uuxBJqlOxQd9tN+8EPgac1Wh1nlJ3PZJUl2KDPvtzYB/g9+ouRJLq\nUnTQd9vNjVRTLX+30eocWHc9klSHooM++1OqqZbn1V2IJNWh+KDvtptXU51d8/80Wp2D665Hkkat\n+KDP3k7Vq39T3YVI0qgti6Dvtpvfphqrf0uj1Xly3fVI0igti6DP3go8HvjjuguRpFFaNkHfbTdv\nBi4Ezm20Oj9Vdz2SNCrLJuizt1MdLfuuuguRpFFZVkHfbTe3Ae8BfqXR6ryo7nokaRSWVdBn7wd+\nAPxVo9XZu+5iJGnYll3Qd9vNB6hOdPYc4Nyay5GkoVt2QZ99HrgCeEej1Tm07mIkaZiWZdB3280E\n/D7wROB9NZcjSUO1LIMeoNtu3gS8m+o0xqfUXY8kDcuyDfrsHcB3gQ83Wp0n1V2MJA3Dsg76brv5\nIHAOcATQrrkcSRqKZR30AN1289+BD1Cds/4lddcjSUtt2Qd99kfAJuCiRqtzUN3FSNJSMuh5dG79\nWVRDOBc2Wp2ouSRJWjIGfdZtN78BnA+cgQdSSSqIQb+n9wJfBv6i0eocV3cxkrQUDPoe3XZzF3A2\ncBdwiVMuJZXAoJ+m227eAbwaeDqO10sqgEE/g267eSXw/4AzqU6AJkkTy6Dvrw38I/CeRqtzct3F\nSNJCGfR95BOfvRa4Afhso9X5yZpLkqQFMegH6Lab/w28AtgF/HOj1VlVc0mSNG8G/Sy67eatwC8C\nTwUucyaOpElj0M9Bt928Cvh14GeBzzVanX1rLkmS5sygn6Nuu9kBXge8DPi43zcraVIsKugjohsR\n10XEhohYn9sOiYjLI2JjXq9YmlLr1203Pwa8lWqe/ccarY5vlJLG3lIE1UtSSsellNbkyy1gXUrp\nGGBdvlyMbrvZBv4/8BrgAsNe0rgbRkidDlycty+mmrVSmncC7wLWAh91GEfSOFts0CfgKxFxdUSs\nzW2rUkpb8vZWYMYpiRGxNiLW5yGflYusY6TyHPvzgbcDvwV8qtHq7FNrUZLUR6SUFn7niMNSSpsj\n4snA5cDvAZemlA7uuc09KaWB4/QRsb5n6GeiNFqdt1Cd9bIDvKrbbt5fc0mSlom5ZueievQppc15\nvR34AvBcYFtErM5FrAa2L+Y5xl233Xwf8HrgNOBrHlQladwsOOgj4oCIeNLUNnAycD1wKdWpfsnr\nLy22yHHXbTc/THVQ1TOAqxqtzrNqLkmSHrWYHv0q4OsR8R3gW0AnpfTPVCcDe1lEbAROypeL1203\nLwNeDOwP/Huj1Xl5vRVJUmVRY/RLVsQEj9FP12h1ngb8E/Bsqqml78sf3krSkhrJGL0eq9tu/gB4\nHvA54D3APzRanYPqrUrScmbQD0GeefMq4A+pjiO4ptHqPLfeqiQtVw7dDFmj1Xk+8Gmqs1+eTzWU\ns6veqiSVwKGbMdFtN/8NOA74IvBu4IpGq9OotShJy4pBPwLddvMe4Azgd4A1wHWNVudcv3hc0ig4\ndDNieVbORcBLgSuB13fbzZvqrUrSJHLoZkzlWTknU53b/meA7zRanXc0Wp0n1FuZpFLZo69Ro9V5\nMvA+4DeB24A/AP7RefeS5sIe/QTotpvbu+3mWVRH1N4LXAJ8tdHqHF9rYZKKYtCPgW67+a/ACcC5\nVMM51zRanU83Wp2j661MUgkcuhkz+SjaPwDeDOwDfAx4Z7fdvK3WwiSNnblmp0E/phqtzmqqA6xe\nBwTVTJ13GfiSphj0hWi0OkdQfSH5OcBeVOP47+22m9+utTBJtTPoC9NodQ4HzqMax38i8HXgL4Ev\ndtvNnXXWJqkeBn2hGq3OwcBvU31tYwP4AfBB4G+77ebdNZYmacQM+sI1Wp29qL7V6k3AS4AHqb7O\n8cPA1zxxmlQ+g34ZabQ6zwHWAq8BDgK+D3wS+IynV5DKZdAvQ41W5wDglVTf1ftSqtk611F9gHtJ\nt938Xo3lSVpiBv0y12h1ngr8GtVZM5+fmzewO/Rvqas2SUvDoNej8oydX6cK/RNz89VU58i/HFjf\nbTcfqak8SQtk0GtGjVbnSKqe/quAn8/N/wV8lSr0rwA2eWI1afwZ9JpVo9U5lGos/yTgZcDT8lU/\nYHfor+u2m3fUU6GkQQx6zUv+tqufpAr8k4BfoJrBA9XY/hVU4f+Nbrt5Xy1FStqDQa9FabQ6ewM/\nx+7e/vOoTrK2C7ge+EbPstGhHmn0DHotqTx18wVUgf+/gf8FHJivvovqw91repZbDX9puAx6DVU+\nMveZVLN4TqQ6n/6zqXr9AD8Cvp2X66mGf27stpsPjr5aqUwGvUau0ersRxX2J/QszwH2zzfZBdxC\nFfw3AjcD3wVu7rab9468YGnCGfQaC7nnfzRwHPCsvDwHeDrVaZenbKUn+IGNwK1At9tuPjDKmqVJ\nYdBrrDVanX2pwv4ZPctP5/Uh026+lSr0b6U6j0/v9tZuu/nwiMqWxopBr4nVaHVWUr0JHD1tOQo4\ngj2/6zgB/wn8kOoNYSuwpWd76vK2bru5Y0Q/gjQSBr2KlP8TeBq7g/+pwJHAYcBT8rKyz93vpv8b\nQe/2Pc4Y0iQw6LVs5TeDJ7M7+FcP2N5/hod4mD2D/868/BdwL9WMohnXDiNplAx6aRb5aOADmflN\noPfyyrzsO4eHfZBZ3gxmWd8L3OdJ5jQXBr20xBqtzv5UbwwHTVvP1NZvfSB7fsbQz3+z55vAfcD9\nwAN5fT/Vm8pD05YdM7TN9bqdDllNFoNeGkP5v4gDmPubQu/2E/J9p5b92X2A2lJILOwNYrbrdgKP\nUB1H8cgSby/6sSb5azcNemkZaLQ6j6MaUtqvZ5l+ebHXLeQ+k2amN5DUs+yadnn60u96BlyG6lvg\nPtJtN9+7kKLnmp17L+TBJY2H3Bt9MC9jIf/Xsg+73wT2phqu2isvi9le7P3n+rh7UYVw7/K4GdoG\nXc+Ay1PbiWpq8FDZo5ekCTXX7JzLh0ILLeCUiLg5IjZFRGtYzyNJGmwoQR8RewEfBE4FjgVeHRHH\nDuO5JEmDDatH/1xgU0rp1pTSDuCzwOlDei5J0gDDCvrD2PMDhttzmyRpxGqbdRMRa4G1+WK/c5NI\nkhZpWD36zVRnGZxyeG57VErpwpTSmvyJ8Z1DqkOSlr1hBf1/AMdExFERsS9wJnDpkJ5LkjTAUIZu\nUko7I+KNwL9QHXhwUUrphmE8lyRpsHE5YOoO4LYF3n0l4zn0M651wfjWZl3zY13zU2JdR6aUDp3t\nRmMR9IsxrkfVjmtdML61Wdf8WNf8LOe6hnZkrCRpPBj0klS4EoL+wroL6GNc64Lxrc265se65mfZ\n1jXxY/SSpMFK6NFLkgaY6KCv81TIEXFERHw1Im6MiBsi4k25/W0RsTkiNuTltJ77vDXXenNEvHyI\ntXUj4rr8/Otz2yERcXlEbMzrFbk9IuIDua5rI+KEIdX0jJ59siEi7o2I8+rYXxFxUURsj4jre9rm\nvX8i4ux8+40RcfaQ6npvRHw3P/cXIuLg3N6IiB/37LcP9dzn5/LrvynXHjM93yLrmvfrttR/r33q\n+vuemroRsSG3j3J/9cuG+n7HUkoTuVAdiHULcDTVN9l8Bzh2hM+/Gjghbz8J+B7VKZnfBrxlhtsf\nm2vcDzgq177XkGrrAiuntb0HaOXtFvDuvH0a8GWqb7s5EfjmiF67rcCRdewv4EXACcD1C90/wCHA\nrXm9Im+vGEJdJwN75+1399TV6L3dtMf5Vq41cu2nDqGueb1uw/h7namuade/H/jjGvZXv2yo7Xds\nknv0tZ4KOaW0JaV0Td6+D7iJwWfoPB34bErpoZTS94FNVD/DqJwOXJy3LwZe0dP+iVS5Cjg4IlYP\nuZaXAreklAYdJDe0/ZVSuhK4e4bnm8/+eTlweUrp7pTSPcDlwClLXVdK6SsppZ354lVU543qK9d2\nYErpqlSlxSd6fpYlq2uAfq/bkv+9Dqor98rPAD4z6DGGtL/6ZUNtv2OTHPRjcyrkiGgAxwPfzE1v\nzP+CXTT17xmjrTcBX4mIq6M6SyjAqpTSlry9FVhVQ11TzmTPP8C69xfMf//Usd9+m6rnN+WoiPh2\nRPxrRLwwtx2WaxlFXfN53Ua9v14IbEspbexpG/n+mpYNtf2OTXLQj4WIeCLwOeC8lNK9wAXA04Hj\ngC1U/z6O2gtSSidQfcPXGyLiRb1X5p5LLdOtojrJ3S8D/5CbxmF/7aHO/dNPRJwP7AQ+lZu2AE9L\nKR0PvBn4dEQcOMKSxu51m+bV7NmZGPn+miEbHjXq37FJDvpZT4U8bBGxD9UL+amU0ucBUkrbUkqP\npJR2AR9h93DDyOpNKW3O6+3AF3IN26aGZPJ6+6jryk4Frkkpbcs11r6/svnun5HVFxG/Bfwi8Bs5\nIMhDI3fl7aupxr9/KtfQO7wzlLoW8LqNcn/tDfwK8Pc99Y50f82UDdT4OzbJQV/rqZDzGOBHgZtS\nSn/W0947vv1KYGpGwKXAmRGxX0QcBRxD9SHQUtd1QEQ8aWqb6sO86/PzT31qfzbwpZ66zsqf/J8I\n/Kjn38th2KOnVff+6jHf/fMvwMkRsSIPW5yc25ZURJwC/CHwyymlB3raD43qu5mJiKOp9s+tubZ7\nI+LE/Dt6Vs/PspR1zfd1G+Xf60nAd1NKjw7JjHJ/9csG6vwdW8yny3UvVJ9Wf4/q3fn8ET/3C6j+\n9boW2JCX04C/A67L7ZcCq3vuc36u9WYW+cn+gLqOpprR8B3ghqn9AvwEsA7YCFwBHJLbg+qL3G/J\nda8Z4j47ALgLOKinbeT7i+qNZgvwMNW45zkL2T9UY+ab8vLaIdW1iWqcdup37EP5tr+aX98NwDXA\nL/U8zhqq4L0F+GvygZFLXNe8X7el/nudqa7c/nHg9dNuO8r91S8bavsd88hYSSrcJA/dSJLmwKCX\npMIZ9JJUOINekgpn0EtS4Qx6SSqcQS9JhTPoJalw/wN269c5Nx7sMgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f8ae4b7e1d0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pylab.plot(hist.history['loss'])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"[8.6481799024418997, 0.92546439885191034]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Found standard deviations (ground truth is 10 and 1):\n",
"[np.exp(K.get_value(log_var[0]))**0.5 for log_var in trainable_model.layers[-1].log_vars]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}