diff --git a/crates/burn-autodiff/src/runtime/memory_management.rs b/crates/burn-autodiff/src/runtime/memory_management.rs index 134dd461d8..84a36816e7 100644 --- a/crates/burn-autodiff/src/runtime/memory_management.rs +++ b/crates/burn-autodiff/src/runtime/memory_management.rs @@ -12,19 +12,13 @@ pub struct GraphMemoryManagement { statuses: HashMap, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] enum NodeMemoryStatus { Useful, Unavailable, Unknown, } -#[derive(Clone)] -enum Mode { - TagAsUseful, - Explore, -} - impl GraphMemoryManagement { /// Register a new node with its parent. pub fn register(&mut self, node: NodeRefCount, parents: Vec) { @@ -66,9 +60,7 @@ impl GraphMemoryManagement { // available node with a tensor reference exist in their descendance. // But some may seem useless from some leaf but be useful from another one, // hence the need to iterate on all leaves. - for leaf in leaves.clone() { - self.useful_propagation(leaf, Mode::Explore); - } + self.useful_propagation(leaves.clone()); // New leaves are the roots of a useful backward sub-tree. // Deletables are everything not marked as useful. @@ -115,81 +107,146 @@ impl GraphMemoryManagement { } } - fn useful_propagation(&mut self, node_id: NodeID, mode: Mode) { - let parents = self.nodes.get(&node_id).cloned().unwrap_or(vec![]); - - match mode { - Mode::TagAsUseful => { - self.statuses.insert(node_id, NodeMemoryStatus::Useful); - for parent in parents { - self.useful_propagation(parent, Mode::TagAsUseful) - } - } - Mode::Explore => { - let node_status = self - .statuses - .get(&node_id) - .expect("All nodes should have received a status at this point") - .clone(); - - match node_status { - NodeMemoryStatus::Useful => { - // Nothing to do, was already tagged through some other path - } - NodeMemoryStatus::Unavailable => { - // Even if this node is unavailable, it is still possible that an ancestor is useful if referenced - for parent in parents { - self.useful_propagation(parent, Mode::Explore); + fn useful_propagation(&mut self, leaves: HashSet) { + // Accumulate visited nodes + let mut explored = HashSet::new(); + let mut tagged_useful = HashSet::new(); + + // Queue of nodes to visit + let mut to_tag_useful = PopNodeSet::default(); + let mut to_explore = PopNodeSet::new(leaves); + + // Utilitary function to iterate over a node's parents + let parents = |node_id| { + self.nodes + .get(&node_id) + .cloned() + .unwrap_or_default() + .into_iter() + }; + + loop { + // Pop a node id, greedily looking at tag_useful ones first + let (node_id, status) = match to_tag_useful.pop() { + Some(node_id) => (node_id, NodeMemoryStatus::Useful), + None => match to_explore.pop() { + Some(node_id) => { + let node_status = self + .statuses + .get(&node_id) + .expect("All nodes should have received a status during unavailable_propagation") + .to_owned(); + + if let NodeMemoryStatus::Unknown = node_status { + match self.is_referenced(node_id) { + true => (node_id, NodeMemoryStatus::Useful), + false => (node_id, NodeMemoryStatus::Unknown), + } + } else { + (node_id, node_status) } } - NodeMemoryStatus::Unknown => { - // If this node is referenced and not unavailable, - // then it is useful and we must retain all ancestors - - let mut mode = Mode::Explore; - if self.is_referenced(node_id) { - self.statuses.insert(node_id, NodeMemoryStatus::Useful); - mode = Mode::TagAsUseful; + None => { + // There are no nodes in the queues anymore + break; + } + }, + }; + + match status { + NodeMemoryStatus::Useful => { + tagged_useful.insert(node_id); + for parent in parents(node_id) { + // The node can be explored, as long as it's not already tagged useful + if !(tagged_useful.contains(&parent) || to_tag_useful.contains(&parent)) { + to_tag_useful.insert(parent); } - - for parent in parents { - self.useful_propagation(parent, mode.clone()); + } + } + _ => { + explored.insert(node_id); + for parent in parents(node_id) { + if !(explored.contains(&parent) || to_explore.contains(&parent)) { + to_explore.insert(parent); } } } } - } - } - fn is_referenced(&self, node_id: NodeID) -> bool { - match self.nodes.get_key_value(&node_id) { - Some((key, _value)) => Arc::strong_count(key) > 1, - None => panic!("Node should be in the nodes map"), + self.statuses.insert(node_id, status); } } fn identify_leaves_and_deletables( &self, - node_id: NodeID, + leaf_id: NodeID, new_leaves: &mut HashSet, to_delete: &mut Vec, ) { - let current_status = self - .statuses - .get(&node_id) - .expect("Node should have status"); - - match current_status { - NodeMemoryStatus::Useful => { - new_leaves.insert(node_id); - } - _ => { - let parents = self.nodes.get(&node_id).cloned().unwrap_or(vec![]); - for parent in parents { - self.identify_leaves_and_deletables(parent, new_leaves, to_delete) + let mut visited = HashSet::new(); + let mut to_visit = vec![leaf_id]; + + while let Some(node_id) = to_visit.pop() { + visited.insert(node_id); + + match self + .statuses + .get(&node_id) + .expect("Node should have status") + { + NodeMemoryStatus::Useful => { + new_leaves.insert(node_id); } - to_delete.push(node_id); - } + _ => { + to_delete.push(node_id); + + for parent in self + .nodes + .get(&node_id) + .cloned() + .unwrap_or_default() + .into_iter() + { + if !visited.contains(&parent) { + to_visit.push(parent); + } + } + } + }; + } + } + + fn is_referenced(&self, node_id: NodeID) -> bool { + match self.nodes.get_key_value(&node_id) { + Some((key, _value)) => Arc::strong_count(key) > 1, + None => panic!("Node should be in the nodes map"), } } } + +/// Wrapper over hash set for fast popping of any node +#[derive(new, Default)] +struct PopNodeSet { + hash_set: HashSet, +} + +impl PopNodeSet { + #[inline(always)] + fn pop(&mut self) -> Option { + self.hash_set + .iter() + .next() + .copied() + .and_then(|node_id| self.hash_set.take(&node_id)) + } + + #[inline(always)] + fn contains(&self, node_id: &NodeID) -> bool { + self.hash_set.contains(node_id) + } + + #[inline(always)] + fn insert(&mut self, node_id: NodeID) { + self.hash_set.insert(node_id); + } +} diff --git a/crates/burn-autodiff/src/tests/memory_management.rs b/crates/burn-autodiff/src/tests/memory_management.rs index 716afd5e9d..b3bff39e2a 100644 --- a/crates/burn-autodiff/src/tests/memory_management.rs +++ b/crates/burn-autodiff/src/tests/memory_management.rs @@ -238,4 +238,56 @@ mod tests { assert!(tensor_2.grad(&grads).is_some()); assert!(tensor_3.grad(&grads).is_none()); } + + #[test] + #[should_panic] + fn test_mm_deletables_propagate_well() { + let data = Data::from([[1.0, 2.0], [3.0, 4.0]]); + let device = Default::default(); + + let tensor_0 = + Tensor::::from_data(data.clone(), &device).require_grad(); + let tensor_1 = + Tensor::::from_data(data.clone(), &device).require_grad(); + + let tensor_2 = tensor_0 * tensor_1; + let tensor_3 = tensor_2.clone().exp(); + let tensor_4 = tensor_3.clone().log(); + + let grads = tensor_2.backward(); + + // We are testing that after backward on tensor_2, not only the leaf tensor_4 is deleted, but + // the intermediate tensor_3 as well + let grads = tensor_3.backward(); + } + + #[test] + fn test_mm_node_explored_once_can_still_be_tagged_as_useful_when_found_again_deeper() { + let data = Data::from([[1.0, 2.0], [3.0, 4.0]]); + let device = Default::default(); + + // The test has 50% chance of starting with leaf tensor_8 instead of tensor_4, which is not informative + // By repeating it many times it becomes almost impossible that it passes if it shouldn't + for _ in 0..12 { + let tensor_0 = + Tensor::::from_data(data.clone(), &device).require_grad(); + let tensor_1 = + Tensor::::from_data(data.clone(), &device).require_grad(); + + let tensor_2 = tensor_1.clone().exp(); + let tensor_3 = tensor_0.exp(); + let tensor_4 = tensor_3.clone() * tensor_2.clone(); + let tensor_5 = tensor_2.exp(); + let tensor_6 = tensor_5.exp(); + let tensor_7 = tensor_6.exp(); + let tensor_8 = tensor_7.exp(); + + // tensor_2 should be tagged unknown through the leaf tensor_4, then useful through the leaf tensor_8 + // which should happen after because tensor_2 is deeper from tensor_8 point of view and we're in breadth first search + tensor_3.backward(); + let grads = tensor_8.backward(); + + assert!(tensor_1.grad(&grads).is_some()); + } + } }