diff --git a/tfjs-core/src/backends/cpu/backend_cpu.ts b/tfjs-core/src/backends/cpu/backend_cpu.ts index 0715afc7192..ad203f2a833 100644 --- a/tfjs-core/src/backends/cpu/backend_cpu.ts +++ b/tfjs-core/src/backends/cpu/backend_cpu.ts @@ -142,7 +142,12 @@ export class MathBackendCPU implements KernelBackend { pixels instanceof HTMLVideoElement; const isImage = typeof (HTMLImageElement) !== 'undefined' && pixels instanceof HTMLImageElement; - + const [width, height] = isVideo ? + [ + (pixels as HTMLVideoElement).videoWidth, + (pixels as HTMLVideoElement).videoHeight + ] : + [pixels.width, pixels.height]; let vals: Uint8ClampedArray|Uint8Array; // tslint:disable-next-line:no-any if (ENV.get('IS_NODE') && (pixels as any).getContext == null) { @@ -155,7 +160,7 @@ export class MathBackendCPU implements KernelBackend { // tslint:disable-next-line:no-any vals = (pixels as any) .getContext('2d') - .getImageData(0, 0, pixels.width, pixels.height) + .getImageData(0, 0, width, height) .data; } else if (isImageData || isPixelData) { vals = (pixels as PixelData | ImageData).data; @@ -165,13 +170,11 @@ export class MathBackendCPU implements KernelBackend { 'Can\'t read pixels from HTMLImageElement outside ' + 'the browser.'); } - this.fromPixels2DContext.canvas.width = pixels.width; - this.fromPixels2DContext.canvas.height = pixels.height; + this.fromPixels2DContext.canvas.width = width; + this.fromPixels2DContext.canvas.height = height; this.fromPixels2DContext.drawImage( - pixels as HTMLVideoElement, 0, 0, pixels.width, pixels.height); - vals = this.fromPixels2DContext - .getImageData(0, 0, pixels.width, pixels.height) - .data; + pixels as HTMLVideoElement, 0, 0, width, height); + vals = this.fromPixels2DContext.getImageData(0, 0, width, height).data; } else { throw new Error( 'pixels passed to tf.browser.fromPixels() must be either an ' + @@ -183,7 +186,7 @@ export class MathBackendCPU implements KernelBackend { if (numChannels === 4) { values = new Int32Array(vals); } else { - const numPixels = pixels.width * pixels.height; + const numPixels = width * height; values = new Int32Array(numPixels * numChannels); for (let i = 0; i < numPixels; i++) { for (let channel = 0; channel < numChannels; ++channel) { @@ -191,8 +194,7 @@ export class MathBackendCPU implements KernelBackend { } } } - const outShape: [number, number, number] = - [pixels.height, pixels.width, numChannels]; + const outShape: [number, number, number] = [height, width, numChannels]; return tensor3d(values, outShape, 'int32'); } async read(dataId: DataId): Promise { diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index 6dc0e2475ec..3bd2b8387aa 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -286,8 +286,6 @@ export class MathBackendWebGL implements KernelBackend { throw new Error( 'pixels passed to tf.browser.fromPixels() can not be null'); } - const texShape: [number, number] = [pixels.height, pixels.width]; - const outShape = [pixels.height, pixels.width, numChannels]; const isCanvas = (typeof (OffscreenCanvas) !== 'undefined' && pixels instanceof OffscreenCanvas) || @@ -300,6 +298,15 @@ export class MathBackendWebGL implements KernelBackend { pixels instanceof HTMLVideoElement; const isImage = typeof (HTMLImageElement) !== 'undefined' && pixels instanceof HTMLImageElement; + const [width, height] = isVideo ? + [ + (pixels as HTMLVideoElement).videoWidth, + (pixels as HTMLVideoElement).videoHeight + ] : + [pixels.width, pixels.height]; + + const texShape: [number, number] = [height, width]; + const outShape = [height, width, numChannels]; if (!isCanvas && !isPixelData && !isImageData && !isVideo && !isImage) { throw new Error( @@ -312,21 +319,15 @@ export class MathBackendWebGL implements KernelBackend { if (isImage || isVideo) { if (this.fromPixels2DContext == null) { - if (document.readyState !== 'complete') { - throw new Error( - 'The DOM is not ready yet. Please call ' + - 'tf.browser.fromPixels() once the DOM is ready. One way to ' + - 'do that is to add an event listener for `load` ' + - 'on the document object'); - } //@ts-ignore this.fromPixels2DContext = createCanvas(ENV.getNumber('WEBGL_VERSION')).getContext('2d'); } - this.fromPixels2DContext.canvas.width = pixels.width; - this.fromPixels2DContext.canvas.height = pixels.height; + + this.fromPixels2DContext.canvas.width = width; + this.fromPixels2DContext.canvas.height = height; this.fromPixels2DContext.drawImage( - pixels as HTMLVideoElement, 0, 0, pixels.width, pixels.height); + pixels as HTMLVideoElement | HTMLImageElement, 0, 0, width, height); //@ts-ignore pixels = this.fromPixels2DContext.canvas; } diff --git a/tfjs-core/src/ops/array_ops_test.ts b/tfjs-core/src/ops/array_ops_test.ts index 00adfbee3fa..fe990494af2 100644 --- a/tfjs-core/src/ops/array_ops_test.ts +++ b/tfjs-core/src/ops/array_ops_test.ts @@ -1608,15 +1608,36 @@ describeWithFlags('fromPixels', BROWSER_ENVS, () => { expect(data.length).toEqual(10 * 10 * 3); }); it('fromPixels for HTMLVideolement', async () => { + const video = document.createElement('video'); + video.autoplay = true; + const source = document.createElement('source'); + // tslint:disable:max-line-length + source.src = 'data:video/mp4;base64,AAAAIGZ0eXBpc29tAAACAGlzb21pc28yYXZjMW1wNDEAAAAIZnJlZQAAAu1tZGF0AAACrQYF//+p3EXpvebZSLeWLNgg2SPu73gyNjQgLSBjb3JlIDE1NSByMjkwMSA3ZDBmZjIyIC0gSC4yNjQvTVBFRy00IEFWQyBjb2RlYyAtIENvcHlsZWZ0IDIwMDMtMjAxOCAtIGh0dHA6Ly93d3cudmlkZW9sYW4ub3JnL3gyNjQuaHRtbCAtIG9wdGlvbnM6IGNhYmFjPTEgcmVmPTMgZGVibG9jaz0xOjA6MCBhbmFseXNlPTB4MzoweDExMyBtZT1oZXggc3VibWU9NyBwc3k9MSBwc3lfcmQ9MS4wMDowLjAwIG1peGVkX3JlZj0xIG1lX3JhbmdlPTE2IGNocm9tYV9tZT0xIHRyZWxsaXM9MSA4eDhkY3Q9MSBjcW09MCBkZWFkem9uZT0yMSwxMSBmYXN0X3Bza2lwPTEgY2hyb21hX3FwX29mZnNldD0tMiB0aHJlYWRzPTMgbG9va2FoZWFkX3RocmVhZHM9MSBzbGljZWRfdGhyZWFkcz0wIG5yPTAgZGVjaW1hdGU9MSBpbnRlcmxhY2VkPTAgYmx1cmF5X2NvbXBhdD0wIGNvbnN0cmFpbmVkX2ludHJhPTAgYmZyYW1lcz0zIGJfcHlyYW1pZD0yIGJfYWRhcHQ9MSBiX2JpYXM9MCBkaXJlY3Q9MSB3ZWlnaHRiPTEgb3Blbl9nb3A9MCB3ZWlnaHRwPTIga2V5aW50PTI1MCBrZXlpbnRfbWluPTEgc2NlbmVjdXQ9NDAgaW50cmFfcmVmcmVzaD0wIHJjX2xvb2thaGVhZD00MCByYz1jcmYgbWJ0cmVlPTEgY3JmPTI4LjAgcWNvbXA9MC42MCBxcG1pbj0wIHFwbWF4PTY5IHFwc3RlcD00IGlwX3JhdGlvPTEuNDAgYXE9MToxLjAwAIAAAAAwZYiEAD//8m+P5OXfBeLGOfKE3xkODvFZuBflHv/+VwJIta6cbpIo4ABLoKBaYTkTAAAC7m1vb3YAAABsbXZoZAAAAAAAAAAAAAAAAAAAA+gAAAPoAAEAAAEAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIAAAIYdHJhawAAAFx0a2hkAAAAAwAAAAAAAAAAAAAAAQAAAAAAAAPoAAAAAAAAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAQAAAAACgAAAAWgAAAAAAJGVkdHMAAAAcZWxzdAAAAAAAAAABAAAD6AAAAAAAAQAAAAABkG1kaWEAAAAgbWRoZAAAAAAAAAAAAAAAAAAAQAAAAEAAVcQAAAAAAC1oZGxyAAAAAAAAAAB2aWRlAAAAAAAAAAAAAAAAVmlkZW9IYW5kbGVyAAAAATttaW5mAAAAFHZtaGQAAAABAAAAAAAAAAAAAAAkZGluZgAAABxkcmVmAAAAAAAAAAEAAAAMdXJsIAAAAAEAAAD7c3RibAAAAJdzdHNkAAAAAAAAAAEAAACHYXZjMQAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAACgAFoASAAAAEgAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABj//wAAADFhdmNDAWQACv/hABhnZAAKrNlCjfkhAAADAAEAAAMAAg8SJZYBAAZo6+JLIsAAAAAYc3R0cwAAAAAAAAABAAAAAQAAQAAAAAAcc3RzYwAAAAAAAAABAAAAAQAAAAEAAAABAAAAFHN0c3oAAAAAAAAC5QAAAAEAAAAUc3RjbwAAAAAAAAABAAAAMAAAAGJ1ZHRhAAAAWm1ldGEAAAAAAAAAIWhkbHIAAAAAAAAAAG1kaXJhcHBsAAAAAAAAAAAAAAAALWlsc3QAAAAlqXRvbwAAAB1kYXRhAAAAAQAAAABMYXZmNTguMTIuMTAw'; + source.type = 'video/mp4'; + video.appendChild(source); + document.body.appendChild(video); + + // On mobile safari the ready state is ready immediately so we + if (video.readyState < 2) { + await new Promise(resolve => { + video.addEventListener('loadeddata', () => resolve()); + }); + } + + const res = tf.browser.fromPixels(video); + expect(res.shape).toEqual([90, 160, 3]); + const data = await res.data(); + expect(data.length).toEqual(90 * 160 *3); + document.body.removeChild(video); + }); + + it('fromPixels for HTMLVideolement throws without loadeddata', async () => { const video = document.createElement('video'); video.width = 1; video.height = 1; video.src = 'data:image/gif;base64' + ',R0lGODlhAQABAIAAAP///wAAACH5BAEAAAAALAAAAAABAAEAAAICRAEAOw=='; - const res = tf.browser.fromPixels(video); - expect(res.shape).toEqual([1, 1, 3]); - const data = await res.data(); - expect(data.length).toEqual(1 * 1 * 3); + expect(() => tf.browser.fromPixels(video)).toThrowError(); }); it('throws when passed a primitive number', () => { diff --git a/tfjs-core/src/ops/browser.ts b/tfjs-core/src/ops/browser.ts index 58a084e75cd..7d9032cab02 100644 --- a/tfjs-core/src/ops/browser.ts +++ b/tfjs-core/src/ops/browser.ts @@ -52,6 +52,18 @@ function fromPixels_( throw new Error( 'Cannot construct Tensor with more than 4 channels from pixels.'); } + const isVideo = typeof (HTMLVideoElement) !== 'undefined' && + pixels instanceof HTMLVideoElement; + if (isVideo) { + const HAVE_CURRENT_DATA_READY_STATE = 2; + if (isVideo && + (pixels as HTMLVideoElement).readyState < + HAVE_CURRENT_DATA_READY_STATE) { + throw new Error( + 'The video element has not loaded data yet. Please wait for ' + + '`loadeddata` event on the