Skip to content

Commit b2df8ca

Browse files
authored
Update index.js
1 parent c6c004c commit b2df8ca

File tree

1 file changed

+47
-29
lines changed

1 file changed

+47
-29
lines changed

Autoencoder/index.js

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,52 +8,63 @@
88
// for arbitrary data though. It's worth a look :)
99
import {IMAGE_H, IMAGE_W, MnistData} from './datas.js';
1010

11-
// This is a helper class for drawing loss graphs and MNIST images to the
12-
// window. For the purposes of understanding the machine learning bits, you can
13-
// largely ignore it
1411
import * as ui from './ui.js';
1512

1613

17-
function createConvModel(n_layers,n_units,hidden) {
18-
14+
function createConvModel(n_layers,n_units,hidden) { //resnet-densenet-batchnorm
1915
this.latent_dim = Number(hidden); //final dimension of hidden layer
2016
this.n_layers = Number(n_layers); //how many hidden layers in encoder and decoder
2117
this.n_units = Number(n_units); //output dimension of each layer
2218
this.img_shape = [28,28];
2319
this.img_units = this.img_shape[0] * this.img_shape[1];
2420
// build the encoder
21+
2522
var i = tf.input({shape: this.img_shape});
2623
var h = tf.layers.flatten().apply(i);
27-
28-
for (var j=0; j<this.n_layers; j++) {
24+
h=tf.layers.batchNormalization(-1).apply(h);
25+
h = tf.layers.dense({units: this.n_units, activation:'relu'}).apply(h);
26+
for (var j=0; j<this.n_layers-1; j++) {
27+
var tm=h;
28+
const addLayer = tf.layers.add();
2929
var h = tf.layers.dense({units: this.n_units, activation:'relu'}).apply(h); //n hidden
30+
h=addLayer.apply([tm,h]);
31+
h=tf.layers.batchNormalization(0).apply(h);
3032
}
3133

32-
var o = tf.layers.dense({units: this.latent_dim}).apply(h); //1 final
34+
var o = tf.layers.dense({units: this.latent_dim}).apply(h);
35+
//1 final
3336
this.encoder = tf.model({inputs: i, outputs: o});
3437

3538
// build the decoder
3639
var i = h = tf.input({shape: this.latent_dim});
37-
for (var j=0; j<this.n_layers; j++) { //n hidden
40+
h = tf.layers.dense({units: this.n_units, activation:'relu'}).apply(h);
41+
for (var j=0; j<this.n_layers-1; j++) {
42+
var tm=h;
43+
const addLayer = tf.layers.add(); //n hidden
3844
var h = tf.layers.dense({units: this.n_units, activation:'relu'}).apply(h);
45+
h=addLayer.apply([tm,h]);
3946
}
40-
var o = tf.layers.dense({units: this.img_units}).apply(h) ; //1 final
47+
48+
var o = tf.layers.dense({units: this.img_units}).apply(h); //1 final
4149
var o = tf.layers.reshape({targetShape: this.img_shape}).apply(o);
4250
this.decoder = tf.model({inputs: i, outputs: o});
4351

4452
// stack the autoencoder
4553
var i = tf.input({shape: this.img_shape});
4654
var z = this.encoder.apply(i); //z is hidden code
47-
4855
var o = this.decoder.apply(z);
4956
this.auto = tf.model({inputs: i, outputs: o});
5057

5158
}
59+
60+
5261
let epochs=0,trainEpochs,batch;
5362
var trainData;
5463
var testData;
5564
var b;var model;
5665

66+
67+
5768
async function train(model) {
5869

5970
const e=document.getElementById('batchsize');
@@ -84,8 +95,6 @@ await showPredictions(model,epochs); //Triv
8495

8596
}
8697

87-
88-
8998
async function showPredictions(model,epochs) { //Trivial Samples of autoencoder
9099
const testExamples = 10;
91100
const examples = data.getTestData(testExamples);
@@ -106,14 +115,15 @@ async function run(){
106115
testData = data.getTestData();
107116
}
108117

118+
document.getElementById('vis').oninput=function(){vis=Number(document.getElementById('vis').value);console.log(vis);};
109119

110120
async function load() {
111121
var ele=document.getElementById('barc');
112122
ele.style.display="none";
113123
const n_units=document.getElementById('n_units').value;
114124
const n_layers=document.getElementById('n_layers').value;
115125
const hidden=document.getElementById('hidden').value;
116-
model = new createConvModel(n_layers,n_units,hidden);
126+
model = new createConvModel(n_layers,n_units,hidden); //load model
117127
const elem=document.getElementById('new')
118128
elem.innerHTML="Model Created!!!"
119129
epochs=0;
@@ -122,13 +132,15 @@ async function load() {
122132

123133
load();
124134

135+
136+
125137
async function runtrain(){
126138
var ele=document.getElementById('barc');
127139
ele.style.display="block";
128140
var elem=document.getElementById('new');
129141
elem.innerHTML="";
130142
b=0;
131-
await train(model);
143+
await train(model); //start training
132144
vis=Number(document.getElementById('vis').value);
133145
}
134146

@@ -151,7 +163,7 @@ function normaltensor(prediction){
151163
prediction= prediction.sub(inputMin).div(inputMax.sub(inputMin));
152164
return prediction;}
153165
function normal(prediction){
154-
const inputMax = prediction.max();
166+
const inputMax = prediction.max(); //normailization
155167
const inputMin = prediction.min();
156168
prediction= prediction.sub(inputMin).div(inputMax.sub(inputMin));
157169
return prediction;
@@ -163,22 +175,27 @@ const canvas=document.getElementById('celeba-scene');
163175
const mot=document.getElementById('mot');
164176
var cont=mot.getContext('2d');
165177

178+
179+
180+
181+
182+
183+
184+
185+
186+
166187
function sample(obj) { //plotting
167188
obj.x = (obj.x) * vis;
168189
obj.y = (obj.y) * vis;
169190
// convert 10, 50 into a vector
170191
var y = tf.tensor2d([[obj.x, obj.y]]);
171-
// sample from region 10, 50 in latent space
172192

173193
var prediction = model.decoder.predict(y).dataSync();
174-
175-
//scaling
194+
//scaling
176195
prediction=normaltensor(prediction);
177196
prediction=prediction.reshape([28,28]);
178197

179-
prediction=prediction.mul(255).toInt();
180-
181-
198+
prediction=prediction.mul(255).toInt(); //for2dplot
182199
// log the prediction to the browser console
183200
tf.browser.toPixels(prediction, canvas);
184201
}
@@ -190,7 +207,7 @@ cont.fillRect(0,0,mot.width,mot.height);
190207
mot.addEventListener('mousemove', function(e) {
191208
mouse.x = (e.pageX - this.offsetLeft)*3.43;
192209
mouse.y = (e.pageY - this.offsetTop)*1.9;
193-
}, false);
210+
}, false); //mouse movement for 2dplot
194211

195212
mot.addEventListener('mousedown', function(e) {
196213
mot.addEventListener('mousemove', on, false);
@@ -209,11 +226,6 @@ var on= function() {
209226
};
210227

211228

212-
213-
214-
215-
216-
217229
function plot2d(){
218230
load();
219231
const decision=Number(document.getElementById("hidden").value);
@@ -241,6 +253,12 @@ document.addEventListener('DOMContentLoaded',plot2d);
241253

242254

243255

256+
257+
258+
259+
260+
261+
244262
const canv=document.getElementById('canv');
245263
const outcanv=document.getElementById('outcanv');
246264
var ct = outcanv.getContext('2d');
@@ -250,7 +268,7 @@ var ctx = canv.getContext('2d');
250268
function clear(){
251269
ctx.clearRect(0, 0, canv.width, canv.height);
252270
ctx.fillStyle = "black";
253-
ctx.fillRect(0, 0, canv.width, canv.height);
271+
ctx.fillRect(0, 0, canv.width, canv.height); //for canvas autoencoding
254272
ct.clearRect(0, 0, outcanv.width, outcanv.height);
255273
ct.fillStyle = "#DDDDDD";
256274
ct.fillRect(0, 0, outcanv.width, outcanv.height);

0 commit comments

Comments
 (0)