Skip to content

Commit

Permalink
Add multi-label classification dataset and metric (#1572)
Browse files Browse the repository at this point in the history
* Add multilabel classification dataset

- Add MultiLabel annotation support
- Refactor de/serialize annotation with AnnotationRaw
- Add ImageFolderDataset::with_items methods

* Fix custom-image-classification example deps

* Add image_folder_dataset_multilabel test

* Do not change class names order when provided

* Add hamming score and multi-label classification output

* Add new_classification_with_items test

* Fix clippy suggestions

* Implement default trait for hamming score

* Remove de/serialization and use AnnotationRaw as type

* Fix clippy

* Fix metric backend phantom data
  • Loading branch information
laggui authored Apr 5, 2024
1 parent f5159b6 commit f3e0aa6
Show file tree
Hide file tree
Showing 7 changed files with 413 additions and 36 deletions.
247 changes: 217 additions & 30 deletions crates/burn-dataset/src/vision/image_folder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@ impl TryFrom<PixelDepth> for f32 {
}
}

/// Image target for different tasks.
/// Annotation type for different tasks.
#[derive(Debug, Clone, PartialEq)]
pub enum Annotation {
/// Image-level label.
Label(usize),
/// Multiple image-level labels.
MultiLabel(Vec<usize>),
/// Object bounding boxes.
BoundingBoxes(Vec<BoundingBox>),
/// Segmentation mask.
Expand Down Expand Up @@ -97,30 +99,56 @@ pub struct ImageDatasetItem {
pub annotation: Annotation,
}

/// Raw annotation types.
#[derive(Deserialize, Serialize, Debug, Clone)]
enum AnnotationRaw {
Label(String),
MultiLabel(Vec<String>),
// TODO: bounding boxes and segmentation mask
}

#[derive(Deserialize, Serialize, Debug, Clone)]
struct ImageDatasetItemRaw {
/// Image path.
pub image_path: PathBuf,
image_path: PathBuf,

/// Image annotation.
/// The annotation bytes can represent a string (category name) or path to annotation file.
pub annotation: Vec<u8>,
annotation: AnnotationRaw,
}

impl ImageDatasetItemRaw {
fn new<P: AsRef<Path>>(image_path: P, annotation: AnnotationRaw) -> ImageDatasetItemRaw {
ImageDatasetItemRaw {
image_path: image_path.as_ref().to_path_buf(),
annotation,
}
}
}

struct PathToImageDatasetItem {
classes: HashMap<String, usize>,
}

/// Parse the image annotation to the corresponding type.
fn parse_image_annotation(annotation: &[u8], classes: &HashMap<String, usize>) -> Annotation {
fn parse_image_annotation(
annotation: &AnnotationRaw,
classes: &HashMap<String, usize>,
) -> Annotation {
// TODO: add support for other annotations
// - [ ] Object bounding boxes
// - [ ] Segmentation mask
// For now, only image classification labels are supported.

// Map class string to label id
let name = std::str::from_utf8(annotation).unwrap();
Annotation::Label(*classes.get(name).unwrap())
match annotation {
AnnotationRaw::Label(name) => Annotation::Label(*classes.get(name).unwrap()),
AnnotationRaw::MultiLabel(names) => Annotation::MultiLabel(
names
.iter()
.map(|name| *classes.get(name).unwrap())
.collect(),
),
}
}

impl Mapper<ImageDatasetItemRaw, ImageDatasetItem> for PathToImageDatasetItem {
Expand Down Expand Up @@ -212,7 +240,7 @@ pub enum ImageLoaderError {
type ImageDatasetMapper =
MapperDataset<InMemDataset<ImageDatasetItemRaw>, PathToImageDatasetItem, ImageDatasetItemRaw>;

/// A generic dataset to load classification images from disk.
/// A generic dataset to load images from disk.
pub struct ImageFolderDataset {
dataset: ImageDatasetMapper,
}
Expand Down Expand Up @@ -259,26 +287,14 @@ impl ImageFolderDataset {
P: AsRef<Path>,
S: AsRef<str>,
{
/// Check if extension is supported.
fn check_extension<S: AsRef<str>>(extension: &S) -> Result<String, ImageLoaderError> {
let extension = extension.as_ref();
if !SUPPORTED_FILES.contains(&extension) {
Err(ImageLoaderError::InvalidFileExtensionError(
extension.to_string(),
))
} else {
Ok(extension.to_string())
}
}

// Glob all images with extensions
let walker = globwalk::GlobWalkerBuilder::from_patterns(
root.as_ref(),
&[format!(
"*.{{{}}}", // "*.{ext1,ext2,ext3}
extensions
.iter()
.map(check_extension)
.map(Self::check_extension)
.collect::<Result<Vec<_>, _>>()?
.join(",")
)],
Expand Down Expand Up @@ -312,21 +328,102 @@ impl ImageFolderDataset {

classes.insert(label.clone());

items.push(ImageDatasetItemRaw {
image_path: image_path.to_path_buf(),
annotation: label.into_bytes(),
})
items.push(ImageDatasetItemRaw::new(
image_path,
AnnotationRaw::Label(label),
))
}

// Sort class names
let mut classes = classes.into_iter().collect::<Vec<_>>();
classes.sort();

Self::with_items(items, &classes)
}

/// Create an image classification dataset with the specified items.
///
/// # Arguments
///
/// * `items` - List of dataset items, each item represented by a tuple `(image path, label)`.
/// * `classes` - Dataset class names.
///
/// # Returns
/// A new dataset instance.
pub fn new_classification_with_items<P: AsRef<Path>, S: AsRef<str>>(
items: Vec<(P, String)>,
classes: &[S],
) -> Result<Self, ImageLoaderError> {
// Parse items and check valid image extension types
let items = items
.into_iter()
.map(|(path, label)| {
// Map image path and label
let path = path.as_ref();
let label = AnnotationRaw::Label(label);

Self::check_extension(&path.extension().unwrap().to_str().unwrap())?;

Ok(ImageDatasetItemRaw::new(path, label))
})
.collect::<Result<Vec<_>, _>>()?;

Self::with_items(items, classes)
}

/// Create a multi-label image classification dataset with the specified items.
///
/// # Arguments
///
/// * `items` - List of dataset items, each item represented by a tuple `(image path, labels)`.
/// * `classes` - Dataset class names.
///
/// # Returns
/// A new dataset instance.
pub fn new_multilabel_classification_with_items<P: AsRef<Path>, S: AsRef<str>>(
items: Vec<(P, Vec<String>)>,
classes: &[S],
) -> Result<Self, ImageLoaderError> {
// Parse items and check valid image extension types
let items = items
.into_iter()
.map(|(path, labels)| {
// Map image path and multi-label
let path = path.as_ref();
let labels = AnnotationRaw::MultiLabel(labels);

Self::check_extension(&path.extension().unwrap().to_str().unwrap())?;

Ok(ImageDatasetItemRaw::new(path, labels))
})
.collect::<Result<Vec<_>, _>>()?;

Self::with_items(items, classes)
}

/// Create an image dataset with the specified items.
///
/// # Arguments
///
/// * `items` - Raw dataset items.
/// * `classes` - Dataset class names.
///
/// # Returns
/// A new dataset instance.
fn with_items<S: AsRef<str>>(
items: Vec<ImageDatasetItemRaw>,
classes: &[S],
) -> Result<Self, ImageLoaderError> {
// NOTE: right now we don't need to validate the supported image files since
// the method is private. We assume it's already validated.
let dataset = InMemDataset::new(items);

// Class names to index map
let mut classes = classes.into_iter().collect::<Vec<_>>();
classes.sort();
let classes = classes.iter().map(|c| c.as_ref()).collect::<Vec<_>>();
let classes_map: HashMap<_, _> = classes
.into_iter()
.enumerate()
.map(|(idx, cls)| (cls, idx))
.map(|(idx, cls)| (cls.to_string(), idx))
.collect();

let mapper = PathToImageDatasetItem {
Expand All @@ -336,6 +433,18 @@ impl ImageFolderDataset {

Ok(Self { dataset })
}

/// Check if extension is supported.
fn check_extension<S: AsRef<str>>(extension: &S) -> Result<String, ImageLoaderError> {
let extension = extension.as_ref();
if !SUPPORTED_FILES.contains(&extension) {
Err(ImageLoaderError::InvalidFileExtensionError(
extension.to_string(),
))
} else {
Ok(extension.to_string())
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -370,6 +479,69 @@ mod tests {
assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1));
}

#[test]
pub fn image_folder_dataset_with_items() {
let root = Path::new(DATASET_ROOT);
let items = vec![
(root.join("orange").join("dot.jpg"), "orange".to_string()),
(root.join("red").join("dot.jpg"), "red".to_string()),
(root.join("red").join("dot.png"), "red".to_string()),
];
let dataset =
ImageFolderDataset::new_classification_with_items(items, &["orange", "red"]).unwrap();

// Dataset has 3 elements
assert_eq!(dataset.len(), 3);
assert_eq!(dataset.get(3), None);

// Dataset elements should be: orange (0), red (1), red (1)
assert_eq!(dataset.get(0).unwrap().annotation, Annotation::Label(0));
assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1));
assert_eq!(dataset.get(2).unwrap().annotation, Annotation::Label(1));
}

#[test]
pub fn image_folder_dataset_multilabel() {
let root = Path::new(DATASET_ROOT);
let items = vec![
(
root.join("orange").join("dot.jpg"),
vec!["dot".to_string(), "orange".to_string()],
),
(
root.join("red").join("dot.jpg"),
vec!["dot".to_string(), "red".to_string()],
),
(
root.join("red").join("dot.png"),
vec!["dot".to_string(), "red".to_string()],
),
];
let dataset = ImageFolderDataset::new_multilabel_classification_with_items(
items,
&["dot", "orange", "red"],
)
.unwrap();

// Dataset has 3 elements
assert_eq!(dataset.len(), 3);
assert_eq!(dataset.get(3), None);

// Dataset elements should be: [dot, orange] (0, 1), [dot, red] (0, 2), [dot, red] (0, 2)
assert_eq!(
dataset.get(0).unwrap().annotation,
Annotation::MultiLabel(vec![0, 1])
);
assert_eq!(
dataset.get(1).unwrap().annotation,
Annotation::MultiLabel(vec![0, 2])
);
assert_eq!(
dataset.get(2).unwrap().annotation,
Annotation::MultiLabel(vec![0, 2])
);
}

#[test]
#[should_panic]
pub fn image_folder_dataset_invalid_extension() {
Expand Down Expand Up @@ -417,11 +589,26 @@ mod tests {
}

#[test]
pub fn parse_image_annotation_string() {
pub fn parse_image_annotation_label_string() {
let classes = HashMap::from([("0".to_string(), 0_usize), ("1".to_string(), 1_usize)]);
let anno = AnnotationRaw::Label("0".to_string());
assert_eq!(
parse_image_annotation(&"0".to_string().into_bytes(), &classes),
parse_image_annotation(&anno, &classes),
Annotation::Label(0)
);
}

#[test]
pub fn parse_image_annotation_multilabel_string() {
let classes = HashMap::from([
("0".to_string(), 0_usize),
("1".to_string(), 1_usize),
("2".to_string(), 2_usize),
]);
let anno = AnnotationRaw::MultiLabel(vec!["0".to_string(), "2".to_string()]);
assert_eq!(
parse_image_annotation(&anno, &classes),
Annotation::MultiLabel(vec![0, 2])
);
}
}
27 changes: 26 additions & 1 deletion crates/burn-train/src/learner/classification.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::metric::{AccuracyInput, Adaptor, LossInput};
use crate::metric::{AccuracyInput, Adaptor, HammingScoreInput, LossInput};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{Int, Tensor};

Expand Down Expand Up @@ -26,3 +26,28 @@ impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
LossInput::new(self.loss.clone())
}
}

/// Multi-label classification output adapted for multiple metrics.
#[derive(new)]
pub struct MultiLabelClassificationOutput<B: Backend> {
/// The loss.
pub loss: Tensor<B, 1>,

/// The output.
pub output: Tensor<B, 2>,

/// The targets.
pub targets: Tensor<B, 2, Int>,
}

impl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelClassificationOutput<B> {
fn adapt(&self) -> HammingScoreInput<B> {
HammingScoreInput::new(self.output.clone(), self.targets.clone())
}
}

impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
fn adapt(&self) -> LossInput<B> {
LossInput::new(self.loss.clone())
}
}
4 changes: 3 additions & 1 deletion crates/burn-train/src/metric/acc.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use core::marker::PhantomData;

use super::state::{FormatOptions, NumericMetricState};
use super::{MetricEntry, MetricMetadata};
use crate::metric::{Metric, Numeric};
Expand All @@ -9,7 +11,7 @@ use burn_core::tensor::{ElementConversion, Int, Tensor};
pub struct AccuracyMetric<B: Backend> {
state: NumericMetricState,
pad_token: Option<usize>,
_b: B,
_b: PhantomData<B>,
}

/// The [accuracy metric](AccuracyMetric) input type.
Expand Down
Loading

0 comments on commit f3e0aa6

Please sign in to comment.