Skip to content

Commit

Permalink
random init and iris pipeline update
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 6, 2020
1 parent abf96f2 commit 27d46ac
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 17 deletions.
1 change: 1 addition & 0 deletions examples/iris_classifier/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
iris.csv
3 changes: 2 additions & 1 deletion examples/iris_classifier/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
polars = {path = "../../polars"}
polars = {path = "../../polars", features = ["random"]}
reqwest = {version = "0.10.8", features = ["blocking"]}
168 changes: 153 additions & 15 deletions examples/iris_classifier/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,105 @@
//! Running this program outputs:
//!
//! +----------+----------+----------+----------+---------------+
//! | column_1 | column_2 | column_3 | column_4 | column_5 |
//! | --- | --- | --- | --- | --- |
//! | f64 | f64 | f64 | f64 | str |
//! +==========+==========+==========+==========+===============+
//! | 5.1 | 3.5 | 1.4 | 0.2 | "Iris-setosa" |
//! +----------+----------+----------+----------+---------------+
//! | 4.9 | 3 | 1.4 | 0.2 | "Iris-setosa" |
//! +----------+----------+----------+----------+---------------+
//! | 4.7 | 3.2 | 1.3 | 0.2 | "Iris-setosa" |
//! +----------+----------+----------+----------+---------------+
//!
//! +--------------+-------------+-------------+--------------+---------------+
//! | sepal.length | sepal.width | petal.width | petal.length | class |
//! | --- | --- | --- | --- | --- |
//! | f64 | f64 | f64 | f64 | str |
//! +==============+=============+=============+==============+===============+
//! | 5.1 | 3.5 | 1.4 | 0.2 | "Iris-setosa" |
//! +--------------+-------------+-------------+--------------+---------------+
//! | 4.9 | 3 | 1.4 | 0.2 | "Iris-setosa" |
//! +--------------+-------------+-------------+--------------+---------------+
//! | 4.7 | 3.2 | 1.3 | 0.2 | "Iris-setosa" |
//! +--------------+-------------+-------------+--------------+---------------+
//!
//! +--------------+-------------+-------------+--------------+---------------+
//! | sepal.length | sepal.width | petal.width | petal.length | class |
//! | --- | --- | --- | --- | --- |
//! | f64 | f64 | f64 | f64 | str |
//! +==============+=============+=============+==============+===============+
//! | 5.1 | 3.5 | 1.4 | 0.2 | "Iris-setosa" |
//! +--------------+-------------+-------------+--------------+---------------+
//! | 4.9 | 3 | 1.4 | 0.2 | "Iris-setosa" |
//! +--------------+-------------+-------------+--------------+---------------+
//! | 4.7 | 3.2 | 1.3 | 0.2 | "Iris-setosa" |
//! +--------------+-------------+-------------+--------------+---------------+
//!
//! +--------------+-------------+-------------+--------------+---------------+
//! | sepal.length | sepal.width | petal.width | petal.length | class |
//! | --- | --- | --- | --- | --- |
//! | f64 | f64 | f64 | f64 | str |
//! +==============+=============+=============+==============+===============+
//! | 0.006 | 0.008 | 0.002 | 0.001 | "Iris-setosa" |
//! +--------------+-------------+-------------+--------------+---------------+
//! | 0.006 | 0.007 | 0.002 | 0.001 | "Iris-setosa" |
//! +--------------+-------------+-------------+--------------+---------------+
//! | 0.005 | 0.007 | 0.002 | 0.001 | "Iris-setosa" |
//! +--------------+-------------+-------------+--------------+---------------+
//!
//! +--------------+-------------+-------------+--------------+---------------+-------------+
//! | sepal.length | sepal.width | petal.width | petal.length | class | ohe |
//! | --- | --- | --- | --- | --- | --- |
//! | f64 | f64 | f64 | f64 | str | list [u32] |
//! +==============+=============+=============+==============+===============+=============+
//! | 0.006 | 0.008 | 0.002 | 0.001 | "Iris-setosa" | "[0, 1, 0]" |
//! +--------------+-------------+-------------+--------------+---------------+-------------+
//! | 0.006 | 0.007 | 0.002 | 0.001 | "Iris-setosa" | "[0, 1, 0]" |
//! +--------------+-------------+-------------+--------------+---------------+-------------+
//! | 0.005 | 0.007 | 0.002 | 0.001 | "Iris-setosa" | "[0, 1, 0]" |
//! +--------------+-------------+-------------+--------------+---------------+-------------+
//!
use polars::prelude::*;
use std::io::Cursor;
use reqwest;
use std::fs::File;
use std::io::Write;
use std::path::Path;

fn download_iris() -> std::io::Result<()> {
let r = reqwest::blocking::get(
"https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data",
)
.expect("could not download iris");
let mut f = File::create("iris.csv")?;
f.write_all(r.text().unwrap().as_bytes())
}

fn read_csv() -> Result<DataFrame> {
let s = r#""sepal.length","sepal.width","petal.length","petal.width","variety"
5.1,3.5,1.4,.2,"Setosa"
4.9,3,1.4,.2,"Setosa"
4.7,3.2,1.3,.2,"Setosa"
4.6,3.1,1.5,.2,"Setosa"
5,3.6,1.4,.2,"Setosa"
5.4,3.9,1.7,.4,"Setosa"
4.6,3.4,1.4,.3,"Setosa""#;

let file = Cursor::new(s);
let file = File::open("iris.csv").expect("could not read iris file");
CsvReader::new(file)
.infer_schema(Some(100))
.has_header(true)
.has_header(false)
.with_batch_size(100)
.finish()
}

