From bc063b84d317b4bc915939b5235183d4d855445a Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Wed, 8 Jul 2020 14:37:54 -0400 Subject: [PATCH 1/3] modularize signal ops --- tfjs-core/src/ops/frame.ts | 70 ++++++++ tfjs-core/src/ops/frame_test.ts | 124 ++++++++++++++ tfjs-core/src/ops/hamming_window.ts | 38 +++++ tfjs-core/src/ops/hamming_window_test.ts | 48 ++++++ tfjs-core/src/ops/hann_window.ts | 39 +++++ tfjs-core/src/ops/hann_window_test.ts | 47 +++++ tfjs-core/src/ops/ops.ts | 14 +- tfjs-core/src/ops/signal_ops.ts | 159 ----------------- tfjs-core/src/ops/signal_ops_util.ts | 35 ++++ tfjs-core/src/ops/stft.ts | 61 +++++++ .../ops/{signal_ops_test.ts => stft_test.ts} | 161 +----------------- tfjs-core/src/tests.ts | 5 +- 12 files changed, 479 insertions(+), 322 deletions(-) create mode 100644 tfjs-core/src/ops/frame.ts create mode 100644 tfjs-core/src/ops/frame_test.ts create mode 100644 tfjs-core/src/ops/hamming_window.ts create mode 100644 tfjs-core/src/ops/hamming_window_test.ts create mode 100644 tfjs-core/src/ops/hann_window.ts create mode 100644 tfjs-core/src/ops/hann_window_test.ts delete mode 100644 tfjs-core/src/ops/signal_ops.ts create mode 100644 tfjs-core/src/ops/signal_ops_util.ts create mode 100644 tfjs-core/src/ops/stft.ts rename tfjs-core/src/ops/{signal_ops_test.ts => stft_test.ts} (53%) diff --git a/tfjs-core/src/ops/frame.ts b/tfjs-core/src/ops/frame.ts new file mode 100644 index 00000000000..00fa8f73e13 --- /dev/null +++ b/tfjs-core/src/ops/frame.ts @@ -0,0 +1,70 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {op} from '../ops/operation'; +import {Tensor, Tensor1D} from '../tensor'; + +import {concat} from './concat'; +import {fill} from './fill'; +import {slice} from './slice'; +import {tensor2d} from './tensor_ops'; + +/** + * Expands input into frames of frameLength. + * Slides a window size with frameStep. + * + * ```js + * tf.signal.frame([1, 2, 3], 2, 1).print(); + * ``` + * @param signal The input tensor to be expanded + * @param frameLength Length of each frame + * @param frameStep The frame hop size in samples. + * @param padEnd Whether to pad the end of signal with padValue. + * @param padValue An number to use where the input signal does + * not exist when padEnd is True. + */ +/** + * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'} + */ +function frame_( + signal: Tensor1D, frameLength: number, frameStep: number, padEnd = false, + padValue = 0): Tensor { + let start = 0; + const output: Tensor[] = []; + while (start + frameLength <= signal.size) { + output.push(slice(signal, start, frameLength)); + start += frameStep; + } + + if (padEnd) { + while (start < signal.size) { + const padLen = (start + frameLength) - signal.size; + const pad = concat([ + slice(signal, start, frameLength - padLen), fill([padLen], padValue) + ]); + output.push(pad); + start += frameStep; + } + } + + if (output.length === 0) { + return tensor2d([], [0, frameLength]); + } + + return concat(output).as2D(output.length, frameLength); +} +export const frame = op({frame_}); \ No newline at end of file diff --git a/tfjs-core/src/ops/frame_test.ts b/tfjs-core/src/ops/frame_test.ts new file mode 100644 index 00000000000..8a130009c9a --- /dev/null +++ b/tfjs-core/src/ops/frame_test.ts @@ -0,0 +1,124 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('frame', ALL_ENVS, () => { + it('3 length frames', async () => { + const input = tf.tensor1d([1, 2, 3, 4, 5]); + const frameLength = 3; + const frameStep = 1; + const output = tf.signal.frame(input, frameLength, frameStep); + expect(output.shape).toEqual([3, 3]); + expectArraysClose(await output.data(), [1, 2, 3, 2, 3, 4, 3, 4, 5]); + }); + + it('3 length frames with step 2', async () => { + const input = tf.tensor1d([1, 2, 3, 4, 5]); + const frameLength = 3; + const frameStep = 2; + const output = tf.signal.frame(input, frameLength, frameStep); + expect(output.shape).toEqual([2, 3]); + expectArraysClose(await output.data(), [1, 2, 3, 3, 4, 5]); + }); + + it('3 length frames with step 5', async () => { + const input = tf.tensor1d([1, 2, 3, 4, 5]); + const frameLength = 3; + const frameStep = 5; + const output = tf.signal.frame(input, frameLength, frameStep); + expect(output.shape).toEqual([1, 3]); + expectArraysClose(await output.data(), [1, 2, 3]); + }); + + it('Exceeding frame length', async () => { + const input = tf.tensor1d([1, 2, 3, 4, 5]); + const frameLength = 6; + const frameStep = 1; + const output = tf.signal.frame(input, frameLength, frameStep); + expect(output.shape).toEqual([0, 6]); + expectArraysClose(await output.data(), []); + }); + + it('Zero frame step', async () => { + const input = tf.tensor1d([1, 2, 3, 4, 5]); + const frameLength = 6; + const frameStep = 0; + const output = tf.signal.frame(input, frameLength, frameStep); + expect(output.shape).toEqual([0, 6]); + expectArraysClose(await output.data(), []); + }); + + it('Padding with default value', async () => { + const input = tf.tensor1d([1, 2, 3, 4, 5]); + const frameLength = 3; + const frameStep = 3; + const padEnd = true; + const output = tf.signal.frame(input, frameLength, frameStep, padEnd); + expect(output.shape).toEqual([2, 3]); + expectArraysClose(await output.data(), [1, 2, 3, 4, 5, 0]); + }); + + it('Padding with the given value', async () => { + const input = tf.tensor1d([1, 2, 3, 4, 5]); + const frameLength = 3; + const frameStep = 3; + const padEnd = true; + const padValue = 100; + const output = + tf.signal.frame(input, frameLength, frameStep, padEnd, padValue); + expect(output.shape).toEqual([2, 3]); + expectArraysClose(await output.data(), [1, 2, 3, 4, 5, 100]); + }); + + it('Padding all remaining frames with step=1', async () => { + const input = tf.tensor1d([1, 2, 3, 4, 5]); + const frameLength = 4; + const frameStep = 1; + const padEnd = true; + const output = tf.signal.frame(input, frameLength, frameStep, padEnd); + expect(output.shape).toEqual([5, 4]); + expectArraysClose( + await output.data(), + [1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 0, 4, 5, 0, 0, 5, 0, 0, 0]); + }); + + it('Padding all remaining frames with step=1 and given pad-value', + async () => { + const input = tf.tensor1d([1, 2, 3, 4, 5]); + const frameLength = 4; + const frameStep = 1; + const padEnd = true; + const padValue = 42; + const output = + tf.signal.frame(input, frameLength, frameStep, padEnd, padValue); + expect(output.shape).toEqual([5, 4]); + expectArraysClose( + await output.data(), + [1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 42, 4, 5, 42, 42, 5, 42, 42, 42]); + }); + + it('Padding all remaining frames with step=2', async () => { + const input = tf.tensor1d([1, 2, 3, 4, 5]); + const output = tf.signal.frame(input, 4, 2, true); + expect(output.shape).toEqual([3, 4]); + expectArraysClose( + await output.data(), [1, 2, 3, 4, 3, 4, 5, 0, 5, 0, 0, 0]); + }); +}); diff --git a/tfjs-core/src/ops/hamming_window.ts b/tfjs-core/src/ops/hamming_window.ts new file mode 100644 index 00000000000..effd91d70d0 --- /dev/null +++ b/tfjs-core/src/ops/hamming_window.ts @@ -0,0 +1,38 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {op} from '../ops/operation'; +import {Tensor1D} from '../tensor'; +import {cosineWindow} from './signal_ops_util'; + +/** + * Generate a hamming window. + * + * See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows + * + * ```js + * tf.signal.hammingWindow(10).print(); + * ``` + * @param The length of window + */ +/** + * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'} + */ +function hammingWindow_(windowLength: number): Tensor1D { + return cosineWindow(windowLength, 0.54, 0.46); +} +export const hammingWindow = op({hammingWindow_}); diff --git a/tfjs-core/src/ops/hamming_window_test.ts b/tfjs-core/src/ops/hamming_window_test.ts new file mode 100644 index 00000000000..148dd1f1b90 --- /dev/null +++ b/tfjs-core/src/ops/hamming_window_test.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('hammingWindow', ALL_ENVS, () => { + it('length=3', async () => { + const ret = tf.signal.hammingWindow(3); + expectArraysClose(await ret.data(), [0.08, 1, 0.08]); + }); + + it('length=6', async () => { + const ret = tf.signal.hammingWindow(6); + expectArraysClose(await ret.data(), [0.08, 0.31, 0.77, 1., 0.77, 0.31]); + }); + + it('length=7', async () => { + const ret = tf.signal.hammingWindow(7); + expectArraysClose( + await ret.data(), [0.08, 0.31, 0.77, 1, 0.77, 0.31, 0.08]); + }); + + it('length=20', async () => { + const ret = tf.signal.hammingWindow(20); + expectArraysClose(await ret.data(), [ + 0.08000001, 0.10251403, 0.16785222, 0.2696188, 0.3978522, + 0.54, 0.68214786, 0.8103813, 0.9121479, 0.977486, + 1., 0.977486, 0.9121478, 0.8103812, 0.6821477, + 0.54, 0.39785212, 0.2696187, 0.16785222, 0.102514 + ]); + }); +}); diff --git a/tfjs-core/src/ops/hann_window.ts b/tfjs-core/src/ops/hann_window.ts new file mode 100644 index 00000000000..49af2e78a39 --- /dev/null +++ b/tfjs-core/src/ops/hann_window.ts @@ -0,0 +1,39 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {op} from '../ops/operation'; +import {Tensor1D} from '../tensor'; +import {cosineWindow} from './signal_ops_util'; + +/** + * Generate a Hann window. + * + * See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows + * + * ```js + * tf.signal.hannWindow(10).print(); + * ``` + * @param The length of window + */ +/** + * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'} + */ +function hannWindow_(windowLength: number): Tensor1D { + return cosineWindow(windowLength, 0.5, 0.5); +} + +export const hannWindow = op({hannWindow_}); diff --git a/tfjs-core/src/ops/hann_window_test.ts b/tfjs-core/src/ops/hann_window_test.ts new file mode 100644 index 00000000000..0cf4e1f83af --- /dev/null +++ b/tfjs-core/src/ops/hann_window_test.ts @@ -0,0 +1,47 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('hannWindow', ALL_ENVS, () => { + it('length=3', async () => { + const ret = tf.signal.hannWindow(3); + expectArraysClose(await ret.data(), [0, 1, 0]); + }); + + it('length=7', async () => { + const ret = tf.signal.hannWindow(7); + expectArraysClose(await ret.data(), [0, 0.25, 0.75, 1, 0.75, 0.25, 0]); + }); + + it('length=6', async () => { + const ret = tf.signal.hannWindow(6); + expectArraysClose(await ret.data(), [0., 0.25, 0.75, 1., 0.75, 0.25]); + }); + + it('length=20', async () => { + const ret = tf.signal.hannWindow(20); + expectArraysClose(await ret.data(), [ + 0., 0.02447176, 0.09549153, 0.20610738, 0.34549153, + 0.5, 0.65450853, 0.79389274, 0.9045085, 0.97552824, + 1., 0.97552824, 0.9045085, 0.7938925, 0.65450835, + 0.5, 0.34549144, 0.20610726, 0.09549153, 0.02447173 + ]); + }); +}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 50fee3c0e72..1756094195b 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -149,14 +149,24 @@ export * from './spectral_ops'; export * from './sparse_to_dense'; export * from './gather_nd'; export * from './dropout'; -export * from './signal_ops'; +export * from './signal_ops_util'; export * from './in_top_k'; export {op} from './operation'; import * as spectral from './spectral_ops'; import * as fused from './fused_ops'; -import * as signal from './signal_ops'; + +import {hammingWindow} from './hamming_window'; +import {hannWindow} from './hann_window'; +import {frame} from './frame'; +import {stft} from './stft'; +const signal = { + hammingWindow, + hannWindow, + frame, + stft, +}; // Image Ops namespace import {cropAndResize} from './crop_and_resize'; diff --git a/tfjs-core/src/ops/signal_ops.ts b/tfjs-core/src/ops/signal_ops.ts deleted file mode 100644 index 6ba0380ff87..00000000000 --- a/tfjs-core/src/ops/signal_ops.ts +++ /dev/null @@ -1,159 +0,0 @@ -/** - * @license - * Copyright 2019 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {op} from '../ops/operation'; -import {Tensor, Tensor1D} from '../tensor'; - -import {concat} from './concat'; -import {fill} from './fill'; -import {mul} from './mul'; -import {slice} from './slice'; -import {rfft} from './spectral_ops'; -import {tensor1d, tensor2d} from './tensor_ops'; - -/** - * Generate a Hann window. - * - * See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows - * - * ```js - * tf.signal.hannWindow(10).print(); - * ``` - * @param The length of window - */ -/** - * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'} - */ -function hannWindow_(windowLength: number): Tensor1D { - return cosineWindow(windowLength, 0.5, 0.5); -} - -/** - * Generate a hamming window. - * - * See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows - * - * ```js - * tf.signal.hammingWindow(10).print(); - * ``` - * @param The length of window - */ -/** - * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'} - */ -function hammingWindow_(windowLength: number): Tensor1D { - return cosineWindow(windowLength, 0.54, 0.46); -} - -/** - * Expands input into frames of frameLength. - * Slides a window size with frameStep. - * - * ```js - * tf.signal.frame([1, 2, 3], 2, 1).print(); - * ``` - * @param signal The input tensor to be expanded - * @param frameLength Length of each frame - * @param frameStep The frame hop size in samples. - * @param padEnd Whether to pad the end of signal with padValue. - * @param padValue An number to use where the input signal does - * not exist when padEnd is True. - */ -/** - * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'} - */ -function frame_( - signal: Tensor1D, frameLength: number, frameStep: number, padEnd = false, - padValue = 0): Tensor { - let start = 0; - const output: Tensor[] = []; - while (start + frameLength <= signal.size) { - output.push(slice(signal, start, frameLength)); - start += frameStep; - } - - if (padEnd) { - while (start < signal.size) { - const padLen = (start + frameLength) - signal.size; - const pad = concat([ - slice(signal, start, frameLength - padLen), fill([padLen], padValue) - ]); - output.push(pad); - start += frameStep; - } - } - - if (output.length === 0) { - return tensor2d([], [0, frameLength]); - } - - return concat(output).as2D(output.length, frameLength); -} - -/** - * Computes the Short-time Fourier Transform of signals - * See: https://en.wikipedia.org/wiki/Short-time_Fourier_transform - * - * ```js - * const input = tf.tensor1d([1, 1, 1, 1, 1]) - * tf.signal.stft(input, 3, 1).print(); - * ``` - * @param signal 1-dimensional real value tensor. - * @param frameLength The window length of samples. - * @param frameStep The number of samples to step. - * @param fftLength The size of the FFT to apply. - * @param windowFn A callable that takes a window length and returns 1-d tensor. - */ -/** - * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'} - */ -function stft_( - signal: Tensor1D, frameLength: number, frameStep: number, - fftLength?: number, - windowFn: (length: number) => Tensor1D = hannWindow): Tensor { - if (fftLength == null) { - fftLength = enclosingPowerOfTwo(frameLength); - } - const framedSignal = frame(signal, frameLength, frameStep); - const windowedSignal = mul(framedSignal, windowFn(frameLength)); - const output: Tensor[] = []; - for (let i = 0; i < framedSignal.shape[0]; i++) { - output.push( - rfft(windowedSignal.slice([i, 0], [1, frameLength]), fftLength)); - } - return concat(output); -} - -function enclosingPowerOfTwo(value: number) { - // Return 2**N for integer N such that 2**N >= value. - return Math.floor(Math.pow(2, Math.ceil(Math.log(value) / Math.log(2.0)))); -} - -function cosineWindow(windowLength: number, a: number, b: number): Tensor1D { - const even = 1 - windowLength % 2; - const newValues = new Float32Array(windowLength); - for (let i = 0; i < windowLength; ++i) { - const cosArg = (2.0 * Math.PI * i) / (windowLength + even - 1); - newValues[i] = a - b * Math.cos(cosArg); - } - return tensor1d(newValues, 'float32'); -} - -export const hannWindow = op({hannWindow_}); -export const hammingWindow = op({hammingWindow_}); -export const frame = op({frame_}); -export const stft = op({stft_}); diff --git a/tfjs-core/src/ops/signal_ops_util.ts b/tfjs-core/src/ops/signal_ops_util.ts new file mode 100644 index 00000000000..f930187e4dc --- /dev/null +++ b/tfjs-core/src/ops/signal_ops_util.ts @@ -0,0 +1,35 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Tensor1D} from '../tensor'; +import {tensor1d} from './tensor_ops'; + +export function enclosingPowerOfTwo(value: number) { + // Return 2**N for integer N such that 2**N >= value. + return Math.floor(Math.pow(2, Math.ceil(Math.log(value) / Math.log(2.0)))); +} + +export function cosineWindow( + windowLength: number, a: number, b: number): Tensor1D { + const even = 1 - windowLength % 2; + const newValues = new Float32Array(windowLength); + for (let i = 0; i < windowLength; ++i) { + const cosArg = (2.0 * Math.PI * i) / (windowLength + even - 1); + newValues[i] = a - b * Math.cos(cosArg); + } + return tensor1d(newValues, 'float32'); +} diff --git a/tfjs-core/src/ops/stft.ts b/tfjs-core/src/ops/stft.ts new file mode 100644 index 00000000000..b2c4e23664f --- /dev/null +++ b/tfjs-core/src/ops/stft.ts @@ -0,0 +1,61 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {op} from '../ops/operation'; +import {Tensor, Tensor1D} from '../tensor'; + +import {concat} from './concat'; +import {frame} from './frame'; +import {hannWindow} from './hann_window'; +import {mul} from './mul'; +import {enclosingPowerOfTwo} from './signal_ops_util'; +import {rfft} from './spectral_ops'; + +/** + * Computes the Short-time Fourier Transform of signals + * See: https://en.wikipedia.org/wiki/Short-time_Fourier_transform + * + * ```js + * const input = tf.tensor1d([1, 1, 1, 1, 1]) + * tf.signal.stft(input, 3, 1).print(); + * ``` + * @param signal 1-dimensional real value tensor. + * @param frameLength The window length of samples. + * @param frameStep The number of samples to step. + * @param fftLength The size of the FFT to apply. + * @param windowFn A callable that takes a window length and returns 1-d tensor. + */ +/** + * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'} + */ +function stft_( + signal: Tensor1D, frameLength: number, frameStep: number, + fftLength?: number, + windowFn: (length: number) => Tensor1D = hannWindow): Tensor { + if (fftLength == null) { + fftLength = enclosingPowerOfTwo(frameLength); + } + const framedSignal = frame(signal, frameLength, frameStep); + const windowedSignal = mul(framedSignal, windowFn(frameLength)); + const output: Tensor[] = []; + for (let i = 0; i < framedSignal.shape[0]; i++) { + output.push( + rfft(windowedSignal.slice([i, 0], [1, frameLength]), fftLength)); + } + return concat(output); +} +export const stft = op({stft_}); diff --git a/tfjs-core/src/ops/signal_ops_test.ts b/tfjs-core/src/ops/stft_test.ts similarity index 53% rename from tfjs-core/src/ops/signal_ops_test.ts rename to tfjs-core/src/ops/stft_test.ts index 0626b231cf5..81d2ace9e51 100644 --- a/tfjs-core/src/ops/signal_ops_test.ts +++ b/tfjs-core/src/ops/stft_test.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2019 Google LLC. All Rights Reserved. + * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -19,165 +19,6 @@ import * as tf from '../index'; import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; import {expectArraysClose} from '../test_util'; -describeWithFlags('hannWindow', ALL_ENVS, () => { - it('length=3', async () => { - const ret = tf.signal.hannWindow(3); - expectArraysClose(await ret.data(), [0, 1, 0]); - }); - - it('length=7', async () => { - const ret = tf.signal.hannWindow(7); - expectArraysClose(await ret.data(), [0, 0.25, 0.75, 1, 0.75, 0.25, 0]); - }); - - it('length=6', async () => { - const ret = tf.signal.hannWindow(6); - expectArraysClose(await ret.data(), [0., 0.25, 0.75, 1., 0.75, 0.25]); - }); - - it('length=20', async () => { - const ret = tf.signal.hannWindow(20); - expectArraysClose(await ret.data(), [ - 0., 0.02447176, 0.09549153, 0.20610738, 0.34549153, - 0.5, 0.65450853, 0.79389274, 0.9045085, 0.97552824, - 1., 0.97552824, 0.9045085, 0.7938925, 0.65450835, - 0.5, 0.34549144, 0.20610726, 0.09549153, 0.02447173 - ]); - }); -}); - -describeWithFlags('hammingWindow', ALL_ENVS, () => { - it('length=3', async () => { - const ret = tf.signal.hammingWindow(3); - expectArraysClose(await ret.data(), [0.08, 1, 0.08]); - }); - - it('length=6', async () => { - const ret = tf.signal.hammingWindow(6); - expectArraysClose(await ret.data(), [0.08, 0.31, 0.77, 1., 0.77, 0.31]); - }); - - it('length=7', async () => { - const ret = tf.signal.hammingWindow(7); - expectArraysClose( - await ret.data(), [0.08, 0.31, 0.77, 1, 0.77, 0.31, 0.08]); - }); - - it('length=20', async () => { - const ret = tf.signal.hammingWindow(20); - expectArraysClose(await ret.data(), [ - 0.08000001, 0.10251403, 0.16785222, 0.2696188, 0.3978522, - 0.54, 0.68214786, 0.8103813, 0.9121479, 0.977486, - 1., 0.977486, 0.9121478, 0.8103812, 0.6821477, - 0.54, 0.39785212, 0.2696187, 0.16785222, 0.102514 - ]); - }); -}); - -describeWithFlags('frame', ALL_ENVS, () => { - it('3 length frames', async () => { - const input = tf.tensor1d([1, 2, 3, 4, 5]); - const frameLength = 3; - const frameStep = 1; - const output = tf.signal.frame(input, frameLength, frameStep); - expect(output.shape).toEqual([3, 3]); - expectArraysClose(await output.data(), [1, 2, 3, 2, 3, 4, 3, 4, 5]); - }); - - it('3 length frames with step 2', async () => { - const input = tf.tensor1d([1, 2, 3, 4, 5]); - const frameLength = 3; - const frameStep = 2; - const output = tf.signal.frame(input, frameLength, frameStep); - expect(output.shape).toEqual([2, 3]); - expectArraysClose(await output.data(), [1, 2, 3, 3, 4, 5]); - }); - - it('3 length frames with step 5', async () => { - const input = tf.tensor1d([1, 2, 3, 4, 5]); - const frameLength = 3; - const frameStep = 5; - const output = tf.signal.frame(input, frameLength, frameStep); - expect(output.shape).toEqual([1, 3]); - expectArraysClose(await output.data(), [1, 2, 3]); - }); - - it('Exceeding frame length', async () => { - const input = tf.tensor1d([1, 2, 3, 4, 5]); - const frameLength = 6; - const frameStep = 1; - const output = tf.signal.frame(input, frameLength, frameStep); - expect(output.shape).toEqual([0, 6]); - expectArraysClose(await output.data(), []); - }); - - it('Zero frame step', async () => { - const input = tf.tensor1d([1, 2, 3, 4, 5]); - const frameLength = 6; - const frameStep = 0; - const output = tf.signal.frame(input, frameLength, frameStep); - expect(output.shape).toEqual([0, 6]); - expectArraysClose(await output.data(), []); - }); - - it('Padding with default value', async () => { - const input = tf.tensor1d([1, 2, 3, 4, 5]); - const frameLength = 3; - const frameStep = 3; - const padEnd = true; - const output = tf.signal.frame(input, frameLength, frameStep, padEnd); - expect(output.shape).toEqual([2, 3]); - expectArraysClose(await output.data(), [1, 2, 3, 4, 5, 0]); - }); - - it('Padding with the given value', async () => { - const input = tf.tensor1d([1, 2, 3, 4, 5]); - const frameLength = 3; - const frameStep = 3; - const padEnd = true; - const padValue = 100; - const output = - tf.signal.frame(input, frameLength, frameStep, padEnd, padValue); - expect(output.shape).toEqual([2, 3]); - expectArraysClose(await output.data(), [1, 2, 3, 4, 5, 100]); - }); - - it('Padding all remaining frames with step=1', async () => { - const input = tf.tensor1d([1, 2, 3, 4, 5]); - const frameLength = 4; - const frameStep = 1; - const padEnd = true; - const output = tf.signal.frame(input, frameLength, frameStep, padEnd); - expect(output.shape).toEqual([5, 4]); - expectArraysClose( - await output.data(), - [1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 0, 4, 5, 0, 0, 5, 0, 0, 0]); - }); - - it('Padding all remaining frames with step=1 and given pad-value', - async () => { - const input = tf.tensor1d([1, 2, 3, 4, 5]); - const frameLength = 4; - const frameStep = 1; - const padEnd = true; - const padValue = 42; - const output = - tf.signal.frame(input, frameLength, frameStep, padEnd, padValue); - expect(output.shape).toEqual([5, 4]); - expectArraysClose( - await output.data(), - [1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 42, 4, 5, 42, 42, 5, 42, 42, 42]); - }); - - it('Padding all remaining frames with step=2', async () => { - const input = tf.tensor1d([1, 2, 3, 4, 5]); - const output = tf.signal.frame(input, 4, 2, true); - expect(output.shape).toEqual([3, 4]); - expectArraysClose( - await output.data(), [1, 2, 3, 4, 3, 4, 5, 0, 5, 0, 0, 0]); - }); -}); - describeWithFlags('stft', ALL_ENVS, () => { it('3 length with hann window', async () => { const input = tf.tensor1d([1, 1, 1, 1, 1]); diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index 7bc70093dff..4891da0cb72 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -87,12 +87,15 @@ import './ops/equal_test'; import './ops/expand_dims_test'; import './ops/eye_test'; import './ops/floor_test'; +import './ops/frame_test'; import './ops/fused_test'; import './ops/gather_nd_test'; import './ops/gather_test'; import './ops/gram_schmidt_test'; import './ops/greater_equal_test'; import './ops/greater_test'; +import './ops/hamming_window_test'; +import './ops/hann_window_test'; import './ops/hinge_loss_test'; import './ops/huber_loss_test'; import './ops/in_top_k_test'; @@ -146,7 +149,6 @@ import './ops/scatter_nd_test'; import './ops/selu_test'; import './ops/sigmoid_cross_entropy_test'; import './ops/sign_test'; -import './ops/signal_ops_test'; import './ops/slice_test'; import './ops/slice_util_test'; import './ops/softmax_cross_entropy_test'; @@ -155,6 +157,7 @@ import './ops/space_to_batch_nd_test'; import './ops/sparse_to_dense_test'; import './ops/spectral_ops_test'; import './ops/stack_test'; +import './ops/stft_test'; import './ops/strided_slice_test'; import './ops/sub_test'; import './ops/sum_test'; From b491d25afd06b329f0808dc16df5c042028c1251 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Wed, 8 Jul 2020 14:44:05 -0400 Subject: [PATCH 2/3] save --- tfjs-core/src/ops/frame.ts | 5 +++-- tfjs-core/src/ops/stft.ts | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tfjs-core/src/ops/frame.ts b/tfjs-core/src/ops/frame.ts index 00fa8f73e13..7070c720c35 100644 --- a/tfjs-core/src/ops/frame.ts +++ b/tfjs-core/src/ops/frame.ts @@ -20,6 +20,7 @@ import {Tensor, Tensor1D} from '../tensor'; import {concat} from './concat'; import {fill} from './fill'; +import {reshape} from './reshape'; import {slice} from './slice'; import {tensor2d} from './tensor_ops'; @@ -65,6 +66,6 @@ function frame_( return tensor2d([], [0, frameLength]); } - return concat(output).as2D(output.length, frameLength); + return reshape(concat(output), [output.length, frameLength]); } -export const frame = op({frame_}); \ No newline at end of file +export const frame = op({frame_}); diff --git a/tfjs-core/src/ops/stft.ts b/tfjs-core/src/ops/stft.ts index b2c4e23664f..d28c81b0407 100644 --- a/tfjs-core/src/ops/stft.ts +++ b/tfjs-core/src/ops/stft.ts @@ -23,6 +23,7 @@ import {frame} from './frame'; import {hannWindow} from './hann_window'; import {mul} from './mul'; import {enclosingPowerOfTwo} from './signal_ops_util'; +import {slice} from './slice'; import {rfft} from './spectral_ops'; /** @@ -54,7 +55,7 @@ function stft_( const output: Tensor[] = []; for (let i = 0; i < framedSignal.shape[0]; i++) { output.push( - rfft(windowedSignal.slice([i, 0], [1, frameLength]), fftLength)); + rfft(slice(windowedSignal, [i, 0], [1, frameLength]), fftLength)); } return concat(output); } From 10e35a8c90e53df316ed6891332150211fee71d0 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Thu, 9 Jul 2020 11:16:11 -0400 Subject: [PATCH 3/3] build --- tfjs-core/src/tests.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index c98697c6fe8..728cebed70f 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -158,7 +158,6 @@ import './ops/scatter_nd_test'; import './ops/selu_test'; import './ops/sigmoid_cross_entropy_test'; import './ops/sign_test'; -import './ops/signal_ops_test'; import './ops/sin_test'; import './ops/sinh_test'; import './ops/slice1d_test';