Skip to content
This repository has been archived by the owner on Jul 15, 2022. It is now read-only.

Add show.model and show.layer #9

Merged
merged 8 commits into from
Sep 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ training respectively.

## show.fitCallbacks(container: Surface metrics: string[]) => {[key: string]: (iteration: number, log: Logs) => Promise<void>}


Returns a collection of callbacks to pass to [model.fit](https://js.tensorflow.org/api/latest/#tf.Model.fit).
Callbacks are returned for the following events, `onBatchEnd` & `onEpochEnd`.

Expand All @@ -199,7 +198,6 @@ on how to pass in callback functions to the training process.

## show.perClassAccuracy(container: Drawable, classAccuracy: {accuracy: number[], count: number[]}, classLabels?: string[]) => Promise<void>


Renders a per class accuracy table for classification task evaluation

* @param container A `{name: string, tab?: string}` object specifying which
Expand All @@ -212,7 +210,6 @@ Renders a per class accuracy table for classification task evaluation

## show.confusionMatrix(container: Drawable, confusionMatrix: number[][], classLabels?: string[]) => Promise<void>


Renders a confusion matrix for classification task evaluation

* @param container A `{name: string, tab?: string}` object specifying which
Expand All @@ -222,7 +219,29 @@ Renders a confusion matrix for classification task evaluation
* @param classLabels An array of string labels for the classes in
`classAccuracy`. Optional.

## show.valuesDistribution(container: Drawable, tensor: Tensor) => Promise<void>

Renders a histogram showing the distribution of all values in a tensor.

* @param container A `{name: string, tab?: string}` object specifying which
surface to render to.
* @param tensor a `Tensor`

## show.modelSummary(container: Drawable, model: tf.Model) => Promise<void>

Renders a summary of a `tf.Model`. Displays a table with layer information.

* @param container A `{name: string, tab?: string}` object specifying which
surface to render to.
* @param model a `tf.Model`

## show.layer(container: Drawable, layer: Layer) => Promise<void>

Renders summary information about a layer and a histogram of parameter values in that layer.

* @param container A `{name: string, tab?: string}` object specifying which
surface to render to.
* @param layer a `tf.layers.Layer`

## Renderers

Expand Down
16 changes: 16 additions & 0 deletions demos/mnist/index.html
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
<!-- /**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/ -->
<html>

<head>
Expand Down
17 changes: 17 additions & 0 deletions demos/mnist/index.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '@tensorflow/tfjs';
import * as tfvis from '../../src'
import {getModel, loadData} from './model';
Expand Down
19 changes: 17 additions & 2 deletions demos/mnist/model.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '@tensorflow/tfjs';
import {MnistData} from './data';

Expand Down Expand Up @@ -43,8 +60,6 @@ export function getModel() {
return model;
}



export async function loadData() {
const data = new MnistData();
await data.load();
Expand Down
5 changes: 5 additions & 0 deletions demos/mnist/tufte.css
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
/*
From https://github.com/edwardtufte/tufte-css
https://github.com/edwardtufte/tufte-css/blob/gh-pages/LICENSE
*/

@charset "UTF-8";

/* Tufte CSS styles */
Expand Down
144 changes: 144 additions & 0 deletions demos/mnist_internals/data.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '@tensorflow/tfjs';

const IMAGE_SIZE = 784;
const NUM_CLASSES = 10;
const NUM_DATASET_ELEMENTS = 65000;

const NUM_TRAIN_ELEMENTS = 55000;
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;

const MNIST_IMAGES_SPRITE_PATH =
'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
const MNIST_LABELS_PATH =
'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';

/**
* A class that fetches the sprited MNIST dataset and returns shuffled batches.
*
* NOTE: This will get much easier. For now, we do data fetching and
* manipulation manually.
*/
export class MnistData {
constructor() {
this.shuffledTrainIndex = 0;
this.shuffledTestIndex = 0;
}

async load() {
// Make a request for the MNIST sprited image.
const img = new Image();
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
const imgRequest = new Promise((resolve, reject) => {
img.crossOrigin = '';
img.onload = () => {
img.width = img.naturalWidth;
img.height = img.naturalHeight;

const datasetBytesBuffer =
new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);

const chunkSize = 5000;
canvas.width = img.width;
canvas.height = chunkSize;

for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
const datasetBytesView = new Float32Array(
datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
IMAGE_SIZE * chunkSize);
ctx.drawImage(
img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
chunkSize);

const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

for (let j = 0; j < imageData.data.length / 4; j++) {
// All channels hold an equal value since the image is grayscale, so
// just read the red channel.
datasetBytesView[j] = imageData.data[j * 4] / 255;
}
}
this.datasetImages = new Float32Array(datasetBytesBuffer);

resolve();
};
img.src = MNIST_IMAGES_SPRITE_PATH;
});

const labelsRequest = fetch(MNIST_LABELS_PATH);
const [imgResponse, labelsResponse] =
await Promise.all([imgRequest, labelsRequest]);

this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

// Create shuffled indices into the train/test set for when we select a
// random dataset element for training / validation.
this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);

// Slice the the images and labels into train and test sets.
this.trainImages =
this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.trainLabels =
this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
this.testLabels =
this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
}

nextTrainBatch(batchSize) {
return this.nextBatch(
batchSize, [this.trainImages, this.trainLabels], () => {
this.shuffledTrainIndex =
(this.shuffledTrainIndex + 1) % this.trainIndices.length;
return this.trainIndices[this.shuffledTrainIndex];
});
}

nextTestBatch(batchSize) {
return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
this.shuffledTestIndex =
(this.shuffledTestIndex + 1) % this.testIndices.length;
return this.testIndices[this.shuffledTestIndex];
});
}

nextBatch(batchSize, data, index) {
const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);

for (let i = 0; i < batchSize; i++) {
const idx = index();

const image =
data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
batchImagesArray.set(image, i * IMAGE_SIZE);

const label =
data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
batchLabelsArray.set(label, i * NUM_CLASSES);
}

const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);

return {xs, labels};
}
}
Loading