Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions mnist-core-webworker/.babelrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"presets": [
[
"env",
{
"modules": false,
"targets": {
"browsers": [
"> 3%"
]
},
"useBuiltIns": true
}
]
],
"plugins": [
"transform-runtime",
"transform-class-properties"
]
}
10 changes: 10 additions & 0 deletions mnist-core-webworker/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# TensorFlow.js Example: MNIST using tfjs-core

This example shows you how to train MNIST on a web worker using WebGL only using the [TensorFlow.js core]
(https://github.com/tensorflow/tfjs-core) library (not using the layers API).

Note: currently the entire dataset of MNIST images is stored in a PNG image we have
sprited, and the code in `data.js` is responsible for converting it into
`Tensor`s. This will become much simpler in the near future.

[See this example live!](https://storage.googleapis.com/tfjs-examples/mnist-core-webworker/dist/index.html)
139 changes: 139 additions & 0 deletions mnist-core-webworker/data.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '@tensorflow/tfjs-core';

const IMAGE_SIZE = 784;
const NUM_CLASSES = 10;
const NUM_DATASET_ELEMENTS = 65000;

const TRAIN_TEST_RATIO = 5 / 6;

const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS);
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;

const MNIST_IMAGES_SPRITE_PATH =
'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
const MNIST_LABELS_PATH =
'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';

/**
* A class that fetches the sprited MNIST dataset and returns shuffled batches.
*
* NOTE: This will get much easier. For now, we do data fetching and
* manipulation manually.
*/
export class MnistData {
constructor() {
this.shuffledTrainIndex = 0;
this.shuffledTestIndex = 0;
}

async load() {
// Make a request for the MNIST sprited image.
const imgRequest = (async () => {
const response = await fetch(MNIST_IMAGES_SPRITE_PATH, {mode: 'cors'});
const blob = await response.blob();
const bitmap = await createImageBitmap(blob);
const {width} = bitmap;

const datasetBytesBuffer =
new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);

const chunkSize = 5000;
const canvas = new OffscreenCanvas(width, chunkSize);
const ctx = canvas.getContext('2d');

for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
const datasetBytesView = new Float32Array(
datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
IMAGE_SIZE * chunkSize);
ctx.drawImage(
bitmap, 0, i * chunkSize, width, chunkSize, 0, 0, width, chunkSize);

const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

for (let j = 0; j < imageData.data.length / 4; j++) {
// All channels hold an equal value since the image is grayscale, so
// just read the red channel.
datasetBytesView[j] = imageData.data[j * 4] / 255;
}
}

this.datasetImages = new Float32Array(datasetBytesBuffer);
})();

const labelsRequest = fetch(MNIST_LABELS_PATH);
const [imgResponse, labelsResponse] =
await Promise.all([imgRequest, labelsRequest]);

this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

// Create shuffled indices into the train/test set for when we select a
// random dataset element for training / validation.
this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);

// Slice the the images and labels into train and test sets.
this.trainImages =
this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.trainLabels =
this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
this.testLabels =
this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
}

nextTrainBatch(batchSize) {
return this.nextBatch(
batchSize, [this.trainImages, this.trainLabels], () => {
this.shuffledTrainIndex =
(this.shuffledTrainIndex + 1) % this.trainIndices.length;
return this.trainIndices[this.shuffledTrainIndex];
});
}

nextTestBatch(batchSize) {
return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
this.shuffledTestIndex =
(this.shuffledTestIndex + 1) % this.testIndices.length;
return this.testIndices[this.shuffledTestIndex];
});
}

nextBatch(batchSize, data, index) {
const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);

for (let i = 0; i < batchSize; i++) {
const idx = index();

const image =
data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
batchImagesArray.set(image, i * IMAGE_SIZE);

const label =
data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
batchLabelsArray.set(label, i * NUM_CLASSES);
}

const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);

return {xs, labels};
}
}
66 changes: 66 additions & 0 deletions mnist-core-webworker/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
<!--
Copyright 2018 Google LLC. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
-->

<html>

<head>
<link rel="stylesheet" href="https://code.getmdl.io/1.3.0/material.cyan-teal.min.css" />
</head>

<body>
<h3>TensorFlow.js: Train MNIST with the Core API</h3>
<p>Open your browser console</p>
<style>
canvas {
width: 100px;
}

.pred {
font-size: 30px;
width: 100px;
}

.pred-correct {
background-color: #00cf00;
}

.pred-incorrect {
background-color: red;
}

.pred-container {
display: inline-block;
width: 100px;
margin: 10px;
}

#animating-square {
width: 25px;
height: 25px;
background-color: red;
transition: transform 0.1s linear;
}
</style>

<div id="animating-square"></div>
<div id="status">Loading data...</div>
<div id="message"></div>
<div id="images"></div>
<script src="index.js"></script>
</body>

</html>
47 changes: 47 additions & 0 deletions mnist-core-webworker/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as Comlink from 'comlink';

import * as ui from './ui';

async function load(model) {
await model.load();
}

async function train(model) {
ui.isTraining();
await model.train(Comlink.proxyValue(ui.trainingLog));
}

async function test(model) {
const testExamples = 50;
const {imagesData, predictions, labels} = await model.test(testExamples);

ui.showTestResults(imagesData, predictions, labels, testExamples);
}

async function mnist() {
ui.animate();
const Model = Comlink.proxy(new Worker('model.js'));
const model = await new Model();

await load(model);
await train(model);
test(model);
}
mnist();
Loading