Skip to content

Commit

Permalink
Copy the NumPy API tests from Trax to the TF-NumPy directory. Otherwi…
Browse files Browse the repository at this point in the history
…se, there are no existing local tests for TF-NumPy

PiperOrigin-RevId: 536559285
  • Loading branch information
JW1992 authored and tensorflower-gardener committed May 31, 2023
1 parent a584df7 commit f806622
Show file tree
Hide file tree
Showing 8 changed files with 7,772 additions and 0 deletions.
154 changes: 154 additions & 0 deletions tensorflow/python/ops/numpy_ops/tests/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
licenses(["notice"])

py_library(
name = "config",
srcs = ["config.py"],
srcs_version = "PY2AND3",
deps = [
],
)

py_library(
name = "test_util",
srcs = ["test_util.py"],
srcs_version = "PY2AND3",
deps = [
":config",
":extensions",
"//tensorflow:tensorflow_py",
],
)

py_library(
name = "np_wrapper",
srcs = ["np_wrapper.py"],
srcs_version = "PY2AND3",
visibility = [
"//visibility:public",
],
deps = [
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/ops/numpy_ops:np_config",
"//tensorflow/python/ops/numpy_ops:np_dtypes",
"//tensorflow/python/ops/numpy_ops:numpy",
],
)

py_library(
name = "extensions",
srcs = ["extensions.py"],
srcs_version = "PY2AND3",
deps = [
":np_wrapper",
"//tensorflow:tensorflow_py",
"@six_archive//:six",
],
)

# copybara:uncomment_begin(google-only)
# py_test(
# name = "extensions_test",
# srcs = ["extensions_test.py"],
# python_version = "PY3",
# srcs_version = "PY2AND3",
# tags = [
# "gpu",
# "requires-gpu-nvidia",
# ],
# deps = [
# ":extensions",
# ":np_wrapper",
# "//learning/brain/research/jax:gpu_support",
# "//third_party/py/jax",
# "//tensorflow:tensorflow_py",
# ],
# )
#
# py_test(
# name = "extensions_test_tpu",
# srcs = ["extensions_test.py"],
# args = [
# "--jax_allow_unused_tpus",
# "--requires_tpu",
# ],
# main = "extensions_test.py",
# python_version = "PY3",
# tags = [
# "requires-tpu",
# ],
# deps = [
# ":extensions",
# ":np_wrapper",
# "//learning/brain/google/xla",
# "//third_party/py/jax",
# "//tensorflow:tensorflow_py",
# "@absl_py//absl/flags",
# ],
# )
# copybara:uncomment_end

py_test(
name = "np_test",
timeout = "long",
srcs = ["np_test.py"],
args = [
"--num_generated_cases=90",
"--enable_x64", # Needed to enable dtype check
],
python_version = "PY3",
shard_count = 20,
srcs_version = "PY2AND3",
tags = [
"gpu",
"requires-gpu-nvidia",
],
deps = [
":np_wrapper",
":test_util",
],
)

py_test(
name = "np_indexing_test",
srcs = ["np_indexing_test.py"],
args = [
"--num_generated_cases=90",
"--enable_x64", # Needed to enable dtype check
],
python_version = "PY3",
shard_count = 10,
srcs_version = "PY2AND3",
# TODO(b/164245103): Re-enable GPU once tf.tensor_strided_slice_update's segfault is fixed.
# tags = [
# "gpu",
# "requires-gpu-nvidia",
# ],
deps = [
":np_wrapper",
":test_util",
],
)

