Skip to content

Commit

Permalink
refactor reduce into separate traits (#1798)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd committed May 22, 2024
1 parent 0331719 commit e39b4d2
Show file tree
Hide file tree
Showing 19 changed files with 332 additions and 259 deletions.
86 changes: 13 additions & 73 deletions crates/burn-jit/src/kernel/reduce/base.rs
Original file line number Diff line number Diff line change
@@ -1,78 +1,15 @@
use burn_cube::dialect::{Item, Scope, Variable};

#[cfg(feature = "autotune")]
use crate::kernel::reduce::reduce_dim_autotune;
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};

use super::{reduce_dim_naive, reduce_dim_shared, ArgMax, ArgMin, MeanDim, ProdDim, SumDim};

/// Specifies the reduce dim algorithm in use
pub trait ReduceDimAlgorithm<E: JitElement>: Send + Sync + 'static {
/// The reduction accumulator
type Accumulator: Copy;

/// Initialization for naive algorithm
fn initialize_naive(
scope: &mut Scope,
input_item: Item,
output_item: Item,
) -> Self::Accumulator;

/// Inner loop for naive algorithm
fn inner_loop_naive(
scope: &mut Scope,
accumulator: Self::Accumulator,
current_value: Variable,
i: Variable,
);

/// Assignation for naive algorithm
fn assign_naive(
scope: &mut Scope,
output: Variable,
accumulator: Self::Accumulator,
shape_reduce_dim: Variable,
);
use super::{
naive::{base::ReduceDimNaive, shader::reduce_dim_naive},
shared::{base::ReduceDimShared, shader::reduce_dim_shared},
};

/// Initialization for shared algorithm
fn initialize_shared(
scope: &mut Scope,
shared_memory_size: u32,
write_position: Variable,
input_item: Item,
) -> Self::Accumulator;

/// How to write to shared memory
fn write_to_shared(
scope: &mut Scope,
shared_memory: Self::Accumulator,
write_position: Variable,
value: Self::Accumulator,
);

/// How to read from input in shared algorithm
fn read_from_input(
scope: &mut Scope,
input: Variable,
read_position: Variable,
i: Variable,
) -> Self::Accumulator;

/// How to read from shared memory
fn read_from_shared(
scope: &mut Scope,
shared_memory: Self::Accumulator,
read_position: Variable,
) -> Self::Accumulator;

/// How to assign from shared memory
fn assign_shared(
scope: &mut Scope,
shared_memory: Self::Accumulator,
output: Variable,
write_position: Variable,
shape_reduce_dim: Variable,
);
pub(crate) trait ReduceDimAlgorithm<E: JitElement>:
ReduceDimNaive<E> + ReduceDimShared<E>
{
}

/// Creates an empty output tensor with reduce output shape
Expand Down Expand Up @@ -116,7 +53,10 @@ impl Default for ReduceStrategy {
}

