Skip to content

Commit

Permalink
Feat/module no grad (#274)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Apr 7, 2023
1 parent d8f64ce commit f04fe10
Show file tree
Hide file tree
Showing 38 changed files with 364 additions and 387 deletions.
24 changes: 19 additions & 5 deletions burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -745,16 +745,30 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
}

fn detach<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
// When we detach a tensor, we remove it from the graph, but we still want to keep the
// `require_grad` setting.
let is_require_grad = Self::is_require_grad(&tensor);
let tensor = ADTensor::new(tensor.primitive);

match tensor.node.requirement {
Requirement::Grad => tensor.require_grad(),
_ => tensor,
match is_require_grad {
true => tensor.require_grad(),
false => tensor,
}
}

fn require_grad<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, D> {
tensor.require_grad()
fn set_require_grad<const D: usize>(
tensor: ADTensor<B, D>,
require_grad: bool,
) -> ADTensor<B, D> {
if require_grad {
return tensor.require_grad();
}

ADTensor::new(tensor.primitive)
}

fn is_require_grad<const D: usize>(tensor: &ADTensor<B, D>) -> bool {
matches!(tensor.node.requirement, Requirement::Grad)
}

fn mean<const D: usize>(tensor: ADTensor<B, D>) -> ADTensor<B, 1> {
Expand Down
149 changes: 126 additions & 23 deletions burn-core/src/module/base.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use alloc::{format, string::String, vec::Vec};
use alloc::vec::Vec;

use super::ParamId;
use crate::{
Expand All @@ -8,6 +8,58 @@ use crate::{
pub use burn_derive::Module;
use burn_tensor::Tensor;

// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.
// We may consider making it public in the future.
macro_rules! module {
(map=$module:ident, ops=$item:expr) => {{
struct Mapper;
impl<B: Backend> ModuleMapper<B> for Mapper {
fn map<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
let func = $item;
func(tensor)
}
}
let mut mapper = Mapper;
$module.map(&mut mapper)
}};
(map=$module:ident, ops=$item:expr, capture={$capture:ident: $ty:ty}) => {{
struct Mapper<'a, B: Backend> {
capture: &'a $ty,
backend: core::marker::PhantomData<B>,
}
impl<'a, B: Backend> ModuleMapper<B> for Mapper<'a, B> {
fn map<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
let func = $item;
func(tensor, self.capture)
}
}
let mut mapper = Mapper {
capture: $capture,
backend: core::marker::PhantomData::default(),
};
$module.map(&mut mapper)
}};
(visit=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
struct Visitor<'a, B: Backend> {
state: &'a mut $state_ty,
backend: core::marker::PhantomData<B>,
}
impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
fn visit<const D: usize>(&mut self, _id: &ParamId, tensor: &Tensor<B, D>) {
let func = $item;
func(tensor, &mut self.state)
}
}
let mut state = $init();
let mut visitor = Visitor {
state: &mut state,
backend: core::marker::PhantomData::default(),
};
$module.visit(&mut visitor);
state
}};
}

/// Trait for all neural network modules.
///
/// Modules should be created using the [derive](burn_derive::Module) attribute.
Expand Down Expand Up @@ -42,13 +94,80 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
type Record: Record;

/// Get the device list of the module and all of its sub-modules.
fn devices(&self) -> Vec<B::Device>;
fn devices(&self) -> Vec<B::Device> {
module!(
visit = self,
ops = |tensor: &Tensor<B, D>, state: &mut Vec<B::Device>| {
let device = tensor.device();
if !state.contains(&device) {
state.push(device);
}
},
state = Vec<B::Device>,
init = Vec::new
)
}
/// Fork the module and all of its sub-modules to the given device.
///
/// # Notes
///
/// This is similar to [to_device](Module::to_device), but it ensures the module will
/// have its own autodiff graph.
fn fork(self, device: &B::Device) -> Self {
module!(
map = self,
ops = |tensor: Tensor<B, D>, device: &B::Device| {
let is_require_grad = tensor.is_require_grad();
let mut tensor = tensor.to_device(device).detach();

if is_require_grad {
tensor = tensor.require_grad();
}

tensor
},
capture = { device: B::Device }
)
}
/// Move the module and all of its sub-modules to the given device.
fn to_device(self, device: &B::Device) -> Self;
/// Detach the module from the graph.
fn detach(self) -> Self;
///
/// # Warnings
///
/// The device operations will be registered in the autodiff graph. Therefore, be sure to call
/// backward only one time even if you have the same module on multiple devices. If you want to
/// call backward multiple times, look into using [fork](Module::fork) instead.
fn to_device(self, device: &B::Device) -> Self {
module!(
map = self,
ops = |tensor: Tensor<B, D>, device: &B::Device| tensor.to_device(device),
capture = { device: B::Device }
)
}
/// Each tensor in the module tree will not require grad.
///
/// # Warnings
///
/// This should not be used for inference, use [valid](ADModule::valid) when using
/// AD modules. This is mostly useful when performing partial finetuning, which is updating only
/// a small fraction of the parameters instead of finetuning all of them.
fn no_grad(self) -> Self {
module!(
map = self,
ops = |tensor: Tensor<B, D>| tensor.set_require_grad(false)
)
}

/// Get the number of parameters the module has, including all of its sub-modules.
fn num_params(&self) -> usize;
fn num_params(&self) -> usize {
module!(
visit = self,
ops = |tensor: &Tensor<B, D>, state: &mut usize| {
*state += tensor.shape().num_elements();
},
state = usize,
init = || 0
)
}
/// Visit each tensor in the module with a [visitor](ModuleVisitor).
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V);
/// Map each tensor in the module with a [mapper](ModuleMapper).
Expand All @@ -72,21 +191,5 @@ pub trait ADModule<B: ADBackend>: Module<B> + Send + Sync + core::fmt::Debug {
type InnerModule: Module<B::InnerBackend>;

/// Get the same module, but on the inner backend without auto-differentiation.
fn inner(self) -> Self::InnerModule;
fn from_inner(module: Self::InnerModule) -> Self;
}

