Skip to content

Commit

Permalink
Add test for optimize_out_slice_nd
Browse files Browse the repository at this point in the history
  • Loading branch information
jotix16 committed Jun 13, 2021
1 parent 6410df5 commit 10cffc1
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions tests/test_TFNetworkRecLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3328,7 +3328,7 @@ def test_rec_subnet_simple_rnn():
print("rnn_cell also fine.")


def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, shared_base_net=None, rtol=1e-4):
def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, shared_base_net=None, from_=None, rtol=1e-4):
"""
:param dict[str] subnet_layer_dict: opts for the output layer inside the rec-layer subnet
:param dict[str,dict[str]] other_subnet_layers: other layers for the rec-layer subnet
Expand All @@ -3344,7 +3344,7 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha
subnet_layer_dict.setdefault("from", ["data:source"])
rec_layer_dict = {
"class": "rec",
"from": ["data"],
"from": ["data"] if from_ is None else [from_],
"unit": {"output": subnet_layer_dict},
"n_out": n_out,
"is_output_layer": True
Expand Down Expand Up @@ -3598,6 +3598,31 @@ def test_reclayer_optimize_out_access_split():
other_subnet_layers={"split": {"class": "split", "from": ["data:source"], "size_splits": [5, 8]}})


def test_reclayer_optimize_out_slice_nd():
def random_start_positions(source, **kwargs):
import tensorflow as tf
enc = source(0, as_data=True, enforce_batch_major=True, auto_convert=False)
enc_shape = tf.shape(enc.placeholder)
enc_time_dim = enc_shape[enc.time_dim_axis]
return tf.random.uniform(enc_shape[:-1], 0, enc_time_dim-2, dtype=tf.dtypes.int32)

check_reclayer_optimize_out(
{"class": "linear", "activation": None, "from": ["encoder_reduced"]},
from_="position",
other_subnet_layers={
"window": {"class": "slice_nd", "from": "base:encoder", "start": "data:source", "size": None, "min_size": 1, "is_output_layer": True},
"encoder_reduced": {"class": "reduce", "mode": "sum", "axis": "T", "from": ["base:encoder"], "is_output_layer": True}
},
shared_base_net={
"encoder": {"class": "copy", "from": "data", "is_output_layer": True},
"position": {
"class": "eval", "from": "encoder", "is_output_layer": True,
"eval": random_start_positions,
"out_type": {"batch_dim_axis": 0, "time_dim_axis": 1, "shape": (None,), "sparse": True, "dtype": "int32", "dim": None}}
}
)


def test_reclayer_att_with_kv_in_rec():
net_dict = {
'decision': {'class': 'decide', 'from': ['output'], 'loss': 'edit_distance', 'loss_opts': {}, 'target': 'classes'},
Expand Down

0 comments on commit 10cffc1

Please sign in to comment.