diff --git a/burn-train/src/checkpoint/strategy/base.rs b/burn-train/src/checkpoint/strategy/base.rs index fdab52b6db..f16acfeb41 100644 --- a/burn-train/src/checkpoint/strategy/base.rs +++ b/burn-train/src/checkpoint/strategy/base.rs @@ -1,6 +1,7 @@ -use crate::EventCollector; use std::ops::DerefMut; +use crate::metric::store::EventStoreClient; + /// Action to be taken by a [checkpointer](crate::checkpoint::Checkpointer). #[derive(Clone, PartialEq, Debug)] pub enum CheckpointingAction { @@ -11,15 +12,23 @@ pub enum CheckpointingAction { } /// Define when checkpoint should be saved and deleted. -pub trait CheckpointingStrategy { +pub trait CheckpointingStrategy { /// Based on the epoch, determine if the checkpoint should be saved. - fn checkpointing(&mut self, epoch: usize, collector: &mut E) -> Vec; + fn checkpointing( + &mut self, + epoch: usize, + collector: &EventStoreClient, + ) -> Vec; } // We make dyn box implement the checkpointing strategy so that it can be used with generic, but // still be dynamic. -impl CheckpointingStrategy for Box> { - fn checkpointing(&mut self, epoch: usize, collector: &mut E) -> Vec { +impl CheckpointingStrategy for Box { + fn checkpointing( + &mut self, + epoch: usize, + collector: &EventStoreClient, + ) -> Vec { self.deref_mut().checkpointing(epoch, collector) } } diff --git a/burn-train/src/checkpoint/strategy/composed.rs b/burn-train/src/checkpoint/strategy/composed.rs index b43c3f7d5f..6ca3197788 100644 --- a/burn-train/src/checkpoint/strategy/composed.rs +++ b/burn-train/src/checkpoint/strategy/composed.rs @@ -1,45 +1,40 @@ +use crate::metric::store::EventStoreClient; + use super::{CheckpointingAction, CheckpointingStrategy}; -use crate::EventCollector; use std::collections::HashSet; /// Compose multiple checkpointing strategy and only delete checkpoints when both strategy flag an /// epoch to be deleted. -pub struct ComposedCheckpointingStrategy { - strategies: Vec>>, +pub struct ComposedCheckpointingStrategy { + strategies: Vec>, deleted: Vec>, } /// Help building a [checkpointing strategy](CheckpointingStrategy) by combining multiple ones. -pub struct ComposedCheckpointingStrategyBuilder { - strategies: Vec>>, +#[derive(Default)] +pub struct ComposedCheckpointingStrategyBuilder { + strategies: Vec>, } -impl Default for ComposedCheckpointingStrategyBuilder { - fn default() -> Self { - Self { - strategies: Vec::new(), - } - } -} - -impl ComposedCheckpointingStrategyBuilder { +impl ComposedCheckpointingStrategyBuilder { /// Add a new [checkpointing strategy](CheckpointingStrategy). + #[allow(clippy::should_implement_trait)] pub fn add(mut self, strategy: S) -> Self where - S: CheckpointingStrategy + 'static, + S: CheckpointingStrategy + 'static, { self.strategies.push(Box::new(strategy)); self } /// Create a new [composed checkpointing strategy](ComposedCheckpointingStrategy). - pub fn build(self) -> ComposedCheckpointingStrategy { + pub fn build(self) -> ComposedCheckpointingStrategy { ComposedCheckpointingStrategy::new(self.strategies) } } -impl ComposedCheckpointingStrategy { - fn new(strategies: Vec>>) -> Self { +impl ComposedCheckpointingStrategy { + fn new(strategies: Vec>) -> Self { Self { deleted: strategies.iter().map(|_| HashSet::new()).collect(), strategies, @@ -47,13 +42,17 @@ impl ComposedCheckpointingStrategy { } /// Create a new builder which help compose multiple /// [checkpointing strategies](CheckpointingStrategy). - pub fn builder() -> ComposedCheckpointingStrategyBuilder { + pub fn builder() -> ComposedCheckpointingStrategyBuilder { ComposedCheckpointingStrategyBuilder::default() } } -impl CheckpointingStrategy for ComposedCheckpointingStrategy { - fn checkpointing(&mut self, epoch: usize, collector: &mut E) -> Vec { +impl CheckpointingStrategy for ComposedCheckpointingStrategy { + fn checkpointing( + &mut self, + epoch: usize, + collector: &EventStoreClient, + ) -> Vec { let mut saved = false; let mut actions = Vec::new(); let mut epochs_to_check = Vec::new(); @@ -104,15 +103,12 @@ impl CheckpointingStrategy for ComposedCheckpointingStrate #[cfg(test)] mod tests { - use crate::{ - checkpoint::KeepLastNCheckpoints, info::MetricsInfo, test_utils::TestEventCollector, - }; - use super::*; + use crate::{checkpoint::KeepLastNCheckpoints, metric::store::LogEventStore}; #[test] fn should_delete_when_both_deletes() { - let mut collector = TestEventCollector::::new(MetricsInfo::new()); + let store = EventStoreClient::new(LogEventStore::default()); let mut strategy = ComposedCheckpointingStrategy::builder() .add(KeepLastNCheckpoints::new(1)) .add(KeepLastNCheckpoints::new(2)) @@ -120,17 +116,17 @@ mod tests { assert_eq!( vec![CheckpointingAction::Save], - strategy.checkpointing(1, &mut collector) + strategy.checkpointing(1, &store) ); assert_eq!( vec![CheckpointingAction::Save], - strategy.checkpointing(2, &mut collector) + strategy.checkpointing(2, &store) ); assert_eq!( vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)], - strategy.checkpointing(3, &mut collector) + strategy.checkpointing(3, &store) ); } } diff --git a/burn-train/src/checkpoint/strategy/lastn.rs b/burn-train/src/checkpoint/strategy/lastn.rs index 7bbbf0f0c5..66f5df91bf 100644 --- a/burn-train/src/checkpoint/strategy/lastn.rs +++ b/burn-train/src/checkpoint/strategy/lastn.rs @@ -1,5 +1,5 @@ use super::CheckpointingStrategy; -use crate::{checkpoint::CheckpointingAction, EventCollector}; +use crate::{checkpoint::CheckpointingAction, metric::store::EventStoreClient}; /// Keep the last N checkpoints. /// @@ -10,8 +10,12 @@ pub struct KeepLastNCheckpoints { num_keep: usize, } -impl CheckpointingStrategy for KeepLastNCheckpoints { - fn checkpointing(&mut self, epoch: usize, _collector: &mut E) -> Vec { +impl CheckpointingStrategy for KeepLastNCheckpoints { + fn checkpointing( + &mut self, + epoch: usize, + _store: &EventStoreClient, + ) -> Vec { let mut actions = vec![CheckpointingAction::Save]; if let Some(epoch) = usize::checked_sub(epoch, self.num_keep) { @@ -26,28 +30,27 @@ impl CheckpointingStrategy for KeepLastNCheckpoints { #[cfg(test)] mod tests { - use crate::{info::MetricsInfo, test_utils::TestEventCollector}; - use super::*; + use crate::metric::store::LogEventStore; #[test] fn should_always_delete_lastn_epoch_if_higher_than_one() { let mut strategy = KeepLastNCheckpoints::new(2); - let mut collector = TestEventCollector::::new(MetricsInfo::new()); + let store = EventStoreClient::new(LogEventStore::default()); assert_eq!( vec![CheckpointingAction::Save], - strategy.checkpointing(1, &mut collector) + strategy.checkpointing(1, &store) ); assert_eq!( vec![CheckpointingAction::Save], - strategy.checkpointing(2, &mut collector) + strategy.checkpointing(2, &store) ); assert_eq!( vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)], - strategy.checkpointing(3, &mut collector) + strategy.checkpointing(3, &store) ); } } diff --git a/burn-train/src/checkpoint/strategy/metric.rs b/burn-train/src/checkpoint/strategy/metric.rs index 5ef3826d24..f2aa58efeb 100644 --- a/burn-train/src/checkpoint/strategy/metric.rs +++ b/burn-train/src/checkpoint/strategy/metric.rs @@ -1,6 +1,10 @@ use super::CheckpointingStrategy; use crate::{ - checkpoint::CheckpointingAction, metric::Metric, Aggregate, Direction, EventCollector, Split, + checkpoint::CheckpointingAction, + metric::{ + store::{Aggregate, Direction, EventStoreClient, Split}, + Metric, + }, }; /// Keep the best checkpoint based on a metric. @@ -28,10 +32,14 @@ impl MetricCheckpointingStrategy { } } -impl CheckpointingStrategy for MetricCheckpointingStrategy { - fn checkpointing(&mut self, epoch: usize, collector: &mut E) -> Vec { +impl CheckpointingStrategy for MetricCheckpointingStrategy { + fn checkpointing( + &mut self, + epoch: usize, + store: &EventStoreClient, + ) -> Vec { let best_epoch = - match collector.find_epoch(&self.name, self.aggregate, self.direction, self.split) { + match store.find_epoch(&self.name, self.aggregate, self.direction, self.split) { Some(epoch_best) => epoch_best, None => epoch, }; @@ -56,93 +64,70 @@ impl CheckpointingStrategy for MetricCheckpointingStrategy #[cfg(test)] mod tests { - use burn_core::tensor::{backend::Backend, ElementConversion, Tensor}; - - use super::*; use crate::{ - info::MetricsInfo, logger::InMemoryMetricLogger, - metric::{Adaptor, LossInput, LossMetric}, - test_utils::TestEventCollector, - Event, LearnerItem, TestBackend, + metric::{ + processor::{ + test_utils::{end_epoch, process_train}, + Metrics, MinimalEventProcessor, + }, + store::LogEventStore, + LossMetric, + }, + TestBackend, }; + use std::sync::Arc; + + use super::*; #[test] fn always_keep_the_best_epoch() { + let mut store = LogEventStore::default(); let mut strategy = MetricCheckpointingStrategy::new::>( Aggregate::Mean, Direction::Lowest, Split::Train, ); - let mut info = MetricsInfo::new(); + let mut metrics = Metrics::::default(); // Register an in memory logger. - info.register_logger_train(InMemoryMetricLogger::default()); + store.register_logger_train(InMemoryMetricLogger::default()); // Register the loss metric. - info.register_train_metric_numeric(LossMetric::::new()); - - let mut collector = TestEventCollector::::new(info); + metrics.register_train_metric_numeric(LossMetric::::new()); + let store = Arc::new(EventStoreClient::new(store)); + let mut processor = MinimalEventProcessor::new(metrics, store.clone()); // Two points for the first epoch. Mean 0.75 let mut epoch = 1; - item(&mut collector, 1.0, epoch); - item(&mut collector, 0.5, epoch); - end_epoch(&mut collector, epoch); + process_train(&mut processor, 1.0, epoch); + process_train(&mut processor, 0.5, epoch); + end_epoch(&mut processor, epoch); // Should save the current record. assert_eq!( vec![CheckpointingAction::Save], - strategy.checkpointing(epoch, &mut collector) + strategy.checkpointing(epoch, &store) ); // Two points for the second epoch. Mean 0.4 epoch += 1; - item(&mut collector, 0.5, epoch); - item(&mut collector, 0.3, epoch); - end_epoch(&mut collector, epoch); + process_train(&mut processor, 0.5, epoch); + process_train(&mut processor, 0.3, epoch); + end_epoch(&mut processor, epoch); // Should save the current record and delete the pervious one. assert_eq!( vec![CheckpointingAction::Delete(1), CheckpointingAction::Save], - strategy.checkpointing(epoch, &mut collector) + strategy.checkpointing(epoch, &store) ); // Two points for the last epoch. Mean 2.0 epoch += 1; - item(&mut collector, 1.0, epoch); - item(&mut collector, 3.0, epoch); - end_epoch(&mut collector, epoch); + process_train(&mut processor, 1.0, epoch); + process_train(&mut processor, 3.0, epoch); + end_epoch(&mut processor, epoch); // Should not delete the previous record, since it's the best one, and should not save a // new one. - assert!(strategy.checkpointing(epoch, &mut collector).is_empty()); - } - - fn item(collector: &mut TestEventCollector, value: f64, epoch: usize) { - let dummy_progress = burn_core::data::dataloader::Progress { - items_processed: 1, - items_total: 10, - }; - let num_epochs = 3; - let dummy_iteration = 1; - - collector.on_event_train(Event::ProcessedItem(LearnerItem::new( - value, - dummy_progress, - epoch, - num_epochs, - dummy_iteration, - None, - ))); - } - - fn end_epoch(collector: &mut TestEventCollector, epoch: usize) { - collector.on_event_train(Event::EndEpoch(epoch)); - collector.on_event_valid(Event::EndEpoch(epoch)); - } - - impl Adaptor> for f64 { - fn adapt(&self) -> LossInput { - LossInput::new(Tensor::from_data([self.elem()])) - } + assert!(strategy.checkpointing(epoch, &store).is_empty()); } } diff --git a/burn-train/src/collector/async_collector.rs b/burn-train/src/collector/async_collector.rs deleted file mode 100644 index eb4f75d705..0000000000 --- a/burn-train/src/collector/async_collector.rs +++ /dev/null @@ -1,118 +0,0 @@ -use super::EventCollector; -use crate::{Aggregate, Direction, Event, Split}; -use std::{sync::mpsc, thread::JoinHandle}; - -enum Message { - OnEventTrain(Event), - OnEventValid(Event), - End, - FindEpoch( - String, - Aggregate, - Direction, - Split, - mpsc::SyncSender>, - ), -} - -/// Async [event collector](EventCollector). -/// -/// This will create a worker thread where all the computation is done ensuring that the training loop is -/// never blocked by metric calculation. -pub struct AsyncEventCollector { - sender: mpsc::Sender>, - handler: Option>, -} - -#[derive(new)] -struct WorkerThread { - collector: C, - receiver: mpsc::Receiver>, -} - -impl WorkerThread -where - C: EventCollector, -{ - fn run(mut self) { - for item in self.receiver.iter() { - match item { - Message::End => { - return; - } - Message::FindEpoch(name, aggregate, direction, split, sender) => { - let response = self - .collector - .find_epoch(&name, aggregate, direction, split); - sender.send(response).unwrap(); - } - Message::OnEventTrain(event) => self.collector.on_event_train(event), - Message::OnEventValid(event) => self.collector.on_event_valid(event), - } - } - } -} - -impl AsyncEventCollector { - /// Create a new async [event collector](EventCollector). - pub fn new(collector: C) -> Self - where - C: EventCollector + 'static, - { - let (sender, receiver) = mpsc::channel(); - let thread = WorkerThread::new(collector, receiver); - - let handler = std::thread::spawn(move || thread.run()); - let handler = Some(handler); - - Self { sender, handler } - } -} - -impl EventCollector for AsyncEventCollector { - type ItemTrain = T; - type ItemValid = V; - - fn on_event_train(&mut self, event: Event) { - self.sender.send(Message::OnEventTrain(event)).unwrap(); - } - - fn on_event_valid(&mut self, event: Event) { - self.sender.send(Message::OnEventValid(event)).unwrap(); - } - - fn find_epoch( - &mut self, - name: &str, - aggregate: Aggregate, - direction: Direction, - split: Split, - ) -> Option { - let (sender, receiver) = mpsc::sync_channel(1); - self.sender - .send(Message::FindEpoch( - name.to_string(), - aggregate, - direction, - split, - sender, - )) - .unwrap(); - - match receiver.recv() { - Ok(value) => value, - Err(err) => panic!("Async server crashed: {:?}", err), - } - } -} - -impl Drop for AsyncEventCollector { - fn drop(&mut self) { - self.sender.send(Message::End).unwrap(); - let handler = self.handler.take(); - - if let Some(handler) = handler { - handler.join().unwrap(); - } - } -} diff --git a/burn-train/src/collector/base.rs b/burn-train/src/collector/base.rs deleted file mode 100644 index 0e2e3deebb..0000000000 --- a/burn-train/src/collector/base.rs +++ /dev/null @@ -1,134 +0,0 @@ -use burn_core::{data::dataloader::Progress, LearningRate}; - -/// Event happening during the training/validation process. -pub enum Event { - /// Signal that an item have been processed. - ProcessedItem(LearnerItem), - /// Signal the end of an epoch. - EndEpoch(usize), -} - -/// Defines how training and validation events are collected. -/// -/// This trait also exposes methods that uses the collected data to compute useful information. -pub trait EventCollector: Send { - /// Training item. - type ItemTrain; - /// Validation item. - type ItemValid; - - /// Collect the training event. - fn on_event_train(&mut self, event: Event); - - /// Collect the validaion event. - fn on_event_valid(&mut self, event: Event); - - /// Find the epoch following the given criteria from the collected data. - fn find_epoch( - &mut self, - name: &str, - aggregate: Aggregate, - direction: Direction, - split: Split, - ) -> Option; -} - -#[derive(Copy, Clone)] -/// How to aggregate the metric. -pub enum Aggregate { - /// Compute the average. - Mean, -} - -#[derive(Copy, Clone)] -/// The split to use. -pub enum Split { - /// The training split. - Train, - /// The validation split. - Valid, -} - -#[derive(Copy, Clone)] -/// The direction of the query. -pub enum Direction { - /// Lower is better. - Lowest, - /// Higher is better. - Highest, -} - -/// A learner item. -#[derive(new)] -pub struct LearnerItem { - /// The item. - pub item: T, - - /// The progress. - pub progress: Progress, - - /// The epoch. - pub epoch: usize, - - /// The total number of epochs. - pub epoch_total: usize, - - /// The iteration. - pub iteration: usize, - - /// The learning rate. - pub lr: Option, -} - -#[cfg(test)] -pub mod test_utils { - use crate::{info::MetricsInfo, Aggregate, Direction, Event, EventCollector, Split}; - - #[derive(new)] - pub struct TestEventCollector - where - T: Send + Sync + 'static, - V: Send + Sync + 'static, - { - info: MetricsInfo, - } - - impl EventCollector for TestEventCollector - where - T: Send + Sync + 'static, - V: Send + Sync + 'static, - { - type ItemTrain = T; - type ItemValid = V; - - fn on_event_train(&mut self, event: Event) { - match event { - Event::ProcessedItem(item) => { - let metadata = (&item).into(); - self.info.update_train(&item, &metadata); - } - Event::EndEpoch(epoch) => self.info.end_epoch_train(epoch), - } - } - - fn on_event_valid(&mut self, event: Event) { - match event { - Event::ProcessedItem(item) => { - let metadata = (&item).into(); - self.info.update_valid(&item, &metadata); - } - Event::EndEpoch(epoch) => self.info.end_epoch_valid(epoch), - } - } - - fn find_epoch( - &mut self, - name: &str, - aggregate: Aggregate, - direction: Direction, - split: Split, - ) -> Option { - self.info.find_epoch(name, aggregate, direction, split) - } - } -} diff --git a/burn-train/src/collector/metrics/base.rs b/burn-train/src/collector/metrics/base.rs deleted file mode 100644 index 4b1ad2af47..0000000000 --- a/burn-train/src/collector/metrics/base.rs +++ /dev/null @@ -1,131 +0,0 @@ -use crate::{ - info::MetricsInfo, - metric::MetricMetadata, - renderer::{MetricState, MetricsRenderer, TrainingProgress}, - Aggregate, Direction, Event, EventCollector, LearnerItem, Split, -}; - -/// Collect training events in order to display metrics with a metrics renderer. -#[derive(new)] -pub(crate) struct RenderedMetricsEventCollector -where - T: Send + Sync + 'static, - V: Send + Sync + 'static, -{ - renderer: Box, - info: MetricsInfo, -} - -impl EventCollector for RenderedMetricsEventCollector -where - T: Send + Sync + 'static, - V: Send + Sync + 'static, -{ - type ItemTrain = T; - type ItemValid = V; - - fn on_event_train(&mut self, event: Event) { - match event { - Event::ProcessedItem(item) => self.on_train_item(item), - Event::EndEpoch(epoch) => self.on_train_end_epoch(epoch), - } - } - - fn on_event_valid(&mut self, event: Event) { - match event { - Event::ProcessedItem(item) => self.on_valid_item(item), - Event::EndEpoch(epoch) => self.on_valid_end_epoch(epoch), - } - } - - fn find_epoch( - &mut self, - name: &str, - aggregate: Aggregate, - direction: Direction, - split: Split, - ) -> Option { - self.info.find_epoch(name, aggregate, direction, split) - } -} - -impl RenderedMetricsEventCollector -where - T: Send + Sync + 'static, - V: Send + Sync + 'static, -{ - fn on_train_item(&mut self, item: LearnerItem) { - let progress = (&item).into(); - let metadata = (&item).into(); - - let update = self.info.update_train(&item, &metadata); - - update - .entries - .into_iter() - .for_each(|entry| self.renderer.update_train(MetricState::Generic(entry))); - - update - .entries_numeric - .into_iter() - .for_each(|(entry, value)| { - self.renderer - .update_train(MetricState::Numeric(entry, value)) - }); - - self.renderer.render_train(progress); - } - - fn on_valid_item(&mut self, item: LearnerItem) { - let progress = (&item).into(); - let metadata = (&item).into(); - - let update = self.info.update_valid(&item, &metadata); - - update - .entries - .into_iter() - .for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry))); - - update - .entries_numeric - .into_iter() - .for_each(|(entry, value)| { - self.renderer - .update_valid(MetricState::Numeric(entry, value)) - }); - - self.renderer.render_train(progress); - } - - fn on_train_end_epoch(&mut self, epoch: usize) { - self.info.end_epoch_train(epoch); - } - - fn on_valid_end_epoch(&mut self, epoch: usize) { - self.info.end_epoch_valid(epoch); - } -} - -impl From<&LearnerItem> for TrainingProgress { - fn from(item: &LearnerItem) -> Self { - Self { - progress: item.progress.clone(), - epoch: item.epoch, - epoch_total: item.epoch_total, - iteration: item.iteration, - } - } -} - -impl From<&LearnerItem> for MetricMetadata { - fn from(item: &LearnerItem) -> Self { - Self { - progress: item.progress.clone(), - epoch: item.epoch, - epoch_total: item.epoch_total, - iteration: item.iteration, - lr: item.lr, - } - } -} diff --git a/burn-train/src/collector/metrics/mod.rs b/burn-train/src/collector/metrics/mod.rs deleted file mode 100644 index 41e113f920..0000000000 --- a/burn-train/src/collector/metrics/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod base; - -pub(crate) use base::*; diff --git a/burn-train/src/collector/mod.rs b/burn-train/src/collector/mod.rs deleted file mode 100644 index ffb89eb298..0000000000 --- a/burn-train/src/collector/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod async_collector; -mod base; - -pub use async_collector::*; -pub use base::*; - -/// Metrics collector module. -pub mod metrics; diff --git a/burn-train/src/components.rs b/burn-train/src/components.rs index 2d8302bb83..fa20bad518 100644 --- a/burn-train/src/components.rs +++ b/burn-train/src/components.rs @@ -1,6 +1,6 @@ use crate::{ checkpoint::{Checkpointer, CheckpointingStrategy}, - EventCollector, + metric::processor::EventProcessor, }; use burn_core::{ lr_scheduler::LrScheduler, @@ -28,14 +28,13 @@ pub trait LearnerComponents { >; /// The checkpointer used for the scheduler. type CheckpointerLrScheduler: Checkpointer<::Record>; - /// Training event collector used for training tracking. - type EventCollector: EventCollector + 'static; + type EventProcessor: EventProcessor + 'static; /// The strategy to save and delete checkpoints. - type CheckpointerStrategy: CheckpointingStrategy; + type CheckpointerStrategy: CheckpointingStrategy; } /// Concrete type that implements [training components trait](TrainingComponents). -pub struct LearnerComponentsMarker { +pub struct LearnerComponentsMarker { _backend: PhantomData, _lr_scheduler: PhantomData, _model: PhantomData, @@ -43,12 +42,12 @@ pub struct LearnerComponentsMarker { _checkpointer_model: PhantomData, _checkpointer_optim: PhantomData, _checkpointer_scheduler: PhantomData, - _collector: PhantomData, + _event_processor: PhantomData, _strategy: S, } -impl LearnerComponents - for LearnerComponentsMarker +impl LearnerComponents + for LearnerComponentsMarker where B: ADBackend, LR: LrScheduler, @@ -57,8 +56,8 @@ where CM: Checkpointer, CO: Checkpointer, CS: Checkpointer, - EC: EventCollector + 'static, - S: CheckpointingStrategy, + EP: EventProcessor + 'static, + S: CheckpointingStrategy, { type Backend = B; type LrScheduler = LR; @@ -67,6 +66,6 @@ where type CheckpointerModel = CM; type CheckpointerOptimizer = CO; type CheckpointerLrScheduler = CS; - type EventCollector = EC; + type EventProcessor = EP; type CheckpointerStrategy = S; } diff --git a/burn-train/src/info/mod.rs b/burn-train/src/info/mod.rs deleted file mode 100644 index 0adbc16d10..0000000000 --- a/burn-train/src/info/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod aggregates; -mod metrics; - -pub(crate) use aggregates::*; -pub use metrics::*; diff --git a/burn-train/src/learner/base.rs b/burn-train/src/learner/base.rs index ae62bce788..d976ce0e7a 100644 --- a/burn-train/src/learner/base.rs +++ b/burn-train/src/learner/base.rs @@ -1,5 +1,7 @@ use crate::checkpoint::{Checkpointer, CheckpointingAction, CheckpointingStrategy}; use crate::components::LearnerComponents; +use crate::learner::EarlyStoppingStrategy; +use crate::metric::store::EventStoreClient; use burn_core::lr_scheduler::LrScheduler; use burn_core::module::Module; use burn_core::optim::Optimizer; @@ -19,8 +21,10 @@ pub struct Learner { pub(crate) grad_accumulation: Option, pub(crate) checkpointer: Option>, pub(crate) devices: Vec<::Device>, - pub(crate) collector: LC::EventCollector, pub(crate) interrupter: TrainingInterrupter, + pub(crate) early_stopping: Option>, + pub(crate) event_processor: LC::EventProcessor, + pub(crate) event_store: Arc, } #[derive(new)] @@ -38,9 +42,9 @@ impl LearnerCheckpointer { optim: &LC::Optimizer, scheduler: &LC::LrScheduler, epoch: usize, - collector: &mut LC::EventCollector, + store: &EventStoreClient, ) { - let actions = self.strategy.checkpointing(epoch, collector); + let actions = self.strategy.checkpointing(epoch, store); for action in actions { match action { diff --git a/burn-train/src/learner/builder.rs b/burn-train/src/learner/builder.rs index 60e6df42c8..0de9872520 100644 --- a/burn-train/src/learner/builder.rs +++ b/burn-train/src/learner/builder.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use super::log::install_file_logger; use super::Learner; use crate::checkpoint::{ @@ -5,13 +7,14 @@ use crate::checkpoint::{ KeepLastNCheckpoints, MetricCheckpointingStrategy, }; use crate::components::LearnerComponentsMarker; -use crate::info::MetricsInfo; use crate::learner::base::TrainingInterrupter; +use crate::learner::EarlyStoppingStrategy; use crate::logger::{FileMetricLogger, MetricLogger}; +use crate::metric::processor::{FullEventProcessor, Metrics}; +use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split}; use crate::metric::{Adaptor, LossMetric, Metric}; use crate::renderer::{default_renderer, MetricsRenderer}; -use crate::{collector::metrics::RenderedMetricsEventCollector, Aggregate, Direction, Split}; -use crate::{AsyncEventCollector, LearnerCheckpointer}; +use crate::LearnerCheckpointer; use burn_core::lr_scheduler::LrScheduler; use burn_core::module::ADModule; use burn_core::optim::Optimizer; @@ -43,11 +46,13 @@ where grad_accumulation: Option, devices: Vec, renderer: Option>, - info: MetricsInfo, + metrics: Metrics, + event_store: LogEventStore, interrupter: TrainingInterrupter, log_to_file: bool, num_loggers: usize, - checkpointer_strategy: Box>>, + checkpointer_strategy: Box, + early_stopping: Option>, } impl LearnerBuilder @@ -72,7 +77,8 @@ where directory: directory.to_string(), grad_accumulation: None, devices: vec![B::Device::default()], - info: MetricsInfo::new(), + metrics: Metrics::default(), + event_store: LogEventStore::default(), renderer: None, interrupter: TrainingInterrupter::new(), log_to_file: true, @@ -87,6 +93,7 @@ where )) .build(), ), + early_stopping: None, } } @@ -101,8 +108,8 @@ where MT: MetricLogger + 'static, MV: MetricLogger + 'static, { - self.info.register_logger_train(logger_train); - self.info.register_logger_valid(logger_valid); + self.event_store.register_logger_train(logger_train); + self.event_store.register_logger_valid(logger_valid); self.num_loggers += 1; self } @@ -110,7 +117,7 @@ where /// Update the checkpointing_strategy. pub fn with_checkpointing_strategy(&mut self, strategy: CS) where - CS: CheckpointingStrategy> + 'static, + CS: CheckpointingStrategy + 'static, { self.checkpointer_strategy = Box::new(strategy); } @@ -133,7 +140,7 @@ where where T: Adaptor, { - self.info.register_metric_train(metric); + self.metrics.register_metric_train(metric); self } @@ -142,7 +149,7 @@ where where V: Adaptor, { - self.info.register_valid_metric(metric); + self.metrics.register_valid_metric(metric); self } @@ -167,7 +174,7 @@ where Me: Metric + crate::metric::Numeric + 'static, T: Adaptor, { - self.info.register_train_metric_numeric(metric); + self.metrics.register_train_metric_numeric(metric); self } @@ -179,7 +186,7 @@ where where V: Adaptor, { - self.info.register_valid_metric_numeric(metric); + self.metrics.register_valid_metric_numeric(metric); self } @@ -206,6 +213,16 @@ where self.interrupter.clone() } + /// Register an [early stopping strategy](EarlyStoppingStrategy) to stop the training when the + /// conditions are meet. + pub fn early_stopping(mut self, strategy: Strategy) -> Self + where + Strategy: EarlyStoppingStrategy + 'static, + { + self.early_stopping = Some(Box::new(strategy)); + self + } + /// By default, Rust logs are captured and written into /// `experiment.log`. If disabled, standard Rust log handling /// will apply. @@ -267,8 +284,8 @@ where AsyncCheckpointer, AsyncCheckpointer, AsyncCheckpointer, - AsyncEventCollector, - Box>>, + FullEventProcessor, + Box, >, > where @@ -285,16 +302,18 @@ where let directory = &self.directory; if self.num_loggers == 0 { - self.info.register_logger_train(FileMetricLogger::new( - format!("{directory}/train").as_str(), - )); - self.info.register_logger_valid(FileMetricLogger::new( - format!("{directory}/valid").as_str(), - )); + self.event_store + .register_logger_train(FileMetricLogger::new( + format!("{directory}/train").as_str(), + )); + self.event_store + .register_logger_valid(FileMetricLogger::new( + format!("{directory}/valid").as_str(), + )); } - let collector = - AsyncEventCollector::new(RenderedMetricsEventCollector::new(renderer, self.info)); + let event_store = Arc::new(EventStoreClient::new(self.event_store)); + let event_processor = FullEventProcessor::new(self.metrics, renderer, event_store.clone()); let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| { LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy) @@ -306,11 +325,13 @@ where lr_scheduler, checkpointer, num_epochs: self.num_epochs, - collector, + event_processor, + event_store, checkpoint: self.checkpoint, grad_accumulation: self.grad_accumulation, devices: self.devices, interrupter: self.interrupter, + early_stopping: self.early_stopping, } } diff --git a/burn-train/src/learner/early_stopping.rs b/burn-train/src/learner/early_stopping.rs new file mode 100644 index 0000000000..641d49551b --- /dev/null +++ b/burn-train/src/learner/early_stopping.rs @@ -0,0 +1,209 @@ +use crate::metric::{ + store::{Aggregate, Direction, EventStoreClient, Split}, + Metric, +}; + +/// The condition that [early stopping strategies](EarlyStoppingStrategy) should follow. +pub enum StoppingCondition { + /// When no improvement has happened since the given number of epochs. + NoImprovementSince { + /// The number of epochs allowed to worsen before it gets better. + n_epochs: usize, + }, +} + +/// A strategy that checks if the training should be stopped. +pub trait EarlyStoppingStrategy { + /// Update its current state and returns if the training should be stopped. + fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool; +} + +/// An [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected +/// during training or validation. +pub struct MetricEarlyStoppingStrategy { + condition: StoppingCondition, + metric_name: String, + aggregate: Aggregate, + direction: Direction, + split: Split, + best_epoch: usize, + best_value: f64, +} + +impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy { + fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool { + let current_value = + match store.find_metric(&self.metric_name, epoch, self.aggregate, self.split) { + Some(value) => value, + None => { + log::warn!("Can't find metric for early stopping."); + return false; + } + }; + + let is_best = match self.direction { + Direction::Lowest => current_value < self.best_value, + Direction::Highest => current_value > self.best_value, + }; + + if is_best { + log::info!( + "New best epoch found {} {}: {}", + epoch, + self.metric_name, + current_value + ); + self.best_value = current_value; + self.best_epoch = epoch; + return false; + } + + match self.condition { + StoppingCondition::NoImprovementSince { n_epochs } => { + let should_stop = epoch - self.best_epoch >= n_epochs; + + if should_stop { + log::info!("Stopping training loop, no improvement since epoch {}, {}: {}, current epoch {}, {}: {}", self.best_epoch, self.metric_name, self.best_value, epoch, self.metric_name, current_value); + } + + should_stop + } + } + } +} + +impl MetricEarlyStoppingStrategy { + /// Create a new [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected + /// during training or validation. + /// + /// # Notes + /// + /// The metric should be registered for early stopping to work, otherwise no data is collected. + pub fn new( + aggregate: Aggregate, + direction: Direction, + split: Split, + condition: StoppingCondition, + ) -> Self { + let init_value = match direction { + Direction::Lowest => f64::MAX, + Direction::Highest => f64::MIN, + }; + + Self { + metric_name: Me::NAME.to_string(), + condition, + aggregate, + direction, + split, + best_epoch: 1, + best_value: init_value, + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::{ + logger::InMemoryMetricLogger, + metric::{ + processor::{ + test_utils::{end_epoch, process_train}, + Metrics, MinimalEventProcessor, + }, + store::LogEventStore, + LossMetric, + }, + TestBackend, + }; + + use super::*; + + #[test] + fn never_early_stop_while_it_is_improving() { + test_early_stopping( + 1, + &[ + (&[0.5, 0.3], false, "Should not stop first epoch"), + (&[0.4, 0.3], false, "Should not stop when improving"), + (&[0.3, 0.3], false, "Should not stop when improving"), + (&[0.2, 0.3], false, "Should not stop when improving"), + ], + ); + } + + #[test] + fn early_stop_when_no_improvement_since_two_epochs() { + test_early_stopping( + 2, + &[ + (&[1.0, 0.5], false, "Should not stop first epoch"), + (&[0.5, 0.3], false, "Should not stop when improving"), + ( + &[1.0, 3.0], + false, + "Should not stop first time it gets worse", + ), + ( + &[1.0, 2.0], + true, + "Should stop since two following epochs didn't improve", + ), + ], + ); + } + + #[test] + fn early_stop_when_stays_equal() { + test_early_stopping( + 2, + &[ + (&[0.5, 0.3], false, "Should not stop first epoch"), + ( + &[0.5, 0.3], + false, + "Should not stop first time it stars the same", + ), + ( + &[0.5, 0.3], + true, + "Should stop since two following epochs didn't improve", + ), + ], + ); + } + + fn test_early_stopping(n_epochs: usize, data: &[(&[f64], bool, &str)]) { + let mut early_stopping = MetricEarlyStoppingStrategy::new::>( + Aggregate::Mean, + Direction::Lowest, + Split::Train, + StoppingCondition::NoImprovementSince { n_epochs }, + ); + let mut store = LogEventStore::default(); + let mut metrics = Metrics::::default(); + + store.register_logger_train(InMemoryMetricLogger::default()); + metrics.register_train_metric_numeric(LossMetric::::new()); + + let store = Arc::new(EventStoreClient::new(store)); + let mut processor = MinimalEventProcessor::new(metrics, store.clone()); + + let mut epoch = 1; + for (points, should_start, comment) in data { + for point in points.iter() { + process_train(&mut processor, *point, epoch); + } + end_epoch(&mut processor, epoch); + + assert_eq!( + *should_start, + early_stopping.should_stop(epoch, &store), + "{comment}" + ); + epoch += 1; + } + } +} diff --git a/burn-train/src/learner/epoch.rs b/burn-train/src/learner/epoch.rs index f5c932d7af..475c3ec4b4 100644 --- a/burn-train/src/learner/epoch.rs +++ b/burn-train/src/learner/epoch.rs @@ -4,8 +4,9 @@ use burn_core::{ }; use std::sync::Arc; -use crate::{components::LearnerComponents, learner::base::TrainingInterrupter, Event}; -use crate::{EventCollector, LearnerItem, MultiDevicesTrainStep, TrainStep, ValidStep}; +use crate::metric::processor::{Event, EventProcessor, LearnerItem}; +use crate::{components::LearnerComponents, learner::base::TrainingInterrupter}; +use crate::{MultiDevicesTrainStep, TrainStep, ValidStep}; /// A validation epoch. #[derive(new)] @@ -30,14 +31,14 @@ impl ValidEpoch { /// # Arguments /// /// * `model` - The model to validate. - /// * `callback` - The callback to use. + /// * `processor` - The event processor to use. pub fn run( &self, model: &LC::Model, - callback: &mut LC::EventCollector, + processor: &mut LC::EventProcessor, interrupter: &TrainingInterrupter, ) where - LC::EventCollector: EventCollector, + LC::EventProcessor: EventProcessor, >::InnerModule: ValidStep, { log::info!("Executing validation step for epoch {}", self.epoch); @@ -60,14 +61,14 @@ impl ValidEpoch { None, ); - callback.on_event_valid(Event::ProcessedItem(item)); + processor.process_valid(Event::ProcessedItem(item)); if interrupter.should_stop() { log::info!("Training interrupted."); break; } } - callback.on_event_valid(Event::EndEpoch(self.epoch)); + processor.process_valid(Event::EndEpoch(self.epoch)); } } @@ -79,7 +80,7 @@ impl TrainEpoch { /// * `model` - The model to train. /// * `optim` - The optimizer to use. /// * `scheduler` - The learning rate scheduler to use. - /// * `callback` - The callback to use. + /// * `processor` - The event processor to use. /// /// # Returns /// @@ -89,11 +90,11 @@ impl TrainEpoch { mut model: LC::Model, mut optim: LC::Optimizer, scheduler: &mut LC::LrScheduler, - callback: &mut LC::EventCollector, + processor: &mut LC::EventProcessor, interrupter: &TrainingInterrupter, ) -> (LC::Model, LC::Optimizer) where - LC::EventCollector: EventCollector, + LC::EventProcessor: EventProcessor, LC::Model: TrainStep, { log::info!("Executing training step for epoch {}", self.epoch,); @@ -134,13 +135,14 @@ impl TrainEpoch { Some(lr), ); - callback.on_event_train(Event::ProcessedItem(item)); + processor.process_train(Event::ProcessedItem(item)); + if interrupter.should_stop() { log::info!("Training interrupted."); break; } } - callback.on_event_train(Event::EndEpoch(self.epoch)); + processor.process_train(Event::EndEpoch(self.epoch)); (model, optim) } @@ -154,7 +156,7 @@ impl TrainEpoch { /// * `model` - The model to train. /// * `optim` - The optimizer to use. /// * `lr_scheduler` - The learning rate scheduler to use. - /// * `callback` - The callback to use. + /// * `processor` - The event processor to use. /// * `devices` - The devices to use. /// /// # Returns @@ -165,12 +167,12 @@ impl TrainEpoch { mut model: LC::Model, mut optim: LC::Optimizer, lr_scheduler: &mut LC::LrScheduler, - callback: &mut LC::EventCollector, + processor: &mut LC::EventProcessor, devices: Vec<::Device>, interrupter: &TrainingInterrupter, ) -> (LC::Model, LC::Optimizer) where - LC::EventCollector: EventCollector, + LC::EventProcessor: EventProcessor, LC::Model: TrainStep, TO: Send + 'static, TI: Send + 'static, @@ -224,7 +226,7 @@ impl TrainEpoch { Some(lr), ); - callback.on_event_train(Event::ProcessedItem(item)); + processor.process_train(Event::ProcessedItem(item)); if interrupter.should_stop() { log::info!("Training interrupted."); @@ -238,7 +240,7 @@ impl TrainEpoch { } } - callback.on_event_train(Event::EndEpoch(self.epoch)); + processor.process_train(Event::EndEpoch(self.epoch)); (model, optim) } diff --git a/burn-train/src/learner/mod.rs b/burn-train/src/learner/mod.rs index 5ef0a1c99e..e01080475a 100644 --- a/burn-train/src/learner/mod.rs +++ b/burn-train/src/learner/mod.rs @@ -1,6 +1,7 @@ mod base; mod builder; mod classification; +mod early_stopping; mod epoch; mod regression; mod step; @@ -11,6 +12,7 @@ pub(crate) mod log; pub use base::*; pub use builder::*; pub use classification::*; +pub use early_stopping::*; pub use epoch::*; pub use regression::*; pub use step::*; diff --git a/burn-train/src/learner/train_val.rs b/burn-train/src/learner/train_val.rs index da3f081e26..0a06c8c4f9 100644 --- a/burn-train/src/learner/train_val.rs +++ b/burn-train/src/learner/train_val.rs @@ -1,5 +1,6 @@ use crate::components::LearnerComponents; -use crate::{EventCollector, Learner, TrainEpoch, ValidEpoch}; +use crate::metric::processor::EventProcessor; +use crate::{Learner, TrainEpoch, ValidEpoch}; use burn_core::data::dataloader::DataLoader; use burn_core::module::{ADModule, Module}; use burn_core::optim::{GradientsParams, Optimizer}; @@ -115,7 +116,7 @@ impl Learner { OutputValid: Send, LC::Model: TrainStep, >::InnerModule: ValidStep, - LC::EventCollector: EventCollector, + LC::EventProcessor: EventProcessor, { log::info!("Fitting {}", self.model.to_string()); // The reference model is always on the first device provided. @@ -151,7 +152,7 @@ impl Learner { self.model, self.optim, &mut self.lr_scheduler, - &mut self.collector, + &mut self.event_processor, self.devices.clone(), &self.interrupter, ) @@ -160,7 +161,7 @@ impl Learner { self.model, self.optim, &mut self.lr_scheduler, - &mut self.collector, + &mut self.event_processor, &self.interrupter, ); } @@ -170,7 +171,11 @@ impl Learner { } let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs); - epoch_valid.run::(&self.model, &mut self.collector, &self.interrupter); + epoch_valid.run::( + &self.model, + &mut self.event_processor, + &self.interrupter, + ); if let Some(checkpointer) = &mut self.checkpointer { checkpointer.checkpoint( @@ -178,9 +183,15 @@ impl Learner { &self.optim, &self.lr_scheduler, epoch, - &mut self.collector, + &self.event_store, ); } + + if let Some(early_stopping) = &mut self.early_stopping { + if early_stopping.should_stop(epoch, &self.event_store) { + break; + } + } } self.model diff --git a/burn-train/src/lib.rs b/burn-train/src/lib.rs index 23d1a5e817..dcd1cfa4fc 100644 --- a/burn-train/src/lib.rs +++ b/burn-train/src/lib.rs @@ -19,13 +19,8 @@ pub mod logger; /// The metric module. pub mod metric; -/// All information collected during training. -pub mod info; - -mod collector; mod learner; -pub use collector::*; pub use learner::*; #[cfg(test)] diff --git a/burn-train/src/logger/in_memory.rs b/burn-train/src/logger/in_memory.rs new file mode 100644 index 0000000000..31cf3f165c --- /dev/null +++ b/burn-train/src/logger/in_memory.rs @@ -0,0 +1,16 @@ +use super::Logger; + +/// In memory logger. +#[derive(Default)] +pub struct InMemoryLogger { + pub(crate) values: Vec, +} + +impl Logger for InMemoryLogger +where + T: std::fmt::Display, +{ + fn log(&mut self, item: T) { + self.values.push(item.to_string()); + } +} diff --git a/burn-train/src/logger/metric.rs b/burn-train/src/logger/metric.rs index 450104851b..100e7a10ef 100644 --- a/burn-train/src/logger/metric.rs +++ b/burn-train/src/logger/metric.rs @@ -1,4 +1,4 @@ -use super::{AsyncLogger, FileLogger, Logger}; +use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger}; use crate::metric::MetricEntry; use std::collections::HashMap; @@ -16,7 +16,7 @@ pub trait MetricLogger: Send { /// # Arguments /// /// * `epoch` - The epoch. - fn epoch(&mut self, epoch: usize); + fn end_epoch(&mut self, epoch: usize); /// Read the logs for an epoch. fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String>; @@ -81,9 +81,9 @@ impl MetricLogger for FileMetricLogger { logger.log(value.clone()); } - fn epoch(&mut self, epoch: usize) { + fn end_epoch(&mut self, epoch: usize) { self.loggers.clear(); - self.epoch = epoch; + self.epoch = epoch + 1; } fn read_numeric(&mut self, name: &str, epoch: usize) -> Result, String> { @@ -125,23 +125,24 @@ impl MetricLogger for FileMetricLogger { /// In memory metric logger, useful when testing and debugging. #[derive(Default)] pub struct InMemoryMetricLogger { - values: HashMap>>, + values: HashMap>, } impl MetricLogger for InMemoryMetricLogger { fn log(&mut self, item: &MetricEntry) { if !self.values.contains_key(&item.name) { - self.values.insert(item.name.clone(), vec![vec![]]); + self.values + .insert(item.name.clone(), vec![InMemoryLogger::default()]); } let values = self.values.get_mut(&item.name).unwrap(); - values.last_mut().unwrap().push(item.serialize.clone()); + values.last_mut().unwrap().log(item.serialize.clone()); } - fn epoch(&mut self, _epoch: usize) { + fn end_epoch(&mut self, _epoch: usize) { for (_, values) in self.values.iter_mut() { - values.push(Vec::new()); + values.push(InMemoryLogger::default()); } } @@ -152,7 +153,8 @@ impl MetricLogger for InMemoryMetricLogger { }; match values.get(epoch - 1) { - Some(values) => Ok(values + Some(logger) => Ok(logger + .values .iter() .filter_map(|value| value.parse::().ok()) .collect()), diff --git a/burn-train/src/logger/mod.rs b/burn-train/src/logger/mod.rs index 996257c226..df727a35ea 100644 --- a/burn-train/src/logger/mod.rs +++ b/burn-train/src/logger/mod.rs @@ -1,9 +1,11 @@ mod async_logger; mod base; mod file; +mod in_memory; mod metric; pub use async_logger::*; pub use base::*; pub use file::*; +pub use in_memory::*; pub use metric::*; diff --git a/burn-train/src/metric/base.rs b/burn-train/src/metric/base.rs index d199290435..d84c039511 100644 --- a/burn-train/src/metric/base.rs +++ b/burn-train/src/metric/base.rs @@ -74,7 +74,7 @@ pub trait Numeric { } /// Data type that contains the current state of a metric at a given time. -#[derive(new, Debug)] +#[derive(new, Debug, Clone)] pub struct MetricEntry { /// The name of the metric. pub name: String, diff --git a/burn-train/src/metric/mod.rs b/burn-train/src/metric/mod.rs index 75ecdc132c..37ad5af73b 100644 --- a/burn-train/src/metric/mod.rs +++ b/burn-train/src/metric/mod.rs @@ -26,3 +26,7 @@ pub use learning_rate::*; pub use loss::*; #[cfg(feature = "metrics")] pub use memory_use::*; + +pub(crate) mod processor; +/// Module responsible to save and exposes data collected during training. +pub mod store; diff --git a/burn-train/src/metric/processor/base.rs b/burn-train/src/metric/processor/base.rs new file mode 100644 index 0000000000..9093d26457 --- /dev/null +++ b/burn-train/src/metric/processor/base.rs @@ -0,0 +1,45 @@ +use burn_core::data::dataloader::Progress; +use burn_core::LearningRate; + +/// Event happening during the training/validation process. +pub enum Event { + /// Signal that an item have been processed. + ProcessedItem(LearnerItem), + /// Signal the end of an epoch. + EndEpoch(usize), +} + +/// Process events happening during training and validation. +pub trait EventProcessor { + /// The training item. + type ItemTrain; + /// The validation item. + type ItemValid; + + /// Collect a training event. + fn process_train(&mut self, event: Event); + /// Collect a validation event. + fn process_valid(&mut self, event: Event); +} + +/// A learner item. +#[derive(new)] +pub struct LearnerItem { + /// The item. + pub item: T, + + /// The progress. + pub progress: Progress, + + /// The epoch. + pub epoch: usize, + + /// The total number of epochs. + pub epoch_total: usize, + + /// The iteration. + pub iteration: usize, + + /// The learning rate. + pub lr: Option, +} diff --git a/burn-train/src/metric/processor/full.rs b/burn-train/src/metric/processor/full.rs new file mode 100644 index 0000000000..b25870dfb4 --- /dev/null +++ b/burn-train/src/metric/processor/full.rs @@ -0,0 +1,100 @@ +use super::{Event, EventProcessor, Metrics}; +use crate::metric::store::EventStoreClient; +use crate::renderer::{MetricState, MetricsRenderer}; +use std::sync::Arc; + +/// An [event processor](EventProcessor) that handles: +/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore). +/// - Render metrics using a [metrics renderer](MetricsRenderer). +pub struct FullEventProcessor { + metrics: Metrics, + renderer: Box, + store: Arc, +} + +impl FullEventProcessor { + pub(crate) fn new( + metrics: Metrics, + renderer: Box, + store: Arc, + ) -> Self { + Self { + metrics, + renderer, + store, + } + } +} + +impl EventProcessor for FullEventProcessor { + type ItemTrain = T; + type ItemValid = V; + + fn process_train(&mut self, event: Event) { + match event { + Event::ProcessedItem(item) => { + let progress = (&item).into(); + let metadata = (&item).into(); + + let update = self.metrics.update_train(&item, &metadata); + + self.store + .add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone())); + + update + .entries + .into_iter() + .for_each(|entry| self.renderer.update_train(MetricState::Generic(entry))); + + update + .entries_numeric + .into_iter() + .for_each(|(entry, value)| { + self.renderer + .update_train(MetricState::Numeric(entry, value)) + }); + + self.renderer.render_train(progress); + } + Event::EndEpoch(epoch) => { + self.metrics.end_epoch_train(); + self.store + .add_event_train(crate::metric::store::Event::EndEpoch(epoch)); + } + } + } + + fn process_valid(&mut self, event: Event) { + match event { + Event::ProcessedItem(item) => { + let progress = (&item).into(); + let metadata = (&item).into(); + + let update = self.metrics.update_valid(&item, &metadata); + + self.store + .add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone())); + + update + .entries + .into_iter() + .for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry))); + + update + .entries_numeric + .into_iter() + .for_each(|(entry, value)| { + self.renderer + .update_valid(MetricState::Numeric(entry, value)) + }); + + self.renderer.render_valid(progress); + } + Event::EndEpoch(epoch) => { + self.metrics.end_epoch_valid(); + self.store + .add_event_valid(crate::metric::store::Event::EndEpoch(epoch)); + } + } + } +} diff --git a/burn-train/src/info/metrics.rs b/burn-train/src/metric/processor/metrics.rs similarity index 59% rename from burn-train/src/info/metrics.rs rename to burn-train/src/metric/processor/metrics.rs index 6a94b00636..e2992f12b0 100644 --- a/burn-train/src/info/metrics.rs +++ b/burn-train/src/metric/processor/metrics.rs @@ -1,74 +1,66 @@ -use super::NumericMetricsAggregate; +use super::LearnerItem; use crate::{ - logger::MetricLogger, - metric::{Adaptor, Metric, MetricEntry, MetricMetadata, Numeric}, - Aggregate, Direction, LearnerItem, Split, + metric::{store::MetricsUpdate, Adaptor, Metric, MetricEntry, MetricMetadata, Numeric}, + renderer::TrainingProgress, }; -/// Metrics information collected during training. -pub struct MetricsInfo -where - T: Send + Sync + 'static, - V: Send + Sync + 'static, -{ +pub(crate) struct Metrics { train: Vec>>, valid: Vec>>, train_numeric: Vec>>, valid_numeric: Vec>>, - loggers_train: Vec>, - loggers_valid: Vec>, - aggregate_train: NumericMetricsAggregate, - aggregate_valid: NumericMetricsAggregate, -} - -#[derive(new)] -pub(crate) struct MetricsUpdate { - pub(crate) entries: Vec, - pub(crate) entries_numeric: Vec<(MetricEntry, f64)>, } -impl MetricsInfo -where - T: Send + Sync + 'static, - V: Send + Sync + 'static, -{ - pub(crate) fn new() -> Self { +impl Default for Metrics { + fn default() -> Self { Self { - train: vec![], - valid: vec![], - train_numeric: vec![], - valid_numeric: vec![], - loggers_train: vec![], - loggers_valid: vec![], - aggregate_train: NumericMetricsAggregate::default(), - aggregate_valid: NumericMetricsAggregate::default(), + train: Vec::default(), + valid: Vec::default(), + train_numeric: Vec::default(), + valid_numeric: Vec::default(), } } +} - /// Signal the end of a training epoch. - pub(crate) fn end_epoch_train(&mut self, epoch: usize) { - for metric in self.train.iter_mut() { - metric.clear(); - } - for metric in self.train_numeric.iter_mut() { - metric.clear(); - } - for logger in self.loggers_train.iter_mut() { - logger.epoch(epoch + 1); - } +impl Metrics { + /// Register a training metric. + pub(crate) fn register_metric_train(&mut self, metric: Me) + where + T: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.train.push(Box::new(metric)) } - /// Signal the end of a validation epoch. - pub(crate) fn end_epoch_valid(&mut self, epoch: usize) { - for metric in self.valid.iter_mut() { - metric.clear(); - } - for metric in self.valid_numeric.iter_mut() { - metric.clear(); - } - for logger in self.loggers_valid.iter_mut() { - logger.epoch(epoch + 1); - } + /// Register a validation metric. + pub(crate) fn register_valid_metric(&mut self, metric: Me) + where + V: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.valid.push(Box::new(metric)) + } + + /// Register a numeric training metric. + pub(crate) fn register_train_metric_numeric( + &mut self, + metric: Me, + ) where + T: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.train_numeric.push(Box::new(metric)) + } + + /// Register a numeric validation metric. + pub(crate) fn register_valid_metric_numeric( + &mut self, + metric: Me, + ) where + V: Adaptor + 'static, + { + let metric = MetricWrapper::new(metric); + self.valid_numeric.push(Box::new(metric)) } /// Update the training information from the training item. @@ -82,20 +74,11 @@ where for metric in self.train.iter_mut() { let state = metric.update(item, metadata); - - for logger in self.loggers_train.iter_mut() { - logger.log(&state); - } - entries.push(state); } for metric in self.train_numeric.iter_mut() { let (state, value) = metric.update(item, metadata); - for logger in self.loggers_train.iter_mut() { - logger.log(&state); - } - entries_numeric.push((state, value)); } @@ -113,94 +96,58 @@ where for metric in self.valid.iter_mut() { let state = metric.update(item, metadata); - - for logger in self.loggers_valid.iter_mut() { - logger.log(&state); - } - entries.push(state); } for metric in self.valid_numeric.iter_mut() { let (state, value) = metric.update(item, metadata); - for logger in self.loggers_valid.iter_mut() { - logger.log(&state); - } - entries_numeric.push((state, value)); } MetricsUpdate::new(entries, entries_numeric) } - /// Find the epoch corresponding to the given criteria. - pub(crate) fn find_epoch( - &mut self, - name: &str, - aggregate: Aggregate, - direction: Direction, - split: Split, - ) -> Option { - match split { - Split::Train => { - self.aggregate_train - .find_epoch(name, aggregate, direction, &mut self.loggers_train) - } - Split::Valid => { - self.aggregate_valid - .find_epoch(name, aggregate, direction, &mut self.loggers_valid) - } + /// Signal the end of a training epoch. + pub(crate) fn end_epoch_train(&mut self) { + for metric in self.train.iter_mut() { + metric.clear(); + } + for metric in self.train_numeric.iter_mut() { + metric.clear(); } } - /// Register a logger for training metrics. - pub(crate) fn register_logger_train(&mut self, logger: ML) { - self.loggers_train.push(Box::new(logger)); - } - - /// Register a logger for validation metrics. - pub(crate) fn register_logger_valid(&mut self, logger: ML) { - self.loggers_valid.push(Box::new(logger)); - } - - /// Register a training metric. - pub(crate) fn register_metric_train(&mut self, metric: Me) - where - T: Adaptor, - { - let metric = MetricWrapper::new(metric); - self.train.push(Box::new(metric)) - } - - /// Register a validation metric. - pub(crate) fn register_valid_metric(&mut self, metric: Me) - where - V: Adaptor, - { - let metric = MetricWrapper::new(metric); - self.valid.push(Box::new(metric)) + /// Signal the end of a validation epoch. + pub(crate) fn end_epoch_valid(&mut self) { + for metric in self.valid.iter_mut() { + metric.clear(); + } + for metric in self.valid_numeric.iter_mut() { + metric.clear(); + } } +} - /// Register a numeric training metric. - pub(crate) fn register_train_metric_numeric( - &mut self, - metric: Me, - ) where - T: Adaptor, - { - let metric = MetricWrapper::new(metric); - self.train_numeric.push(Box::new(metric)) +impl From<&LearnerItem> for TrainingProgress { + fn from(item: &LearnerItem) -> Self { + Self { + progress: item.progress.clone(), + epoch: item.epoch, + epoch_total: item.epoch_total, + iteration: item.iteration, + } } +} - /// Register a numeric validation metric. - pub(crate) fn register_valid_metric_numeric( - &mut self, - metric: Me, - ) where - V: Adaptor, - { - let metric = MetricWrapper::new(metric); - self.valid_numeric.push(Box::new(metric)) +impl From<&LearnerItem> for MetricMetadata { + fn from(item: &LearnerItem) -> Self { + Self { + progress: item.progress.clone(), + epoch: item.epoch, + epoch_total: item.epoch_total, + iteration: item.iteration, + lr: item.lr, + } } } diff --git a/burn-train/src/metric/processor/minimal.rs b/burn-train/src/metric/processor/minimal.rs new file mode 100644 index 0000000000..bb60713e45 --- /dev/null +++ b/burn-train/src/metric/processor/minimal.rs @@ -0,0 +1,52 @@ +use super::{Event, EventProcessor, Metrics}; +use crate::metric::store::EventStoreClient; +use std::sync::Arc; + +/// An [event processor](EventProcessor) that handles: +/// - Computing and storing metrics in an [event store](crate::metric::store::EventStore). +#[derive(new)] +pub(crate) struct MinimalEventProcessor { + metrics: Metrics, + store: Arc, +} + +impl EventProcessor for MinimalEventProcessor { + type ItemTrain = T; + type ItemValid = V; + + fn process_train(&mut self, event: Event) { + match event { + Event::ProcessedItem(item) => { + let metadata = (&item).into(); + + let update = self.metrics.update_train(&item, &metadata); + + self.store + .add_event_train(crate::metric::store::Event::MetricsUpdate(update)); + } + Event::EndEpoch(epoch) => { + self.metrics.end_epoch_train(); + self.store + .add_event_train(crate::metric::store::Event::EndEpoch(epoch)); + } + } + } + + fn process_valid(&mut self, event: Event) { + match event { + Event::ProcessedItem(item) => { + let metadata = (&item).into(); + + let update = self.metrics.update_valid(&item, &metadata); + + self.store + .add_event_valid(crate::metric::store::Event::MetricsUpdate(update)); + } + Event::EndEpoch(epoch) => { + self.metrics.end_epoch_valid(); + self.store + .add_event_valid(crate::metric::store::Event::EndEpoch(epoch)); + } + } + } +} diff --git a/burn-train/src/metric/processor/mod.rs b/burn-train/src/metric/processor/mod.rs new file mode 100644 index 0000000000..f889894098 --- /dev/null +++ b/burn-train/src/metric/processor/mod.rs @@ -0,0 +1,53 @@ +mod base; +mod full; +mod metrics; +mod minimal; + +pub use base::*; +pub(crate) use full::*; +pub(crate) use metrics::*; + +#[cfg(test)] +pub(crate) use minimal::*; + +#[cfg(test)] +pub(crate) mod test_utils { + use crate::metric::{ + processor::{Event, EventProcessor, LearnerItem, MinimalEventProcessor}, + Adaptor, LossInput, + }; + use burn_core::tensor::{backend::Backend, ElementConversion, Tensor}; + + impl Adaptor> for f64 { + fn adapt(&self) -> LossInput { + LossInput::new(Tensor::from_data([self.elem()])) + } + } + + pub(crate) fn process_train( + processor: &mut MinimalEventProcessor, + value: f64, + epoch: usize, + ) { + let dummy_progress = burn_core::data::dataloader::Progress { + items_processed: 1, + items_total: 10, + }; + let num_epochs = 3; + let dummy_iteration = 1; + + processor.process_train(Event::ProcessedItem(LearnerItem::new( + value, + dummy_progress, + epoch, + num_epochs, + dummy_iteration, + None, + ))); + } + + pub(crate) fn end_epoch(processor: &mut MinimalEventProcessor, epoch: usize) { + processor.process_train(Event::EndEpoch(epoch)); + processor.process_valid(Event::EndEpoch(epoch)); + } +} diff --git a/burn-train/src/info/aggregates.rs b/burn-train/src/metric/store/aggregate.rs similarity index 82% rename from burn-train/src/info/aggregates.rs rename to burn-train/src/metric/store/aggregate.rs index e270b14ac9..679f6fa22e 100644 --- a/burn-train/src/info/aggregates.rs +++ b/burn-train/src/metric/store/aggregate.rs @@ -1,28 +1,32 @@ -use crate::{logger::MetricLogger, Aggregate, Direction}; +use crate::logger::MetricLogger; use std::collections::HashMap; +use super::{Aggregate, Direction}; + /// Type that can be used to fetch and use numeric metric aggregates. #[derive(Default, Debug)] pub(crate) struct NumericMetricsAggregate { - mean_for_each_epoch: HashMap, + value_for_each_epoch: HashMap, } #[derive(new, Hash, PartialEq, Eq, Debug)] struct Key { name: String, epoch: usize, + aggregate: Aggregate, } impl NumericMetricsAggregate { - pub(crate) fn mean( + pub(crate) fn aggregate( &mut self, name: &str, epoch: usize, + aggregate: Aggregate, loggers: &mut [Box], ) -> Option { - let key = Key::new(name.to_string(), epoch); + let key = Key::new(name.to_string(), epoch, aggregate); - if let Some(value) = self.mean_for_each_epoch.get(&key) { + if let Some(value) = self.value_for_each_epoch.get(&key) { return Some(*value); } @@ -45,10 +49,13 @@ impl NumericMetricsAggregate { } let num_points = points.len(); - let mean = points.into_iter().sum::() / num_points as f64; + let sum = points.into_iter().sum::(); + let value = match aggregate { + Aggregate::Mean => sum / num_points as f64, + }; - self.mean_for_each_epoch.insert(key, mean); - Some(mean) + self.value_for_each_epoch.insert(key, value); + Some(value) } pub(crate) fn find_epoch( @@ -61,16 +68,8 @@ impl NumericMetricsAggregate { let mut data = Vec::new(); let mut current_epoch = 1; - loop { - match aggregate { - Aggregate::Mean => match self.mean(name, current_epoch, loggers) { - Some(value) => { - data.push(value); - } - None => break, - }, - }; - + while let Some(value) = self.aggregate(name, current_epoch, aggregate, loggers) { + data.push(value); current_epoch += 1; } @@ -131,8 +130,8 @@ mod tests { )); } fn new_epoch(&mut self) { + self.logger.end_epoch(self.epoch); self.epoch += 1; - self.logger.epoch(self.epoch); } } diff --git a/burn-train/src/metric/store/base.rs b/burn-train/src/metric/store/base.rs new file mode 100644 index 0000000000..51592a683c --- /dev/null +++ b/burn-train/src/metric/store/base.rs @@ -0,0 +1,69 @@ +use crate::metric::MetricEntry; + +/// Event happening during the training/validation process. +pub enum Event { + /// Signal that metrics have been updated. + MetricsUpdate(MetricsUpdate), + /// Signal the end of an epoch. + EndEpoch(usize), +} + +/// Contains all metric information. +#[derive(new, Clone)] +pub struct MetricsUpdate { + /// Metrics information related to non-numeric metrics. + pub entries: Vec, + /// Metrics information related to numeric metrics. + pub entries_numeric: Vec<(MetricEntry, f64)>, +} + +/// Defines how training and validation events are collected and searched. +/// +/// This trait also exposes methods that uses the collected data to compute useful information. +pub trait EventStore: Send { + /// Collect a training/validation event. + fn add_event(&mut self, event: Event, split: Split); + + /// Find the epoch following the given criteria from the collected data. + fn find_epoch( + &mut self, + name: &str, + aggregate: Aggregate, + direction: Direction, + split: Split, + ) -> Option; + + /// Find the metric value for the current epoch following the given criteria. + fn find_metric( + &mut self, + name: &str, + epoch: usize, + aggregate: Aggregate, + split: Split, + ) -> Option; +} + +#[derive(Copy, Clone, Hash, PartialEq, Eq, Debug)] +/// How to aggregate the metric. +pub enum Aggregate { + /// Compute the average. + Mean, +} + +#[derive(Copy, Clone)] +/// The split to use. +pub enum Split { + /// The training split. + Train, + /// The validation split. + Valid, +} + +#[derive(Copy, Clone)] +/// The direction of the query. +pub enum Direction { + /// Lower is better. + Lowest, + /// Higher is better. + Highest, +} diff --git a/burn-train/src/metric/store/client.rs b/burn-train/src/metric/store/client.rs new file mode 100644 index 0000000000..e4f4d34be5 --- /dev/null +++ b/burn-train/src/metric/store/client.rs @@ -0,0 +1,149 @@ +use super::EventStore; +use super::{Aggregate, Direction, Event, Split}; +use std::{sync::mpsc, thread::JoinHandle}; + +/// Type that allows to communicate with an [event store](EventStore). +pub struct EventStoreClient { + sender: mpsc::Sender, + handler: Option>, +} + +impl EventStoreClient { + /// Create a new [event store](EventStore) client. + pub(crate) fn new(store: C) -> Self + where + C: EventStore + 'static, + { + let (sender, receiver) = mpsc::channel(); + let thread = WorkerThread::new(store, receiver); + + let handler = std::thread::spawn(move || thread.run()); + let handler = Some(handler); + + Self { sender, handler } + } +} + +impl EventStoreClient { + /// Add a training event to the [event store](EventStore). + pub(crate) fn add_event_train(&self, event: Event) { + self.sender.send(Message::OnEventTrain(event)).unwrap(); + } + + /// Add a validation event to the [event store](EventStore). + pub(crate) fn add_event_valid(&self, event: Event) { + self.sender.send(Message::OnEventValid(event)).unwrap(); + } + + /// Find the epoch following the given criteria from the collected data. + pub fn find_epoch( + &self, + name: &str, + aggregate: Aggregate, + direction: Direction, + split: Split, + ) -> Option { + let (sender, receiver) = mpsc::sync_channel(1); + self.sender + .send(Message::FindEpoch( + name.to_string(), + aggregate, + direction, + split, + sender, + )) + .unwrap(); + + match receiver.recv() { + Ok(value) => value, + Err(err) => panic!("Event store thread crashed: {:?}", err), + } + } + + /// Find the metric value for the current epoch following the given criteria. + pub fn find_metric( + &self, + name: &str, + epoch: usize, + aggregate: Aggregate, + split: Split, + ) -> Option { + let (sender, receiver) = mpsc::sync_channel(1); + self.sender + .send(Message::FindMetric( + name.to_string(), + epoch, + aggregate, + split, + sender, + )) + .unwrap(); + + match receiver.recv() { + Ok(value) => value, + Err(err) => panic!("Event store thread crashed: {:?}", err), + } + } +} + +#[derive(new)] +struct WorkerThread { + store: S, + receiver: mpsc::Receiver, +} + +impl WorkerThread +where + C: EventStore, +{ + fn run(mut self) { + for item in self.receiver.iter() { + match item { + Message::End => { + return; + } + Message::FindEpoch(name, aggregate, direction, split, sender) => { + let response = self.store.find_epoch(&name, aggregate, direction, split); + sender.send(response).unwrap(); + } + Message::FindMetric(name, epoch, aggregate, split, sender) => { + let response = self.store.find_metric(&name, epoch, aggregate, split); + sender.send(response).unwrap(); + } + Message::OnEventTrain(event) => self.store.add_event(event, Split::Train), + Message::OnEventValid(event) => self.store.add_event(event, Split::Valid), + } + } + } +} + +enum Message { + OnEventTrain(Event), + OnEventValid(Event), + End, + FindEpoch( + String, + Aggregate, + Direction, + Split, + mpsc::SyncSender>, + ), + FindMetric( + String, + usize, + Aggregate, + Split, + mpsc::SyncSender>, + ), +} + +impl Drop for EventStoreClient { + fn drop(&mut self) { + self.sender.send(Message::End).unwrap(); + let handler = self.handler.take(); + + if let Some(handler) = handler { + handler.join().unwrap(); + } + } +} diff --git a/burn-train/src/metric/store/log.rs b/burn-train/src/metric/store/log.rs new file mode 100644 index 0000000000..9272e32330 --- /dev/null +++ b/burn-train/src/metric/store/log.rs @@ -0,0 +1,101 @@ +use super::{aggregate::NumericMetricsAggregate, Aggregate, Direction, Event, EventStore, Split}; +use crate::logger::MetricLogger; + +#[derive(Default)] +pub(crate) struct LogEventStore { + loggers_train: Vec>, + loggers_valid: Vec>, + aggregate_train: NumericMetricsAggregate, + aggregate_valid: NumericMetricsAggregate, +} + +impl EventStore for LogEventStore { + fn add_event(&mut self, event: Event, split: Split) { + match event { + Event::MetricsUpdate(update) => match split { + Split::Train => { + update + .entries + .iter() + .chain(update.entries_numeric.iter().map(|(entry, _value)| entry)) + .for_each(|entry| { + self.loggers_train + .iter_mut() + .for_each(|logger| logger.log(entry)); + }); + } + Split::Valid => { + update + .entries + .iter() + .chain(update.entries_numeric.iter().map(|(entry, _value)| entry)) + .for_each(|entry| { + self.loggers_valid + .iter_mut() + .for_each(|logger| logger.log(entry)); + }); + } + }, + Event::EndEpoch(epoch) => match split { + Split::Train => self + .loggers_train + .iter_mut() + .for_each(|logger| logger.end_epoch(epoch)), + Split::Valid => self + .loggers_valid + .iter_mut() + .for_each(|logger| logger.end_epoch(epoch + 1)), + }, + } + } + + fn find_epoch( + &mut self, + name: &str, + aggregate: Aggregate, + direction: Direction, + split: Split, + ) -> Option { + match split { + Split::Train => { + self.aggregate_train + .find_epoch(name, aggregate, direction, &mut self.loggers_train) + } + Split::Valid => { + self.aggregate_valid + .find_epoch(name, aggregate, direction, &mut self.loggers_valid) + } + } + } + + fn find_metric( + &mut self, + name: &str, + epoch: usize, + aggregate: Aggregate, + split: Split, + ) -> Option { + match split { + Split::Train => { + self.aggregate_train + .aggregate(name, epoch, aggregate, &mut self.loggers_train) + } + Split::Valid => { + self.aggregate_valid + .aggregate(name, epoch, aggregate, &mut self.loggers_valid) + } + } + } +} + +impl LogEventStore { + /// Register a logger for training metrics. + pub(crate) fn register_logger_train(&mut self, logger: ML) { + self.loggers_train.push(Box::new(logger)); + } + + /// Register a logger for validation metrics. + pub(crate) fn register_logger_valid(&mut self, logger: ML) { + self.loggers_valid.push(Box::new(logger)); + } +} diff --git a/burn-train/src/metric/store/mod.rs b/burn-train/src/metric/store/mod.rs new file mode 100644 index 0000000000..b86c0f4c2d --- /dev/null +++ b/burn-train/src/metric/store/mod.rs @@ -0,0 +1,9 @@ +pub(crate) mod aggregate; + +mod base; +mod client; +mod log; + +pub(crate) use self::log::*; +pub use base::*; +pub use client::*; diff --git a/burn-train/src/renderer/cli.rs b/burn-train/src/renderer/cli.rs index 9bd01037df..66fd6b2d18 100644 --- a/burn-train/src/renderer/cli.rs +++ b/burn-train/src/renderer/cli.rs @@ -1,4 +1,4 @@ -use crate::metric::callback::{MetricState, MetricsRenderer, TrainingProgress}; +use crate::metric::renderer::{MetricState, MetricsRenderer, TrainingProgress}; /// A simple renderer for when the cli feature is not enabled. pub struct CliMetricsRenderer; diff --git a/examples/mnist/src/model.rs b/examples/mnist/src/model.rs index 7d04fff30d..10ae34d75c 100644 --- a/examples/mnist/src/model.rs +++ b/examples/mnist/src/model.rs @@ -1,5 +1,4 @@ use crate::data::MNISTBatch; - use burn::{ module::Module, nn::{self, loss::CrossEntropyLoss, BatchNorm, PaddingConfig2d}, diff --git a/examples/mnist/src/training.rs b/examples/mnist/src/training.rs index 900069d594..d08e78e89a 100644 --- a/examples/mnist/src/training.rs +++ b/examples/mnist/src/training.rs @@ -5,7 +5,9 @@ use burn::module::Module; use burn::optim::decay::WeightDecayConfig; use burn::optim::AdamConfig; use burn::record::{CompactRecorder, NoStdTrainingRecorder}; +use burn::train::metric::store::{Aggregate, Direction, Split}; use burn::train::metric::{CpuMemory, CpuTemperature, CpuUse}; +use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition}; use burn::{ config::Config, data::{dataloader::DataLoaderBuilder, dataset::source::huggingface::MNISTDataset}, @@ -69,6 +71,12 @@ pub fn run(device: B::Device) { .metric_train_numeric(LossMetric::new()) .metric_valid_numeric(LossMetric::new()) .with_file_checkpointer(CompactRecorder::new()) + .early_stopping(MetricEarlyStoppingStrategy::new::>( + Aggregate::Mean, + Direction::Lowest, + Split::Valid, + StoppingCondition::NoImprovementSince { n_epochs: 1 }, + )) .devices(vec![device]) .num_epochs(config.num_epochs) .build(Model::new(), config.optimizer.init(), 1e-4);