@@ -8,6 +8,7 @@ use crate::{
8
8
ops:: {
9
9
broadcast_arrays_device, broadcast_to_device,
10
10
indexing:: { count_non_new_axis_operations, expand_ellipsis_operations} ,
11
+ reshape_device,
11
12
} ,
12
13
utils:: { resolve_index_signed_unchecked, VectorArray } ,
13
14
Array , Stream ,
@@ -84,40 +85,59 @@ fn update_slice(
84
85
let operations = expand_ellipsis_operations ( ndim, operations) ;
85
86
86
87
// If no non-None indices return the broadcasted update
87
- if count_non_new_axis_operations ( & operations) == 0 {
88
+ let non_new_axis_operation_count = count_non_new_axis_operations ( & operations) ;
89
+ if non_new_axis_operation_count == 0 {
88
90
return Ok ( Some ( broadcast_to_device ( & update, src. shape ( ) , & stream) ?) ) ;
89
91
}
90
92
91
93
// Process entries
92
- let mut update_expand_dims: SmallVec < [ i32 ; DEFAULT_STACK_VEC_LEN ] > = SmallVec :: new ( ) ;
93
- let mut axis = 0i32 ;
94
- for item in operations. iter ( ) {
94
+ // let mut update_expand_dims: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = SmallVec::new();
95
+ let mut update_reshape: SmallVec < [ i32 ; DEFAULT_STACK_VEC_LEN ] > = smallvec ! [ 0 ; ndim] ;
96
+ let mut axis = src. ndim ( ) - 1 ;
97
+ let mut update_axis = update. ndim ( ) as i32 - 1 ;
98
+
99
+ while axis >= non_new_axis_operation_count {
100
+ if update_axis >= 0 {
101
+ update_reshape[ axis] = update. dim ( update_axis) ;
102
+ update_axis -= 1 ;
103
+ } else {
104
+ update_reshape[ axis] = 1 ;
105
+ }
106
+ axis -= 1 ;
107
+ }
108
+
109
+ for item in operations. iter ( ) . rev ( ) {
95
110
use ArrayIndexOp :: * ;
96
111
97
112
match item {
98
113
TakeIndex { index } => {
99
- let size = src. dim ( axis) ;
114
+ let size = src. dim ( axis as i32 ) ;
100
115
let index = if index. is_negative ( ) {
101
116
size + index
102
117
} else {
103
118
* index
104
119
} ;
105
120
// SAFETY: axis is always non-negative
106
- starts[ axis as usize ] = index;
107
- ends[ axis as usize ] = index + 1 ;
108
- if ndim - ( axis as usize ) < update. ndim ( ) {
109
- update_expand_dims. push ( axis. saturating_sub_unsigned ( ndim as u32 ) ) ;
110
- }
121
+ starts[ axis] = index;
122
+ ends[ axis] = index. saturating_add ( 1 ) ;
111
123
112
- axis = axis. saturating_add ( 1 ) ;
124
+ update_reshape[ axis] = 1 ;
125
+ axis = axis. saturating_sub ( 1 ) ;
113
126
}
114
- Slice ( range_index ) => {
115
- let size = src. dim ( axis) ;
127
+ Slice ( slice ) => {
128
+ let size = src. dim ( axis as i32 ) ;
116
129
// SAFETY: axis is always non-negative
117
- starts[ axis as usize ] = range_index. start ( size) ;
118
- ends[ axis as usize ] = range_index. end ( size) ;
119
- strides[ axis as usize ] = range_index. stride ( ) ;
120
- axis = axis. saturating_add ( 1 ) ;
130
+ starts[ axis] = slice. start ( size) ;
131
+ ends[ axis] = slice. end ( size) ;
132
+ strides[ axis] = slice. stride ( ) ;
133
+
134
+ if update_axis >= 0 {
135
+ update_reshape[ axis] = update. dim ( update_axis) ;
136
+ update_axis = update_axis. saturating_sub ( 1 ) ;
137
+ } else {
138
+ update_reshape[ axis] = 1 ;
139
+ }
140
+ axis = axis. saturating_sub ( 1 ) ;
121
141
}
122
142
ExpandDims => { }
123
143
Ellipsis | TakeArray { indices : _ } | TakeArrayRef { indices : _ } => {
@@ -126,8 +146,8 @@ fn update_slice(
126
146
}
127
147
}
128
148
129
- if !update_expand_dims . is_empty ( ) {
130
- update = Cow :: Owned ( update . expand_dims_axes_device ( & update_expand_dims , & stream) ?) ;
149
+ if update . shape ( ) != & update_reshape [ .. ] {
150
+ update = Cow :: Owned ( reshape_device ( update , & update_reshape , & stream) ?) ;
131
151
}
132
152
133
153
Ok ( Some ( src. slice_update_device (
@@ -1277,7 +1297,10 @@ where
1277
1297
/// The unit tests below are adapted from the Swift binding tests
1278
1298
#[ cfg( test) ]
1279
1299
mod tests {
1280
- use crate :: { ops:: indexing:: * , Array } ;
1300
+ use crate :: {
1301
+ ops:: { indexing:: * , ones, zeros} ,
1302
+ Array ,
1303
+ } ;
1281
1304
1282
1305
#[ test]
1283
1306
fn test_array_mutate_single_index ( ) {
@@ -1473,4 +1496,16 @@ mod tests {
1473
1496
128142
1474
1497
) ;
1475
1498
}
1499
+
1500
+ #[ test]
1501
+ fn test_slice_update_with_broadcast ( ) {
1502
+ let mut xs = zeros :: < f32 > ( & [ 4 , 3 , 2 ] ) . unwrap ( ) ;
1503
+ let x = ones :: < f32 > ( & [ 4 , 2 ] ) . unwrap ( ) ;
1504
+
1505
+ let result = xs. try_index_mut ( ( .., 0 , ..) , x) ;
1506
+ assert ! (
1507
+ result. is_ok( ) ,
1508
+ "Failed to update slice with broadcast: {result:?}"
1509
+ ) ;
1510
+ }
1476
1511
}
0 commit comments