Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sample naming #248

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions algorithms/linfa-preprocessing/src/linear_scaling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,12 @@ impl<F: Float, D: Data<Elem = F>, T: AsTargets>
/// Panics if the shape of the records is not compatible with the shape of the dataset used for fitting.
fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
let feature_names = x.feature_names();
let sample_names = x.sample_names();
let (records, targets, weights) = (x.records, x.targets, x.weights);
let records = self.transform(records.to_owned());
DatasetBase::new(records, targets)
.with_weights(weights)
.with_sample_names(sample_names)
.with_feature_names(feature_names)
}
}
Expand Down Expand Up @@ -566,6 +568,17 @@ mod tests {
assert_eq!(original_feature_names, transformed.feature_names())
}

#[test]
fn test_retain_sample_names() {
let dataset = linfa_datasets::diabetes();
let original_sample_names = dataset.sample_names();
let transformed = LinearScaler::standard()
.fit(&dataset)
.unwrap()
.transform(dataset);
assert_eq!(original_sample_names, transformed.sample_names())
}

#[test]
#[should_panic]
fn test_transform_wrong_size_array_standard() {
Expand Down
10 changes: 10 additions & 0 deletions algorithms/linfa-preprocessing/src/norm_scaling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ impl<F: Float, D: Data<Elem = F>, T: AsTargets>
/// Substitutes the records of the dataset with their scaled versions with unit norm.
fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
let feature_names = x.feature_names();
let sample_names = x.sample_names();
let (records, targets, weights) = (x.records, x.targets, x.weights);
let records = self.transform(records.to_owned());
DatasetBase::new(records, targets)
.with_weights(weights)
.with_feature_names(feature_names)
.with_sample_names(sample_names)
}
}

Expand Down Expand Up @@ -151,4 +153,12 @@ mod tests {
let transformed = NormScaler::l2().transform(dataset);
assert_eq!(original_feature_names, transformed.feature_names())
}

#[test]
fn test_retain_sample_names() {
let dataset = linfa_datasets::diabetes();
let original_sample_names = dataset.sample_names();
let transformed = NormScaler::l2().transform(dataset);
assert_eq!(original_sample_names, transformed.sample_names())
}
}
13 changes: 13 additions & 0 deletions algorithms/linfa-preprocessing/src/whitening.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,12 @@ impl<F: Float, D: Data<Elem = F>, T: AsTargets>
{
fn transform(&self, x: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
let feature_names = x.feature_names();
let sample_names = x.sample_names();
let (records, targets, weights) = (x.records, x.targets, x.weights);
let records = self.transform(records.to_owned());
DatasetBase::new(records, targets)
.with_weights(weights)
.with_sample_names(sample_names)
.with_feature_names(feature_names)
}
}
Expand Down Expand Up @@ -324,6 +326,17 @@ mod tests {
assert_eq!(original_feature_names, transformed.feature_names())
}

#[test]
fn test_retain_sample_names() {
let dataset = linfa_datasets::diabetes();
let original_sample_names = dataset.sample_names();
let transformed = Whitener::cholesky()
.fit(&dataset)
.unwrap()
.transform(dataset);
assert_eq!(original_sample_names, transformed.sample_names())
}

