8
8
// for arbitrary data though. It's worth a look :)
9
9
import { IMAGE_H , IMAGE_W , MnistData } from './datas.js' ;
10
10
11
- // This is a helper class for drawing loss graphs and MNIST images to the
12
- // window. For the purposes of understanding the machine learning bits, you can
13
- // largely ignore it
14
11
import * as ui from './ui.js' ;
15
12
16
13
17
- function createConvModel ( n_layers , n_units , hidden ) {
18
-
14
+ function createConvModel ( n_layers , n_units , hidden ) { //resnet-densenet-batchnorm
19
15
this . latent_dim = Number ( hidden ) ; //final dimension of hidden layer
20
16
this . n_layers = Number ( n_layers ) ; //how many hidden layers in encoder and decoder
21
17
this . n_units = Number ( n_units ) ; //output dimension of each layer
22
18
this . img_shape = [ 28 , 28 ] ;
23
19
this . img_units = this . img_shape [ 0 ] * this . img_shape [ 1 ] ;
24
20
// build the encoder
21
+
25
22
var i = tf . input ( { shape : this . img_shape } ) ;
26
23
var h = tf . layers . flatten ( ) . apply ( i ) ;
27
-
28
- for ( var j = 0 ; j < this . n_layers ; j ++ ) {
24
+ h = tf . layers . batchNormalization ( - 1 ) . apply ( h ) ;
25
+ h = tf . layers . dense ( { units : this . n_units , activation :'relu' } ) . apply ( h ) ;
26
+ for ( var j = 0 ; j < this . n_layers - 1 ; j ++ ) {
27
+ var tm = h ;
28
+ const addLayer = tf . layers . add ( ) ;
29
29
var h = tf . layers . dense ( { units : this . n_units , activation :'relu' } ) . apply ( h ) ; //n hidden
30
+ h = addLayer . apply ( [ tm , h ] ) ;
31
+ h = tf . layers . batchNormalization ( 0 ) . apply ( h ) ;
30
32
}
31
33
32
- var o = tf . layers . dense ( { units : this . latent_dim } ) . apply ( h ) ; //1 final
34
+ var o = tf . layers . dense ( { units : this . latent_dim } ) . apply ( h ) ;
35
+ //1 final
33
36
this . encoder = tf . model ( { inputs : i , outputs : o } ) ;
34
37
35
38
// build the decoder
36
39
var i = h = tf . input ( { shape : this . latent_dim } ) ;
37
- for ( var j = 0 ; j < this . n_layers ; j ++ ) { //n hidden
40
+ h = tf . layers . dense ( { units : this . n_units , activation :'relu' } ) . apply ( h ) ;
41
+ for ( var j = 0 ; j < this . n_layers - 1 ; j ++ ) {
42
+ var tm = h ;
43
+ const addLayer = tf . layers . add ( ) ; //n hidden
38
44
var h = tf . layers . dense ( { units : this . n_units , activation :'relu' } ) . apply ( h ) ;
45
+ h = addLayer . apply ( [ tm , h ] ) ;
39
46
}
40
- var o = tf . layers . dense ( { units : this . img_units } ) . apply ( h ) ; //1 final
47
+
48
+ var o = tf . layers . dense ( { units : this . img_units } ) . apply ( h ) ; //1 final
41
49
var o = tf . layers . reshape ( { targetShape : this . img_shape } ) . apply ( o ) ;
42
50
this . decoder = tf . model ( { inputs : i , outputs : o } ) ;
43
51
44
52
// stack the autoencoder
45
53
var i = tf . input ( { shape : this . img_shape } ) ;
46
54
var z = this . encoder . apply ( i ) ; //z is hidden code
47
-
48
55
var o = this . decoder . apply ( z ) ;
49
56
this . auto = tf . model ( { inputs : i , outputs : o } ) ;
50
57
51
58
}
59
+
60
+
52
61
let epochs = 0 , trainEpochs , batch ;
53
62
var trainData ;
54
63
var testData ;
55
64
var b ; var model ;
56
65
66
+
67
+
57
68
async function train ( model ) {
58
69
59
70
const e = document . getElementById ( 'batchsize' ) ;
@@ -84,8 +95,6 @@ await showPredictions(model,epochs); //Triv
84
95
85
96
}
86
97
87
-
88
-
89
98
async function showPredictions ( model , epochs ) { //Trivial Samples of autoencoder
90
99
const testExamples = 10 ;
91
100
const examples = data . getTestData ( testExamples ) ;
@@ -106,14 +115,15 @@ async function run(){
106
115
testData = data . getTestData ( ) ;
107
116
}
108
117
118
+ document . getElementById ( 'vis' ) . oninput = function ( ) { vis = Number ( document . getElementById ( 'vis' ) . value ) ; console . log ( vis ) ; } ;
109
119
110
120
async function load ( ) {
111
121
var ele = document . getElementById ( 'barc' ) ;
112
122
ele . style . display = "none" ;
113
123
const n_units = document . getElementById ( 'n_units' ) . value ;
114
124
const n_layers = document . getElementById ( 'n_layers' ) . value ;
115
125
const hidden = document . getElementById ( 'hidden' ) . value ;
116
- model = new createConvModel ( n_layers , n_units , hidden ) ;
126
+ model = new createConvModel ( n_layers , n_units , hidden ) ; //load model
117
127
const elem = document . getElementById ( 'new' )
118
128
elem . innerHTML = "Model Created!!!"
119
129
epochs = 0 ;
@@ -122,13 +132,15 @@ async function load() {
122
132
123
133
load ( ) ;
124
134
135
+
136
+
125
137
async function runtrain ( ) {
126
138
var ele = document . getElementById ( 'barc' ) ;
127
139
ele . style . display = "block" ;
128
140
var elem = document . getElementById ( 'new' ) ;
129
141
elem . innerHTML = "" ;
130
142
b = 0 ;
131
- await train ( model ) ;
143
+ await train ( model ) ; //start training
132
144
vis = Number ( document . getElementById ( 'vis' ) . value ) ;
133
145
}
134
146
@@ -151,7 +163,7 @@ function normaltensor(prediction){
151
163
prediction = prediction . sub ( inputMin ) . div ( inputMax . sub ( inputMin ) ) ;
152
164
return prediction ; }
153
165
function normal ( prediction ) {
154
- const inputMax = prediction . max ( ) ;
166
+ const inputMax = prediction . max ( ) ; //normailization
155
167
const inputMin = prediction . min ( ) ;
156
168
prediction = prediction . sub ( inputMin ) . div ( inputMax . sub ( inputMin ) ) ;
157
169
return prediction ;
@@ -163,22 +175,27 @@ const canvas=document.getElementById('celeba-scene');
163
175
const mot = document . getElementById ( 'mot' ) ;
164
176
var cont = mot . getContext ( '2d' ) ;
165
177
178
+
179
+
180
+
181
+
182
+
183
+
184
+
185
+
186
+
166
187
function sample ( obj ) { //plotting
167
188
obj . x = ( obj . x ) * vis ;
168
189
obj . y = ( obj . y ) * vis ;
169
190
// convert 10, 50 into a vector
170
191
var y = tf . tensor2d ( [ [ obj . x , obj . y ] ] ) ;
171
- // sample from region 10, 50 in latent space
172
192
173
193
var prediction = model . decoder . predict ( y ) . dataSync ( ) ;
174
-
175
- //scaling
194
+ //scaling
176
195
prediction = normaltensor ( prediction ) ;
177
196
prediction = prediction . reshape ( [ 28 , 28 ] ) ;
178
197
179
- prediction = prediction . mul ( 255 ) . toInt ( ) ;
180
-
181
-
198
+ prediction = prediction . mul ( 255 ) . toInt ( ) ; //for2dplot
182
199
// log the prediction to the browser console
183
200
tf . browser . toPixels ( prediction , canvas ) ;
184
201
}
@@ -190,7 +207,7 @@ cont.fillRect(0,0,mot.width,mot.height);
190
207
mot . addEventListener ( 'mousemove' , function ( e ) {
191
208
mouse . x = ( e . pageX - this . offsetLeft ) * 3.43 ;
192
209
mouse . y = ( e . pageY - this . offsetTop ) * 1.9 ;
193
- } , false ) ;
210
+ } , false ) ; //mouse movement for 2dplot
194
211
195
212
mot . addEventListener ( 'mousedown' , function ( e ) {
196
213
mot . addEventListener ( 'mousemove' , on , false ) ;
@@ -209,11 +226,6 @@ var on= function() {
209
226
} ;
210
227
211
228
212
-
213
-
214
-
215
-
216
-
217
229
function plot2d ( ) {
218
230
load ( ) ;
219
231
const decision = Number ( document . getElementById ( "hidden" ) . value ) ;
@@ -241,6 +253,12 @@ document.addEventListener('DOMContentLoaded',plot2d);
241
253
242
254
243
255
256
+
257
+
258
+
259
+
260
+
261
+
244
262
const canv = document . getElementById ( 'canv' ) ;
245
263
const outcanv = document . getElementById ( 'outcanv' ) ;
246
264
var ct = outcanv . getContext ( '2d' ) ;
@@ -250,7 +268,7 @@ var ctx = canv.getContext('2d');
250
268
function clear ( ) {
251
269
ctx . clearRect ( 0 , 0 , canv . width , canv . height ) ;
252
270
ctx . fillStyle = "black" ;
253
- ctx . fillRect ( 0 , 0 , canv . width , canv . height ) ;
271
+ ctx . fillRect ( 0 , 0 , canv . width , canv . height ) ; //for canvas autoencoding
254
272
ct . clearRect ( 0 , 0 , outcanv . width , outcanv . height ) ;
255
273
ct . fillStyle = "#DDDDDD" ;
256
274
ct . fillRect ( 0 , 0 , outcanv . width , outcanv . height ) ;
0 commit comments