Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: create tape only during backprop #1

Merged
merged 7 commits into from Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
98 changes: 0 additions & 98 deletions burn-tensor/src/graph/node.rs

This file was deleted.

5 changes: 5 additions & 0 deletions burn-tensor/src/graph/node/mod.rs
@@ -0,0 +1,5 @@
mod node;
mod state;

pub use node::*;
pub use state::*;
33 changes: 33 additions & 0 deletions burn-tensor/src/graph/node/node.rs
@@ -0,0 +1,33 @@
use super::NodeStateRef;
use crate::{ops::RecordedOpsRef, tape::Tape};
use std::rc::Rc;

#[derive(new, Debug)]
pub struct Node<Out> {
pub state: NodeStateRef<Out>,
pub ops: RecordedOpsRef,
}

impl<Out> Node<Out> {
pub fn record(&self, tape: &mut Tape) {
let mut all_ops = self.ops.parents_ops();
tape.add(self.ops.clone());

loop {
if all_ops.len() == 0 {
return;
}
let ops = all_ops.pop().unwrap();
all_ops.append(&mut ops.parents_ops());
tape.add(ops);
}
}
}
pub type NodeRef<Out> = Rc<Node<Out>>;

pub trait Zeros<T> {
fn zeros(&self) -> T;
}
pub trait Ones<T> {
fn ones(&self) -> T;
}
45 changes: 45 additions & 0 deletions burn-tensor/src/graph/node/state.rs
@@ -0,0 +1,45 @@
use crate::node::Zeros;
use std::{cell::RefCell, ops::Add, rc::Rc};

#[derive(Debug)]
pub struct NodeState<Out> {
pub value: Out,
pub grad: Option<Out>,
}
pub type NodeStateRef<Out> = Rc<RefCell<NodeState<Out>>>;

impl<Out> NodeState<Out> {
pub fn new(value: Out) -> Self {
Self { value, grad: None }
}
pub fn new_mut(value: Out) -> NodeStateRef<Out> {
Rc::new(RefCell::new(Self::new(value)))
}
}
impl<Out> NodeState<Out>
where
Out: Clone,
{
pub fn value(&self) -> Out {
self.value.clone()
}
}

impl<Out> NodeState<Out>
where
Out: Zeros<Out> + Clone + Add<Output = Out>,
Out: std::fmt::Debug,
{
pub fn grad(&mut self) -> Out {
let grad_self = match &self.grad {
Some(val) => val.clone(),
None => self.value.zeros(),
};
self.grad = Some(grad_self.clone());
grad_self
}

pub fn update_grad(&mut self, grad: Out) {
self.grad = Some(self.grad() + grad);
}
}
87 changes: 21 additions & 66 deletions burn-tensor/src/graph/ops/binary.rs
@@ -1,99 +1,54 @@
use super::{BinaryRecordedState, RecordedOps};
use crate::node::{Node, NodeId, NodeRef, Ones, Zeros};
use super::{BinaryOpsNodeState, RecordedOps, RecordedOpsRef};
use crate::node::{NodeRef, NodeStateRef, Ones, Zeros};
use std::ops::{Add, Mul};

pub trait BinaryOps<Lhs, Rhs, Out>: std::fmt::Debug {
fn partial_left(&self, state: &BinaryRecordedState<Lhs, Rhs, Out>) -> Lhs;
fn partial_right(&self, state: &BinaryRecordedState<Lhs, Rhs, Out>) -> Rhs;
}

#[derive(Debug)]
pub struct BinaryOpsNode<Lhs, Rhs, Out> {
pub id: NodeId,
pub parent_left: NodeRef<Lhs>,
pub parent_right: NodeRef<Rhs>,
pub value: Out,
pub grad: Option<Out>,
fn partial_left(&self, state: &BinaryOpsNodeState<Lhs, Rhs, Out>) -> Lhs;
fn partial_right(&self, state: &BinaryOpsNodeState<Lhs, Rhs, Out>) -> Rhs;
}

#[derive(new, Debug)]
pub struct BinaryRecordedOps<Lhs, Rhs, Out, Ops> {
lhs: NodeRef<Lhs>,
rhs: NodeRef<Rhs>,
out: NodeRef<Out>,
out: NodeStateRef<Out>,
ops: Ops,
}

impl<Lhs, Rhs, Out> BinaryOpsNode<Lhs, Rhs, Out> {
pub fn new(parent_left: NodeRef<Lhs>, parent_right: NodeRef<Rhs>, value: Out) -> Self {
Self {
id: NodeId::new(),
parent_left,
parent_right,
value,
grad: None,
}
}
}

impl<Lhs, Rhs, Out> Node<Out> for BinaryOpsNode<Lhs, Rhs, Out>
where
Out: Zeros<Out> + Clone + Mul<Output = Out> + Add<Output = Out>,
Lhs: std::fmt::Debug,
Rhs: std::fmt::Debug,
Out: std::fmt::Debug,
{
fn id(&self) -> NodeId {
self.id.clone()
}
fn value(&self) -> Out {
self.value.clone()
}

fn grad(&mut self) -> Out {
let grad_self = match &self.grad {
Some(val) => val.clone(),
None => self.value.zeros(),
};
self.grad = Some(grad_self.clone());
grad_self
}

fn update_grad(&mut self, grad: Out) {
self.grad = Some(self.grad() + grad);
}
}

