Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: create tape only during backprop (#1)
- Loading branch information
1 parent
30fadcb
commit 8e80502
Showing
19 changed files
with
320 additions
and
443 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
mod node; | ||
mod state; | ||
|
||
pub use node::*; | ||
pub use state::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
use super::NodeStateRef; | ||
use crate::{ops::RecordedOpsRef, tape::Tape}; | ||
use std::rc::Rc; | ||
|
||
#[derive(new, Debug)] | ||
pub struct Node<Out> { | ||
pub state: NodeStateRef<Out>, | ||
pub ops: RecordedOpsRef, | ||
} | ||
|
||
impl<Out> Node<Out> { | ||
pub fn record(&self, tape: &mut Tape) { | ||
let mut all_ops = self.ops.parents_ops(); | ||
tape.add(self.ops.clone()); | ||
|
||
loop { | ||
if all_ops.len() == 0 { | ||
return; | ||
} | ||
let ops = all_ops.pop().unwrap(); | ||
all_ops.append(&mut ops.parents_ops()); | ||
tape.add(ops); | ||
} | ||
} | ||
} | ||
pub type NodeRef<Out> = Rc<Node<Out>>; | ||
|
||
pub trait Zeros<T> { | ||
fn zeros(&self) -> T; | ||
} | ||
pub trait Ones<T> { | ||
fn ones(&self) -> T; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
use crate::node::Zeros; | ||
use std::{cell::RefCell, ops::Add, rc::Rc}; | ||
|
||
#[derive(Debug)] | ||
pub struct NodeState<Out> { | ||
pub value: Out, | ||
pub grad: Option<Out>, | ||
} | ||
pub type NodeStateRef<Out> = Rc<RefCell<NodeState<Out>>>; | ||
|
||
impl<Out> NodeState<Out> { | ||
pub fn new(value: Out) -> Self { | ||
Self { value, grad: None } | ||
} | ||
pub fn new_mut(value: Out) -> NodeStateRef<Out> { | ||
Rc::new(RefCell::new(Self::new(value))) | ||
} | ||
} | ||
impl<Out> NodeState<Out> | ||
where | ||
Out: Clone, | ||
{ | ||
pub fn value(&self) -> Out { | ||
self.value.clone() | ||
} | ||
} | ||
|
||
impl<Out> NodeState<Out> | ||
where | ||
Out: Zeros<Out> + Clone + Add<Output = Out>, | ||
Out: std::fmt::Debug, | ||
{ | ||
pub fn grad(&mut self) -> Out { | ||
let grad_self = match &self.grad { | ||
Some(val) => val.clone(), | ||
None => self.value.zeros(), | ||
}; | ||
self.grad = Some(grad_self.clone()); | ||
grad_self | ||
} | ||
|
||
pub fn update_grad(&mut self, grad: Out) { | ||
self.grad = Some(self.grad() + grad); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,99 +1,54 @@ | ||
use super::{BinaryRecordedState, RecordedOps}; | ||
use crate::node::{Node, NodeId, NodeRef, Ones, Zeros}; | ||
use super::{BinaryOpsNodeState, RecordedOps, RecordedOpsRef}; | ||
use crate::node::{NodeRef, NodeStateRef, Ones, Zeros}; | ||
use std::ops::{Add, Mul}; | ||
|
||
pub trait BinaryOps<Lhs, Rhs, Out>: std::fmt::Debug { | ||
fn partial_left(&self, state: &BinaryRecordedState<Lhs, Rhs, Out>) -> Lhs; | ||
fn partial_right(&self, state: &BinaryRecordedState<Lhs, Rhs, Out>) -> Rhs; | ||
} | ||
|
||
#[derive(Debug)] | ||
pub struct BinaryOpsNode<Lhs, Rhs, Out> { | ||
pub id: NodeId, | ||
pub parent_left: NodeRef<Lhs>, | ||
pub parent_right: NodeRef<Rhs>, | ||
pub value: Out, | ||
pub grad: Option<Out>, | ||
fn partial_left(&self, state: &BinaryOpsNodeState<Lhs, Rhs, Out>) -> Lhs; | ||
fn partial_right(&self, state: &BinaryOpsNodeState<Lhs, Rhs, Out>) -> Rhs; | ||
} | ||
|
||
#[derive(new, Debug)] | ||
pub struct BinaryRecordedOps<Lhs, Rhs, Out, Ops> { | ||
lhs: NodeRef<Lhs>, | ||
rhs: NodeRef<Rhs>, | ||
out: NodeRef<Out>, | ||
out: NodeStateRef<Out>, | ||
ops: Ops, | ||
} | ||
|
||
impl<Lhs, Rhs, Out> BinaryOpsNode<Lhs, Rhs, Out> { | ||
pub fn new(parent_left: NodeRef<Lhs>, parent_right: NodeRef<Rhs>, value: Out) -> Self { | ||
Self { | ||
id: NodeId::new(), | ||
parent_left, | ||
parent_right, | ||
value, | ||
grad: None, | ||
} | ||
} | ||
} | ||
|
||
impl<Lhs, Rhs, Out> Node<Out> for BinaryOpsNode<Lhs, Rhs, Out> | ||
where | ||
Out: Zeros<Out> + Clone + Mul<Output = Out> + Add<Output = Out>, | ||
Lhs: std::fmt::Debug, | ||
Rhs: std::fmt::Debug, | ||
Out: std::fmt::Debug, | ||
{ | ||
fn id(&self) -> NodeId { | ||
self.id.clone() | ||
} | ||
fn value(&self) -> Out { | ||
self.value.clone() | ||
} | ||
|
||
fn grad(&mut self) -> Out { | ||
let grad_self = match &self.grad { | ||
Some(val) => val.clone(), | ||
None => self.value.zeros(), | ||
}; | ||
self.grad = Some(grad_self.clone()); | ||
grad_self | ||
} | ||
|
||
fn update_grad(&mut self, grad: Out) { | ||
self.grad = Some(self.grad() + grad); | ||
} | ||
} | ||
|
||
impl<Lhs, Rhs, Out, Ops> RecordedOps for BinaryRecordedOps<Lhs, Rhs, Out, Ops> | ||
where | ||
Lhs: Clone + Zeros<Lhs> + Mul<Out, Output = Lhs>, | ||
Rhs: Clone + Zeros<Rhs> + Mul<Out, Output = Rhs>, | ||
Out: Clone + Zeros<Out> + Ones<Out> + 'static, | ||
Lhs: Clone + Zeros<Lhs> + Mul<Out, Output = Lhs> + Add<Output = Lhs> + 'static, | ||
Rhs: Clone + Zeros<Rhs> + Mul<Out, Output = Rhs> + Add<Output = Rhs> + 'static, | ||
Out: Clone + Zeros<Out> + Ones<Out> + Add<Output = Out> + 'static, | ||
Lhs: std::fmt::Debug, | ||
Rhs: std::fmt::Debug, | ||
Out: std::fmt::Debug, | ||
Ops: BinaryOps<Lhs, Rhs, Out>, | ||
Ops: BinaryOps<Lhs, Rhs, Out> + 'static, | ||
{ | ||
fn id(&self) -> NodeId { | ||
self.out.borrow().id() | ||
} | ||
|
||
fn backward(&mut self) { | ||
let state = BinaryRecordedState::new(&self.lhs, &self.rhs, &self.out); | ||
fn backward(&self) { | ||
let state = BinaryOpsNodeState::new(&self.lhs.state, &self.rhs.state, &self.out); | ||
|
||
let partial_left = self.ops.partial_left(&state); | ||
let partial_right: Rhs = self.ops.partial_right(&state); | ||
|
||
let grad_mine = self.out.borrow_mut().grad(); | ||
|
||
self.lhs | ||
.state | ||
.borrow_mut() | ||
.update_grad(partial_left * grad_mine.clone()); | ||
self.rhs.borrow_mut().update_grad(partial_right * grad_mine); | ||
self.rhs | ||
.state | ||
.borrow_mut() | ||
.update_grad(partial_right * grad_mine); | ||
} | ||
|
||
fn set_last_ops(&mut self) { | ||
fn set_last_ops(&self) { | ||
let value = self.out.borrow().value(); | ||
self.out.borrow_mut().update_grad(value.ones()); | ||
} | ||
|
||
fn parents_ops(&self) -> Vec<RecordedOpsRef> { | ||
vec![self.lhs.ops.clone(), self.rhs.ops.clone()] | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,22 @@ | ||
use crate::node::{NodeId, NodeRef}; | ||
use crate::node::NodeStateRef; | ||
use std::rc::Rc; | ||
|
||
#[derive(new)] | ||
pub struct BinaryRecordedState<'a, Lhs, Rhs, Out> { | ||
pub left: &'a NodeRef<Lhs>, | ||
pub right: &'a NodeRef<Rhs>, | ||
pub output: &'a NodeRef<Out>, | ||
pub struct BinaryOpsNodeState<'a, Lhs, Rhs, Out> { | ||
pub left: &'a NodeStateRef<Lhs>, | ||
pub right: &'a NodeStateRef<Rhs>, | ||
pub output: &'a NodeStateRef<Out>, | ||
} | ||
|
||
#[derive(new)] | ||
pub struct SingleRecordedState<'a, In, Out> { | ||
pub struct SingleOpsNodeState<'a, In, Out> { | ||
pub input: &'a In, | ||
pub output: &'a Out, | ||
} | ||
|
||
pub trait RecordedOps: std::fmt::Debug { | ||
fn id(&self) -> NodeId; | ||
fn backward(&mut self); | ||
fn set_last_ops(&mut self); | ||
fn backward(&self); | ||
fn set_last_ops(&self); | ||
fn parents_ops(&self) -> Vec<RecordedOpsRef>; | ||
} | ||
pub type RecordedOpsRef = Box<dyn RecordedOps>; | ||
pub type RecordedOpsRef = Rc<dyn RecordedOps>; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,24 @@ | ||
use super::RecordedOps; | ||
use crate::node::{NodeId, NodeRef, Ones, Zeros}; | ||
use super::{RecordedOps, RecordedOpsRef}; | ||
use crate::node::{NodeStateRef, Ones, Zeros}; | ||
use std::ops::Add; | ||
|
||
#[derive(new, Debug)] | ||
#[derive(new, Debug, Clone)] | ||
pub struct InitRecordedOps<Out> { | ||
root: NodeRef<Out>, | ||
root: NodeStateRef<Out>, | ||
} | ||
|
||
impl<Out> RecordedOps for InitRecordedOps<Out> | ||
where | ||
Out: Clone + Zeros<Out> + Ones<Out> + 'static, | ||
Out: Clone + Zeros<Out> + Ones<Out> + Add<Output = Out> + 'static, | ||
Out: std::fmt::Debug, | ||
{ | ||
fn id(&self) -> NodeId { | ||
self.root.borrow().id() | ||
} | ||
|
||
fn backward(&mut self) {} | ||
fn set_last_ops(&mut self) { | ||
fn backward(&self) {} | ||
fn set_last_ops(&self) { | ||
let value = self.root.borrow().value(); | ||
self.root.borrow_mut().update_grad(value.ones()); | ||
} | ||
|
||
fn parents_ops(&self) -> Vec<RecordedOpsRef> { | ||
vec![] | ||
} | ||
} |
Oops, something went wrong.