Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Commit

Permalink
added support for spectral ops and updated the supported ops doc (#253)
Browse files Browse the repository at this point in the history
* added support spectral ops and updated the supported ops doc

* add comment to explain where tjfs-core.json is from
  • Loading branch information
pyu10055 committed Nov 27, 2018
1 parent e054243 commit 5e34e92
Show file tree
Hide file tree
Showing 13 changed files with 330 additions and 215 deletions.
146 changes: 0 additions & 146 deletions docs/doc_gen.ts

This file was deleted.

4 changes: 3 additions & 1 deletion docs/supported_ops.md
Expand Up @@ -181,6 +181,7 @@
|NotEqual|notEqual|
|Select|where|
|Not mapped|logicalXor|
|Not mapped|whereAsync|

## Operations - Matrices

Expand Down Expand Up @@ -212,6 +213,7 @@
|CropAndResize|cropAndResize|
|ResizeBilinear|resizeBilinear|
|ResizeNearestNeighbor|resizeNearestNeighbor|
|Not mapped|nonMaxSuppression|

## Operations - Reduction

Expand Down Expand Up @@ -270,5 +272,5 @@
|Reshape|reshape|
|SpaceToBatchND|spaceToBatchND|
|Squeeze|squeeze|
|ListDiff|setdiff1dAsync|
|Not mapped|setdiff1dAsync|

4 changes: 2 additions & 2 deletions package.json
Expand Up @@ -14,10 +14,10 @@
},
"license": "Apache-2.0",
"peerDependencies": {
"@tensorflow/tfjs-core": "0.13.11"
"@tensorflow/tfjs-core": "0.13.12"
},
"devDependencies": {
"@tensorflow/tfjs-core": "0.13.11",
"@tensorflow/tfjs-core": "0.13.12",
"@types/jasmine": "~2.8.6",
"@types/node-fetch": "1.6.9",
"ajv": "~6.3.0",
Expand Down
71 changes: 36 additions & 35 deletions scripts/gen_doc.ts
Expand Up @@ -14,10 +14,7 @@
* limitations under the License.
* =============================================================================
*/

import * as tfc from '@tensorflow/tfjs-core';
import * as fs from 'fs';
import fetch from 'node-fetch';

import * as arithmetic from '../src/operations/op_list/arithmetic';
import * as basicMath from '../src/operations/op_list/basic_math';
Expand All @@ -36,22 +33,19 @@ import * as sliceJoin from '../src/operations/op_list/slice_join';
import * as transformation from '../src/operations/op_list/transformation';
import {OpMapper} from '../src/operations/types';

const DOC_DIR = './docs/';
// tfjs-core api file is generated by tfjs-website project and should be
// manually copied over before running this script.
const JSON_DIR = './tfjs-core.json';
const DOC_DIR = '../docs/';

const opMappers = [
...arithmetic.json, ...basicMath.json, ...control.json, ...convolution.json,
...creation.json, ...dynamic.json, ...evaluation.json, ...logical.json,
...image.json, ...graph.json, ...matrices.json, ...normalization.json,
...reduction.json, ...sliceJoin.json, ...transformation.json
const ops = [
arithmetic, basicMath, control, convolution, creation, dynamic, evaluation,
logical, image, graph, matrices, normalization, reduction, sliceJoin,
transformation
];
const GITHUB_URL_PREFIX =
'https://raw.githubusercontent.com/tensorflow/tfjs-website';
const CORE_API_PREFIX =
`/master/source/_data/api/${tfc.version_core}/tfjs-core.json`;

async function genDoc() {
const response = await fetch(GITHUB_URL_PREFIX + CORE_API_PREFIX);
const json = await response.json();
const json = JSON.parse(fs.readFileSync(JSON_DIR, 'utf8'));
// tslint:disable-next-line:no-any
const coreApis = json.docs.headings.reduce((list: Array<{}>, h: any) => {
return h.subheadings ? list.concat(h.subheadings.reduce(
Expand All @@ -67,51 +61,58 @@ async function genDoc() {
output.push('# Supported Tensorflow Ops\n\n');

generateTable(
'Operations', 'Arithmetic', arithmetic.json as OpMapper[], output,
coreApis);
generateTable(
'Operations', 'Basic math', basicMath.json as OpMapper[], output,
'Operations', 'Arithmetic', (arithmetic.json as {}) as OpMapper[], output,
coreApis);
generateTable(
'Operations', 'Control Flow', control.json as OpMapper[], output,
'Operations', 'Basic math', (basicMath.json as {}) as OpMapper[], output,
coreApis);
generateTable(
'Operations', 'Convolution', convolution.json as OpMapper[], output,
'Operations', 'Control Flow', (control.json as {}) as OpMapper[], output,
coreApis);
generateTable(
'Tensors', 'Creation', creation.json as OpMapper[], output, coreApis);
'Operations', 'Convolution', (convolution.json as {}) as OpMapper[],
output, coreApis);
generateTable(
'Operations', 'Dynamic', dynamic.json as OpMapper[], output, coreApis);
'Tensors', 'Creation', (creation.json as {}) as OpMapper[], output,
coreApis);
generateTable(
'Operations', 'Evaluation', evaluation.json as OpMapper[], output,
'Operations', 'Dynamic', (dynamic.json as {}) as OpMapper[], output,
coreApis);
generateTable(
'Tensorflow', 'Graph', graph.json as OpMapper[], output, coreApis);
'Operations', 'Evaluation', (evaluation.json as {}) as OpMapper[], output,
coreApis);
generateTable(
'Operations', 'Logical', logical.json as OpMapper[], output, coreApis);
'Tensorflow', 'Graph', (graph.json as {}) as OpMapper[], output,
coreApis);
generateTable(
'Operations', 'Matrices', matrices.json as OpMapper[], output, coreApis);
'Operations', 'Logical', (logical.json as {}) as OpMapper[], output,
coreApis);
generateTable(
'Operations', 'Normalization', normalization.json as OpMapper[], output,
'Operations', 'Matrices', (matrices.json as {}) as OpMapper[], output,
coreApis);
generateTable(
'Operations', 'Images', image.json as OpMapper[], output, coreApis);
'Operations', 'Normalization', (normalization.json as {}) as OpMapper[],
output, coreApis);
generateTable(
'Operations', 'Reduction', reduction.json as OpMapper[], output,
'Operations', 'Images', (image.json as {}) as OpMapper[], output,
coreApis);
generateTable(
'Tensors', 'Slicing and Joining', sliceJoin.json as OpMapper[], output,
'Operations', 'Reduction', (reduction.json as {}) as OpMapper[], output,
coreApis);
generateTable(
'Tensors', 'Transformations', transformation.json as OpMapper[], output,
coreApis);
'Tensors', 'Slicing and Joining', (sliceJoin.json as {}) as OpMapper[],
output, coreApis);
generateTable('Operations', 'Spectral', [] as OpMapper[], output, coreApis);
generateTable(
'Tensors', 'Transformations', (transformation.json as {}) as OpMapper[],
output, coreApis);

console.log(process.cwd());
fs.writeFileSync(DOC_DIR + 'supported_ops.md', output.join(''));

console.log(
`Supported Ops written to ${DOC_DIR + 'supported_ops.md'}\n` +
`Found ${opMappers.length} ops\n`);
`Found ${ops.reduce((sum, cat) => sum += cat.json.length, 0)} ops\n`);
}

