Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ee95869
initial commit of module writing tool
tafsiri Jul 29, 2020
eacc171
remove chaining api from browser.ts
tafsiri Jul 29, 2020
31ee164
split base from index for core, cpu and webgl
tafsiri Jul 30, 2020
886ef69
fix import
tafsiri Jul 30, 2020
5d1ac5f
build tools separately.
tafsiri Jul 30, 2020
0637292
set side effects property for core, cpu, webgl
tafsiri Jul 30, 2020
55addb2
add tests for tools
tafsiri Jul 30, 2020
672c670
add more validation
tafsiri Jul 30, 2020
ffdbe6c
split out required side effects for core
tafsiri Jul 30, 2020
d071d67
Merge branch 'master' into custom-bundle-prototype
tafsiri Jul 31, 2020
a2869d2
restore gradient registration
tafsiri Jul 31, 2020
98ee2aa
Merge branch 'master' into custom-bundle-prototype
tafsiri Aug 3, 2020
c8fff26
add kernel_to_ops mapper script
tafsiri Aug 4, 2020
f8bb301
allow multiple tfjs op calls per executor block
tafsiri Aug 4, 2020
ec2aa31
refactor executors to not branch on op name once
tafsiri Aug 4, 2020
5ddd280
add tests for kernel to op mapping
tafsiri Aug 4, 2020
881ad0e
Merge branch 'master' into model-json-kernels
tafsiri Aug 4, 2020
6f13136
docs and lint
tafsiri Aug 4, 2020
41c5ad8
save
tafsiri Aug 4, 2020
614ff7f
Merge branch 'master' into model-json-kernels
tafsiri Aug 6, 2020
852f593
Merge branch 'master' into model-json-kernels
tafsiri Aug 6, 2020
3db116d
Merge branch 'master' into model-json-kernels
pyu10055 Aug 7, 2020
5a6e67c
Merge branch 'master' into model-json-kernels
tafsiri Aug 7, 2020
501fdbf
add argparse as dev dependency
tafsiri Aug 7, 2020
d639372
add argparse types
tafsiri Aug 7, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
524 changes: 524 additions & 0 deletions tfjs-converter/metadata/kernel2op.json

Large diffs are not rendered by default.

12 changes: 8 additions & 4 deletions tfjs-converter/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
"@rollup/plugin-typescript": "^3.0.0",
"@tensorflow/tfjs-backend-cpu": "link:../tfjs-backend-cpu",
"@tensorflow/tfjs-core": "link:../tfjs-core",
"@types/argparse": "^1.0.38",
"@types/deep-equal": "^1.0.1",
"@types/jasmine": "~2.8.6",
"@types/long": "~3.0.32",
"@types/node-fetch": "1.6.9",
"ajv": "~6.3.0",
"argparse": "^1.0.10",
"babel-core": "~6.26.3",
"babel-plugin-external-helpers": "~6.22.0",
"babel-preset-env": "~1.7.0",
Expand All @@ -48,15 +50,16 @@
"rollup": "~2.3.2",
"rollup-plugin-terser": "~5.3.0",
"rollup-plugin-visualizer": "~3.3.2",
"ts-morph": "^7.1.3",
"ts-node": "~8.8.2",
"tslint": "~5.8.0",
"tslint-no-circular-imports": "~0.5.0",
"typescript": "3.5.3",
"yalc": "~1.0.0-pre.21"
},
"scripts": {
"build": "yarn gen-json --test && tsc && yarn bundle",
"build-ci": "yarn gen-json --test && tsc && yarn bundle-ci",
"build": "yarn gen-json --test && yarn gen-kernel2ops && tsc && yarn bundle",
"build-ci": "yarn gen-json --test && yarn gen-kernel2ops && tsc && yarn bundle-ci",
"bundle": "rollup -c",
"bundle-ci": "rollup -c --ci",
"build-core": "cd ../tfjs-core && yarn && yarn build",
Expand All @@ -69,7 +72,7 @@
"link-local": "yalc link",
"publish-local": "yarn build-npm && yalc push",
"publish-npm": "npm publish",
"test": "yarn && yarn build-deps && yarn gen-json --test && ts-node -P tsconfig.test.json run_tests.ts",
"test": "yarn && yarn build-deps && yarn gen-json --test && yarn gen-kernel2ops && ts-node -P tsconfig.test.json run_tests.ts",
"test-ci": "ts-node --skip-ignore -P tsconfig.test.json run_tests.ts",
"test-snippets": "ts-node --skip-ignore -s ./scripts/test_snippets.ts",
"lint": "tslint -p . -t verbose",
Expand All @@ -79,6 +82,7 @@
"model-summary": "ts-node -s ./tools/model_summary.ts",
"pb2json": "ts-node -s ./tools/pb2json_converter.ts",
"build-pip-package": "yarn gen-json --test && cd python && ./build-pip-package.sh --test /tmp/tfjs-pips",
"run-python-tests": "yarn gen-json --test && cd python && ./run-python-tests.sh"
"run-python-tests": "yarn gen-json --test && cd python && ./run-python-tests.sh",
"gen-kernel2ops": "ts-node -s scripts/kernels_to_ops.ts --out metadata/kernel2op.json"
}
}
160 changes: 160 additions & 0 deletions tfjs-converter/scripts/kernels_to_ops.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