#[test]
#[should_panic]
fn test_pca_fail_on_empty_input() {
Expand Down
5 changes: 4 additions & 1 deletion algorithms/linfa-reduction/src/pca.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ impl<F: Float, D: Data<Elem = F>, T>
Transformer<DatasetBase<ArrayBase<D, Ix2>, T>, DatasetBase<Array2<F>, T>> for Pca<F>
{
fn transform(&self, ds: DatasetBase<ArrayBase<D, Ix2>, T>) -> DatasetBase<Array2<F>, T> {
let sample_names = ds.sample_names();
let DatasetBase {
records,
targets,
Expand All @@ -211,7 +212,9 @@ impl<F: Float, D: Data<Elem = F>, T>
let mut new_records = self.default_target(&records);
self.predict_inplace(&records, &mut new_records);

DatasetBase::new(new_records, targets).with_weights(weights)
DatasetBase::new(new_records, targets)
.with_weights(weights)
.with_sample_names(sample_names)
}
}
#[cfg(test)]
Expand Down
8 changes: 6 additions & 2 deletions algorithms/linfa-tsne/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,19 @@ impl<T, F: Float, R: Rng + Clone>
for TSneValidParams<F, R>
{
fn transform(&self, ds: DatasetBase<Array2<F>, T>) -> Result<DatasetBase<Array2<F>, T>> {
let sample_names = ds.sample_names();
let DatasetBase {
records,
targets,
weights,
..
} = ds;

self.transform(records)
.map(|new_records| DatasetBase::new(new_records, targets).with_weights(weights))
self.transform(records).map(|new_records| {
DatasetBase::new(new_records, targets)
.with_weights(weights)
.with_sample_names(sample_names)
})
}
}

Expand Down
22 changes: 20 additions & 2 deletions datasets/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@ pub fn iris() -> Dataset<f64, usize, Ix1> {
);

let feature_names = vec!["sepal length", "sepal width", "petal length", "petal width"];
let sample_names = (0..data.nrows())
.map(|idx| format!("sample-{idx}"))
.collect();

Dataset::new(data, targets)
.map_targets(|x| *x as usize)
.with_feature_names(feature_names)
.with_sample_names(sample_names)
}

#[cfg(feature = "diabetes")]
Expand All @@ -57,8 +61,13 @@ pub fn diabetes() -> Dataset<f64, f64, Ix1> {
"lamotrigine",
"blood sugar level",
];
let sample_names = (0..data.nrows())
.map(|idx| format!("sample-{idx}"))
.collect();

Dataset::new(data, targets).with_feature_names(feature_names)
Dataset::new(data, targets)
.with_feature_names(feature_names)
.with_sample_names(sample_names)
}

#[cfg(feature = "winequality")]
Expand All @@ -85,10 +94,14 @@ pub fn winequality() -> Dataset<f64, usize, Ix1> {
"sulphates",
"alcohol",
];
let sample_names = (0..data.nrows())
.map(|idx| format!("sample-{idx}"))
.collect();

Dataset::new(data, targets)
.map_targets(|x| *x as usize)
.with_feature_names(feature_names)
.with_sample_names(sample_names)
}

#[cfg(feature = "linnerud")]
Expand All @@ -112,8 +125,13 @@ pub fn linnerud() -> Dataset<f64, f64> {
let output_array = array_from_buf(&output_data[..]);

let feature_names = vec!["Chins", "Situps", "Jumps"];
let sample_names = (0..input_array.nrows())
.map(|idx| format!("sample-{idx}"))
.collect();

Dataset::new(input_array, output_array).with_feature_names(feature_names)
Dataset::new(input_array, output_array)
.with_feature_names(feature_names)
.with_sample_names(sample_names)
}

#[cfg(test)]
Expand Down
76 changes: 74 additions & 2 deletions src/dataset/impl_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ impl<R: Records, S> DatasetBase<R, S> {
targets,
weights: Array1::zeros(0),
feature_names: Vec::new(),
sample_names: Vec::new(),
}
}

Expand Down Expand Up @@ -70,6 +71,19 @@ impl<R: Records, S> DatasetBase<R, S> {
}
}

/// Returns sample names
///
/// A row name gives a human-readable string describing the sample.
pub fn sample_names(&self) -> Vec<String> {
if !self.sample_names.is_empty() {
self.sample_names.clone()
} else {
(0..self.records.nsamples())
.map(|idx| format!("sample-{}", idx))
.collect()
}
}

/// Return records of a dataset
///
/// The records are data points from which predictions are made. This functions returns a
Expand All @@ -88,6 +102,7 @@ impl<R: Records, S> DatasetBase<R, S> {
targets: self.targets,
weights: Array1::zeros(0),
feature_names: Vec::new(),
sample_names: Vec::new(),
}
}

Expand All @@ -100,6 +115,7 @@ impl<R: Records, S> DatasetBase<R, S> {
targets,
weights: self.weights,
feature_names: self.feature_names,
sample_names: self.sample_names,
}
}

Expand All @@ -118,6 +134,29 @@ impl<R: Records, S> DatasetBase<R, S> {

self
}