function findCoreOps(heading: string, subHeading: string, coreApis: Array<{}>) {
Expand Down Expand Up @@ -141,7 +142,7 @@ function generateTable(
output.push(`|Not mapped|${element.symbolName}|\n`);
}
});
output.push('\n\n');
output.push('\n');
}

genDoc();
52 changes: 52 additions & 0 deletions src/operations/executors/spectral_executor.ts
@@ -0,0 +1,52 @@
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tfc from '@tensorflow/tfjs-core';

import {NamedTensorsMap} from '../../data/types';
import {ExecutionContext} from '../../executor/execution_context';
import {Node} from '../types';

import {OpExecutor} from './types';
import {getParamValue} from './utils';

export let executeOp: OpExecutor =
(node: Node, tensorMap: NamedTensorsMap,
context: ExecutionContext): tfc.Tensor[] => {
switch (node.op) {
case 'fft': {
return [tfc.fft(
getParamValue('x', node, tensorMap, context) as tfc.Tensor)];
}
case 'ifft': {
return [tfc.ifft(
getParamValue('x', node, tensorMap, context) as tfc.Tensor)];
}
case 'rfft': {
return [tfc.rfft(
getParamValue('x', node, tensorMap, context) as tfc.Tensor)];
}
case 'irfft': {
return [tfc.irfft(
getParamValue('x', node, tensorMap, context) as tfc.Tensor)];
}
default:
throw TypeError(`Node type ${node.op} is not implemented`);
}
};

export const CATEGORY = 'spectral';

0 comments on commit 5e34e92

Please sign in to comment.