Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/jit fusion #1750

Merged
merged 5 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
48 changes: 40 additions & 8 deletions crates/burn-jit/src/codegen/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,16 @@ impl OutputInfo {
#[allow(dead_code)]
pub fn item(&self) -> Item {
match self {
OutputInfo::ArrayWrite { item, local: _ } => *item,
OutputInfo::ArrayWrite {
item,
local: _,
position: _,
} => *item,
OutputInfo::InputArrayWrite {
item,
input: _,
local: _,
position: _,
} => *item,
OutputInfo::Array { item } => *item,
}
Expand All @@ -298,9 +303,18 @@ pub enum OutputInfo {
/// Write the local variable to a new array.
///
/// This will create a new binding in the [compute shader](ComputeShader).
ArrayWrite { item: Item, local: u16 },
ArrayWrite {
item: Item,
local: u16,
position: Variable,
},
/// Write the local variable to an existing input binding.
InputArrayWrite { item: Item, input: u16, local: u16 },
InputArrayWrite {
item: Item,
input: u16,
local: u16,
position: Variable,
},
/// Simply register the output, but don't automatically add a write to it.
///
/// Useful when a [procedure](gpu::Procedure) writes to the output using
Expand All @@ -312,11 +326,16 @@ impl OutputInfo {
#[allow(dead_code)]
pub fn elem_size<R: Runtime>(&self) -> usize {
let elem = match self {
OutputInfo::ArrayWrite { item, local: _ } => bool_elem(item.elem()),
OutputInfo::ArrayWrite {
item,
local: _,
position: _,
} => bool_elem(item.elem()),
OutputInfo::InputArrayWrite {
item,
input: _,
local: _,
position: _,
} => bool_elem(item.elem()),
OutputInfo::Array { item } => bool_elem(item.elem()),
};
Expand Down Expand Up @@ -424,7 +443,11 @@ impl Compilation {

for array in self.info.outputs.drain(..) {
match array {
OutputInfo::ArrayWrite { item, local } => {
OutputInfo::ArrayWrite {
item,
local,
position,
} => {
let item = if let Some(vectorization) = settings.vectorization {
item.vectorize(vectorization)
} else {
Expand All @@ -441,10 +464,16 @@ impl Compilation {
self.info.scope.write_global(
Variable::Local(local, item, self.info.scope.depth),
Variable::GlobalOutputArray(index, elem_adapted),
position,
);
index += 1;
}
OutputInfo::InputArrayWrite { item, input, local } => {
OutputInfo::InputArrayWrite {
item,
input,
local,
position,
} => {
let item = if let Some(vectorization) = settings.vectorization {
item.vectorize(vectorization)
} else {
Expand All @@ -454,6 +483,7 @@ impl Compilation {
self.info.scope.write_global(
Variable::Local(local, item, self.info.scope.depth),
Variable::GlobalInputArray(input, bool_item(item)),
position,
);
}
OutputInfo::Array { item } => {
Expand Down Expand Up @@ -483,12 +513,13 @@ impl Compilation {
None => panic!("No output found."),
};

let (item, local) = match output {
OutputInfo::ArrayWrite { item, local } => (item, local),
let (item, local, position) = match output {
OutputInfo::ArrayWrite { item, local, position } => (item, local, position),
OutputInfo::InputArrayWrite {
item: _,
input,
local: _,
position: _,
} => {
assert_eq!(
*input, mapping.pos_input as u16,
Expand Down Expand Up @@ -521,6 +552,7 @@ impl Compilation {
item,
input: mapping.pos_input as u16,
local: *local,
position: *position,
};
}
}
Expand Down
23 changes: 17 additions & 6 deletions crates/burn-jit/src/codegen/dialect/gpu/procedure/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ pub struct ReadGlobal {
pub global: Variable,
/// The output variable to write the result.
pub out: Variable,
/// The reference position index.
pub position: Variable,
}

/// Read a global array with the given layout.
Expand All @@ -20,21 +22,24 @@ pub struct ReadGlobalWithLayout {
pub outs: Vec<Variable>,
/// The layout to be used.
pub layout: Variable,
/// The reference position index.
pub position: Variable,
}

impl ReadGlobal {
#[allow(missing_docs)]
pub fn expand(self, scope: &mut Scope) {
scope.register(Operator::Index(BinaryOperator {
lhs: self.global,
rhs: Variable::Id,
rhs: self.position,
out: self.out,
}));
}
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
Self {
global: self.global.vectorize(vectorization),
out: self.out.vectorize(vectorization),
position: self.position,
}
}
}
Expand All @@ -47,6 +52,10 @@ impl ReadGlobalWithLayout {
return None;
}

if self.position != other.position {
return None;
}

let mut globals = Vec::with_capacity(self.globals.len() + other.globals.len());
globals.extend(&self.globals);
globals.extend(&other.globals);
Expand All @@ -59,6 +68,7 @@ impl ReadGlobalWithLayout {
globals,
outs,
layout: self.layout,
position: self.position,
})
}

Expand All @@ -75,7 +85,7 @@ impl ReadGlobalWithLayout {
tensors: tensors.clone(),
layout: self.layout,
indexes: indexes.clone(),
index_ref: Variable::Id,
position: self.position,
dim_start: 0u32.into(),
dim_end: Variable::Rank,
}
Expand Down Expand Up @@ -103,6 +113,7 @@ impl ReadGlobalWithLayout {
.iter()
.map(|o| o.vectorize(vectorization))
.collect(),
position: self.position,
}
}
}
Expand All @@ -117,10 +128,10 @@ pub struct IndexOffsetGlobalWithLayout {
pub indexes: Vec<Variable>,
/// Reference layout.
pub layout: Variable,
/// Index that corresponds to the reference layout.
/// Position index that corresponds to the reference layout.
///
/// All other indexes will be made to be compatible with this one.
pub index_ref: Variable,
pub position: Variable,
pub dim_start: Variable,
pub dim_end: Variable,
}
Expand All @@ -130,7 +141,7 @@ impl IndexOffsetGlobalWithLayout {
pub fn expand(self, scope: &mut Scope) {
let layout = self.layout;
let index_item_ty = Item::Scalar(Elem::UInt);
let offset_ref = self.index_ref;
let offset_ref = self.position;
let zero: Variable = 0u32.into();
let vectorization_factor: Variable = match self.tensors[0].item() {
Item::Vec4(_) => 4u32,
Expand Down Expand Up @@ -187,7 +198,7 @@ impl IndexOffsetGlobalWithLayout {
.map(|t| t.vectorize(vectorization))
.collect(),
layout: self.layout.vectorize(vectorization),
index_ref: self.index_ref.vectorize(vectorization),
position: self.position.vectorize(vectorization),
dim_start: self.dim_start,
dim_end: self.dim_end,
}
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-jit/src/codegen/dialect/gpu/procedure/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ use serde::{Deserialize, Serialize};
pub struct WriteGlobal {
pub input: Variable,
pub global: Variable,
pub position: Variable,
}

impl WriteGlobal {
#[allow(missing_docs)]
pub fn expand(self, scope: &mut Scope) {
let output = self.global;
let input = self.input;
let position = Variable::Id;
let position = self.position;

gpu!(scope, output[position] = input);
}
Expand All @@ -23,6 +24,7 @@ impl WriteGlobal {
Self {
input: self.input.vectorize(vectorization),
global: self.global.vectorize(vectorization),
position: self.position,
}
}
}
53 changes: 33 additions & 20 deletions crates/burn-jit/src/codegen/dialect/gpu/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ pub struct Scope {
locals: Vec<Variable>,
shared_memories: Vec<Variable>,
local_arrays: Vec<Variable>,
reads_global: Vec<(Variable, ReadingStrategy, Variable)>,
reads_global: Vec<(Variable, ReadingStrategy, Variable, Variable)>,
index_offset_with_output_layout_position: Vec<usize>,
writes_global: Vec<(Variable, Variable)>,
writes_global: Vec<(Variable, Variable, Variable)>,
reads_scalar: Vec<(Variable, Variable)>,
pub layout_ref: Option<Variable>,
undeclared: u16,
Expand Down Expand Up @@ -102,8 +102,13 @@ impl Scope {
/// Reads an input array to a local variable.
///
/// The index refers to the argument position of the array in the compute shader.
pub(crate) fn read_array<I: Into<Item>>(&mut self, index: u16, item: I) -> Variable {
self.read_input_strategy(index, item.into(), ReadingStrategy::OutputLayout)
pub(crate) fn read_array<I: Into<Item>>(
&mut self,
index: u16,
item: I,
position: Variable,
) -> Variable {
self.read_input_strategy(index, item.into(), ReadingStrategy::OutputLayout, position)
}

/// Add the procedure into the scope.
Expand Down Expand Up @@ -143,27 +148,31 @@ impl Scope {
self.locals
.iter_mut()
.for_each(|var| *var = var.vectorize(vectorization));
self.reads_global.iter_mut().for_each(|(input, _, output)| {
*input = input.vectorize(vectorization);
*output = output.vectorize(vectorization);
});
self.writes_global.iter_mut().for_each(|(input, output)| {
*input = input.vectorize(vectorization);
*output = output.vectorize(vectorization);
});
self.reads_global
.iter_mut()
.for_each(|(input, _, output, _position)| {
*input = input.vectorize(vectorization);
*output = output.vectorize(vectorization);
});
self.writes_global
.iter_mut()
.for_each(|(input, output, _)| {
*input = input.vectorize(vectorization);
*output = output.vectorize(vectorization);
});
}

/// Writes a variable to given output.
///
/// Notes:
///
/// This should only be used when doing compilation.
pub(crate) fn write_global(&mut self, input: Variable, output: Variable) {
pub(crate) fn write_global(&mut self, input: Variable, output: Variable, position: Variable) {
// This assumes that all outputs have the same layout
if self.layout_ref.is_none() {
self.layout_ref = Some(output);
}
self.writes_global.push((input, output));
self.writes_global.push((input, output, position));
}

/// Writes a variable to given output.
Expand All @@ -184,10 +193,10 @@ impl Scope {
///
/// This should only be used when doing compilation.
pub(crate) fn update_read(&mut self, index: u16, strategy: ReadingStrategy) {
if let Some((_, strategy_old, _)) = self
if let Some((_, strategy_old, _, _position)) = self
.reads_global
.iter_mut()
.find(|(var, _, _)| var.index() == Some(index))
.find(|(var, _, _, _)| var.index() == Some(index))
{
*strategy_old = strategy;
}
Expand All @@ -197,7 +206,7 @@ impl Scope {
pub(crate) fn read_globals(&self) -> Vec<(u16, ReadingStrategy)> {
self.reads_global
.iter()
.map(|(var, strategy, _)| match var {
.map(|(var, strategy, _, _)| match var {
Variable::GlobalInputArray(id, _) => (*id, *strategy),
_ => panic!("Can only read global input arrays."),
})
Expand Down Expand Up @@ -250,7 +259,7 @@ impl Scope {

let mut operations = Vec::new();

for (input, strategy, local) in self.reads_global.drain(..) {
for (input, strategy, local, position) in self.reads_global.drain(..) {
match strategy {
ReadingStrategy::OutputLayout => {
let output = self.layout_ref.expect(
Expand All @@ -261,13 +270,15 @@ impl Scope {
globals: vec![input],
layout: output,
outs: vec![local],
position,
},
)));
}
ReadingStrategy::Plain => {
operations.push(Operation::Procedure(Procedure::ReadGlobal(ReadGlobal {
global: input,
out: local,
position,
})))
}
}
Expand All @@ -288,10 +299,11 @@ impl Scope {
operations.push(op);
}

for (input, global) in self.writes_global.drain(..) {
for (input, global, position) in self.writes_global.drain(..) {
operations.push(Operation::Procedure(Procedure::WriteGlobal(WriteGlobal {
input,
global,
position,
})))
}

Expand Down Expand Up @@ -323,6 +335,7 @@ impl Scope {
index: u16,
item: Item,
strategy: ReadingStrategy,
position: Variable,
) -> Variable {
let item_global = match item.elem() {
Elem::Bool => match item {
Expand All @@ -336,7 +349,7 @@ impl Scope {
let input = Variable::GlobalInputArray(index, item_global);
let index = self.new_local_index();
let local = Variable::Local(index, item, self.depth);
self.reads_global.push((input, strategy, local));
self.reads_global.push((input, strategy, local, position));
self.locals.push(local);
local
}
Expand Down