Skip to content

Commit

Permalink
group_norm functional implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
jwric committed Jun 13, 2024
1 parent 1a3655b commit f514031
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 47 deletions.
98 changes: 60 additions & 38 deletions crates/burn-core/src/nn/norm/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,7 @@ pub struct GroupNormConfig {
pub affine: bool,
}

/// Applies Group Normalization over a mini-batch of inputs as described in the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
///
/// `Y = groupnorm(X) * γ + β`
///
/// Where:
/// - `X` is the input tensor
/// - `Y` is the output tensor
/// - `γ` is the learnable weight
/// - `β` is the learnable bias
/// Applies Group Normalization over a mini-batch of inputs. See the [group_norm](group_norm) function for more information.
///
/// Should be created using the [GroupNormConfig](GroupNormConfig) struct.
#[derive(Module, Debug)]
Expand Down Expand Up @@ -87,45 +79,75 @@ impl<B: Backend> GroupNorm<B> {
/// - input: `[batch_size, num_channels, *]`
/// - output: `[batch_size, num_channels, *]`
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let shape = input.shape();
if shape.num_elements() <= 2 {
if input.shape().dims[1] != self.num_channels {
panic!(
"input rank for GroupNorm should be at least 3, but got {}",
shape.num_elements()
"The number of channels in the input tensor should be equal to the number of channels in the GroupNorm module. Expected {}, got {}",
self.num_channels,
input.shape().dims[1]
);
}

let batch_size = shape.dims[0];
let num_channels = shape.dims[1];
let gamma = self.gamma.as_ref().map(|x| x.val());
let beta = self.beta.as_ref().map(|x| x.val());

if num_channels != self.num_channels {
panic!(
"expected {} channels but got {}",
self.num_channels, num_channels
);
}
group_norm(input, gamma, beta, self.num_groups, self.epsilon, self.affine)
}
}

/// Applies Group Normalization over a mini-batch of inputs as described in the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
///
/// `Y = groupnorm(X) * γ + β`
///
/// Where:
/// - `X` is the input tensor
/// - `Y` is the output tensor
/// - `γ` is the learnable weight
/// - `β` is the learnable bias
///
pub(crate) fn group_norm<B: Backend, const D: usize>(
input: Tensor<B, D>,
gamma: Option<Tensor<B, 1>>,
beta: Option<Tensor<B, 1>>,
num_groups: usize,
epsilon: f64,
affine: bool
) -> Tensor<B, D> {

let hidden_size =
shape.dims[2..].iter().product::<usize>() * num_channels / self.num_groups;
let input = input.reshape([batch_size, self.num_groups, hidden_size]);
if (affine && gamma.is_none()) || (affine && beta.is_none()) {
panic!("Affine is set to true, but gamma or beta is None");
}

let mean = input.clone().sum_dim(2) / hidden_size as f64;
let input = input.sub(mean);
let shape = input.shape();
if shape.num_elements() <= 2 {
panic!(
"input rank for GroupNorm should be at least 3, but got {}",
shape.num_elements()
);
}

let var = input.clone().powf_scalar(2.).sum_dim(2) / hidden_size as f64;
let input_normalized = input.div(var.sqrt().add_scalar(self.epsilon));
let batch_size = shape.dims[0];
let num_channels = shape.dims[1];

if self.affine {
let mut affine_shape = [1; D];
affine_shape[1] = num_channels;
let hidden_size =
shape.dims[2..].iter().product::<usize>() * num_channels / num_groups;
let input = input.reshape([batch_size, num_groups, hidden_size]);

input_normalized
.reshape(shape)
.mul(self.gamma.clone().unwrap().val().reshape(affine_shape))
.add(self.beta.clone().unwrap().val().reshape(affine_shape))
} else {
input_normalized.reshape(shape)
}
let mean = input.clone().sum_dim(2) / hidden_size as f64;
let input = input.sub(mean);

let var = input.clone().powf_scalar(2.).sum_dim(2) / hidden_size as f64;
let input_normalized = input.div(var.sqrt().add_scalar(epsilon));

if affine {
let mut affine_shape = [1; D];
affine_shape[1] = num_channels;

input_normalized
.reshape(shape)
.mul(gamma.clone().unwrap().reshape(affine_shape))
.add(beta.clone().unwrap().reshape(affine_shape))
} else {
input_normalized.reshape(shape)
}
}

Expand Down
16 changes: 7 additions & 9 deletions crates/burn-core/src/nn/norm/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate as burn;
use crate::config::Config;
use crate::module::{Module, Param};
use crate::tensor::{backend::Backend, Tensor};
use crate::nn::norm::GroupNorm;
use crate::nn::norm::group_norm;
use crate::nn::Initializer;

/// Configuration to create a [InstanceNorm](InstanceNorm) layer using the [init function](InstanceNormConfig::init).
Expand Down Expand Up @@ -70,14 +70,12 @@ impl<B: Backend> InstanceNorm<B> {
/// - output: `[batch_size, num_channels, *]`
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
// Instance norm is equivalent to group norm when the number of groups is equal to the number of channels.
GroupNorm {
gamma: self.gamma.clone(),
beta: self.beta.clone(),
num_groups: self.num_channels,
num_channels: self.num_channels,
epsilon: self.epsilon,
affine: self.affine,
}.forward(input)
let num_groups = self.num_channels;

let gamma = self.gamma.as_ref().map(|x| x.val());
let beta = self.beta.as_ref().map(|x| x.val());

group_norm(input, gamma, beta, num_groups, self.epsilon, self.affine)
}
}

Expand Down

0 comments on commit f514031

Please sign in to comment.