#[derive(new, Debug)]
pub struct LoadingError {
message: String,
fn valid(&self) -> Self::InnerModule;
}

impl core::fmt::Display for LoadingError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(format!("Loading error: {}", self.message).as_str())
}
}

// TODO: Move from std to core after Error is core (see https://github.com/rust-lang/rust/issues/103765)
#[cfg(feature = "std")]
impl std::error::Error for LoadingError {}
8 changes: 3 additions & 5 deletions burn-core/src/module/param/base.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use alloc::format;
use serde::{Deserialize, Serialize};

use super::ParamId;
use alloc::format;

/// Define a trainable parameter.
#[derive(new, Debug, Clone, Serialize, Deserialize)]
/// Define a parameter.
#[derive(new, Debug, Clone)]
pub struct Param<T> {
pub(crate) id: ParamId,
pub(crate) value: T,
Expand Down
24 changes: 2 additions & 22 deletions burn-core/src/module/param/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,6 @@ macro_rules! constant {
(module) => {
type Record = ();

fn devices(&self) -> alloc::vec::Vec<<B as burn_tensor::backend::Backend>::Device> {
alloc::vec::Vec::new()
}

fn to_device(self, _device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
self
}

fn detach(self) -> Self {
self
}

fn num_params(&self) -> usize {
0
}

fn visit<V: burn::module::ModuleVisitor<B>>(&self, _visitor: &mut V) {
// Nothing to do
}
Expand All @@ -39,12 +23,8 @@ macro_rules! constant {
(ad_module, $type:ty) => {
type InnerModule = $type;

fn inner(self) -> Self::InnerModule {
self
}

fn from_inner(module: Self::InnerModule) -> Self {
module
fn valid(&self) -> Self::InnerModule {
self.clone()
}
};

Expand Down
8 changes: 4 additions & 4 deletions burn-core/src/module/param/id.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use alloc::string::{String, ToString};

use burn_common::id::IdGenerator;

use serde::{Deserialize, Serialize};

#[derive(Debug, Hash, PartialEq, Eq, Clone, Serialize, Deserialize)]
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
pub struct ParamId {
value: String,
}
Expand Down Expand Up @@ -35,6 +32,9 @@ impl ParamId {
value: IdGenerator::generate(),
}
}
pub fn into_string(self) -> String {
self.value
}
}

impl core::fmt::Display for ParamId {
Expand Down
74 changes: 6 additions & 68 deletions burn-core/src/module/param/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,6 @@ where
{
type Record = Option<T::Record>;

fn devices(&self) -> Vec<<B as burn_tensor::backend::Backend>::Device> {
if let Some(module) = self {
return Module::<B>::devices(module);
}

Vec::new()
}

fn to_device(self, device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
self.map(|module| module.to_device(device))
}

fn detach(self) -> Self {
self.map(|module| module.detach())
}

fn num_params(&self) -> usize {
match &self {
Some(module) => module.num_params(),
None => 0,
}
}

fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
if let Some(module) = self {
module.visit(visitor)
Expand Down Expand Up @@ -60,12 +37,8 @@ where
{
type InnerModule = Option<T::InnerModule>;

fn inner(self) -> Self::InnerModule {
self.map(|module| module.inner())
}

fn from_inner(module: Self::InnerModule) -> Self {
module.map(|module| T::from_inner(module))
fn valid(&self) -> Self::InnerModule {
self.as_ref().map(|module| module.valid())
}
}

Expand All @@ -76,22 +49,6 @@ where
{
type Record = Vec<T::Record>;

fn devices(&self) -> Vec<<B as burn_tensor::backend::Backend>::Device> {
let mut devices = Vec::new();
for module in self.iter() {
devices.append(&mut module.devices());
}
devices
}

fn to_device(self, device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
self.into_iter().map(|val| val.to_device(device)).collect()
}

fn detach(self) -> Self {
self.into_iter().map(|module| module.detach()).collect()
}

fn num_params(&self) -> usize {
let mut num_params = 0;
for module in self.iter() {
Expand Down Expand Up @@ -130,15 +87,8 @@ where
{
type InnerModule = Vec<T::InnerModule>;

fn inner(self) -> Self::InnerModule {
self.into_iter().map(|module| module.inner()).collect()
}

fn from_inner(module: Self::InnerModule) -> Self {
module
.into_iter()
.map(|module| T::from_inner(module))
.collect()
fn valid(&self) -> Self::InnerModule {
self.iter().map(|module| module.valid()).collect()
}
}

Expand All @@ -158,14 +108,6 @@ where
devices
}

fn to_device(self, device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
self.map(|val| val.to_device(device))
}

fn detach(self) -> Self {
self.map(|module| module.detach())
}

fn num_params(&self) -> usize {
let mut num_params = 0;
for module in self.iter() {
Expand Down Expand Up @@ -209,11 +151,7 @@ where
{
type InnerModule = [T::InnerModule; N];

fn inner(self) -> Self::InnerModule {
self.map(|module| module.inner())
}

fn from_inner(module: Self::InnerModule) -> Self {
module.map(|module| T::from_inner(module))
fn valid(&self) -> Self::InnerModule {
self.map(|module| module.valid())
}
}

0 comments on commit f04fe10

Please sign in to comment.