Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion tfjs-backend-webgpu/cloudbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,21 @@ steps:
args: ['install']
waitFor: ['yarn-common']

# Build core from master.
- name: 'node:10'
dir: 'tfjs-backend-webgpu'
id: 'build-core'
entrypoint: 'yarn'
args: ['build-core']
waitFor: ['yarn-common']

# Run tests.
- name: 'node:10'
dir: 'tfjs-backend-webgpu'
entrypoint: 'yarn'
id: 'test-webgpu'
args: ['test-ci']
waitFor: ['yarn']
waitFor: ['build-core']

# General configuration
timeout: 1800s
Expand Down
19 changes: 10 additions & 9 deletions tfjs-backend-webgpu/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
"scripts": {
"publish-local": "rimraf dist/ && yarn build && rollup -c && yalc push",
"publish-npm": "./scripts/publish-npm.sh",
"build": "tsc",
"build-core": "cd ../tfjs-core && yarn && yarn build",
"build": "yarn build-core && rimraf dist/ && tsc",
"link-local": "yalc link",
"unlink-local": "yalc remove",
"lint": "tslint -p . -t verbose",
Expand All @@ -20,7 +21,7 @@
},
"license": "Apache-2.0",
"devDependencies": {
"@tensorflow/tfjs-core": "1.3.1",
"@tensorflow/tfjs-core": "link:../tfjs-core",
"@types/jasmine": "~2.5.53",
"clang-format": "~1.2.2",
"http-server": "~0.10.0",
Expand All @@ -34,12 +35,12 @@
"karma-typescript-es6-transform": "^4.1.1",
"rimraf": "~2.6.2",
"rollup": "~1.26.3",
"rollup-plugin-commonjs": "10.1.0",
"rollup-plugin-node-resolve": "5.2.0",
"rollup-plugin-terser": "5.1.1",
"rollup-plugin-typescript2": "0.25.2",
"tslint": "5.20.0",
"tslint-no-circular-imports": "0.7.0",
"rollup-plugin-commonjs": "~10.1.0",
"rollup-plugin-node-resolve": "~5.2.0",
"rollup-plugin-terser": "~5.1.1",
"rollup-plugin-typescript2": "~0.25.2",
"tslint": "~5.20.0",
"tslint-no-circular-imports": "~0.7.0",
"typescript": "3.5.3",
"yalc": "~1.0.0-pre.21"
},
Expand All @@ -48,6 +49,6 @@
"@webgpu/types": "0.0.18"
},
"peerDependencies": {
"@tensorflow/tfjs-core": "1.3.1"
"@tensorflow/tfjs-core": "link:../tfjs-core"
}
}
1 change: 0 additions & 1 deletion tfjs-backend-webgpu/scripts/test-ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

set -e

yarn
yarn lint
yarn build

18 changes: 5 additions & 13 deletions tfjs-backend-webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,7 @@

import './flags_webgpu';

import {backend_util, DataStorage, DataType, engine, env, findBackend, KernelBackend, Rank, RecursiveArray, ShapeMap, slice_util, Tensor, Tensor2D, Tensor3D, Tensor4D, TimingInfo, util} from '@tensorflow/tfjs-core';
// TODO(xing.xu): use FusedConv2DConfig from backend_util:
// https://github.com/tensorflow/tfjs/issues/2471
// tslint:disable-next-line: no-imports-from-dist
import {FusedConv2DConfig} from '@tensorflow/tfjs-core/dist/ops/fused_util';
// TODO: Import reduce_util from backend_util with next release of core.
import {computeOptimalWindowSize} from '@tensorflow/tfjs-core/src/ops/reduce_util';
// TODO: import sumOutType directly from '@tensorflow/tfjs-core' with next
// release of core.
import {sumOutType} from '@tensorflow/tfjs-core/src/types';
import {backend_util, DataStorage, DataType, engine, env, findBackend, KernelBackend, Rank, RecursiveArray, ShapeMap, slice_util, sumOutType, Tensor, Tensor2D, Tensor3D, Tensor4D, TimingInfo, util} from '@tensorflow/tfjs-core';
import {Glslang} from '@webgpu/glslang/dist/web-devel/glslang.onefile';