/**
* This script generates a mapping of Kernel Names to op names as defined by
* the converter source code. This allows a couple of things for modular builds
* 1. From a model.json file we can create imports for the ops the converter
* will call.
* 2. From those ops we could validate that the kernels we add to the modular
* build match the names of kernels in model.json (this is not necessary
* but is potentially useful for alignment).
*
* This can also be used to keep our supported ops list up to date.
*
* The approach used is to parse the source code of the converter executors
* (src/operations/executors) for the following kind pattern.
* case 'BiasAdd':
* case 'AddV2':
* case 'Add': {
* return [tfc.add(
* (getParamValue('a', node, tensorMap, context) as tfc.Tensor),
* getParamValue('b', node, tensorMap, context) as tfc.Tensor)];
* }
*
* Case matchers represent kernel names and tfc.*(...) represent the tfjs op(s)
* that are called. This example shows that we need to support fallthrough case
* statements as well.
*
*/

import * as argparse from 'argparse';
import * as fs from 'fs';
import {CaseClause, CaseOrDefaultClause, Project, SourceFile, SwitchStatement, SyntaxKind} from 'ts-morph';

const parser = new argparse.ArgumentParser();

parser.addArgument(
'--out', {help: 'Path to output JSON to create', required: true});

const project = new Project({});

function getSwitchStatement(source: SourceFile): SwitchStatement {
// Each executor only has one switch statment.
let switchStatement: SwitchStatement;
source.forEachDescendant((node) => {
if (node.getKindName() === 'SwitchStatement') {
switchStatement = node as SwitchStatement;
}
});
return switchStatement;
}

type KernelMapping = {
[key: string]: string[]
};

