diff --git a/package.json b/package.json index 2a78c5f164..63ab39227a 100644 --- a/package.json +++ b/package.json @@ -51,7 +51,7 @@ "test": "karma start", "run-browserstack": "karma start --singleRun --reporters='dots,karma-typescript,BrowserStack' --hostname='bs-local.com'", "test-benchmark": "cd integration_tests/benchmarks && yarn benchmark-travis && cd ../../", - "test-node": "ts-node src/test_node.ts", + "test-node": "ts-node --files src/test_node.ts", "test-integration": "./scripts/test-integration.sh", "test-travis": "./scripts/test-travis.sh" }, diff --git a/src/canvas_util.ts b/src/canvas_util.ts index 0b7dd08e4c..055b017eea 100644 --- a/src/canvas_util.ts +++ b/src/canvas_util.ts @@ -27,9 +27,17 @@ const WEBGL_ATTRIBUTES: WebGLContextAttributes = { failIfMajorPerformanceCaveat: true }; +export function createCanvas(): HTMLCanvasElement { + if (window && ('OffscreenCanvas' in window)) { + return new OffscreenCanvas(1, 1); + } else { + return document.createElement('canvas'); + } +} + export function getWebGLContext(webGLVersion: number): WebGLRenderingContext { if (!(webGLVersion in contexts)) { - const canvas = document.createElement('canvas'); + const canvas = createCanvas(); canvas.addEventListener('webglcontextlost', ev => { ev.preventDefault(); delete contexts[webGLVersion]; @@ -60,7 +68,7 @@ function getWebGLRenderingContext(webGLVersion: number): WebGLRenderingContext { throw new Error('Cannot get WebGL rendering context, WebGL is disabled.'); } - const canvas = document.createElement('canvas'); + const canvas = createCanvas(); if (webGLVersion === 1) { return (canvas.getContext('webgl', WEBGL_ATTRIBUTES) || canvas.getContext('experimental-webgl', WEBGL_ATTRIBUTES)) as diff --git a/src/environment_test.ts b/src/environment_test.ts index e529655baf..6cf6901b69 100644 --- a/src/environment_test.ts +++ b/src/environment_test.ts @@ -127,6 +127,27 @@ describe('Backend', () => { ENV.removeBackend('custom-cpu'); }); + it('web worker', () => { + if (typeof OffscreenCanvas !== 'undefined') { + const features = { + 'WEBGL_VERSION': 2, + 'IS_WORKER': true, + 'IS_BROWSER': false + }; + ENV.setFeatures(features); + + let backend: MathBackendWebGL; + const success = ENV.registerBackend('worker-webgl', () => { + backend = new MathBackendWebGL(); + return backend; + }, 104); + expect(success).toBe(true); + + const canvas = backend.getCanvas(); + expect(canvas.constructor.name).toBe('OffscreenCanvas'); + } + }); + it('default custom background null', () => { expect(ENV.findBackend('custom')).toBeNull(); }); diff --git a/src/kernels/backend_cpu.ts b/src/kernels/backend_cpu.ts index dec8ef57ae..af9b433a0e 100644 --- a/src/kernels/backend_cpu.ts +++ b/src/kernels/backend_cpu.ts @@ -16,6 +16,8 @@ */ import * as seedrandom from 'seedrandom'; + +import {createCanvas} from '../canvas_util'; import {ENV} from '../environment'; import {warn} from '../log'; import * as array_ops_util from '../ops/array_ops_util'; @@ -34,6 +36,7 @@ import {DataId, Scalar, setTensorTracker, Tensor, Tensor1D, Tensor2D, Tensor3D, import {DataType, DataTypeMap, DataValues, NumericDataType, Rank, ShapeMap, TypedArray, upcastType} from '../types'; import * as util from '../util'; import {now} from '../util'; + import {BackendTimingInfo, DataMover, DataStorage, KernelBackend} from './backend'; import * as backend_util from './backend_util'; import * as complex_util from './complex_util'; @@ -60,8 +63,7 @@ export class MathBackendCPU implements KernelBackend { constructor() { if (ENV.get('IS_BROWSER')) { - this.fromPixels2DContext = - document.createElement('canvas').getContext('2d'); + this.fromPixels2DContext = createCanvas().getContext('2d'); } } diff --git a/src/kernels/backend_webgl.ts b/src/kernels/backend_webgl.ts index cd202e05bf..038b6a2151 100644 --- a/src/kernels/backend_webgl.ts +++ b/src/kernels/backend_webgl.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {getWebGLContext} from '../canvas_util'; +import {createCanvas, getWebGLContext} from '../canvas_util'; import {MemoryInfo, TimingInfo} from '../engine'; import {ENV} from '../environment'; import {tidy} from '../globals'; @@ -35,6 +35,7 @@ import {DataId, Scalar, setTensorTracker, Tensor, Tensor1D, Tensor2D, Tensor3D, import {DataType, DataTypeMap, DataValues, NumericDataType, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../types'; import * as util from '../util'; import {getTypedArrayFromDType, sizeFromShape} from '../util'; + import {DataMover, DataStorage, KernelBackend} from './backend'; import * as backend_util from './backend_util'; import {mergeRealAndImagArrays} from './complex_util'; @@ -214,8 +215,7 @@ export class MathBackendWebGL implements KernelBackend { 'once the DOM is ready. One way to do that is to add an event ' + 'listener for `DOMContentLoaded` on the document object'); } - this.fromPixels2DContext = - document.createElement('canvas').getContext('2d'); + this.fromPixels2DContext = createCanvas().getContext('2d'); } this.fromPixels2DContext.canvas.width = pixels.width; this.fromPixels2DContext.canvas.height = pixels.height; @@ -504,7 +504,6 @@ export class MathBackendWebGL implements KernelBackend { if (ENV.get('WEBGL_VERSION') < 1) { throw new Error('WebGL is not supported on this device'); } - if (gpgpu == null) { const gl = getWebGLContext(ENV.get('WEBGL_VERSION')); this.gpgpu = new GPGPUContext(gl); @@ -515,12 +514,16 @@ export class MathBackendWebGL implements KernelBackend { this.canvas = gpgpu.gl.canvas; } if (ENV.get('WEBGL_PAGING_ENABLED')) { - // Use the device screen's resolution as a heuristic to decide on the - // maximum memory allocated on the GPU before starting to page. - this.NUM_BYTES_BEFORE_PAGING = - (window.screen.height * window.screen.width * - window.devicePixelRatio) * - BEFORE_PAGING_CONSTANT; + if (window && ('OffscreenCanvas' in window)) { + this.NUM_BYTES_BEFORE_PAGING = Infinity; + } else { + // Use the device screen's resolution as a heuristic to decide on the + // maximum memory allocated on the GPU before starting to page. + this.NUM_BYTES_BEFORE_PAGING = + (window.screen.height * window.screen.width * + window.devicePixelRatio) * + BEFORE_PAGING_CONSTANT; + } } this.textureManager = new TextureManager(this.gpgpu); } diff --git a/tsconfig.json b/tsconfig.json index 7177eaf672..b02869fef2 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -8,7 +8,10 @@ "preserveConstEnums": true, "declaration": true, "target": "es5", - "lib": ["es2015", "dom"], + "lib": [ + "es2015", + "dom" + ], "outDir": "./dist", "noUnusedLocals": true, "noImplicitReturns": true, @@ -20,6 +23,7 @@ "allowUnreachableCode": false }, "include": [ + "types/", "src/" ] } diff --git a/types/OffscreenCanvas.d.ts b/types/OffscreenCanvas.d.ts new file mode 100644 index 0000000000..25b34c7204 --- /dev/null +++ b/types/OffscreenCanvas.d.ts @@ -0,0 +1,6 @@ +declare let OffscreenCanvas: { + new (width: number, height: number): HTMLCanvasElement; + prototype: HTMLCanvasElement; +} + +interface OffscreenCanvas extends EventTarget {}