Skip to content

Commit

Permalink
Implement PositionEmbedding layer (#7887)
Browse files Browse the repository at this point in the history
* Implement position embedding

* Strip debug ops in jax conversion tests (#7889)

INTERNAL
This fixes an internal issue with jax tests. See cl/550054296.

* Update weights loading (#7872)

* Update weights loading

* fix tests

* remove

* fix

* fix comments

* fix lint

* Load python rules in tfjs-converter converters dir (#7892)

* Implement MultiHeadAttention Layer (#7875)

* Add spec for multi-head attention

* Add CachedMultiHeadAttention cache

* Fix typos

* Lint

* Add Transformer Decoder spec

* lint

* Add Einsum spec

* lint

* Remove unused type declaration

* Move helper functions outside EinsumDense class

* Implement Einsum Dense

* Address comments

* Implement MHA Layer

* Add masked softmax support

* Fix typo

* Check for undef and null

* Make buildFromSignature public

* Wrap softmax call in tf.tidy

* Implement position embedding

---------

Co-authored-by: Matthew Soulanille <msoulanille@google.com>
Co-authored-by: fengwuyao <131706622+fengwuyao@users.noreply.github.com>
  • Loading branch information
3 people committed Aug 7, 2023
1 parent 0fd462d commit 7b84039
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 12 deletions.
62 changes: 50 additions & 12 deletions tfjs-layers/src/layers/nlp/modeling/position_embedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
*/

/* Original source: keras_nlp/layers/modeling/position_embedding.py */
import { Tensor, Tensor1D, Tensor2D, serialization } from '@tensorflow/tfjs-core';
import { Tensor, serialization, tidy } from '@tensorflow/tfjs-core';

import { Shape } from '../../../keras_format/common';
import { Layer, LayerArgs } from '../../../engine/topology';
import { NotImplementedError } from '../../../errors';
import { InitializerIdentifier } from '../../../initializers';
import { ValueError } from '../../../errors';
import { Initializer, InitializerIdentifier, getInitializer, serializeInitializer } from '../../../initializers';
import { getExactlyOneTensor } from '../../../utils/types_utils';
import { LayerVariable } from '../../../variables';

export declare interface PositionEmbeddingArgs extends LayerArgs {
/**
Expand All @@ -37,7 +39,7 @@ export declare interface PositionEmbeddingArgs extends LayerArgs {
* The initializer to use for the embedding weights.
* Defaults to `"glorotUniform"`.
*/
initializer?: InitializerIdentifier;
initializer?: Initializer|InitializerIdentifier;
}

export declare interface PositionEmbeddingOptions {
Expand Down Expand Up @@ -84,26 +86,62 @@ export declare interface PositionEmbeddingOptions {
export class PositionEmbedding extends Layer {
/** @nocollapse */
static readonly className = 'PositionEmbedding';
private sequenceLength: number;
private initializer: Initializer;
protected positionEmbeddings: LayerVariable;

constructor(args: PositionEmbeddingArgs) {
super(args);

throw new NotImplementedError('PositionEmbedding not implemented yet.');
if (args.sequenceLength == null) {
throw new ValueError(
'`sequenceLength` must be an Integer, received `null`.');
}
this.sequenceLength = args.sequenceLength;
this.initializer = getInitializer(args.initializer || 'glorotUniform');
}

override getConfig(): serialization.ConfigDict {
throw new NotImplementedError('Not implemented yet.');
const config = {
'sequenceLength': this.sequenceLength,
'initializer': serializeInitializer(this.initializer),
};
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}

override build(inputShape: Shape | Shape[]): void {
throw new NotImplementedError('Not implemented yet.');
override build(inputShape: Shape): void {
const featureSize = inputShape[inputShape.length - 1];
this.positionEmbeddings = this.addWeight(
'embeddings',
[this.sequenceLength, featureSize],
null,
this.initializer,
null,
true
);
super.build(inputShape);
}

override call(
inputs: Tensor|Tensor[],
kwargs: PositionEmbeddingOptions={startIndex: 0}
): Tensor1D|Tensor2D {
throw new NotImplementedError('Not implemented yet.');
kwargs?: PositionEmbeddingOptions
): Tensor {
return tidy(() => {
kwargs.startIndex = kwargs.startIndex ?? 0;
const shape = getExactlyOneTensor(inputs).shape;
const featureLength = shape[shape.length - 1];
const sequenceLength = shape[shape.length - 2];
// trim to match the length of the input sequence, which might be less
// than the sequence_length of the layer.
const positionEmbeddings = this.positionEmbeddings.read().slice(
[kwargs.startIndex, 0], [sequenceLength, featureLength]);
return positionEmbeddings.broadcastTo(shape);
});
}

override computeOutputShape(inputShape: Shape): Shape {
return inputShape;
}
}
serialization.registerClass(PositionEmbedding);
177 changes: 177 additions & 0 deletions tfjs-layers/src/layers/nlp/modeling/position_embedding_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

/**
* Tests for position embedding layer..
*/

import { Tensor, memory, ones, randomUniform } from '@tensorflow/tfjs-core';

import { SymbolicTensor } from '../../../engine/topology';
import { input, model } from '../../../exports';
import { PositionEmbedding } from './position_embedding';

describe('PositionEmbedding', () => {
it('static layer output shape', () => {
// Create a 3-dimensional input (the first dimension is implicit).
const sequenceLength = 21;
const featureSize = 30;
const testLayer = new PositionEmbedding({sequenceLength});
const inputTensor = input({shape: [sequenceLength, featureSize]});
const outputTensor = testLayer.apply(inputTensor) as SymbolicTensor;

// When using static position embedding shapes, the output is expected to
// be the same shape as the input shape in all dimensions save batch.
const expectedOutputShape = [null, sequenceLength, featureSize];

expect(outputTensor.shape).toEqual(expectedOutputShape);
});

it('more than 3 dimensions static', () => {
// Create a 4-dimensional input (the first dimension is implicit).
const sequenceLength = 21;
const featureSize = 30;
const testLayer = new PositionEmbedding({sequenceLength});
const inputTensor =
input({shape: [featureSize, sequenceLength, featureSize]});
const outputTensor = testLayer.apply(inputTensor) as SymbolicTensor;

// When using static position embedding shapes, the output is expected
// to be the same as the input shape in all dimensions save batch.
const expectedOutputShape =
[null, featureSize, sequenceLength, featureSize];

expect(outputTensor.shape).toEqual(expectedOutputShape);
});

it('float32 dtype', () => {
// Create a 3-dimensional input (the first dimension is implicit).
const sequenceLength = 21;
const featureSize = 30;
const testLayer = new PositionEmbedding({sequenceLength, dtype: 'float32'});
const inputTensor = input({shape: [sequenceLength, featureSize]});
const outputTensor = testLayer.apply(inputTensor) as SymbolicTensor;

// When using static position embedding shapes, the output is expected
// to be the same as the input shape in all dimensions save batch.
const expectedOutputShape =
[null, sequenceLength, featureSize];

expect(outputTensor.shape).toEqual(expectedOutputShape);
// The output dtype for this layer should match the compute dtype.
expect(outputTensor.dtype).toEqual('float32');
});

it('dynamic layer output shape', () => {
const maxSequenceLength = 21;
const featureSize = 30;
const testLayer =
new PositionEmbedding({sequenceLength: maxSequenceLength});
// Create a 3-dimensional input (the first dimension is implicit).
const inputTensor = input({shape: [null, featureSize]});
const outputTensor = testLayer.apply(inputTensor) as SymbolicTensor;

// When using dynamic position embedding shapes, the output is expected to
// be the same shape as the input shape in all dimensions - but may be
// null if the input shape is null there.
const expectedOutputShape = [null, null, featureSize];

expect(outputTensor.shape).toEqual(expectedOutputShape);
});

it('more than 3 dimensions dynamic', () => {
const maxSequenceLength = 60;
const featureSize = 30;
const testLayer =
new PositionEmbedding({sequenceLength: maxSequenceLength});
// Create a 4-dimensional input (the first dimension is implicit).
const inputTensor = input({shape: [null, null, featureSize]});
const outputTensor = testLayer.apply(inputTensor) as SymbolicTensor;

// When using dynamic position embedding shapes, the output is expected
// to be the same as the input shape in all dimensions save batch.
const expectedOutputShape = [null, null, null, featureSize];

expect(outputTensor.shape).toEqual(expectedOutputShape);
});

it('dynamic layer slicing', () => {
const maxSequenceLength = 40;
const featureSize = 30;
const testLayer =
new PositionEmbedding({sequenceLength: maxSequenceLength});
// Create a 3-dimensional input (the first dimension is implicit).
const inputTensor = input({shape: [null, featureSize]});
const outputTensor = testLayer.apply(inputTensor) as SymbolicTensor;

const pmodel = model({inputs: inputTensor, outputs: outputTensor});

// Create input data that is shorter than maxSequenceLength, which
// should trigger a down-slice.
const inputLength = 17;
// Note: In practice, this layer should be used inside a model, where it can
// be projected when added to another tensor.
const inputData = ones([1, inputLength, featureSize]);
const outputData = pmodel.predict(inputData) as Tensor;

expect(outputData.shape).toEqual([1, inputLength, featureSize]);
});

it('one training step', async () => {
const maxSequenceLength = 4;
const featureSize = 3;
const inputs = input({shape: [maxSequenceLength, featureSize]});
const testLayer =
new PositionEmbedding({sequenceLength: maxSequenceLength});
const outputs = testLayer.apply(inputs) as SymbolicTensor;
const pmodel = model({inputs, outputs});

const batchSize = 2;
const data = randomUniform([batchSize, maxSequenceLength, featureSize]);
const label = randomUniform([batchSize, maxSequenceLength, featureSize]);

pmodel.compile({optimizer: 'adam', loss: 'meanSquaredError'});
const loss = pmodel.trainOnBatch(data, label);

expect(await loss).toBeGreaterThan(0);
});

it('serialization round trip', () => {
const maxSequenceLength = 40;
const testLayer = new PositionEmbedding({
sequenceLength: maxSequenceLength,
});
const original = testLayer.getConfig();
const restored =
PositionEmbedding.fromConfig(PositionEmbedding, original).getConfig();

expect(original['sequenceLength']).toEqual(restored['sequenceLength']);
});

it('does not leak memory', () => {
const sequenceLength = 4;
const batchSize = 2;
const testLayer = new PositionEmbedding({sequenceLength});
const data = randomUniform([batchSize, sequenceLength, 3]);
testLayer.build(data.shape);

const numTensors = memory().numTensors;
testLayer.apply(data);

expect(memory().numTensors).toEqual(numTensors + 1);
});
});

0 comments on commit 7b84039

Please sign in to comment.