diff --git a/burn-autodiff/src/ops/tensor.rs b/burn-autodiff/src/ops/tensor.rs index 57c3e5b5b7..6dbbbab976 100644 --- a/burn-autodiff/src/ops/tensor.rs +++ b/burn-autodiff/src/ops/tensor.rs @@ -745,16 +745,30 @@ impl TensorOps> for ADBackendDecorator { } fn detach(tensor: ADTensor) -> ADTensor { + // When we detach a tensor, we remove it from the graph, but we still want to keep the + // `require_grad` setting. + let is_require_grad = Self::is_require_grad(&tensor); let tensor = ADTensor::new(tensor.primitive); - match tensor.node.requirement { - Requirement::Grad => tensor.require_grad(), - _ => tensor, + match is_require_grad { + true => tensor.require_grad(), + false => tensor, } } - fn require_grad(tensor: ADTensor) -> ADTensor { - tensor.require_grad() + fn set_require_grad( + tensor: ADTensor, + require_grad: bool, + ) -> ADTensor { + if require_grad { + return tensor.require_grad(); + } + + ADTensor::new(tensor.primitive) + } + + fn is_require_grad(tensor: &ADTensor) -> bool { + matches!(tensor.node.requirement, Requirement::Grad) } fn mean(tensor: ADTensor) -> ADTensor { diff --git a/burn-core/src/module/base.rs b/burn-core/src/module/base.rs index 1ca4194a8b..be922a506b 100644 --- a/burn-core/src/module/base.rs +++ b/burn-core/src/module/base.rs @@ -1,4 +1,4 @@ -use alloc::{format, string::String, vec::Vec}; +use alloc::vec::Vec; use super::ParamId; use crate::{ @@ -8,6 +8,58 @@ use crate::{ pub use burn_derive::Module; use burn_tensor::Tensor; +// At the moment, our plan is to continue experimenting with the macro internally and monitor its development. +// We may consider making it public in the future. +macro_rules! module { + (map=$module:ident, ops=$item:expr) => {{ + struct Mapper; + impl ModuleMapper for Mapper { + fn map(&mut self, _id: &ParamId, tensor: Tensor) -> Tensor { + let func = $item; + func(tensor) + } + } + let mut mapper = Mapper; + $module.map(&mut mapper) + }}; + (map=$module:ident, ops=$item:expr, capture={$capture:ident: $ty:ty}) => {{ + struct Mapper<'a, B: Backend> { + capture: &'a $ty, + backend: core::marker::PhantomData, + } + impl<'a, B: Backend> ModuleMapper for Mapper<'a, B> { + fn map(&mut self, _id: &ParamId, tensor: Tensor) -> Tensor { + let func = $item; + func(tensor, self.capture) + } + } + let mut mapper = Mapper { + capture: $capture, + backend: core::marker::PhantomData::default(), + }; + $module.map(&mut mapper) + }}; + (visit=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{ + struct Visitor<'a, B: Backend> { + state: &'a mut $state_ty, + backend: core::marker::PhantomData, + } + impl<'a, B: Backend> ModuleVisitor for Visitor<'a, B> { + fn visit(&mut self, _id: &ParamId, tensor: &Tensor) { + let func = $item; + func(tensor, &mut self.state) + } + } + let mut state = $init(); + let mut visitor = Visitor { + state: &mut state, + backend: core::marker::PhantomData::default(), + }; + $module.visit(&mut visitor); + state + }}; +} + /// Trait for all neural network modules. /// /// Modules should be created using the [derive](burn_derive::Module) attribute. @@ -42,13 +94,80 @@ pub trait Module: Clone + Send + Sync + core::fmt::Debug { type Record: Record; /// Get the device list of the module and all of its sub-modules. - fn devices(&self) -> Vec; + fn devices(&self) -> Vec { + module!( + visit = self, + ops = |tensor: &Tensor, state: &mut Vec| { + let device = tensor.device(); + if !state.contains(&device) { + state.push(device); + } + }, + state = Vec, + init = Vec::new + ) + } + /// Fork the module and all of its sub-modules to the given device. + /// + /// # Notes + /// + /// This is similar to [to_device](Module::to_device), but it ensures the module will + /// have its own autodiff graph. + fn fork(self, device: &B::Device) -> Self { + module!( + map = self, + ops = |tensor: Tensor, device: &B::Device| { + let is_require_grad = tensor.is_require_grad(); + let mut tensor = tensor.to_device(device).detach(); + + if is_require_grad { + tensor = tensor.require_grad(); + } + + tensor + }, + capture = { device: B::Device } + ) + } /// Move the module and all of its sub-modules to the given device. - fn to_device(self, device: &B::Device) -> Self; - /// Detach the module from the graph. - fn detach(self) -> Self; + /// + /// # Warnings + /// + /// The device operations will be registered in the autodiff graph. Therefore, be sure to call + /// backward only one time even if you have the same module on multiple devices. If you want to + /// call backward multiple times, look into using [fork](Module::fork) instead. + fn to_device(self, device: &B::Device) -> Self { + module!( + map = self, + ops = |tensor: Tensor, device: &B::Device| tensor.to_device(device), + capture = { device: B::Device } + ) + } + /// Each tensor in the module tree will not require grad. + /// + /// # Warnings + /// + /// This should not be used for inference, use [valid](ADModule::valid) when using + /// AD modules. This is mostly useful when performing partial finetuning, which is updating only + /// a small fraction of the parameters instead of finetuning all of them. + fn no_grad(self) -> Self { + module!( + map = self, + ops = |tensor: Tensor| tensor.set_require_grad(false) + ) + } + /// Get the number of parameters the module has, including all of its sub-modules. - fn num_params(&self) -> usize; + fn num_params(&self) -> usize { + module!( + visit = self, + ops = |tensor: &Tensor, state: &mut usize| { + *state += tensor.shape().num_elements(); + }, + state = usize, + init = || 0 + ) + } /// Visit each tensor in the module with a [visitor](ModuleVisitor). fn visit>(&self, visitor: &mut V); /// Map each tensor in the module with a [mapper](ModuleMapper). @@ -72,21 +191,5 @@ pub trait ADModule: Module + Send + Sync + core::fmt::Debug { type InnerModule: Module; /// Get the same module, but on the inner backend without auto-differentiation. - fn inner(self) -> Self::InnerModule; - fn from_inner(module: Self::InnerModule) -> Self; -} - -#[derive(new, Debug)] -pub struct LoadingError { - message: String, + fn valid(&self) -> Self::InnerModule; } - -impl core::fmt::Display for LoadingError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(format!("Loading error: {}", self.message).as_str()) - } -} - -// TODO: Move from std to core after Error is core (see https://github.com/rust-lang/rust/issues/103765) -#[cfg(feature = "std")] -impl std::error::Error for LoadingError {} diff --git a/burn-core/src/module/param/base.rs b/burn-core/src/module/param/base.rs index bfdad44730..407019d355 100644 --- a/burn-core/src/module/param/base.rs +++ b/burn-core/src/module/param/base.rs @@ -1,10 +1,8 @@ -use alloc::format; -use serde::{Deserialize, Serialize}; - use super::ParamId; +use alloc::format; -/// Define a trainable parameter. -#[derive(new, Debug, Clone, Serialize, Deserialize)] +/// Define a parameter. +#[derive(new, Debug, Clone)] pub struct Param { pub(crate) id: ParamId, pub(crate) value: T, diff --git a/burn-core/src/module/param/constant.rs b/burn-core/src/module/param/constant.rs index 4f59f877ef..200d3414bf 100644 --- a/burn-core/src/module/param/constant.rs +++ b/burn-core/src/module/param/constant.rs @@ -5,22 +5,6 @@ macro_rules! constant { (module) => { type Record = (); - fn devices(&self) -> alloc::vec::Vec<::Device> { - alloc::vec::Vec::new() - } - - fn to_device(self, _device: &::Device) -> Self { - self - } - - fn detach(self) -> Self { - self - } - - fn num_params(&self) -> usize { - 0 - } - fn visit>(&self, _visitor: &mut V) { // Nothing to do } @@ -39,12 +23,8 @@ macro_rules! constant { (ad_module, $type:ty) => { type InnerModule = $type; - fn inner(self) -> Self::InnerModule { - self - } - - fn from_inner(module: Self::InnerModule) -> Self { - module + fn valid(&self) -> Self::InnerModule { + self.clone() } }; diff --git a/burn-core/src/module/param/id.rs b/burn-core/src/module/param/id.rs index 3c4e36fa25..92f88d442d 100644 --- a/burn-core/src/module/param/id.rs +++ b/burn-core/src/module/param/id.rs @@ -1,10 +1,7 @@ use alloc::string::{String, ToString}; - use burn_common::id::IdGenerator; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Hash, PartialEq, Eq, Clone, Serialize, Deserialize)] +#[derive(Debug, Hash, PartialEq, Eq, Clone)] pub struct ParamId { value: String, } @@ -35,6 +32,9 @@ impl ParamId { value: IdGenerator::generate(), } } + pub fn into_string(self) -> String { + self.value + } } impl core::fmt::Display for ParamId { diff --git a/burn-core/src/module/param/primitive.rs b/burn-core/src/module/param/primitive.rs index 7007474e9c..639979136b 100644 --- a/burn-core/src/module/param/primitive.rs +++ b/burn-core/src/module/param/primitive.rs @@ -10,29 +10,6 @@ where { type Record = Option; - fn devices(&self) -> Vec<::Device> { - if let Some(module) = self { - return Module::::devices(module); - } - - Vec::new() - } - - fn to_device(self, device: &::Device) -> Self { - self.map(|module| module.to_device(device)) - } - - fn detach(self) -> Self { - self.map(|module| module.detach()) - } - - fn num_params(&self) -> usize { - match &self { - Some(module) => module.num_params(), - None => 0, - } - } - fn visit>(&self, visitor: &mut V) { if let Some(module) = self { module.visit(visitor) @@ -60,12 +37,8 @@ where { type InnerModule = Option; - fn inner(self) -> Self::InnerModule { - self.map(|module| module.inner()) - } - - fn from_inner(module: Self::InnerModule) -> Self { - module.map(|module| T::from_inner(module)) + fn valid(&self) -> Self::InnerModule { + self.as_ref().map(|module| module.valid()) } } @@ -76,22 +49,6 @@ where { type Record = Vec; - fn devices(&self) -> Vec<::Device> { - let mut devices = Vec::new(); - for module in self.iter() { - devices.append(&mut module.devices()); - } - devices - } - - fn to_device(self, device: &::Device) -> Self { - self.into_iter().map(|val| val.to_device(device)).collect() - } - - fn detach(self) -> Self { - self.into_iter().map(|module| module.detach()).collect() - } - fn num_params(&self) -> usize { let mut num_params = 0; for module in self.iter() { @@ -130,15 +87,8 @@ where { type InnerModule = Vec; - fn inner(self) -> Self::InnerModule { - self.into_iter().map(|module| module.inner()).collect() - } - - fn from_inner(module: Self::InnerModule) -> Self { - module - .into_iter() - .map(|module| T::from_inner(module)) - .collect() + fn valid(&self) -> Self::InnerModule { + self.iter().map(|module| module.valid()).collect() } } @@ -158,14 +108,6 @@ where devices } - fn to_device(self, device: &::Device) -> Self { - self.map(|val| val.to_device(device)) - } - - fn detach(self) -> Self { - self.map(|module| module.detach()) - } - fn num_params(&self) -> usize { let mut num_params = 0; for module in self.iter() { @@ -209,11 +151,7 @@ where { type InnerModule = [T::InnerModule; N]; - fn inner(self) -> Self::InnerModule { - self.map(|module| module.inner()) - } - - fn from_inner(module: Self::InnerModule) -> Self { - module.map(|module| T::from_inner(module)) + fn valid(&self) -> Self::InnerModule { + self.map(|module| module.valid()) } } diff --git a/burn-core/src/module/param/running.rs b/burn-core/src/module/param/running.rs index 7aab31c8d6..08934ca2c8 100644 --- a/burn-core/src/module/param/running.rs +++ b/burn-core/src/module/param/running.rs @@ -1,4 +1,4 @@ -use alloc::{sync::Arc, vec, vec::Vec}; +use alloc::sync::Arc; use super::ParamId; use crate::module::{ADModule, Module, ModuleMapper, ModuleVisitor, Param}; @@ -40,62 +40,22 @@ use threading::*; /// The state value is the average of all updates on all threads. #[derive(Clone, Debug)] pub struct RunningState { + id: ParamId, values: Arc>>, value: Arc>, } -impl From>> - for Param>> -{ - fn from(value: RunningState>) -> Self { - Param { - id: ParamId::new(), - value, - } - } -} - -impl Module for Param>> { +impl Module for RunningState> { type Record = Param>; - fn num_params(&self) -> usize { - let tensor = self.value.value.read().unwrap(); - tensor.shape().num_elements() - } - - fn devices(&self) -> Vec { - let tensor = self.value.value.read().unwrap(); - vec![tensor.device()] - } - - fn to_device(self, device: &B::Device) -> Self { - self.value.sync(); - - let mut tensor = self.value.value.write().unwrap(); - tensor.inplace(|tensor| tensor.to_device(device)); - core::mem::drop(tensor); - - self - } - - fn detach(self) -> Self { - self.sync(); - - let mut tensor = self.value.value.write().unwrap(); - tensor.inplace(|tensor| tensor.detach()); - core::mem::drop(tensor); - - self - } - fn visit>(&self, visitor: &mut V) { - let tensor = self.value.value.read().unwrap(); + let tensor = self.value.read().unwrap(); visitor.visit(&self.id, &tensor) } fn map>(self, mapper: &mut M) -> Self { - let mut tensor = self.value.value.write().unwrap(); + let mut tensor = self.value.write().unwrap(); let tensor_out = mapper.map(&self.id, tensor.clone()); *tensor = tensor_out; @@ -106,13 +66,13 @@ impl Module for Param>> fn into_record(self) -> Self::Record { self.sync(); - let tensor = self.value.value.read().unwrap(); + let tensor = self.value.read().unwrap(); Param::new(self.id, tensor.clone()) } fn load_record(mut self, record: Self::Record) -> Self { - let mut tensor = self.value.value.write().unwrap(); + let mut tensor = self.value.write().unwrap(); *tensor = record.value.to_device(&tensor.device()); self.id = record.id; @@ -126,6 +86,16 @@ impl RunningState> { /// Create a new running state. pub fn new(value: Tensor) -> Self { Self { + id: ParamId::new(), + values: Arc::new(Mutex::new(HashMap::new())), + value: Arc::new(RwLock::new(value)), + } + } + + /// Create a new running state. + pub fn with_id(id: ParamId, value: Tensor) -> Self { + Self { + id, values: Arc::new(Mutex::new(HashMap::new())), value: Arc::new(RwLock::new(value)), } @@ -200,27 +170,13 @@ impl RunningState> { } } -impl ADModule for Param>> { - type InnerModule = Param>>; +impl ADModule for RunningState> { + type InnerModule = RunningState>; - fn inner(self) -> Self::InnerModule { + fn valid(&self) -> Self::InnerModule { self.sync(); - let value = self.value.value(); + let value = self.value(); - Param { - id: self.id, - value: RunningState::new(value.inner()), - } - } - - fn from_inner(module: Self::InnerModule) -> Self { - module.sync(); - - let value = module.value.value(); - - Param { - id: module.id, - value: RunningState::new(Tensor::from_inner(value)), - } + RunningState::with_id(self.id.clone(), value.inner()) } } diff --git a/burn-core/src/module/param/tensor.rs b/burn-core/src/module/param/tensor.rs index eceeb3df29..fd923bb4a0 100644 --- a/burn-core/src/module/param/tensor.rs +++ b/burn-core/src/module/param/tensor.rs @@ -1,5 +1,3 @@ -use alloc::{vec, vec::Vec}; - use super::{Param, ParamId}; use crate::module::{ADModule, Module, ModuleMapper, ModuleVisitor}; use crate::tensor::{ @@ -9,45 +7,20 @@ use crate::tensor::{ impl From> for Param> { fn from(value: Tensor) -> Self { - Param { - id: ParamId::new(), - value: value.require_grad(), - } + Param::new(ParamId::new(), value.require_grad()) } } impl Module for Param> { type Record = Param>; - fn num_params(&self) -> usize { - self.value.shape().num_elements() - } - - fn devices(&self) -> Vec { - vec![self.value.device()] - } - - fn to_device(self, device: &B::Device) -> Self { - Self { - id: self.id, - value: self.value.to_device(device).require_grad(), - } - } - - fn detach(self) -> Self { - Self { - id: self.id, - value: self.value.detach().require_grad(), - } - } - fn visit>(&self, visitor: &mut V) { visitor.visit(&self.id, &self.value) } fn map>(self, mapper: &mut M) -> Self { - let value = mapper.map(&self.id, self.value).require_grad(); - Self { id: self.id, value } + let value = mapper.map(&self.id, self.value); + Self::new(self.id, value) } fn into_record(self) -> Self::Record { @@ -55,24 +28,63 @@ impl Module for Param> { } fn load_record(self, record: Self::Record) -> Self { - record.to_device(&self.device()) + let mut tensor = record.value.detach(); + let device = self.device(); + + // Make sure we load the record into the same module device. + if tensor.device() != device { + tensor = tensor.to_device(&device).detach(); + } + + // Make sure we load the record with the same autodiff setting. + if self.is_require_grad() { + tensor = tensor.require_grad(); + } + + Self::new(record.id, tensor) } } impl ADModule for Param> { type InnerModule = Param>; - fn inner(self) -> Self::InnerModule { - Param { - id: self.id, - value: self.value.inner(), - } + fn valid(&self) -> Self::InnerModule { + Param::new( + self.id.clone(), + self.value.clone().inner().set_require_grad(false), + ) } +} - fn from_inner(module: Self::InnerModule) -> Self { - Param { - id: module.id, - value: Tensor::from_inner(module.value).require_grad(), - } +#[cfg(all(test, feature = "std"))] +mod tests { + use crate::{ + record::{NoStdInferenceRecordSettings, Record}, + TestADBackend, + }; + + use super::*; + + #[test] + fn test_load_record_setting() { + let tensor = Tensor::::ones([3, 3]); + let bytes = Param::from(tensor.clone()) + .into_record() + .record::(()) + .unwrap(); + + let no_grad_is_require_grad = Param::from(tensor.clone()) + .no_grad() + .load_record(Param::load::(bytes.clone()).unwrap()) + .value + .is_require_grad(); + + let with_default_is_require_grad = Param::from(tensor) + .load_record(Param::load::(bytes).unwrap()) + .value + .is_require_grad(); + + assert!(!no_grad_is_require_grad); + assert!(with_default_is_require_grad); } } diff --git a/burn-core/src/nn/attention/mha.rs b/burn-core/src/nn/attention/mha.rs index 3d232d28b2..37ebcb1c10 100644 --- a/burn-core/src/nn/attention/mha.rs +++ b/burn-core/src/nn/attention/mha.rs @@ -1,5 +1,3 @@ -use alloc::vec::Vec; - use crate as burn; use crate::nn::cache::TensorCache; @@ -9,7 +7,6 @@ use crate::{ nn, tensor::{activation, backend::Backend, Bool, Tensor}, }; - use libm::sqrtf; /// Configuration to create a [Multi Head Attention](MultiHeadAttention) layer. @@ -257,6 +254,7 @@ pub struct MHAAutoregressiveCache { mod tests { use super::*; use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; + use alloc::vec::Vec; use burn::tensor::{Distribution, Shape}; use burn_tensor::Int; diff --git a/burn-core/src/nn/conv/conv1d.rs b/burn-core/src/nn/conv/conv1d.rs index c9000c40e0..942607f6b2 100644 --- a/burn-core/src/nn/conv/conv1d.rs +++ b/burn-core/src/nn/conv/conv1d.rs @@ -1,5 +1,3 @@ -use alloc::vec::Vec; - use crate as burn; use crate::config::Config; diff --git a/burn-core/src/nn/conv/conv2d.rs b/burn-core/src/nn/conv/conv2d.rs index 519f0a93b7..6392c48b1c 100644 --- a/burn-core/src/nn/conv/conv2d.rs +++ b/burn-core/src/nn/conv/conv2d.rs @@ -1,5 +1,3 @@ -use alloc::vec::Vec; - use crate as burn; use crate::config::Config; diff --git a/burn-core/src/nn/embedding.rs b/burn-core/src/nn/embedding.rs index fa0665bb83..2688acbee3 100644 --- a/burn-core/src/nn/embedding.rs +++ b/burn-core/src/nn/embedding.rs @@ -1,6 +1,3 @@ -use alloc::vec::Vec; -use burn_tensor::Int; - use crate as burn; use super::Initializer; @@ -9,6 +6,7 @@ use crate::module::Module; use crate::module::Param; use crate::tensor::backend::Backend; use crate::tensor::Tensor; +use burn_tensor::Int; /// Configuration to create an [Embedding](Embedding) layer. #[derive(Config)] diff --git a/burn-core/src/nn/linear.rs b/burn-core/src/nn/linear.rs index a79869c785..8815d29c3e 100644 --- a/burn-core/src/nn/linear.rs +++ b/burn-core/src/nn/linear.rs @@ -1,5 +1,3 @@ -use alloc::vec::Vec; - use crate as burn; use crate::config::Config; diff --git a/burn-core/src/nn/norm/batch.rs b/burn-core/src/nn/norm/batch.rs index 8bf13ec858..f188983b51 100644 --- a/burn-core/src/nn/norm/batch.rs +++ b/burn-core/src/nn/norm/batch.rs @@ -1,5 +1,3 @@ -use alloc::vec::Vec; - use crate as burn; use crate::{ @@ -28,8 +26,8 @@ pub struct BatchNormConfig { pub struct BatchNorm { gamma: Param>, beta: Param>, - running_mean: Param>>, - running_var: Param>>, + running_mean: RunningState>, + running_var: RunningState>, momentum: f64, epsilon: f64, } @@ -46,8 +44,8 @@ impl BatchNormConfig { BatchNorm { gamma: Param::from(gamma), beta: Param::from(beta), - running_mean: Param::from(RunningState::new(running_mean)), - running_var: Param::from(RunningState::new(running_var)), + running_mean: RunningState::new(running_mean), + running_var: RunningState::new(running_var), momentum: self.momentum, epsilon: self.epsilon, } @@ -76,8 +74,8 @@ impl BatchNorm { fn forward_inference(&self, input: Tensor) -> Tensor { let channels = input.dims()[1]; - let mean = self.running_mean.val().value(); - let var = self.running_var.val().value(); + let mean = self.running_mean.value(); + let var = self.running_var.value(); let mut shape = [1; DI]; shape[1] = channels; @@ -192,7 +190,7 @@ mod tests_1d { let module = BatchNormConfig::new(3).init::(); module.forward(input_tensor()); - let module = module.inner(); + let module = module.valid(); let output = module.forward(input_tensor()); output.to_data().assert_approx_eq( @@ -247,7 +245,7 @@ mod tests_2d { let module = BatchNormConfig::new(3).init::(); module.forward(input_tensor()); - let module = module.inner(); + let module = module.valid(); let output = module.forward(input_tensor()); output.to_data().assert_approx_eq( @@ -301,11 +299,9 @@ mod tests_2d { let _output = module.forward(input_tensor()); - let module_valid = module.inner(); + let module_valid = module.valid(); let running_mean = module_valid.running_mean.value(); - - let module_train = BatchNorm::::from_inner(module_valid); - let running_mean_after = module_train.running_mean.value(); + let running_mean_after = module.running_mean.value(); running_mean_after .into_data() diff --git a/burn-core/src/nn/norm/layer.rs b/burn-core/src/nn/norm/layer.rs index 31250eb98f..c025233ba3 100644 --- a/burn-core/src/nn/norm/layer.rs +++ b/burn-core/src/nn/norm/layer.rs @@ -1,5 +1,3 @@ -use alloc::vec::Vec; - use crate as burn; use crate::config::Config; diff --git a/burn-core/src/nn/transformer/pwff.rs b/burn-core/src/nn/transformer/pwff.rs index 865e53fa06..f0287a867f 100644 --- a/burn-core/src/nn/transformer/pwff.rs +++ b/burn-core/src/nn/transformer/pwff.rs @@ -1,5 +1,3 @@ -use alloc::vec::Vec; - use crate as burn; use crate::{ diff --git a/burn-core/src/optim/grads.rs b/burn-core/src/optim/grads.rs index e0121d08f6..74877259e4 100644 --- a/burn-core/src/optim/grads.rs +++ b/burn-core/src/optim/grads.rs @@ -97,9 +97,7 @@ mod tests { fn test_convert_grads() { let layer_1 = layer(); let mut layer_2 = layer_1.clone(); - layer_2 = layer_2 - .to_device(&::Device::default()) - .detach(); + layer_2 = layer_2.fork(&::Device::default()); let loss_1 = layer_1.forward(random_tensor()); let loss_2 = layer_2.forward(random_tensor()); let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer_1); diff --git a/burn-core/src/optim/simple/adaptor.rs b/burn-core/src/optim/simple/adaptor.rs index bbb06289d7..3a6768464d 100644 --- a/burn-core/src/optim/simple/adaptor.rs +++ b/burn-core/src/optim/simple/adaptor.rs @@ -83,7 +83,9 @@ where if let Some(grad) = grad { let device = grad.device(); + let is_require_grad = tensor.is_require_grad(); let (key, record) = self.records.remove_entry(id).unzip(); + let (tensor, state) = self.optimizer.step( tensor.inner(), grad, @@ -97,7 +99,11 @@ where ); } - return Tensor::from_inner(tensor); + let mut tensor = Tensor::from_inner(tensor); + if is_require_grad { + tensor = tensor.require_grad(); + } + return tensor; } tensor diff --git a/burn-core/src/record/primitive.rs b/burn-core/src/record/primitive.rs index d3ac5c78ff..f502d21544 100644 --- a/burn-core/src/record/primitive.rs +++ b/burn-core/src/record/primitive.rs @@ -1,6 +1,8 @@ use alloc::string::String; use alloc::string::ToString; use alloc::vec::Vec; +use serde::Deserialize; +use serde::Serialize; use super::{Record, RecordSettings}; use crate::module::{Param, ParamId}; @@ -87,21 +89,22 @@ impl Record for DataSerialize { } } +/// (De)serialize parameters into a clean format. +#[derive(new, Debug, Clone, Serialize, Deserialize)] +pub struct ParamSerde { + id: String, + param: T, +} + impl Record for Param { - type Item = Param>; + type Item = ParamSerde>; fn into_item(self) -> Self::Item { - Param { - id: self.id, - value: self.value.into_item(), - } + ParamSerde::new(self.id.into_string(), self.value.into_item()) } fn from_item(item: Self::Item) -> Self { - Param { - id: item.id, - value: T::from_item(item.value), - } + Param::new(ParamId::from(item.id), T::from_item(item.param)) } } diff --git a/burn-core/tests/derive_module.rs b/burn-core/tests/derive_module.rs index 6a753995a3..74994c5c0f 100644 --- a/burn-core/tests/derive_module.rs +++ b/burn-core/tests/derive_module.rs @@ -4,6 +4,8 @@ use burn::tensor::{Distribution, Shape, Tensor}; use burn_core as burn; pub type TestBackend = burn_ndarray::NdArrayBackend; +#[cfg(feature = "std")] +pub type TestADBackend = burn_autodiff::ADBackendDecorator; #[derive(Module, Debug)] pub struct ModuleBasic { @@ -93,3 +95,53 @@ mod num_params { assert_eq!(2 * 20 * 20, module.num_params()); } } + +#[cfg(feature = "std")] +mod require_grad { + use burn_tensor::backend::ADBackend; + + use super::*; + + #[test] + fn should_have_grad_by_default() { + let module = ModuleBasic::::new(); + let mut grads = calculate_grads(&module); + + let grad_x = module.weight_basic.grad_remove(&mut grads); + + assert!(grad_x.is_some()); + } + + #[test] + fn should_have_no_grad_after_no_grad() { + let module = ModuleBasic::::new().no_grad(); + let mut grads = calculate_grads(&module); + + let grad_x = module.weight_basic.grad_remove(&mut grads); + + assert!(grad_x.is_none()); + } + + #[test] + fn should_have_grad_when_from_record() { + let module = ModuleBasic::::new(); + let record = ModuleBasicRecord { + weight_basic: module.weight_basic.clone(), // Even when param is no_grad, + }; + let module = module.load_record(record); + let mut grads = calculate_grads(&module); + + let grad_x = module.weight_basic.grad_remove(&mut grads); + + assert!(grad_x.is_some()); + } + + fn calculate_grads( + module: &ModuleBasic, + ) -> ::Gradients { + let x = Tensor::ones([20, 20]).require_grad(); + let y = module.weight_basic.val().matmul(x); + + y.backward() + } +} diff --git a/burn-derive/src/module/base.rs b/burn-derive/src/module/base.rs index 81239ccad2..a3d38d998b 100644 --- a/burn-derive/src/module/base.rs +++ b/burn-derive/src/module/base.rs @@ -25,13 +25,9 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream { let num_params_fn = generator.gen_num_params_fn(); let visit = generator.gen_visit_fn(); let map_mut = generator.gen_map_fn(); - let devices_fn = generator.gen_devices_fn(); - let to_device_fn = generator.gen_to_device_fn(); - let inner_fn = generator.gen_inner_fn(); - let from_inner_fn = generator.gen_from_inner_fn(); + let valid_fn = generator.gen_valid_fn(); let into_record_fn = generator.gen_into_record_fn(); let load_record_fn = generator.gen_load_record_fn(); - let detach_fn = generator.gen_detach_fn(); let clone_fn = generator.gen_clone_fn(); let generics_names_except_backend = generics_names_except_backend(&ast.generics); @@ -45,14 +41,10 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream { impl #generics burn::module::Module for #name #generics_ty #generics_where { type Record = #record_name #generics_ty; - #devices_fn - #to_device_fn - #load_record_fn #into_record_fn #num_params_fn - #detach_fn #visit #map_mut @@ -61,8 +53,7 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream { impl #generics burn::module::ADModule for #name #generics_ty where B: burn::tensor::backend::ADBackend, { type InnerModule=#name; - #inner_fn - #from_inner_fn + #valid_fn } impl #generics core::fmt::Display for #name #generics_ty #generics_where { diff --git a/burn-derive/src/module/generator.rs b/burn-derive/src/module/generator.rs index 51d416d6ea..234349ef1f 100644 --- a/burn-derive/src/module/generator.rs +++ b/burn-derive/src/module/generator.rs @@ -96,68 +96,15 @@ impl FnGenerator { } } - pub fn gen_devices_fn(&self) -> TokenStream { - let body = self.gen_fields_fn(|name| { - quote! { - devices.append(&mut burn::module::Module::::devices(&self.#name)); - } - }); - - quote! { - fn devices(&self) -> Vec { - let mut devices = Vec::new(); - #body - devices - } - } - } - - pub fn gen_to_device_fn(&self) -> TokenStream { + pub fn gen_valid_fn(&self) -> TokenStream { let (names, body) = self.gen_fields_fn_names(|name| { quote! { - let #name = burn::module::Module::::to_device(self.#name, device); + let #name = burn::module::ADModule::::valid(&self.#name); } }); quote! { - fn to_device(self, device: &B::Device) -> Self { - #body - - Self { - #(#names),* - } - } - } - } - - pub fn gen_detach_fn(&self) -> TokenStream { - let (names, body) = self.gen_fields_fn_names(|name| { - quote! { - let #name = burn::module::Module::::detach(self.#name); - } - }); - - quote! { - fn detach(self) -> Self { - #body - - Self { - #(#names),* - } - - } - } - } - - pub fn gen_inner_fn(&self) -> TokenStream { - let (names, body) = self.gen_fields_fn_names(|name| { - quote! { - let #name = burn::module::ADModule::::inner(self.#name); - } - }); - - quote! { - fn inner(self) -> Self::InnerModule { + fn valid(&self) -> Self::InnerModule { #body Self::InnerModule { @@ -167,24 +114,6 @@ impl FnGenerator { } } - pub fn gen_from_inner_fn(&self) -> TokenStream { - let (names, body) = self.gen_fields_fn_names(|name| { - quote! { - let #name = burn::module::ADModule::::from_inner(module.#name); - } - }); - - quote! { - fn from_inner(module: Self::InnerModule) -> Self { - #body - - Self { - #(#names),* - } - } - } - } - pub fn gen_clone_fn(&self) -> TokenStream { let (names, body) = self.gen_fields_fn_names(|name| { quote! { diff --git a/burn-ndarray/src/backend.rs b/burn-ndarray/src/backend.rs index 9147bb1c50..3ce136baf5 100644 --- a/burn-ndarray/src/backend.rs +++ b/burn-ndarray/src/backend.rs @@ -16,7 +16,7 @@ use burn_common::stub::Mutex; pub(crate) static SEED: Mutex> = Mutex::new(None); -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum NdArrayDevice { Cpu, } diff --git a/burn-no-std-tests/src/conv.rs b/burn-no-std-tests/src/conv.rs index 99f720b1e7..c901b27ae6 100644 --- a/burn-no-std-tests/src/conv.rs +++ b/burn-no-std-tests/src/conv.rs @@ -1,7 +1,5 @@ // Orginally copied from the burn/examples/mnist package -use alloc::vec::Vec; - use burn::{ config::Config, module::Module, diff --git a/burn-no-std-tests/src/model.rs b/burn-no-std-tests/src/model.rs index 89eb07939f..ed0d6478bb 100644 --- a/burn-no-std-tests/src/model.rs +++ b/burn-no-std-tests/src/model.rs @@ -1,7 +1,5 @@ // Orginally copied from the burn/examples/mnist package -use alloc::vec::Vec; - use crate::{ conv::{ConvBlock, ConvBlockConfig}, mlp::{Mlp, MlpConfig}, diff --git a/burn-tch/src/backend.rs b/burn-tch/src/backend.rs index 1b918b4790..67a0cd5b06 100644 --- a/burn-tch/src/backend.rs +++ b/burn-tch/src/backend.rs @@ -2,7 +2,7 @@ use super::element::TchElement; use super::TchTensor; use burn_tensor::backend::Backend; -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] /// The device struct when using the `tch` backend. /// /// Note that you need to provide the device index when using Cuda. diff --git a/burn-tensor/src/tensor/api/float.rs b/burn-tensor/src/tensor/api/float.rs index 418456df35..42f84667d8 100644 --- a/burn-tensor/src/tensor/api/float.rs +++ b/burn-tensor/src/tensor/api/float.rs @@ -305,7 +305,20 @@ where /// Mark the tensor to keep gradients during the backward pass. /// This function does nothing when autodiff is not enabled. pub fn require_grad(self) -> Self { - Self::new(B::require_grad(self.primitive)) + self.set_require_grad(true) + } + + /// Returns true if the tensor requires gradients during the backward pass. + pub fn is_require_grad(&self) -> bool { + B::is_require_grad(&self.primitive) + } + + /// Mark the tensor as tracked or untracked depending on the require grad argument. + /// When tracked, the gradients will be available after the backward pass. + /// + /// This function does nothing when autodiff is not enabled. + pub fn set_require_grad(self, require_grad: bool) -> Self { + Self::new(B::set_require_grad(self.primitive, require_grad)) } /// Applies the relu function to the tensor. diff --git a/burn-tensor/src/tensor/backend/base.rs b/burn-tensor/src/tensor/backend/base.rs index 6af55b9fa7..8bfae1aad0 100644 --- a/burn-tensor/src/tensor/backend/base.rs +++ b/burn-tensor/src/tensor/backend/base.rs @@ -63,7 +63,7 @@ pub trait Backend: + 'static { /// Device type. - type Device: Clone + Default + core::fmt::Debug + Send + Sync; + type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync; /// Pointer to another backend that have a full precision float element type type FullPrecisionBackend: Backend; diff --git a/burn-tensor/src/tensor/ops/tensor.rs b/burn-tensor/src/tensor/ops/tensor.rs index 350ec325ea..bcf2cbfec1 100644 --- a/burn-tensor/src/tensor/ops/tensor.rs +++ b/burn-tensor/src/tensor/ops/tensor.rs @@ -193,10 +193,17 @@ pub trait TensorOps { // Should only be overriden by autodiff backends. tensor } - fn require_grad(tensor: B::TensorPrimitive) -> B::TensorPrimitive { + fn set_require_grad( + tensor: B::TensorPrimitive, + _require_grad: bool, + ) -> B::TensorPrimitive { // Should only be overriden by autodiff backends. tensor } + fn is_require_grad(_tensor: &B::TensorPrimitive) -> bool { + // Should only be overriden by autodiff backends. + false + } fn sum(tensor: B::TensorPrimitive) -> B::TensorPrimitive<1>; fn sum_dim(tensor: B::TensorPrimitive, dim: usize) -> B::TensorPrimitive; fn mean(tensor: B::TensorPrimitive) -> B::TensorPrimitive<1>; diff --git a/burn-train/src/learner/builder.rs b/burn-train/src/learner/builder.rs index 037b214a97..820f1ccbe6 100644 --- a/burn-train/src/learner/builder.rs +++ b/burn-train/src/learner/builder.rs @@ -192,7 +192,6 @@ where } None => None, }; - let model = model.detach(); Learner { model, diff --git a/burn-train/src/learner/epoch.rs b/burn-train/src/learner/epoch.rs index b3fb06a8a8..81f51d3d92 100644 --- a/burn-train/src/learner/epoch.rs +++ b/burn-train/src/learner/epoch.rs @@ -24,14 +24,14 @@ pub struct TrainEpoch { } impl ValidEpoch { - pub fn run(&self, model: M, callback: &mut Box>) -> M + pub fn run(&self, model: &M, callback: &mut Box>) where B: ADBackend, M: ADModule, M::InnerModule: ValidStep, { log::info!("Executing validation step for epoch {}", self.epoch); - let model = model.inner(); + let model = model.valid(); let mut iterator = self.dataloader.iter(); let mut iteration = 0; @@ -50,8 +50,6 @@ impl ValidEpoch { )); } callback.on_valid_end_epoch(self.epoch); - - ADModule::from_inner(model) } } @@ -77,6 +75,7 @@ impl TrainEpoch { while let Some(item) = iterator.next() { iteration += 1; + log::info!("Iteration {}", iteration); let progress = iterator.progress(); let item = model.step(item); @@ -154,7 +153,6 @@ impl TrainEpoch { let grads = item.grads.to_device(&device_main, &model); - log::info!("Updated device"); accumulator.accumulate(&model, grads); accumulation_current += 1; diff --git a/burn-train/src/learner/step/train.rs b/burn-train/src/learner/step/train.rs index ccf8fe056b..f077fd2198 100644 --- a/burn-train/src/learner/step/train.rs +++ b/burn-train/src/learner/step/train.rs @@ -47,7 +47,7 @@ where spawn(move || loop { match receiver_input.recv() { Ok(item) => { - let step = item.model.to_device(&device).detach(); + let step = item.model.fork(&device); let output = step.step(item.item); sender_output.send(output).unwrap(); diff --git a/burn-train/src/learner/train_val.rs b/burn-train/src/learner/train_val.rs index af78f3b467..b691a453d0 100644 --- a/burn-train/src/learner/train_val.rs +++ b/burn-train/src/learner/train_val.rs @@ -49,7 +49,7 @@ where log::info!("Fitting {}", self.model.to_string()); // The reference model is always on the first device provided. if let Some(device) = self.devices.get(0) { - self.model = self.model.to_device(device).detach(); + self.model = self.model.fork(device); } let starting_epoch = match self.checkpoint { @@ -83,7 +83,7 @@ where } let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs); - model = epoch_valid.run(model, &mut self.callback); + epoch_valid.run(&model, &mut self.callback); Self::checkpoint( &model, diff --git a/examples/mnist-inference-web/model.bin b/examples/mnist-inference-web/model.bin index 8ce947d569..cc565aeb9c 100644 Binary files a/examples/mnist-inference-web/model.bin and b/examples/mnist-inference-web/model.bin differ diff --git a/examples/mnist-inference-web/src/model.rs b/examples/mnist-inference-web/src/model.rs index a1460bfc04..0a0bd5ed91 100644 --- a/examples/mnist-inference-web/src/model.rs +++ b/examples/mnist-inference-web/src/model.rs @@ -2,8 +2,6 @@ // Orginally copied from the burn/examples/mnist package -use alloc::vec::Vec; - use burn::{ module::Module, nn::{self, conv::Conv2dPaddingConfig, BatchNorm}, @@ -15,6 +13,7 @@ pub struct Model { conv1: ConvBlock, conv2: ConvBlock, conv3: ConvBlock, + dropout: nn::Dropout, fc1: nn::Linear, fc2: nn::Linear, activation: nn::GELU, @@ -27,7 +26,6 @@ impl Model { let conv1 = ConvBlock::new([1, 8], [3, 3]); // out: [Batch,8,26,26] let conv2 = ConvBlock::new([8, 16], [3, 3]); // out: [Batch,16,24x24] let conv3 = ConvBlock::new([16, 24], [3, 3]); // out: [Batch,24,22x22] - let hidden_size = 24 * 22 * 22; let fc1 = nn::LinearConfig::new(hidden_size, 32) .with_bias(false) @@ -36,12 +34,15 @@ impl Model { .with_bias(false) .init(); + let dropout = nn::DropoutConfig::new(0.5).init(); + Self { conv1, conv2, conv3, fc1, fc2, + dropout, activation: nn::GELU::new(), } } @@ -57,6 +58,7 @@ impl Model { let [batch_size, channels, heigth, width] = x.dims(); let x = x.reshape([batch_size, channels * heigth * width]); + let x = self.dropout.forward(x); let x = self.fc1.forward(x); let x = self.activation.forward(x); diff --git a/examples/mnist/src/model.rs b/examples/mnist/src/model.rs index fa754cffb1..3073474bd3 100644 --- a/examples/mnist/src/model.rs +++ b/examples/mnist/src/model.rs @@ -36,7 +36,7 @@ impl Model { .with_bias(false) .init(); - let dropout = nn::DropoutConfig::new(0.3).init(); + let dropout = nn::DropoutConfig::new(0.5).init(); Self { conv1, @@ -60,9 +60,9 @@ impl Model { let [batch_size, channels, heigth, width] = x.dims(); let x = x.reshape([batch_size, channels * heigth * width]); + let x = self.dropout.forward(x); let x = self.fc1.forward(x); let x = self.activation.forward(x); - let x = self.dropout.forward(x); self.fc2.forward(x) } diff --git a/examples/mnist/src/training.rs b/examples/mnist/src/training.rs index 95bfceed16..0f01605252 100644 --- a/examples/mnist/src/training.rs +++ b/examples/mnist/src/training.rs @@ -21,7 +21,7 @@ static ARTIFACT_DIR: &str = "/tmp/burn-example-mnist"; #[derive(Config)] pub struct MnistTrainingConfig { - #[config(default = 4)] + #[config(default = 10)] pub num_epochs: usize, #[config(default = 64)] diff --git a/examples/text-classification/src/inference.rs b/examples/text-classification/src/inference.rs index fff181fccb..653bde7abb 100644 --- a/examples/text-classification/src/inference.rs +++ b/examples/text-classification/src/inference.rs @@ -42,7 +42,7 @@ pub fn infer( let record = Record::load::(format!("{artifact_dir}/model").into()) .expect("Trained model weights"); let model = model.load_record(record); - let model = model.to_device(&device); + let model = model.fork(&device); println!("Running inference ..."); let item = batcher.batch(samples.clone());