Skip to content

Commit a843a93

Browse files
committed
Fix logic to sequence based indexing in row/col/slice functions
Prior to this change, the following functions were not checking for additional dimensional data beyond the dimension concerned with the particular function. - row - col - slice - rows - cols - slices - set_row - set_col - slice - set_rows - set_cols - set_slices Similar logic was missing in one particular matching pattern of view macro which is also fixed in this change. Few additional unit tests are added in macro and index module checking for the pitfalls this change addresses
1 parent 6be6c54 commit a843a93

File tree

2 files changed

+148
-32
lines changed

2 files changed

+148
-32
lines changed

src/core/index.rs

+93-31
Original file line numberDiff line numberDiff line change
@@ -293,13 +293,11 @@ pub fn row<T>(input: &Array<T>, row_num: i64) -> Array<T>
293293
where
294294
T: HasAfEnum,
295295
{
296-
index(
297-
input,
298-
&[
299-
Seq::new(row_num as f64, row_num as f64, 1.0),
300-
Seq::default(),
301-
],
302-
)
296+
let mut seqs = vec![Seq::new(row_num as f64, row_num as f64, 1.0)];
297+
for _d in 1..input.dims().ndims() {
298+
seqs.push(Seq::default());
299+
}
300+
index(input, &seqs)
303301
}
304302

