Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
91 lines (75 sloc) 2.95 KB
# coding=utf-8
# Copyright 2019 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base TestCase to use test_data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import os
import tempfile
from absl import logging
from absl.testing import absltest
import six
import tensorflow as tf
from tensorflow_datasets.core.utils import gcs_utils
GCS_ACCESS_FNS = {
"original_info": gcs_utils.gcs_dataset_info_files,
"dummy_info": lambda _: [],
"original_datasets": gcs_utils.is_dataset_on_gcs,
"dummy_datasets": lambda _: False,
}
class TestCase(tf.test.TestCase):
"""Base TestCase to be used for all tests.
`test_data` class attribute: path to the directory with test data.
`tmp_dir` attribute: path to temp directory reset before every test.
"""
@classmethod
def setUpClass(cls):
super(TestCase, cls).setUpClass()
cls.test_data = os.path.join(os.path.dirname(__file__), "test_data")
# Test must not communicate with GCS.
gcs_utils.gcs_dataset_info_files = GCS_ACCESS_FNS["dummy_info"]
gcs_utils.is_dataset_on_gcs = GCS_ACCESS_FNS["dummy_datasets"]
@contextlib.contextmanager
def gcs_access(self):
# Restore GCS access
gcs_utils.gcs_dataset_info_files = GCS_ACCESS_FNS["original_info"]
gcs_utils.is_dataset_on_gcs = GCS_ACCESS_FNS["original_datasets"]
yield
# Revert access
gcs_utils.gcs_dataset_info_files = GCS_ACCESS_FNS["dummy_info"]
gcs_utils.is_dataset_on_gcs = GCS_ACCESS_FNS["dummy_datasets"]
def setUp(self):
super(TestCase, self).setUp()
# get_temp_dir is actually the same for all tests, so create a temp sub-dir.
self.tmp_dir = tempfile.mkdtemp(dir=tf.compat.v1.test.get_temp_dir())
def assertRaisesWithPredicateMatch(self, err_type, predicate):
if isinstance(predicate, six.string_types):
predicate_fct = lambda err: predicate in str(err)
else:
predicate_fct = predicate
return super(TestCase, self).assertRaisesWithPredicateMatch(
err_type, predicate_fct)
@contextlib.contextmanager
def assertLogs(self, text, level="info"):
with absltest.mock.patch.object(logging, level) as mock_log:
yield
concat_logs = ""
for log_call in mock_log.call_args_list:
args = log_call[0]
base, args = args[0], args[1:]
log_text = base % tuple(args)
concat_logs += " " + log_text
self.assertIn(text, concat_logs)
You can’t perform that action at this time.