macro_rules! reduce_operation {
($name:ident, $ops:ty) => {
($name:ident, $ops:ident) => {
pub(crate) struct $ops;
impl<E: JitElement> ReduceDimAlgorithm<E> for $ops {}

/// Executes the reduce operation with the given strategy.
pub fn $name<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(
tensor: JitTensor<R, EI, D>,
Expand All @@ -143,5 +83,5 @@ macro_rules! reduce_operation {
reduce_operation!(sum_dim, SumDim);
reduce_operation!(mean_dim, MeanDim);
reduce_operation!(prod_dim, ProdDim);
reduce_operation!(argmin, ArgMin);
reduce_operation!(argmax, ArgMax);
reduce_operation!(argmin, Argmin);
reduce_operation!(argmax, Argmax);
16 changes: 2 additions & 14 deletions crates/burn-jit/src/kernel/reduce/mod.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,11 @@
mod argmax_dim;
mod argmin_dim;
mod base;
mod mean_dim;
mod naive_reduce_shader;
mod naive;
mod prod;
mod prod_dim;
mod shared_reduce_shader;
mod shared;
mod sum;
mod sum_dim;
mod tune;

pub(crate) use argmax_dim::*;
pub(crate) use argmin_dim::*;
pub use base::*;
pub(crate) use mean_dim::*;
pub use naive_reduce_shader::*;
pub use prod::*;
pub(crate) use prod_dim::*;
pub use shared_reduce_shader::*;
pub use sum::*;
pub(crate) use sum_dim::*;
pub use tune::*;
49 changes: 49 additions & 0 deletions crates/burn-jit/src/kernel/reduce/naive/argmax.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use crate::{kernel::reduce::Argmax, JitElement};
use burn_cube::{
cpa,
dialect::{Elem, Item, Scope, Variable},
};

use super::base::ReduceDimNaive;

impl<E: JitElement> ReduceDimNaive<E> for Argmax {
type Accumulator = (Variable, Variable);

fn initialize_naive(
scope: &mut Scope,
input_item: Item,
_output_item: Item,
) -> Self::Accumulator {
let index = scope.create_local(Elem::UInt);
let max = scope.create_local(input_item);
let max_initial =
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), input_item.elem());
cpa!(scope, max = max_initial);

(max, index)
}

fn inner_loop_naive(
scope: &mut Scope,
(max, index): Self::Accumulator,
value: Variable,
i: Variable,
) {
let condition = scope.create_local(Elem::Bool);
cpa!(scope, condition = value > max);
cpa!(scope, if(condition).then(|scope| {
cpa!(scope, max = value);
cpa!(scope, index = i);
}));
}

fn assign_naive(
scope: &mut Scope,
output: Variable,
(_max, index): Self::Accumulator,
_shape_reduce_dim: Variable,
) {
let id = Variable::Id;
cpa!(scope, output[id] = index);
}
}
50 changes: 50 additions & 0 deletions crates/burn-jit/src/kernel/reduce/naive/argmin.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use burn_cube::{
cpa,
dialect::{Elem, Item, Scope, Variable},
};

use crate::{kernel::reduce::Argmin, JitElement};

use super::base::ReduceDimNaive;

impl<E: JitElement> ReduceDimNaive<E> for Argmin {
type Accumulator = (Variable, Variable);

fn initialize_naive(
scope: &mut Scope,
input_item: Item,
_output_item: Item,
) -> Self::Accumulator {
let index = scope.create_local(Elem::UInt);
let min = scope.create_local(input_item);
let min_initial =
Variable::ConstantScalar(E::maximum_value().to_f64().unwrap(), input_item.elem());
cpa!(scope, min = min_initial);

(min, index)
}

fn inner_loop_naive(
scope: &mut Scope,
(min, index): Self::Accumulator,
value: Variable,
i: Variable,
) {
let condition = scope.create_local(Elem::Bool);
cpa!(scope, condition = value < min);
cpa!(scope, if(condition).then(|scope| {
cpa!(scope, min = value);
cpa!(scope, index = i);
}));
}

fn assign_naive(
scope: &mut Scope,
output: Variable,
(_min, index): Self::Accumulator,
_shape_reduce_dim: Variable,
) {
let id = Variable::Id;
cpa!(scope, output[id] = index);
}
}
32 changes: 32 additions & 0 deletions crates/burn-jit/src/kernel/reduce/naive/base.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use burn_cube::dialect::{Item, Scope, Variable};

use crate::JitElement;

/// Specifies the reduce dim algorithm in use
pub trait ReduceDimNaive<E: JitElement>: Send + Sync + 'static {
/// The reduction accumulator
type Accumulator: Copy;

/// Initialization for naive algorithm
fn initialize_naive(
scope: &mut Scope,
input_item: Item,
output_item: Item,
) -> Self::Accumulator;

/// Inner loop for naive algorithm
fn inner_loop_naive(
scope: &mut Scope,
accumulator: Self::Accumulator,
current_value: Variable,
i: Variable,
);

/// Assignation for naive algorithm
fn assign_naive(
scope: &mut Scope,
output: Variable,
accumulator: Self::Accumulator,
shape_reduce_dim: Variable,
);
}
32 changes: 32 additions & 0 deletions crates/burn-jit/src/kernel/reduce/naive/mean_dim.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use crate::{kernel::reduce::MeanDim, JitElement};
use burn_cube::{
cpa,
dialect::{Item, Scope, Variable},
};

use super::base::ReduceDimNaive;

impl<E: JitElement> ReduceDimNaive<E> for MeanDim {
type Accumulator = Variable;

fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable {
scope.zero(output_item)
}

fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) {
cpa!(scope, accumulator += value);
}

fn assign_naive(
scope: &mut Scope,
output: Variable,
accumulator: Variable,
shape_reduce_dim: Variable,
) {
let id = Variable::Id;
let denominator = scope.create_local(accumulator.item());
cpa!(scope, denominator = cast(shape_reduce_dim));
cpa!(scope, accumulator = accumulator / denominator);
cpa!(scope, output[id] = accumulator);
}
}
7 changes: 7 additions & 0 deletions crates/burn-jit/src/kernel/reduce/naive/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pub(crate) mod argmax;
pub(crate) mod argmin;
pub(crate) mod base;
pub(crate) mod mean_dim;
pub(crate) mod prod_dim;
pub(crate) mod shader;
pub(crate) mod sum_dim;
29 changes: 29 additions & 0 deletions crates/burn-jit/src/kernel/reduce/naive/prod_dim.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use crate::{kernel::reduce::ProdDim, JitElement};
use burn_cube::{
cpa,
dialect::{Item, Scope, Variable},
};

use super::base::ReduceDimNaive;

impl<E: JitElement> ReduceDimNaive<E> for ProdDim {
type Accumulator = Variable;

fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable {
scope.create_with_value(1, output_item)
}

fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) {
cpa!(scope, accumulator *= value);
}

fn assign_naive(
scope: &mut Scope,
output: Variable,
accumulator: Variable,
_shape_reduce_dim: Variable,
) {
let id = Variable::Id;
cpa!(scope, output[id] = accumulator);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use std::marker::PhantomData;

use crate::{element::JitElement, kernel::GpuComputeShaderPhase, tensor::JitTensor, JitRuntime};

use super::ReduceDimAlgorithm;
use super::base::ReduceDimNaive;

pub(crate) struct NaiveReduceDimComputeShader<E: JitElement, RD: ReduceDimAlgorithm<E>> {
pub(crate) struct NaiveReduceDimComputeShader<E: JitElement, RD: ReduceDimNaive<E>> {
tensor: Variable,
dim: usize,
output: Variable,
Expand All @@ -20,7 +20,7 @@ pub(crate) struct NaiveReduceDimComputeShader<E: JitElement, RD: ReduceDimAlgori

#[derive(new)]
pub(crate) struct NaiveReduceDimEagerKernel<
RD: ReduceDimAlgorithm<EI>,
RD: ReduceDimNaive<EI>,
R: JitRuntime,
EI: JitElement,
EO: JitElement,
Expand All @@ -32,8 +32,8 @@ pub(crate) struct NaiveReduceDimEagerKernel<
_elem_out: PhantomData<EO>,
}

impl<RD: ReduceDimAlgorithm<EI>, R: JitRuntime, EI: JitElement, EO: JitElement>
GpuComputeShaderPhase for NaiveReduceDimEagerKernel<RD, R, EI, EO>
impl<RD: ReduceDimNaive<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> GpuComputeShaderPhase
for NaiveReduceDimEagerKernel<RD, R, EI, EO>
{
fn compile(&self) -> ComputeShader {
let mut scope = Scope::root();
Expand Down Expand Up @@ -76,7 +76,7 @@ impl<RD: ReduceDimAlgorithm<EI>, R: JitRuntime, EI: JitElement, EO: JitElement>
}
}

impl<E: JitElement, RD: ReduceDimAlgorithm<E>> NaiveReduceDimComputeShader<E, RD> {
impl<E: JitElement, RD: ReduceDimNaive<E>> NaiveReduceDimComputeShader<E, RD> {
pub(crate) fn expand(self, scope: &mut Scope) {
let tensor = self.tensor;
let dim: Variable = self.dim.into();
Expand Down Expand Up @@ -136,7 +136,7 @@ impl<E: JitElement, RD: ReduceDimAlgorithm<E>> NaiveReduceDimComputeShader<E, RD

/// Executes the naive kernel for reduce dim
pub fn reduce_dim_naive<
RD: ReduceDimAlgorithm<EI>,
RD: ReduceDimNaive<EI>,
R: JitRuntime,
EI: JitElement,
EO: JitElement,
Expand Down
Loading

0 comments on commit e39b4d2

Please sign in to comment.