Skip to content

Commit

Permalink
Merge branch 'master' into ft-projector_mixing
Browse files Browse the repository at this point in the history
  • Loading branch information
tspooner committed May 4, 2018
2 parents 918749d + 5714e68 commit ed2c998
Show file tree
Hide file tree
Showing 16 changed files with 264 additions and 160 deletions.
6 changes: 4 additions & 2 deletions src/approximators/mod.rs
@@ -1,5 +1,5 @@
use error::*;
use projectors::{IndexT, IndexSet};
use projectors::{IndexSet, IndexT};
use std::collections::HashMap;

mod simple;
Expand Down Expand Up @@ -28,7 +28,9 @@ pub trait Approximator<I: ?Sized> {
impl<I: ?Sized, T: Approximator<I>> Approximator<I> for Box<T> {
type Value = T::Value;

fn evaluate(&self, input: &I) -> EvaluationResult<Self::Value> { (**self).evaluate(input) }
fn evaluate(&self, input: &I) -> EvaluationResult<Self::Value> {
(**self).evaluate(input)
}

fn update(&mut self, input: &I, update: Self::Value) -> UpdateResult<()> {
(**self).update(input, update)
Expand Down
54 changes: 29 additions & 25 deletions src/approximators/multi.rs
@@ -1,10 +1,10 @@
use approximators::Approximator;
use error::AdaptError;
use geometry::{Vector, Matrix};
use projectors::{Projection, IndexT, IndexSet};
use geometry::{Matrix, Vector};
use projectors::{IndexSet, IndexT, Projection};
use std::collections::HashMap;
use std::mem::replace;
use {EvaluationResult, UpdateResult, AdaptResult};
use {AdaptResult, EvaluationResult, UpdateResult};

#[derive(Clone, Serialize, Deserialize)]
pub struct Multi {
Expand All @@ -24,17 +24,16 @@ impl Multi {
let n_rows_new = new_rows.len();

// Weight matrix stored in row-major format.
let mut weights = unsafe {
replace(&mut self.weights, Matrix::uninitialized((0, 0))).into_raw_vec()
};
let mut weights =
unsafe { replace(&mut self.weights, Matrix::uninitialized((0, 0))).into_raw_vec() };

weights.reserve_exact(n_rows_new);

for row in new_rows {
weights.extend(row);
}

self.weights = Matrix::from_shape_vec((n_rows+n_rows_new, n_cols), weights).unwrap();
self.weights = Matrix::from_shape_vec((n_rows + n_rows_new, n_cols), weights).unwrap();
}
}

Expand Down Expand Up @@ -63,7 +62,7 @@ impl Approximator<Projection> for Multi {
let error_matrix = errors.view().into_shape((1, self.weights.cols())).unwrap();

self.weights.scaled_add(1.0 / z, &view.dot(&error_matrix))
},
}
&Projection::Sparse(ref sparse) => for c in 0..self.weights.cols() {
let mut col = self.weights.column_mut(c);
let error = errors[c];
Expand All @@ -82,25 +81,30 @@ impl Approximator<Projection> for Multi {

let max_index = self.weights.len() + n_nfs - 1;

let new_weights: Result<Vec<Vec<f64>>, _> = new_features.into_iter().map(|(&i, idx)| {
if i > max_index {
Err(AdaptError::Failed)
} else {
Ok((0..n_outputs).map(|c| {
let c = self.weights.column(c);

idx.iter().fold(0.0, |acc, r| acc + c[*r])
}).collect())
}
}).collect();
let new_weights: Result<Vec<Vec<f64>>, _> = new_features
.into_iter()
.map(|(&i, idx)| {
if i > max_index {
Err(AdaptError::Failed)
} else {
Ok((0..n_outputs)
.map(|c| {
let c = self.weights.column(c);

idx.iter().fold(0.0, |acc, r| acc + c[*r])
})
.collect())
}
})
.collect();

match new_weights {
Ok(new_weights) => {
self.append_weight_rows(new_weights);

Ok(n_nfs)
},
Err(err) => Err(err)
}
Err(err) => Err(err),
}
}
}
Expand All @@ -113,7 +117,7 @@ mod tests {
use approximators::{Approximator, Multi};
use geometry::Vector;
use projectors::fixed::{Fourier, TileCoding};
use std::collections::{HashMap, BTreeSet};
use std::collections::{BTreeSet, HashMap};
use std::hash::BuildHasherDefault;

type SHBuilder = BuildHasherDefault<seahash::SeaHasher>;
Expand Down Expand Up @@ -168,9 +172,9 @@ mod tests {
let c0 = f.weights.column(0);
let c1 = f.weights.column(1);

assert_eq!(c0[100], c0[10]/2.0 + c0[90]/2.0);
assert_eq!(c1[100], c1[10]/2.0 + c1[90]/2.0);
},
assert_eq!(c0[100], c0[10] / 2.0 + c0[90] / 2.0);
assert_eq!(c1[100], c1[10] / 2.0 + c1[90] / 2.0);
}
Err(err) => panic!("Simple::adapt failed with AdaptError::{:?}", err),
}
}
Expand Down
34 changes: 18 additions & 16 deletions src/approximators/simple.rs
@@ -1,10 +1,10 @@
use approximators::Approximator;
use error::AdaptError;
use geometry::Vector;
use projectors::{Projection, IndexT, IndexSet};
use projectors::{IndexSet, IndexT, Projection};
use std::collections::HashMap;
use std::mem::replace;
use {EvaluationResult, UpdateResult, AdaptResult};
use {AdaptResult, EvaluationResult, UpdateResult};

