Skip to content
This repository has been archived by the owner. It is now read-only.
Permalink
Browse files

Add fft ops. (#1191)

FEATURE

Add fft ops. Since TypeScript/JavaScript does not support complex number natively, it regards two length array as a complex number so that it can support [n, 2] shape tensor as input.

```
const real = tf.tensor1d([1, 2, 3]);
const imag = tf.tensor1d([1, 2, 3]);
const input = tf.complex(real, imag);
tf.fft(input); // => [[6, 6], [-2.3660252, -0.63397473], [-0.6339747, -2.3660254]]
```

see: tensorflow/tfjs#214
  • Loading branch information...
Lewuathe authored and nsthorat committed Oct 1, 2018
1 parent 9e1d279 commit ebcb59853e23eab4dbe031b9e517752384c05dae
@@ -275,6 +275,7 @@ export interface KernelBackend extends TensorStorage, BackendTimer {
boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number,
iouThreshold: number, scoreThreshold?: number): Tensor1D;

fft(x: Tensor1D): Tensor1D;
complex<T extends Tensor>(real: T, imag: T): T;
real<T extends Tensor>(input: T): T;
imag<T extends Tensor>(input: T): T;
@@ -2468,6 +2468,70 @@ export class MathBackendCPU implements KernelBackend {
boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
}

fft(input: Tensor1D): Tensor1D {
util.assert(input.shape.length > 0, 'input must have at least one rank.');
const n = input.shape[0];

if (this.is_exponent_of_2(n)) {
return this.fftRadix2(input, n);
} else {
const data = input.dataSync();
const rawOutput = this.fourierTransformByMatmul(data, n) as Float32Array;
const output = complex_util.splitRealAndImagArrays(rawOutput);
return ops.complex(output.real, output.imag).as1D();
}
}

private is_exponent_of_2(size: number): boolean {
return (size & size - 1) === 0;
}

// FFT using Cooley-Tukey algorithm on radix 2 dimensional input.
private fftRadix2(input: Tensor1D, size: number): Tensor1D {
if (size === 1) {
return input;
}
const data = input.dataSync() as Float32Array;
const half = size / 2;
const evenComplex = complex_util.complexWithEvenIndex(data);
let evenTensor = ops.complex(evenComplex.real, evenComplex.imag).as1D();
const oddComplex = complex_util.complexWithOddIndex(data);
let oddTensor = ops.complex(oddComplex.real, oddComplex.imag).as1D();

// Recursive call for half part of original input.
evenTensor = this.fftRadix2(evenTensor, half);
oddTensor = this.fftRadix2(oddTensor, half);

const e = complex_util.exponents(size);
const exponent = ops.complex(e.real, e.imag).mul(oddTensor);

const addPart = evenTensor.add(exponent);
const subPart = evenTensor.sub(exponent);

const realTensor = ops.real(addPart).concat(ops.real(subPart));
const imagTensor = ops.imag(addPart).concat(ops.imag(subPart));

return ops.complex(realTensor, imagTensor).as1D();
}

// Calculate fourier transform by multplying sinusoid matrix.
private fourierTransformByMatmul(data: TypedArray, size: number): TypedArray {
const ret = new Float32Array(size * 2);
// TODO: Use matmul instead once it supports complex64 type.
for (let r = 0; r < size; r++) {
let real = 0.0;
let imag = 0.0;
for (let c = 0; c < size; c++) {
const e = complex_util.exponent(r * c, size);
const term = complex_util.getComplexWithIndex(data as Float32Array, c);
real += term.real * e.real - term.imag * e.imag;
imag += term.real * e.imag + term.imag * e.real;
}
complex_util.assignToTypedArray(ret, real, imag, r);
}
return ret;
}

depthToSpace(x: Tensor4D, blockSize: number, dataFormat: 'NHWC'|'NCHW'):
Tensor4D {
util.assert(
@@ -55,6 +55,8 @@ import {CropAndResizeProgram} from './webgl/crop_and_resize_gpu';
import {CumSumProgram} from './webgl/cumsum_gpu';
import {DepthToSpaceProgram} from './webgl/depth_to_space_gpu';
import {EncodeFloatProgram} from './webgl/encode_float_gpu';
import {FFTProgram} from './webgl/fft_gpu';
import * as fft_gpu from './webgl/fft_gpu';
import {FromPixelsProgram} from './webgl/from_pixels_gpu';
import {GatherProgram} from './webgl/gather_gpu';
import {GPGPUContext} from './webgl/gpgpu_context';
@@ -1494,6 +1496,24 @@ export class MathBackendWebGL implements KernelBackend {
return split(x, sizeSplits, axis);
}

fft(x: Tensor1D): Tensor1D {
const xData = this.texData.get(x.dataId);

const realProgram = new FFTProgram(fft_gpu.COMPLEX_FFT.REAL, x.shape);
const imagProgram = new FFTProgram(fft_gpu.COMPLEX_FFT.IMAG, x.shape);
const inputs = [
this.makeComplexComponentTensorHandle(x, xData.complexTensors.real),
this.makeComplexComponentTensorHandle(x, xData.complexTensors.imag),
];

const real = this.compileAndRun<Tensor>(realProgram, inputs);
const imag = this.compileAndRun<Tensor>(imagProgram, inputs);
const complex = this.complex(real, imag).as1D();
real.dispose();
imag.dispose();
return complex;
}

private makeOutputArray<T extends Tensor>(shape: number[], dtype: DataType):
T {
return Tensor.make(shape, {}, dtype) as T;
@@ -15,6 +15,7 @@
* =============================================================================
*/

import {TypedArray} from '../types';
/**
* Merges real and imaginary Float32Arrays into a single complex Float32Array.
*
@@ -68,3 +69,85 @@ export function splitRealAndImagArrays(complex: Float32Array):
}
return {real, imag};
}

/**
* Extracts even indexed complex values in the given array.
* @param complex The complex tensor values
*/
export function complexWithEvenIndex(complex: Float32Array):
{real: Float32Array, imag: Float32Array} {
const len = Math.ceil(complex.length / 4);
const real = new Float32Array(len);
const imag = new Float32Array(len);
for (let i = 0; i < complex.length; i += 4) {
real[Math.floor(i / 4)] = complex[i];
imag[Math.floor(i / 4)] = complex[i + 1];
}
return {real, imag};
}

/**
* Extracts odd indexed comple values in the given array.
* @param complex The complex tensor values
*/
export function complexWithOddIndex(complex: Float32Array):
{real: Float32Array, imag: Float32Array} {
const len = Math.floor(complex.length / 4);
const real = new Float32Array(len);
const imag = new Float32Array(len);
for (let i = 2; i < complex.length; i += 4) {
real[Math.floor(i / 4)] = complex[i];
imag[Math.floor(i / 4)] = complex[i + 1];
}
return {real, imag};
}

/**
* Get the map representing a complex value in the given array.
* @param complex The complex tensor values.
* @param index An index of the target complex value.
*/
export function getComplexWithIndex(complex: Float32Array, index: number):
{real: number, imag: number} {
const real = complex[index*2];
const imag = complex[index*2+1];
return {real, imag};
}

/**
* Insert a given complex value into the TypedArray.
* @param data The array in which the complex value is inserted.
* @param c The complex value to be inserted.
* @param index An index of the target complex value.
*/
export function assignToTypedArray(data: TypedArray,
real: number, imag: number, index: number) {
data[index*2] = real;
data[index*2+1] = imag;
}

/**
* Make the list of exponent terms used by FFT.
*/
export function exponents(n: number):
{real: Float32Array, imag: Float32Array} {
const real = new Float32Array(n / 2);
const imag = new Float32Array(n / 2);
for (let i = 0; i < Math.ceil(n/2); i++) {
const x = -2 * Math.PI * (i / n);
real[i] = Math.cos(x);
imag[i] = Math.sin(x);
}
return {real, imag};
}

/**
* Make the exponent term used by FFT.
*/
export function exponent(k: number, n: number):
{real: number, imag: number} {
const x = -2 * Math.PI * (k / n);
const real = Math.cos(x);
const imag = Math.sin(x);
return {real, imag};
}
@@ -15,6 +15,8 @@
* =============================================================================
*/
import * as complex_util from './complex_util';
import {expectArraysClose, ALL_ENVS} from '../test_util';
import {describeWithFlags} from '../jasmine_util';

describe('complex_util', () => {
it('mergeRealAndImagArrays', () => {
@@ -30,4 +32,37 @@ describe('complex_util', () => {
expect(result.real).toEqual(new Float32Array([1, 2, 3]));
expect(result.imag).toEqual(new Float32Array([4, 5, 6]));
});

it('complexWithEvenIndex', () => {
const complex = new Float32Array([1, 2, 3, 4, 5, 6]);
const result = complex_util.complexWithEvenIndex(complex);
expect(result.real).toEqual(new Float32Array([1, 5]));
expect(result.imag).toEqual(new Float32Array([2, 6]));
});

it('complexWithOddIndex', () => {
const complex = new Float32Array([1, 2, 3, 4, 5, 6]);
const result = complex_util.complexWithOddIndex(complex);
expect(result.real).toEqual(new Float32Array([3]));
expect(result.imag).toEqual(new Float32Array([4]));
});
});

describeWithFlags('complex_util exponents', ALL_ENVS, () => {
it('exponents', () => {
const result = complex_util.exponents(5);
expectArraysClose(result.real, new Float32Array([1, 0.30901700258255005]));
expectArraysClose(result.imag, new Float32Array([0, -0.9510565400123596]));
});
});

describeWithFlags('complex_util assignment', ALL_ENVS, () => {
it('assign complex value in TypedArray', () => {
const t = new Float32Array(4);

complex_util.assignToTypedArray(t, 1, 2, 0);
complex_util.assignToTypedArray(t, 3, 4, 1);

expectArraysClose(t, new Float32Array([1, 2, 3, 4]));
});
});
@@ -0,0 +1,63 @@
/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {GPGPUProgram} from './gpgpu_math';

export const COMPLEX_FFT = {
REAL: 'return real * expR - imag * expI;',
IMAG: 'return real * expI + imag * expR;'
};

export class FFTProgram implements GPGPUProgram {
variableNames = ['real', 'imag'];
outputShape: number[];
userCode: string;

constructor(op: string, inputShape: number[]) {
const size = inputShape[0];
this.outputShape = [size];

this.userCode = `
float unaryOpComplex(float real, float expR, float imag, float expI) {
${op}
}
float mulMatDFT(int row) {
// TODO: Gather constants in one place?
const float PI = 3.1415926535897932384626433832795;
float result = 0.0;
for (int i = 0; i < ${size}; i++) {
float x = -2.0 * PI * float(row * i) / float(${size});
float expR = cos(x);
float expI = sin(x);
float real = getReal(i);
float imag = getImag(i);
result += unaryOpComplex(real, expR, imag, expI);
}
return result;
}
void main() {
int row = getOutputCoords();
setOutput(mulMatDFT(row));
}
`;
}
}
@@ -40,6 +40,7 @@ export * from './lstm';
export * from './moving_average';
export * from './strided_slice';
export * from './topk';
export * from './spectral_ops';

export {op} from './operation';

@@ -0,0 +1,43 @@
/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {ENV} from '../environment';
import {op} from '../ops/operation';
import {Tensor1D} from '../tensor';
import {assert} from '../util';

/**
* Compute the 1-dimentional discrete fourier transform
* The input is expected to be the 1D tensor with dtype complex64.
*
* ```js
* const real = tf.tensor1d([1, 2, 3]);
* const imag = tf.tensor1d([1, 2, 3]);
* const x = tf.complex(real, imag);
*
* x.fft().print();
* ```
* @param input The complex input to compute an fft over.
*/
function fft_(input: Tensor1D): Tensor1D {
assert(input.dtype === 'complex64', 'dtype must be complex64');
assert(input.rank === 1, 'input rank must be 1');
const ret = ENV.engine.runKernel(backend => backend.fft(input), {input});
return ret;
}

export const fft = op({fft_});

0 comments on commit ebcb598

Please sign in to comment.
You can’t perform that action at this time.