Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
antimora committed Jun 12, 2024
1 parent bf18455 commit 1073a08
Show file tree
Hide file tree
Showing 16 changed files with 456 additions and 324 deletions.
328 changes: 172 additions & 156 deletions crates/burn-core/src/module/display.rs

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions crates/burn-core/src/module/param/base.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::module::{Content, ModuleDisplay, ModuleDisplayDefault};

use super::ParamId;
use alloc::boxed::Box;
use alloc::format;
Expand Down Expand Up @@ -54,6 +56,16 @@ impl<T: Parameter> core::fmt::Debug for Param<T> {
}
}

impl<T: Parameter> ModuleDisplay for Param<T> {}

impl<T: Parameter> ModuleDisplayDefault for Param<T> {
fn content(&self, content: Content) -> Option<Content> {
content
.add_formatted(&format!("Param {{id: {}}}", self.id))
.optional()
}
}

/// Trait that defines what is necessary for a type to be a parameter.
pub trait Parameter: Clone + core::fmt::Debug + Send {
/// The device type to be used.
Expand Down
43 changes: 42 additions & 1 deletion crates/burn-core/src/module/param/constant.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::{
self as burn,
module::{AutodiffModule, Devices, Module, ModuleMapper, ModuleVisitor},
module::{
AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault,
ModuleMapper, ModuleVisitor,
},
record::Record,
};
use burn::record::PrecisionSettings;
Expand Down Expand Up @@ -96,6 +99,15 @@ macro_rules! constant {
impl<B: burn::tensor::backend::AutodiffBackend> burn::module::AutodiffModule<B> for $type {
constant!(ad_module, $type);
}

impl burn::module::ModuleDisplayDefault for $type {
fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {
let string = format!("{}", self);
content.add_formatted(&string).optional()
}
}

impl burn::module::ModuleDisplay for $type {}
};
}

Expand Down Expand Up @@ -158,6 +170,15 @@ impl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {
}
}

impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplayDefault for Tensor<B, D, K> {
fn content(&self, content: Content) -> Option<Content> {
let string = format!("Tensor {{rank: {D}, shape: {:?}}}", self.shape().dims);
content.add_single(&string).optional()
}
}

impl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplay for Tensor<B, D, K> {}

impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>
for Tensor<B, D, K>
{
Expand Down Expand Up @@ -200,6 +221,14 @@ impl<B: Backend> Module<B> for PhantomData<B> {
}
}

impl<B: Backend> ModuleDisplayDefault for PhantomData<B> {
fn content(&self, content: Content) -> Option<Content> {
content.add_single(&"PhantomData".to_string()).optional()
}
}

impl<B: Backend> ModuleDisplay for PhantomData<B> {}

impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
type InnerModule = PhantomData<B::InnerBackend>;

Expand Down Expand Up @@ -248,6 +277,18 @@ where
}
}

impl<T> ModuleDisplayDefault for Ignored<T>
where
T: Sync + Send + core::fmt::Debug + Clone,
{
fn content(&self, content: Content) -> Option<Content> {
// TODO figure if we need to display the content of the ignored type
content.add_single(&"Ignored".to_string()).optional()
}
}

impl<T> ModuleDisplay for Ignored<T> where T: Sync + Send + core::fmt::Debug + Clone {}

impl<T> Display for Ignored<T>
where
T: Sync + Send + core::fmt::Debug + Clone,
Expand Down
60 changes: 59 additions & 1 deletion crates/burn-core/src/module/param/primitive.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor};
use crate::module::{
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
ModuleVisitor,
};
use alloc::vec::Vec;
use burn_tensor::backend::{AutodiffBackend, Backend};
use core::fmt::Debug;
Expand Down Expand Up @@ -46,6 +49,20 @@ where
}
}

impl<T: ModuleDisplay> ModuleDisplayDefault for Option<T> {
fn content(&self, content: Content) -> Option<Content> {
match self {
Some(module) => content.add_single(module).optional(),
None => {
let none_string = "None".to_string();
content.add_single(&none_string).optional()
}
}
}
}

impl<T: ModuleDisplay> ModuleDisplay for Option<T> {}