#[derive(Clone, Serialize, Deserialize)]
pub struct Simple {
Expand All @@ -19,9 +19,8 @@ impl Simple {
}

fn extend_weights(&mut self, new_weights: Vec<f64>) {
let mut weights = unsafe {
replace(&mut self.weights, Vector::uninitialized((0,))).into_raw_vec()
};
let mut weights =
unsafe { replace(&mut self.weights, Vector::uninitialized((0,))).into_raw_vec() };

weights.extend(new_weights);

Expand All @@ -37,7 +36,7 @@ impl Approximator<Projection> for Simple {
&Projection::Dense(ref dense) => self.weights.dot(&(dense / p.z())),
&Projection::Sparse(ref sparse) => {
sparse.iter().fold(0.0, |acc, idx| acc + self.weights[*idx])
},
}
})
}

Expand All @@ -56,13 +55,16 @@ impl Approximator<Projection> for Simple {
let n_nfs = new_features.len();
let max_index = self.weights.len() + n_nfs - 1;

let new_weights: Result<Vec<f64>, _> = new_features.into_iter().map(|(&i, idx)| {
if i > max_index {
Err(AdaptError::Failed)
} else {
Ok(idx.iter().fold(0.0, |acc, j| acc + self.weights[*j]) / (idx.len() as f64))
}
}).collect();
let new_weights: Result<Vec<f64>, _> = new_features
.into_iter()
.map(|(&i, idx)| {
if i > max_index {
Err(AdaptError::Failed)
} else {
Ok(idx.iter().fold(0.0, |acc, j| acc + self.weights[*j]) / (idx.len() as f64))
}
})
.collect();

self.extend_weights(new_weights?);

Expand All @@ -77,7 +79,7 @@ mod tests {
use LFA;
use approximators::{Approximator, Simple};
use projectors::fixed::{Fourier, TileCoding};
use std::collections::{HashMap, BTreeSet};
use std::collections::{BTreeSet, HashMap};
use std::hash::BuildHasherDefault;

type SHBuilder = BuildHasherDefault<seahash::SeaHasher>;
Expand Down Expand Up @@ -126,8 +128,8 @@ mod tests {
Ok(n) => {
assert_eq!(n, 1);
assert_eq!(f.weights.len(), 101);
assert_eq!(f.weights[100], f.weights[10]/2.0 + f.weights[90]/2.0);
},
assert_eq!(f.weights[100], f.weights[10] / 2.0 + f.weights[90] / 2.0);
}
Err(err) => panic!("Simple::adapt failed with AdaptError::{:?}", err),
}
}
Expand Down
7 changes: 4 additions & 3 deletions src/lfa.rs
@@ -1,6 +1,6 @@
use approximators::{Approximator, Simple, Multi};
use approximators::{Approximator, Multi, Simple};
use error::*;
use projectors::{Projector, Projection, IndexT, IndexSet};
use projectors::{IndexSet, IndexT, Projection, Projector};
use std::collections::HashMap;
use std::marker::PhantomData;

Expand Down Expand Up @@ -52,7 +52,8 @@ impl<I: ?Sized, P: Projector<I>, A: Approximator<Projection>> Approximator<I> fo
}

fn update(&mut self, input: &I, update: Self::Value) -> UpdateResult<()> {
self.approximator.update(&self.projector.project(input), update)
self.approximator
.update(&self.projector.project(input), update)
}

fn adapt(&mut self, new_features: &HashMap<IndexT, IndexSet>) -> AdaptResult<usize> {
Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
@@ -1,6 +1,6 @@
extern crate rand;
extern crate ndarray;
extern crate itertools;
extern crate ndarray;
extern crate rand;

pub extern crate spaces as geometry;

Expand All @@ -14,7 +14,7 @@ mod error;
pub use self::error::*;

pub mod projectors;
pub use self::projectors::{Projection, Projector, AdaptiveProjector};
pub use self::projectors::{AdaptiveProjector, Projection, Projector};

pub mod approximators;
pub use self::approximators::Approximator;
Expand Down

0 comments on commit ed2c998

Please sign in to comment.