Skip to content

Commit

Permalink
Feat/enable cube cl (#1777)
Browse files Browse the repository at this point in the history
* Ben WIP

* Compile burn-jit

* WGPU works

* Remove old code

* move language cube stuff

* cleaning up

* some import reworking

* remove cube reexport

* template feature flag in cube

* ci

---------

Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
  • Loading branch information
louisfd and nathanielsimard committed May 19, 2024
1 parent 9c5b07c commit 499ff0d
Show file tree
Hide file tree
Showing 161 changed files with 2,732 additions and 6,159 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 7 additions & 3 deletions crates/burn-cube/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@ repository = "https://github.com/tracel-ai/burn/tree/main/burn-cube"
version.workspace = true

[features]
default = []
default = ["tensor"]
std = []
template = []
tensor = ["burn-tensor"]

[dependencies]
burn-compute = { path = "../burn-compute", version = "0.15.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.15.0", default-features = false, optional = true }

bytemuck = { workspace = true }
half = { workspace = true, features=["bytemuck"] }
serde = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }
serde = { workspace = true }
burn-cube-macros = { path = "../burn-cube-macros", version = "0.15.0" }
derive-new = { workspace = true }

Expand Down
117 changes: 1 addition & 116 deletions crates/burn-cube/src/codegen/compilation.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#[cfg(feature = "fusion")]
use crate::fusion::JitFusionHandle;

