-
Notifications
You must be signed in to change notification settings - Fork 382
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor reduce into separate traits (#1798)
- Loading branch information
Showing
19 changed files
with
332 additions
and
259 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,11 @@ | ||
mod argmax_dim; | ||
mod argmin_dim; | ||
mod base; | ||
mod mean_dim; | ||
mod naive_reduce_shader; | ||
mod naive; | ||
mod prod; | ||
mod prod_dim; | ||
mod shared_reduce_shader; | ||
mod shared; | ||
mod sum; | ||
mod sum_dim; | ||
mod tune; | ||
|
||
pub(crate) use argmax_dim::*; | ||
pub(crate) use argmin_dim::*; | ||
pub use base::*; | ||
pub(crate) use mean_dim::*; | ||
pub use naive_reduce_shader::*; | ||
pub use prod::*; | ||
pub(crate) use prod_dim::*; | ||
pub use shared_reduce_shader::*; | ||
pub use sum::*; | ||
pub(crate) use sum_dim::*; | ||
pub use tune::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
use crate::{kernel::reduce::Argmax, JitElement}; | ||
use burn_cube::{ | ||
cpa, | ||
dialect::{Elem, Item, Scope, Variable}, | ||
}; | ||
|
||
use super::base::ReduceDimNaive; | ||
|
||
impl<E: JitElement> ReduceDimNaive<E> for Argmax { | ||
type Accumulator = (Variable, Variable); | ||
|
||
fn initialize_naive( | ||
scope: &mut Scope, | ||
input_item: Item, | ||
_output_item: Item, | ||
) -> Self::Accumulator { | ||
let index = scope.create_local(Elem::UInt); | ||
let max = scope.create_local(input_item); | ||
let max_initial = | ||
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), input_item.elem()); | ||
cpa!(scope, max = max_initial); | ||
|
||
(max, index) | ||
} | ||
|
||
fn inner_loop_naive( | ||
scope: &mut Scope, | ||
(max, index): Self::Accumulator, | ||
value: Variable, | ||
i: Variable, | ||
) { | ||
let condition = scope.create_local(Elem::Bool); | ||
cpa!(scope, condition = value > max); | ||
cpa!(scope, if(condition).then(|scope| { | ||
cpa!(scope, max = value); | ||
cpa!(scope, index = i); | ||
})); | ||
} | ||
|
||
fn assign_naive( | ||
scope: &mut Scope, | ||
output: Variable, | ||
(_max, index): Self::Accumulator, | ||
_shape_reduce_dim: Variable, | ||
) { | ||
let id = Variable::Id; | ||
cpa!(scope, output[id] = index); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
use burn_cube::{ | ||
cpa, | ||
dialect::{Elem, Item, Scope, Variable}, | ||
}; | ||
|
||
use crate::{kernel::reduce::Argmin, JitElement}; | ||
|
||
use super::base::ReduceDimNaive; | ||
|
||
impl<E: JitElement> ReduceDimNaive<E> for Argmin { | ||
type Accumulator = (Variable, Variable); | ||
|
||
fn initialize_naive( | ||
scope: &mut Scope, | ||
input_item: Item, | ||
_output_item: Item, | ||
) -> Self::Accumulator { | ||
let index = scope.create_local(Elem::UInt); | ||
let min = scope.create_local(input_item); | ||
let min_initial = | ||
Variable::ConstantScalar(E::maximum_value().to_f64().unwrap(), input_item.elem()); | ||
cpa!(scope, min = min_initial); | ||
|
||
(min, index) | ||
} | ||
|
||
fn inner_loop_naive( | ||
scope: &mut Scope, | ||
(min, index): Self::Accumulator, | ||
value: Variable, | ||
i: Variable, | ||
) { | ||
let condition = scope.create_local(Elem::Bool); | ||
cpa!(scope, condition = value < min); | ||
cpa!(scope, if(condition).then(|scope| { | ||
cpa!(scope, min = value); | ||
cpa!(scope, index = i); | ||
})); | ||
} | ||
|
||
fn assign_naive( | ||
scope: &mut Scope, | ||
output: Variable, | ||
(_min, index): Self::Accumulator, | ||
_shape_reduce_dim: Variable, | ||
) { | ||
let id = Variable::Id; | ||
cpa!(scope, output[id] = index); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
use burn_cube::dialect::{Item, Scope, Variable}; | ||
|
||
use crate::JitElement; | ||
|
||
/// Specifies the reduce dim algorithm in use | ||
pub trait ReduceDimNaive<E: JitElement>: Send + Sync + 'static { | ||
/// The reduction accumulator | ||
type Accumulator: Copy; | ||
|
||
/// Initialization for naive algorithm | ||
fn initialize_naive( | ||
scope: &mut Scope, | ||
input_item: Item, | ||
output_item: Item, | ||
) -> Self::Accumulator; | ||
|
||
/// Inner loop for naive algorithm | ||
fn inner_loop_naive( | ||
scope: &mut Scope, | ||
accumulator: Self::Accumulator, | ||
current_value: Variable, | ||
i: Variable, | ||
); | ||
|
||
/// Assignation for naive algorithm | ||
fn assign_naive( | ||
scope: &mut Scope, | ||
output: Variable, | ||
accumulator: Self::Accumulator, | ||
shape_reduce_dim: Variable, | ||
); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
use crate::{kernel::reduce::MeanDim, JitElement}; | ||
use burn_cube::{ | ||
cpa, | ||
dialect::{Item, Scope, Variable}, | ||
}; | ||
|
||
use super::base::ReduceDimNaive; | ||
|
||
impl<E: JitElement> ReduceDimNaive<E> for MeanDim { | ||
type Accumulator = Variable; | ||
|
||
fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable { | ||
scope.zero(output_item) | ||
} | ||
|
||
fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) { | ||
cpa!(scope, accumulator += value); | ||
} | ||
|
||
fn assign_naive( | ||
scope: &mut Scope, | ||
output: Variable, | ||
accumulator: Variable, | ||
shape_reduce_dim: Variable, | ||
) { | ||
let id = Variable::Id; | ||
let denominator = scope.create_local(accumulator.item()); | ||
cpa!(scope, denominator = cast(shape_reduce_dim)); | ||
cpa!(scope, accumulator = accumulator / denominator); | ||
cpa!(scope, output[id] = accumulator); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
pub(crate) mod argmax; | ||
pub(crate) mod argmin; | ||
pub(crate) mod base; | ||
pub(crate) mod mean_dim; | ||
pub(crate) mod prod_dim; | ||
pub(crate) mod shader; | ||
pub(crate) mod sum_dim; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
use crate::{kernel::reduce::ProdDim, JitElement}; | ||
use burn_cube::{ | ||
cpa, | ||
dialect::{Item, Scope, Variable}, | ||
}; | ||
|
||
use super::base::ReduceDimNaive; | ||
|
||
impl<E: JitElement> ReduceDimNaive<E> for ProdDim { | ||
type Accumulator = Variable; | ||
|
||
fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable { | ||
scope.create_with_value(1, output_item) | ||
} | ||
|
||
fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) { | ||
cpa!(scope, accumulator *= value); | ||
} | ||
|
||
fn assign_naive( | ||
scope: &mut Scope, | ||
output: Variable, | ||
accumulator: Variable, | ||
_shape_reduce_dim: Variable, | ||
) { | ||
let id = Variable::Id; | ||
cpa!(scope, output[id] = accumulator); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.