-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add polynomial regression example using core api. (#15)
- Loading branch information
Showing
7 changed files
with
4,791 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# TensorFlow.js Example: Fitting a curve to synthetic data | ||
|
||
This example shows you how to use TensorFlow.js operations and optimizers (the lower level api) to write a simple model that learns the coefficients of polynomial that we want to use to describe our data. In this toy example, we generate synthetic data by adding some noise to a polynomial function. Then starting with random coefficients, we train a model to learn the true coefficients that data was generated with. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/** | ||
* @license | ||
* Copyright 2018 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* ============================================================================= | ||
*/ | ||
|
||
import * as tf from '@tensorflow/tfjs'; | ||
|
||
export function generateData(numPoints, coeff, sigma = 0.04) { | ||
return tf.tidy(() => { | ||
const [a, b, c, d] = [ | ||
tf.scalar(coeff.a), tf.scalar(coeff.b), tf.scalar(coeff.c), | ||
tf.scalar(coeff.d) | ||
]; | ||
|
||
const xs = tf.randomUniform([numPoints], -1, 1); | ||
|
||
// Generate polynomial data | ||
const three = tf.scalar(3, 'int32'); | ||
const ys = a.mul(xs.pow(three)) | ||
.add(b.mul(xs.square())) | ||
.add(c.mul(xs)) | ||
.add(d) | ||
// Add random noise to the generated data | ||
// to make the problem a bit more interesting | ||
.add(tf.randomNormal([numPoints], 0, sigma)); | ||
|
||
// Normalize the y values to the range 0 to 1. | ||
const ymin = ys.min(); | ||
const ymax = ys.max(); | ||
const yrange = ymax.sub(ymin); | ||
const ysNormalized = ys.sub(ymin).div(yrange); | ||
|
||
return { | ||
xs, | ||
ys: ysNormalized | ||
}; | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
<!-- Copyright 2018 Google LLC. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================--> | ||
<html> | ||
<head> | ||
<link rel="stylesheet" href="https://code.getmdl.io/1.3.0/material.cyan-teal.min.css" /> | ||
</head> | ||
<body> | ||
<h3>TensorFlow.js: Fitting a curve to synthetic data</h3> | ||
<style> | ||
body { | ||
padding: 50px; | ||
} | ||
|
||
.plots { | ||
display: flex; | ||
flex-direction: row; | ||
flex-wrap: wrap; | ||
} | ||
|
||
#data, #random, #trained { | ||
margin: 30px; | ||
} | ||
|
||
.caption { | ||
font-weight: bold; | ||
} | ||
|
||
.coeff { | ||
font-weight: normal; | ||
} | ||
|
||
</style> | ||
|
||
<div class="plots"> | ||
<div id="data"> | ||
<div class="caption">Original Data (Synthetic)</div> | ||
<div class="caption">True coeffecients: <span class='coeff'></span></div> | ||
<div class="plot"></div> | ||
</div> | ||
<div id="random"> | ||
<div class="caption">Fit curve with random coefficients (before training)</div> | ||
<div class="caption">Random coeffecients: | ||
<span class='coeff'></span> | ||
</div> | ||
<div class="plot"></div> | ||
</div> | ||
<div id="trained"> | ||
<div class="caption">Fit curve with learned coefficients (after training)</div> | ||
<div class="caption">Learned coeffecients: | ||
<span class='coeff'></span> | ||
</div> | ||
<div class="plot"></div> | ||
</div> | ||
</div> | ||
|
||
|
||
<script src="index.js"></script> | ||
</body> | ||
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
/** | ||
* @license | ||
* Copyright 2018 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* ============================================================================= | ||
*/ | ||
|
||
import * as tf from '@tensorflow/tfjs'; | ||
import {generateData} from './data'; | ||
import {plotData, plotDataAndPredictions, renderCoefficients} from './ui'; | ||
|
||
/** | ||
* We want to learn the coefficients that give correct solutions to the | ||
* following quadratic equation: | ||
* y = a * x^3 + b * x^2 + c * x + d | ||
* In other words we want to learn values for: | ||
* a | ||
* b | ||
* c | ||
* d | ||
* Such that this function produces 'desired outputs' for y when provided | ||
* with x. We will provide some examples of 'xs' and 'ys' to allow this model | ||
* to learn what we mean by desired outputs and then use it to produce new | ||
* values of y that fit the curve implied by our example. | ||
*/ | ||
|
||
// Step 1. Set up variables, these are the things we want the model | ||
// to learn in order to do prediction accurately. We will initialize | ||
// them with random values. | ||
const a = tf.variable(tf.scalar(Math.random())); | ||
const b = tf.variable(tf.scalar(Math.random())); | ||
const c = tf.variable(tf.scalar(Math.random())); | ||
const d = tf.variable(tf.scalar(Math.random())); | ||
|
||
|
||
// Step 2. Create an optimizer, we will use this later. You can play | ||
// with some of these values to see how the model perfoms. | ||
const numIterations = 75; | ||
const learningRate = 0.5; | ||
const optimizer = tf.train.sgd(learningRate); | ||
|
||
// Step 3. Write our training process functions. | ||
|
||
/* | ||
* This function represents our 'model'. Given an input 'x' it will try and | ||
* predict the appropriate output 'y'. | ||
* | ||
* It is also sometimes referred to as the 'forward' step of our training | ||
* process. Though we will use the same function for predictions later. | ||
* | ||
* @return number predicted y value | ||
*/ | ||
function predict(x) { | ||
// y = a * x ^ 3 + b * x ^ 2 + c * x + d | ||
return tf.tidy(() => { | ||
return a.mul(x.pow(tf.scalar(3, 'int32'))) | ||
.add(b.mul(x.square())) | ||
.add(c.mul(x)) | ||
.add(d); | ||
}); | ||
} | ||
|
||
/* | ||
* This will tell us how good the 'prediction' is given what we actually | ||
* expected. | ||
* | ||
* prediction is a tensor with our predicted y values. | ||
* labels is a tensor with the y values the model should have predicted. | ||
*/ | ||
function loss(prediction, labels) { | ||
// Having a good error function is key for training a machine learning model | ||
const error = prediction.sub(labels).square().mean(); | ||
return error; | ||
} | ||
|
||
/* | ||
* This will iteratively train our model. | ||
* | ||
* xs - training data x values | ||
* ys — training data y values | ||
*/ | ||
async function train(xs, ys, numIterations) { | ||
for (let iter = 0; iter < numIterations; iter++) { | ||
// optimizer.minimize is where the training happens. | ||
|
||
// The function it takes must return a numerical estimate (i.e. loss) | ||
// of how well we are doing using the current state of | ||
// the variables we created at the start. | ||
|
||
// This optimizer does the 'backward' step of our training process | ||
// updating variables defined previously in order to minimize the | ||
// loss. | ||
optimizer.minimize(() => { | ||
// Feed the examples into the model | ||
const pred = predict(xs); | ||
return loss(pred, ys); | ||
}); | ||
|
||
// Use tf.nextFrame to not block the browser. | ||
await tf.nextFrame(); | ||
} | ||
} | ||
|
||
async function learnCoefficients() { | ||
const trueCoefficients = {a: -.8, b: -.2, c: .9, d: .5}; | ||
const trainingData = generateData(100, trueCoefficients); | ||
|
||
// Plot original data | ||
renderCoefficients('#data .coeff', trueCoefficients); | ||
await plotData('#data .plot', trainingData.xs, trainingData.ys) | ||
|
||
// See what the predictions look like with random coefficients | ||
renderCoefficients('#random .coeff', { | ||
a: a.dataSync()[0], | ||
b: b.dataSync()[0], | ||
c: c.dataSync()[0], | ||
d: d.dataSync()[0], | ||
}); | ||
const predictionsBefore = predict(trainingData.xs); | ||
await plotDataAndPredictions( | ||
'#random .plot', trainingData.xs, trainingData.ys, predictionsBefore); | ||
|
||
// Train the model! | ||
await train(trainingData.xs, trainingData.ys, numIterations); | ||
|
||
// See what the final results predictions are after training. | ||
renderCoefficients('#trained .coeff', { | ||
a: a.dataSync()[0], | ||
b: b.dataSync()[0], | ||
c: c.dataSync()[0], | ||
d: d.dataSync()[0], | ||
}); | ||
const predictionsAfter = predict(trainingData.xs); | ||
await plotDataAndPredictions( | ||
'#trained .plot', trainingData.xs, trainingData.ys, predictionsAfter); | ||
|
||
predictionsBefore.dispose(); | ||
predictionsAfter.dispose(); | ||
} | ||
|
||
|
||
learnCoefficients(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
{ | ||
"name": "tfjs-examples-polynomial-regression-core", | ||
"version": "0.1.0", | ||
"description": "", | ||
"main": "index.js", | ||
"license": "Apache-2.0", | ||
"private": true, | ||
"dependencies": { | ||
"@tensorflow/tfjs": "0.0.2", | ||
"vega-embed": "^3.2.0", | ||
"vega-lite": "^2.3.1" | ||
}, | ||
"scripts": { | ||
"watch": "NODE_ENV=development parcel --no-hmr --open index.html ", | ||
"build": "NODE_ENV=production parcel build index.html --no-minify --public-url /" | ||
}, | ||
"devDependencies": { | ||
"babel-plugin-transform-runtime": "~6.23.0", | ||
"babel-polyfill": "~6.26.0", | ||
"babel-preset-env": "~1.6.1", | ||
"clang-format": "~1.2.2", | ||
"parcel-bundler": "~1.6.2" | ||
}, | ||
"babel": { | ||
"presets": [ | ||
[ | ||
"env", | ||
{ | ||
"modules": false, | ||
"targets": { | ||
"browsers": [ | ||
"> 1%", | ||
"last 3 versions", | ||
"ie >= 9", | ||
"ios >= 8", | ||
"android >= 4.2" | ||
] | ||
}, | ||
"useBuiltIns": false | ||
} | ||
] | ||
], | ||
"plugins": [ | ||
"transform-runtime" | ||
] | ||
} | ||
} |
Oops, something went wrong.