Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tfjs-backend-wasm/src/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ tfjs_cc_library(
":Min",
":Minimum",
":Multiply",
":Neg",
":Negate",
":NonMaxSuppressionV3",
":NonMaxSuppressionV5",
":NotEqual",
Expand Down Expand Up @@ -635,8 +635,8 @@ tfjs_cc_library(
)

tfjs_cc_library(
name = "Neg",
srcs = ["kernels/Neg.cc"],
name = "Negate",
srcs = ["kernels/Negate.cc"],
deps = [
":backend",
":unary",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "src/cc/unary.h"

namespace {
inline float neg(const float val) { return -val; }
inline float negate(const float val) { return -val; }
} // namespace

namespace tfjs {
Expand All @@ -33,7 +33,7 @@ extern "C" {
#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
void Neg(const int x_id, const int out_id) { unary(x_id, out_id, neg); }
void Negate(const int x_id, const int out_id) { unary(x_id, out_id, negate); }

} // extern "C"
} // namespace wasm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
*/

import {registerUnaryKernel} from './unary_kernel';
registerUnaryKernel('Neg');
registerUnaryKernel('Negate');
2 changes: 1 addition & 1 deletion tfjs-backend-wasm/src/kernels/all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ import './MaxPool';
import './Min';
import './Minimum';
import './Muliply';
import './Neg';
import './Negate';
import './NonMaxSuppressionV3';
import './NonMaxSuppressionV5';
import './NotEqual';
Expand Down
1 change: 1 addition & 0 deletions tfjs-core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"rollup-plugin-terser": "~5.3.0",
"rollup-plugin-visualizer": "~3.3.2",
"shelljs": "~0.8.3",
"ts-morph": "^7.1.2",
"ts-node": "~8.8.2",
"tslint": "~5.11.0",
"tslint-no-circular-imports": "~0.5.0",
Expand Down
141 changes: 141 additions & 0 deletions tfjs-core/scripts/extract_op.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import * as argparse from 'argparse';
import {execSync} from 'child_process';
import * as path from 'path';
import {FunctionDeclaration, ImportDeclaration, Project, SourceFile, VariableStatement} from 'ts-morph';

const parser = new argparse.ArgumentParser();

parser.addArgument(
'--op_file',
{help: 'path to op file. e.g. src/ops/unary_ops.ts', required: true});
parser.addArgument('--ops', {
help: 'comma seprated list of ops to extract (e.g. tanh,tan).' +
' Skip this param to extract all ops from the file.',
defaultValue: [],
});

// initialize
const project = new Project({});

interface OpInfo {
variableStatement: VariableStatement;
opFuncName: string;
opIdentifier: string;
opFunc: FunctionDeclaration;
}

function getImports(sourceFile: SourceFile) {
return sourceFile.getImportDeclarations();
}

function getOpExports(file: SourceFile): OpInfo[] {
const variables = file.getVariableStatements();
const opFuncRegex = /op\(\{(.*_)\}\)/;
const exported = variables.filter(
v => v.isExported() && v.getFullText().match(opFuncRegex));
const opInfo = exported.map(variable => {
const declaration = variable.getDeclarations()[0];
const opFuncName = variable.getFullText().match(opFuncRegex)[1];
const opFunc = getOpFunc(file, opFuncName);
if (opFunc == null) {
console.warn(`Warning: could not find implementation function for ${
declaration.getName()}`);
return null;
}
return {
variableStatement: variable,
// string with exported name of op
opIdentifier: declaration.getName(),
// string with name of the function that actually implements the op
opFuncName,
// function that implements the op
opFunc,
};
});
return opInfo.filter(op => op != null);
}

function getOpFunc(sourceFile: SourceFile, opFuncName: string) {
return sourceFile.getFunction(opFuncName);
}

function toSnakeCase(str: string) {
// add exceptions here.
if (str === 'isNaN') {
return 'is_nan';
}
return str.replace(/[A-Z]/g, (s: string) => `_${s.toLowerCase()}`);
}

async function moveToNewFile(
opInfo: OpInfo, imports: ImportDeclaration[], sourceFile: SourceFile) {
//
// Move code to a new file
//
const targetFp = `src/ops/${toSnakeCase(opInfo.opIdentifier)}.ts`;
const newOpFile = project.createSourceFile(targetFp, (writer) => {
// By using getFullText here we will also get the copyright notice at the
// begining of the file
const importsStr = imports.map(i => i.getFullText()).join('');
const functionStr = opInfo.opFunc.getFullText();
const exportString = opInfo.variableStatement.getFullText();
const contents = [importsStr, functionStr, exportString].join('');
writer.write(contents);
}, {overwrite: true});
newOpFile.fixUnusedIdentifiers();
await newOpFile.save();

// make a test file
// create a test file
const testFilePath = `src/ops/${args.op}_test.ts`;
const command = `touch ${testFilePath}`;
execSync(command);

// Add export to ops file.
const opsExportFile = project.getSourceFile('src/ops/ops.ts');
opsExportFile.addExportDeclaration({
namedExports: [opInfo.opIdentifier],
moduleSpecifier: `./${path.basename(targetFp, '.ts')}`,
});

await opsExportFile.save();
}

async function run(filePath: string, ops: string[]) {
console.log('ops', ops);
project.addSourceFilesAtPaths(filePath);
// add the ops export file to the project
project.addSourceFilesAtPaths('src/ops/ops.ts');
const opFile = project.getSourceFile(filePath);
const imports = getImports(opFile);
const opExports = getOpExports(opFile);

opExports.forEach(async o => {
if (ops.length === 0 || ops.indexOf(o.opIdentifier) !== -1) {
await moveToNewFile(o, imports, opFile);
}
});

// Remove the op from the source file and save it
opExports.forEach(async o => {
if (ops.length === 0 || ops.indexOf(o.opIdentifier) !== -1) {
opFile.removeStatement(o.variableStatement.getChildIndex());
}
});
opFile.fixUnusedIdentifiers();
await opFile.save();
}

// add source files

const args = parser.parseArgs();
let opsToExtract = args.ops;
if (!Array.isArray(opsToExtract)) {
opsToExtract = opsToExtract.split(',');
}
console.log('Extracting from', args.op_file);
if (opsToExtract.length > 0) {
console.log('Only extract: ', opsToExtract);
}

run(args.op_file, opsToExtract);
91 changes: 68 additions & 23 deletions tfjs-core/scripts/touch_modular_op_files.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@

/**
* A helper script to generate empty files typically used in modularizing an op.
* Takes two params:
* params:
* --op the name of the op
* --kernel the name of the kernel (this is optional)
* --chained is this op part of the chained api (optional)
* --grad the name of the kernel(s) to create gradient files for
*
* It assumes you run it from tfjs_core.
*
* Example
* npx ts-node -s scripts/touch_modular_op_files.ts --op "op_name" --kernel \
* "KernelName" --chained
* npx ts-node -s scripts/touch_modular_op_files.ts --op op_name --grad \
* KernelName --chained
*
* Generates the following files (they will be empty)
* tfjs_core/src/ops/op_name.ts
Expand All @@ -36,15 +36,19 @@
*
* if --kernel is present
* tfjs_core/src/gradients/KernelName_grad.ts
*
* Example 2 (multiple kernels)
* npx ts-node -s scripts/touch_modular_op_files.ts --grad Kernel1,Kernel2
*/

import * as argparse from 'argparse';
import {execSync} from 'child_process';
import * as fs from 'fs';

const parser = new argparse.ArgumentParser();

parser.addArgument('--op', {help: 'the name of the op'});
parser.addArgument('--kernel', {help: 'the name of the kernel.'});
parser.addArgument('--grad', {help: 'the name of the grad.'});
parser.addArgument('--chained', {
action: 'storeTrue',
defaultValue: false,
Expand All @@ -55,29 +59,70 @@ async function main() {
const args = parser.parseArgs();
console.log('Called touch_modular_op_files with args:', args);

if (args.op == null) {
throw new Error('You must specify an op');
if (args.op != null) {
const ops: string[] = args.op.split(',');
ops.forEach(op => {
let filePath = `./src/ops/${op}.ts`;
let command = `touch ${filePath}`;
execSync(command);

// create a test file
filePath = `./src/ops/${op}_test.ts`;
command = `touch ${filePath}`;
execSync(command);

if (args.chained) {
filePath = `./src/public/chained_ops/${op}.ts`;
command = `touch ${filePath}`;
execSync(command);
}
});
}

let filePath = `./src/ops/${args.op}.ts`;
let command = `touch ${filePath}`;
execSync(command);
if (args.grad) {
const downcaseFirstChar = (str: string) => {
return str.charAt(0).toLowerCase() + str.slice(1);
};

// create a test file
filePath = `./src/ops/${args.op}_test.ts`;
command = `touch ${filePath}`;
execSync(command);
const kernels: string[] = args.grad.split(',');

if (args.chained) {
filePath = `./src/public/chained_ops/${args.op}.ts`;
command = `touch ${filePath}`;
execSync(command);
}
kernels.forEach(kernelName => {
const gradientFileTemplate = `/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

if (args.kernel) {
filePath = `./src/gradients/${args.kernel}_grad.ts`;
command = `touch ${filePath}`;
execSync(command);
import {KernelName, KernelNameAttrs} from '../kernel_names';
import {GradConfig, NamedAttrMap} from '../kernel_registry';
import {Tensor} from '../tensor';

export const ${downcaseFirstChar(kernelName)}GradConfig: GradConfig = {
kernelName: KernelName,
inputsToSave: [], // UPDATE ME
outputsToSave: [], // UPDATE ME
gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => {
const [] = saved;
const {} = attrs as {} as KernelNameAttrs;
return {
};
}
};
`;
const filePath = `./src/gradients/${kernelName}_grad.ts`;
fs.writeFileSync(filePath, gradientFileTemplate, {flag: 'a'});
});
}
}

Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/gradients/Atan2_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ import {add} from '../ops/add';
import {assertAndGetBroadcastShape, getReductionAxes} from '../ops/broadcast_util';
import {div} from '../ops/div';
import {mul} from '../ops/mul';
import {neg} from '../ops/neg';
import {reshape} from '../ops/reshape';
import {square} from '../ops/square';
import {sum} from '../ops/sum';
import {neg} from '../ops/unary_ops';
import {Tensor} from '../tensor';

export const atan2GradConfig: GradConfig = {
Expand Down
29 changes: 29 additions & 0 deletions tfjs-core/src/gradients/Ceil_grad.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Ceil} from '../kernel_names';
import {GradConfig} from '../kernel_registry';
import {zerosLike} from '../ops/tensor_ops';
import {Tensor} from '../tensor';

export const ceilGradConfig: GradConfig = {
kernelName: Ceil,
gradFunc: (dy: Tensor) => {
// TODO(manrajgrover): Return null for gradients when backprop supports it.
return {x: () => zerosLike(dy)};
}
};
2 changes: 1 addition & 1 deletion tfjs-core/src/gradients/Div_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import {GradConfig} from '../kernel_registry';
import * as broadcast_util from '../ops/broadcast_util';
import {div} from '../ops/div';
import {mul} from '../ops/mul';
import {neg} from '../ops/neg';
import {reshape} from '../ops/reshape';
import {square} from '../ops/square';
import {sum} from '../ops/sum';
import {neg} from '../ops/unary_ops';
import {Tensor} from '../tensor';

export const divGradConfig: GradConfig = {
Expand Down
Loading