305303
/// Set `row_num`^th row in `inout` Array to a new Array `new_row`
@@ -308,7 +306,7 @@ where
308306
T: HasAfEnum,
309307
{
310308
let mut seqs = vec![Seq::new(row_num as f64, row_num as f64, 1.0)];
311-
if inout.dims().ndims() > 1 {
309+
for _d in 1..inout.dims().ndims() {
312310
seqs.push(Seq::default());
313311
}
314312
assign_seq(inout, &seqs, new_row)
@@ -320,10 +318,11 @@ where
320318
T: HasAfEnum,
321319
{
322320
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
323-
index(
324-
input,
325-
&[Seq::new(first as f64, last as f64, step), Seq::default()],
326-
)
321+
let mut seqs = vec![Seq::new(first as f64, last as f64, step)];
322+
for _d in 1..input.dims().ndims() {
323+
seqs.push(Seq::default());
324+
}
325+
index(input, &seqs)
327326
}
328327

329328
/// Set rows from `first` to `last` in `inout` Array with rows from Array `new_rows`
@@ -332,7 +331,10 @@ where
332331
T: HasAfEnum,
333332
{
334333
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
335-
let seqs = [Seq::new(first as f64, last as f64, step), Seq::default()];
334+
let mut seqs = vec![Seq::new(first as f64, last as f64, step)];
335+
for _d in 1..inout.dims().ndims() {
336+
seqs.push(Seq::default());
337+
}
336338
assign_seq(inout, &seqs, new_rows)
337339
}
338340

@@ -352,24 +354,28 @@ pub fn col<T>(input: &Array<T>, col_num: i64) -> Array<T>
352354
where
353355
T: HasAfEnum,
354356
{
355-
index(
356-
input,
357-
&[
358-
Seq::default(),
359-
Seq::new(col_num as f64, col_num as f64, 1.0),
360-
],
361-
)
357+
let mut seqs = vec![
358+
Seq::default(),
359+
Seq::new(col_num as f64, col_num as f64, 1.0),
360+
];
361+
for _d in 2..input.dims().ndims() {
362+
seqs.push(Seq::default());
363+
}
364+
index(input, &seqs)
362365
}
363366

364367
/// Set `col_num`^th col in `inout` Array to a new Array `new_col`
365368
pub fn set_col<T>(inout: &mut Array<T>, new_col: &Array<T>, col_num: i64)
366369
where
367370
T: HasAfEnum,
368371
{
369-
let seqs = [
372+
let mut seqs = vec![
370373
Seq::default(),
371374
Seq::new(col_num as f64, col_num as f64, 1.0),
372375
];
376+
for _d in 2..inout.dims().ndims() {
377+
seqs.push(Seq::default());
378+
}
373379
assign_seq(inout, &seqs, new_col)
374380
}
375381

@@ -379,10 +385,11 @@ where
379385
T: HasAfEnum,
380386
{
381387
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
382-
index(
383-
input,
384-
&[Seq::default(), Seq::new(first as f64, last as f64, step)],
385-
)
388+
let mut seqs = vec![Seq::default(), Seq::new(first as f64, last as f64, step)];
389+
for _d in 2..input.dims().ndims() {
390+
seqs.push(Seq::default());
391+
}
392+
index(input, &seqs)
386393
}
387394

388395
/// Set cols from `first` to `last` in `inout` Array with cols from Array `new_cols`
@@ -391,7 +398,10 @@ where
391398
T: HasAfEnum,
392399
{
393400
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
394-
let seqs = [Seq::default(), Seq::new(first as f64, last as f64, step)];
401+
let mut seqs = vec![Seq::default(), Seq::new(first as f64, last as f64, step)];
402+
for _d in 2..inout.dims().ndims() {
403+
seqs.push(Seq::default());
404+
}
395405
assign_seq(inout, &seqs, new_cols)
396406
}
397407

@@ -402,11 +412,14 @@ pub fn slice<T>(input: &Array<T>, slice_num: i64) -> Array<T>
402412
where
403413
T: HasAfEnum,
404414
{
405-
let seqs = [
415+
let mut seqs = vec![
406416
Seq::default(),
407417
Seq::default(),
408418
Seq::new(slice_num as f64, slice_num as f64, 1.0),
409419
];
420+
for _d in 3..input.dims().ndims() {
421+
seqs.push(Seq::default());
422+
}
410423
index(input, &seqs)
411424
}
412425

@@ -417,11 +430,14 @@ pub fn set_slice<T>(inout: &mut Array<T>, new_slice: &Array<T>, slice_num: i64)
417430
where
418431
T: HasAfEnum,
419432
{
420-
let seqs = [
433+
let mut seqs = vec![
421434
Seq::default(),
422435
Seq::default(),
423436
Seq::new(slice_num as f64, slice_num as f64, 1.0),
424437
];
438+
for _d in 3..inout.dims().ndims() {
439+
seqs.push(Seq::default());
440+
}
425441
assign_seq(inout, &seqs, new_slice)
426442
}
427443

@@ -433,11 +449,14 @@ where
433449
T: HasAfEnum,
434450
{
435451
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
436-
let seqs = [
452+
let mut seqs = vec![
437453
Seq::default(),
438454
Seq::default(),
439455
Seq::new(first as f64, last as f64, step),
440456
];
457+
for _d in 3..input.dims().ndims() {
458+
seqs.push(Seq::default());
459+
}
441460
index(input, &seqs)
442461
}
443462

@@ -449,11 +468,14 @@ where
449468
T: HasAfEnum,
450469
{
451470
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
452-
let seqs = [
471+
let mut seqs = vec![
453472
Seq::default(),
454473
Seq::default(),
455474
Seq::new(first as f64, last as f64, step),
456475
];
476+
for _d in 3..inout.dims().ndims() {
477+
seqs.push(Seq::default());
478+
}
457479
assign_seq(inout, &seqs, new_slices)
458480
}
459481

@@ -655,7 +677,7 @@ mod tests {
655677
use super::super::device::set_device;
656678
use super::super::dim4::Dim4;
657679
use super::super::index::{assign_gen, assign_seq, col, index, index_gen, row, Indexer};
658-
use super::super::index::{cols, rows};
680+
use super::super::index::{cols, rows, set_row, set_rows};
659681
use super::super::random::randu;
660682
use super::super::seq::Seq;
661683

@@ -868,4 +890,44 @@ mod tests {
868890
// 0.9675 0.3712 0.7896
869891
// ANCHOR_END: get_rows
870892
}
893+
894+
#[test]
895+
fn change_row() {
896+
set_device(0);
897+
898+
let v0: Vec<bool> = vec![true, true, true, true, true, true];
899+
let mut a0 = Array::new(&v0, dim4!(v0.len() as u64));
900+
901+
let v1: Vec<bool> = vec![false];
902+
let a1 = Array::new(&v1, dim4!(v1.len() as u64));
903+
904+
set_row(&mut a0, &a1, 2);
905+
906+
let mut res = vec![true; a0.elements()];
907+
a0.host(&mut res);
908+
909+
let gold = vec![true, true, false, true, true, true];
910+
911+
assert_eq!(gold, res);
912+
}
913+
914+
#[test]
915+
fn change_rows() {
916+
set_device(0);
917+
918+
let v0: Vec<bool> = vec![true, true, true, true, true, true];
919+
let mut a0 = Array::new(&v0, dim4!(v0.len() as u64));
920+
921+
let v1: Vec<bool> = vec![false, false];
922+
let a1 = Array::new(&v1, dim4!(v1.len() as u64));
923+
924+
set_rows(&mut a0, &a1, 2, 3);
925+
926+
let mut res = vec![true; a0.elements()];
927+
a0.host(&mut res);
928+
929+
let gold = vec![true, true, false, false, true, true];
930+
931+
assert_eq!(gold, res);
932+
}
871933
}

src/core/macros.rs

+55-1
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,9 @@ macro_rules! view {
190190
$(
191191
seq_vec.push($crate::seq!($start:$end:$step));
192192
)*
193+
for _d in seq_vec.len()..$array_ident.dims().ndims() {
194+
seq_vec.push($crate::seq!());
195+
}
193196
$crate::index(&$array_ident, &seq_vec)
194197
}
195198
};
@@ -354,7 +357,7 @@ mod tests {
354357
use super::super::array::Array;
355358
use super::super::data::constant;
356359
use super::super::device::set_device;
357-
use super::super::index::index;
360+
use super::super::index::{index, rows, set_rows};
358361
use super::super::random::randu;
359362

360363
#[test]
@@ -505,4 +508,55 @@ mod tests {
505508
let _ruu32_5x5 = randu!(u32; 5, 5);
506509
let _ruu8_5x5 = randu!(u8; 5, 5);
507510
}
511+
512+
#[test]
513+
fn match_eval_macro_with_set_rows() {
514+
set_device(0);
515+
516+
let inpt = vec![true, true, true, true, true, true, true, true, true, true];
517+
let gold = vec![
518+
true, true, false, false, true, true, true, false, false, true,
519+
];
520+
521+
let mut orig_arr = Array::new(&inpt, dim4!(5, 2));
522+
let mut orig_cln = orig_arr.clone();
523+
524+
let new_vals = vec![false, false, false, false];
525+
let new_arr = Array::new(&new_vals, dim4!(2, 2));
526+
527+
eval!( orig_arr[2:3:1,1:1:0] = new_arr );
528+
let mut res1 = vec![true; orig_arr.elements()];
529+
orig_arr.host(&mut res1);
530+
531+
set_rows(&mut orig_cln, &new_arr, 2, 3);
532+
let mut res2 = vec![true; orig_cln.elements()];
533+
orig_cln.host(&mut res2);
534+
535+
assert_eq!(gold, res1);
536+
assert_eq!(res1, res2);
537+
}
538+
539+
#[test]
540+
fn match_view_macro_with_get_rows() {
541+
set_device(0);
542+
543+
let inpt: Vec<i32> = (0..10).collect();
544+
let gold: Vec<i32> = vec![2, 3, 7, 8];
545+
546+
println!("input {:?}", inpt);
547+
println!("gold {:?}", gold);
548+
549+
let orig_arr = Array::new(&inpt, dim4!(5, 2));
550+
551+
let view_out = view!( orig_arr[2:3:1] );
552+
let mut res1 = vec![0i32; view_out.elements()];
553+
view_out.host(&mut res1);
554+
555+
let rows_out = rows(&orig_arr, 2, 3);
556+
let mut res2 = vec![0i32; rows_out.elements()];
557+
rows_out.host(&mut res2);
558+
559+
assert_eq!(gold, res1);
560+
assert_eq!(res1, res2);
561+
}
508562
}

0 commit comments

Comments
 (0)