import {BufferManager} from './buffer_manager';
Expand Down Expand Up @@ -776,7 +767,7 @@ export class WebGPUBackend extends KernelBackend {

fusedConv2d(
{input, filter, convInfo, bias, activation, preluActivationWeights}:
FusedConv2DConfig): Tensor4D {
backend_util.FusedConv2DConfig): Tensor4D {
const dataId = this.write(null /*values*/, convInfo.outShape, input.dtype);
const output = engine().makeTensorFromDataId(
dataId, convInfo.outShape, input.dtype, this);
Expand Down Expand Up @@ -835,7 +826,7 @@ export class WebGPUBackend extends KernelBackend {
Tensor2D {
const batchSize = x.shape[0];
const inSize = x.shape[1];
const windowSize = computeOptimalWindowSize(inSize);
const windowSize = backend_util.computeOptimalWindowSize(inSize);
const reduceInfo = {windowSize, inSize, batchSize};
const program = new ReduceProgram(reduceInfo, reduceType);
const output = this.makeOutputArray(program.outputShape, dtype);
Expand Down Expand Up @@ -1011,7 +1002,8 @@ export class WebGPUBackend extends KernelBackend {
const program =
new ResizeBilinearProgram(x.shape, newHeight, newWidth, alignCorners);

const output: Tensor4D = this.makeOutputArray(program.outputShape, x.dtype);
const output: Tensor4D =
this.makeOutputArray(program.outputShape, 'float32');

return this.compileAndRun(program, [x], output);
}
Expand Down
11 changes: 6 additions & 5 deletions tfjs-backend-webgpu/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ const TEST_FILTERS: TestFilter[] = [
'broadcast 2D + 1D', // Actual != expected.
'upcasts when dtypes dont match', // Actual != expected.
'gradient', // square, sum not yet implemented.
'divNoNan' // Equal not yet implemented.
]
},
{
Expand Down Expand Up @@ -177,18 +178,18 @@ const TEST_FILTERS: TestFilter[] = [
{
include: 'resizeBilinear',
excludes: [
'gradient', // Not yet implemented.
'gradient', // Not yet implemented.
'works for ints' // Actual != expected.
]
},
{include: 'floor divide ', excludes: []},
{
include: 'fused',
excludes: [
'A x B', // fusedBatchMatMul not yet implemented.
'A x B with elu', // elu not yet implemented.
'A x B with elu and broadcasted bias', // elu not yet implemented.
'A x B', // fusedBatchMatMul not yet implemented.
'elu', // elu not yet implemented.
'A x B with bias only', // fusedBatchMatMul not yet implemented.
'basic with elu', // elu not yet implemented.
'basic with bias', // Actual != expected.
'gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0', // conv2dDerInput not yet
// implemented.
'gradient x=[2,3,3,1] f=[2,2,1,1] s=1 p=0 with bias', // conv2dDerInput
Expand Down
54 changes: 23 additions & 31 deletions tfjs-backend-webgpu/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,9 @@
esutils "^2.0.2"
js-tokens "^4.0.0"

"@tensorflow/tfjs-core@1.3.1":
version "1.3.1"
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-1.3.1.tgz#89e3253f320233120ee24d4ae62fca65c315580f"
integrity sha512-X4MKhpg1gLEZetKUMeQNW6diP3gbFFddeF6UT816sH8jOenX/8x2HnVmANpNnUxCTPhDniY3V9zhBWwbl13+Yg==
dependencies:
"@types/offscreencanvas" "~2019.3.0"
"@types/seedrandom" "2.4.27"
"@types/webgl-ext" "0.0.30"
"@types/webgl2" "0.0.4"
node-fetch "~2.1.2"
seedrandom "2.4.3"
"@tensorflow/tfjs-core@link:../tfjs-core":
version "0.0.0"
uid ""

"@types/estree@*", "@types/estree@0.0.39":
version "0.0.39"
Expand Down Expand Up @@ -3908,7 +3900,7 @@ ripemd160@^2.0.0, ripemd160@^2.0.1:
hash-base "^3.0.0"
inherits "^2.0.1"

rollup-plugin-commonjs@10.1.0:
rollup-plugin-commonjs@~10.1.0:
version "10.1.0"
resolved "https://registry.yarnpkg.com/rollup-plugin-commonjs/-/rollup-plugin-commonjs-10.1.0.tgz#417af3b54503878e084d127adf4d1caf8beb86fb"
integrity sha512-jlXbjZSQg8EIeAAvepNwhJj++qJWNJw1Cl0YnOqKtP5Djx+fFGkp3WRh+W0ASCaFG5w1jhmzDxgu3SJuVxPF4Q==
Expand All @@ -3919,7 +3911,7 @@ rollup-plugin-commonjs@10.1.0:
resolve "^1.11.0"
rollup-pluginutils "^2.8.1"

rollup-plugin-node-resolve@5.2.0:
rollup-plugin-node-resolve@~5.2.0:
version "5.2.0"
resolved "https://registry.yarnpkg.com/rollup-plugin-node-resolve/-/rollup-plugin-node-resolve-5.2.0.tgz#730f93d10ed202473b1fb54a5997a7db8c6d8523"
integrity sha512-jUlyaDXts7TW2CqQ4GaO5VJ4PwwaV8VUGA7+km3n6k6xtOEacf61u0VXwN80phY/evMcaS+9eIeJ9MOyDxt5Zw==
Expand All @@ -3930,21 +3922,21 @@ rollup-plugin-node-resolve@5.2.0:
resolve "^1.11.1"
rollup-pluginutils "^2.8.1"

rollup-plugin-terser@5.1.1:
version "5.1.1"
resolved "https://registry.yarnpkg.com/rollup-plugin-terser/-/rollup-plugin-terser-5.1.1.tgz#e9d2545ec8d467f96ba99b9216d2285aad8d5b66"
integrity sha512-McIMCDEY8EU6Y839C09UopeRR56wXHGdvKKjlfiZG/GrP6wvZQ62u2ko/Xh1MNH2M9WDL+obAAHySljIZYCuPQ==
rollup-plugin-terser@~5.1.1:
version "5.1.3"
resolved "https://registry.yarnpkg.com/rollup-plugin-terser/-/rollup-plugin-terser-5.1.3.tgz#5f4c4603b12b4f8d093f4b6f31c9aa5eba98a223"
integrity sha512-FuFuXE5QUJ7snyxHLPp/0LFXJhdomKlIx/aK7Tg88Yubsx/UU/lmInoJafXJ4jwVVNcORJ1wRUC5T9cy5yk0wA==
dependencies:
"@babel/code-frame" "^7.0.0"
jest-worker "^24.6.0"
rollup-pluginutils "^2.8.1"
serialize-javascript "^1.7.0"
serialize-javascript "^2.1.2"
terser "^4.1.0"

rollup-plugin-typescript2@0.25.2:
version "0.25.2"
resolved "https://registry.yarnpkg.com/rollup-plugin-typescript2/-/rollup-plugin-typescript2-0.25.2.tgz#1a165df08560902da45b355413464caca1765d3a"
integrity sha512-+tpZj/ZIf2lwjyjX6xEW1S5Y38/21TB3p6poLodISIia8owMMfIKuFFnWcESE4FPBHkR8XPKqjY0PH9IUJJK+Q==
rollup-plugin-typescript2@~0.25.2:
version "0.25.3"
resolved "https://registry.yarnpkg.com/rollup-plugin-typescript2/-/rollup-plugin-typescript2-0.25.3.tgz#a5fb2f0f85488789334ce540abe6c7011cbdf40f"
integrity sha512-ADkSaidKBovJmf5VBnZBZe+WzaZwofuvYdzGAKTN/J4hN7QJCFYAq7IrH9caxlru6T5qhX41PNFS1S4HqhsGQg==
dependencies:
find-cache-dir "^3.0.0"
fs-extra "8.1.0"
Expand Down Expand Up @@ -4010,10 +4002,10 @@ semver@^6.0.0:
resolved "https://registry.yarnpkg.com/semver/-/semver-6.3.0.tgz#ee0a64c8af5e8ceea67687b133761e1becbd1d3d"
integrity sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==

serialize-javascript@^1.7.0:
version "1.7.0"
resolved "https://registry.yarnpkg.com/serialize-javascript/-/serialize-javascript-1.7.0.tgz#d6e0dfb2a3832a8c94468e6eb1db97e55a192a65"
integrity sha512-ke8UG8ulpFOxO8f8gRYabHQe/ZntKlcig2Mp+8+URDP1D8vJZ0KUt7LYo07q25Z/+JVSgpr/cui9PIp5H6/+nA==
serialize-javascript@^2.1.2:
version "2.1.2"
resolved "https://registry.yarnpkg.com/serialize-javascript/-/serialize-javascript-2.1.2.tgz#ecec53b0e0317bdc95ef76ab7074b7384785fa61"
integrity sha512-rs9OggEUF0V4jUSecXazOYsLfu7OGK2qIn3c7IPBiffz32XniEp/TX9Xmc9LQfK2nQ2QKHvZ2oygKUGU0lG4jQ==

set-blocking@^2.0.0, set-blocking@~2.0.0:
version "2.0.0"
Expand Down Expand Up @@ -4478,15 +4470,15 @@ tslib@^1.8.0, tslib@^1.8.1:
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.9.3.tgz#d7e4dd79245d85428c4d7e4822a79917954ca286"
integrity sha512-4krF8scpejhaOgqzBEcGM7yDIEfi0/8+8zDRZhNZZ2kjmHJ4hv3zCbQWxoJGz1iw5U0Jl0nma13xzHXcncMavQ==

tslint-no-circular-imports@0.7.0:
tslint-no-circular-imports@~0.7.0:
version "0.7.0"
resolved "https://registry.yarnpkg.com/tslint-no-circular-imports/-/tslint-no-circular-imports-0.7.0.tgz#9df0a15654d66b172e0b7843eed073fa5ae99b5f"
integrity sha512-k3wxpeMC4ef40UbpfBVHEHIzKfNZq5/SCtAO1YjGsaNTklo+K53/TWLrym+poA65RJFDiYgYNWvkeIIkJNA0Vw==

tslint@5.20.0:
version "5.20.0"
resolved "https://registry.yarnpkg.com/tslint/-/tslint-5.20.0.tgz#fac93bfa79568a5a24e7be9cdde5e02b02d00ec1"
integrity sha512-2vqIvkMHbnx8acMogAERQ/IuINOq6DFqgF8/VDvhEkBqQh/x6SP0Y+OHnKth9/ZcHQSroOZwUQSN18v8KKF0/g==
tslint@~5.20.0:
version "5.20.1"
resolved "https://registry.yarnpkg.com/tslint/-/tslint-5.20.1.tgz#e401e8aeda0152bc44dd07e614034f3f80c67b7d"
integrity sha512-EcMxhzCFt8k+/UP5r8waCf/lzmeSyVlqxqMEDQE7rWYiQky8KpIBz1JAoYXfROHrPZ1XXd43q8yQnULOLiBRQg==
dependencies:
"@babel/code-frame" "^7.0.0"
builtin-modules "^1.1.1"
Expand Down