/// Updates the row names of a dataset
///
/// ## Panics
///
/// This method will panic for any of the following three reasons:
///
/// - If the names vector length is different to nsamples
pub fn with_sample_names<I: Into<String>>(mut self, names: Vec<I>) -> DatasetBase<R, S> {
if names.len() == self.records().nsamples() {
let sample_names = names.into_iter().map(|x| x.into()).collect();

self.sample_names = sample_names;
} else if !names.is_empty() {
panic!(
"Sample names vector length, {}, is different to nsamples, {}.",
names.len(),
self.records().nsamples()
);
}

self
}
}

impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
Expand All @@ -143,6 +182,7 @@ impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
targets,
weights,
feature_names,
sample_names,
..
} = self;

Expand All @@ -153,6 +193,7 @@ impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
targets: targets.map(fnc),
weights,
feature_names,
sample_names,
}
}

Expand Down Expand Up @@ -215,6 +256,7 @@ where

DatasetBase::new(records, targets)
.with_feature_names(self.feature_names.clone())
.with_sample_names(self.sample_names.clone())
.with_weights(self.weights.clone())
}

Expand Down Expand Up @@ -287,13 +329,26 @@ where
} else {
(Array1::zeros(0), Array1::zeros(0))
};

let (first_sample_names, second_sample_names) =
if self.sample_names.len() == self.nsamples() {
(
self.sample_names.iter().take(n).collect(),
self.sample_names.iter().skip(n).collect(),
)
} else {
(Vec::new(), Vec::new())
};

let dataset1 = DatasetBase::new(records_first, targets_first)
.with_weights(first_weights)
.with_feature_names(self.feature_names.clone());
.with_feature_names(self.feature_names.clone())
.with_sample_names(first_sample_names);

let dataset2 = DatasetBase::new(records_second, targets_second)
.with_weights(second_weights)
.with_feature_names(self.feature_names.clone());
.with_feature_names(self.feature_names.clone())
.with_sample_names(second_sample_names);

(dataset1, dataset2)
}
Expand Down Expand Up @@ -339,6 +394,7 @@ where
label,
DatasetBase::new(self.records().view(), targets)
.with_feature_names(self.feature_names.clone())
.with_sample_names(self.sample_names.clone())
.with_weights(self.weights.clone()),
)
})
Expand Down Expand Up @@ -395,6 +451,7 @@ impl<F, D: Data<Elem = F>, I: Dimension> From<ArrayBase<D, I>>
targets: empty_targets,
weights: Array1::zeros(0),
feature_names: Vec::new(),
sample_names: Vec::new(),
}
}
}
Expand All @@ -411,6 +468,7 @@ where
targets: rec_tar.1,
weights: Array1::zeros(0),
feature_names: Vec::new(),
sample_names: Vec::new(),
}
}
}
Expand Down Expand Up @@ -977,12 +1035,26 @@ impl<F, E, I: TargetDim> Dataset<F, E, I> {
Array1::zeros(0)
};

// split sample_names into two disjoint Vec
let second_sample_names = if self.sample_names.len() == n1 + n2 {
let mut sample_names = self.sample_names;

let sample_names2 = sample_names.split_off(n1);
self.sample_names = sample_names;

sample_names2
} else {
Vec::new()
};

// create new datasets with attached weights
let dataset1 = Dataset::new(first, first_targets)
.with_weights(self.weights)
.with_sample_names(self.sample_names)
.with_feature_names(feature_names.clone());
let dataset2 = Dataset::new(second, second_targets)
.with_weights(second_weights)
.with_sample_names(second_sample_names)
.with_feature_names(feature_names);

(dataset1, dataset2)
Expand Down
1 change: 1 addition & 0 deletions src/dataset/impl_targets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ where
weights: Array1::from(weights),
targets,
feature_names: self.feature_names.clone(),
sample_names: self.sample_names.clone(),
}
}
}
2 changes: 2 additions & 0 deletions src/dataset/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ where
let mut targets = self.dataset.targets.as_targets();
let feature_names;
let weights = self.dataset.weights.clone();
let sample_names = self.dataset.sample_names.clone();

if !self.target_or_feature {
// This branch should only run for 2D targets
Expand All @@ -103,6 +104,7 @@ where
targets,
weights,
feature_names,
sample_names,
};

Some(dataset_view)
Expand Down
Loading