Skip to content

Commit

Permalink
Enables graph tensor checks in Runner's tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 584356970
  • Loading branch information
aferludin authored and tensorflower-gardener committed Nov 21, 2023
1 parent 8d5dac2 commit 2f3ada8
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tensorflow_gnn/runner/distribute_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ def get_dataset(self, _: tf.distribute.InputContext) -> tf.data.Dataset:

class OrchestrationTests(tf.test.TestCase, parameterized.TestCase):

def setUp(self):
super().setUp()
tfgnn.enable_graph_tensor_validation_at_runtime()

@tfdistribute.combinations.generate(
tftest.combinations.times(
_all_eager_strategy_combinations(),
Expand Down
8 changes: 8 additions & 0 deletions tensorflow_gnn/runner/examples/ogbn/mag/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

class MaskPaperLabelsTest(tf.test.TestCase, parameterized.TestCase):

def setUp(self):
super().setUp()
tfgnn.enable_graph_tensor_validation_at_runtime()

def test(self):
graph = tfgnn.GraphTensor.from_pieces(
node_sets={
Expand Down Expand Up @@ -59,6 +63,10 @@ def test(self):

class MakeCausalMaskTest(tf.test.TestCase, parameterized.TestCase):

def setUp(self):
super().setUp()
tfgnn.enable_graph_tensor_validation_at_runtime()

def test(self):
graph = tfgnn.GraphTensor.from_pieces(
node_sets={
Expand Down
4 changes: 4 additions & 0 deletions tensorflow_gnn/runner/orchestration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ def metrics(self) -> Metrics:

class OrchestrationTests(tf.test.TestCase, parameterized.TestCase):

def setUp(self):
super().setUp()
tfgnn.enable_graph_tensor_validation_at_runtime()

@parameterized.named_parameters([
dict(
testcase_name="GraphTensors",
Expand Down
7 changes: 7 additions & 0 deletions tensorflow_gnn/runner/tasks/classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
GraphTensor = tfgnn.GraphTensor
Field = tfgnn.Field

# Enables tests for graph pieces that are members of test classes.
tfgnn.enable_graph_tensor_validation_at_runtime()

TEST_GRAPH_TENSOR = GraphTensor.from_pieces(
context=tfgnn.Context.from_fields(
features={"labels": tf.constant((8, 1, 9, 1))}
Expand Down Expand Up @@ -60,6 +63,10 @@ def with_readout(num_labels: int, gt: GraphTensor) -> GraphTensor:

class Classification(tf.test.TestCase, parameterized.TestCase):

def setUp(self):
super().setUp()
tfgnn.enable_graph_tensor_validation_at_runtime()

@parameterized.named_parameters([
dict(
testcase_name="GraphBinaryClassificationLabelFn",
Expand Down
4 changes: 4 additions & 0 deletions tensorflow_gnn/runner/tasks/link_prediction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def _get_graph_tensor(

class LinkPredictionTest(tf.test.TestCase, parameterized.TestCase):

def setUp(self):
super().setUp()
tfgnn.enable_graph_tensor_validation_at_runtime()

def test_predict_on_dot_product_link_prediction(self):
task = link_prediction.DotProductLinkPrediction()
similarities = task.predict(_get_graph_tensor())
Expand Down
7 changes: 7 additions & 0 deletions tensorflow_gnn/runner/tasks/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
GraphTensor = tfgnn.GraphTensor
Field = tfgnn.Field

# Enables tests for graph pieces that are members of test classes.
tfgnn.enable_graph_tensor_validation_at_runtime()

TEST_GRAPH_TENSOR = GraphTensor.from_pieces(
context=tfgnn.Context.from_fields(
features={"labels": tf.constant((.8, .1, .9, .1))}
Expand Down Expand Up @@ -56,6 +59,10 @@ def with_readout(gt: GraphTensor) -> GraphTensor:

class Regression(tf.test.TestCase, parameterized.TestCase):

def setUp(self):
super().setUp()
tfgnn.enable_graph_tensor_validation_at_runtime()

@parameterized.named_parameters([
dict(
testcase_name="GraphMeanAbsoluteErrorLabelFn",
Expand Down
7 changes: 7 additions & 0 deletions tensorflow_gnn/runner/utils/attribution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,16 @@

IntegratedGradientsExporter = attribution.IntegratedGradientsExporter

# Enables tests for graph pieces that are members of test classes.
tfgnn.enable_graph_tensor_validation_at_runtime()


class AttributionTest(tf.test.TestCase):

def setUp(self):
super().setUp()
tfgnn.enable_graph_tensor_validation_at_runtime()

gt = tfgnn.GraphTensor.from_pieces(
context=tfgnn.Context.from_fields(features={
"h": tf.convert_to_tensor(((.514, .433),)),
Expand Down
4 changes: 4 additions & 0 deletions tensorflow_gnn/runner/utils/parsing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def random_serialized_graph_tensor() -> str:

class ParsingTest(tf.test.TestCase, parameterized.TestCase):

def setUp(self):
super().setUp()
tfgnn.enable_graph_tensor_validation_at_runtime()

def _assert_fields_equal(self, a: Fields, b: Fields):
self.assertCountEqual(a.keys(), b.keys())
for k, v in a.items():
Expand Down

0 comments on commit 2f3ada8

Please sign in to comment.