@@ -91,6 +91,9 @@ def __init__(self, spaces: Tuple[str, ...], ndims: int, nelems: int, npoints: in
91
91
The number of points.
92
92
'''
93
93
94
+ for i , space in enumerate (spaces ):
95
+ if space in spaces [i + 1 :]:
96
+ raise ValueError (f'All spaces in a `Sample` must be unique, but space { space } is repeated.' )
94
97
self .spaces = spaces
95
98
self .ndims = ndims
96
99
self .nelems = nelems
@@ -479,6 +482,21 @@ def zip(*samples: 'Sample') -> 'Sample':
479
482
480
483
return _Zip (* samples )
481
484
485
+ def rename_spaces (self , map : Mapping [str , str ], / ) -> 'Sample' :
486
+ '''Return a :class:`Sample` with spaces renamed according to ``map``.
487
+
488
+ Args
489
+ ----
490
+ map : mapping of :class:`str` to :class:`str`
491
+ A mapping of old space to new space.
492
+
493
+ Returns
494
+ -------
495
+ renamed : :class:`Sample`
496
+ '''
497
+
498
+ raise NotImplementedError
499
+
482
500
483
501
class _TransformChainsSample (Sample ):
484
502
@@ -554,6 +572,9 @@ def get_evaluable_indices(self, ielem: evaluable.Array) -> evaluable.Array:
554
572
def _bind (self , func : function .Array ) -> function .Array :
555
573
return _ConcatenatePoints (func , self )
556
574
575
+ def rename_spaces (self , map : Mapping [str , str ], / ) -> Sample :
576
+ return _DefaultIndex (map .get (self .space , self .space ), self .transforms , self .points )
577
+
557
578
558
579
class _CustomIndex (_TransformChainsSample ):
559
580
@@ -578,6 +599,9 @@ def tri(self) -> numpy.ndarray:
578
599
def hull (self ) -> numpy .ndarray :
579
600
return numpy .take (self ._index , self ._parent .hull )
580
601
602
+ def rename_spaces (self , map : Mapping [str , str ], / ) -> Sample :
603
+ return _CustomIndex (self ._parent .rename_spaces (map ), self ._index )
604
+
581
605
582
606
if os .environ .get ('NUTILS_TENSORIAL' , None ) == 'test' : # pragma: nocover
583
607
@@ -627,7 +651,7 @@ def get_evaluable_weights(self, __ielem: evaluable.Array) -> evaluable.Array:
627
651
return evaluable .Zeros ((evaluable .constant (0 ),) * len (self .spaces ), dtype = float )
628
652
629
653
def get_lower_args (self , __ielem : evaluable .Array ) -> function .LowerArgs :
630
- return function .LowerArgs ((), {}, {} )
654
+ return function .LowerArgs . empty ( )
631
655
632
656
def get_element_tri (self , ielem : int ) -> numpy .ndarray :
633
657
raise IndexError ('index out of range' )
@@ -647,6 +671,9 @@ def _bind(self, func: function.Array) -> function.Array:
647
671
def basis (self , interpolation : str = 'none' ) -> function .Array :
648
672
return function .zeros ((0 ,), float )
649
673
674
+ def rename_spaces (self , map : Mapping [str , str ], / ) -> Sample :
675
+ return _Empty (tuple (map .get (space , space ) for space in self .spaces ), self .ndims )
676
+
650
677
651
678
class _Add (_TensorialSample ):
652
679
@@ -694,6 +721,9 @@ def _integral(self, func: function.Array) -> function.Array:
694
721
def _bind (self , func : function .Array ) -> function .Array :
695
722
return numpy .concatenate ([self ._sample1 ._bind (func ), self ._sample2 ._bind (func )])
696
723
724
+ def rename_spaces (self , map : Mapping [str , str ], / ) -> Sample :
725
+ return _Add (self ._sample1 .rename_spaces (map ), self ._sample2 .rename_spaces (map ))
726
+
697
727
698
728
def _simplex_strip (strip ):
699
729
# Helper function that creates simplices for an extruded simplex, with
@@ -807,7 +837,7 @@ def get_evaluable_weights(self, __ielem: evaluable.Array) -> evaluable.Array:
807
837
808
838
def get_lower_args (self , __ielem : evaluable .Array ) -> function .LowerArgs :
809
839
ielem1 , ielem2 = evaluable .divmod (__ielem , self ._sample2 .nelems )
810
- return self ._sample1 .get_lower_args (ielem1 ) | self ._sample2 .get_lower_args (ielem2 )
840
+ return self ._sample1 .get_lower_args (ielem1 ) * self ._sample2 .get_lower_args (ielem2 )
811
841
812
842
@property
813
843
def _reversed_factors (self ):
@@ -900,6 +930,9 @@ def basis(self, interpolation: str = 'none') -> Sample:
900
930
assert basis1 .ndim == basis2 .ndim == 1
901
931
return numpy .ravel (basis1 [:, None ] * basis2 [None , :])
902
932
933
+ def rename_spaces (self , map : Mapping [str , str ], / ) -> Sample :
934
+ return _Mul (self ._sample1 .rename_spaces (map ), self ._sample2 .rename_spaces (map ))
935
+
903
936
904
937
class _Zip (Sample ):
905
938
@@ -940,15 +973,16 @@ def _getslice(self, ielem):
940
973
941
974
def get_lower_args (self , __ielem : evaluable .Array ) -> function .LowerArgs :
942
975
points_shape = evaluable .Take (evaluable .Constant (self ._sizes ), __ielem ),
943
- coordinates = {}
944
- transform_chains = {}
976
+ args = function .LowerArgs .empty (points_shape )
945
977
for samplei , ielemsi , ilocalsi in zip (self ._samples , self ._ielems , self ._ilocals ):
946
- argsi = samplei .get_lower_args (evaluable .Take (evaluable .Constant (ielemsi ), __ielem ))
947
978
slicei = evaluable .Take (evaluable .Constant (ilocalsi ), self ._getslice (__ielem ))
948
- transform_chains .update (argsi .transform_chains )
949
- for space , coords in argsi .coordinates .items ():
950
- coordinates [space ] = evaluable .Transpose .to_end (evaluable .Take (evaluable ._flat (evaluable .Transpose .from_end (coords , 0 ), ndim = 2 ), slicei ), 0 )
951
- return function .LowerArgs (points_shape , transform_chains , coordinates )
979
+ args += samplei \
980
+ .get_lower_args (evaluable .Take (evaluable .Constant (ielemsi ), __ielem )) \
981
+ .map_coordinates (
982
+ points_shape ,
983
+ lambda coords : evaluable .Transpose .to_end (evaluable .Take (evaluable ._flat (evaluable .Transpose .from_end (coords , 0 ), ndim = 2 ), slicei ), 0 ),
984
+ )
985
+ return args
952
986
953
987
def get_evaluable_indices (self , ielem ):
954
988
return evaluable .Take (evaluable .Constant (self ._indices ), self ._getslice (ielem ))
@@ -959,6 +993,9 @@ def get_evaluable_weights(self, ielem):
959
993
weights = self ._samples [0 ].get_evaluable_weights (ielem0 )
960
994
return evaluable ._take (evaluable ._flat (weights ), slice0 , axis = 0 )
961
995
996
+ def rename_spaces (self , map : Mapping [str , str ], / ) -> Sample :
997
+ return _Zip (* [smpl .rename_spaces (map ) for smpl in self ._samples ])
998
+
962
999
963
1000
class _TakeElements (_TensorialSample ):
964
1001
@@ -1014,6 +1051,9 @@ def get_element_hull(self, __ielem: int) -> numpy.ndarray:
1014
1051
def take_elements (self , __indices : numpy .ndarray ) -> Sample :
1015
1052
return self ._parent .take_elements (numpy .take (self ._indices , __indices ))
1016
1053
1054
+ def rename_spaces (self , map : Mapping [str , str ], / ) -> Sample :
1055
+ return _TakeElements (self ._parent .rename_spaces (map ), self ._indices )
1056
+
1017
1057
1018
1058
class _Integral (function .Array ):
1019
1059
@@ -1023,9 +1063,9 @@ def __init__(self, integrand: function.Array, sample: Sample) -> None:
1023
1063
super ().__init__ (shape = integrand .shape , dtype = float if integrand .dtype in (bool , int ) else integrand .dtype , spaces = integrand .spaces - frozenset (sample .spaces ), arguments = integrand .arguments )
1024
1064
1025
1065
def lower (self , args : function .LowerArgs ) -> evaluable .Array :
1026
- ielem = evaluable .loop_index ('_sample_' + '_' . join ( self . _sample . spaces ) , self ._sample .nelems )
1066
+ ielem = evaluable .loop_index (f '_sample_{ len ( args . args ) } ' , self ._sample .nelems )
1027
1067
weights = evaluable .astype (self ._sample .get_evaluable_weights (ielem ), self .dtype )
1028
- integrand = evaluable .astype (self ._integrand .lower (args | self ._sample .get_lower_args (ielem )), self .dtype )
1068
+ integrand = evaluable .astype (self ._integrand .lower (args * self ._sample .get_lower_args (ielem )), self .dtype )
1029
1069
elem_integral = evaluable .einsum ('B,ABC->AC' , weights , integrand , B = weights .ndim , C = self .ndim )
1030
1070
return evaluable .loop_sum (elem_integral , ielem )
1031
1071
@@ -1039,8 +1079,8 @@ def __init__(self, func: function.Array, sample: _TransformChainsSample) -> None
1039
1079
1040
1080
def lower (self , args : function .LowerArgs ) -> evaluable .Array :
1041
1081
axis = len (args .points_shape )
1042
- ielem = evaluable .loop_index ('_sample_' + '_' . join ( self . _sample . spaces ) , self ._sample .nelems )
1043
- args | = self ._sample .get_lower_args (ielem )
1082
+ ielem = evaluable .loop_index (f '_sample_{ len ( args . args ) } ' , self ._sample .nelems )
1083
+ args * = self ._sample .get_lower_args (ielem )
1044
1084
func = self ._func .lower (args )
1045
1085
func = evaluable .Transpose .to_end (func , * range (axis , len (args .points_shape )))
1046
1086
for i in range (len (args .points_shape ) - axis - 1 ):
@@ -1073,7 +1113,8 @@ def __init__(self, sample: _TransformChainsSample, interpolation: str) -> None:
1073
1113
super ().__init__ (shape = (sample .npoints ,), dtype = float , spaces = frozenset ({sample .space }), arguments = {})
1074
1114
1075
1115
def lower (self , args : function .LowerArgs ) -> evaluable .Array :
1076
- aligned_space_coords = args .coordinates [self ._sample .space ]
1116
+ arg = args [self ._sample .space ]
1117
+ aligned_space_coords = arg .coordinates
1077
1118
assert aligned_space_coords .ndim == len (args .points_shape ) + 1
1078
1119
space_coords , where = evaluable .unalign (aligned_space_coords )
1079
1120
# Reinsert the coordinate axis, the last axis of `aligned_space_coords`, or
@@ -1085,9 +1126,8 @@ def lower(self, args: function.LowerArgs) -> evaluable.Array:
1085
1126
space_coords = evaluable .Transpose (space_coords , numpy .argsort (where ))
1086
1127
where = tuple (sorted (where ))
1087
1128
1088
- (chain , * _ ), tip_index = args .transform_chains [self ._sample .space ]
1089
- index = evaluable .TransformIndex (self ._sample .transforms [0 ], chain , tip_index )
1090
- coords = evaluable .TransformCoords (self ._sample .transforms [0 ], chain , tip_index , space_coords )
1129
+ index = evaluable .TransformIndex (self ._sample .transforms [0 ], arg .transforms , arg .index )
1130
+ coords = evaluable .TransformCoords (self ._sample .transforms [0 ], arg .transforms , arg .index , space_coords )
1091
1131
expect = self ._sample .points .get_evaluable_coords (index )
1092
1132
sampled = evaluable .Sampled (coords , expect , self ._interpolation )
1093
1133
indices = self ._sample .get_evaluable_indices (index )
0 commit comments