function getKernelMappingForFile(source: SourceFile) {
const switchStatement = getSwitchStatement(source);
if (switchStatement === null) {
throw new Error('No switch statment found in executor');
}
const caseClauses = switchStatement.getClauses();

const kernelsToOp: KernelMapping = {};
let currentClauseGroup: string[] = [];

// Loop through clauses until you reach one that has a block or return.
// This allows us to coalesce fallthrough case blocks in a switch statement.
caseClauses.forEach((caseClause: CaseOrDefaultClause) => {
if (caseClause instanceof CaseClause) {
let kernelName;
caseClause.forEachChild(clausePart => {
const kind = clausePart.getKindName();
if (kind === 'StringLiteral') {
kernelName = clausePart.getText().replace(/\'/g, '');
currentClauseGroup.push(kernelName);
}
if (kind === 'Block' || kind === 'ReturnStatement') {
// We have reached a code block, all the previously captured
// kernels use this block as their execution path.

// Parse the code block and determing all the tfc.*() function calls
// used.
const callExprs =
clausePart.getDescendantsOfKind(SyntaxKind.CallExpression);
const tfcCallExprs =
callExprs.filter(expr => expr.getText().match(/tfc/));
const tfSymbols: Set<string> = new Set();
for (const tfcCall of tfcCallExprs) {
const tfcCallStr = tfcCall.getText();
const functionCallMatcher = /(tfc\.([\w\.]*)\()/g;
const matches = tfcCallStr.match(functionCallMatcher);
if (matches != null && matches.length > 0) {
for (const match of matches) {
// extract the method name (and any namespaces used to call it)
const symbolMatcher = /(tfc\.([\w\.]*)\()/;
const symbol = match.match(symbolMatcher)[2];
tfSymbols.add(symbol);
}
}
}
for (const kern of currentClauseGroup) {
kernelsToOp[kern] = Array.from(tfSymbols);
}
// Reset the clause tracker as we are moving to a new set of kernels
currentClauseGroup = [];
}
});
}
});

return kernelsToOp;
}

function getKernelMapping() {
const sourceFiles = project.getSourceFiles();
const kernelsToOp: KernelMapping = {};

for (const sourceFile of sourceFiles) {
const mapping = getKernelMappingForFile(sourceFile);
Object.assign(kernelsToOp, mapping);
}
return kernelsToOp;
}

async function run(outputFilePath: string) {
const EXECUTORS_PATH = 'src/operations/executors/*_executor.ts';
project.addSourceFilesAtPaths(EXECUTORS_PATH);

const kernelMapping = getKernelMapping();

const pairs: Array<[string, string[]]> = Object.entries(kernelMapping).sort();
const sortedKernelMapping: KernelMapping = {};
pairs.forEach(([k, v]) => {
sortedKernelMapping[k] = v;
});
const replacer: null = null;
const space = 2;
fs.writeFileSync(
outputFilePath, JSON.stringify(sortedKernelMapping, replacer, space),
{encoding: 'utf8'});
}

const args = parser.parseArgs();
console.log('Writing output to', args.out);
run(args.out);
90 changes: 90 additions & 0 deletions tfjs-converter/src/metadata_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

// This test checks the metadata.json to make sure that only the kernels that
// we know we don't map to tfjs ops have empty entries in metadata
// kernel2op.json
describe('kernel2op metadata file', () => {
it('has kernel2op.json', () => {
expect(() => {
// tslint:disable-next-line:no-require-imports
require('../metadata/kernel2op.json');
}).not.toThrow();
});

it('only known unmapped kernel are unmmapped', () => {
const knownUnmappedKernels = [
'Const',
'Enter',
'Exit',
'FakeQuantWithMinMaxVars',
'Identity',
'IdentityN',
'If',
'LoopCond',
'Merge',
'NextIteration',
'Placeholder',
'PlaceholderWithDefault',
'Print',
'Snapshot',
'StatelessIf',
'StatelessWhile',
'StopGradient',
'Switch',
'TensorArrayCloseV3',
'TensorArrayConcatV3',
'TensorArrayGatherV3',
'TensorArrayReadV3',
'TensorArrayScatterV3',
'TensorArraySizeV3',
'TensorArraySplitV3',
'TensorArrayV3',
'TensorArrayWriteV3',
'TensorListConcat',
'TensorListFromTensor',
'TensorListGather',
'TensorListGetItem',
'TensorListPopBack',
'TensorListPushBack',
'TensorListReserve',
'TensorListScatter',
'TensorListScatterV2',
'TensorListSetItem',
'TensorListSplit',
'TensorListStack',
'While',
];
// tslint:disable-next-line:no-require-imports
const kernel2op = require('../metadata/kernel2op.json');
const kernels: string[] = Object.keys(kernel2op);

for (const kernelName of kernels) {
const tfOps = kernel2op[kernelName];
if (knownUnmappedKernels.includes(kernelName)) {
expect(tfOps.length)
.toEqual(0, `Kernel "${kernelName}" is expected to be unmapped but
instead maps to ${tfOps}`);
} else {
expect(tfOps.length)
.toBeGreaterThan(
0, `Kernel ${kernelName} is expected to be mapped to a list
of tf ops`);
}
}
});
});
Loading