Permalink
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
118 lines (92 sloc) 3.26 KB
import {Component, ViewChild} from '@angular/core';
import {DrawableDirective} from '../drawable.directive';
import * as brain from 'brain.js/browser';
import * as tf from '@tensorflow/tfjs';
import {environment} from '../../environments/environment';
@Component({
selector: 'app-home',
templateUrl: 'home.page.html',
styleUrls: ['home.page.scss'],
})
export class HomePage {
@ViewChild(DrawableDirective) drawable: DrawableDirective;
detectionsMLP: number[] = [];
detectedNumberMLP: number;
detectionsCNN: Float32Array = null;
detectedNumberCNN: number;
private net: brain.NeuralNetwork = null;
private tfModel: tf.Model;
constructor() {
this.initBrain();
this.initTf();
}
async initBrain() {
const response = await fetch('assets/model.json');
const brainModel = await response.json();
this.net = new brain.NeuralNetwork();
this.net.fromJSON(brainModel);
}
async initTf() {
this.tfModel = await tf.loadModel(`${environment.serverURL}/assets/tfjsmnist/model.json`);
}
detect(canvas) {
const canvasCopy = document.createElement('canvas');
canvasCopy.width = 28;
canvasCopy.height = 28;
const copyContext = canvasCopy.getContext('2d');
const ratioX = canvas.width / 28;
const ratioY = canvas.height / 28;
const drawBox = this.drawable.getDrawingBox();
const scaledSourceWidth = Math.min(20, Math.max(4, ((drawBox[2] - drawBox[0] + 32) / ratioX)));
const scaledSourceHeight = Math.min(20, ((drawBox[3] - drawBox[1] + 32) / ratioY));
const dx = (28 - scaledSourceWidth) / 2;
const dy = (28 - scaledSourceHeight) / 2;
copyContext.drawImage(canvas, drawBox[0] - 16, drawBox[1] - 16, drawBox[2] - drawBox[0] + 16, drawBox[3] - drawBox[1] + 16,
dx, dy, scaledSourceWidth, scaledSourceHeight);
const imageData = copyContext.getImageData(0, 0, 28, 28);
const numPixels = imageData.width * imageData.height;
const values = new Array<number>(numPixels);
for (let i = 0; i < numPixels; i++) {
values[i] = imageData.data[i * 4 + 3] / 255.0;
}
// CNN with Tensorflow.js
const predictTensor = this.tfModel.predict(tf.tensor4d(values, [1, 28, 28, 1])) as tf.Tensor;
const data = predictTensor.dataSync<'float32'>();
this.detectedNumberCNN = this.indexMax(data);
this.detectionsCNN = data;
// console.log(tf.argMax(predictTensor, 1).dataSync());
// MLP with brain.js
const detection = this.net.run(values);
this.detectedNumberMLP = this.maxScore(detection);
this.detectionsMLP = [];
for (let i = 0; i <= 9; i++) {
this.detectionsMLP.push(detection[i]);
}
}
erase() {
this.detectionsMLP = [];
this.detectedNumberMLP = null;
this.detectionsCNN = null;
this.detectedNumberCNN = null;
this.drawable.clear();
}
maxScore(obj: { [key: number]: number }) {
let maxKey = 0;
let maxValue = 0;
Object.entries(obj).forEach(entry => {
const value = entry[1];
if (value > maxValue) {
maxValue = value;
maxKey = parseInt(entry[0], 10);
}
});
return maxKey;
}
private indexMax(data: Float32Array): number {
let indexMax = 0;
for (let r = 0; r < data.length; r++) {
indexMax = data[r] > data[indexMax] ? r : indexMax;
}
return indexMax;
}
}