1
- import random
2
- import numpy as np
3
- import math
4
- import itertools
5
-
6
- _WEIGHTS_MODES = {
7
- 'abs' : math .fabs ,
8
- 'reward' : lambda x : x ,
9
- 'same' : lambda _ : 1
10
- }
11
-
12
- class CebLinear :
13
- def __init__ (self , maxSize , sampleWeight = 'samp ' ):
14
- self .maxSize = maxSize
15
- self ._sizeLimit = math .floor (maxSize * 1.1 )
16
- self ._samples = []
17
- self ._sampleWeight = _WEIGHTS_MODES .get (sampleWeight , sampleWeight )
18
-
19
- def addEpisode (self , replay , terminated ):
20
- if 1 < len (replay ):
21
- for step in replay [:- 1 ]:
22
- self ._samples .append ((* step , 1 ))
23
- self ._samples .append ((* replay [- 1 ], - 1 if terminated else 0 ))
24
-
25
- self .update ()
26
- return
27
-
28
- def update (self ):
29
- if self ._sizeLimit < len (self ._samples ):
30
- self ._samples = self ._samples [- self .maxSize :]
31
- return
32
-
33
- def __len__ (self ):
34
- return len (self ._samples )
35
-
36
- def _fixRewardMultiplier (self , x ):
37
- if np .isscalar (x ):
38
- return abs (x )
39
-
40
- if isinstance (x , (np .ndarray , np .generic )):
41
- return np .abs (x )
42
-
43
- raise Exception ('Unknown reward type. (%s)' % type (x ))
44
-
45
- def _createBatch (self , batch_size , sampler ):
46
- samplesLeft = batch_size
47
- cumweights = list (itertools .accumulate (self ._sampleWeight (x [2 ]) for x in self ._samples ))
48
- indexRange = np .arange (len (self ._samples ))
49
- res = []
50
- while 0 < samplesLeft :
51
- indexes = set (random .choices (
52
- indexRange , cum_weights = cumweights ,
53
- k = min ((samplesLeft , len (self ._samples )))
54
- ))
55
-
56
- for i in indexes :
57
- sample = sampler (i )
58
- if sample :
59
- while len (res ) < len (sample ): res .append ([])
60
- for i , value in enumerate (sample [:- 1 ]):
61
- res [i ].append (value )
62
- res [- 1 ].append (self ._fixRewardMultiplier (sample [- 1 ]))
63
- samplesLeft -= 1
64
-
65
- return [np .array (values ) for values in res ]
66
-
67
- def sampleBatch (self , batch_size ):
68
- return self ._createBatch (batch_size , lambda i : self ._samples [i ])
69
-
70
- def sampleSequenceBatch (self , batch_size , sequenceLen , ** kwargs ):
71
- def sampler (ind ):
72
- sample = self ._samples [ind :ind + sequenceLen ]
73
- if not (sequenceLen == len (sample )): return None
74
- if 1 < sequenceLen :
75
- if any (x [- 1 ] < 1 for x in sample [:- 1 ]):
76
- return None
77
-
78
- transposed = [
79
- np .array ([x [col ] for x in sample ]) for col in range (len (sample [0 ]))
80
- ]
81
- return transposed
82
-
1
+ import random
2
+ import numpy as np
3
+ import math
4
+ import itertools
5
+
6
+ _WEIGHTS_MODES = {
7
+ 'abs' : math .fabs ,
8
+ 'reward' : lambda x : x ,
9
+ 'same' : lambda _ : 1
10
+ }
11
+
12
+ class CebLinear :
13
+ def __init__ (self , maxSize , sampleWeight = 'same ' ):
14
+ self .maxSize = maxSize
15
+ self ._sizeLimit = math .floor (maxSize * 1.1 )
16
+ self ._samples = []
17
+ self ._sampleWeight = _WEIGHTS_MODES .get (sampleWeight , sampleWeight )
18
+
19
+ def addEpisode (self , replay , terminated ):
20
+ if 1 < len (replay ):
21
+ for step in replay [:- 1 ]:
22
+ self ._samples .append ((* step , 1 ))
23
+ self ._samples .append ((* replay [- 1 ], - 1 if terminated else 0 ))
24
+
25
+ self .update ()
26
+ return
27
+
28
+ def update (self ):
29
+ if self ._sizeLimit < len (self ._samples ):
30
+ self ._samples = self ._samples [- self .maxSize :]
31
+ return
32
+
33
+ def __len__ (self ):
34
+ return len (self ._samples )
35
+
36
+ def _fixRewardMultiplier (self , x ):
37
+ if np .isscalar (x ):
38
+ return abs (x )
39
+
40
+ if isinstance (x , (np .ndarray , np .generic )):
41
+ return np .abs (x )
42
+
43
+ raise Exception ('Unknown reward type. (%s)' % type (x ))
44
+
45
+ def _createBatch (self , batch_size , sampler ):
46
+ samplesLeft = batch_size
47
+ cumweights = list (itertools .accumulate (self ._sampleWeight (x [2 ]) for x in self ._samples ))
48
+ indexRange = np .arange (len (self ._samples ))
49
+ res = []
50
+ while 0 < samplesLeft :
51
+ indexes = set (random .choices (
52
+ indexRange , cum_weights = cumweights ,
53
+ k = min ((samplesLeft , len (self ._samples )))
54
+ ))
55
+
56
+ for i in indexes :
57
+ sample = sampler (i )
58
+ if sample :
59
+ while len (res ) < len (sample ): res .append ([])
60
+ for i , value in enumerate (sample [:- 1 ]):
61
+ res [i ].append (value )
62
+ res [- 1 ].append (self ._fixRewardMultiplier (sample [- 1 ]))
63
+ samplesLeft -= 1
64
+
65
+ return [np .array (values ) for values in res ]
66
+
67
+ def sampleBatch (self , batch_size ):
68
+ return self ._createBatch (batch_size , lambda i : self ._samples [i ])
69
+
70
+ def sampleSequenceBatch (self , batch_size , sequenceLen , ** kwargs ):
71
+ def sampler (ind ):
72
+ sample = self ._samples [ind :ind + sequenceLen ]
73
+ if not (sequenceLen == len (sample )): return None
74
+ if 1 < sequenceLen :
75
+ if any (x [- 1 ] < 1 for x in sample [:- 1 ]):
76
+ return None
77
+
78
+ transposed = [
79
+ np .array ([x [col ] for x in sample ]) for col in range (len (sample [0 ]))
80
+ ]
81
+ return transposed
82
+
83
83
return self ._createBatch (batch_size , sampler )
0 commit comments