Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions tfjs-core/src/backends/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';
import {FusedBatchMatMulConfig, FusedConv2DConfig} from '../ops/fused_util';
import {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
import {BackendValues, DataType, PixelData, Rank, ShapeMap} from '../types';
import {BackendValues, DataType, Rank, ShapeMap} from '../types';

export const EPSILON_FLOAT32 = 1e-7;
export const EPSILON_FLOAT16 = 1e-4;
Expand All @@ -37,10 +37,6 @@ export interface TensorStorage {
write(values: BackendValues, shape: number[], dtype: DataType): DataId;
move(dataId: DataId, values: BackendValues, shape: number[], dtype: DataType):
void;
fromPixels(
pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement|
HTMLVideoElement,
numChannels: number): Tensor3D;
memory(): {unreliable: boolean;}; // Backend-specific information.
/** Returns number of data ids currently in the storage. */
numDataIds(): number;
Expand Down Expand Up @@ -114,12 +110,6 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
disposeData(dataId: object): void {
return notYetImplemented('disposeData');
}
fromPixels(
pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement|
HTMLVideoElement,
numChannels: number): Tensor<Rank.R3> {
return notYetImplemented('fromPixels');
}
write(values: BackendValues, shape: number[], dtype: DataType): DataId {
return notYetImplemented('write');
}
Expand Down
93 changes: 2 additions & 91 deletions tfjs-core/src/backends/cpu/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ import * as erf_util from '../../ops/erf_util';
import {Activation, FusedBatchMatMulConfig, FusedConv2DConfig} from '../../ops/fused_util';
import * as gather_nd_util from '../../ops/gather_nd_util';
import * as ops from '../../ops/ops';
import {buffer, scalar, tensor, tensor3d, tensor4d} from '../../ops/ops';
import {buffer, scalar, tensor, tensor4d} from '../../ops/ops';
import * as scatter_nd_util from '../../ops/scatter_nd_util';
import * as selu_util from '../../ops/selu_util';
import {computeFlatOffset, computeOutShape, isSliceContinous} from '../../ops/slice_util';
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor';
import {BackendValues, DataType, DataValues, NumericDataType, PixelData, Rank, ShapeMap, TypedArray, upcastType} from '../../types';
import {BackendValues, DataType, DataValues, NumericDataType, Rank, ShapeMap, TypedArray, upcastType} from '../../types';
import * as util from '../../util';
import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util';
import {BackendTimingInfo, DataStorage, EPSILON_FLOAT32, KernelBackend} from '../backend';
Expand Down Expand Up @@ -67,15 +67,6 @@ function mapActivation(
`Activation ${activation} has not been implemented for the CPU backend.`);
}

function createCanvas() {
if (typeof OffscreenCanvas !== 'undefined') {
return new OffscreenCanvas(300, 150);
} else if (typeof document !== 'undefined') {
return document.createElement('canvas');
}
return null;
}

export interface TensorData<D extends DataType> {
values?: BackendValues;
dtype: D;
Expand All @@ -91,19 +82,10 @@ export class MathBackendCPU extends KernelBackend {
public blockSize = 48;

data: DataStorage<TensorData<DataType>>;
private fromPixels2DContext: CanvasRenderingContext2D|
OffscreenCanvasRenderingContext2D;
private firstUse = true;

constructor() {
super();
if (env().get('IS_BROWSER')) {
const canvas = createCanvas();
if (canvas !== null) {
this.fromPixels2DContext =
canvas.getContext('2d') as CanvasRenderingContext2D;
}
}
this.data = new DataStorage(this, ENGINE);
}

Expand Down Expand Up @@ -138,77 +120,6 @@ export class MathBackendCPU extends KernelBackend {
return this.data.numDataIds();
}

fromPixels(
pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement|
HTMLVideoElement,
numChannels: number): Tensor3D {
if (pixels == null) {
throw new Error(
'pixels passed to tf.browser.fromPixels() can not be null');
}

const isPixelData = (pixels as PixelData).data instanceof Uint8Array;
const isImageData =
typeof (ImageData) !== 'undefined' && pixels instanceof ImageData;
const isVideo = typeof (HTMLVideoElement) !== 'undefined' &&
pixels instanceof HTMLVideoElement;
const isImage = typeof (HTMLImageElement) !== 'undefined' &&
pixels instanceof HTMLImageElement;
const [width, height] = isVideo ?
[
(pixels as HTMLVideoElement).videoWidth,
(pixels as HTMLVideoElement).videoHeight
] :
[pixels.width, pixels.height];
let vals: Uint8ClampedArray|Uint8Array;
// tslint:disable-next-line:no-any
if (env().get('IS_NODE') && (pixels as any).getContext == null) {
throw new Error(
'When running in node, pixels must be an HTMLCanvasElement ' +
'like the one returned by the `canvas` npm package');
}
// tslint:disable-next-line:no-any
if ((pixels as any).getContext != null) {
// tslint:disable-next-line:no-any
vals = (pixels as any)
.getContext('2d')
.getImageData(0, 0, width, height)
.data;
} else if (isImageData || isPixelData) {
vals = (pixels as PixelData | ImageData).data;
} else if (isImage || isVideo) {
if (this.fromPixels2DContext == null) {
throw new Error(
'Can\'t read pixels from HTMLImageElement outside ' +
'the browser.');
}
this.fromPixels2DContext.canvas.width = width;
this.fromPixels2DContext.canvas.height = height;
this.fromPixels2DContext.drawImage(
pixels as HTMLVideoElement, 0, 0, width, height);
vals = this.fromPixels2DContext.getImageData(0, 0, width, height).data;
} else {
throw new Error(
'pixels passed to tf.browser.fromPixels() must be either an ' +
`HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData ` +
`or {data: Uint32Array, width: number, height: number}, ` +
`but was ${(pixels as {}).constructor.name}`);
}
let values: Int32Array;
if (numChannels === 4) {
values = new Int32Array(vals);
} else {
const numPixels = width * height;
values = new Int32Array(numPixels * numChannels);
for (let i = 0; i < numPixels; i++) {
for (let channel = 0; channel < numChannels; ++channel) {
values[i * numChannels + channel] = vals[i * 4 + channel];
}
}
}
const outShape: [number, number, number] = [height, width, numChannels];
return tensor3d(values, outShape, 'int32');
}
async read(dataId: DataId): Promise<BackendValues> {
return this.readSync(dataId);
}
Expand Down
3 changes: 2 additions & 1 deletion tfjs-core/src/backends/cpu/backend_cpu_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,12 @@ describeWithFlags('memory cpu', CPU_ENVS, () => {
});
});

describe('CPU backend has sync init', () => {
describeWithFlags('CPU backend has sync init', {}, () => {
it('can do matmul without waiting for ready', async () => {
tf.registerBackend('my-cpu', () => {
return new MathBackendCPU();
});
tf.setBackend('my-cpu');
const a = tf.tensor1d([5]);
const b = tf.tensor1d([3]);
const res = tf.dot(a, b);
Expand Down
1 change: 1 addition & 0 deletions tfjs-core/src/backends/webgl/all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
// global registry when we compile the library. A modular build would replace
// the contents of this file and import only the kernels that are needed.
import './square';
import './fromPixels';
94 changes: 7 additions & 87 deletions tfjs-core/src/backends/webgl/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import * as slice_util from '../../ops/slice_util';
import {softmax} from '../../ops/softmax';
import {range, scalar, tensor} from '../../ops/tensor_ops';
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor';
import {BackendValues, DataType, DataTypeMap, NumericDataType, PixelData, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types';
import {BackendValues, DataType, DataTypeMap, NumericDataType, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types';
import * as util from '../../util';
import {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} from '../../util';
import {DataStorage, EPSILON_FLOAT16, EPSILON_FLOAT32, KernelBackend} from '../backend';
Expand All @@ -64,7 +64,7 @@ import * as binaryop_gpu from './binaryop_gpu';
import {BinaryOpProgram} from './binaryop_gpu';
import * as binaryop_packed_gpu from './binaryop_packed_gpu';
import {BinaryOpPackedProgram} from './binaryop_packed_gpu';
import {createCanvas, getWebGLContext} from './canvas_util';
import {getWebGLContext} from './canvas_util';
import {ClipProgram} from './clip_gpu';
import {ClipPackedProgram} from './clip_packed_gpu';
import {ComplexAbsProgram} from './complex_abs_gpu';
Expand All @@ -88,8 +88,6 @@ import {EncodeMatrixPackedProgram} from './encode_matrix_packed_gpu';
import * as fft_gpu from './fft_gpu';
import {FFTProgram} from './fft_gpu';
import {FillProgram} from './fill_gpu';
import {FromPixelsProgram} from './from_pixels_gpu';
import {FromPixelsPackedProgram} from './from_pixels_packed_gpu';
import {GatherProgram} from './gather_gpu';
import {GatherNDProgram} from './gather_nd_gpu';
import {GPGPUContext} from './gpgpu_context';
Expand Down Expand Up @@ -221,6 +219,8 @@ export const MATMUL_SHARED_DIM_THRESHOLD = 1000;

export class MathBackendWebGL extends KernelBackend {
texData: DataStorage<TextureData>;
gpgpu: GPGPUContext;

// Maps data ids that have a pending read operation, to list of subscribers.
private pendingRead = new WeakMap<DataId, Array<(arr: TypedArray) => void>>();
// List of data ids that are scheduled for disposal, but are waiting on a
Expand All @@ -232,8 +232,6 @@ export class MathBackendWebGL extends KernelBackend {
private numBytesInGPU = 0;

private canvas: HTMLCanvasElement|OffscreenCanvas;
private fromPixels2DContext: CanvasRenderingContext2D|
OffscreenCanvasRenderingContext2D;

private programTimersStack: TimerNode[];
private activeTimers: TimerNode[];
Expand All @@ -252,7 +250,7 @@ export class MathBackendWebGL extends KernelBackend {
private numMBBeforeWarning: number;
private warnedAboutMemory = false;

constructor(private gpgpu?: GPGPUContext) {
constructor(gpgpu?: GPGPUContext) {
super();
if (!env().getBool('HAS_WEBGL')) {
throw new Error('WebGL is not supported on this device');
Expand All @@ -265,6 +263,7 @@ export class MathBackendWebGL extends KernelBackend {
this.canvas = gl.canvas;
this.gpgpuCreatedLocally = true;
} else {
this.gpgpu = gpgpu;
this.binaryCache = {};
this.gpgpuCreatedLocally = false;
this.canvas = gpgpu.gl.canvas;
Expand All @@ -281,79 +280,6 @@ export class MathBackendWebGL extends KernelBackend {
this.pendingDeletes;
}

fromPixels(
pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement|
HTMLVideoElement,
numChannels: number): Tensor3D {
if (pixels == null) {
throw new Error(
'pixels passed to tf.browser.fromPixels() can not be null');
}

const isCanvas = (typeof (OffscreenCanvas) !== 'undefined' &&
pixels instanceof OffscreenCanvas) ||
(typeof (HTMLCanvasElement) !== 'undefined' &&
pixels instanceof HTMLCanvasElement);
const isPixelData = (pixels as PixelData).data instanceof Uint8Array;
const isImageData =
typeof (ImageData) !== 'undefined' && pixels instanceof ImageData;
const isVideo = typeof (HTMLVideoElement) !== 'undefined' &&
pixels instanceof HTMLVideoElement;
const isImage = typeof (HTMLImageElement) !== 'undefined' &&
pixels instanceof HTMLImageElement;
const [width, height] = isVideo ?
[
(pixels as HTMLVideoElement).videoWidth,
(pixels as HTMLVideoElement).videoHeight
] :
[pixels.width, pixels.height];

const texShape: [number, number] = [height, width];
const outShape = [height, width, numChannels];

if (!isCanvas && !isPixelData && !isImageData && !isVideo && !isImage) {
throw new Error(
'pixels passed to tf.browser.fromPixels() must be either an ' +
`HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData ` +
`in browser, or OffscreenCanvas, ImageData in webworker` +
` or {data: Uint32Array, width: number, height: number}, ` +
`but was ${(pixels as {}).constructor.name}`);
}

if (isImage || isVideo) {
if (this.fromPixels2DContext == null) {
//@ts-ignore
this.fromPixels2DContext =
createCanvas(env().getNumber('WEBGL_VERSION')).getContext('2d');
}

this.fromPixels2DContext.canvas.width = width;
this.fromPixels2DContext.canvas.height = height;
this.fromPixels2DContext.drawImage(
pixels as HTMLVideoElement | HTMLImageElement, 0, 0, width, height);
//@ts-ignore
pixels = this.fromPixels2DContext.canvas;
}

const tempPixelHandle = this.makeTensorInfo(texShape, 'int32');
// This is a byte texture with pixels.
this.texData.get(tempPixelHandle.dataId).usage = TextureUsage.PIXELS;
this.gpgpu.uploadPixelDataToTexture(
this.getTexture(tempPixelHandle.dataId), pixels as ImageData);
let program, res;
if (env().getBool('WEBGL_PACK')) {
program = new FromPixelsPackedProgram(outShape);
res = this.compileAndRun(program, [tempPixelHandle]);
} else {
program = new FromPixelsProgram(outShape);
res = this.compileAndRun(program, [tempPixelHandle]);
}

this.disposeData(tempPixelHandle.dataId);

return res as Tensor3D;
}

write(values: BackendValues, shape: number[], dtype: DataType): DataId {
if (env().getBool('DEBUG')) {
this.checkNumericalProblems(values);
Expand Down Expand Up @@ -2482,7 +2408,7 @@ export class MathBackendWebGL extends KernelBackend {
return backend_util.linspaceImpl(start, stop, num);
}

private makeTensorInfo(shape: number[], dtype: DataType): TensorInfo {
makeTensorInfo(shape: number[], dtype: DataType): TensorInfo {
const dataId = this.write(null /* values */, shape, dtype);
this.texData.get(dataId).usage = null;
return {dataId, shape, dtype};
Expand Down Expand Up @@ -2709,12 +2635,6 @@ export class MathBackendWebGL extends KernelBackend {
} else {
this.canvas = null;
}
if (this.fromPixels2DContext != null &&
//@ts-ignore
this.fromPixels2DContext.canvas.remove) {
//@ts-ignore
this.fromPixels2DContext.canvas.remove();
}
if (this.gpgpuCreatedLocally) {
this.gpgpu.program = null;
this.gpgpu.dispose();
Expand Down
3 changes: 2 additions & 1 deletion tfjs-core/src/backends/webgl/backend_webgl_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -532,11 +532,12 @@ describeWithFlags('caching on cpu', WEBGL_ENVS, () => {
});
});

describe('WebGL backend has sync init', () => {
describeWithFlags('WebGL backend has sync init', WEBGL_ENVS, () => {
it('can do matmul without waiting for ready', async () => {
tf.registerBackend('my-webgl', () => {
return new MathBackendWebGL();
});
tf.setBackend('my-webgl');
const a = tf.tensor1d([5]);
const b = tf.tensor1d([3]);
const res = tf.dot(a, b);
Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/backends/webgl/canvas_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ export function getWebGLContext(webGLVersion: number): WebGLRenderingContext {
return contexts[webGLVersion];
}

export function createCanvas(webGLVersion: number) {
function createCanvas(webGLVersion: number) {
if (typeof OffscreenCanvas !== 'undefined' && webGLVersion === 2) {
return new OffscreenCanvas(300, 150);
} else if (typeof document !== 'undefined') {
Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/backends/webgl/canvas_util_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ describeWithFlags('canvas_util', BROWSER_ENVS, () => {
describeWithFlags('canvas_util webgl2', {flags: {WEBGL_VERSION: 2}}, () => {
it('is ok when the user requests webgl 1 canvas', () => {
const canvas = getWebGLContext(1).canvas;
expect((canvas instanceof HTMLCanvasElement)).toBe(true);
expect(canvas.getContext != null).toBe(true);
});
});
Loading