diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3ab4e81 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +node_modules/ +src/bundle.js +src/*.js.map +build/ +bower_components/ +src/images/ + +npm-debug.log +.DS_Store +dist/ +.idea/ + +*~ diff --git a/bower.json b/bower.json new file mode 100644 index 0000000..3062a76 --- /dev/null +++ b/bower.json @@ -0,0 +1,34 @@ +{ + "name": "gan-playground", + "description": "", + "main": "", + "authors": [], + "license": "MIT", + "homepage": "", + "private": true, + "ignore": [ + "**/.*", + "node_modules", + "bower_components", + "test", + "tests" + ], + "dependencies": { + "paper-dialog-scrollable": "1.*.*", + "paper-icon-button": "1.*.*", + "paper-tooltip": "1.*.*", + "paper-toggle-button": "1.*.*", + "paper-radio-group": "1.*.*", + "paper-radio-button": "1.*.*", + "paper-dialog": "1.*.*", + "paper-button": "1.*.*", + "paper-item": "1.*.*", + "paper-dropdown-menu": "1.*.*", + "paper-listbox": "1.*.*", + "iron-icons": "1.*.*", + "paper-slider": "1.*.*", + "polymer": "1.*.*", + "paper-spinner": "1.*.*", + "paper-progress": "1.*.*" + } +} diff --git a/package.json b/package.json new file mode 100644 index 0000000..e0b5348 --- /dev/null +++ b/package.json @@ -0,0 +1,28 @@ +{ + "name": "gan-playground", + "version": "1.0.0", + "description": "", + "main": "index.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "author": "", + "license": "MIT", + "devDependencies": { + "@types/jasmine": "~2.5.53", + "@types/polymer": "~1.1.31", + "bower": "~1.8.0", + "browserify": "~14.4.0", + "cross-spawn": "~5.1.0", + "deeplearn": "~0.2.3", + "http-server": "~0.10.0", + "jasmine-core": "~2.6.4", + "polymer-bundler": "~3.0.1", + "tsify": "~3.0.1", + "tslint": "~5.6.0", + "typedoc": "~0.8.0", + "typescript": "2.4.2", + "uglify-js": "~3.0.28", + "watchify": "~3.9.0" + } +} diff --git a/scripts/build-demo b/scripts/build-demo new file mode 100755 index 0000000..92d761f --- /dev/null +++ b/scripts/build-demo @@ -0,0 +1,36 @@ +#!/usr/bin/env node +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +const path = require('path'); +const spawn = require('cross-spawn'); + +const startTsFilePath = process.argv[2]; +const outDir = process.argv[3]; + +let outputPath; +if (outDir != null) { + outputPath = path.join(outDir, 'bundle.js'); +} else { + outputPath = path.join(path.dirname(startTsFilePath), 'bundle.js') +} + + +const cmd = path.join('node_modules', '.bin', 'browserify'); +const child = spawn(cmd, [startTsFilePath, '-p', '[tsify]', '-o' , outputPath], + {detached: false}); +child.stdout.pipe(process.stdout); +child.stderr.pipe(process.stderr); +child.on('close', () => console.log(`Stored bundle in ${outputPath}`)); diff --git a/scripts/deploy-demo b/scripts/deploy-demo new file mode 100755 index 0000000..c5787a5 --- /dev/null +++ b/scripts/deploy-demo @@ -0,0 +1,36 @@ +#!/usr/bin/env node +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +const path = require('path'); +const spawn = require('cross-spawn'); + +const startTsFilePath = process.argv[2]; +const startHTMLFilePath = process.argv[3]; +const outDir = process.argv[4]; + +const cmd = path.join('scripts', 'build-demo'); +const child = spawn(cmd, [startTsFilePath, outDir], {detached: false}); +child.stdout.pipe(process.stdout); +child.stderr.pipe(process.stderr); +child.on('close', () => { + const bundlePath = path.join(outDir, path.basename(startHTMLFilePath)); + const cmd = path.join('node_modules', '.bin', 'polymer-bundler'); + const child = spawn(cmd, ['--inline-scripts', '--inline-css', + '--out-html', bundlePath, startHTMLFilePath], {detached: false}); + child.stdout.pipe(process.stdout); + child.stderr.pipe(process.stderr); + child.on('close', () => console.log(`Saved bundled demo at ${bundlePath}`)); +}); diff --git a/scripts/watch-demo b/scripts/watch-demo new file mode 100755 index 0000000..5c43b7d --- /dev/null +++ b/scripts/watch-demo @@ -0,0 +1,42 @@ +#!/usr/bin/env node +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +const path = require('path'); +const spawn = require('cross-spawn'); + +const startTsFilePath = process.argv[2]; +const outputPath = path.join(path.dirname(startTsFilePath), 'bundle.js') + +const cmd = path.join('node_modules', '.bin', 'watchify'); +const watchify = spawn(cmd, [startTsFilePath, '-p', '[tsify]', '-v', '--debug', + '-o' , outputPath], {detached: false}); +watchify.stdout.pipe(process.stdout); +watchify.stderr.pipe(process.stderr); + +let httpServerStarted = false; + +console.log('Waiting for initial compile...'); +watchify.stderr.on('data', (data) => { + if (data.toString().includes(`bytes written to`)) { + if (!httpServerStarted) { + const httpCmd = path.join('node_modules', '.bin', 'http-server'); + const httpServer = spawn(httpCmd, ['-c-1'], { detached: false}); + httpServer.stdout.pipe(process.stdout); + httpServer.stderr.pipe(process.stderr); + httpServerStarted = true; + } + } +}); diff --git a/src/chartjs.d.ts b/src/chartjs.d.ts new file mode 100644 index 0000000..76c7d9b --- /dev/null +++ b/src/chartjs.d.ts @@ -0,0 +1,436 @@ +/** + * @license + * This project is licensed under the MIT license. + * Copyrights are respective of each contributor listed at the beginning of each + * definition file. + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +// Type definitions for Chart.js +// Project: https://github.com/nnnick/Chart.js +// Definitions by: Alberto Nuti +// Definitions: https://github.com/DefinitelyTyped/DefinitelyTyped + +declare enum ChartType { line, bar, radar, doughnut, polarArea, bubble } +declare enum TimeUnit { + millisecond, + second, + minute, + hour, + day, + week, + month, + quarter, + year +} +interface ChartLegendItem { + text?: string; + fillStyle?: string; + hidden?: boolean; + lineCap?: string; + lineDash?: number[]; + lineDashOffset?: number; + lineJoin?: string; + lineWidth?: number; + strokeStyle?: string; +} +interface ChartTooltipItem { + xLabel?: string; + yLabel?: string; + datasetIndex?: number; + index?: number; +} +interface ChartTooltipCallback { + beforeTitle?: (item?: ChartTooltipItem[], data?: any) => void; + title?: (item?: ChartTooltipItem[], data?: any) => void; + afterTitle?: (item?: ChartTooltipItem[], data?: any) => void; + beforeBody?: (item?: ChartTooltipItem[], data?: any) => void; + beforeLabel?: (tooltipItem?: ChartTooltipItem, data?: any) => void; + label?: (tooltipItem?: ChartTooltipItem, data?: any) => void; + afterLabel?: (tooltipItem?: ChartTooltipItem, data?: any) => void; + afterBody?: (item?: ChartTooltipItem[], data?: any) => void; + beforeFooter?: (item?: ChartTooltipItem[], data?: any) => void; + footer?: (item?: ChartTooltipItem[], data?: any) => void; + afterfooter?: (item?: ChartTooltipItem[], data?: any) => void; +} +interface ChartAnimationParameter { + chartInstance?: any; + animationObject?: any; +} +interface ChartPoint { + x?: number; + y?: number; +} + +interface ChartConfiguration { + type?: string; + data?: ChartData; + options?: ChartOptions; +} + +interface ChartData {} + +interface LinearChartData extends ChartData { + labels?: string[]; + datasets?: ChartDataSets[]; +} + +interface ChartOptions { + responsive?: boolean; + responsiveAnimationDuration?: number; + maintainAspectRatio?: boolean; + events?: string[]; + onClick?: (any?: any) => any; + title?: ChartTitleOptions; + legend?: ChartLegendOptions; + tooltips?: ChartTooltipOptions; + hover?: ChartHoverOptions; + animation?: ChartAnimationOptions; + elements?: ChartElementsOptions; + scales?: ChartScales; +} + +interface ChartFontOptions { + defaultFontColor?: ChartColor; + defaultFontFamily?: string; + defaultFontSize?: number; + defaultFontStyle?: string; +} + +interface ChartTitleOptions { + display?: boolean; + position?: string; + fullWdith?: boolean; + fontSize?: number; + fontFamily?: string; + fontColor?: ChartColor; + fontStyle?: string; + padding?: number; + text?: string; +} + +interface ChartLegendOptions { + display?: boolean; + position?: string; + fullWidth?: boolean; + onClick?: (event: any, legendItem: any) => void; + labels?: ChartLegendLabelOptions; +} + +interface ChartLegendLabelOptions { + boxWidth?: number; + fontSize?: number; + fontStyle?: number; + fontColor?: ChartColor; + fontFamily?: string; + padding?: number; + generateLabels?: (chart: any) => any; +} + +interface ChartTooltipOptions { + enabled?: boolean; + custom?: (a: any) => void; + mode?: string; + backgroundColor?: ChartColor; + titleFontFamily?: string; + titleFontSize?: number; + titleFontStyle?: string; + titleFontColor?: ChartColor; + titleSpacing?: number; + titleMarginBottom?: number; + bodyFontFamily?: string; + bodyFontSize?: number; + bodyFontStyle?: string; + bodyFontColor?: ChartColor; + bodySpacing?: number; + footerFontFamily?: string; + footerFontSize?: number; + footerFontStyle?: string; + footerFontColor?: ChartColor; + footerSpacing?: number; + footerMarginTop?: number; + xPadding?: number; + yPadding?: number; + caretSize?: number; + cornerRadius?: number; + multiKeyBackground?: string; + callbacks?: ChartTooltipCallback; +} + +interface ChartHoverOptions { + mode?: string; + animationDuration?: number; + onHover?: (active: any) => void; +} + +interface ChartAnimationObject { + currentStep?: number; + numSteps?: number; + easing?: string; + render?: (arg: any) => void; + onAnimationProgress?: (arg: any) => void; + onAnimationComplete?: (arg: any) => void; +} + +interface ChartAnimationOptions { + duration?: number; + easing?: string; + onProgress?: (chart: any) => void; + onComplete?: (chart: any) => void; +} + +interface ChartElementsOptions { + point?: ChartPointOptions; + line?: ChartLineOptions; + arg?: ChartArcOtpions; + rectangle?: ChartRectangleOptions; +} + +interface ChartArcOtpions { + backgroundColor?: ChartColor; + borderColor?: ChartColor; + borderWidth?: number; +} + +interface ChartLineOptions { + tension?: number; + backgroundColor?: ChartColor; + borderWidth?: number; + borderColor?: ChartColor; + borderCapStyle?: string; + borderDash?: any[]; + borderDashOffset?: number; + borderJoinStyle?: string; +} + +interface ChartPointOptions { + radius?: number; + pointStyle?: string; + backgroundColor?: ChartColor; + borderWidth?: number; + borderColor?: ChartColor; + hitRadius?: number; + hoverRadius?: number; + hoverBorderWidth?: number; +} + +interface ChartRectangleOptions { + backgroundColor?: ChartColor; + borderWidth?: number; + borderColor?: ChartColor; + borderSkipped?: string; +} +interface GridLineOptions { + display?: boolean; + color?: ChartColor; + lineWidth?: number; + drawBorder?: boolean; + drawOnChartArea?: boolean; + drawticks?: boolean; + tickMarkLength?: number; + zeroLineWidth?: number; + zeroLineColor?: ChartColor; + offsetGridLines?: boolean; +} + +interface ScaleTitleOptions { + display?: boolean; + labelString?: string; + fontColor?: ChartColor; + fontFamily?: string; + fontSize?: number; + fontStyle?: string; +} + +interface TickOptions { + autoSkip?: boolean; + callback?: (value: any, index: any, values: any) => string; + display?: boolean; + fontColor?: ChartColor; + fontFamily?: string; + fontSize?: number; + fontStyle?: string; + labelOffset?: number; + maxRotation?: number; + minRotation?: number; + mirror?: boolean; + padding?: number; + reverse?: boolean; + min?: any; + max?: any; +} +interface AngleLineOptions { + display?: boolean; + color?: ChartColor; + lineWidth?: number; +} + +interface PointLabelOptions { + callback?: (arg: any) => any; + fontColor?: ChartColor; + fontFamily?: string; + fontSize?: number; + fontStyle?: string; +} + +interface TickOptions { + backdropColor?: ChartColor; + backdropPaddingX?: number; + backdropPaddingY?: number; + maxTicksLimit?: number; + showLabelBackdrop?: boolean; +} +interface LinearTickOptions extends TickOptions { + beginAtZero?: boolean; + min?: number; + max?: number; + maxTicksLimit?: number; + stepSize?: number; + suggestedMin?: number; + suggestedMax?: number; +} + +interface LogarithmicTickOptions extends TickOptions { + min?: number; + max?: number; +} + +type ChartColor = string|CanvasGradient|CanvasPattern; + +interface ChartDataSets { + backgroundColor?: ChartColor; + borderWidth?: number; + borderColor?: ChartColor; + borderCapStyle?: string; + borderDash?: number[]; + borderDashOffset?: number; + borderJoinStyle?: string; + data?: number[]|ChartPoint[]; + fill?: boolean; + label?: string; + lineTension?: number; + pointBorderColor?: ChartColor|ChartColor[]; + pointBackgroundColor?: ChartColor|ChartColor[]; + pointBorderWidth?: number|number[]; + pointRadius?: number|number[]; + pointHoverRadius?: number|number[]; + pointHitRadius?: number|number[]; + pointHoverBackgroundColor?: ChartColor|ChartColor[]; + pointHoverBorderColor?: ChartColor|ChartColor[]; + pointHoverBorderWidth?: number|number[]; + pointStyle?: string|string[]|HTMLImageElement|HTMLImageElement[]; + xAxisID?: string; + yAxisID?: string; +} + +interface ChartScales { + type?: string; + display?: boolean; + position?: string; + beforeUpdate?: (scale?: any) => void; + beforeSetDimension?: (scale?: any) => void; + beforeDataLimits?: (scale?: any) => void; + beforeBuildTicks?: (scale?: any) => void; + beforeTickToLabelConversion?: (scale?: any) => void; + beforeCalculateTickRotation?: (scale?: any) => void; + beforeFit?: (scale?: any) => void; + afterUpdate?: (scale?: any) => void; + afterSetDimension?: (scale?: any) => void; + afterDataLimits?: (scale?: any) => void; + afterBuildTicks?: (scale?: any) => void; + afterTickToLabelConversion?: (scale?: any) => void; + afterCalculateTickRotation?: (scale?: any) => void; + afterFit?: (scale?: any) => void; + gridLines?: GridLineOptions; + scaleLabel?: ScaleTitleOptions; + ticks?: TickOptions; + xAxes?: ChartXAxe[]; + yAxes?: ChartYAxe[]; +} + +interface ChartXAxe { + type?: string; + display?: boolean; + id?: string; + stacked?: boolean; + categoryPercentage?: number; + barPercentage?: number; + barThickness?: number; + gridLines?: GridLineOptions; + position?: string; + ticks?: TickOptions; + time?: TimeScale; + scaleLabel?: ScaleTitleOptions; +} + +interface ChartYAxe { + type?: string; + display?: boolean; + id?: string; + stacked?: boolean; + position?: string; + ticks?: TickOptions; + scaleLabel?: ScaleTitleOptions; +} + +interface LinearScale extends ChartScales { + ticks?: LinearTickOptions; +} + +interface LogarithmicScale extends ChartScales { + ticks?: LogarithmicTickOptions; +} + +interface TimeScale extends ChartScales { + format?: string; + displayFormats?: string; + isoWeekday?: boolean; + max?: string; + min?: string; + parser?: string|((arg: any) => any); + round?: string; + tooltipFormat?: string; + unit?: string|TimeUnit; + unitStepSize?: number; +} + +interface RadialLinearScale { + lineArc?: boolean; + angleLines?: AngleLineOptions; + pointLabels?: PointLabelOptions; + ticks?: TickOptions; +} + +declare class Chart { + constructor(context: CanvasRenderingContext2D, options: ChartConfiguration); + config: ChartConfiguration; + destroy: () => {}; + update: (duration?: any, lazy?: any) => {}; + render: (duration?: any, lazy?: any) => {}; + stop: () => {}; + resize: () => {}; + clear: () => {}; + toBase64: () => string; + generateLegend: () => {}; + getElementAtEvent: (e: any) => {}; + getElementsAtEvent: (e: any) => {}[]; + getDatasetAtEvent: (e: any) => {}[]; + + defaults: {global: ChartOptions;} +} diff --git a/src/cifar10-conv.json b/src/cifar10-conv.json new file mode 100644 index 0000000..152d1ac --- /dev/null +++ b/src/cifar10-conv.json @@ -0,0 +1 @@ +[{"layerName":"Convolution","fieldSize":5,"stride":1,"zeroPad":2,"outputDepth":16},{"layerName":"ReLU"},{"layerName":"Max pool","fieldSize":2,"stride":2,"zeroPad":0},{"layerName":"Convolution","fieldSize":5,"stride":1,"zeroPad":2,"outputDepth":20},{"layerName":"ReLU"},{"layerName":"Max pool","fieldSize":2,"stride":2,"zeroPad":0},{"layerName":"Convolution","fieldSize":5,"stride":1,"zeroPad":2,"outputDepth":20},{"layerName":"ReLU"},{"layerName":"Max pool","fieldSize":2,"stride":2,"zeroPad":0},{"layerName":"Flatten"},{"layerName":"Fully connected","hiddenUnits":10}] \ No newline at end of file diff --git a/src/layer_builder.ts b/src/layer_builder.ts new file mode 100644 index 0000000..1b7fd5a --- /dev/null +++ b/src/layer_builder.ts @@ -0,0 +1,372 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// tslint:disable-next-line:max-line-length +import {Array1D, Array2D, Array4D, conv_util, Graph, Initializer, NDArrayInitializer, Tensor, util, VarianceScalingInitializer, ZerosInitializer} from 'deeplearn'; + +/** + * Classes that specify operation parameters, how they affect output shape, + * and methods for building the operations themselves. Any new ops to be added + * to the model builder UI should be added here. + */ + +export type LayerName = 'Fully connected' | 'ReLU' | 'Convolution' | + 'Max pool' | 'Reshape' | 'Flatten'; + +/** + * Creates a layer builder object. + * + * @param layerName The name of the layer to build. + * @param layerBuilderJson An optional LayerBuilder JSON object. This doesn't + * have the prototype methods on them as it comes from serialization. This + * method creates the object with the necessary prototype methods. + */ +export function getLayerBuilder( + layerName: LayerName, layerBuilderJson?: LayerBuilder): LayerBuilder { + let layerBuilder: LayerBuilder; + switch (layerName) { + case 'Fully connected': + layerBuilder = new FullyConnectedLayerBuilder(); + break; + case 'ReLU': + layerBuilder = new ReLULayerBuilder(); + break; + case 'Convolution': + layerBuilder = new Convolution2DLayerBuilder(); + break; + case 'Max pool': + layerBuilder = new MaxPoolLayerBuilder(); + break; + case 'Reshape': + layerBuilder = new ReshapeLayerBuilder(); + break; + case 'Flatten': + layerBuilder = new FlattenLayerBuilder(); + break; + default: + throw new Error('Layer builder for ' + layerName + ' not found.'); + } + + // For layer builders passed as serialized objects, we create the objects and + // set the fields. + if (layerBuilderJson != null) { + for (const prop in layerBuilderJson) { + if (layerBuilderJson.hasOwnProperty(prop)) { + // tslint:disable-next-line:no-any + (layerBuilder as any)[prop] = (layerBuilderJson as any)[prop]; + } + } + } + return layerBuilder; +} + +export interface LayerParam { + label: string; + initialValue(inputShape: number[]): number|string; + type: 'number'|'text'; + min?: number; + max?: number; + setValue(value: number|string): void; + getValue(): number|string; +} + +export type LayerWeightsDict = { + [name: string]: number[] +}; + +export interface LayerBuilder { + layerName: LayerName; + getLayerParams(): LayerParam[]; + getOutputShape(inputShape: number[]): number[]; + addLayer( + g: Graph, network: Tensor, inputShape: number[], index: number, + weights?: LayerWeightsDict|null): Tensor; + // Return null if no errors, otherwise return an array of errors. + validate(inputShape: number[]): string[]|null; +} + +export class FullyConnectedLayerBuilder implements LayerBuilder { + layerName: LayerName = 'Fully connected'; + hiddenUnits: number; + + getLayerParams(): LayerParam[] { + return [{ + label: 'Hidden units', + initialValue: (inputShape: number[]) => 10, + type: 'number', + min: 1, + max: 1000, + setValue: (value: number) => this.hiddenUnits = value, + getValue: () => this.hiddenUnits + }]; + } + + getOutputShape(inputShape: number[]): number[] { + return [this.hiddenUnits]; + } + + addLayer( + g: Graph, network: Tensor, inputShape: number[], index: number, + weights: LayerWeightsDict|null): Tensor { + const inputSize = util.sizeFromShape(inputShape); + const wShape: [number, number] = [this.hiddenUnits, inputSize]; + + let weightsInitializer: Initializer; + let biasInitializer: Initializer; + if (weights != null) { + weightsInitializer = + new NDArrayInitializer(Array2D.new(wShape, weights['W'])); + biasInitializer = new NDArrayInitializer(Array1D.new(weights['b'])); + } else { + weightsInitializer = new VarianceScalingInitializer(); + biasInitializer = new ZerosInitializer(); + } + + const useBias = true; + return g.layers.dense( + 'fc1', network, this.hiddenUnits, null, useBias, weightsInitializer, + biasInitializer); + } + + validate(inputShape: number[]) { + if (inputShape.length !== 1) { + return ['Input shape must be a Array1D.']; + } + return null; + } +} + +export class ReLULayerBuilder implements LayerBuilder { + layerName: LayerName = 'ReLU'; + getLayerParams(): LayerParam[] { + return []; + } + + getOutputShape(inputShape: number[]): number[] { + return inputShape; + } + + addLayer( + g: Graph, network: Tensor, inputShape: number[], index: number, + weights: LayerWeightsDict|null): Tensor { + return g.relu(network); + } + + validate(inputShape: number[]): string[]|null { + return null; + } +} + +export class Convolution2DLayerBuilder implements LayerBuilder { + layerName: LayerName = 'Convolution'; + fieldSize: number; + stride: number; + zeroPad: number; + outputDepth: number; + + getLayerParams(): LayerParam[] { + return [ + { + label: 'Field size', + initialValue: (inputShape: number[]) => 3, + type: 'number', + min: 1, + max: 100, + setValue: (value: number) => this.fieldSize = value, + getValue: () => this.fieldSize + }, + { + label: 'Stride', + initialValue: (inputShape: number[]) => 1, + type: 'number', + min: 1, + max: 100, + setValue: (value: number) => this.stride = value, + getValue: () => this.stride + }, + { + label: 'Zero pad', + initialValue: (inputShape: number[]) => 0, + type: 'number', + min: 0, + max: 100, + setValue: (value: number) => this.zeroPad = value, + getValue: () => this.zeroPad + }, + { + label: 'Output depth', + initialValue: (inputShape: number[]) => + this.outputDepth != null ? this.outputDepth : 1, + type: 'number', + min: 1, + max: 1000, + setValue: (value: number) => this.outputDepth = value, + getValue: () => this.outputDepth + } + ]; + } + + getOutputShape(inputShape: number[]): number[] { + return conv_util.computeOutputShape3D( + inputShape as [number, number, number], this.fieldSize, + this.outputDepth, this.stride, this.zeroPad); + } + + addLayer( + g: Graph, network: Tensor, inputShape: number[], index: number, + weights: LayerWeightsDict|null): Tensor { + const wShape: [number, number, number, number] = + [this.fieldSize, this.fieldSize, inputShape[2], this.outputDepth]; + let w: Array4D; + let b: Array1D; + if (weights != null) { + w = Array4D.new(wShape, weights['W']); + b = Array1D.new(weights['b']); + } else { + w = Array4D.randTruncatedNormal(wShape, 0, 0.1); + b = Array1D.zeros([this.outputDepth]); + } + const wTensor = g.variable('conv2d-' + index + '-w', w); + const bTensor = g.variable('conv2d-' + index + '-b', b); + return g.conv2d( + network, wTensor, bTensor, this.fieldSize, this.outputDepth, + this.stride, this.zeroPad); + } + + validate(inputShape: number[]) { + if (inputShape.length !== 3) { + return ['Input shape must be a Array3D.']; + } + return null; + } +} + +export class MaxPoolLayerBuilder implements LayerBuilder { + layerName: LayerName = 'Max pool'; + fieldSize: number; + stride: number; + zeroPad: number; + + getLayerParams(): LayerParam[] { + return [ + { + label: 'Field size', + initialValue: (inputShape: number[]) => 3, + type: 'number', + min: 1, + max: 100, + setValue: (value: number) => this.fieldSize = value, + getValue: () => this.fieldSize + }, + { + label: 'Stride', + initialValue: (inputShape: number[]) => 1, + type: 'number', + min: 1, + max: 100, + setValue: (value: number) => this.stride = value, + getValue: () => this.stride + }, + { + label: 'Zero pad', + initialValue: (inputShape: number[]) => 0, + type: 'number', + min: 0, + max: 100, + setValue: (value: number) => this.zeroPad = value, + getValue: () => this.zeroPad + } + ]; + } + + getOutputShape(inputShape: number[]): number[] { + return conv_util.computeOutputShape3D( + inputShape as [number, number, number], this.fieldSize, inputShape[2], + this.stride, this.zeroPad); + } + + addLayer( + g: Graph, network: Tensor, inputShape: number[], index: number, + weights: LayerWeightsDict|null): Tensor { + return g.maxPool(network, this.fieldSize, this.stride, this.zeroPad); + } + + validate(inputShape: number[]) { + if (inputShape.length !== 3) { + return ['Input shape must be a Array3D.']; + } + return null; + } +} + +export class ReshapeLayerBuilder implements LayerBuilder { + layerName: LayerName = 'Reshape'; + outputShape: number[]; + getLayerParams() { + return [{ + label: 'Shape (comma separated)', + initialValue: (inputShape: number[]) => inputShape.join(', '), + type: 'text' as 'text', + setValue: (value: string) => this.outputShape = + value.split(',').map((value) => +value), + getValue: () => this.outputShape.join(', ') + }]; + } + + getOutputShape(inputShape: number[]): number[] { + return this.outputShape; + } + + addLayer( + g: Graph, network: Tensor, inputShape: number[], index: number, + weights: LayerWeightsDict|null): Tensor { + return g.reshape(network, this.outputShape); + } + + validate(inputShape: number[]) { + const inputSize = util.sizeFromShape(inputShape); + const outputSize = util.sizeFromShape(this.outputShape); + if (inputSize !== outputSize) { + return [ + `Input size (${inputSize}) must match output size (${outputSize}).` + ]; + } + return null; + } +} + +export class FlattenLayerBuilder implements LayerBuilder { + layerName: LayerName = 'Flatten'; + + getLayerParams(): LayerParam[] { + return []; + } + + getOutputShape(inputShape: number[]): number[] { + return [util.sizeFromShape(inputShape)]; + } + + addLayer( + g: Graph, network: Tensor, inputShape: number[], index: number, + weights: LayerWeightsDict|null): Tensor { + return g.reshape(network, this.getOutputShape(inputShape)); + } + + validate(inputShape: number[]): string[]|null { + return null; + } +} diff --git a/src/mnist-conv.json b/src/mnist-conv.json new file mode 100644 index 0000000..49d826f --- /dev/null +++ b/src/mnist-conv.json @@ -0,0 +1 @@ +[{"layerName":"Convolution","fieldSize":5,"stride":1,"zeroPad":2,"outputDepth":8},{"layerName":"ReLU"},{"layerName":"Max pool","fieldSize":2,"stride":2,"zeroPad":0},{"layerName":"Convolution","fieldSize":5,"stride":1,"zeroPad":2,"outputDepth":16},{"layerName":"ReLU"},{"layerName":"Max pool","fieldSize":2,"stride":2,"zeroPad":0},{"layerName":"Flatten"},{"layerName":"Fully connected","hiddenUnits":10}] diff --git a/src/mnist-fully-connected.json b/src/mnist-fully-connected.json new file mode 100644 index 0000000..8dab4bc --- /dev/null +++ b/src/mnist-fully-connected.json @@ -0,0 +1 @@ +[{"layerName":"Flatten"},{"layerName":"Fully connected","hiddenUnits":128},{"layerName":"ReLU"},{"layerName":"Fully connected","hiddenUnits":32},{"layerName":"ReLU"},{"layerName":"Fully connected","hiddenUnits":10}] diff --git a/src/model-builder-datasets-config.json b/src/model-builder-datasets-config.json new file mode 100644 index 0000000..9dfc901 --- /dev/null +++ b/src/model-builder-datasets-config.json @@ -0,0 +1,64 @@ +{ + "MNIST": { + "data": [{ + "name": "images", + "path": "https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png", + "dataType": "png", + "shape": [28, 28, 1] + }, { + "name": "labels", + "path": "https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8", + "dataType": "uint8", + "shape": [10] + }], + "modelConfigs": { + "Fully connected": { + "path": "mnist-fully-connected.json" + }, + "Convolutional": { + "path": "mnist-conv.json" + } + } + }, + "Fashion MNIST": { + "data": [{ + "name": "images", + "path": "https://storage.googleapis.com/learnjs-data/model-builder/fashion_mnist_images.png", + "dataType": "png", + "shape": [28, 28, 1] + }, { + "name": "labels", + "path": "https://storage.googleapis.com/learnjs-data/model-builder/fashion_mnist_labels_uint8", + "dataType": "uint8", + "shape": [10] + }], + "labelClassNames": ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"], + "modelConfigs": { + "Fully connected": { + "path": "mnist-fully-connected.json" + }, + "Convolutional": { + "path": "mnist-conv.json" + } + } + }, + "CIFAR 10": { + "data": [{ + "name": "images", + "path": "https://storage.googleapis.com/learnjs-data/model-builder/cifar10_images.png", + "dataType": "png", + "shape": [32, 32, 3] + }, { + "name": "labels", + "path": "https://storage.googleapis.com/learnjs-data/model-builder/cifar10_labels_uint8", + "dataType": "uint8", + "shape": [10] + }], + "labelClassNames": ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"], + "modelConfigs": { + "Convolutional": { + "path": "cifar10-conv.json" + } + } + } +} diff --git a/src/model-builder-demo.html b/src/model-builder-demo.html new file mode 100644 index 0000000..5a02ec0 --- /dev/null +++ b/src/model-builder-demo.html @@ -0,0 +1,52 @@ + + + + + + + + + + + + + + deeplearn.js GAN playground + + + + + + + + diff --git a/src/model-builder.html b/src/model-builder.html new file mode 100644 index 0000000..fc690b2 --- /dev/null +++ b/src/model-builder.html @@ -0,0 +1,369 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/model-builder.ts b/src/model-builder.ts new file mode 100644 index 0000000..485c936 --- /dev/null +++ b/src/model-builder.ts @@ -0,0 +1,932 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import './ndarray-image-visualizer'; +import './ndarray-logits-visualizer'; +import './model-layer'; + +// tslint:disable-next-line:max-line-length +import {Array1D, Array3D, DataStats, FeedEntry, Graph, GraphRunner, GraphRunnerEventObserver, InCPUMemoryShuffledInputProviderBuilder, InMemoryDataset, MetricReduction, MomentumOptimizer, SGDOptimizer, RMSPropOptimizer, AdagradOptimizer, NDArray, NDArrayMath, NDArrayMathCPU, NDArrayMathGPU, Optimizer, Scalar, Session, Tensor, util, xhr_dataset, XhrDataset, XhrDatasetConfig} from 'deeplearn'; +import {NDArrayImageVisualizer} from './ndarray-image-visualizer'; +import {NDArrayLogitsVisualizer} from './ndarray-logits-visualizer'; +import {PolymerElement, PolymerHTMLElement} from './polymer-spec'; + +import {LayerBuilder, LayerWeightsDict} from './layer_builder'; +import {ModelLayer} from './model-layer'; +import * as model_builder_util from './model_builder_util'; +import {Normalization} from './tensorflow'; + +const DATASETS_CONFIG_JSON = 'model-builder-datasets-config.json'; + +/** How often to evaluate the model against test data. */ +const EVAL_INTERVAL_MS = 1500; +/** How often to compute the cost. Downloading the cost stalls the GPU. */ +const COST_INTERVAL_MS = 500; +/** How many inference examples to show when evaluating accuracy. */ +const INFERENCE_EXAMPLE_COUNT = 15; +const INFERENCE_IMAGE_SIZE_PX = 100; +/** + * How often to show inference examples. This should be less often than + * EVAL_INTERVAL_MS as we only show inference examples during an eval. + */ +const INFERENCE_EXAMPLE_INTERVAL_MS = 3000; + +// Smoothing factor for the examples/s standalone text statistic. +const EXAMPLE_SEC_STAT_SMOOTHING_FACTOR = .7; + +const TRAIN_TEST_RATIO = 5 / 6; + +const IMAGE_DATA_INDEX = 0; +const LABEL_DATA_INDEX = 1; + +// tslint:disable-next-line:variable-name +export let ModelBuilderPolymer: new () => PolymerHTMLElement = PolymerElement({ + is: 'model-builder', + properties: { + inputShapeDisplay: String, + isValid: Boolean, + inferencesPerSec: Number, + inferenceDuration: Number, + examplesTrained: Number, + examplesPerSec: Number, + totalTimeSec: String, + applicationState: Number, + modelInitialized: Boolean, + showTrainStats: Boolean, + datasetDownloaded: Boolean, + datasetNames: Array, + selectedDatasetName: String, + modelNames: Array, + selectedOptimizerName: String, + optimizerNames: Array, + learningRate: Number, + momentum: Number, + needMomentum: Boolean, + gamma: Number, + needGamma: Boolean, + batchSize: Number, + selectedModelName: String, + selectedNormalizationOption: + {type: Number, value: Normalization.NORMALIZATION_NEGATIVE_ONE_TO_ONE}, + // Stats + showDatasetStats: Boolean, + statsInputMin: Number, + statsInputMax: Number, + statsInputShapeDisplay: String, + statsLabelShapeDisplay: String, + statsExampleCount: Number, + } +}); + +export enum ApplicationState { + IDLE = 1, + TRAINING = 2 +} + +export class ModelBuilder extends ModelBuilderPolymer { + // Polymer properties. + private isValid: boolean; + private totalTimeSec: string; + private applicationState: ApplicationState; + private modelInitialized: boolean; + private showTrainStats: boolean; + private selectedNormalizationOption: number; + + // Datasets and models. + private graphRunner: GraphRunner; + private graph: Graph; + private session: Session; + private optimizer: Optimizer; + private xTensor: Tensor; + private labelTensor: Tensor; + private costTensor: Tensor; + private accuracyTensor: Tensor; + private predictionTensor: Tensor; + + private datasetDownloaded: boolean; + private datasetNames: string[]; + private selectedDatasetName: string; + private modelNames: string[]; + private selectedModelName: string; + private optimizerNames: string[]; + private selectedOptimizerName: string; + private loadedWeights: LayerWeightsDict[]|null; + private dataSets: {[datasetName: string]: InMemoryDataset}; + private dataSet: InMemoryDataset; + private xhrDatasetConfigs: {[datasetName: string]: XhrDatasetConfig}; + private datasetStats: DataStats[]; + private learingRate: number; + private momentum: number; + private needMomentum: boolean; + private gamma: number; + private needGamma: boolean; + private batchSize: number; + + // Stats. + private showDatasetStats: boolean; + private statsInputRange: string; + private statsInputShapeDisplay: string; + private statsLabelShapeDisplay: string; + private statsExampleCount: number; + + // Charts. + private costChart: Chart; + private accuracyChart: Chart; + private examplesPerSecChart: Chart; + private costChartData: ChartPoint[]; + private accuracyChartData: ChartPoint[]; + private examplesPerSecChartData: ChartPoint[]; + + private trainButton: HTMLButtonElement; + + // Visualizers. + private inputNDArrayVisualizers: NDArrayImageVisualizer[]; + private outputNDArrayVisualizers: NDArrayLogitsVisualizer[]; + + private inputShape: number[]; + private labelShape: number[]; + private examplesPerSec: number; + private examplesTrained: number; + private inferencesPerSec: number; + private inferenceDuration: number; + + private inputLayer: ModelLayer; + private hiddenLayers: ModelLayer[]; + + private layersContainer: HTMLDivElement; + + private math: NDArrayMath; + // Keep one instance of each NDArrayMath so we don't create a user-initiated + // number of NDArrayMathGPU's. + private mathGPU: NDArrayMathGPU; + private mathCPU: NDArrayMathCPU; + + ready() { + this.mathGPU = new NDArrayMathGPU(); + this.mathCPU = new NDArrayMathCPU(); + this.math = this.mathGPU; + + const eventObserver: GraphRunnerEventObserver = { + batchesTrainedCallback: (batchesTrained: number) => + this.displayBatchesTrained(batchesTrained), + avgCostCallback: (avgCost: Scalar) => this.displayCost(avgCost), + metricCallback: (metric: Scalar) => this.displayAccuracy(metric), + inferenceExamplesCallback: + (inputFeeds: FeedEntry[][], inferenceOutputs: NDArray[]) => + this.displayInferenceExamplesOutput(inputFeeds, inferenceOutputs), + inferenceExamplesPerSecCallback: (examplesPerSec: number) => + this.displayInferenceExamplesPerSec(examplesPerSec), + trainExamplesPerSecCallback: (examplesPerSec: number) => + this.displayExamplesPerSec(examplesPerSec), + totalTimeCallback: (totalTimeSec: number) => this.totalTimeSec = + totalTimeSec.toFixed(1), + }; + this.graphRunner = new GraphRunner(this.math, this.session, eventObserver); + this.optimizer = new MomentumOptimizer(this.learingRate, this.momentum); + + // Set up datasets. + this.populateDatasets(); + + this.querySelector('#dataset-dropdown .dropdown-content') + .addEventListener( + // tslint:disable-next-line:no-any + 'iron-activate', (event: any) => { + // Update the dataset. + const datasetName = event.detail.selected; + this.updateSelectedDataset(datasetName); + + // TODO(nsthorat): Remember the last model used for each dataset. + this.removeAllLayers(); + }); + this.querySelector('#model-dropdown .dropdown-content') + .addEventListener( + // tslint:disable-next-line:no-any + 'iron-activate', (event: any) => { + // Update the model. + const modelName = event.detail.selected; + this.updateSelectedModel(modelName); + }); + + { + const normalizationDropdown = + this.querySelector('#normalization-dropdown .dropdown-content'); + // tslint:disable-next-line:no-any + normalizationDropdown.addEventListener('iron-activate', (event: any) => { + const selectedNormalizationOption = event.detail.selected; + this.applyNormalization(selectedNormalizationOption); + this.setupDatasetStats(); + }); + } + this.querySelector("#optimizer-dropdown .dropdown-content") + // tslint:disable-next-line:no-any + .addEventListener('iron-activate', (event: any) => { + // Activate, deactivate hyper parameter inputs. + this.refreshHyperParamRequirements(event.detail.selected); + }); + this.learningRate = 0.1; + this.momentum = 0.1; + this.needMomentum = true; + this.gamma = 0.1; + this.needGamma = false; + this.batchSize = 64; + // Default optimizer is momentum + this.selectedOptimizerName = "momentum"; + this.optimizerNames = ["sgd", "momentum", "rmsprop", "adagrad"]; + + this.applicationState = ApplicationState.IDLE; + this.loadedWeights = null; + this.modelInitialized = false; + this.showTrainStats = false; + this.showDatasetStats = false; + + const addButton = this.querySelector('#add-layer'); + addButton.addEventListener('click', () => this.addLayer()); + + const downloadModelButton = this.querySelector('#download-model'); + downloadModelButton.addEventListener('click', () => this.downloadModel()); + const uploadModelButton = this.querySelector('#upload-model'); + uploadModelButton.addEventListener('click', () => this.uploadModel()); + this.setupUploadModelButton(); + + const uploadWeightsButton = this.querySelector('#upload-weights'); + uploadWeightsButton.addEventListener('click', () => this.uploadWeights()); + this.setupUploadWeightsButton(); + + const stopButton = this.querySelector('#stop'); + stopButton.addEventListener('click', () => { + this.applicationState = ApplicationState.IDLE; + this.graphRunner.stopTraining(); + }); + + this.trainButton = this.querySelector('#train') as HTMLButtonElement; + this.trainButton.addEventListener('click', () => { + this.createModel(); + this.startTraining(); + }); + + this.querySelector('#environment-toggle') + .addEventListener('change', (event) => { + this.math = + // tslint:disable-next-line:no-any + (event.target as any).active ? this.mathGPU : this.mathCPU; + this.graphRunner.setMath(this.math); + }); + + this.hiddenLayers = []; + this.examplesPerSec = 0; + this.inferencesPerSec = 0; + } + + isTraining(applicationState: ApplicationState): boolean { + return applicationState === ApplicationState.TRAINING; + } + + isIdle(applicationState: ApplicationState): boolean { + return applicationState === ApplicationState.IDLE; + } + + + + private getTestData(): NDArray[][] { + const data = this.dataSet.getData(); + if (data == null) { + return null; + } + const [images, labels] = this.dataSet.getData() as [NDArray[], NDArray[]]; + + const start = Math.floor(TRAIN_TEST_RATIO * images.length); + + return [images.slice(start), labels.slice(start)]; + } + + private getTrainingData(): NDArray[][] { + const [images, labels] = this.dataSet.getData() as [NDArray[], NDArray[]]; + + const end = Math.floor(TRAIN_TEST_RATIO * images.length); + + return [images.slice(0, end), labels.slice(0, end)]; + } + + private startInference() { + const testData = this.getTestData(); + if (testData == null) { + // Dataset not ready yet. + return; + } + if (this.isValid && (testData != null)) { + const inferenceShuffledInputProviderGenerator = + new InCPUMemoryShuffledInputProviderBuilder(testData); + const [inferenceInputProvider, inferenceLabelProvider] = + inferenceShuffledInputProviderGenerator.getInputProviders(); + + const inferenceFeeds = [ + {tensor: this.xTensor, data: inferenceInputProvider}, + {tensor: this.labelTensor, data: inferenceLabelProvider} + ]; + + this.graphRunner.infer( + this.predictionTensor, inferenceFeeds, INFERENCE_EXAMPLE_INTERVAL_MS, + INFERENCE_EXAMPLE_COUNT); + } + } + + private resetHyperParamRequirements() { + this.needMomentum = false; + this.needGamma = false; + } + + /** + * Set flag to disable input by optimizer selection. + */ + private refreshHyperParamRequirements(optimizerName: string) { + this.resetHyperParamRequirements(); + switch (optimizerName) { + case "sgd": { + // No additional hyper parameters + break; + } + case "momentum": { + this.needMomentum = true; + break; + } + case "rmsprop": { + this.needMomentum = true; + this.needGamma = true; + break; + } + case "adagrad": { + this.needMomentum = true; + break; + } + default: { + throw new Error(`Unknown optimizer "${this.selectedOptimizerName}"`); + } + } + } + + private createOptimizer() { + switch (this.selectedOptimizerName) { + case 'sgd': { + return new SGDOptimizer(+this.learningRate); + } + case 'momentum': { + return new MomentumOptimizer(+this.learningRate, +this.momentum); + } + case 'rmsprop': { + return new RMSPropOptimizer(+this.learningRate, +this.gamma); + } + case 'adagrad': { + return new AdagradOptimizer(+this.learningRate, +this.momentum); + } + default: { + throw new Error(`Unknown optimizer "${this.selectedOptimizerName}"`); + } + } + } + + private startTraining() { + const trainingData = this.getTrainingData(); + const testData = this.getTestData(); + + // Recreate optimizer with the selected optimizer and hyperparameters. + this.optimizer = this.createOptimizer(); + + if (this.isValid && (trainingData != null) && (testData != null)) { + this.recreateCharts(); + this.graphRunner.resetStatistics(); + + const trainingShuffledInputProviderGenerator = + new InCPUMemoryShuffledInputProviderBuilder(trainingData); + const [trainInputProvider, trainLabelProvider] = + trainingShuffledInputProviderGenerator.getInputProviders(); + + const trainFeeds = [ + {tensor: this.xTensor, data: trainInputProvider}, + {tensor: this.labelTensor, data: trainLabelProvider} + ]; + + const accuracyShuffledInputProviderGenerator = + new InCPUMemoryShuffledInputProviderBuilder(testData); + const [accuracyInputProvider, accuracyLabelProvider] = + accuracyShuffledInputProviderGenerator.getInputProviders(); + + const accuracyFeeds = [ + {tensor: this.xTensor, data: accuracyInputProvider}, + {tensor: this.labelTensor, data: accuracyLabelProvider} + ]; + + this.graphRunner.train( + this.costTensor, trainFeeds, this.batchSize, this.optimizer, + undefined /** numBatches */, this.accuracyTensor, accuracyFeeds, + this.batchSize, MetricReduction.MEAN, EVAL_INTERVAL_MS, + COST_INTERVAL_MS); + + this.showTrainStats = true; + this.applicationState = ApplicationState.TRAINING; + } + } + + private createModel() { + if (this.session != null) { + this.session.dispose(); + } + + this.modelInitialized = false; + if (this.isValid === false) { + return; + } + + this.graph = new Graph(); + const g = this.graph; + this.xTensor = g.placeholder('input', this.inputShape); + this.labelTensor = g.placeholder('label', this.labelShape); + + let network = this.xTensor; + + for (let i = 0; i < this.hiddenLayers.length; i++) { + let weights: LayerWeightsDict|null = null; + if (this.loadedWeights != null) { + weights = this.loadedWeights[i]; + } + network = this.hiddenLayers[i].addLayer(g, network, i, weights); + } + this.predictionTensor = network; + this.costTensor = + g.softmaxCrossEntropyCost(this.predictionTensor, this.labelTensor); + this.accuracyTensor = + g.argmaxEquals(this.predictionTensor, this.labelTensor); + + this.loadedWeights = null; + + this.session = new Session(g, this.math); + this.graphRunner.setSession(this.session); + + this.startInference(); + + this.modelInitialized = true; + } + + private populateDatasets() { + this.dataSets = {}; + xhr_dataset.getXhrDatasetConfig(DATASETS_CONFIG_JSON) + .then( + xhrDatasetConfigs => { + for (const datasetName in xhrDatasetConfigs) { + if (xhrDatasetConfigs.hasOwnProperty(datasetName)) { + this.dataSets[datasetName] = + new XhrDataset(xhrDatasetConfigs[datasetName]); + } + } + this.datasetNames = Object.keys(this.dataSets); + this.selectedDatasetName = this.datasetNames[0]; + this.xhrDatasetConfigs = xhrDatasetConfigs; + this.updateSelectedDataset(this.datasetNames[0]); + }, + error => { + throw new Error('Dataset config could not be loaded: ' + error); + }); + } + + private updateSelectedDataset(datasetName: string) { + if (this.dataSet != null) { + this.dataSet.removeNormalization(IMAGE_DATA_INDEX); + } + + this.graphRunner.stopTraining(); + this.graphRunner.stopInferring(); + + if (this.dataSet != null) { + this.dataSet.dispose(); + } + + this.selectedDatasetName = datasetName; + this.selectedModelName = ''; + this.dataSet = this.dataSets[datasetName]; + this.datasetDownloaded = false; + this.showDatasetStats = false; + + this.dataSet.fetchData().then(() => { + this.datasetDownloaded = true; + this.applyNormalization(this.selectedNormalizationOption); + this.setupDatasetStats(); + if (this.isValid) { + this.createModel(); + this.startInference(); + } + // Get prebuilt models. + this.populateModelDropdown(); + }); + + this.inputShape = this.dataSet.getDataShape(IMAGE_DATA_INDEX); + this.labelShape = this.dataSet.getDataShape(LABEL_DATA_INDEX); + + this.layersContainer = + this.querySelector('#hidden-layers') as HTMLDivElement; + + this.inputLayer = this.querySelector('#input-layer') as ModelLayer; + this.inputLayer.outputShapeDisplay = + model_builder_util.getDisplayShape(this.inputShape); + + const labelShapeDisplay = + model_builder_util.getDisplayShape(this.labelShape); + const costLayer = this.querySelector('#cost-layer') as ModelLayer; + costLayer.inputShapeDisplay = labelShapeDisplay; + costLayer.outputShapeDisplay = labelShapeDisplay; + + const outputLayer = this.querySelector('#output-layer') as ModelLayer; + outputLayer.inputShapeDisplay = labelShapeDisplay; + + // Setup the inference example container. + // TODO(nsthorat): Generalize this. + const inferenceContainer = + this.querySelector('#inference-container') as HTMLElement; + inferenceContainer.innerHTML = ''; + this.inputNDArrayVisualizers = []; + this.outputNDArrayVisualizers = []; + for (let i = 0; i < INFERENCE_EXAMPLE_COUNT; i++) { + const inferenceExampleElement = document.createElement('div'); + inferenceExampleElement.className = 'inference-example'; + + // Set up the input visualizer. + const ndarrayImageVisualizer = + document.createElement('ndarray-image-visualizer') as + NDArrayImageVisualizer; + ndarrayImageVisualizer.setShape(this.inputShape); + ndarrayImageVisualizer.setSize( + INFERENCE_IMAGE_SIZE_PX, INFERENCE_IMAGE_SIZE_PX); + this.inputNDArrayVisualizers.push(ndarrayImageVisualizer); + inferenceExampleElement.appendChild(ndarrayImageVisualizer); + + // Set up the output ndarray visualizer. + const ndarrayLogitsVisualizer = + document.createElement('ndarray-logits-visualizer') as + NDArrayLogitsVisualizer; + ndarrayLogitsVisualizer.initialize( + INFERENCE_IMAGE_SIZE_PX, INFERENCE_IMAGE_SIZE_PX); + this.outputNDArrayVisualizers.push(ndarrayLogitsVisualizer); + inferenceExampleElement.appendChild(ndarrayLogitsVisualizer); + + inferenceContainer.appendChild(inferenceExampleElement); + } + } + + private populateModelDropdown() { + const modelNames = ['Custom']; + + const modelConfigs = + this.xhrDatasetConfigs[this.selectedDatasetName].modelConfigs; + for (const modelName in modelConfigs) { + if (modelConfigs.hasOwnProperty(modelName)) { + modelNames.push(modelName); + } + } + this.modelNames = modelNames; + this.selectedModelName = modelNames[modelNames.length - 1]; + this.updateSelectedModel(this.selectedModelName); + } + + private updateSelectedModel(modelName: string) { + this.removeAllLayers(); + if (modelName === 'Custom') { + // TODO(nsthorat): Remember the custom layers. + return; + } + + this.loadModelFromPath(this.xhrDatasetConfigs[this.selectedDatasetName] + .modelConfigs[modelName] + .path); + } + + private loadModelFromPath(modelPath: string) { + const xhr = new XMLHttpRequest(); + xhr.open('GET', modelPath); + + xhr.onload = () => { + this.loadModelFromJson(xhr.responseText); + }; + xhr.onerror = (error) => { + throw new Error( + 'Model could not be fetched from ' + modelPath + ': ' + error); + }; + xhr.send(); + } + + private setupDatasetStats() { + this.datasetStats = this.dataSet.getStats(); + this.statsExampleCount = this.datasetStats[IMAGE_DATA_INDEX].exampleCount; + this.statsInputRange = '[' + this.datasetStats[IMAGE_DATA_INDEX].inputMin + + ', ' + this.datasetStats[IMAGE_DATA_INDEX].inputMax + ']'; + this.statsInputShapeDisplay = model_builder_util.getDisplayShape( + this.datasetStats[IMAGE_DATA_INDEX].shape); + this.statsLabelShapeDisplay = model_builder_util.getDisplayShape( + this.datasetStats[LABEL_DATA_INDEX].shape); + this.showDatasetStats = true; + } + + private applyNormalization(selectedNormalizationOption: number) { + switch (selectedNormalizationOption) { + case Normalization.NORMALIZATION_NEGATIVE_ONE_TO_ONE: { + this.dataSet.normalizeWithinBounds(IMAGE_DATA_INDEX, -1, 1); + break; + } + case Normalization.NORMALIZATION_ZERO_TO_ONE: { + this.dataSet.normalizeWithinBounds(IMAGE_DATA_INDEX, 0, 1); + break; + } + case Normalization.NORMALIZATION_NONE: { + this.dataSet.removeNormalization(IMAGE_DATA_INDEX); + break; + } + default: { throw new Error('Normalization option must be 0, 1, or 2'); } + } + this.setupDatasetStats(); + } + + private recreateCharts() { + this.costChartData = []; + if (this.costChart != null) { + this.costChart.destroy(); + } + this.costChart = + this.createChart('cost-chart', 'Cost', this.costChartData, 0); + + if (this.accuracyChart != null) { + this.accuracyChart.destroy(); + } + this.accuracyChartData = []; + this.accuracyChart = this.createChart( + 'accuracy-chart', 'Accuracy', this.accuracyChartData, 0, 100); + + if (this.examplesPerSecChart != null) { + this.examplesPerSecChart.destroy(); + } + this.examplesPerSecChartData = []; + this.examplesPerSecChart = this.createChart( + 'examplespersec-chart', 'Examples/sec', this.examplesPerSecChartData, + 0); + } + + private createChart( + canvasId: string, label: string, data: ChartData[], min?: number, + max?: number): Chart { + const context = (document.getElementById(canvasId) as HTMLCanvasElement) + .getContext('2d') as CanvasRenderingContext2D; + return new Chart(context, { + type: 'line', + data: { + datasets: [{ + data, + fill: false, + label, + pointRadius: 0, + borderColor: 'rgba(75,192,192,1)', + borderWidth: 1, + lineTension: 0, + pointHitRadius: 8 + }] + }, + options: { + animation: {duration: 0}, + responsive: false, + scales: { + xAxes: [{type: 'linear', position: 'bottom'}], + yAxes: [{ + ticks: { + max, + min, + } + }] + } + } + }); + } + + displayBatchesTrained(totalBatchesTrained: number) { + this.examplesTrained = this.batchSize * totalBatchesTrained; + } + + displayCost(avgCost: Scalar) { + this.costChartData.push( + {x: this.graphRunner.getTotalBatchesTrained(), y: avgCost.get()}); + this.costChart.update(); + } + + displayAccuracy(accuracy: Scalar) { + this.accuracyChartData.push({ + x: this.graphRunner.getTotalBatchesTrained(), + y: accuracy.get() * 100 + }); + this.accuracyChart.update(); + } + + displayInferenceExamplesPerSec(examplesPerSec: number) { + this.inferencesPerSec = + this.smoothExamplesPerSec(this.inferencesPerSec, examplesPerSec); + this.inferenceDuration = Number((1000 / examplesPerSec).toPrecision(3)); + } + + displayExamplesPerSec(examplesPerSec: number) { + this.examplesPerSecChartData.push( + {x: this.graphRunner.getTotalBatchesTrained(), y: examplesPerSec}); + this.examplesPerSecChart.update(); + this.examplesPerSec = + this.smoothExamplesPerSec(this.examplesPerSec, examplesPerSec); + } + + private smoothExamplesPerSec( + lastExamplesPerSec: number, nextExamplesPerSec: number): number { + return Number((EXAMPLE_SEC_STAT_SMOOTHING_FACTOR * lastExamplesPerSec + + (1 - EXAMPLE_SEC_STAT_SMOOTHING_FACTOR) * nextExamplesPerSec) + .toPrecision(3)); + } + + displayInferenceExamplesOutput( + inputFeeds: FeedEntry[][], inferenceOutputs: NDArray[]) { + let images: Array3D[] = []; + const logits: Array1D[] = []; + const labels: Array1D[] = []; + for (let i = 0; i < inputFeeds.length; i++) { + images.push(inputFeeds[i][IMAGE_DATA_INDEX].data as Array3D); + labels.push(inputFeeds[i][LABEL_DATA_INDEX].data as Array1D); + logits.push(inferenceOutputs[i] as Array1D); + } + + images = + this.dataSet.unnormalizeExamples(images, IMAGE_DATA_INDEX) as Array3D[]; + + // Draw the images. + for (let i = 0; i < inputFeeds.length; i++) { + this.inputNDArrayVisualizers[i].saveImageDataFromNDArray(images[i]); + } + + // Draw the logits. + for (let i = 0; i < inputFeeds.length; i++) { + const softmaxLogits = this.math.softmax(logits[i]); + + this.outputNDArrayVisualizers[i].drawLogits( + softmaxLogits, labels[i], + this.xhrDatasetConfigs[this.selectedDatasetName].labelClassNames); + this.inputNDArrayVisualizers[i].draw(); + + softmaxLogits.dispose(); + } + } + + addLayer(): ModelLayer { + const modelLayer = document.createElement('model-layer') as ModelLayer; + modelLayer.className = 'layer'; + this.layersContainer.appendChild(modelLayer); + + const lastHiddenLayer = this.hiddenLayers[this.hiddenLayers.length - 1]; + const lastOutputShape = lastHiddenLayer != null ? + lastHiddenLayer.getOutputShape() : + this.inputShape; + this.hiddenLayers.push(modelLayer); + modelLayer.initialize(this, lastOutputShape); + return modelLayer; + } + + removeLayer(modelLayer: ModelLayer) { + this.layersContainer.removeChild(modelLayer); + this.hiddenLayers.splice(this.hiddenLayers.indexOf(modelLayer), 1); + this.layerParamChanged(); + } + + private removeAllLayers() { + for (let i = 0; i < this.hiddenLayers.length; i++) { + this.layersContainer.removeChild(this.hiddenLayers[i]); + } + this.hiddenLayers = []; + this.layerParamChanged(); + } + + private validateModel() { + let valid = true; + for (let i = 0; i < this.hiddenLayers.length; ++i) { + valid = valid && this.hiddenLayers[i].isValid(); + } + if (this.hiddenLayers.length > 0) { + const lastLayer = this.hiddenLayers[this.hiddenLayers.length - 1]; + valid = valid && + util.arraysEqual(this.labelShape, lastLayer.getOutputShape()); + } + this.isValid = valid && (this.hiddenLayers.length > 0); + } + + layerParamChanged() { + // Go through each of the model layers and propagate shapes. + let lastOutputShape = this.inputShape; + for (let i = 0; i < this.hiddenLayers.length; i++) { + lastOutputShape = this.hiddenLayers[i].setInputShape(lastOutputShape); + } + this.validateModel(); + + if (this.isValid) { + this.createModel(); + this.startInference(); + } + } + + private downloadModel() { + const modelJson = this.getModelAsJson(); + const blob = new Blob([modelJson], {type: 'text/json'}); + const textFile = window.URL.createObjectURL(blob); + + // Force a download. + const a = document.createElement('a'); + document.body.appendChild(a); + a.style.display = 'none'; + a.href = textFile; + // tslint:disable-next-line:no-any + (a as any).download = this.selectedDatasetName + '_model'; + a.click(); + + document.body.removeChild(a); + window.URL.revokeObjectURL(textFile); + } + + private uploadModel() { + (this.querySelector('#model-file') as HTMLInputElement).click(); + } + + private setupUploadModelButton() { + // Show and setup the load view button. + const fileInput = this.querySelector('#model-file') as HTMLInputElement; + fileInput.addEventListener('change', event => { + const file = fileInput.files[0]; + // Clear out the value of the file chooser. This ensures that if the user + // selects the same file, we'll re-read it. + fileInput.value = ''; + const fileReader = new FileReader(); + fileReader.onload = (evt) => { + this.removeAllLayers(); + const modelJson: string = fileReader.result; + this.loadModelFromJson(modelJson); + }; + fileReader.readAsText(file); + }); + } + + private getModelAsJson(): string { + const layerBuilders: LayerBuilder[] = []; + for (let i = 0; i < this.hiddenLayers.length; i++) { + layerBuilders.push(this.hiddenLayers[i].layerBuilder); + } + return JSON.stringify(layerBuilders); + } + + private loadModelFromJson(modelJson: string) { + let lastOutputShape = this.inputShape; + + const layerBuilders = JSON.parse(modelJson) as LayerBuilder[]; + for (let i = 0; i < layerBuilders.length; i++) { + const modelLayer = this.addLayer(); + modelLayer.loadParamsFromLayerBuilder(lastOutputShape, layerBuilders[i]); + lastOutputShape = this.hiddenLayers[i].setInputShape(lastOutputShape); + } + this.validateModel(); + } + + private uploadWeights() { + (this.querySelector('#weights-file') as HTMLInputElement).click(); + } + + private setupUploadWeightsButton() { + // Show and setup the load view button. + const fileInput = this.querySelector('#weights-file') as HTMLInputElement; + fileInput.addEventListener('change', event => { + const file = fileInput.files[0]; + // Clear out the value of the file chooser. This ensures that if the user + // selects the same file, we'll re-read it. + fileInput.value = ''; + const fileReader = new FileReader(); + fileReader.onload = (evt) => { + const weightsJson: string = fileReader.result; + this.loadWeightsFromJson(weightsJson); + this.createModel(); + this.startInference(); + }; + fileReader.readAsText(file); + }); + } + + private loadWeightsFromJson(weightsJson: string) { + this.loadedWeights = JSON.parse(weightsJson) as LayerWeightsDict[]; + } +} + +document.registerElement(ModelBuilder.prototype.is, ModelBuilder); diff --git a/src/model-layer.html b/src/model-layer.html new file mode 100644 index 0000000..44dc475 --- /dev/null +++ b/src/model-layer.html @@ -0,0 +1,135 @@ + + + + + + + + + + + diff --git a/src/model-layer.ts b/src/model-layer.ts new file mode 100644 index 0000000..2268641 --- /dev/null +++ b/src/model-layer.ts @@ -0,0 +1,192 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Graph, Tensor} from 'deeplearn'; +import {PolymerElement, PolymerHTMLElement} from './polymer-spec'; + +import * as layer_builder from './layer_builder'; +import {LayerBuilder, LayerName, LayerWeightsDict} from './layer_builder'; +import {ModelBuilder} from './model-builder'; +import * as model_builder_util from './model_builder_util'; + +// tslint:disable-next-line:variable-name +export let ModelLayerPolymer: new () => PolymerHTMLElement = PolymerElement({ + is: 'model-layer', + properties: { + layerName: String, + inputShapeDisplay: String, + outputShapeDisplay: String, + isStatic: {type: Boolean, value: false}, + layerNames: Array, + selectedLayerName: String, + hasError: {type: Boolean, value: false}, + errorMessages: Array, + } +}); + +export class ModelLayer extends ModelLayerPolymer { + // Polymer properties. + inputShapeDisplay: string; + outputShapeDisplay: string; + private layerNames: LayerName[]; + private selectedLayerName: LayerName; + private hasError: boolean; + private errorMessages: string[]; + + private modelBuilder: ModelBuilder; + layerBuilder: LayerBuilder; + private inputShape: number[]; + private outputShape: number[]; + + private paramContainer: HTMLDivElement; + + initialize(modelBuilder: ModelBuilder, inputShape: number[]) { + this.modelBuilder = modelBuilder; + this.paramContainer = + this.querySelector('.param-container') as HTMLDivElement; + this.layerNames = [ + 'Fully connected', 'ReLU', 'Convolution', 'Max pool', 'Reshape', 'Flatten' + ]; + this.inputShape = inputShape; + this.buildParamsUI('Fully connected', this.inputShape); + + this.querySelector('.dropdown-content') + .addEventListener( + // tslint:disable-next-line:no-any + 'iron-activate', (event: any) => { + this.buildParamsUI( + event.detail.selected as LayerName, this.inputShape); + }); + + this.querySelector('#remove-layer').addEventListener('click', (event) => { + modelBuilder.removeLayer(this); + }); + } + + setInputShape(shape: number[]): number[] { + this.inputShape = shape; + this.inputShapeDisplay = + model_builder_util.getDisplayShape(this.inputShape); + + const errors: string[] = []; + const validationErrors = this.layerBuilder.validate(this.inputShape); + if (validationErrors != null) { + for (let i = 0; i < validationErrors.length; i++) { + errors.push('Error: ' + validationErrors[i]); + } + } + + try { + this.outputShape = this.layerBuilder.getOutputShape(this.inputShape); + } catch (e) { + errors.push(e); + } + this.outputShapeDisplay = + model_builder_util.getDisplayShape(this.outputShape); + + if (errors.length > 0) { + this.hasError = true; + this.errorMessages = errors; + } else { + this.hasError = false; + this.errorMessages = []; + } + + return this.outputShape; + } + + isValid(): boolean { + return !this.hasError; + } + + getOutputShape(): number[] { + return this.outputShape; + } + + addLayer( + g: Graph, network: Tensor, index: number, + weights: LayerWeightsDict|null): Tensor { + return this.layerBuilder.addLayer( + g, network, this.inputShape, index, weights); + } + + /** + * Build parameters for the UI for a given op type. This is called when the + * op is added, and when the op type changes. + */ + buildParamsUI( + layerName: LayerName, inputShape: number[], + layerBuilderJson?: LayerBuilder) { + this.selectedLayerName = layerName; + + this.layerBuilder = + layer_builder.getLayerBuilder(layerName, layerBuilderJson); + + // Clear any existing parameters. + this.paramContainer.innerHTML = ''; + + // Add all the parameters to the UI. + const layerParams = this.layerBuilder.getLayerParams(); + for (let i = 0; i < layerParams.length; i++) { + const initialValue = layerBuilderJson != null ? + layerParams[i].getValue() : + layerParams[i].initialValue(inputShape); + this.addParamField( + layerParams[i].label, initialValue, layerParams[i].setValue, + layerParams[i].type, layerParams[i].min, layerParams[i].max); + } + this.modelBuilder.layerParamChanged(); + } + + loadParamsFromLayerBuilder( + inputShape: number[], layerBuilderJson: LayerBuilder) { + this.buildParamsUI( + layerBuilderJson.layerName, inputShape, layerBuilderJson); + } + + private addParamField( + label: string, initialValue: number|string, + setValue: (value: number|string) => void, type: 'number'|'text', + min?: number, max?: number) { + const input = document.createElement('paper-input'); + input.setAttribute('always-float-label', 'true'); + input.setAttribute('label', label); + input.setAttribute('value', '' + initialValue); + input.setAttribute('type', type); + if (type === 'number') { + input.setAttribute('min', '' + min); + input.setAttribute('max', '' + max); + } + input.className = 'param-input'; + this.paramContainer.appendChild(input); + + // Update the parent when this changes. + input.addEventListener('input', (event) => { + if (type === 'number') { + // tslint:disable-next-line:no-any + setValue((event.target as any).valueAsNumber as number); + } else { + // tslint:disable-next-line:no-any + setValue((event.target as any).value as string); + } + this.modelBuilder.layerParamChanged(); + }); + setValue(initialValue); + } +} + +document.registerElement(ModelLayer.prototype.is, ModelLayer); diff --git a/src/model_builder_util.ts b/src/model_builder_util.ts new file mode 100644 index 0000000..348e203 --- /dev/null +++ b/src/model_builder_util.ts @@ -0,0 +1,20 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +export function getDisplayShape(shape: number[]) { + return '[' + shape + ']'; +} diff --git a/src/ndarray-image-visualizer.html b/src/ndarray-image-visualizer.html new file mode 100644 index 0000000..304a9b2 --- /dev/null +++ b/src/ndarray-image-visualizer.html @@ -0,0 +1,25 @@ + + + + + + diff --git a/src/ndarray-image-visualizer.ts b/src/ndarray-image-visualizer.ts new file mode 100644 index 0000000..4971c16 --- /dev/null +++ b/src/ndarray-image-visualizer.ts @@ -0,0 +1,90 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Array3D} from 'deeplearn'; +import {PolymerElement, PolymerHTMLElement} from './polymer-spec'; + +// tslint:disable-next-line +export let NDArrayImageVisualizerPolymer: new () => PolymerHTMLElement = + PolymerElement({is: 'ndarray-image-visualizer', properties: {}}); + +export class NDArrayImageVisualizer extends NDArrayImageVisualizerPolymer { + private canvas: HTMLCanvasElement; + private canvasContext: CanvasRenderingContext2D; + private imageData: ImageData; + + ready() { + this.canvas = this.querySelector('#canvas') as HTMLCanvasElement; + this.canvas.width = 0; + this.canvas.height = 0; + this.canvasContext = + this.canvas.getContext('2d') as CanvasRenderingContext2D; + this.canvas.style.display = 'none'; + } + + setShape(shape: number[]) { + this.canvas.width = shape[1]; + this.canvas.height = shape[0]; + } + + setSize(width: number, height: number) { + this.canvas.style.width = width + 'px'; + this.canvas.style.height = height + 'px'; + } + + saveImageDataFromNDArray(ndarray: Array3D) { + this.imageData = this.canvasContext.createImageData( + this.canvas.width, this.canvas.height); + if (ndarray.shape[2] === 1) { + this.drawGrayscaleImageData(ndarray); + } else if (ndarray.shape[2] === 3) { + this.drawRGBImageData(ndarray); + } + } + + drawRGBImageData(ndarray: Array3D) { + let pixelOffset = 0; + for (let i = 0; i < ndarray.shape[0]; i++) { + for (let j = 0; j < ndarray.shape[1]; j++) { + this.imageData.data[pixelOffset++] = ndarray.get(i, j, 0); + this.imageData.data[pixelOffset++] = ndarray.get(i, j, 1); + this.imageData.data[pixelOffset++] = ndarray.get(i, j, 2); + this.imageData.data[pixelOffset++] = 255; + } + } + } + + drawGrayscaleImageData(ndarray: Array3D) { + let pixelOffset = 0; + for (let i = 0; i < ndarray.shape[0]; i++) { + for (let j = 0; j < ndarray.shape[1]; j++) { + const value = ndarray.get(i, j, 0); + this.imageData.data[pixelOffset++] = value; + this.imageData.data[pixelOffset++] = value; + this.imageData.data[pixelOffset++] = value; + this.imageData.data[pixelOffset++] = 255; + } + } + } + + draw() { + this.canvas.style.display = ''; + this.canvasContext.putImageData(this.imageData, 0, 0); + } +} +document.registerElement( + NDArrayImageVisualizer.prototype.is, NDArrayImageVisualizer); diff --git a/src/ndarray-logits-visualizer.html b/src/ndarray-logits-visualizer.html new file mode 100644 index 0000000..dd444f5 --- /dev/null +++ b/src/ndarray-logits-visualizer.html @@ -0,0 +1,45 @@ + + + + + + diff --git a/src/ndarray-logits-visualizer.ts b/src/ndarray-logits-visualizer.ts new file mode 100644 index 0000000..0eeafcf --- /dev/null +++ b/src/ndarray-logits-visualizer.ts @@ -0,0 +1,98 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Array1D, NDArrayMathCPU} from 'deeplearn'; + +import {PolymerElement, PolymerHTMLElement} from './polymer-spec'; + +const TOP_K = 3; + +// tslint:disable-next-line +export let NDArrayLogitsVisualizerPolymer: new () => PolymerHTMLElement = + PolymerElement({is: 'ndarray-logits-visualizer', properties: {}}); + +export class NDArrayLogitsVisualizer extends NDArrayLogitsVisualizerPolymer { + private logitLabelElements: HTMLElement[]; + private logitVizElements: HTMLElement[]; + private width: number; + + initialize(width: number, height: number) { + this.width = width; + this.logitLabelElements = []; + this.logitVizElements = []; + const container = this.querySelector('.logits-container') as HTMLElement; + container.style.height = height + 'px'; + + for (let i = 0; i < TOP_K; i++) { + const logitContainer = document.createElement('div'); + logitContainer.style.height = height / (TOP_K + 1) + 'px'; + logitContainer.style.margin = + height / ((2 * TOP_K) * (TOP_K + 1)) + 'px 0'; + logitContainer.className = + 'single-logit-container ndarray-logits-visualizer'; + + const logitLabelElement = document.createElement('div'); + logitLabelElement.className = 'logit-label ndarray-logits-visualizer'; + this.logitLabelElements.push(logitLabelElement); + + const logitVizOuterElement = document.createElement('div'); + logitVizOuterElement.className = + 'logit-viz-outer ndarray-logits-visualizer'; + + const logitVisInnerElement = document.createElement('div'); + logitVisInnerElement.className = + 'logit-viz-inner ndarray-logits-visualizer'; + logitVisInnerElement.innerHTML = ' '; + logitVizOuterElement.appendChild(logitVisInnerElement); + + this.logitVizElements.push(logitVisInnerElement); + + logitContainer.appendChild(logitLabelElement); + logitContainer.appendChild(logitVizOuterElement); + container.appendChild(logitContainer); + } + } + + drawLogits( + predictedLogits: Array1D, labelLogits: Array1D, + labelClassNames?: string[]) { + const mathCpu = new NDArrayMathCPU(); + const labelClass = mathCpu.argMax(labelLogits).get(); + + const topk = mathCpu.topK(predictedLogits, TOP_K); + const topkIndices = topk.indices.getValues(); + const topkValues = topk.values.getValues(); + + for (let i = 0; i < topkIndices.length; i++) { + const index = topkIndices[i]; + this.logitLabelElements[i].innerText = + labelClassNames ? labelClassNames[index] : index + ''; + this.logitLabelElements[i].style.width = + labelClassNames != null ? '100px' : '20px'; + this.logitVizElements[i].style.backgroundColor = index === labelClass ? + 'rgba(120, 185, 50, .84)' : + 'rgba(220, 10, 10, 0.84)'; + this.logitVizElements[i].style.width = + Math.floor(100 * topkValues[i]) + '%'; + this.logitVizElements[i].innerText = + `${(100 * topkValues[i]).toFixed(1)}%`; + } + } +} + +document.registerElement( + NDArrayLogitsVisualizer.prototype.is, NDArrayLogitsVisualizer); diff --git a/src/polymer-spec.ts b/src/polymer-spec.ts new file mode 100644 index 0000000..305a770 --- /dev/null +++ b/src/polymer-spec.ts @@ -0,0 +1,66 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * @fileoverview + * + * Defines an interface for creating Polymer elements in Typescript with the + * correct typings. A Polymer element should be defined like this: + * + * ``` + * let MyElementPolymer = PolymerElement({ + * is: 'my-polymer-element', + * properties: { + * foo: string, + * bar: Array + * } + * }); + * + * class MyElement extends MyElementPolymer { + * foo: string; + * bar: number[]; + * + * ready() { + * console.log('MyElement initialized!'); + * } + * } + * + * document.registerElement(MyElement.prototype.is, MyElement); + * ``` + */ + +export type Spec = { + is: string; properties: { + [key: string]: (Function|{ + // tslint:disable-next-line:no-any + type: Function, value?: any; + reflectToAttribute?: boolean; + readonly?: boolean; + notify?: boolean; + computed?: string; + observer?: string; + }) + }; + observers?: string[]; +}; + +export function PolymerElement(spec: Spec) { + // tslint:disable-next-line:no-any + return Polymer.Class(spec as any) as {new (): PolymerHTMLElement}; +} + +export interface PolymerHTMLElement extends HTMLElement, polymer.Base {} diff --git a/src/support.js b/src/support.js new file mode 100644 index 0000000..fd21912 --- /dev/null +++ b/src/support.js @@ -0,0 +1,80 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +function isSafari() { + var ua = navigator.userAgent.toLowerCase(); + if (ua.indexOf('safari') != -1) { + if (ua.indexOf('chrome') > -1) { + return false; + } else { + return true; + } + } +} +function isMobile() { + var a = navigator.userAgent || navigator.vendor || window.opera; + return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i + .test(a) || + /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i + .test(a.substr(0, 4)); +} + +function isWebGLEnabled() { + var canvas = document.createElement('canvas'); + + var attributes = { + alpha: false, + antialias: false, + premultipliedAlpha: false, + preserveDrawingBuffer: false, + depth: false, + stencil: false, + failIfMajorPerformanceCaveat: true + }; + return null != (canvas.getContext('webgl', attributes) || + canvas.getContext('experimental-webgl', attributes)); +} + +function isNotSupported() { + return isMobile() || isSafari() || !isWebGLEnabled(); +} + +function inializePolymerPage() { + document.addEventListener('WebComponentsReady', function(event) { + if (isNotSupported()) { + var dialogContainer = document.createElement('div'); + dialogContainer.innerHTML = ` + +

