@@ -262,27 +262,6 @@ def __init__(self):
262262 self .drawn_vars = dict ()
263263
264264
265- class _DrawValuesContextDetacher (_DrawValuesContext ,
266- metaclass = InitContextMeta ):
267- """
268- Context manager that starts a new drawn variables context copying the
269- parent's context drawn_vars dict. The following changes do not affect the
270- parent contexts but do affect the subsequent calls. This can be used to
271- iterate the same random method many times to get different results, while
272- respecting the drawn variables from previous contexts.
273- """
274- def __new__ (cls , * args , ** kwargs ):
275- return super ().__new__ (cls )
276-
277- def __init__ (self ):
278- self .drawn_vars = self .drawn_vars .copy ()
279-
280- def update_parent (self ):
281- parent = self .parent
282- if parent is not None :
283- parent .drawn_vars .update (self .drawn_vars )
284-
285-
286265def is_fast_drawable (var ):
287266 return isinstance (var , (numbers .Number ,
288267 np .ndarray ,
@@ -660,20 +639,30 @@ def generate_samples(generator, *args, **kwargs):
660639 samples = generator (size = broadcast_shape , * args , ** kwargs )
661640 elif dist_shape == broadcast_shape :
662641 samples = generator (size = size_tup + dist_shape , * args , ** kwargs )
663- elif len (dist_shape ) == 0 and size_tup and broadcast_shape [:len (size_tup )] == size_tup :
664- # Input's dist_shape is scalar, but it has size repetitions.
665- # So now the size matches but we have to manually broadcast to
666- # the right dist_shape
667- samples = [generator (* args , ** kwargs )]
668- if samples [0 ].shape == broadcast_shape :
669- samples = samples [0 ]
642+ elif len (dist_shape ) == 0 and size_tup and broadcast_shape :
643+ # There is no dist_shape (scalar distribution) but the parameters
644+ # broadcast shape and size_tup determine the size to provide to
645+ # the generator
646+ if broadcast_shape [:len (size_tup )] == size_tup :
647+ # Input's dist_shape is scalar, but it has size repetitions.
648+ # So now the size matches but we have to manually broadcast to
649+ # the right dist_shape
650+ samples = [generator (* args , ** kwargs )]
651+ if samples [0 ].shape == broadcast_shape :
652+ samples = samples [0 ]
653+ else :
654+ suffix = broadcast_shape [len (size_tup ):] + dist_shape
655+ samples .extend ([generator (* args , ** kwargs ).
656+ reshape (broadcast_shape )[..., np .newaxis ]
657+ for _ in range (np .prod (suffix ,
658+ dtype = int ) - 1 )])
659+ samples = np .hstack (samples ).reshape (size_tup + suffix )
670660 else :
671- suffix = broadcast_shape [len (size_tup ):] + dist_shape
672- samples .extend ([generator (* args , ** kwargs ).
673- reshape (broadcast_shape )[..., np .newaxis ]
674- for _ in range (np .prod (suffix ,
675- dtype = int ) - 1 )])
676- samples = np .hstack (samples ).reshape (size_tup + suffix )
661+ # The parameter shape is given, but we have to concatenate it
662+ # with the size tuple
663+ samples = generator (size = size_tup + broadcast_shape ,
664+ * args ,
665+ ** kwargs )
677666 else :
678667 samples = None
679668 # Args have been broadcast correctly, can just ask for the right shape out
0 commit comments