fn rename_cols(mut df: DataFrame) -> Result<DataFrame> {
(0..5)
.zip(&[
"sepal.length",
"sepal.width",
"petal.width",
"petal.length",
"class",
])
.for_each(|(idx, name)| {
df[idx].rename(name);
});

Ok(df)
}

fn enforce_schema(mut df: DataFrame) -> Result<DataFrame> {
let dtypes = &[
ArrowDataType::Float64,
Expand Down Expand Up @@ -63,12 +144,69 @@ fn normalize(mut df: DataFrame) -> Result<DataFrame> {
Ok(df)
}

fn one_hot_encode(mut df: DataFrame) -> Result<DataFrame> {
let y = df["class"].utf8().unwrap();

let unique = y.unique();
let n_unique = unique.len();

let mut ohe = y
.into_iter()
.map(|opt_s| {
let mut ohe = vec![0; n_unique];
let mut idx = 0;
for i in 0..n_unique {
if unique.get(i) == opt_s {
idx = i;
break;
}
}
ohe[idx] = 1;
match opt_s {
Some(s) => UInt32Chunked::new_from_slice(s, &ohe).into_series(),
None => UInt32Chunked::new_from_slice("null", &ohe).into_series(),
}
})
.collect::<Series>();
ohe.rename("ohe");
df.add_column(ohe)?;

Ok(df)
}

fn print_state(df: DataFrame) -> Result<DataFrame> {
println!("{:?}", df.head(Some(3)));
Ok(df)
}

fn pipe() -> Result<DataFrame> {
read_csv()?.pipe(enforce_schema)?.pipe(normalize)
read_csv()?
.pipe(print_state)
.unwrap()
.pipe(rename_cols)
.expect("could not rename columns")
.pipe(print_state)
.unwrap()
.pipe(enforce_schema)
.expect("could not enforce schema")
.pipe(print_state)
.unwrap()
.pipe(normalize)?
.pipe(print_state)
.unwrap()
.pipe(one_hot_encode)
.expect("could not ohe")
.pipe(print_state)
}
fn train(mut df: DataFrame) {
todo!()
}

fn main() {
let df = pipe().unwrap();
if !Path::new("iris.csv").exists() {
download_iris().expect("could not create file")
}

println!("{:?}", df);
let df = pipe().expect("could not prepare DataFrame");
train(df);
}
5 changes: 4 additions & 1 deletion polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ simd = ["arrow/packed_simd"]
docs = []
temporal = ["chrono"]
parquet_ser = ["parquet"]
random = ["rand", "rand_distr"]
default = ["pretty", "docs", "temporal"]

[dependencies]
Expand All @@ -30,4 +31,6 @@ prettytable-rs = { version="^0.8.0", features=["win_crlf"], optional = true, def
crossbeam = "^0.7"
chrono = {version = "^0.4.13", optional = true}
enum_dispatch = "^0.3.2"
parquet = {version = "1.0.1", optional = true}
parquet = {version = "1.0.1", optional = true}
rand = {version = "0.7.3", optional = true}
rand_distr = {version = "0.3.0", optional = true}
2 changes: 2 additions & 0 deletions polars/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ pub mod cast;
pub mod chunkops;
pub mod comparison;
pub mod iterator;
#[cfg(feature = "random")]
pub mod random;
pub mod set;
pub mod take;
#[cfg(feature = "temporal")]
Expand Down
33 changes: 33 additions & 0 deletions polars/src/chunked_array/random.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use crate::prelude::*;
use num::{Float, NumCast};
use rand::prelude::*;
use rand_distr::{Distribution, Normal, StandardNormal};

impl<T> ChunkedArray<T>
where
T: PolarsNumericType,
T::Native: Float + NumCast,
{
/// Create `ChunkedArray` with samples from a Normal distribution.
pub fn rand_normal(name: &str, length: usize, mean: f64, std_dev: f64) -> Result<Self> {
let normal = Normal::new(mean, std_dev)?;
let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
for _ in 0..length {
let smpl = normal.sample(&mut rand::thread_rng());
let smpl = NumCast::from(smpl).unwrap();
builder.append_value(smpl)
}
Ok(builder.finish())
}

/// Create `ChunkedArray` with samples from a Standard Normal distribution.
pub fn rand_standard_normal(name: &str, length: usize) -> Self {
let mut builder = PrimitiveChunkedBuilder::<T>::new(name, length);
for _ in 0..length {
let smpl: f64 = thread_rng().sample(StandardNormal);
let smpl = NumCast::from(smpl).unwrap();
builder.append_value(smpl)
}
builder.finish()
}
}
3 changes: 3 additions & 0 deletions polars/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ pub enum PolarsError {
#[cfg(feature = "parquet_ser")]
#[error(transparent)]
ParquetError(#[from] parquet::errors::ParquetError),
#[cfg(feature = "random")]
#[error(transparent)]
RandError(#[from] rand_distr::NormalError),
}

pub type Result<T> = std::result::Result<T, PolarsError>;

0 comments on commit 27d46ac

Please sign in to comment.