Skip to content

Commit

Permalink
Add initial support for loading tfjs in WebWorkers
Browse files Browse the repository at this point in the history
FEATURE

Uses OffscreenCanvas where available
  • Loading branch information
WenheLI authored and tafsiri committed Jun 13, 2019
1 parent 92df253 commit af3f975
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 11 deletions.
16 changes: 14 additions & 2 deletions src/backends/cpu/backend_cpu.ts
Expand Up @@ -65,6 +65,18 @@ interface TensorData<D extends DataType> {
complexTensors?: {real: Tensor, imag: Tensor};
}

function createCanvas() {
//@ts-ignore
if (typeof(OffscreenCanvas) !== 'undefined') {
//@ts-ignore
return new OffscreenCanvas(300, 150);
} else if (typeof(document) !== 'undefined') {
return document.createElement('canvas');
} else {
throw new Error('Cannot create a canvas in this context');
}
}

export class MathBackendCPU implements KernelBackend {
public blockSize = 48;

Expand All @@ -74,8 +86,8 @@ export class MathBackendCPU implements KernelBackend {

constructor() {
if (ENV.get('IS_BROWSER')) {
this.fromPixels2DContext =
document.createElement('canvas').getContext('2d');
const canvas = createCanvas();
this.fromPixels2DContext = canvas.getContext('2d');
}
this.data = new DataStorage(this, ENGINE);
}
Expand Down
9 changes: 7 additions & 2 deletions src/backends/webgl/backend_webgl.ts
Expand Up @@ -2466,8 +2466,13 @@ export class MathBackendWebGL implements KernelBackend {
return;
}
this.textureManager.dispose();
this.canvas.remove();
if (this.fromPixels2DContext != null) {
if (this.canvas.remove != null) {
this.canvas.remove();
} else {
this.canvas = null;
}
if (this.fromPixels2DContext != null
&& this.fromPixels2DContext.canvas.remove != null) {
this.fromPixels2DContext.canvas.remove();
}
if (this.gpgpuCreatedLocally) {
Expand Down
16 changes: 14 additions & 2 deletions src/backends/webgl/canvas_util.ts
Expand Up @@ -55,13 +55,25 @@ export function getWebGLContext(webGLVersion: number): WebGLRenderingContext {
return contexts[webGLVersion];
}

export function createCanvas(webGLVersion: number) {
//@ts-ignore
if (typeof(OffscreenCanvas) !== 'undefined' && webGLVersion === 2) {
//@ts-ignore
return new OffscreenCanvas(300, 150);
} else if (typeof(document) !== 'undefined') {
return document.createElement('canvas');
} else {
throw new Error('Cannot create a canvas in this context');
}
}

function getWebGLRenderingContext(webGLVersion: number): WebGLRenderingContext {
if (webGLVersion !== 1 && webGLVersion !== 2) {
throw new Error('Cannot get WebGL rendering context, WebGL is disabled.');
}
const canvas = createCanvas(webGLVersion);

const canvas = document.createElement('canvas');
canvas.addEventListener('webglcontextlost', ev => {
canvas.addEventListener('webglcontextlost', (ev : Event) => {
ev.preventDefault();
delete contexts[webGLVersion];
}, false);
Expand Down
6 changes: 4 additions & 2 deletions src/backends/webgl/canvas_util_test.ts
Expand Up @@ -23,7 +23,9 @@ import {getWebGLContext} from './canvas_util';
describeWithFlags('canvas_util', BROWSER_ENVS, () => {
it('Returns a valid canvas', () => {
const canvas = getWebGLContext(ENV.getNumber('WEBGL_VERSION')).canvas;
expect(canvas instanceof HTMLCanvasElement).toBe(true);
expect((canvas instanceof HTMLCanvasElement)
//@ts-ignore
|| (canvas instanceof OffscreenCanvas)).toBe(true);
});

it('Returns a valid gl context', () => {
Expand All @@ -35,6 +37,6 @@ describeWithFlags('canvas_util', BROWSER_ENVS, () => {
describeWithFlags('canvas_util webgl2', {flags: {WEBGL_VERSION: 2}}, () => {
it('is ok when the user requests webgl 1 canvas', () => {
const canvas = getWebGLContext(1).canvas;
expect(canvas instanceof HTMLCanvasElement).toBe(true);
expect((canvas instanceof HTMLCanvasElement)).toBe(true);
});
});
4 changes: 3 additions & 1 deletion src/device_util.ts
Expand Up @@ -27,5 +27,7 @@ export function isMobile(): boolean {
}

export function isBrowser(): boolean {
return typeof window !== 'undefined';
return (typeof window !== 'undefined') ||
//@ts-ignore
(typeof WorkerGlobalScope !== 'undefined');
}
2 changes: 2 additions & 0 deletions src/engine.ts
Expand Up @@ -932,6 +932,8 @@ function getGlobalNamespace(): {_tfengine: Engine} {
ns = global;
} else if (typeof (process) !== 'undefined') {
ns = process;
} else if (typeof (self) !== 'undefined' ) {
ns = self;
} else {
throw new Error('Could not find a global object');
}
Expand Down
4 changes: 4 additions & 0 deletions src/io/browser_files.ts
Expand Up @@ -63,6 +63,10 @@ export class BrowserDownloads implements IOHandler {
}

async save(modelArtifacts: ModelArtifacts): Promise<SaveResult> {
if (typeof(document) === 'undefined') {
throw new Error('Browser downloads are not supported in ' +
'this environment since `document` is not present');
}
const weightsURL = window.URL.createObjectURL(new Blob(
[modelArtifacts.weightData], {type: 'application/octet-stream'}));

Expand Down
1 change: 1 addition & 0 deletions src/jasmine_util.ts
Expand Up @@ -34,6 +34,7 @@ export const CHROME_ENVS: Constraints = {
export const BROWSER_ENVS: Constraints = {
predicate: () => ENV.platformName === 'browser'
};

export const SYNC_BACKEND_ENVS: Constraints = {
predicate: (testEnv: TestEnv) => testEnv.isDataSync === true
};
Expand Down
4 changes: 2 additions & 2 deletions src/platforms/platform_browser_test.ts
Expand Up @@ -22,11 +22,11 @@ import {PlatformBrowser} from './platform_browser';
describeWithFlags('PlatformBrowser', BROWSER_ENVS, async () => {
it('fetch calls window.fetch', async () => {
const response = new Response();
spyOn(window, 'fetch').and.returnValue(response);
spyOn(self, 'fetch').and.returnValue(response);
const platform = new PlatformBrowser();

await platform.fetch('test/url', {method: 'GET'});

expect(window.fetch).toHaveBeenCalledWith('test/url', {method: 'GET'});
expect(self.fetch).toHaveBeenCalledWith('test/url', {method: 'GET'});
});
});

0 comments on commit af3f975

Please sign in to comment.