Skip to content

Commit

Permalink
add ndarray tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 4, 2021
1 parent 10ccde4 commit d0ed0e9
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions polars/polars-core/src/chunked_array/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,47 @@ impl DataFrame {
Ok(ndarr)
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_ndarray_from_ca() -> Result<()> {
let ca = Float64Chunked::new_from_slice("", &[1.0, 2.0, 3.0]);
let ndarr = ca.to_ndarray()?;
assert_eq!(ndarr, ArrayView1::from(&[1.0, 2.0, 3.0]));

let mut builder = ListPrimitiveChunkedBuilder::new("", 10, 10, DataType::Float64);
builder.append_slice(Some(&[1.0, 2.0, 3.0]));
builder.append_slice(Some(&[2.0, 4.0, 5.0]));
builder.append_slice(Some(&[6.0, 7.0, 8.0]));
let list = builder.finish();

let ndarr = list.to_ndarray::<Float64Type>()?;
let expected = array![[1.0, 2.0, 3.0], [2.0, 4.0, 5.0], [6.0, 7.0, 8.0]];
assert_eq!(ndarr, expected);

// test list array that is not square
let mut builder = ListPrimitiveChunkedBuilder::new("", 10, 10, DataType::Float64);
builder.append_slice(Some(&[1.0, 2.0, 3.0]));
builder.append_slice(Some(&[2.0]));
builder.append_slice(Some(&[6.0, 7.0, 8.0]));
let list = builder.finish();
assert!(list.to_ndarray::<Float64Type>().is_err());
Ok(())
}

#[test]
fn test_ndarray_from_df() -> Result<()> {
let df = df!["a"=> [1.0, 2.0, 3.0],
"b" => [2.0, 3.0, 4.0]
]?;

let ndarr = df.to_ndarray::<Float64Type>()?;
let expected = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
assert_eq!(ndarr, expected);

Ok(())
}
}

0 comments on commit d0ed0e9

Please sign in to comment.