This device is not yet supported

+
+

We do not yet support your device, please try to load this demo on a desktop computer with Chrome. We are working hard to add support for other devices. Check back soon!

+
+
+ `; + document.body.appendChild(dialogContainer); + var dialog = document.getElementById('dialog'); + dialog.style.width = '400px'; + dialogPolyfill.registerDialog(dialog); + dialog.showModal(); + } else { + var bundleScript = document.createElement('script'); + bundleScript.src = 'bundle.js'; + document.head.appendChild(bundleScript); + } + }); +} +inializePolymerPage(); diff --git a/src/tensorflow.ts b/src/tensorflow.ts new file mode 100644 index 0000000..c4f6e4d --- /dev/null +++ b/src/tensorflow.ts @@ -0,0 +1,203 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// tslint:disable-next-line:max-line-length +import {Convolution2DLayerBuilder, LayerBuilder, MaxPoolLayerBuilder} from './layer_builder'; + +export enum Normalization { + NORMALIZATION_NEGATIVE_ONE_TO_ONE, + NORMALIZATION_ZERO_TO_ONE, + NORMALIZATION_NONE +} + +export function generatePython( + datasetName: string, normalizationStrategy: number, inputShape: number[], + modelLayers: LayerBuilder[]): string { + const loadData = generateLoadData(datasetName, normalizationStrategy); + const buildModel = generateBuildModel(inputShape, modelLayers); + const captureWeights = generateCaptureWeights(modelLayers); + return [loadData, buildModel, captureWeights].join('\n\n'); +} + +function generateLoadData( + datasetName: string, normalizationStrategy: number): string { + let loadFunction: string; + switch (datasetName) { + case 'CIFAR 10': { + loadFunction = 'cifar10'; + break; + } + case 'MNIST': { + loadFunction = 'mnist'; + break; + } + default: { + throw new Error('datasetName must be \'CIFAR 10\' or \'MNIST\''); + } + } + + let normString: string; + switch (normalizationStrategy) { + case Normalization.NORMALIZATION_NEGATIVE_ONE_TO_ONE: { + normString = 'NORMALIZATION_NEGATIVE_ONE_TO_ONE'; + break; + } + case Normalization.NORMALIZATION_ZERO_TO_ONE: { + normString = 'NORMALIZATION_ZERO_TO_ONE'; + break; + } + case Normalization.NORMALIZATION_NONE: { + normString = 'NORMALIZATION_NONE'; + break; + } + default: { throw new Error('invalid normalizationStrategy value'); } + } + + return `def load_data(): + return learnjs_colab.load_${loadFunction}(learnjs_colab.${normString}) +`; +} + +function generateBuildModelLayer( + layerIndex: number, inputShape: number[], layer: LayerBuilder): string { + let src = ''; + const W = 'W_' + layerIndex; + const b = 'b_' + layerIndex; + const outputShape = layer.getOutputShape(inputShape); + switch (layer.layerName) { + case 'Fully connected': { + const shape = [inputShape[0], outputShape].join(', '); + + src = ` ${W} = tf.Variable(tf.truncated_normal([${shape}], + stddev = 1.0 / math.sqrt(${outputShape[0]}))) + ${b} = tf.Variable(tf.truncated_normal([${outputShape[0]}], stddev = 0.1)) + layers.append({ 'x': layers[-1]['y'], + 'W': ${W}, + 'b': ${b}, + 'y': tf.add(tf.matmul(layers[-1]['y'], ${W}), ${b}) })`; + break; + } + + case 'ReLU': { + src = ` layers.append({ 'x': layers[-1]['y'], + 'y': tf.nn.relu(layers[-1]['y']) })`; + break; + } + + case 'Convolution': { + const conv = layer as Convolution2DLayerBuilder; + const f = conv.fieldSize; + const d1 = inputShape[inputShape.length - 1]; + const d2 = outputShape[outputShape.length - 1]; + const wShape = '[' + f + ', ' + f + ', ' + d1 + ', ' + d2 + ']'; + const stride = '[1, ' + conv.stride + ', ' + conv.stride + ', 1]'; + src = ` ${W} = tf.Variable(tf.truncated_normal(${wShape}, stddev = 0.1)) + ${b} = tf.Variable(tf.truncated_normal([${d2}], stddev = 0.1)) + layers.append({ 'x': layers[-1]['y'], + 'W': ${W}, + 'b': ${b}, + 'y': tf.add(tf.nn.conv2d(layers[-1]['y'], + ${W}, + strides = ${stride}, + padding = 'SAME'), ${b}) })`; + break; + } + + case 'Max pool': { + const mp = layer as MaxPoolLayerBuilder; + const field = '[1, ' + mp.fieldSize + ', ' + mp.fieldSize + ', 1]'; + const stride = '[1, ' + mp.stride + ', ' + mp.stride + ', 1]'; + src = ` layers.append({ 'x': layers[-1]['y'], + 'y': tf.nn.max_pool(layers[-1]['y'], + ${field}, + ${stride}, + padding = 'SAME') })`; + break; + } + + case 'Reshape': { + break; + } + + case 'Flatten': { + src = ` layers.append({ 'x': layers[-1]['y'], + 'y': tf.reshape(layers[-1]['y'], [-1, ${outputShape[0]}]) })`; + break; + } + + default: { + throw new Error('unknown layer type \'' + layer.layerName + '\''); + } + } + + return src; +} + +function generateBuildModel( + inputShape: number[], modelLayers: LayerBuilder[]): string { + const inputShapeStr = inputShape.join(', '); + const sources: string[] = []; + + sources.push(`def build_model(): + layers = [] + + layers.append({ 'y': tf.placeholder(tf.float32, [None, ${inputShapeStr}]), + 'y_label': tf.placeholder(tf.float32, [None, 10]) })`); + + for (let i = 0; i < modelLayers.length; ++i) { + sources.push(generateBuildModelLayer(i + 1, inputShape, modelLayers[i])); + inputShape = modelLayers[i].getOutputShape(inputShape); + } + + sources.push(' return layers\n'); + return sources.join('\n\n'); +} + +function generateCaptureWeights(modelLayers: LayerBuilder[]): string { + const sources: string[] = []; + sources.push(`def capture_weights(): + weights = []`); + + for (let i = 0; i < modelLayers.length; ++i) { + const layer = modelLayers[i]; + const index = i + 1; + let src = ''; + const W = '\'W\': model[' + index + '][\'W\']'; + const b = '\'b\': model[' + index + '][\'b\']'; + switch (layer.layerName) { + case 'Fully connected': { + src = ` weights.append({ ${W}.eval().flatten().tolist(), + ${b}.eval().flatten().tolist() })`; + break; + } + + case 'Convolution': { + src = ` weights.append({ ${W}.eval().transpose().flatten().tolist(), + ${b}.eval().flatten().tolist() })`; + break; + } + + default: { src = ' weights.append({})'; } + } + + src += ' # ' + layer.layerName; + sources.push(src); + } + + sources.push(' return weights'); + return sources.join('\n'); +} diff --git a/tsconfig.json b/tsconfig.json new file mode 100644 index 0000000..409f6b5 --- /dev/null +++ b/tsconfig.json @@ -0,0 +1,20 @@ +{ + "compilerOptions": { + "module": "commonjs", + "noImplicitAny": true, + "sourceMap": true, + "removeComments": true, + "preserveConstEnums": true, + "declaration": true, + "target": "es5", + "lib": ["es2015", "dom"], + "outDir": "./dist", + "noUnusedLocals": false, + "noImplicitReturns": true, + "noImplicitThis": true, + "noUnusedParameters": false, + "pretty": true, + "noFallthroughCasesInSwitch": true, + "allowUnreachableCode": false + } +} diff --git a/tslint.json b/tslint.json new file mode 100644 index 0000000..546ec93 --- /dev/null +++ b/tslint.json @@ -0,0 +1,57 @@ +{ + "rules": { + "array-type": [true, "array-simple"], + "arrow-return-shorthand": true, + "ban": [true, + ["fit"], + ["fdescribe"], + ["xit"], + ["xdescribe"], + ["fitAsync"], + ["xitAsync"], + ["fitFakeAsync"], + ["xitFakeAsync"] + ], + "ban-types": [true, + ["Object", "Use {} instead."], + ["String", "Use 'string' instead."], + ["Number", "Use 'number' instead."], + ["Boolean", "Use 'boolean' instead."] + ], + "class-name": true, + "interface-name": [true, "never-prefix"], + "jsdoc-format": true, + "forin": false, + "label-position": true, + "max-line-length": [true, 80], + "new-parens": true, + "no-angle-bracket-type-assertion": true, + "no-any": true, + "no-construct": true, + "no-debugger": true, + "no-default-export": true, + "no-inferrable-types": true, + "no-namespace": [true, "allow-declarations"], + "no-reference": true, + "no-require-imports": true, + "no-string-throw": true, + "no-unused-expression": true, + "no-unused-variable": true, + "no-var-keyword": true, + "object-literal-shorthand": true, + "only-arrow-functions": [true, "allow-declarations", "allow-named-functions"], + "prefer-const": true, + "radix": true, + "semicolon": [true, "always", "ignore-bound-class-methods"], + "switch-default": true, + "triple-equals": [true, "allow-null-check"], + "use-isnan": true, + "variable-name": [ + true, + "check-format", + "ban-keywords", + "allow-leading-underscore", + "allow-trailing-underscore" + ] + } +}