impl<T, B> AutodiffModule<B> for Option<T>
where
T: AutodiffModule<B> + Debug + Send + Clone,
Expand Down Expand Up @@ -114,6 +131,31 @@ where
}
}

// impl<T: ModuleDisplay> ModuleDisplayDefault for Vec<T> {
// fn content(&self, content: Content) -> Option<Content> {
// let mut new_content = content.clone();
// for (i, module) in self.iter().enumerate() {
// let index = format!("{}: ", i);
// new_content.add(&index, module);
// }
// new_content.optional()
// }
// }

impl<T: ModuleDisplay> ModuleDisplayDefault for Vec<T> {
fn content(&self, content: Content) -> Option<Content> {
self.iter()
.enumerate()
.fold(content, |acc, (i, module)| {
let index = format!("{}: ", i);
acc.add(&index, module)
})
.optional()
}
}

impl<T: ModuleDisplay> ModuleDisplay for Vec<T> {}

impl<T, B> AutodiffModule<B> for Vec<T>
where
T: AutodiffModule<B> + Debug + Send + Clone,
Expand Down Expand Up @@ -183,6 +225,20 @@ where
}
}

impl<const N: usize, T: ModuleDisplay> ModuleDisplayDefault for [T; N] {
fn content(&self, content: Content) -> Option<Content> {
self.iter()
.enumerate()
.fold(content, |acc, (i, module)| {
let index = format!("{}: ", i);
acc.add(&index, module)
})
.optional()
}
}

impl<const N: usize, T: ModuleDisplay> ModuleDisplay for [T; N] {}

impl<const N: usize, T, B> AutodiffModule<B> for [T; N]
where
T: AutodiffModule<B> + Debug + Send + Clone + Copy,
Expand Down Expand Up @@ -258,6 +314,8 @@ macro_rules! impl_module_tuple {
};
}

// TODO implement ModuleDisplay for tuple

impl_module_tuple!([L0, L1][0, 1]);
impl_module_tuple!([L0, L1, L2][0, 1, 2]);
impl_module_tuple!([L0, L1, L2, L3][0, 1, 2, 3]);
Expand Down
17 changes: 16 additions & 1 deletion crates/burn-core/src/module/param/running.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::ParamId;
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param};
use crate::module::{
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
ModuleVisitor, Param,
};
use alloc::sync::Arc;
use alloc::vec::Vec;
use burn_common::stub::Mutex;
Expand Down Expand Up @@ -54,6 +57,18 @@ impl<V> core::fmt::Display for RunningState<V> {
}
}

impl<V> ModuleDisplayDefault for RunningState<V> {
fn content(&self, content: Content) -> Option<Content> {
// TODO show param id if setting is enabled

content
.add_formatted(&"RunningState".to_string())
.optional()
}
}

impl<V> ModuleDisplay for RunningState<V> {}

impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
type Record = Param<Tensor<B, D>>;

Expand Down
28 changes: 27 additions & 1 deletion crates/burn-core/src/module/param/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::{Param, ParamId, Parameter};
use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor};
use crate::module::{
AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
ModuleVisitor,
};
use crate::tensor::{
backend::{AutodiffBackend, Backend},
Tensor,
Expand Down Expand Up @@ -146,6 +149,13 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
devices
}
}
// impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D>> {
// fn content(&self, content: Content) -> Option<Content> {
// let string = format!("ParamTensor {{rank: {D}, shape: {:?}}}", self.shape().dims);
// content.add_formatted(&string).optional()
// }
// }
// impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D>> {}

impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
type Record = Param<Tensor<B, D, Int>>;
Expand Down Expand Up @@ -197,6 +207,13 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {
devices
}
}
// impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Int>> {
// fn content(&self, content: Content) -> Option<Content> {
// let string = format!("ParamTensor {{rank: {D}, shape: {:?}}}", self.shape().dims);
// content.add_formatted(&string).optional()
// }
// }
// impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Int>> {}

impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
type Record = Param<Tensor<B, D, Bool>>;
Expand Down Expand Up @@ -249,6 +266,15 @@ impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {
}
}

// impl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Bool>> {
// fn content(&self, content: Content) -> Option<Content> {
// let string = format!("ParamTensor {{rank: {D}, shape: {:?}}}", self.shape().dims);
// content.add_formatted(&string).optional()
// }
// }

