Skip to content

Commit

Permalink
Feat/early stopping + burn train refactor (#878)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Oct 20, 2023
1 parent 3eb7f38 commit af813d0
Show file tree
Hide file tree
Showing 36 changed files with 1,124 additions and 732 deletions.
19 changes: 14 additions & 5 deletions burn-train/src/checkpoint/strategy/base.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::EventCollector;
use std::ops::DerefMut;

use crate::metric::store::EventStoreClient;

/// Action to be taken by a [checkpointer](crate::checkpoint::Checkpointer).
#[derive(Clone, PartialEq, Debug)]
pub enum CheckpointingAction {
Expand All @@ -11,15 +12,23 @@ pub enum CheckpointingAction {
}

/// Define when checkpoint should be saved and deleted.
pub trait CheckpointingStrategy<E: EventCollector> {
pub trait CheckpointingStrategy {
/// Based on the epoch, determine if the checkpoint should be saved.
fn checkpointing(&mut self, epoch: usize, collector: &mut E) -> Vec<CheckpointingAction>;
fn checkpointing(
&mut self,
epoch: usize,
collector: &EventStoreClient,
) -> Vec<CheckpointingAction>;
}

// We make dyn box implement the checkpointing strategy so that it can be used with generic, but
// still be dynamic.
impl<E: EventCollector> CheckpointingStrategy<E> for Box<dyn CheckpointingStrategy<E>> {
fn checkpointing(&mut self, epoch: usize, collector: &mut E) -> Vec<CheckpointingAction> {
impl CheckpointingStrategy for Box<dyn CheckpointingStrategy> {
fn checkpointing(
&mut self,
epoch: usize,
collector: &EventStoreClient,
) -> Vec<CheckpointingAction> {
self.deref_mut().checkpointing(epoch, collector)
}
}
54 changes: 25 additions & 29 deletions burn-train/src/checkpoint/strategy/composed.rs
Original file line number Diff line number Diff line change
@@ -1,59 +1,58 @@
use crate::metric::store::EventStoreClient;

use super::{CheckpointingAction, CheckpointingStrategy};
use crate::EventCollector;
use std::collections::HashSet;

/// Compose multiple checkpointing strategy and only delete checkpoints when both strategy flag an
/// epoch to be deleted.
pub struct ComposedCheckpointingStrategy<E: EventCollector> {
strategies: Vec<Box<dyn CheckpointingStrategy<E>>>,
pub struct ComposedCheckpointingStrategy {
strategies: Vec<Box<dyn CheckpointingStrategy>>,
deleted: Vec<HashSet<usize>>,
}

/// Help building a [checkpointing strategy](CheckpointingStrategy) by combining multiple ones.
pub struct ComposedCheckpointingStrategyBuilder<E: EventCollector> {
strategies: Vec<Box<dyn CheckpointingStrategy<E>>>,
#[derive(Default)]
pub struct ComposedCheckpointingStrategyBuilder {
strategies: Vec<Box<dyn CheckpointingStrategy>>,
}

impl<E: EventCollector> Default for ComposedCheckpointingStrategyBuilder<E> {
fn default() -> Self {
Self {
strategies: Vec::new(),
}
}
}

impl<E: EventCollector> ComposedCheckpointingStrategyBuilder<E> {
impl ComposedCheckpointingStrategyBuilder {
/// Add a new [checkpointing strategy](CheckpointingStrategy).
#[allow(clippy::should_implement_trait)]
pub fn add<S>(mut self, strategy: S) -> Self
where
S: CheckpointingStrategy<E> + 'static,
S: CheckpointingStrategy + 'static,
{
self.strategies.push(Box::new(strategy));
self
}

/// Create a new [composed checkpointing strategy](ComposedCheckpointingStrategy).
pub fn build(self) -> ComposedCheckpointingStrategy<E> {
pub fn build(self) -> ComposedCheckpointingStrategy {
ComposedCheckpointingStrategy::new(self.strategies)
}
}

impl<E: EventCollector> ComposedCheckpointingStrategy<E> {
fn new(strategies: Vec<Box<dyn CheckpointingStrategy<E>>>) -> Self {
impl ComposedCheckpointingStrategy {
fn new(strategies: Vec<Box<dyn CheckpointingStrategy>>) -> Self {
Self {
deleted: strategies.iter().map(|_| HashSet::new()).collect(),
strategies,
}
}
/// Create a new builder which help compose multiple
/// [checkpointing strategies](CheckpointingStrategy).
pub fn builder() -> ComposedCheckpointingStrategyBuilder<E> {
pub fn builder() -> ComposedCheckpointingStrategyBuilder {
ComposedCheckpointingStrategyBuilder::default()
}
}

impl<E: EventCollector> CheckpointingStrategy<E> for ComposedCheckpointingStrategy<E> {
fn checkpointing(&mut self, epoch: usize, collector: &mut E) -> Vec<CheckpointingAction> {
impl CheckpointingStrategy for ComposedCheckpointingStrategy {
fn checkpointing(
&mut self,
epoch: usize,
collector: &EventStoreClient,
) -> Vec<CheckpointingAction> {
let mut saved = false;
let mut actions = Vec::new();
let mut epochs_to_check = Vec::new();
Expand Down Expand Up @@ -104,33 +103,30 @@ impl<E: EventCollector> CheckpointingStrategy<E> for ComposedCheckpointingStrate

#[cfg(test)]
mod tests {
use crate::{
checkpoint::KeepLastNCheckpoints, info::MetricsInfo, test_utils::TestEventCollector,
};

use super::*;
use crate::{checkpoint::KeepLastNCheckpoints, metric::store::LogEventStore};

#[test]
fn should_delete_when_both_deletes() {
let mut collector = TestEventCollector::<f64, f64>::new(MetricsInfo::new());
let store = EventStoreClient::new(LogEventStore::default());
let mut strategy = ComposedCheckpointingStrategy::builder()
.add(KeepLastNCheckpoints::new(1))
.add(KeepLastNCheckpoints::new(2))
.build();

assert_eq!(
vec![CheckpointingAction::Save],
strategy.checkpointing(1, &mut collector)
strategy.checkpointing(1, &store)
);

assert_eq!(
vec![CheckpointingAction::Save],
strategy.checkpointing(2, &mut collector)
strategy.checkpointing(2, &store)
);

assert_eq!(
vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)],
strategy.checkpointing(3, &mut collector)
strategy.checkpointing(3, &store)
);
}
}
21 changes: 12 additions & 9 deletions burn-train/src/checkpoint/strategy/lastn.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::CheckpointingStrategy;
use crate::{checkpoint::CheckpointingAction, EventCollector};
use crate::{checkpoint::CheckpointingAction, metric::store::EventStoreClient};

/// Keep the last N checkpoints.
///
Expand All @@ -10,8 +10,12 @@ pub struct KeepLastNCheckpoints {
num_keep: usize,
}

impl<E: EventCollector> CheckpointingStrategy<E> for KeepLastNCheckpoints {
fn checkpointing(&mut self, epoch: usize, _collector: &mut E) -> Vec<CheckpointingAction> {
impl CheckpointingStrategy for KeepLastNCheckpoints {
fn checkpointing(
&mut self,
epoch: usize,
_store: &EventStoreClient,
) -> Vec<CheckpointingAction> {
let mut actions = vec![CheckpointingAction::Save];

if let Some(epoch) = usize::checked_sub(epoch, self.num_keep) {
Expand All @@ -26,28 +30,27 @@ impl<E: EventCollector> CheckpointingStrategy<E> for KeepLastNCheckpoints {

#[cfg(test)]
mod tests {
use crate::{info::MetricsInfo, test_utils::TestEventCollector};

use super::*;
use crate::metric::store::LogEventStore;

#[test]
fn should_always_delete_lastn_epoch_if_higher_than_one() {
let mut strategy = KeepLastNCheckpoints::new(2);
let mut collector = TestEventCollector::<f64, f64>::new(MetricsInfo::new());
let store = EventStoreClient::new(LogEventStore::default());

assert_eq!(
vec![CheckpointingAction::Save],
strategy.checkpointing(1, &mut collector)
strategy.checkpointing(1, &store)
);

assert_eq!(
vec![CheckpointingAction::Save],
strategy.checkpointing(2, &mut collector)
strategy.checkpointing(2, &store)
);

assert_eq!(
vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)],
strategy.checkpointing(3, &mut collector)
strategy.checkpointing(3, &store)
);
}
}
99 changes: 42 additions & 57 deletions burn-train/src/checkpoint/strategy/metric.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use super::CheckpointingStrategy;
use crate::{
checkpoint::CheckpointingAction, metric::Metric, Aggregate, Direction, EventCollector, Split,
checkpoint::CheckpointingAction,
metric::{
store::{Aggregate, Direction, EventStoreClient, Split},
Metric,
},
};

/// Keep the best checkpoint based on a metric.
Expand Down Expand Up @@ -28,10 +32,14 @@ impl MetricCheckpointingStrategy {
}
}

impl<E: EventCollector> CheckpointingStrategy<E> for MetricCheckpointingStrategy {
fn checkpointing(&mut self, epoch: usize, collector: &mut E) -> Vec<CheckpointingAction> {
impl CheckpointingStrategy for MetricCheckpointingStrategy {
fn checkpointing(
&mut self,
epoch: usize,
store: &EventStoreClient,
) -> Vec<CheckpointingAction> {
let best_epoch =
match collector.find_epoch(&self.name, self.aggregate, self.direction, self.split) {
match store.find_epoch(&self.name, self.aggregate, self.direction, self.split) {
Some(epoch_best) => epoch_best,
None => epoch,
};
Expand All @@ -56,93 +64,70 @@ impl<E: EventCollector> CheckpointingStrategy<E> for MetricCheckpointingStrategy

#[cfg(test)]
mod tests {
use burn_core::tensor::{backend::Backend, ElementConversion, Tensor};

use super::*;
use crate::{
info::MetricsInfo,
logger::InMemoryMetricLogger,
metric::{Adaptor, LossInput, LossMetric},
test_utils::TestEventCollector,
Event, LearnerItem, TestBackend,
metric::{
processor::{
test_utils::{end_epoch, process_train},
Metrics, MinimalEventProcessor,
},
store::LogEventStore,
LossMetric,
},
TestBackend,
};
use std::sync::Arc;

use super::*;

#[test]
fn always_keep_the_best_epoch() {
let mut store = LogEventStore::default();
let mut strategy = MetricCheckpointingStrategy::new::<LossMetric<TestBackend>>(
Aggregate::Mean,
Direction::Lowest,
Split::Train,
);
let mut info = MetricsInfo::new();
let mut metrics = Metrics::<f64, f64>::default();
// Register an in memory logger.
info.register_logger_train(InMemoryMetricLogger::default());
store.register_logger_train(InMemoryMetricLogger::default());
// Register the loss metric.
info.register_train_metric_numeric(LossMetric::<TestBackend>::new());

let mut collector = TestEventCollector::<f64, f64>::new(info);
metrics.register_train_metric_numeric(LossMetric::<TestBackend>::new());
let store = Arc::new(EventStoreClient::new(store));
let mut processor = MinimalEventProcessor::new(metrics, store.clone());

// Two points for the first epoch. Mean 0.75
let mut epoch = 1;
item(&mut collector, 1.0, epoch);
item(&mut collector, 0.5, epoch);
end_epoch(&mut collector, epoch);
process_train(&mut processor, 1.0, epoch);
process_train(&mut processor, 0.5, epoch);
end_epoch(&mut processor, epoch);

// Should save the current record.
assert_eq!(
vec![CheckpointingAction::Save],
strategy.checkpointing(epoch, &mut collector)
strategy.checkpointing(epoch, &store)
);

// Two points for the second epoch. Mean 0.4
epoch += 1;
item(&mut collector, 0.5, epoch);
item(&mut collector, 0.3, epoch);
end_epoch(&mut collector, epoch);
process_train(&mut processor, 0.5, epoch);
process_train(&mut processor, 0.3, epoch);
end_epoch(&mut processor, epoch);

// Should save the current record and delete the pervious one.
assert_eq!(
vec![CheckpointingAction::Delete(1), CheckpointingAction::Save],
strategy.checkpointing(epoch, &mut collector)
strategy.checkpointing(epoch, &store)
);

// Two points for the last epoch. Mean 2.0
epoch += 1;
item(&mut collector, 1.0, epoch);
item(&mut collector, 3.0, epoch);
end_epoch(&mut collector, epoch);
process_train(&mut processor, 1.0, epoch);
process_train(&mut processor, 3.0, epoch);
end_epoch(&mut processor, epoch);

// Should not delete the previous record, since it's the best one, and should not save a
// new one.
assert!(strategy.checkpointing(epoch, &mut collector).is_empty());
}

fn item(collector: &mut TestEventCollector<f64, f64>, value: f64, epoch: usize) {
let dummy_progress = burn_core::data::dataloader::Progress {
items_processed: 1,
items_total: 10,
};
let num_epochs = 3;
let dummy_iteration = 1;

collector.on_event_train(Event::ProcessedItem(LearnerItem::new(
value,
dummy_progress,
epoch,
num_epochs,
dummy_iteration,
None,
)));
}

fn end_epoch(collector: &mut TestEventCollector<f64, f64>, epoch: usize) {
collector.on_event_train(Event::EndEpoch(epoch));
collector.on_event_valid(Event::EndEpoch(epoch));
}

impl<B: Backend> Adaptor<LossInput<B>> for f64 {
fn adapt(&self) -> LossInput<B> {
LossInput::new(Tensor::from_data([self.elem()]))
}
assert!(strategy.checkpointing(epoch, &store).is_empty());
}
}
Loading

0 comments on commit af813d0

Please sign in to comment.