py_test(
name = "np_einsum_test",
srcs = ["np_einsum_test.py"],
args = [
"--num_generated_cases=90",
"--enable_x64", # Needed to enable dtype check
],
python_version = "PY3",
shard_count = 20,
srcs_version = "PY2AND3",
tags = [
"gpu",
"requires-gpu-nvidia",
],
deps = [
":config",
":test_util",
"//tensorflow/python/ops/numpy_ops:np_config",
"//tensorflow/python/ops/numpy_ops:numpy",
"@absl_py//absl/testing:absltest",
"@absl_py//absl/testing:parameterized",
],
)
141 changes: 141 additions & 0 deletions tensorflow/python/ops/numpy_ops/tests/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test configurations."""
import os
import sys


def bool_env(varname: str, default: bool) -> bool:
"""Read an environment variable and interpret it as a boolean.
True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1';
false values are 'n', 'no', 'f', 'false', 'off', and '0'.
Args:
varname: the name of the variable
default: the default boolean value
Raises: ValueError if the environment variable is anything else.
"""
val = os.getenv(varname, str(default))
val = val.lower()
if val in ('y', 'yes', 't', 'true', 'on', '1'):
return True
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
return False
else:
raise ValueError(
'invalid truth value %r for environment %r' % (val, varname)
)


class Config(object):

def __init__(self):
self.values = {}
self.meta = {}
self.FLAGS = NameSpace(self.read)
self.use_absl = False

def update(self, name, val):
if self.use_absl:
setattr(self.absl_flags.FLAGS, name, val)
else:
self.check_exists(name)
if name not in self.values:
raise Exception("Unrecognized config option: {}".format(name))
self.values[name] = val

def read(self, name):
if self.use_absl:
return getattr(self.absl_flags.FLAGS, name)
else:
self.check_exists(name)
return self.values[name]

def add_option(self, name, default, opt_type, meta_args, meta_kwargs):
if name in self.values:
raise Exception("Config option {} already defined".format(name))
self.values[name] = default
self.meta[name] = (opt_type, meta_args, meta_kwargs)

def check_exists(self, name):
if name not in self.values:
raise Exception("Unrecognized config option: {}".format(name))

def DEFINE_bool(self, name, default, *args, **kwargs):
self.add_option(name, default, bool, args, kwargs)

def DEFINE_integer(self, name, default, *args, **kwargs):
self.add_option(name, default, int, args, kwargs)

def DEFINE_string(self, name, default, *args, **kwargs):
self.add_option(name, default, str, args, kwargs)

def DEFINE_enum(self, name, default, *args, **kwargs):
self.add_option(name, default, 'enum', args, kwargs)

def config_with_absl(self):
# Run this before calling `app.run(main)` etc
import absl.flags as absl_FLAGS
from absl import app, flags as absl_flags

self.use_absl = True
self.absl_flags = absl_flags
absl_defs = { bool: absl_flags.DEFINE_bool,
int: absl_flags.DEFINE_integer,
str: absl_flags.DEFINE_string,
'enum': absl_flags.DEFINE_enum }

for name, val in self.values.items():
flag_type, meta_args, meta_kwargs = self.meta[name]
absl_defs[flag_type](name, val, *meta_args, **meta_kwargs)

app.call_after_init(lambda: self.complete_absl_config(absl_flags))

def complete_absl_config(self, absl_flags):
for name, _ in self.values.items():
self.update(name, getattr(absl_flags.FLAGS, name))

def parse_flags_with_absl(self):
global already_configured_with_absl
if not already_configured_with_absl:
import absl.flags
self.config_with_absl()
absl.flags.FLAGS(sys.argv, known_only=True)
self.complete_absl_config(absl.flags)
already_configured_with_absl = True


class NameSpace(object):
def __init__(self, getter):
self._getter = getter

def __getattr__(self, name):
return self._getter(name)


config = Config()
flags = config
FLAGS = flags.FLAGS

already_configured_with_absl = False

flags.DEFINE_bool(
'jax_enable_checks',
bool_env('JAX_ENABLE_CHECKS', False),
help='Turn on invariant checking (core.skip_checks = False)')

flags.DEFINE_bool('tf_numpy_additional_tests', True,
'Run tests added specifically for TF numpy')
Loading

0 comments on commit f806622

Please sign in to comment.