From d92e007cc9e2a7ecd2ec23e9962cbf3739c2e0b7 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Fri, 29 May 2020 15:18:11 -0700 Subject: [PATCH 1/7] add complex weight conversion support --- .../python/tensorflowjs/read_weights.py | 3 +- .../python/tensorflowjs/read_weights_test.py | 10 ++++- .../python/tensorflowjs/write_weights.py | 7 +++- .../python/tensorflowjs/write_weights_test.py | 42 ++++++++++++++----- 4 files changed, 46 insertions(+), 16 deletions(-) diff --git a/tfjs-converter/python/tensorflowjs/read_weights.py b/tfjs-converter/python/tensorflowjs/read_weights.py index 20a13eb2df1..5f5d433faa9 100644 --- a/tfjs-converter/python/tensorflowjs/read_weights.py +++ b/tfjs-converter/python/tensorflowjs/read_weights.py @@ -24,7 +24,8 @@ import numpy as np from tensorflowjs import quantization -_INPUT_DTYPES = [np.float32, np.int32, np.uint8, np.uint16, np.object] +_INPUT_DTYPES = [np.float32, np.int32, np.complex64, + np.uint8, np.uint16, np.object] # Number of bytes used to encode the length of a string in a string tensor. STRING_LENGTH_NUM_BYTES = 4 diff --git a/tfjs-converter/python/tensorflowjs/read_weights_test.py b/tfjs-converter/python/tensorflowjs/read_weights_test.py index 0dd8d3ed93d..102ab5ab011 100644 --- a/tfjs-converter/python/tensorflowjs/read_weights_test.py +++ b/tfjs-converter/python/tensorflowjs/read_weights_test.py @@ -45,6 +45,10 @@ def testReadOneGroup(self): [{ 'name': 'weight1', 'data': np.array([1, 2, 3], 'float32') + }, + { + 'name': 'weight2', + 'data': np.array([1 + 1j, 2 + 2j, 3 + 3j]) }] ] @@ -53,11 +57,13 @@ def testReadOneGroup(self): # Read the weights using `read_weights`. read_output = read_weights.read_weights(manifest, self._tmp_dir) self.assertEqual(1, len(read_output)) - self.assertEqual(1, len(read_output[0])) + self.assertEqual(2, len(read_output[0])) self.assertEqual('weight1', read_output[0][0]['name']) self.assertTrue( np.allclose(groups[0][0]['data'], read_output[0][0]['data'])) - + self.assertEqual('weight2', read_output[0][1]['name']) + self.assertTrue( + np.allclose(groups[0][1]['data'], read_output[0][1]['data'])) def testReadOneGroupString(self): groups = [ [{ diff --git a/tfjs-converter/python/tensorflowjs/write_weights.py b/tfjs-converter/python/tensorflowjs/write_weights.py index 34dd1825410..15a450318bd 100644 --- a/tfjs-converter/python/tensorflowjs/write_weights.py +++ b/tfjs-converter/python/tensorflowjs/write_weights.py @@ -24,10 +24,12 @@ from tensorflowjs import quantization from tensorflowjs import read_weights -_OUTPUT_DTYPES = [np.float32, np.int32, np.uint8, np.uint16, np.bool, np.object] +_OUTPUT_DTYPES = [np.float32, np.int32, np.complex64, + np.uint8, np.uint16, np.bool, np.object] _AUTO_DTYPE_CONVERSION = { np.dtype(np.float64): np.float32, - np.dtype(np.int64): np.int32} + np.dtype(np.int64): np.int32, + np.dtype(np.complex128): np.complex64} def write_weights( weight_groups, write_dir, shard_size_bytes=1024 * 1024 * 4, @@ -366,6 +368,7 @@ def _assert_valid_weight_entry(entry): if not (data.dtype in _OUTPUT_DTYPES or data.dtype in _AUTO_DTYPE_CONVERSION): + print(_OUTPUT_DTYPES) raise ValueError('Error dumping weight ' + name + ', dtype ' + data.dtype.name + ' not supported.') diff --git a/tfjs-converter/python/tensorflowjs/write_weights_test.py b/tfjs-converter/python/tensorflowjs/write_weights_test.py index 4d0dfbc3d50..2ea37b5b159 100644 --- a/tfjs-converter/python/tensorflowjs/write_weights_test.py +++ b/tfjs-converter/python/tensorflowjs/write_weights_test.py @@ -278,6 +278,37 @@ def test_1_group_1_weight_string_sharded(self): string = weight_bytes[4:14].decode('utf-8') self.assertEqual(string, u'helloworld') + def test_1_group_1_weight_complex(self): + groups = [ + [{ + 'name': 'weight1', + 'data': np.array([1 + 1j, 2 + 2j, 3 + 3j], 'complex') + }] + ] + + manifest = write_weights.write_weights( + groups, TMP_DIR, shard_size_bytes=6 * 4) + + self.assertTrue( + os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')), + 'weights_manifest.json does not exist') + + self.assertEqual( + manifest, + [{ + 'paths': ['group1-shard1of1.bin'], + 'weights': [{ + 'name': 'weight1', + 'shape': [3], + 'dtype': 'complex64' + }] + }]) + + weights_path = os.path.join(TMP_DIR, 'group1-shard1of1.bin') + weight1 = np.fromfile(weights_path, 'complex64') + np.testing.assert_array_equal( + weight1, np.array([1 + 1j, 2 + 2j, 3 + 3j], 'complex64')) + def test_1_group_3_weights_packed_multi_dtype(self): # Each string tensor uses different encoding. groups = [ @@ -654,17 +685,6 @@ def test_bad_weights_entry_throws_no_data(self): with self.assertRaises(Exception): write_weights.write_weights(groups, TMP_DIR) - def test_bad_numpy_array_dtype_throws(self): - groups = [ - [{ - 'name': 'weight1', - 'data': np.array([1, 2, 3], 'complex') - }] - ] - - with self.assertRaises(Exception): - write_weights.write_weights(groups, TMP_DIR) - def test_duplicate_weight_name_throws(self): groups = [ [{ From 4a484dd68ff1b2578e175e301b437e2ee461a614 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Fri, 29 May 2020 17:19:34 -0700 Subject: [PATCH 2/7] add complex64 support in weight loader --- tfjs-core/src/io/io_utils.ts | 20 +++++++++++-- tfjs-core/src/io/io_utils_test.ts | 49 +++++++++++++++++++++++++++++-- tfjs-core/src/io/types.ts | 3 +- 3 files changed, 66 insertions(+), 6 deletions(-) diff --git a/tfjs-core/src/io/io_utils.ts b/tfjs-core/src/io/io_utils.ts index 5a8093973ac..f704b40315f 100644 --- a/tfjs-core/src/io/io_utils.ts +++ b/tfjs-core/src/io/io_utils.ts @@ -15,6 +15,8 @@ * ============================================================================= */ +import * as tf from '../index'; + import {tensor} from '../ops/tensor_ops'; import {NamedTensor, NamedTensorMap} from '../tensor_types'; import {TypedArray} from '../types'; @@ -57,7 +59,7 @@ export async function encodeWeights( const name = names[i]; const t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name]; if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' && - t.dtype !== 'string') { + t.dtype !== 'string' && t.dtype !== 'complex64') { throw new Error(`Unsupported dtype in weight '${name}': ${t.dtype}`); } const spec: WeightsManifestEntry = {name, shape: t.shape, dtype: t.dtype}; @@ -171,13 +173,25 @@ export function decodeWeights( values = new Int32Array(byteBuffer); } else if (dtype === 'bool') { values = new Uint8Array(byteBuffer); + } else if (dtype === 'complex64') { + values = new Float32Array(byteBuffer); + const real = new Float32Array(values.length / 2); + const image = new Float32Array(values.length / 2); + for (let i = 0; i < values.length / 2; i++) { + real[i] = values[i * 2]; + image[i] = values[i * 2 + 1]; + } + const realTensor = tensor(real, shape, 'float32'); + const imageTensor = tensor(image, shape, 'float32'); + out[name] = tf.complex(realTensor, imageTensor); } else { throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); } offset += size * dtypeFactor; } - - out[name] = tensor(values, shape, dtype); + if (dtype !== 'complex64') { + out[name] = tensor(values, shape, dtype); + } } return out; } diff --git a/tfjs-core/src/io/io_utils_test.ts b/tfjs-core/src/io/io_utils_test.ts index 70ff481a831..1f2730725bf 100644 --- a/tfjs-core/src/io/io_utils_test.ts +++ b/tfjs-core/src/io/io_utils_test.ts @@ -330,6 +330,40 @@ describeWithFlags('encodeWeights', ALL_ENVS, () => { ]); }); + + it('Complex64 tensors', async () => { + const tensors: NamedTensorMap = { + x1: tf.complex([1, 2], [1, 2]), + x2: tf.complex(1, 2), + x3: tf.complex([[1]], [[2]]), + }; + const dataAndSpecs = await tf.io.encodeWeights(tensors); + const data = dataAndSpecs.data; + const specs = dataAndSpecs.specs; + expect(data.byteLength).toEqual(8 * 4); + expect(new Float32Array(data, 0, 4)).toEqual(new Float32Array([ + 1, 1, 2, 2 + ])); + expect(new Float32Array(data, 16, 2)).toEqual(new Float32Array([1, 2])); + expect(new Float32Array(data, 24, 2)).toEqual(new Float32Array([1, 2])); + expect(specs).toEqual([ + { + name: 'x1', + dtype: 'complex64', + shape: [2], + }, + { + name: 'x2', + dtype: 'complex64', + shape: [], + }, + { + name: 'x3', + dtype: 'complex64', + shape: [1, 1], + } + ]); + }); it('String tensors', async () => { const tensors: NamedTensorMap = { x1: tensor2d([['a', 'bc'], ['def', 'g']], [2, 2]), @@ -396,16 +430,20 @@ describeWithFlags('encodeWeights', ALL_ENVS, () => { x1: tensor2d([[10, 20], [30, 40]], [2, 2], 'int32'), x2: scalar(13.37, 'float32'), x3: tensor1d([true, false, false, true], 'bool'), + x4: tf.complex([1, 1], [2, 2]) }; const dataAndSpecs = await tf.io.encodeWeights(tensors); const data = dataAndSpecs.data; const specs = dataAndSpecs.specs; - expect(data.byteLength).toEqual(4 * 4 + 4 * 1 + 1 * 4); + expect(data.byteLength).toEqual(4 * 4 + 4 * 1 + 1 * 4 + 4 * 4); expect(new Int32Array(data, 0, 4)).toEqual(new Int32Array([ 10, 20, 30, 40 ])); expect(new Float32Array(data, 16, 1)).toEqual(new Float32Array([13.37])); expect(new Uint8Array(data, 20, 4)).toEqual(new Uint8Array([1, 0, 0, 1])); + expect(new Float32Array(data, 24, 4)).toEqual(new Float32Array([ + 1, 2, 1, 2 + ])); expect(specs).toEqual([ { name: 'x1', @@ -421,6 +459,11 @@ describeWithFlags('encodeWeights', ALL_ENVS, () => { name: 'x3', dtype: 'bool', shape: [4], + }, + { + name: 'x4', + dtype: 'complex64', + shape: [2], } ]); }); @@ -436,12 +479,13 @@ describeWithFlags('decodeWeights', {}, () => { x5: tensor1d([''], 'string'), // Empty string. x6: scalar('hello'), // Single string. y1: tensor2d([-10, -20, -30], [3, 1], 'float32'), + y2: tf.complex([1, 1], [2, 2]) }; const dataAndSpecs = await tf.io.encodeWeights(tensors); const data = dataAndSpecs.data; const specs = dataAndSpecs.specs; const decoded = tf.io.decodeWeights(data, specs); - expect(Object.keys(decoded).length).toEqual(7); + expect(Object.keys(decoded).length).toEqual(8); expectArraysEqual(await decoded['x1'].data(), await tensors['x1'].data()); expectArraysEqual(await decoded['x2'].data(), await tensors['x2'].data()); expectArraysEqual(await decoded['x3'].data(), await tensors['x3'].data()); @@ -449,6 +493,7 @@ describeWithFlags('decodeWeights', {}, () => { expectArraysEqual(await decoded['x5'].data(), await tensors['x5'].data()); expectArraysEqual(await decoded['x6'].data(), await tensors['x6'].data()); expectArraysEqual(await decoded['y1'].data(), await tensors['y1'].data()); + expectArraysEqual(await decoded['y2'].data(), await tensors['y2'].data()); }); it('Unsupported dtype raises Error', () => { diff --git a/tfjs-core/src/io/types.ts b/tfjs-core/src/io/types.ts index 9f0d48885e7..f4ee6ff0939 100644 --- a/tfjs-core/src/io/types.ts +++ b/tfjs-core/src/io/types.ts @@ -26,6 +26,7 @@ export const DTYPE_VALUE_SIZE_MAP: {[dtype: string]: number} = { 'uint16': 2, 'uint8': 1, 'bool': 1, + 'complex64': 8 }; /** @@ -85,7 +86,7 @@ export declare interface WeightsManifestEntry { /** * Data type of the weight. */ - dtype: 'float32'|'int32'|'bool'|'string'; + dtype: 'float32'|'int32'|'bool'|'string'|'complex64'; /** * Type of the weight. From 93f5078546883f7d76ee1ab7bab68cee074a7393 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Sat, 30 May 2020 11:44:23 -0700 Subject: [PATCH 3/7] added integration test for complex64 weights, and udpate the integration test to use the local python code --- .gitignore | 1 + e2e/integration_tests/constants.ts | 3 +- e2e/integration_tests/convert_predict.py | 68 +++++++++++++++++-- e2e/integration_tests/convert_predict.ts | 5 +- e2e/integration_tests/requirements-dev.txt | 3 +- e2e/integration_tests/requirements-stable.txt | 8 ++- e2e/scripts/setup-py-env.sh | 2 + tfjs-converter/python/build-pip-package.sh | 6 +- tfjs-converter/python/requirements.txt | 2 +- .../python/tensorflowjs/read_weights_test.py | 3 +- 10 files changed, 86 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index 09263eb168c..4056db56328 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,7 @@ tfjs-layers/integration_tests/tfjs2keras/test-data/ tfjs-layers/integration/typescript/yarn.lock e2e/integration_tests/create_save_predict_data e2e/integration_tests/convert_predict_data +tfjs-converter/python/tensorflowjs.egg-info # Ignore the src, binding, scripts of tfjs-node-gpu since it is copied over when # building. diff --git a/e2e/integration_tests/constants.ts b/e2e/integration_tests/constants.ts index 3695cfd0ba8..1993110f78c 100644 --- a/e2e/integration_tests/constants.ts +++ b/e2e/integration_tests/constants.ts @@ -34,7 +34,8 @@ export const LAYERS_MODELS = [ export const GRAPH_MODELS = [ 'saved_model_v1', 'saved_model_v2', 'saved_model_v2_with_control_flow', - 'saved_model_with_conv2d', 'saved_model_with_prelu' + 'saved_model_with_conv2d', 'saved_model_with_prelu', + 'saved_model_v2_complex64', 'saved_model_v2_with_control_flow_v2' ]; /** Karma server directory serving local files. */ diff --git a/e2e/integration_tests/convert_predict.py b/e2e/integration_tests/convert_predict.py index 53778e15545..363b02698f3 100644 --- a/e2e/integration_tests/convert_predict.py +++ b/e2e/integration_tests/convert_predict.py @@ -50,7 +50,7 @@ curr_dir = os.path.dirname(os.path.realpath(__file__)) _tmp_dir = os.path.join(curr_dir, 'convert_predict_data') -def _save_and_convert_model(model_fn, model_path): +def _save_and_convert_model(model_fn, model_path, control_flow_v2=False): """Benchmark a model's fit() and predict() calls; serialize the model. Args: @@ -94,13 +94,17 @@ def _save_and_convert_model(model_fn, model_path): artifacts_dir = os.path.join(_tmp_dir, model_path) # Convert and store model to file. - subprocess.check_output([ + args = [ 'tensorflowjs_converter', '--input_format', 'tf_saved_model', '--output_format', 'tfjs_graph_model', '--signature_name', 'serving_default', - '--saved_model_tags', 'serve', - tmp_saved_model_dir, artifacts_dir]) + '--saved_model_tags', 'serve']; + if control_flow_v2: + args = args + ['--control_flow_v2', 'True'] + + print(args) + subprocess.check_output(args +[tmp_saved_model_dir, artifacts_dir]) def _create_saved_model_v1(save_dir): """Create a TensorFlow V1 SavedModel for testing. @@ -257,6 +261,58 @@ def _create_saved_model_with_prelu(save_dir): "shape": result.shape, "dtype": "float32"}}} +def _create_saved_model_v2_complex64(save_dir): + """Test a TF V2 model with complex dtype. + + Args: + save_dir: directory name of where the saved model will be stored. + """ + input_data = constant_op.constant(1., shape=[1]) + root = tracking.AutoTrackable() + root.v1 = variables.Variable(3 + 1j, dtype=tf.complex64) + root.f = def_function.function(lambda x: tf.complex(x, x) + root.v1) + to_save = root.f.get_concrete_function(input_data) + + save(root, save_dir, to_save) + return { + "async": False, + "inputs": { + "x": {"value": [1], "shape": [1], "dtype": 'float32'}}, + "outputs": { + "Identity:0": {"value": [4, 2], "shape": [1], "dtype": "complex64"}}} + +def _create_saved_model_v2_with_control_flow_v2(save_dir): + """Test a TF V2 model with complex dtype. + + Args: + save_dir: directory name of where the saved model will be stored. + """ + class CustomModule(tf.Module): + + def __init__(self): + super(CustomModule, self).__init__() + + @tf.function(input_signature=[ + tf.TensorSpec([], tf.float32), tf.TensorSpec([], tf.float32)]) + def control_flow(self, x, y): + while x < y: + x = x + 2 + return x + + + module = CustomModule() + print(module.control_flow(1, 2)) + tf.saved_model.save(module, save_dir, + signatures=module.control_flow) + + return { + "async": False, + "inputs": { + "x": {"value": [-1.], "shape": [], "dtype": 'float32'}, + "y": {"value": [2.], "shape": [], "dtype": 'float32'}}, + "outputs": { + "Identity:0": {"value": [3.], "shape": [], "dtype": "float32"}}} + def main(): # Create the directory to store model and data. if os.path.exists(_tmp_dir) and os.path.isdir(_tmp_dir): @@ -265,8 +321,12 @@ def main(): _save_and_convert_model(_create_saved_model_v1, 'saved_model_v1') _save_and_convert_model(_create_saved_model_v2, 'saved_model_v2') + _save_and_convert_model(_create_saved_model_v2_complex64, + 'saved_model_v2_complex64') _save_and_convert_model(_create_saved_model_v2_with_control_flow, 'saved_model_v2_with_control_flow') + _save_and_convert_model(_create_saved_model_v2_with_control_flow_v2, + 'saved_model_v2_with_control_flow_v2', control_flow_v2=True) _save_and_convert_model(_create_saved_model_with_conv2d, 'saved_model_with_conv2d') _save_and_convert_model(_create_saved_model_with_prelu, diff --git a/e2e/integration_tests/convert_predict.ts b/e2e/integration_tests/convert_predict.ts index 4fb0672b838..e6df517c76e 100644 --- a/e2e/integration_tests/convert_predict.ts +++ b/e2e/integration_tests/convert_predict.ts @@ -60,13 +60,14 @@ describe(`${REGRESSION} convert_predict`, () => { BACKENDS.forEach(backend => { it(`with ${backend}.`, async () => { + await tfc.setBackend(backend); + + console.log(model); const $model = await tfconverter.loadGraphModel( `${KARMA_SERVER}/${DATA_URL}/${model}/model.json`); const xs = createInputTensors(inputsData, inputsShapes); - await tfc.setBackend(backend); - const result = await $model.executeAsync(xs); const ys = diff --git a/e2e/integration_tests/requirements-dev.txt b/e2e/integration_tests/requirements-dev.txt index 3fb61308909..026a4bd8f2d 100644 --- a/e2e/integration_tests/requirements-dev.txt +++ b/e2e/integration_tests/requirements-dev.txt @@ -1,3 +1,2 @@ keras==2.3.1 -tensorflowjs>=1.2.10.1 -tf-nightly>=2.1.0.dev20191007 +-r ../../tfjs-converter/python/requirements.txt diff --git a/e2e/integration_tests/requirements-stable.txt b/e2e/integration_tests/requirements-stable.txt index a9115f7c8f0..b5c35704ac6 100644 --- a/e2e/integration_tests/requirements-stable.txt +++ b/e2e/integration_tests/requirements-stable.txt @@ -1,3 +1,7 @@ keras==2.3.1 -tensorflow==1.15.0 -tensorflowjs==1.7.4 +h5py>=2.8.0 +numpy>=1.16.4 +six>=1.12.0 +tensorflow-cpu==2.1.0 +tensorflow-hub==0.7.0 +PyInquirer==1.0.3 diff --git a/e2e/scripts/setup-py-env.sh b/e2e/scripts/setup-py-env.sh index fe440f87c5d..eef029250c6 100755 --- a/e2e/scripts/setup-py-env.sh +++ b/e2e/scripts/setup-py-env.sh @@ -62,3 +62,5 @@ if [[ "${DEV_VERSION}" == "stable" ]]; then else pip3 install -r requirements-dev.txt fi + +pip3 install -e ../../tfjs-converter/python diff --git a/tfjs-converter/python/build-pip-package.sh b/tfjs-converter/python/build-pip-package.sh index b3f3c8a0068..afe28e64407 100755 --- a/tfjs-converter/python/build-pip-package.sh +++ b/tfjs-converter/python/build-pip-package.sh @@ -32,6 +32,7 @@ # --test-nightly: Test the pip packages by installing it (inside virtualenv) # and running test_pip_package.py and test_pip_nightly_package.py # against the install. +# --build: Create the pip packages. # --upload: Upload the py2 and py3 wheels to prod PyPI. # --upload-to-test: Upload the py2 and py3 wheels to test PyPI, mutually # exclusive with --upload. @@ -62,7 +63,7 @@ set -e function print_usage() { echo "Usage:" echo " build-pip-packages.sh \\" - echo " [--test] [--test-nightly] [--upload] [--upload-to-test] [--confirm-upload] " + echo " [--test] [--test-nightly] [--build] [--upload] [--upload-to-test] [--confirm-upload] " echo } @@ -78,6 +79,7 @@ RUN_TEST_NIGHTLY=0 UPLOAD_TO_PROD_PYPI=0 UPLOAD_TO_TEST_PYPI=0 CONFIRM_UPLOAD=0 +BUILD=0 DEST_DIR="" while true; do if [[ "$1" == "--test" ]]; then @@ -90,6 +92,8 @@ while true; do UPLOAD_TO_TEST_PYPI=1 elif [[ "$1" == "--confirm-upload" ]]; then CONFIRM_UPLOAD=1 + elif [[ "$1" == "--build" ]]; then + BUILD=1 elif [[ "$1" != --* ]]; then DEST_DIR="$1" else diff --git a/tfjs-converter/python/requirements.txt b/tfjs-converter/python/requirements.txt index 735e24a41ff..2e015333d1d 100644 --- a/tfjs-converter/python/requirements.txt +++ b/tfjs-converter/python/requirements.txt @@ -1,6 +1,6 @@ h5py>=2.8.0 numpy>=1.16.4 six>=1.12.0 -tensorflow-cpu>=2.1.0<3 +tensorflow-cpu==2.1.0 tensorflow-hub==0.7.0 PyInquirer==1.0.3 diff --git a/tfjs-converter/python/tensorflowjs/read_weights_test.py b/tfjs-converter/python/tensorflowjs/read_weights_test.py index 102ab5ab011..61af1d3212f 100644 --- a/tfjs-converter/python/tensorflowjs/read_weights_test.py +++ b/tfjs-converter/python/tensorflowjs/read_weights_test.py @@ -45,8 +45,7 @@ def testReadOneGroup(self): [{ 'name': 'weight1', 'data': np.array([1, 2, 3], 'float32') - }, - { + }, { 'name': 'weight2', 'data': np.array([1 + 1j, 2 + 2j, 3 + 3j]) }] From 2588ad988bb8e12a19d6cf1a5b5fcedbb861fb94 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 1 Jun 2020 11:55:35 -0700 Subject: [PATCH 4/7] fix pylint --- tfjs-converter/python/tensorflowjs/write_weights.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tfjs-converter/python/tensorflowjs/write_weights.py b/tfjs-converter/python/tensorflowjs/write_weights.py index 15a450318bd..fbbcbc44cd0 100644 --- a/tfjs-converter/python/tensorflowjs/write_weights.py +++ b/tfjs-converter/python/tensorflowjs/write_weights.py @@ -368,7 +368,6 @@ def _assert_valid_weight_entry(entry): if not (data.dtype in _OUTPUT_DTYPES or data.dtype in _AUTO_DTYPE_CONVERSION): - print(_OUTPUT_DTYPES) raise ValueError('Error dumping weight ' + name + ', dtype ' + data.dtype.name + ' not supported.') From ab0d1b43f52b8e25d2694741ef516ab5492eb907 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 1 Jun 2020 12:05:03 -0700 Subject: [PATCH 5/7] addressed comments --- e2e/integration_tests/convert_predict.ts | 1 - e2e/scripts/setup-py-env.sh | 1 + tfjs-core/src/io/io_utils.ts | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/e2e/integration_tests/convert_predict.ts b/e2e/integration_tests/convert_predict.ts index e6df517c76e..0c40038d097 100644 --- a/e2e/integration_tests/convert_predict.ts +++ b/e2e/integration_tests/convert_predict.ts @@ -62,7 +62,6 @@ describe(`${REGRESSION} convert_predict`, () => { it(`with ${backend}.`, async () => { await tfc.setBackend(backend); - console.log(model); const $model = await tfconverter.loadGraphModel( `${KARMA_SERVER}/${DATA_URL}/${model}/model.json`); diff --git a/e2e/scripts/setup-py-env.sh b/e2e/scripts/setup-py-env.sh index eef029250c6..279e099b1bc 100755 --- a/e2e/scripts/setup-py-env.sh +++ b/e2e/scripts/setup-py-env.sh @@ -63,4 +63,5 @@ else pip3 install -r requirements-dev.txt fi +echo "Loading tensorflowjs pip from source ...." pip3 install -e ../../tfjs-converter/python diff --git a/tfjs-core/src/io/io_utils.ts b/tfjs-core/src/io/io_utils.ts index f704b40315f..347f1c3b3fa 100644 --- a/tfjs-core/src/io/io_utils.ts +++ b/tfjs-core/src/io/io_utils.ts @@ -177,7 +177,7 @@ export function decodeWeights( values = new Float32Array(byteBuffer); const real = new Float32Array(values.length / 2); const image = new Float32Array(values.length / 2); - for (let i = 0; i < values.length / 2; i++) { + for (let i = 0; i < real.length; i++) { real[i] = values[i * 2]; image[i] = values[i * 2 + 1]; } From e1f18cc23b8f3bc699e04854df6214579bd9316b Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 1 Jun 2020 12:06:29 -0700 Subject: [PATCH 6/7] more fixes --- tfjs-core/src/io/io_utils.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tfjs-core/src/io/io_utils.ts b/tfjs-core/src/io/io_utils.ts index 347f1c3b3fa..11cf172e3e5 100644 --- a/tfjs-core/src/io/io_utils.ts +++ b/tfjs-core/src/io/io_utils.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import * as tf from '../index'; +import {complex} from '../index'; import {tensor} from '../ops/tensor_ops'; import {NamedTensor, NamedTensorMap} from '../tensor_types'; @@ -183,7 +183,7 @@ export function decodeWeights( } const realTensor = tensor(real, shape, 'float32'); const imageTensor = tensor(image, shape, 'float32'); - out[name] = tf.complex(realTensor, imageTensor); + out[name] = complex(realTensor, imageTensor); } else { throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); } From c72b581c56204cf865a60660e3db1740a2375852 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 1 Jun 2020 13:34:15 -0700 Subject: [PATCH 7/7] fix lint error --- tfjs-core/src/io/io_utils.ts | 2 +- tfjs-core/src/io/io_utils_test.ts | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tfjs-core/src/io/io_utils.ts b/tfjs-core/src/io/io_utils.ts index 11cf172e3e5..49c590f1f23 100644 --- a/tfjs-core/src/io/io_utils.ts +++ b/tfjs-core/src/io/io_utils.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {complex} from '../index'; +import {complex} from '../ops/complex_ops'; import {tensor} from '../ops/tensor_ops'; import {NamedTensor, NamedTensorMap} from '../tensor_types'; diff --git a/tfjs-core/src/io/io_utils_test.ts b/tfjs-core/src/io/io_utils_test.ts index 1f2730725bf..9ee1f35915a 100644 --- a/tfjs-core/src/io/io_utils_test.ts +++ b/tfjs-core/src/io/io_utils_test.ts @@ -330,7 +330,6 @@ describeWithFlags('encodeWeights', ALL_ENVS, () => { ]); }); - it('Complex64 tensors', async () => { const tensors: NamedTensorMap = { x1: tf.complex([1, 2], [1, 2]),