Skip to content

Commit

Permalink
Autodiff Memory Management: BFS (#1710)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored and nathanielsimard committed May 3, 2024
1 parent 2115a22 commit 3bb0b8f
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 68 deletions.
193 changes: 125 additions & 68 deletions crates/burn-autodiff/src/runtime/memory_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,13 @@ pub struct GraphMemoryManagement {
statuses: HashMap<NodeID, NodeMemoryStatus>,
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
enum NodeMemoryStatus {
Useful,
Unavailable,
Unknown,
}

#[derive(Clone)]
enum Mode {
TagAsUseful,
Explore,
}

impl GraphMemoryManagement {
/// Register a new node with its parent.
pub fn register(&mut self, node: NodeRefCount, parents: Vec<NodeID>) {
Expand Down Expand Up @@ -66,9 +60,7 @@ impl GraphMemoryManagement {
// available node with a tensor reference exist in their descendance.
// But some may seem useless from some leaf but be useful from another one,
// hence the need to iterate on all leaves.
for leaf in leaves.clone() {
self.useful_propagation(leaf, Mode::Explore);
}
self.useful_propagation(leaves.clone());

// New leaves are the roots of a useful backward sub-tree.
// Deletables are everything not marked as useful.
Expand Down Expand Up @@ -115,81 +107,146 @@ impl GraphMemoryManagement {
}
}

fn useful_propagation(&mut self, node_id: NodeID, mode: Mode) {
let parents = self.nodes.get(&node_id).cloned().unwrap_or(vec![]);

match mode {
Mode::TagAsUseful => {
self.statuses.insert(node_id, NodeMemoryStatus::Useful);
for parent in parents {
self.useful_propagation(parent, Mode::TagAsUseful)
}
}
Mode::Explore => {
let node_status = self
.statuses
.get(&node_id)
.expect("All nodes should have received a status at this point")
.clone();

match node_status {
NodeMemoryStatus::Useful => {
// Nothing to do, was already tagged through some other path
}
NodeMemoryStatus::Unavailable => {
// Even if this node is unavailable, it is still possible that an ancestor is useful if referenced
for parent in parents {
self.useful_propagation(parent, Mode::Explore);
fn useful_propagation(&mut self, leaves: HashSet<NodeID>) {
// Accumulate visited nodes
let mut explored = HashSet::new();
let mut tagged_useful = HashSet::new();

// Queue of nodes to visit
let mut to_tag_useful = PopNodeSet::default();
let mut to_explore = PopNodeSet::new(leaves);

// Utilitary function to iterate over a node's parents
let parents = |node_id| {
self.nodes
.get(&node_id)
.cloned()
.unwrap_or_default()
.into_iter()
};

loop {
// Pop a node id, greedily looking at tag_useful ones first
let (node_id, status) = match to_tag_useful.pop() {
Some(node_id) => (node_id, NodeMemoryStatus::Useful),
None => match to_explore.pop() {
Some(node_id) => {
let node_status = self
.statuses
.get(&node_id)
.expect("All nodes should have received a status during unavailable_propagation")
.to_owned();

if let NodeMemoryStatus::Unknown = node_status {
match self.is_referenced(node_id) {
true => (node_id, NodeMemoryStatus::Useful),
false => (node_id, NodeMemoryStatus::Unknown),
}
} else {
(node_id, node_status)
}
}
NodeMemoryStatus::Unknown => {
// If this node is referenced and not unavailable,
// then it is useful and we must retain all ancestors

let mut mode = Mode::Explore;
if self.is_referenced(node_id) {
self.statuses.insert(node_id, NodeMemoryStatus::Useful);
mode = Mode::TagAsUseful;
None => {
// There are no nodes in the queues anymore
break;
}
},
};

match status {
NodeMemoryStatus::Useful => {
tagged_useful.insert(node_id);
for parent in parents(node_id) {
// The node can be explored, as long as it's not already tagged useful
if !(tagged_useful.contains(&parent) || to_tag_useful.contains(&parent)) {
to_tag_useful.insert(parent);
}

for parent in parents {
self.useful_propagation(parent, mode.clone());
}
}
_ => {
explored.insert(node_id);
for parent in parents(node_id) {
if !(explored.contains(&parent) || to_explore.contains(&parent)) {
to_explore.insert(parent);
}
}
}
}
}
}

fn is_referenced(&self, node_id: NodeID) -> bool {
match self.nodes.get_key_value(&node_id) {
Some((key, _value)) => Arc::strong_count(key) > 1,
None => panic!("Node should be in the nodes map"),
self.statuses.insert(node_id, status);
}
}

fn identify_leaves_and_deletables(
&self,
node_id: NodeID,
leaf_id: NodeID,
new_leaves: &mut HashSet<NodeID>,
to_delete: &mut Vec<NodeID>,
) {
let current_status = self
.statuses
.get(&node_id)
.expect("Node should have status");

match current_status {
NodeMemoryStatus::Useful => {
new_leaves.insert(node_id);
}
_ => {
let parents = self.nodes.get(&node_id).cloned().unwrap_or(vec![]);
for parent in parents {
self.identify_leaves_and_deletables(parent, new_leaves, to_delete)
let mut visited = HashSet::new();
let mut to_visit = vec![leaf_id];

while let Some(node_id) = to_visit.pop() {
visited.insert(node_id);

match self
.statuses
.get(&node_id)
.expect("Node should have status")
{
NodeMemoryStatus::Useful => {
new_leaves.insert(node_id);
}
to_delete.push(node_id);
}
_ => {
to_delete.push(node_id);

for parent in self
.nodes
.get(&node_id)
.cloned()
.unwrap_or_default()
.into_iter()
{
if !visited.contains(&parent) {
to_visit.push(parent);
}
}
}
};
}
}

fn is_referenced(&self, node_id: NodeID) -> bool {
match self.nodes.get_key_value(&node_id) {
Some((key, _value)) => Arc::strong_count(key) > 1,
None => panic!("Node should be in the nodes map"),
}
}
}

/// Wrapper over hash set for fast popping of any node
#[derive(new, Default)]
struct PopNodeSet {
hash_set: HashSet<NodeID>,
}

impl PopNodeSet {
#[inline(always)]
fn pop(&mut self) -> Option<NodeID> {
self.hash_set
.iter()
.next()
.copied()
.and_then(|node_id| self.hash_set.take(&node_id))
}

#[inline(always)]
fn contains(&self, node_id: &NodeID) -> bool {
self.hash_set.contains(node_id)
}

#[inline(always)]
fn insert(&mut self, node_id: NodeID) {
self.hash_set.insert(node_id);
}
}
52 changes: 52 additions & 0 deletions crates/burn-autodiff/src/tests/memory_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,4 +238,56 @@ mod tests {
assert!(tensor_2.grad(&grads).is_some());
assert!(tensor_3.grad(&grads).is_none());
}

#[test]
#[should_panic]
fn test_mm_deletables_propagate_well() {
let data = Data::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();

let tensor_0 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();

let tensor_2 = tensor_0 * tensor_1;
let tensor_3 = tensor_2.clone().exp();
let tensor_4 = tensor_3.clone().log();

let grads = tensor_2.backward();

// We are testing that after backward on tensor_2, not only the leaf tensor_4 is deleted, but
// the intermediate tensor_3 as well
let grads = tensor_3.backward();
}

#[test]
fn test_mm_node_explored_once_can_still_be_tagged_as_useful_when_found_again_deeper() {
let data = Data::from([[1.0, 2.0], [3.0, 4.0]]);
let device = Default::default();

// The test has 50% chance of starting with leaf tensor_8 instead of tensor_4, which is not informative
// By repeating it many times it becomes almost impossible that it passes if it shouldn't
for _ in 0..12 {
let tensor_0 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();
let tensor_1 =
Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();

let tensor_2 = tensor_1.clone().exp();
let tensor_3 = tensor_0.exp();
let tensor_4 = tensor_3.clone() * tensor_2.clone();
let tensor_5 = tensor_2.exp();
let tensor_6 = tensor_5.exp();
let tensor_7 = tensor_6.exp();
let tensor_8 = tensor_7.exp();

// tensor_2 should be tagged unknown through the leaf tensor_4, then useful through the leaf tensor_8
// which should happen after because tensor_2 is deeper from tensor_8 point of view and we're in breadth first search
tensor_3.backward();
let grads = tensor_8.backward();

assert!(tensor_1.grad(&grads).is_some());
}
}
}

0 comments on commit 3bb0b8f

Please sign in to comment.