use super::Compiler;
use crate::{
codegen::dialect::{
Expand Down Expand Up @@ -43,7 +40,7 @@ pub struct CompilationSettings {
pub mappings: Vec<InplaceMapping>,
vectorization: Option<Vectorization>,
workgroup_size: WorkgroupSize,
reading_strategy: Vec<(u16, ReadingStrategy)>,
pub reading_strategy: Vec<(u16, ReadingStrategy)>,
}

impl core::fmt::Display for CompilationSettings {
Expand Down Expand Up @@ -119,118 +116,6 @@ impl CompilationSettings {
self
}

#[cfg(feature = "fusion")]
/// Apply dynamic settings based on the runtime information captured by the `burn-fusion`
/// project.
///
/// Two optimizations are done here:
///
/// 1. Find and remove unnecessary broadcasting procedures based on runtime tensor layouts.
///
/// 2. (Optional) Find which inputs can be used inplaced based on runtime tensor layouts and captured tensor
/// descriptions. This is enabled only when stateful is set to true.
pub fn dynamic_settings<R: Runtime>(
self,
info: &CompilationInfo,
inputs: &[&burn_tensor::repr::TensorDescription],
outputs: &[&burn_tensor::repr::TensorDescription],
handles_inputs: &[JitFusionHandle<R>],
stateful: bool,
) -> Self {
let mut settings = self;

if stateful {
settings = settings.dynamic_inplace(info, inputs, outputs, handles_inputs);
}

settings.dynamic_reading_strategy(info, inputs, outputs, handles_inputs)
}

#[cfg(feature = "fusion")]
fn dynamic_inplace<R: Runtime>(
self,
info: &CompilationInfo,
inputs: &[&burn_tensor::repr::TensorDescription],
outputs: &[&burn_tensor::repr::TensorDescription],
handles_inputs: &[JitFusionHandle<R>],
) -> Self {
let mut potential_inplace = inputs
.iter()
.zip(info.inputs.iter())
.enumerate()
.filter_map(|(pos, (desc, input))| {
match desc.status {
burn_tensor::repr::TensorStatus::ReadOnly => return None,
burn_tensor::repr::TensorStatus::NotInit => return None,
burn_tensor::repr::TensorStatus::ReadWrite => (),
};

let handle = &handles_inputs[pos];

if handle.handle.can_mut() && is_contiguous(&handle.strides) {
Some((pos, desc, input))
} else {
None
}
})
.collect::<Vec<_>>();

let mappings = outputs
.iter()
.zip(info.outputs.iter())
.enumerate()
.filter_map(|(pos, (desc, output))| {
if potential_inplace.is_empty() {
return None;
}

for (index, (_, desc_input, input)) in potential_inplace.iter().enumerate() {
if desc.shape == desc_input.shape && input.item() == output.item() {
let (pos_input, _desc, _info) = potential_inplace.remove(index);
return Some(InplaceMapping::new(pos_input, pos));
}
}

None
})
.collect();

self.inplace(mappings)
}

#[cfg(feature = "fusion")]
fn dynamic_reading_strategy<R: Runtime>(
mut self,
info: &CompilationInfo,
inputs: &[&burn_tensor::repr::TensorDescription],
outputs: &[&burn_tensor::repr::TensorDescription],
handles_inputs: &[JitFusionHandle<R>],
) -> Self {
// First output is chosen for the layout reference.
// but all outputs should have the same shape anyways.
let layout_shape = &outputs[0].shape;

for (input_id, strategy) in info.scope.read_globals() {
if let ReadingStrategy::Plain = strategy {
continue;
};

let index = input_id as usize;
let handle = &handles_inputs[index];
let description_input = &inputs[index];

if &description_input.shape != layout_shape {
continue;
}

if is_contiguous(&handle.strides) {
self.reading_strategy
.push((input_id, ReadingStrategy::Plain));
}
}
self
}

/// Set the grid size.
#[allow(dead_code)] // Only used for fusion for now.
pub fn workgroup_size(mut self, workgroup_size: WorkgroupSize) -> Self {
Expand Down
16 changes: 7 additions & 9 deletions crates/burn-cube/src/codegen/dialect/scope.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use crate::PrimitiveVariable;

use super::{
cpa, processing::ScopeProcessing, Elem, IndexOffsetGlobalWithLayout, Item, Operation, Operator,
Procedure, ReadGlobal, ReadGlobalWithLayout, UnaryOperator, Variable, Vectorization,
Expand Down Expand Up @@ -69,13 +67,13 @@ impl Scope {
}

/// Create a variable initialized at some value.
pub fn create_with_value<E: PrimitiveVariable, I: Into<Item> + Copy>(
&mut self,
value: E,
item: I,
) -> Variable {
pub fn create_with_value<E, I>(&mut self, value: E, item: I) -> Variable
where
E: Into<f64>,
I: Into<Item> + Copy,
{
let local = self.create_local(item);
let value = Variable::ConstantScalar(value.to_f64(), item.into().elem());
let value = Variable::ConstantScalar(value.into(), item.into().elem());
cpa!(self, local = value);
local
}
Expand Down Expand Up @@ -203,7 +201,7 @@ impl Scope {
}

#[allow(dead_code)]
pub(crate) fn read_globals(&self) -> Vec<(u16, ReadingStrategy)> {
pub fn read_globals(&self) -> Vec<(u16, ReadingStrategy)> {
self.reads_global
.iter()
.map(|(var, strategy, _, _)| match var {
Expand Down
37 changes: 19 additions & 18 deletions crates/burn-cube/src/codegen/dialect/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,25 @@ impl From<Elem> for Item {
}
}

// impl From<DType> for Elem {
// fn from(dtype: DType) -> Self {
// match dtype {
// DType::F64 => Elem::Float(FloatKind::F64),
// DType::F32 => Elem::Float(FloatKind::F32),
// DType::F16 => Elem::Float(FloatKind::F16),
// DType::BF16 => Elem::Float(FloatKind::BF16),
// DType::I64 => Elem::Int(IntKind::I64),
// DType::I32 => Elem::Int(IntKind::I32),
// DType::I16 => panic!("i16 isn't supported yet."),
// DType::I8 => panic!("i8 isn't supported yet."),
// DType::U64 => Elem::UInt,
// DType::U32 => Elem::UInt,
// DType::U8 => panic!("u8 isn't supported yet."),
// DType::Bool => Elem::Bool,
// }
// }
// }
#[cfg(feature = "tensor")]
impl From<burn_tensor::DType> for Elem {
fn from(dtype: burn_tensor::DType) -> Self {
match dtype {
burn_tensor::DType::F64 => Elem::Float(FloatKind::F64),
burn_tensor::DType::F32 => Elem::Float(FloatKind::F32),
burn_tensor::DType::F16 => Elem::Float(FloatKind::F16),
burn_tensor::DType::BF16 => Elem::Float(FloatKind::BF16),
burn_tensor::DType::I64 => Elem::Int(IntKind::I64),
burn_tensor::DType::I32 => Elem::Int(IntKind::I32),
burn_tensor::DType::I16 => panic!("i16 isn't supported yet."),
burn_tensor::DType::I8 => panic!("i8 isn't supported yet."),
burn_tensor::DType::U64 => Elem::UInt,
burn_tensor::DType::U32 => Elem::UInt,
burn_tensor::DType::U8 => panic!("u8 isn't supported yet."),
burn_tensor::DType::Bool => Elem::Bool,
}
}
}

impl Display for Elem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-cube/src/codegen/dialect/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub enum Variable {
}

impl Variable {
pub(crate) fn index(&self) -> Option<u16> {
pub fn index(&self) -> Option<u16> {
match self {
Variable::GlobalInputArray(idx, _) => Some(*idx),
Variable::GlobalScalar(idx, _) => Some(*idx),
Expand Down
8 changes: 4 additions & 4 deletions crates/burn-cube/src/codegen/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,9 @@ fn create_scalar_handles<R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeE
Elem::Bool => panic!("Bool scalars are not supported"),
};
let scalar_priorities: [usize; 3] = [
element_priority(E1::elem()),
element_priority(E2::elem()),
element_priority(E3::elem()),
element_priority(E1::cube_elem()),
element_priority(E2::cube_elem()),
element_priority(E3::cube_elem()),
];

let mut handles_scalars = Vec::new();
Expand All @@ -359,7 +359,7 @@ fn create_scalar_handles<R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeE
handles_scalars
}

pub(crate) fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
pub fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
let mut num_elems = 1;
for i in shape.iter() {
num_elems *= i;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{ops::Deref, rc::Rc};

use crate::dialect::{Branch, Elem, If, IfElse, Item, Loop, RangeLoop, Variable};
use crate::{CubeContext, ExpandElement, UInt};
use crate::language::{CubeContext, ExpandElement, UInt};

pub fn range<S, E>(start: S, end: E, _unroll: bool) -> core::ops::Range<usize>
where
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::dialect::{Item, Operation, Scope};
use crate::ExpandElement;
use crate::language::ExpandElement;
use alloc::rc::Rc;
use core::cell::RefCell;
use std::collections::HashMap;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{CubeType, ExpandElement};
use crate::language::{CubeType, ExpandElement};

#[derive(new, Clone)]
pub struct Array<E> {
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::dialect::Elem;

use crate::{CubeContext, CubeType, ExpandElement, PrimitiveVariable};
use crate::language::{CubeContext, CubeType, ExpandElement, PrimitiveVariable};

#[derive(Clone, Copy)]
/// Boolean type for kernels
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{assign, dialect::Item, CubeContext, CubeType, PrimitiveVariable};
use crate::dialect::Item;
use crate::language::{assign, CubeContext, CubeType, PrimitiveVariable};

// Enable elegant casting from any to any primitive variable

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::dialect::{Elem, FloatKind, Variable};
use crate::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable};
use crate::language::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable};
use std::rc::Rc;

/// Floating point numbers. Used as input in float kernels
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::dialect::{Elem, IntKind, Variable};
use crate::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable};
use crate::language::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable};
use std::rc::Rc;

/// Signed integer. Used as input in int kernels
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::dialect::Variable;
use crate::language::{CubeContext, CubeType, ExpandElement, PrimitiveVariable};
use std::rc::Rc;

use crate::{dialect::Variable, CubeContext, CubeType, ExpandElement, PrimitiveVariable};

/// Type that encompasses both (unsigned or signed) integers and floats
/// Used in kernels that should work for both.
pub trait Numeric:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use std::rc::Rc;

use crate::dialect::{Elem, Variable};

use crate::{CubeType, ExpandElement};
use crate::language::{CubeType, ExpandElement};

/// Form of CubeType that encapsulates all primitive types:
/// Numeric, UInt, Bool
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::CubeType;
use crate::language::CubeType;

/// Types that exist within a cube function
/// but should not be turned into a JIT variable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use std::rc::Rc;

use crate::dialect::{Elem, Variable};

use crate::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable};
use crate::language::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable};

#[derive(Clone, Copy)]
/// An unsigned int.
Expand Down
9 changes: 9 additions & 0 deletions crates/burn-cube/src/language/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// For use with *
pub mod branch;
mod context;
mod element;
mod operation;

pub use context::*;
pub use element::*;
pub use operation::*;
Loading

0 comments on commit 499ff0d

Please sign in to comment.