Skip to content

Commit

Permalink
refactor: create tape only during backprop (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Jul 20, 2022
1 parent 30fadcb commit 8e80502
Show file tree
Hide file tree
Showing 19 changed files with 320 additions and 443 deletions.
98 changes: 0 additions & 98 deletions burn-tensor/src/graph/node.rs

This file was deleted.

5 changes: 5 additions & 0 deletions burn-tensor/src/graph/node/mod.rs
@@ -0,0 +1,5 @@
mod node;
mod state;

pub use node::*;
pub use state::*;
33 changes: 33 additions & 0 deletions burn-tensor/src/graph/node/node.rs
@@ -0,0 +1,33 @@
use super::NodeStateRef;
use crate::{ops::RecordedOpsRef, tape::Tape};
use std::rc::Rc;

#[derive(new, Debug)]
pub struct Node<Out> {
pub state: NodeStateRef<Out>,
pub ops: RecordedOpsRef,
}

impl<Out> Node<Out> {
pub fn record(&self, tape: &mut Tape) {
let mut all_ops = self.ops.parents_ops();
tape.add(self.ops.clone());

loop {
if all_ops.len() == 0 {
return;
}
let ops = all_ops.pop().unwrap();
all_ops.append(&mut ops.parents_ops());
tape.add(ops);
}
}
}
pub type NodeRef<Out> = Rc<Node<Out>>;

pub trait Zeros<T> {
fn zeros(&self) -> T;
}
pub trait Ones<T> {
fn ones(&self) -> T;
}
45 changes: 45 additions & 0 deletions burn-tensor/src/graph/node/state.rs
@@ -0,0 +1,45 @@
use crate::node::Zeros;
use std::{cell::RefCell, ops::Add, rc::Rc};

#[derive(Debug)]
pub struct NodeState<Out> {
pub value: Out,
pub grad: Option<Out>,
}
pub type NodeStateRef<Out> = Rc<RefCell<NodeState<Out>>>;

impl<Out> NodeState<Out> {
pub fn new(value: Out) -> Self {
Self { value, grad: None }
}
pub fn new_mut(value: Out) -> NodeStateRef<Out> {
Rc::new(RefCell::new(Self::new(value)))
}
}
impl<Out> NodeState<Out>
where
Out: Clone,
{
pub fn value(&self) -> Out {
self.value.clone()
}
}

impl<Out> NodeState<Out>
where
Out: Zeros<Out> + Clone + Add<Output = Out>,
Out: std::fmt::Debug,
{
pub fn grad(&mut self) -> Out {
let grad_self = match &self.grad {
Some(val) => val.clone(),
None => self.value.zeros(),
};
self.grad = Some(grad_self.clone());
grad_self
}

pub fn update_grad(&mut self, grad: Out) {
self.grad = Some(self.grad() + grad);
}
}
87 changes: 21 additions & 66 deletions burn-tensor/src/graph/ops/binary.rs
@@ -1,99 +1,54 @@
use super::{BinaryRecordedState, RecordedOps};
use crate::node::{Node, NodeId, NodeRef, Ones, Zeros};
use super::{BinaryOpsNodeState, RecordedOps, RecordedOpsRef};
use crate::node::{NodeRef, NodeStateRef, Ones, Zeros};
use std::ops::{Add, Mul};

pub trait BinaryOps<Lhs, Rhs, Out>: std::fmt::Debug {
fn partial_left(&self, state: &BinaryRecordedState<Lhs, Rhs, Out>) -> Lhs;
fn partial_right(&self, state: &BinaryRecordedState<Lhs, Rhs, Out>) -> Rhs;
}

#[derive(Debug)]
pub struct BinaryOpsNode<Lhs, Rhs, Out> {
pub id: NodeId,
pub parent_left: NodeRef<Lhs>,
pub parent_right: NodeRef<Rhs>,
pub value: Out,
pub grad: Option<Out>,
fn partial_left(&self, state: &BinaryOpsNodeState<Lhs, Rhs, Out>) -> Lhs;
fn partial_right(&self, state: &BinaryOpsNodeState<Lhs, Rhs, Out>) -> Rhs;
}

#[derive(new, Debug)]
pub struct BinaryRecordedOps<Lhs, Rhs, Out, Ops> {
lhs: NodeRef<Lhs>,
rhs: NodeRef<Rhs>,
out: NodeRef<Out>,
out: NodeStateRef<Out>,
ops: Ops,
}

impl<Lhs, Rhs, Out> BinaryOpsNode<Lhs, Rhs, Out> {
pub fn new(parent_left: NodeRef<Lhs>, parent_right: NodeRef<Rhs>, value: Out) -> Self {
Self {
id: NodeId::new(),
parent_left,
parent_right,
value,
grad: None,
}
}
}

impl<Lhs, Rhs, Out> Node<Out> for BinaryOpsNode<Lhs, Rhs, Out>
where
Out: Zeros<Out> + Clone + Mul<Output = Out> + Add<Output = Out>,
Lhs: std::fmt::Debug,
Rhs: std::fmt::Debug,
Out: std::fmt::Debug,
{
fn id(&self) -> NodeId {
self.id.clone()
}
fn value(&self) -> Out {
self.value.clone()
}

fn grad(&mut self) -> Out {
let grad_self = match &self.grad {
Some(val) => val.clone(),
None => self.value.zeros(),
};
self.grad = Some(grad_self.clone());
grad_self
}

fn update_grad(&mut self, grad: Out) {
self.grad = Some(self.grad() + grad);
}
}

impl<Lhs, Rhs, Out, Ops> RecordedOps for BinaryRecordedOps<Lhs, Rhs, Out, Ops>
where
Lhs: Clone + Zeros<Lhs> + Mul<Out, Output = Lhs>,
Rhs: Clone + Zeros<Rhs> + Mul<Out, Output = Rhs>,
Out: Clone + Zeros<Out> + Ones<Out> + 'static,
Lhs: Clone + Zeros<Lhs> + Mul<Out, Output = Lhs> + Add<Output = Lhs> + 'static,
Rhs: Clone + Zeros<Rhs> + Mul<Out, Output = Rhs> + Add<Output = Rhs> + 'static,
Out: Clone + Zeros<Out> + Ones<Out> + Add<Output = Out> + 'static,
Lhs: std::fmt::Debug,
Rhs: std::fmt::Debug,
Out: std::fmt::Debug,
Ops: BinaryOps<Lhs, Rhs, Out>,
Ops: BinaryOps<Lhs, Rhs, Out> + 'static,
{
fn id(&self) -> NodeId {
self.out.borrow().id()
}

fn backward(&mut self) {
let state = BinaryRecordedState::new(&self.lhs, &self.rhs, &self.out);
fn backward(&self) {
let state = BinaryOpsNodeState::new(&self.lhs.state, &self.rhs.state, &self.out);

let partial_left = self.ops.partial_left(&state);
let partial_right: Rhs = self.ops.partial_right(&state);

let grad_mine = self.out.borrow_mut().grad();

self.lhs
.state
.borrow_mut()
.update_grad(partial_left * grad_mine.clone());
self.rhs.borrow_mut().update_grad(partial_right * grad_mine);
self.rhs
.state
.borrow_mut()
.update_grad(partial_right * grad_mine);
}

fn set_last_ops(&mut self) {
fn set_last_ops(&self) {
let value = self.out.borrow().value();
self.out.borrow_mut().update_grad(value.ones());
}

fn parents_ops(&self) -> Vec<RecordedOpsRef> {
vec![self.lhs.ops.clone(), self.rhs.ops.clone()]
}
}
21 changes: 11 additions & 10 deletions burn-tensor/src/graph/ops/ops.rs
@@ -1,21 +1,22 @@
use crate::node::{NodeId, NodeRef};
use crate::node::NodeStateRef;
use std::rc::Rc;

#[derive(new)]
pub struct BinaryRecordedState<'a, Lhs, Rhs, Out> {
pub left: &'a NodeRef<Lhs>,
pub right: &'a NodeRef<Rhs>,
pub output: &'a NodeRef<Out>,
pub struct BinaryOpsNodeState<'a, Lhs, Rhs, Out> {
pub left: &'a NodeStateRef<Lhs>,
pub right: &'a NodeStateRef<Rhs>,
pub output: &'a NodeStateRef<Out>,
}

#[derive(new)]
pub struct SingleRecordedState<'a, In, Out> {
pub struct SingleOpsNodeState<'a, In, Out> {
pub input: &'a In,
pub output: &'a Out,
}

pub trait RecordedOps: std::fmt::Debug {
fn id(&self) -> NodeId;
fn backward(&mut self);
fn set_last_ops(&mut self);
fn backward(&self);
fn set_last_ops(&self);
fn parents_ops(&self) -> Vec<RecordedOpsRef>;
}
pub type RecordedOpsRef = Box<dyn RecordedOps>;
pub type RecordedOpsRef = Rc<dyn RecordedOps>;
23 changes: 12 additions & 11 deletions burn-tensor/src/graph/ops/root.rs
@@ -1,23 +1,24 @@
use super::RecordedOps;
use crate::node::{NodeId, NodeRef, Ones, Zeros};
use super::{RecordedOps, RecordedOpsRef};
use crate::node::{NodeStateRef, Ones, Zeros};
use std::ops::Add;

#[derive(new, Debug)]
#[derive(new, Debug, Clone)]
pub struct InitRecordedOps<Out> {
root: NodeRef<Out>,
root: NodeStateRef<Out>,
}

impl<Out> RecordedOps for InitRecordedOps<Out>
where
Out: Clone + Zeros<Out> + Ones<Out> + 'static,
Out: Clone + Zeros<Out> + Ones<Out> + Add<Output = Out> + 'static,
Out: std::fmt::Debug,
{
fn id(&self) -> NodeId {
self.root.borrow().id()
}

fn backward(&mut self) {}
fn set_last_ops(&mut self) {
fn backward(&self) {}
fn set_last_ops(&self) {
let value = self.root.borrow().value();
self.root.borrow_mut().update_grad(value.ones());
}

fn parents_ops(&self) -> Vec<RecordedOpsRef> {
vec![]
}
}

0 comments on commit 8e80502

Please sign in to comment.