Skip to content

Commit

Permalink
feat: Make RetroForward public (#1905)
Browse files Browse the repository at this point in the history
  • Loading branch information
femshima committed Jun 18, 2024
1 parent 96468fc commit f8a7c54
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
6 changes: 4 additions & 2 deletions crates/burn-autodiff/src/checkpoint/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
/// Checkpointer module
pub mod base;
pub(crate) mod builder;
pub(crate) mod retro_forward;
pub(crate) mod state;
/// RetroForward module
pub mod retro_forward;
/// BackwardStates module
pub mod state;
/// CheckpointStrategy module
pub mod strategy;
3 changes: 2 additions & 1 deletion crates/burn-autodiff/src/checkpoint/retro_forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ use super::state::{BackwardStates, State};

/// Definition of the forward function of a node, called during retropropagation only.
/// This is different from the normal forward function because it reads and writes from
/// the [InnerStates] map instead of having a clear function signature.
/// the [BackwardStates] map instead of having a clear function signature.
pub trait RetroForward: Debug + Send + 'static {
/// Applies the forward pass for retropropagation.
fn forward(&self, states: &mut BackwardStates, out_node: NodeID);
}

Expand Down
9 changes: 5 additions & 4 deletions crates/burn-autodiff/src/checkpoint/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,16 @@ impl State {
}

#[derive(new, Default, Debug)]
/// Links [NodeID]s to their current [State]
/// Links [NodeID]s to their current state
pub struct BackwardStates {
map: HashMap<NodeID, State>,
}

impl BackwardStates {
/// Returns the output in the [State] of the given [NodeID],
/// Returns the output in the state of the given [NodeID],
/// and decrements the number of times this state is required.
/// This function always gives ownership of the output, but will clone it if needed for further uses.
pub(crate) fn get_state<T>(&mut self, node_id: &NodeID) -> T
pub fn get_state<T>(&mut self, node_id: &NodeID) -> T
where
T: Clone + Send + 'static,
{
Expand Down Expand Up @@ -117,7 +117,8 @@ impl BackwardStates {
self.map.insert(node_id, state);
}

pub(crate) fn save<T>(&mut self, node_id: NodeID, saved_output: T)
/// Saves the output to the state of the given [NodeID].
pub fn save<T>(&mut self, node_id: NodeID, saved_output: T)
where
T: Clone + Send + 'static,
{
Expand Down

0 comments on commit f8a7c54

Please sign in to comment.