-
Notifications
You must be signed in to change notification settings - Fork 0
/
ui.js
119 lines (104 loc) · 3.86 KB
/
ui.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import embed from 'vega-embed';
const statusElement = document.getElementById('status');
const messageElement = document.getElementById('message');
const imagesElement = document.getElementById('images');
export function isTraining() {
statusElement.innerText = 'Training...';
}
export function trainingLog(message) {
messageElement.innerText = `${message}\n`;
console.log(message);
}
export function showTestResults(batch, predictions, labels) {
statusElement.innerText = 'Testing...';
const testExamples = batch.xs.shape[0];
let totalCorrect = 0;
for (let i = 0; i < testExamples; i++) {
const image = batch.xs.slice([i, 0], [1, batch.xs.shape[1]]);
const div = document.createElement('div');
div.className = 'pred-container';
const canvas = document.createElement('canvas');
canvas.className = 'prediction-canvas';
draw(image.flatten(), canvas);
const pred = document.createElement('div');
const prediction = predictions[i];
const label = labels[i];
const correct = prediction === label;
pred.className = `pred ${(correct ? 'pred-correct' : 'pred-incorrect')}`;
pred.innerText = `pred: ${prediction}`;
div.appendChild(pred);
div.appendChild(canvas);
imagesElement.appendChild(div);
}
}
const lossLabelElement = document.getElementById('loss-label');
const accuracyLabelElement = document.getElementById('accuracy-label');
export function plotLosses(lossValues) {
embed(
'#lossCanvas', {
'$schema': 'https://vega.github.io/schema/vega-lite/v2.json',
'data': {'values': lossValues},
'mark': {'type': 'line'},
'width': 260,
'orient': 'vertical',
'encoding': {
'x': {'field': 'batch', 'type': 'ordinal'},
'y': {'field': 'loss', 'type': 'quantitative'},
'color': {'field': 'set', 'type': 'nominal', 'legend': null},
}
},
{width: 360});
lossLabelElement.innerText =
'last loss: ' + lossValues[lossValues.length - 1].loss.toFixed(2);
}
export function plotAccuracies(accuracyValues) {
embed(
'#accuracyCanvas', {
'$schema': 'https://vega.github.io/schema/vega-lite/v2.json',
'data': {'values': accuracyValues},
'width': 260,
'mark': {'type': 'line', 'legend': null},
'orient': 'vertical',
'encoding': {
'x': {'field': 'batch', 'type': 'ordinal'},
'y': {'field': 'accuracy', 'type': 'quantitative'},
'color': {'field': 'set', 'type': 'nominal', 'legend': null},
}
},
{'width': 360});
accuracyLabelElement.innerText = 'last accuracy: ' +
(accuracyValues[accuracyValues.length - 1].accuracy * 100).toFixed(2) +
'%';
}
export function draw(image, canvas) {
const [width, height] = [28, 28];
canvas.width = width;
canvas.height = height;
const ctx = canvas.getContext('2d');
const imageData = new ImageData(width, height);
const data = image.dataSync();
for (let i = 0; i < height * width; ++i) {
const j = i * 4;
imageData.data[j + 0] = data[i] * 255;
imageData.data[j + 1] = data[i] * 255;
imageData.data[j + 2] = data[i] * 255;
imageData.data[j + 3] = 255;
}
ctx.putImageData(imageData, 0, 0);
}