impl<Lhs, Rhs, Out, Ops> RecordedOps for BinaryRecordedOps<Lhs, Rhs, Out, Ops>
where
Lhs: Clone + Zeros<Lhs> + Mul<Out, Output = Lhs>,
Rhs: Clone + Zeros<Rhs> + Mul<Out, Output = Rhs>,
Out: Clone + Zeros<Out> + Ones<Out> + 'static,
Lhs: Clone + Zeros<Lhs> + Mul<Out, Output = Lhs> + Add<Output = Lhs> + 'static,
Rhs: Clone + Zeros<Rhs> + Mul<Out, Output = Rhs> + Add<Output = Rhs> + 'static,
Out: Clone + Zeros<Out> + Ones<Out> + Add<Output = Out> + 'static,
Lhs: std::fmt::Debug,
Rhs: std::fmt::Debug,
Out: std::fmt::Debug,
Ops: BinaryOps<Lhs, Rhs, Out>,
Ops: BinaryOps<Lhs, Rhs, Out> + 'static,
{
fn id(&self) -> NodeId {
self.out.borrow().id()
}

fn backward(&mut self) {
let state = BinaryRecordedState::new(&self.lhs, &self.rhs, &self.out);
fn backward(&self) {
let state = BinaryOpsNodeState::new(&self.lhs.state, &self.rhs.state, &self.out);

let partial_left = self.ops.partial_left(&state);
let partial_right: Rhs = self.ops.partial_right(&state);

let grad_mine = self.out.borrow_mut().grad();

self.lhs
.state
.borrow_mut()
.update_grad(partial_left * grad_mine.clone());
self.rhs.borrow_mut().update_grad(partial_right * grad_mine);
self.rhs
.state
.borrow_mut()
.update_grad(partial_right * grad_mine);
}

fn set_last_ops(&mut self) {
fn set_last_ops(&self) {
let value = self.out.borrow().value();
self.out.borrow_mut().update_grad(value.ones());
}

fn parents_ops(&self) -> Vec<RecordedOpsRef> {
vec![self.lhs.ops.clone(), self.rhs.ops.clone()]
}
}
21 changes: 11 additions & 10 deletions burn-tensor/src/graph/ops/ops.rs
@@ -1,21 +1,22 @@
use crate::node::{NodeId, NodeRef};
use crate::node::NodeStateRef;
use std::rc::Rc;

#[derive(new)]
pub struct BinaryRecordedState<'a, Lhs, Rhs, Out> {
pub left: &'a NodeRef<Lhs>,
pub right: &'a NodeRef<Rhs>,
pub output: &'a NodeRef<Out>,
pub struct BinaryOpsNodeState<'a, Lhs, Rhs, Out> {
pub left: &'a NodeStateRef<Lhs>,
pub right: &'a NodeStateRef<Rhs>,
pub output: &'a NodeStateRef<Out>,
}

#[derive(new)]
pub struct SingleRecordedState<'a, In, Out> {
pub struct SingleOpsNodeState<'a, In, Out> {
pub input: &'a In,
pub output: &'a Out,
}

pub trait RecordedOps: std::fmt::Debug {
fn id(&self) -> NodeId;
fn backward(&mut self);
fn set_last_ops(&mut self);
fn backward(&self);
fn set_last_ops(&self);
fn parents_ops(&self) -> Vec<RecordedOpsRef>;
}
pub type RecordedOpsRef = Box<dyn RecordedOps>;
pub type RecordedOpsRef = Rc<dyn RecordedOps>;
23 changes: 12 additions & 11 deletions burn-tensor/src/graph/ops/root.rs
@@ -1,23 +1,24 @@
use super::RecordedOps;
use crate::node::{NodeId, NodeRef, Ones, Zeros};
use super::{RecordedOps, RecordedOpsRef};
use crate::node::{NodeStateRef, Ones, Zeros};
use std::ops::Add;

#[derive(new, Debug)]
#[derive(new, Debug, Clone)]
pub struct InitRecordedOps<Out> {
root: NodeRef<Out>,
root: NodeStateRef<Out>,
}

impl<Out> RecordedOps for InitRecordedOps<Out>
where
Out: Clone + Zeros<Out> + Ones<Out> + 'static,
Out: Clone + Zeros<Out> + Ones<Out> + Add<Output = Out> + 'static,
Out: std::fmt::Debug,
{
fn id(&self) -> NodeId {
self.root.borrow().id()
}

fn backward(&mut self) {}
fn set_last_ops(&mut self) {
fn backward(&self) {}
fn set_last_ops(&self) {
let value = self.root.borrow().value();
self.root.borrow_mut().update_grad(value.ones());
}

fn parents_ops(&self) -> Vec<RecordedOpsRef> {
vec![]
}
}