Skip to content

Commit

Permalink
Feat/dataloader (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Aug 22, 2022
1 parent 1b8b1e3 commit 0f6b50f
Show file tree
Hide file tree
Showing 21 changed files with 589 additions and 62 deletions.
9 changes: 7 additions & 2 deletions Cargo.toml
Expand Up @@ -16,11 +16,16 @@ tch = ["burn-tensor/tch"]
ndarray = ["burn-tensor/ndarray"]

[dependencies]
num-traits = "0.2"
burn-tensor = { path = "./burn-tensor", version = "0.1.0", default-features = false }
burn-dataset = { path = "./burn-dataset", version = "0.1.0", default-features = false }
burn-derive = { path = "./burn-derive", version = "0.1.0" }
burn-dataset = { path = "./burn-dataset", version = "0.1.0" }

num-traits = "0.2"
derive-new = "0.5"
rand = "0.8"

serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

[dev-dependencies]
burn-dataset = { path = "./burn-dataset", version = "0.1.0", features = ["fake"] }
5 changes: 5 additions & 0 deletions burn-dataset/Cargo.toml
Expand Up @@ -14,10 +14,15 @@ categories = ["science"]
license = "MIT"
edition = "2021"

[features]
default = ["fake"]
fake = ["dep:fake"]

[dependencies]
dirs = "4.0"
rand = "0.8.4"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
fake = { version = "2.5", optional = true }

[dev-dependencies]
2 changes: 1 addition & 1 deletion burn-dataset/src/dataset/dataset.rs
@@ -1,6 +1,6 @@
use crate::DatasetIterator;

pub trait Dataset<I> {
pub trait Dataset<I>: Send + Sync {
fn iter<'a>(&'a self) -> DatasetIterator<'a, I>;
fn get(&self, index: usize) -> Option<I>;
fn len(&self) -> usize;
Expand Down
32 changes: 32 additions & 0 deletions burn-dataset/src/dataset/fake.rs
@@ -0,0 +1,32 @@
use crate::{Dataset, DatasetIterator, InMemDataset};
use fake::{Dummy, Fake, Faker};

pub struct FakeDataset<I> {
dataset: InMemDataset<I>,
}

impl<I: Dummy<Faker>> FakeDataset<I> {
pub fn new(size: usize) -> Self {
let mut items = Vec::with_capacity(size);
for _ in 0..size {
items.push(Faker.fake());
}
let dataset = InMemDataset::new(items);

Self { dataset }
}
}

impl<I: Send + Sync + Clone> Dataset<I> for FakeDataset<I> {
fn iter<'a>(&'a self) -> DatasetIterator<'a, I> {
DatasetIterator::new(self)
}

fn get(&self, index: usize) -> Option<I> {
self.dataset.get(index)
}

fn len(&self) -> usize {
self.dataset.len()
}
}
2 changes: 1 addition & 1 deletion burn-dataset/src/dataset/in_memory.rs
Expand Up @@ -17,7 +17,7 @@ impl<I> InMemDataset<I> {

impl<I> Dataset<I> for InMemDataset<I>
where
I: Clone,
I: Clone + Send + Sync,
{
fn get(&self, index: usize) -> Option<I> {
match self.items.get(index) {
Expand Down
4 changes: 4 additions & 0 deletions burn-dataset/src/dataset/mod.rs
@@ -1,7 +1,11 @@
mod dataset;
#[cfg(feature = "fake")]
mod fake;
mod in_memory;
mod iterator;

#[cfg(feature = "fake")]
pub use self::fake::*;
pub use dataset::*;
pub use in_memory::*;
pub use iterator::*;
25 changes: 6 additions & 19 deletions burn-dataset/src/source/huggingface/mnist.rs
@@ -1,25 +1,24 @@
use super::downloader::cache_dir;
use crate::source::huggingface::downloader::{download, Extractor};
use crate::{Dataset, DatasetIterator, InMemDataset};
use serde::{Deserialize, Serialize};

use super::downloader::cache_dir;

#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct Item {
pub struct MNISTItem {
pub image: [[f32; 28]; 28],
pub label: usize,
}

pub struct MNISTDataset {
dataset: InMemDataset<Item>,
dataset: InMemDataset<MNISTItem>,
}

impl Dataset<Item> for MNISTDataset {
fn iter<'a>(&'a self) -> crate::DatasetIterator<'a, Item> {
impl Dataset<MNISTItem> for MNISTDataset {
fn iter<'a>(&'a self) -> crate::DatasetIterator<'a, MNISTItem> {
DatasetIterator::new(self)
}

fn get(&self, index: usize) -> Option<Item> {
fn get(&self, index: usize) -> Option<MNISTItem> {
self.dataset.get(index)
}

Expand Down Expand Up @@ -58,15 +57,3 @@ impl MNISTDataset {
Self { dataset }
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test() {
let dataset = MNISTDataset::test();
println!("{:?}", dataset.len());
assert_ne!(3, 3);
}
}
4 changes: 3 additions & 1 deletion burn-dataset/src/transform/mapper.rs
Expand Up @@ -17,7 +17,9 @@ impl<M, I> MapperDataset<M, I> {

impl<M, I, O> Dataset<O> for MapperDataset<M, I>
where
M: Mapper<I, O>,
M: Mapper<I, O> + Send + Sync,
I: Send + Sync,
O: Send + Sync,
{
fn get(&self, index: usize) -> Option<O> {
let item = self.dataset.get(index);
Expand Down
2 changes: 2 additions & 0 deletions burn-dataset/src/transform/mod.rs
@@ -1,7 +1,9 @@
mod composed;
mod mapper;
mod partial;
mod random;

pub use composed::*;
pub use mapper::*;
pub use partial::*;
pub use random::*;
144 changes: 144 additions & 0 deletions burn-dataset/src/transform/partial.rs
@@ -0,0 +1,144 @@
use crate::{Dataset, DatasetIterator};
use std::sync::Arc;

pub struct PartialDataset<I> {
dataset: Arc<dyn Dataset<I>>,
start_index: usize,
end_index: usize,
}

impl<I> PartialDataset<I> {
pub fn new(dataset: Arc<dyn Dataset<I>>, start_index: usize, end_index: usize) -> Self {
Self {
dataset,
start_index,
end_index,
}
}
pub fn split(dataset: Arc<dyn Dataset<I>>, num: usize) -> Vec<PartialDataset<I>> {
let mut current = 0;
let mut datasets = Vec::with_capacity(num);

let batch_size = dataset.len() / num;

for i in 0..num {
let start = current;
let mut end = current + batch_size;

if i == (num - 1) {
end = dataset.len();
}

let dataset = PartialDataset::new(dataset.clone(), start, end);

current += batch_size;
datasets.push(dataset);
}

datasets
}
}

impl<I> Dataset<I> for PartialDataset<I>
where
I: Clone + Send + Sync,
{
fn get(&self, index: usize) -> Option<I> {
let index = index + self.start_index;
if index < self.start_index || index >= self.end_index {
return None;
}
self.dataset.get(index)
}

fn iter<'a>(&'a self) -> DatasetIterator<'a, I> {
DatasetIterator::new(self)
}
fn len(&self) -> usize {
usize::min(self.end_index - self.start_index, self.dataset.len())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::FakeDataset;
use std::collections::HashSet;

#[test]
fn test_start_from_beginning() {
let dataset_original = Arc::new(FakeDataset::<String>::new(27));
let dataset_partial = PartialDataset::new(dataset_original.clone(), 0, 10);

let mut items_original_1 = HashSet::new();
let mut items_original_2 = HashSet::new();
let mut items_partial = HashSet::new();

for (i, item) in dataset_original.iter().enumerate() {
if i >= 10 {
items_original_2.insert(item);
} else {
items_original_1.insert(item);
}
}

for item in dataset_partial.iter() {
items_partial.insert(item);
}

assert_eq!(dataset_partial.len(), 10);
assert_eq!(items_original_1, items_partial);
for item in items_original_2 {
assert!(!items_partial.contains(&item));
}
}

#[test]
fn test_start_inside() {
let dataset_original = Arc::new(FakeDataset::<String>::new(27));
let dataset_partial = PartialDataset::new(dataset_original.clone(), 10, 20);

let mut items_original_1 = HashSet::new();
let mut items_original_2 = HashSet::new();
let mut items_partial = HashSet::new();

for (i, item) in dataset_original.iter().enumerate() {
if i < 10 || i >= 20 {
items_original_2.insert(item);
} else {
items_original_1.insert(item);
}
}

for item in dataset_partial.iter() {
items_partial.insert(item);
}

assert_eq!(dataset_partial.len(), 10);
assert_eq!(items_original_1, items_partial);
for item in items_original_2 {
assert!(!items_partial.contains(&item));
}
}

#[test]
fn test_split_contains_all_items_without_duplicates() {
let dataset_original = Arc::new(FakeDataset::<String>::new(27));
let dataset_partials = PartialDataset::split(dataset_original.clone(), 4);

let mut items_original = Vec::new();
let mut items_partial = Vec::new();

for item in dataset_original.iter() {
items_original.push(item);
}

for dataset in dataset_partials {
for item in dataset.iter() {
items_partial.push(item);
}
}

assert_eq!(items_original, items_partial);
}
}
16 changes: 10 additions & 6 deletions burn-dataset/src/transform/random.rs
@@ -1,23 +1,27 @@
use crate::{Dataset, DatasetIterator};
use rand::seq::SliceRandom;
use rand::thread_rng;
use rand::{prelude::SliceRandom, rngs::StdRng, SeedableRng};
use std::sync::Arc;

pub struct ShuffledDataset<I> {
dataset: Box<dyn Dataset<I>>,
dataset: Arc<dyn Dataset<I>>,
indexes: Vec<usize>,
}

impl<I> ShuffledDataset<I> {
pub fn new(dataset: Box<dyn Dataset<I>>) -> Self {
pub fn new(dataset: Arc<dyn Dataset<I>>, rng: &mut StdRng) -> Self {
let mut indexes = Vec::with_capacity(dataset.len());
for i in 0..dataset.len() {
indexes.push(i);
}
let mut rng = thread_rng();
indexes.shuffle(&mut rng);
indexes.shuffle(rng);

Self { dataset, indexes }
}

pub fn with_seed(dataset: Arc<dyn Dataset<I>>, seed: u64) -> Self {
let mut rng = StdRng::seed_from_u64(seed);
Self::new(dataset, &mut rng)
}
}

impl<I> Dataset<I> for ShuffledDataset<I>
Expand Down
4 changes: 4 additions & 0 deletions burn-tensor/src/tensor/api/ad.rs
Expand Up @@ -20,6 +20,10 @@ impl<const D: usize, B: ADBackend> Tensor<B, D> {
pub fn update(&mut self, other_inner: Tensor<B::InnerBackend, D>) {
self.value = B::from_inner(other_inner.value);
}

pub fn from_inner(inner: Tensor<B::InnerBackend, D>) -> Self {
Self::new(B::from_inner(inner.value))
}
}

#[cfg(feature = "ndarray")]
Expand Down
4 changes: 3 additions & 1 deletion burn-tensor/src/tensor/backend/backend.rs
Expand Up @@ -22,6 +22,8 @@ pub trait Backend: Clone + Sized + Default + Send + Sync + std::fmt::Debug + 'st
+ TensorOpsMask<Self, D>
+ TensorOpsMapComparison<Self, D>
+ ReLU<Self::Elem, D>
+ Send
+ Sync
+ 'static;
type BoolTensorPrimitive<const D: usize>: TensorOpsUtilities<bool, D>
+ Clone
Expand Down Expand Up @@ -58,7 +60,7 @@ pub type ADBackendTensorPrimitive<const D: usize, B> =
<<B as ADBackend>::InnerBackend as Backend>::TensorPrimitive<D>;

pub trait ADBackend: Backend {
type InnerBackend: Backend;
type InnerBackend: Backend<Device = Self::Device>;

fn backward<const D: usize>(tensor: &Self::TensorPrimitive<D>) -> Gradients;
fn grad<const D: usize>(
Expand Down

0 comments on commit 0f6b50f

Please sign in to comment.