-
Notifications
You must be signed in to change notification settings - Fork 5
/
dreamer_v3.py
1913 lines (1723 loc) · 74.3 KB
/
dreamer_v3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import logging
import random
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple, Union, cast
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow import keras
from srl.base.define import DoneTypes, SpaceTypes
from srl.base.exception import UndefinedError
from srl.base.rl.algorithms.base_ppo import RLConfig, RLWorker
from srl.base.rl.parameter import RLParameter
from srl.base.rl.processor import Processor
from srl.base.rl.registration import register
from srl.base.rl.trainer import RLTrainer
from srl.base.spaces.array_continuous import ArrayContinuousSpace
from srl.base.spaces.discrete import DiscreteSpace
from srl.rl import functions as funcs
from srl.rl.memories.experience_replay_buffer import ExperienceReplayBuffer, RLConfigComponentExperienceReplayBuffer
from srl.rl.models.config.framework_config import RLConfigComponentFramework
from srl.rl.processors.image_processor import ImageProcessor
from srl.rl.schedulers.scheduler import SchedulerConfig
from srl.rl.tf.distributions.bernoulli_dist_block import BernoulliDistBlock
from srl.rl.tf.distributions.categorical_dist_block import CategoricalDist, CategoricalDistBlock
from srl.rl.tf.distributions.categorical_gumbel_dist_block import CategoricalGumbelDistBlock
from srl.rl.tf.distributions.linear_block import LinearBlock
from srl.rl.tf.distributions.normal_dist_block import NormalDistBlock
from srl.rl.tf.distributions.twohot_dist_block import TwoHotDistBlock
from srl.utils.common import compare_less_version
kl = keras.layers
tfd = tfp.distributions
v216_older = compare_less_version(tf.__version__, "2.16.0")
logger = logging.getLogger(__name__)
"""
paper: https://browse.arxiv.org/abs/2301.04104v1
ref: https://github.com/danijar/dreamerv3
"""
# ------------------------------------------------------
# config
# ------------------------------------------------------
@dataclass
class Config(
RLConfig,
RLConfigComponentExperienceReplayBuffer,
RLConfigComponentFramework,
):
"""
<:ref:`RLConfigComponentExperienceReplayBuffer`>
<:ref:`RLConfigComponentFramework`>
"""
# --- RSSM
#: 決定的な遷移のユニット数、内部的にはGRUのユニット数
rssm_deter_size: int = 4096
#: 確率的な遷移のユニット数
rssm_stoch_size: int = 32
#: 確率的な遷移のクラス数(rssm_use_categorical_distribution=Trueの場合有効)
rssm_classes: int = 32
#: 隠れ状態のユニット数
rssm_hidden_units: int = 1024
#: Trueの場合、LayerNormalization層が追加されます
rssm_use_norm_layer: bool = True
#: Falseの場合、確率的な遷移をガウス分布、Trueの場合カテゴリカル分布で表現します
rssm_use_categorical_distribution: bool = True
#: RSSM Activation
rssm_activation: Any = "silu"
#: カテゴリカル分布で保証する最低限の確率(rssm_use_categorical_distribution=Trueの場合有効)
rssm_unimix: float = 0.01
# --- other model layers
#: 学習する報酬の分布のタイプ
#:
#: Parameters:
#: "linear": MSEで学習(use_symlogの影響を受けます)
#: "normal": ガウス分布による学習(use_symlogの影響はうけません)
#: "normal_fixed_scale": ガウス分布による学習ですが、分散は1で固定(use_symlogの影響はうけません)
#: "twohot": TwoHotエンコーディングによる学習(use_symlogの影響を受けます)
reward_type: str = "twohot"
#: reward_typeが"twohot"の時のみ有効、bins
reward_twohot_bins: int = 255
#: reward_typeが"twohot"の時のみ有効、low
reward_twohot_low: int = -20
#: reward_typeが"twohot"の時のみ有効、high
reward_twohot_high: int = 20
#: reward modelの隠れ層
reward_layer_sizes: Tuple[int, ...] = (1024, 1024, 1024, 1024)
#: continue modelの隠れ層
cont_layer_sizes: Tuple[int, ...] = (1024, 1024, 1024, 1024)
#: critic modelの隠れ層
critic_layer_sizes: Tuple[int, ...] = (1024, 1024, 1024, 1024)
#: actor modelの隠れ層
actor_layer_sizes: Tuple[int, ...] = (1024, 1024, 1024, 1024)
#: 各層のactivation
dense_act: Any = "silu"
#: symlogを使用するか
use_symlog: bool = True
# --- encoder/decoder
#: 入力がIMAGE以外の場合の隠れ層
encoder_decoder_mlp: Tuple[int, ...] = (1024, 1024, 1024, 1024)
#: decoder出力層の分布
#:
#: Parameters:
#: "linear": mse
#: "normal": 正規分布
encoder_decoder_dist: str = "linear"
#: [入力がIMAGEの場合] Conv2Dのユニット数
cnn_depth: int = 96
#: [入力がIMAGEの場合] ResBlockの数
cnn_blocks: int = 0
#: [入力がIMAGEの場合] activation
cnn_activation: Any = "silu"
#: [入力がIMAGEの場合] 正規化層を追加するか
#:
#: Parameters:
#: "none": 何もしません
#: "layer": LayerNormalization層が追加されます
cnn_normalization_type: str = "layer"
#: [入力がIMAGEの場合] 画像を縮小する際のアルゴリズム
#:
#: Parameters:
#: "stride": Conv2Dのスライドで縮小します
#: "stride3": Conv2Dの3スライドで縮小します
cnn_resize_type: str = "stride"
#: [入力がIMAGEの場合] 画像縮小後のサイズ
cnn_resized_image_size: int = 4
#: [入力がIMAGEの場合] Trueの場合、画像の出力層をsigmoidにします。Falseの場合はLinearです。
cnn_use_sigmoid: bool = False
# --- loss params
#: free bit
free_nats: float = 1.0 # 1nat ~ 1.44bit
#: reconstruction loss rate
loss_scale_pred: float = 1.0
#: dynamics kl loss rate
loss_scale_kl_dyn: float = 0.5
#: rep kl loss rate
loss_scale_kl_rep: float = 0.1
#: 序盤はworld modelのみ学習します
warmup_world_model: int = 0
# --- actor/critic
#: critic target update interval
critic_target_update_interval: int = 0 # 0 is disable target
#: critic target soft update tau
critic_target_soft_update: float = 0.02
#: critic model type
#:
#: Parameters:
#: "linear" : MSEで学習(use_symlogの影響を受けます)
#: "normal" : 正規分布(use_symlogの影響は受けません)
#: "normal_fixed_scale": 分散1固定の正規分布(use_symlogの影響は受けません)
#: "twohot" : TwoHotカテゴリカル分布(use_symlogの影響を受けます)
critic_type: str = "twohot"
#: critic_typeが"dreamer_v3"の時のみ有効、bins
critic_twohot_bins: int = 255
#: critic_typeが"dreamer_v3"の時のみ有効、low
critic_twohot_low: int = -20
#: critic_typeが"dreamer_v3"の時のみ有効、high
critic_twohot_high: int = 20
#: actor model type
#:
#: Parameters:
#: "categorical" : カテゴリカル分布
#: "gumbel_categorical" : Gumbelカテゴリ分布
actor_discrete_type: str = "categorical"
#: カテゴリカル分布で保証する最低限の確率(actionタイプがDISCRETEの時のみ有効)
actor_discrete_unimix: float = 0.01
#: actionが連続値の時、正規分布をtanhで-1~1に丸めるか
actor_continuous_enable_normal_squashed: bool = True
# --- Behavior
#: horizonのstep数
horizon: int = 15
#: "actor" or "random", random is debug.
horizon_policy: str = "actor"
#: horizon時の価値の計算方法
#:
#: Parameters:
#: "simple" : 単純な総和
#: "discount" : 割引報酬
#: "ewa" : EWA
#: "h-return" : λ-return
critic_estimation_method: str = "h-return"
#: EWAの係数、小さいほど最近の値を反映("ewa"の時のみ有効)
horizon_ewa_disclam: float = 0.1
#: λ-returnの係数("h-return"の時のみ有効)
horizon_h_return: float = 0.95
#: 割引率
discount: float = 0.997
# ---Training
#: dynamics model training flag
enable_train_model: bool = True
#: critic model training flag
enable_train_critic: bool = True
#: actor model training flag
enable_train_actor: bool = True
#: batch length
batch_length: int = 64
#: <:ref:`scheduler`> dynamics model learning rate
lr_model: Union[float, SchedulerConfig] = 1e-4
#: <:ref:`scheduler`> critic model learning rate
lr_critic: Union[float, SchedulerConfig] = 3e-5
#: <:ref:`scheduler`>actor model learning rate
lr_actor: Union[float, SchedulerConfig] = 3e-5
#: loss計算の方法
#:
#: Parameters:
#: "dreamer_v1" : Vの最大化
#: "dreamer_v2" : Vとエントロピーの最大化
#: "dreamer_v3" : V2 + パーセンタイルによる正規化
actor_loss_type: str = "dreamer_v3"
#: actionがCONTINUOUSの場合のReinforceとDynamics backpropの比率
actor_reinforce_rate: float = 0.0
#: entropy rate
entropy_rate: float = 0.0003
#: baseline
#:
#: Parameters:
#: "v" : -v
#: other: none
reinforce_baseline: str = "v"
# --- other
#: action ε-greedy(for debug)
epsilon: float = 0
#: 報酬の前処理
#:
#: Parameters:
#: "none": なし
#: "tanh": tanh
clip_rewards: str = "none"
def __post_init__(self):
super().__post_init__()
def get_changeable_parameters(self) -> List[str]:
return [
"free_nats",
"loss_scale_pred",
"loss_scale_kl_dyn",
"loss_scale_kl_rep",
"critic_target_update_interval",
"critic_target_soft_update",
"horizon",
"horizon_policy",
"critic_estimation_method",
"horizon_ewa_disclam",
"horizon_h_return",
"discount",
"enable_train_model",
"enable_train_actor",
"enable_train_critic",
"batch_length",
"batch_size",
"lr_model",
"lr_critic",
"lr_actor",
"actor_loss_type",
"entropy_rate",
"reinforce_baseline",
"epsilon",
"clip_rewards",
]
def set_dreamer_v1(self):
# --- RSSM
self.rssm_deter_size = 200
self.rssm_stoch_size = 30
self.rssm_hidden_units = 400
self.rssm_use_norm_layer = False
self.rssm_use_categorical_distribution = False
self.rssm_activation = "elu"
self.rssm_unimix = 0
# --- other model layers
self.reward_type = "normal_fixed_scale"
self.reward_layer_sizes = (400, 400)
self.cont_layer_sizes = (400, 400)
self.critic_layer_sizes = (400, 400, 400)
self.actor_layer_sizes = (400, 400, 400, 400)
self.dense_act = "elu"
self.use_symlog = False
# --- encoder/decoder
self.encoder_decoder_mlp = (400, 400, 400, 400)
self.encoder_decoder_dist = "normal"
self.cnn_depth = 32
self.cnn_blocks = 0
self.cnn_resized_image_size = 1
self.cnn_activation = "relu"
self.cnn_normalization_type = "none"
self.cnn_resize_type = "stride"
self.cnn_resized_image_size = 4
self.cnn_use_sigmoid = False
# --- loss params
self.free_nats = 3.0
self.loss_scale_pred = 1.0
self.loss_scale_kl_dyn = 0.5
self.loss_scale_kl_rep = 0.5
# --- actor/critic
self.critic_target_update_interval = 0
self.critic_type = "normal_fixed_scale"
self.actor_discrete_unimix = 0
# Behavior
self.horizon = 15
self.critic_estimation_method: str = "ewa"
self.horizon_ewa_disclam = 0.1
self.discount: float = 0.99
# Training
self.batch_size = 50
self.batch_length = 50
self.lr_model = 6e-4
self.lr_critic = 8e-5
self.lr_actor = 8e-5
self.actor_loss_type = "dreamer_v1"
def set_dreamer_v2(self):
# --- RSSM
self.rssm_deter_size = 1024
self.rssm_stoch_size = 32
self.rssm_classes = 32
self.rssm_hidden_units = 1024
self.rssm_use_norm_layer = False
self.rssm_use_categorical_distribution = True
self.rssm_activation = "elu"
self.rssm_unimix = 0
# --- other model layers
self.reward_type = "linear"
self.reward_layer_sizes = (400, 400, 400, 400)
self.cont_layer_sizes = (400, 400, 400, 400)
self.critic_layer_sizes = (400, 400, 400, 400)
self.actor_layer_sizes = (400, 400, 400, 400)
self.dense_act = "elu"
self.use_symlog = False
# --- encoder/decoder
self.encoder_decoder_mlp = (400, 400, 400, 400)
self.encoder_decoder_dist = "normal"
self.cnn_depth = 48
self.cnn_blocks = 0
self.cnn_resized_image_size = 1
self.cnn_activation = "relu"
self.cnn_normalization_type = "none"
self.cnn_resize_type = "stride"
self.cnn_resized_image_size = 4
self.cnn_use_sigmoid = False
# --- loss params
self.free_nats = 0.0
self.loss_scale_pred = 1.0
self.loss_scale_kl_dyn = 0.8
self.loss_scale_kl_rep = 0.2
# --- actor/critic
self.critic_target_update_interval = 100
self.critic_target_soft_update = 1
self.critic_type = "normal_fixed_scale"
self.actor_discrete_unimix = 0
# Behavior
self.horizon = 15
self.critic_estimation_method: str = "h-return"
self.horizon_h_return = 0.95
self.discount: float = 0.99
# Training
self.batch_size = 16
self.batch_length = 50
self.lr_model = 1e-4
self.lr_critic = 2e-4
self.lr_actor = 8e-5
self.actor_loss_type = "dreamer_v2"
self.entropy_rate: float = 2e-3
self.reinforce_baseline: str = "v"
def set_dreamer_v3(self):
# --- RSSM
self.rssm_deter_size = 4096
self.rssm_stoch_size = 32
self.rssm_classes = 32
self.rssm_hidden_units = 1024
self.rssm_use_norm_layer = True
self.rssm_use_categorical_distribution = True
self.rssm_activation = "silu"
self.rssm_unimix = 0.01
# --- other model layers
self.reward_type = "twohot"
self.reward_twohot_bins = 255
self.reward_twohot_low = -20
self.reward_twohot_high = 20
self.reward_layer_sizes = (1024, 1024, 1024, 1024)
self.cont_layer_sizes = (1024, 1024, 1024, 1024)
self.critic_layer_sizes = (1024, 1024, 1024, 1024)
self.actor_layer_sizes = (1024, 1024, 1024, 1024)
self.dense_act = "silu"
self.use_symlog = True
# --- encoder/decoder
self.encoder_decoder_mlp = (1024, 1024, 1024, 1024, 1024)
self.encoder_decoder_dist = "linear"
self.cnn_depth = 96
self.cnn_blocks = 0
self.cnn_resized_image_size = 4
self.cnn_activation = "silu"
self.cnn_normalization_type = "layer"
self.cnn_resize_type = "stride"
self.cnn_resized_image_size = 4
self.cnn_use_sigmoid = False
# --- loss params
self.free_nats = 1.0
self.loss_scale_pred = 1.0
self.loss_scale_kl_dyn = 0.5
self.loss_scale_kl_rep = 0.1
# --- actor/critic
self.critic_target_update_interval = 0
self.critic_type = "twohot"
self.critic_twohot_bins = 255
self.critic_twohot_low = -20
self.critic_twohot_high = 20
self.actor_discrete_unimix = 0.01
# Behavior
self.horizon = 333
self.critic_estimation_method: str = "h-return"
self.horizon_h_return = 0.95
self.discount: float = 0.997
# Training
self.batch_size = 16
self.batch_length = 64
self.lr_model = 1e-4
self.lr_critic = 3e-5
self.lr_actor = 3e-5
self.actor_loss_type = "dreamer_v3"
self.entropy_rate: float = 3e-4
self.reinforce_baseline: str = "v"
def get_processors(self) -> List[Processor]:
if self.cnn_resize_type == "stride3":
return [
ImageProcessor(
image_type=SpaceTypes.COLOR,
resize=(96, 96),
enable_norm=True,
)
]
else:
return [
ImageProcessor(
image_type=SpaceTypes.COLOR,
resize=(64, 64),
enable_norm=True,
)
]
def get_framework(self) -> str:
return "tensorflow"
def get_name(self) -> str:
return "DreamerV3"
def assert_params(self) -> None:
super().assert_params()
self.assert_params_memory()
assert self.horizon >= 0
register(
Config(),
__name__ + ":Memory",
__name__ + ":Parameter",
__name__ + ":Trainer",
__name__ + ":Worker",
)
# ------------------------------------------------------
# Memory
# ------------------------------------------------------
class Memory(ExperienceReplayBuffer):
pass
# ------------------------------------------------------
# network
# ------------------------------------------------------
class RSSM(keras.Model):
def __init__(
self,
deter: int,
stoch: int,
classes: int,
hidden_units: int,
unimix: float,
activation: Any,
use_norm_layer: bool,
use_categorical_distribution: bool,
**kwargs,
):
super().__init__(**kwargs)
self.use_categorical_distribution = use_categorical_distribution
self.stoch_size = stoch
self.classes = classes
self.unimix = unimix
# --- img step
self.img_in_layers = [kl.Dense(hidden_units)]
if use_norm_layer:
self.img_in_layers.append(kl.LayerNormalization())
self.img_in_layers.append(kl.Activation(activation))
self.gru_cell = kl.GRUCell(deter)
self.img_out_layers = [kl.Dense(hidden_units)]
if use_norm_layer:
self.img_out_layers.append(kl.LayerNormalization())
self.img_out_layers.append(kl.Activation(activation))
# --- obs step
self.obs_layers = [kl.Dense(hidden_units)]
if use_norm_layer:
self.obs_layers.append(kl.LayerNormalization())
self.obs_layers.append(kl.Activation(activation))
self.concat_layer = kl.Concatenate(axis=-1)
# --- dist
if self.use_categorical_distribution:
self.img_cat_dist_layers = [
kl.Dense(stoch * classes, kernel_initializer="zeros"),
kl.Reshape((stoch, classes)),
]
self.obs_cat_dist_layers = [
kl.Dense(stoch * classes, kernel_initializer="zeros"),
kl.Reshape((stoch, classes)),
]
else:
self.img_norm_dist_block = NormalDistBlock(stoch * classes, (), (), ())
self.obs_norm_dist_block = NormalDistBlock(stoch * classes, (), (), ())
def img_step(self, prev_stoch, prev_deter, prev_onehot_action, training: bool = False):
# --- NN
x = tf.concat([prev_stoch, prev_onehot_action], -1)
for layer in self.img_in_layers:
x = layer(x, training=training)
x, deter = self.gru_cell(x, [prev_deter], training=training)
deter = deter[0]
for layer in self.img_out_layers:
x = layer(x, training=training)
# --- dist
if self.use_categorical_distribution:
for h in self.img_cat_dist_layers:
x = h(x)
# (batch, stoch, classes) -> (batch * stoch, classes)
batch = x.shape[0]
x = tf.reshape(x, (batch * self.stoch_size, self.classes))
dist = CategoricalDist(x, self.unimix)
# (batch * stoch, classes) -> (batch, stoch, classes) -> (batch, stoch * classes)
stoch = tf.cast(
tf.reshape(dist.rsample(), (batch, self.stoch_size, self.classes)),
tf.float32,
)
stoch = tf.reshape(stoch, (batch, self.stoch_size * self.classes))
# (batch * stoch, classes)
probs = dist.probs()
prior = {"stoch": stoch, "probs": probs}
else:
dist = self.img_norm_dist_block(x)
prior = {
"stoch": dist.rsample(),
"mean": dist.mean(),
"stddev": dist.stddev(),
}
return deter, prior
def obs_step(self, deter, embed, training=False):
# --- NN
x = tf.concat([deter, embed], -1)
for layer in self.obs_layers:
x = layer(x, training=training)
# --- dist
if self.use_categorical_distribution:
for h in self.obs_cat_dist_layers:
x = h(x)
# (batch, stoch, classes) -> (batch * stoch, classes)
batch = x.shape[0]
x = tf.reshape(x, (batch * self.stoch_size, self.classes))
dist = CategoricalDist(x, self.unimix)
# (batch * stoch, classes) -> (batch, stoch, classes) -> (batch, stoch * classes)
stoch = tf.cast(
tf.reshape(dist.rsample(), (batch, self.stoch_size, self.classes)),
tf.float32,
)
stoch = tf.reshape(stoch, (batch, self.stoch_size * self.classes))
# (batch * stoch, classes)
probs = dist.probs()
post = {"stoch": stoch, "probs": probs}
else:
dist = self.obs_norm_dist_block(x)
post = {
"stoch": dist.rsample(),
"mean": dist.mean(),
"stddev": dist.stddev(),
}
return post
def get_initial_state(self, batch_size: int = 1):
stoch = tf.zeros((batch_size, self.stoch_size * self.classes), dtype=self.dtype)
if v216_older:
deter = self.gru_cell.get_initial_state(None, batch_size, dtype=self.dtype)
else:
deter = self.gru_cell.get_initial_state(batch_size)[0]
return stoch, deter
@tf.function
def compute_train_loss(self, embed, actions, stoch, deter, undone, batch_size, batch_length, free_nats):
# (seq*batch, shape) -> (seq, batch, shape)
embed = tf.reshape(embed, (batch_length, batch_size) + embed.shape[1:])
undone = tf.reshape(undone, (batch_length, batch_size) + undone.shape[1:])
# --- batch seq step
stochs = []
deters = []
if self.use_categorical_distribution:
post_probs = []
prior_probs = []
for i in range(batch_length):
deter, prior = self.img_step(stoch, deter, actions[i], training=True)
post = self.obs_step(deter, embed[i], training=True)
stoch = post["stoch"]
stochs.append(stoch)
deters.append(deter)
post_probs.append(post["probs"])
prior_probs.append(prior["probs"])
# 終了時は初期化
stoch = stoch * undone[i]
deter = deter * undone[i]
post_probs = tf.stack(post_probs, axis=0)
prior_probs = tf.stack(prior_probs, axis=0)
# 多分KLの計算でlogが使われるので確率0があるとinfになる
post_probs = tf.clip_by_value(post_probs, 1e-10, 1) # log(0)回避用
prior_probs = tf.clip_by_value(prior_probs, 1e-10, 1) # log(0)回避用
post_dist = tfd.OneHotCategorical(probs=post_probs)
prior_dist = tfd.OneHotCategorical(probs=prior_probs)
else:
post_mean = []
post_std = []
prior_mean = []
prior_std = []
for i in range(batch_length):
deter, prior = self.img_step(stoch, deter, actions[i], training=True)
post = self.obs_step(deter, embed[i], training=True)
stoch = post["stoch"]
stochs.append(stoch)
deters.append(deter)
post_mean.append(post["mean"])
post_std.append(post["stddev"])
prior_mean.append(prior["mean"])
prior_std.append(prior["stddev"])
# 終了時は初期化
stoch = stoch * undone[i]
deter = deter * undone[i]
post_mean = tf.stack(post_mean, axis=0)
post_std = tf.stack(post_std, axis=0)
prior_mean = tf.stack(prior_mean, axis=0)
prior_std = tf.stack(prior_std, axis=0)
post_dist = tfd.Normal(post_mean, post_std)
prior_dist = tfd.Normal(prior_mean, prior_std)
stochs = tf.stack(stochs, axis=0)
deters = tf.stack(deters, axis=0)
# (seq, batch, shape) -> (seq*batch, shape)
stochs = tf.reshape(stochs, (batch_length * batch_size,) + stochs.shape[2:])
deters = tf.reshape(deters, (batch_length * batch_size,) + deters.shape[2:])
feats = self.concat_layer([stochs, deters])
# --- KL loss
kl_loss_dyn = tfd.kl_divergence(tf.stop_gradient(post_dist), prior_dist)
kl_loss_rep = tfd.kl_divergence(post_dist, tf.stop_gradient(prior_dist))
kl_loss_dyn = tf.reduce_mean(tf.maximum(kl_loss_dyn, free_nats))
kl_loss_rep = tf.reduce_mean(tf.maximum(kl_loss_rep, free_nats))
return stochs, deters, feats, kl_loss_dyn, kl_loss_rep, stoch, deter
def build_call(self, config: Config, embed_size: int):
self._embed_size = embed_size
in_stoch, in_deter = self.get_initial_state()
if isinstance(config.action_space, DiscreteSpace):
n = config.action_space.n
elif isinstance(config.action_space, ArrayContinuousSpace):
n = config.action_space.size
in_onehot_action = np.zeros((1, n), dtype=np.float32)
in_embed = np.zeros((1, embed_size), dtype=np.float32)
deter, prior = self.img_step(in_stoch, in_deter, in_onehot_action)
post = self.obs_step(deter, in_embed)
return self.concat_layer([post["stoch"], deter])
class ImageEncoder(keras.Model):
def __init__(
self,
img_shape: tuple,
depth: int,
res_blocks: int,
activation,
normalization_type: str,
resize_type: str,
resized_image_size: int,
**kwargs,
):
super().__init__(**kwargs)
assert normalization_type in ["none", "layer"]
self._in_shape = img_shape
self.img_shape = img_shape
_size = int(np.log2(min(img_shape[-3], img_shape[-2])))
_resize = int(np.log2(resized_image_size))
assert _size > _resize
self.stages = _size - _resize
if resize_type == "stride":
assert img_shape[-2] % (2**self.stages) == 0
assert img_shape[-3] % (2**self.stages) == 0
elif resize_type == "stride3":
assert (img_shape[-2] % ((2 ** (self.stages - 1)) * 3)) == 0
assert (img_shape[-3] % ((2 ** (self.stages - 1)) * 3)) == 0
elif resize_type == "max":
assert img_shape[-2] % (2**self.stages) == 0
assert img_shape[-3] % (2**self.stages) == 0
else:
raise NotImplementedError(resize_type)
_conv_kw: dict = dict(
padding="same",
kernel_initializer=tf.initializers.TruncatedNormal(),
bias_initializer="zero",
)
self.blocks = []
for i in range(self.stages):
# --- cnn
use_bias = normalization_type == "none"
if resize_type == "stride":
cnn_layers = [kl.Conv2D(depth, 4, 2, use_bias=use_bias, **_conv_kw)]
elif resize_type == "stride3":
s = 2 if i else 3
k = 5 if i else 4
cnn_layers = [kl.Conv2D(depth, k, s, use_bias=use_bias, **_conv_kw)]
elif resize_type == "mean":
cnn_layers = [
kl.Conv2D(depth, 3, 1, use_bias=use_bias, **_conv_kw),
kl.AveragePooling2D((3, 3), (2, 2), padding="same"),
]
elif resize_type == "max":
cnn_layers = [
kl.Conv2D(depth, 3, 1, use_bias=use_bias, **_conv_kw),
kl.MaxPooling2D((3, 3), (2, 2), padding="same"),
]
else:
raise NotImplementedError(resize_type)
if normalization_type == "layer":
cnn_layers.append(kl.LayerNormalization())
cnn_layers.append(kl.Activation(activation))
# --- res
res_blocks_layers = []
for _ in range(res_blocks):
res_layers = []
if normalization_type == "layer":
res_layers.append(kl.LayerNormalization())
res_layers.append(kl.Activation(activation))
res_layers.append(kl.Conv2D(depth, 3, 1, use_bias=True, **_conv_kw))
if normalization_type == "layer":
res_layers.append(kl.LayerNormalization())
res_layers.append(kl.Activation(activation))
res_layers.append(kl.Conv2D(depth, 3, 1, use_bias=True, **_conv_kw))
res_blocks_layers.append(res_layers)
self.blocks.append([cnn_layers, res_blocks_layers])
depth *= 2
self.out_layers = []
if res_blocks > 0:
self.out_layers.append(kl.Activation(activation))
self.out_layers.append(kl.Flatten())
dummy, img_shape = self._call(np.zeros((1,) + img_shape), return_size=True)
self.resized_img_shape = img_shape[1:]
self.out_size = dummy.shape[1]
@tf.function
def call(self, x, training=False):
return self._call(x, training=training)
def _call(self, x, training=False, return_size=False):
x = x - 0.5
for block in self.blocks:
# --- cnn
for h in block[0]:
x = h(x, training=training)
# --- res
for res_blocks in block[1]:
skip = x
for h in res_blocks:
x = h(x, training=training)
x += skip
x_out = x
for h in self.out_layers:
x_out = h(x_out, training=training)
if return_size:
return x_out, x.shape
else:
return x_out
class ImageDecoder(keras.Model):
def __init__(
self,
encoder: ImageEncoder,
use_sigmoid: bool,
depth: int,
res_blocks: int,
activation,
normalization_type: str,
resize_type: str,
dist_type: str,
**kwargs,
):
super().__init__(**kwargs)
self.use_sigmoid = use_sigmoid
self.dist_type = dist_type
stages = encoder.stages
depth = depth * 2 ** (encoder.stages - 1)
img_shape = encoder.img_shape
resized_img_shape = encoder.resized_img_shape
# --- in layers
self.in_layer = kl.Dense(resized_img_shape[0] * resized_img_shape[1] * resized_img_shape[2])
self.reshape_layer = kl.Reshape([resized_img_shape[0], resized_img_shape[1], resized_img_shape[2]])
# --- conv layers
_conv_kw: dict = dict(
kernel_initializer=tf.initializers.TruncatedNormal(),
bias_initializer="zero",
)
self.blocks = []
for i in range(encoder.stages):
# --- res
res_blocks_layers = []
for _ in range(res_blocks):
res_layers = []
if normalization_type == "layer":
res_layers.append(kl.LayerNormalization())
res_layers.append(kl.Activation(activation))
res_layers.append(kl.Conv2D(depth, 3, 1, padding="same", **_conv_kw))
if normalization_type == "layer":
res_layers.append(kl.LayerNormalization())
res_layers.append(kl.Activation(activation))
res_layers.append(kl.Conv2D(depth, 3, 1, padding="same", **_conv_kw))
res_blocks_layers.append(res_layers)
if i == stages - 1:
depth = img_shape[-1]
else:
depth //= 2
# --- cnn
use_bias = normalization_type == "none"
if resize_type == "stride":
cnn_layers = [kl.Conv2DTranspose(depth, 4, 2, use_bias=use_bias, padding="same", **_conv_kw)]
elif resize_type == "stride3":
s = 3 if i == stages - 1 else 2
k = 5 if i == stages - 1 else 4
cnn_layers = [kl.Conv2DTranspose(depth, k, s, use_bias=use_bias, padding="same", **_conv_kw)]
elif resize_type == "max":
cnn_layers = [
kl.UpSampling2D((2, 2)),
kl.Conv2D(depth, 3, 1, use_bias=use_bias, padding="same", **_conv_kw),
]
else:
raise NotImplementedError(resize_type)
if normalization_type == "layer":
cnn_layers.append(kl.LayerNormalization())
cnn_layers.append(kl.Activation(activation))
self.blocks.append([res_blocks_layers, cnn_layers])
if dist_type == "linear":
self.out_dist = LinearBlock(depth)
elif dist_type == "normal":
self.out_dist = NormalDistBlock(depth)
else:
raise UndefinedError(dist_type)
def call(self, x):
x = self.in_layer(x)
x = self.reshape_layer(x)
for block in self.blocks:
# --- res
for res_blocks in block[0]:
skip = x
for h in res_blocks:
x = h(x)
x += cast(Any, skip)
# --- cnn
for h in block[1]:
x = h(x)
if self.use_sigmoid:
x = tf.nn.sigmoid(x)
else:
x = cast(Any, x) + 0.5
return self.out_dist(x)
@tf.function
def compute_train_loss(self, feat, state):
dist = self(feat)
if self.dist_type == "linear":
return tf.reduce_mean(tf.square(state - dist.y))
elif self.dist_type == "normal":
return -tf.reduce_mean(dist.log_prob(state))
else:
raise UndefinedError(self.dist_type)
class LinearEncoder(keras.Model):
def __init__(
self,
hidden_layer_sizes: Tuple[int, ...],
activation: str,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_layers = []
for size in hidden_layer_sizes:
self.hidden_layers.append(kl.Dense(size, activation=activation))
self.out_size: int = hidden_layer_sizes[-1]
@tf.function
def call(self, x):
for layer in self.hidden_layers:
x = layer(x)
return x
# ------------------------------------------------------
# Parameter
# ------------------------------------------------------
class Parameter(RLParameter):
def __init__(self, *args):
super().__init__(*args)
self.config: Config = self.config
# --- encode/decode
if SpaceTypes.is_image(self.config.observation_space.stype):
self.encode = ImageEncoder(
self.config.observation_space.shape,
self.config.cnn_depth,
self.config.cnn_blocks,
self.config.cnn_activation,
self.config.cnn_normalization_type,
self.config.cnn_resize_type,
self.config.cnn_resized_image_size,
name="ImageEncoder",
)
self.decode = ImageDecoder(
self.encode,
self.config.cnn_use_sigmoid,
self.config.cnn_depth,
self.config.cnn_blocks,
self.config.cnn_activation,
self.config.cnn_normalization_type,
self.config.cnn_resize_type,
self.config.encoder_decoder_dist,
name="ImageDecoder",
)
logger.info(f"Encoder/Decoder: Image({self.config.encoder_decoder_dist})")
else:
self.encode = LinearEncoder(
self.config.encoder_decoder_mlp,
self.config.dense_act,
name="LinearEncoder",
)
if self.config.encoder_decoder_dist == "linear":
self.decode = LinearBlock(
self.config.observation_space.shape[-1],
list(reversed(self.config.encoder_decoder_mlp)),