// impl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Bool>> {}

impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D>> {
type InnerModule = Param<Tensor<B::InnerBackend, D>>;

Expand Down
11 changes: 7 additions & 4 deletions crates/burn-core/src/nn/conv/conv1d.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate as burn;
use crate::module::Attributes;
use crate::module::Content;
use crate::module::DisplaySettings;
use crate::module::Ignored;
use crate::module::ModuleDisplay;
Expand Down Expand Up @@ -74,13 +74,16 @@ impl<B: Backend> ModuleDisplay for Conv1d<B> {
.optional()
}

fn custom_attributes(&self, attributes: Attributes) -> Option<Attributes> {
attributes
fn custom_content(&self, content: Content) -> Option<Content> {
// Since padding does not implement ModuleDisplay, we need to format it manually.
let padding_formatted = format!("{}", &self.padding);

content
.add("stride", &self.stride)
.add("kernel_size", &self.kernel_size)
.add("dilation", &self.dilation)
.add("groups", &self.groups)
.add("padding", &self.padding)
.add("padding", &padding_formatted)
.optional()
}
}
Expand Down
6 changes: 2 additions & 4 deletions crates/burn-core/src/nn/linear.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate as burn;
use crate::module::Attributes;
use crate::module::DisplaySettings;
use crate::module::ModuleDisplay;

Expand Down Expand Up @@ -92,10 +91,9 @@ impl<B: Backend> ModuleDisplay for Linear<B> {
.optional()
}

fn custom_attributes(&self, attributes: Attributes) -> Option<Attributes> {
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
let [d_input, d_output] = self.weight.shape().dims;

attributes
content
.add("d_input", &d_input)
.add("d_output", &d_output)
.add("bias", &self.bias.is_some())
Expand Down
6 changes: 2 additions & 4 deletions crates/burn-core/src/nn/norm/layer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate as burn;
use crate::module::Attributes;
use crate::module::DisplaySettings;
use crate::module::ModuleDisplay;
use crate::nn::Initializer;
Expand Down Expand Up @@ -70,10 +69,9 @@ impl<B: Backend> ModuleDisplay for LayerNorm<B> {
.optional()
}

fn custom_attributes(&self, attributes: Attributes) -> Option<Attributes> {
fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
let [d_model] = self.gamma.shape().dims;

attributes
content
.add("d_model", &d_model)
.add("epsilon", &self.epsilon)
.optional()
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-core/src/nn/pool/avg_pool1d.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate as burn;

use crate::config::Config;
use crate::module::Module;
use crate::module::{Ignored, Module};
use crate::nn::PaddingConfig1d;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
Expand Down Expand Up @@ -42,7 +42,7 @@ pub struct AvgPool1dConfig {
pub struct AvgPool1d {
stride: usize,
kernel_size: usize,
padding: PaddingConfig1d,
padding: Ignored<PaddingConfig1d>,
count_include_pad: bool,
}

Expand All @@ -52,7 +52,7 @@ impl AvgPool1dConfig {
AvgPool1d {
stride: self.stride,
kernel_size: self.kernel_size,
padding: self.padding.clone(),
padding: Ignored(self.padding.clone()),
count_include_pad: self.count_include_pad,
}
}
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-core/src/nn/pool/avg_pool2d.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate as burn;

use crate::config::Config;
use crate::module::Module;
use crate::module::{Ignored, Module};
use crate::nn::PaddingConfig2d;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
Expand Down Expand Up @@ -41,7 +41,7 @@ pub struct AvgPool2dConfig {
pub struct AvgPool2d {
stride: [usize; 2],
kernel_size: [usize; 2],
padding: PaddingConfig2d,
padding: Ignored<PaddingConfig2d>,
count_include_pad: bool,
}

Expand All @@ -51,7 +51,7 @@ impl AvgPool2dConfig {
AvgPool2d {
stride: self.strides,
kernel_size: self.kernel_size,
padding: self.padding.clone(),
padding: Ignored(self.padding.clone()),
count_include_pad: self.count_include_pad,
}
}
Expand Down
Loading

0 comments on commit 1073a08

Please sign in to comment.