Skip to content

Commit

Permalink
ndarray for LargLists
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 9, 2020
1 parent 8f3e53c commit 73e18eb
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 5 deletions.
6 changes: 4 additions & 2 deletions examples/iris_classifier/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
polars = {path = "../../polars", features = ["random"]}
reqwest = {version = "0.10.8", features = ["blocking"]}
polars = {path = "../../polars", features = ["random", "ndarray"]}
reqwest = {version = "0.10.8", features = ["blocking"]}
ndarray = "0.13"
itertools = "0.9"
16 changes: 13 additions & 3 deletions examples/iris_classifier/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,17 @@
//! | 0.005 | 0.007 | 0.002 | 0.001 | "Iris-setosa" | "[0, 1, 0]" |
//! +--------------+-------------+-------------+--------------+---------------+-------------+
//!
use itertools::Itertools;
use ndarray::prelude::*;
use polars::prelude::*;
use reqwest;
use std::fs::File;
use std::io::Write;
use std::path::Path;

const FEATURES: [&str; 4] = ["sepal.length", "sepal.width", "petal.width", "petal.length"];
const LEARNING_RATE: f64 = 0.01;

fn download_iris() -> std::io::Result<()> {
let r = reqwest::blocking::get(
"https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data",
Expand Down Expand Up @@ -129,7 +134,7 @@ fn enforce_schema(mut df: DataFrame) -> Result<DataFrame> {
}

fn normalize(mut df: DataFrame) -> Result<DataFrame> {
let cols = &["sepal.length", "sepal.width", "petal.length", "petal.width"];
let cols = &FEATURES;

for &col in cols {
df.may_apply(col, |s| {
Expand Down Expand Up @@ -198,7 +203,13 @@ fn pipe() -> Result<DataFrame> {
.expect("could not ohe")
.pipe(print_state)
}
fn train(mut df: DataFrame) {
fn train(mut df: DataFrame) -> Result<()> {
let feat = df.select(&FEATURES)?.to_ndarray::<Float64Type>()?;

let target = df
.column("ohe")?
.large_list()?
.to_ndarray::<Float64Type>()?;
todo!()
}

Expand All @@ -208,5 +219,4 @@ fn main() {
}

let df = pipe().expect("could not prepare DataFrame");
train(df);
}
44 changes: 44 additions & 0 deletions polars/src/chunked_array/ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,50 @@ where
}
}

impl LargeListChunked {
/// If all nested `Series` have the same length, a 2 dimensional `ndarray::Array` is returned.
pub fn to_ndarray<N>(&self) -> Result<Array2<N::Native>>
where
N: PolarsNumericType,
{
if self.null_count() != 0 {
return Err(PolarsError::HasNullValues);
} else {
let mut iter = self.into_no_null_iter();

let mut ndarray;
let width;

// first iteration determine the size
if let Some(series) = iter.next() {
width = series.len();

ndarray = unsafe { Array::uninitialized((self.len(), series.len())) };

let series = series.cast::<N>()?;
let ca = series.unpack::<N>()?;
let a = ca.to_ndarray()?;
let mut row = ndarray.slice_mut(s![0, ..]);
row.assign(&a);

while let Some(series) = iter.next() {
if series.len() != width {
return Err(PolarsError::ShapeMisMatch);
}
let series = series.cast::<N>()?;
let ca = series.unpack::<N>()?;
let a = ca.to_ndarray()?;
let mut row = ndarray.slice_mut(s![0, ..]);
row.assign(&a)
}
Ok(ndarray)
} else {
Err(PolarsError::NoData)
}
}
}
}

impl DataFrame {
/// Create a 2D `ndarray::Array` from this `DataFrame`. This requires all columns in the
/// `DataFrame` to be non-null and numeric. They will be casted to the same data type
Expand Down
5 changes: 5 additions & 0 deletions polars/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,11 @@ impl Series {
unpack_series!(self, IntervalYearMonth)
}

/// Unpack to ChunkedArray
pub fn large_list(&self) -> Result<&LargeListChunked> {
unpack_series!(self, LargeList)
}

pub fn append_array(&mut self, other: ArrayRef) -> Result<&mut Self> {
apply_method_all_series!(self, append_array, other)?;
Ok(self)
Expand Down

0 comments on commit 73e18eb

Please sign in to comment.