diff --git a/tests/test_TFNetworkRecLayer.py b/tests/test_TFNetworkRecLayer.py index b76e48d24..3a389fdac 100644 --- a/tests/test_TFNetworkRecLayer.py +++ b/tests/test_TFNetworkRecLayer.py @@ -325,6 +325,56 @@ def test_nativelstm2_initial_state_keep_epoch(): _check_train_simple_network({ "output": {"class": "rec", "unit": "nativelstm2", "loss": "mse", "initial_state": "keep_over_epoch"}}) +def test_lstmblock_initial_state_dependency(): + beam_size = 2 + network = { + "data_embed": {"class": "linear", "activation": None, "with_bias": False, "n_out": 6}, + "lstm0_fw" : { "class": "rec", "unit": "nativelstm2", "n_out" : 5, "direction": 1, "from": ["data_embed"] }, + + "encoder_state": {"class": "get_last_hidden_state", "from": ["lstm0_fw"], 'key': 'c', "n_out": 5}, + + "output": {"class": "rec", "from": [], "unit": { + 'output': {'class': 'choice', 'target': 'classes', 'beam_size': beam_size, 'from': ["output_prob"], + "initial_output": 0}, + "end": {"class": "compare", "from": ["output"], "value": 0}, + 'orth_embed': {'class': 'linear', 'activation': None, "with_bias": False, 'from': ['prev:output'], "n_out": 6}, + + "s": {"class": "rnn_cell", "unit": "LSTMBlock", "from": ["orth_embed"], "initial_state": {"c": "base:encoder_state", "h": 0}, "n_out": 5}, # h_t + + "output_prob": {"class": "softmax", "from": ["s"], "target": "classes", "loss": "ce"} + }, "target": "classes", "max_seq_len": 7}, + } + + from GeneratingDataset import DummyDataset + seq_len = 5 + n_data_dim = 2 + n_classes_dim = 3 + train_data = DummyDataset(input_dim=n_data_dim, output_dim=n_classes_dim, num_seqs=4, seq_len=seq_len) + train_data.init_seq_order(epoch=1) + cv_data = DummyDataset(input_dim=n_data_dim, output_dim=n_classes_dim, num_seqs=2, seq_len=seq_len) + cv_data.init_seq_order(epoch=1) + + config = Config() + config.update({ + "model": "/tmp/model", + "num_outputs": n_classes_dim, + "num_inputs": n_data_dim, + "network": network, + "start_epoch": 1, + "num_epochs": 2, + "batch_size": 10, + "nadam": True, + "learning_rate": 0.01, + "debug_add_check_numerics_ops": True + }) + + from tests.test_TFEngine import _cleanup_old_models + from TFEngine import Engine + _cleanup_old_models(config) + engine = Engine(config=config) + engine.init_train_from_config(config=config, train_data=train_data, dev_data=cv_data, eval_data=None) + engine.train() + engine.search(cv_data) def test_slow_TensorArray(): """