3636 OnlineDTActor ,
3737 ProbabilisticActor ,
3838 SafeModule ,
39+ set_recurrent_mode ,
3940 TanhDelta ,
4041 TanhNormal ,
4142 ValueOperator ,
@@ -729,6 +730,31 @@ def test_errs(self):
729730 with pytest .raises (KeyError , match = "is_init" ):
730731 lstm_module (td )
731732
733+ @pytest .mark .parametrize ("default_val" , [False , True , None ])
734+ def test_set_recurrent_mode (self , default_val ):
735+ lstm_module = LSTMModule (
736+ input_size = 3 ,
737+ hidden_size = 12 ,
738+ batch_first = True ,
739+ in_keys = ["observation" , "hidden0" , "hidden1" ],
740+ out_keys = ["intermediate" , ("next" , "hidden0" ), ("next" , "hidden1" )],
741+ default_recurrent_mode = default_val ,
742+ )
743+ assert lstm_module .recurrent_mode is bool (default_val )
744+ with set_recurrent_mode (True ):
745+ assert lstm_module .recurrent_mode
746+ with set_recurrent_mode (False ):
747+ assert not lstm_module .recurrent_mode
748+ with set_recurrent_mode ("recurrent" ):
749+ assert lstm_module .recurrent_mode
750+ with set_recurrent_mode ("sequential" ):
751+ assert not lstm_module .recurrent_mode
752+ assert lstm_module .recurrent_mode
753+ assert not lstm_module .recurrent_mode
754+ assert lstm_module .recurrent_mode
755+ assert lstm_module .recurrent_mode is bool (default_val )
756+
757+ @pytest .mark .filterwarnings ("ignore::DeprecationWarning" )
732758 def test_set_temporal_mode (self ):
733759 lstm_module = LSTMModule (
734760 input_size = 3 ,
@@ -754,7 +780,8 @@ def test_python_cudnn(self):
754780 num_layers = 2 ,
755781 in_keys = ["observation" , "hidden0" , "hidden1" ],
756782 out_keys = ["intermediate" , ("next" , "hidden0" ), ("next" , "hidden1" )],
757- ).set_recurrent_mode (True )
783+ default_recurrent_mode = True ,
784+ )
758785 obs = torch .rand (10 , 20 , 3 )
759786
760787 hidden0 = torch .rand (10 , 20 , 2 , 12 )
@@ -1109,6 +1136,31 @@ def test_errs(self):
11091136 with pytest .raises (KeyError , match = "is_init" ):
11101137 gru_module (td )
11111138
1139+ @pytest .mark .parametrize ("default_val" , [False , True , None ])
1140+ def test_set_recurrent_mode (self , default_val ):
1141+ gru_module = GRUModule (
1142+ input_size = 3 ,
1143+ hidden_size = 12 ,
1144+ batch_first = True ,
1145+ in_keys = ["observation" , "hidden" ],
1146+ out_keys = ["intermediate" , ("next" , "hidden" )],
1147+ default_recurrent_mode = default_val ,
1148+ )
1149+ assert gru_module .recurrent_mode is bool (default_val )
1150+ with set_recurrent_mode (True ):
1151+ assert gru_module .recurrent_mode
1152+ with set_recurrent_mode (False ):
1153+ assert not gru_module .recurrent_mode
1154+ with set_recurrent_mode ("recurrent" ):
1155+ assert gru_module .recurrent_mode
1156+ with set_recurrent_mode ("sequential" ):
1157+ assert not gru_module .recurrent_mode
1158+ assert gru_module .recurrent_mode
1159+ assert not gru_module .recurrent_mode
1160+ assert gru_module .recurrent_mode
1161+ assert gru_module .recurrent_mode is bool (default_val )
1162+
1163+ @pytest .mark .filterwarnings ("ignore::DeprecationWarning" )
11121164 def test_set_temporal_mode (self ):
11131165 gru_module = GRUModule (
11141166 input_size = 3 ,
0 commit comments