forked from yashovardhan/MNIST-Image-Recognition
-
Notifications
You must be signed in to change notification settings - Fork 0
/
index.js
45 lines (35 loc) · 1.21 KB
/
index.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
require('@tensorflow/tfjs-node');
const tf = require('@tensorflow/tfjs');
const _ = require('lodash');
const mnist = require('mnist-data');
const plot = require('node-remote-plot');
const LogisticRegression = require('./logistic-regression');
function loadData() {
const mnistData = mnist.training(0, 60000);
const features = mnistData.images.values.map(image => _.flatMap(image));
const encodedLabels = mnistData.labels.values.map(label => {
const row = new Array(10).fill(0);
row[label] = 1;
return row;
});
return { features, labels: encodedLabels }
}
const { features, labels } = loadData();
const regression = new LogisticRegression(features, labels, {
learningRate: 1,
iterations: 40,
batchSize: 500
});
regression.train();
const testMnistData = mnist.testing(1, 10000);
const testFeatures = testMnistData.images.values.map(image => _.flatMap(image));
const testEncodedLabels = testMnistData.labels.values.map(label => {
const row = new Array(10).fill(0);
row[label] = 1;
return row;
});
const accuracy = regression.test(testFeatures, testEncodedLabels);
console.log('Accuracy is', accuracy*100, '%');
plot({
x: regression.costHistory.reverse()
});