Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tfjs-core/src/backends/cpu/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ function mapActivation(
return backend.linear(x);
} else if (activation === 'relu') {
return backend.relu(x);
} else if (activation === 'elu') {
return backend.elu(x);
} else if (activation === 'prelu') {
return backend.prelu(x, preluActivationWeights);
}
Expand Down
8 changes: 8 additions & 0 deletions tfjs-core/src/backends/webgl/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ function mapActivationToShaderProgram(
return unary_packed_op.RELU;
}
return unary_op.RELU;
} else if (activation === 'elu') {
if (packed) {
return unary_packed_op.ELU;
}
return unary_op.ELU;
} else if (activation === 'prelu') {
if (packed) {
return binaryop_packed_gpu.PRELU;
Expand Down Expand Up @@ -1745,6 +1750,9 @@ export class MathBackendWebGL implements KernelBackend {
}

elu<T extends Tensor>(x: T): T {
if (ENV.getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
return this.packedUnaryOp(x, unary_packed_op.ELU, x.dtype) as T;
}
const program = new UnaryOpProgram(x.shape, unary_op.ELU);
return this.compileAndRun(program, [x]);
}
Expand Down
11 changes: 11 additions & 0 deletions tfjs-core/src/backends/webgl/unaryop_packed_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ export const RELU = `
return result;
`;

export const ELU = `
vec4 result;

result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);
result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);
result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);
result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);

return result;
`;

export class UnaryOpPackedProgram implements GPGPUProgram {
variableNames = ['A'];
userCode: string;
Expand Down
57 changes: 57 additions & 0 deletions tfjs-core/src/ops/fused_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
expectArraysClose(await c.data(), [0, 8, 0, 20]);
});

it('A x B with elu', async () => {
const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
const transposeA = false;
const transposeB = false;

const c = tf.fused.matMul(
{a, b, transposeA, transposeB, bias: null, activation: 'elu'});

expect(c.shape).toEqual([2, 2]);
expectArraysClose(await c.data(), [0, 8, -0.9502, 20]);
});

it('A x B with prelu', async () => {
const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
Expand Down Expand Up @@ -106,6 +119,21 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
expectArraysClose(await d.data(), [1, 9, 0, 21]);
});

it('A x B with elu and broadcasted bias', async () => {
const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
const c = tf.tensor1d([1, 1]);
const act: tf.fused.Activation = 'elu';
const transposeA = false;
const transposeB = false;

const d = tf.fused.matMul(
{a, b, transposeA, transposeB, bias: c, activation: act});

expect(d.shape).toEqual([2, 2]);
expectArraysClose(await d.data(), [1, 9, -0.8647, 21]);
});

it('A x B with relu and broadcasted bias different rank', async () => {
const a = tf.tensor3d([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [2, 2, 3]);
const b = tf.tensor3d([0, 1, -3, 2, 2, 1, 0, 1, -3, 2, 2, 1], [2, 3, 2]);
Expand Down Expand Up @@ -318,6 +346,35 @@ describeWithFlags('fused conv2d', ALL_ENVS, () => {
expectArraysClose(await result.data(), expected);
});

it('basic with elu', async () => {
const inputDepth = 2;
const inShape: [number, number, number, number] = [2, 2, 2, inputDepth];
const outputDepth = 2;
const fSize = 1;
const pad = 0;
const stride = 1;

const x = tf.tensor4d(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], inShape);
const w =
tf.tensor4d([-1, 1, -2, 0.5], [fSize, fSize, inputDepth, outputDepth]);

const result = tf.fused.conv2d({
x,
filter: w,
strides: stride,
pad,
dataFormat: 'NHWC',
dilations: [1, 1],
activation: 'elu'
});
expect(result.shape).toEqual([2, 2, 2, 2]);
const expected =
[-0.99326, 2, -1, 5, -1, 8, -1, 11, -1, 14, -1, 17, -1, 20, -1, 23];

expectArraysClose(await result.data(), expected);
});

it('basic with prelu', async () => {
const inputDepth = 2;
const inShape: [number, number, number, number] = [2, 2, 2, inputDepth];
Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/ops/fused_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import {Tensor, Tensor3D} from '../tensor';

export type Activation = 'linear'|'relu'|'prelu';
export type Activation = 'linear'|'relu'|'prelu'|'elu';

export type FusedBatchMatMulConfig = {
a: Tensor3D,
Expand Down