From 2142730edd67007a35d0dbf1d22cacef4ba0fc25 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Tue, 9 Apr 2024 14:40:44 -0400 Subject: [PATCH 01/54] WIP --- Cargo.lock | 19 +++ crates/burn-cube-macros/Cargo.toml | 24 ++++ crates/burn-cube-macros/src/lib.rs | 116 ++++++++++++++++++ crates/burn-cube/Cargo.toml | 21 ++++ crates/burn-cube/src/lib.rs | 96 +++++++++++++++ crates/burn-cube/tests/cube.rs | 27 ++++ .../burn-jit/src/codegen/dialect/gpu/scope.rs | 10 +- .../src/codegen/dialect/gpu/shader.rs | 2 +- .../src/codegen/dialect/gpu/variable.rs | 2 +- crates/burn-jit/src/codegen/mod.rs | 2 + 10 files changed, 312 insertions(+), 7 deletions(-) create mode 100644 crates/burn-cube-macros/Cargo.toml create mode 100644 crates/burn-cube-macros/src/lib.rs create mode 100644 crates/burn-cube/Cargo.toml create mode 100644 crates/burn-cube/src/lib.rs create mode 100644 crates/burn-cube/tests/cube.rs diff --git a/Cargo.lock b/Cargo.lock index a149a10c80..e140d33285 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -446,6 +446,25 @@ dependencies = [ "thiserror", ] +[[package]] +name = "burn-cube" +version = "0.13.0" +dependencies = [ + "burn-cube-macros", + "burn-jit", + "log", +] + +[[package]] +name = "burn-cube-macros" +version = "0.13.0" +dependencies = [ + "derive-new", + "proc-macro2", + "quote", + "syn 2.0.55", +] + [[package]] name = "burn-dataset" version = "0.13.0" diff --git a/crates/burn-cube-macros/Cargo.toml b/crates/burn-cube-macros/Cargo.toml new file mode 100644 index 0000000000..06565dec1d --- /dev/null +++ b/crates/burn-cube-macros/Cargo.toml @@ -0,0 +1,24 @@ +[package] +authors = ["nathanielsimard "] +categories = ["science"] +description = "TODO" +edition.workspace = true +keywords = [] +license.workspace = true +name = "burn-cube-macros" +readme.workspace = true +repository = "https://github.com/tracel-ai/burn/tree/main/burn-cube-macros" +version.workspace = true + +[lib] +proc-macro = true + +[features] +default = [] +std = [] + +[dependencies] +proc-macro2 = { workspace = true } +quote = { workspace = true } +syn = { workspace = true } +derive-new = { workspace = true } diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs new file mode 100644 index 0000000000..f4a901e375 --- /dev/null +++ b/crates/burn-cube-macros/src/lib.rs @@ -0,0 +1,116 @@ +use proc_macro::TokenStream; + +/// Derive macro for the module. +#[proc_macro_attribute] +pub fn cube(_attr: TokenStream, tokens: TokenStream) -> TokenStream { + let func: syn::ItemFn = syn::parse(tokens).unwrap(); + let signature = expand_sig(&func.sig); + let mut body = quote::quote! {}; + + for statement in func.block.stmts.iter() { + let mut statement_gen = quote::quote! {}; + let mut skipped = false; + match statement { + syn::Stmt::Expr(expr, _) => match expr { + syn::Expr::Binary(binary) => { + let lhs = &binary.left; + let rhs = &binary.right; + + match binary.op { + syn::BinOp::Add(_add) => { + skipped = true; + statement_gen.extend(quote::quote! { + burn_cube::float_add_expand(context, #lhs, #rhs) + }); + } + _ => (), + } + } + _ => (), + }, + syn::Stmt::Local(local) => { + skipped = true; + let output = &local.pat; + match &output { + syn::Pat::Const(_) => todo!(), + syn::Pat::Ident(ident) => panic!("path ident {:?}", ident), + syn::Pat::Lit(_) => todo!(), + syn::Pat::Macro(_) => todo!(), + syn::Pat::Or(_) => todo!(), + syn::Pat::Paren(_) => todo!(), + syn::Pat::Path(_) => todo!(), + syn::Pat::Range(_) => todo!(), + syn::Pat::Reference(_) => todo!(), + syn::Pat::Rest(_) => todo!(), + syn::Pat::Slice(_) => todo!(), + syn::Pat::Struct(_) => todo!(), + syn::Pat::Tuple(_) => todo!(), + syn::Pat::TupleStruct(_) => todo!(), + syn::Pat::Type(_) => todo!(), + syn::Pat::Verbatim(_) => todo!(), + syn::Pat::Wild(_) => todo!(), + _ => todo!(), + } + } + syn::Stmt::Item(_) => panic!("item"), + syn::Stmt::Macro(_) => panic!("Macros not supported."), + }; + + if !skipped { + body.extend(quote::quote! { + #statement + }); + } else { + body.extend(statement_gen); + } + } + + let code = quote::quote! { + #func + + #signature { + #body + } + } + .into(); + + // panic!("{code}"); + code +} + +fn expand_sig(sig: &syn::Signature) -> proc_macro2::TokenStream { + let mut inputs = quote::quote!(); + + for input in &sig.inputs { + match input { + syn::FnArg::Typed(pat) => { + let ty = &pat.ty; + let ident = pat.pat.clone(); + + inputs.extend(quote::quote! { + #ident: <#ty as burn_cube::CubeVariable>::Variable, + }); + } + _ => todo!(), + } + } + + let mut output = quote::quote!(); + + match &sig.output { + syn::ReturnType::Default => output.extend(quote::quote! { ()}), + syn::ReturnType::Type(_, ty) => { + output.extend(quote::quote! { + <#ty as burn_cube::CubeVariable>::Variable + }); + } + } + + let ident = &sig.ident; + let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span()); + + quote::quote! { + pub fn #ident(context: &mut burn_cube::CodegenContext<'_>, #inputs) -> #output + } + .into() +} diff --git a/crates/burn-cube/Cargo.toml b/crates/burn-cube/Cargo.toml new file mode 100644 index 0000000000..977b877ff8 --- /dev/null +++ b/crates/burn-cube/Cargo.toml @@ -0,0 +1,21 @@ +[package] +authors = ["nathanielsimard "] +categories = ["science"] +description = "TODO" +edition.workspace = true +keywords = [] +license.workspace = true +name = "burn-cube" +readme.workspace = true +repository = "https://github.com/tracel-ai/burn/tree/main/burn-cube" +version.workspace = true + +[features] +default = [] +std = [] + +[dependencies] +burn-jit = { path = "../burn-jit", version = "0.13.0", default-features = false, features= ["autotune"] } +burn-cube-macros = { path = "../burn-cube-macros", version = "0.13.0" } + +log = { workspace = true } diff --git a/crates/burn-cube/src/lib.rs b/crates/burn-cube/src/lib.rs new file mode 100644 index 0000000000..4705ba94e0 --- /dev/null +++ b/crates/burn-cube/src/lib.rs @@ -0,0 +1,96 @@ +extern crate alloc; +pub use burn_cube_macros::cube; + +use alloc::sync::Arc; +use burn_jit::gpu::{self, Item, Scope, Variable}; +use std::collections::HashMap; + +pub struct CodegenContext<'a> { + pub scope: &'a mut Scope, + pub pool: HashMap>, +} + +impl<'a> CodegenContext<'a> { + pub fn crate_float(&mut self, item: Item) -> FloatVariable { + let var = self.create_local(item); + FloatVariable { var } + } + pub fn create_local(&mut self, item: Item) -> Arc { + for variable in self.pool.get(&item).iter() { + if Arc::strong_count(variable) == 1 { + return Arc::clone(variable); + } + } + + let new = Arc::new(self.scope.create_local(item)); + self.pool.insert(item, new.clone()); + + new + } +} + +#[derive(Copy, Clone)] +pub struct Float { + pub kind: u32, + pub val: f32, + pub vectorization: u8, +} + +impl core::ops::Add for Float { + type Output = Self; + + fn add(self, _rhs: Self) -> Self::Output { + panic!("Only used for types"); + } +} + +#[derive(Clone)] +pub struct FloatVariable { + var: Arc, +} + +impl CubeVariable for Float { + type Variable = FloatVariable; +} + +pub trait CubeVariable { + type Variable: Clone; +} + +pub fn float_add_expand( + context: &mut CodegenContext<'_>, + lhs: FloatVariable, + rhs: FloatVariable, +) -> FloatVariable { + let item = lhs.var.item(); + let out = context.create_local(item); + let out = FloatVariable { var: out }; + + let op = gpu::Operator::Add(gpu::BinaryOperator { + lhs: *lhs.var, + rhs: *rhs.var, + out: *out.var, + }); + + context.scope.register(op); + + out +} + +pub fn float_assign_expand( + context: &mut CodegenContext<'_>, + input: FloatVariable, +) -> FloatVariable { + let item = input.var.item(); + let out = context.create_local(item); + let out = FloatVariable { var: out }; + + let op = gpu::Operator::Assign(gpu::UnaryOperator { + input: *input.var, + out: *out.var, + }); + + context.scope.register(op); + + out +} diff --git a/crates/burn-cube/tests/cube.rs b/crates/burn-cube/tests/cube.rs new file mode 100644 index 0000000000..5de36de541 --- /dev/null +++ b/crates/burn-cube/tests/cube.rs @@ -0,0 +1,27 @@ +use burn_cube::{cube, CodegenContext, Float}; +use burn_jit::gpu::{Elem, Item, Scope}; + +#[cube] +pub fn kernel(lhs: Float, rhs: Float) -> Float { + let mut output = lhs; + + for i in 0..10 { + output = lhs + rhs + } + + output +} + +#[test] +fn test_simple_add() { + let mut scope = Scope::root(); + let mut context = CodegenContext { + scope: &mut scope, + pool: Default::default(), + }; + + let lhs = context.crate_float(Item::Vec4(Elem::Float)); + let rhs = context.crate_float(Item::Vec4(Elem::Float)); + + kernel_expand(&mut context, lhs, rhs); +} diff --git a/crates/burn-jit/src/codegen/dialect/gpu/scope.rs b/crates/burn-jit/src/codegen/dialect/gpu/scope.rs index 585c59a601..d3d4220845 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/scope.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/scope.rs @@ -44,7 +44,7 @@ impl Scope { /// [compute shader](crate::codegen::dialect::gpu::ComputeShader). /// /// A local scope can be created with the [child](Self::child) method. - pub(crate) fn root() -> Self { + pub fn root() -> Self { Self { depth: 0, operations: Vec::new(), @@ -69,7 +69,7 @@ impl Scope { } /// Create a variable initialized at some value. - pub(crate) fn create_with_value + Copy>( + pub fn create_with_value + Copy>( &mut self, value: E, item: I, @@ -81,7 +81,7 @@ impl Scope { } /// Create a local variable of the given [item type](Item). - pub(crate) fn create_local>(&mut self, item: I) -> Variable { + pub fn create_local>(&mut self, item: I) -> Variable { let item = item.into(); let index = self.new_local_index(); let local = Variable::Local(index, item, self.depth); @@ -205,12 +205,12 @@ impl Scope { } /// Register an [operation](Operation) into the scope. - pub(crate) fn register>(&mut self, operation: T) { + pub fn register>(&mut self, operation: T) { self.operations.push(operation.into()) } /// Create an empty child scope. - pub(crate) fn child(&mut self) -> Self { + pub fn child(&mut self) -> Self { Self { depth: self.depth + 1, operations: Vec::new(), diff --git a/crates/burn-jit/src/codegen/dialect/gpu/shader.rs b/crates/burn-jit/src/codegen/dialect/gpu/shader.rs index 7fbba14a31..99eb3e9f33 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/shader.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/shader.rs @@ -43,7 +43,7 @@ impl Display for Elem { } } -#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize, Hash)] #[allow(missing_docs)] pub enum Item { Vec4(Elem), diff --git a/crates/burn-jit/src/codegen/dialect/gpu/variable.rs b/crates/burn-jit/src/codegen/dialect/gpu/variable.rs index cd52ea4134..d2015cfd98 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/variable.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/variable.rs @@ -63,7 +63,7 @@ impl Variable { Variable::NumWorkgroupsZ => None, } } - pub(crate) fn item(&self) -> Item { + pub fn item(&self) -> Item { match self { Variable::GlobalInputArray(_, item) => *item, Variable::GlobalOutputArray(_, item) => *item, diff --git a/crates/burn-jit/src/codegen/mod.rs b/crates/burn-jit/src/codegen/mod.rs index cc9ca47404..964b2eca03 100644 --- a/crates/burn-jit/src/codegen/mod.rs +++ b/crates/burn-jit/src/codegen/mod.rs @@ -7,3 +7,5 @@ mod kernel; pub(crate) use compilation::*; pub(crate) use compiler::*; pub(crate) use kernel::*; + +pub use dialect::*; From 94f476f4df6419b68e1a5b5ff50cac85da9f3772 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 10 Apr 2024 12:40:28 -0400 Subject: [PATCH 02/54] WIP --- Cargo.lock | 1 + crates/burn-cube-macros/src/lib.rs | 66 +------ crates/burn-cube-macros/src/statement.rs | 66 +++++++ crates/burn-cube/Cargo.toml | 1 + crates/burn-cube/src/branch.rs | 60 ++++++ crates/burn-cube/src/context.rs | 70 +++++++ crates/burn-cube/src/element.rs | 48 +++++ crates/burn-cube/src/lib.rs | 102 ++--------- crates/burn-cube/src/operation.rs | 173 ++++++++++++++++++ crates/burn-cube/tests/cube.rs | 24 +-- crates/burn-fusion/src/tensor.rs | 3 +- .../burn-jit/src/codegen/dialect/gpu/scope.rs | 2 +- 12 files changed, 448 insertions(+), 168 deletions(-) create mode 100644 crates/burn-cube-macros/src/statement.rs create mode 100644 crates/burn-cube/src/branch.rs create mode 100644 crates/burn-cube/src/context.rs create mode 100644 crates/burn-cube/src/element.rs create mode 100644 crates/burn-cube/src/operation.rs diff --git a/Cargo.lock b/Cargo.lock index e140d33285..78a429dd29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -452,6 +452,7 @@ version = "0.13.0" dependencies = [ "burn-cube-macros", "burn-jit", + "derive-new", "log", ] diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index f4a901e375..2c8a83cbb3 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -1,3 +1,6 @@ +mod statement; +use statement::parse_statement; + use proc_macro::TokenStream; /// Derive macro for the module. @@ -8,61 +11,8 @@ pub fn cube(_attr: TokenStream, tokens: TokenStream) -> TokenStream { let mut body = quote::quote! {}; for statement in func.block.stmts.iter() { - let mut statement_gen = quote::quote! {}; - let mut skipped = false; - match statement { - syn::Stmt::Expr(expr, _) => match expr { - syn::Expr::Binary(binary) => { - let lhs = &binary.left; - let rhs = &binary.right; - - match binary.op { - syn::BinOp::Add(_add) => { - skipped = true; - statement_gen.extend(quote::quote! { - burn_cube::float_add_expand(context, #lhs, #rhs) - }); - } - _ => (), - } - } - _ => (), - }, - syn::Stmt::Local(local) => { - skipped = true; - let output = &local.pat; - match &output { - syn::Pat::Const(_) => todo!(), - syn::Pat::Ident(ident) => panic!("path ident {:?}", ident), - syn::Pat::Lit(_) => todo!(), - syn::Pat::Macro(_) => todo!(), - syn::Pat::Or(_) => todo!(), - syn::Pat::Paren(_) => todo!(), - syn::Pat::Path(_) => todo!(), - syn::Pat::Range(_) => todo!(), - syn::Pat::Reference(_) => todo!(), - syn::Pat::Rest(_) => todo!(), - syn::Pat::Slice(_) => todo!(), - syn::Pat::Struct(_) => todo!(), - syn::Pat::Tuple(_) => todo!(), - syn::Pat::TupleStruct(_) => todo!(), - syn::Pat::Type(_) => todo!(), - syn::Pat::Verbatim(_) => todo!(), - syn::Pat::Wild(_) => todo!(), - _ => todo!(), - } - } - syn::Stmt::Item(_) => panic!("item"), - syn::Stmt::Macro(_) => panic!("Macros not supported."), - }; - - if !skipped { - body.extend(quote::quote! { - #statement - }); - } else { - body.extend(statement_gen); - } + let tokens = parse_statement(statement); + body.extend(tokens); } let code = quote::quote! { @@ -88,7 +38,7 @@ fn expand_sig(sig: &syn::Signature) -> proc_macro2::TokenStream { let ident = pat.pat.clone(); inputs.extend(quote::quote! { - #ident: <#ty as burn_cube::CubeVariable>::Variable, + #ident: <#ty as burn_cube::RuntimeType>::ExpandType, }); } _ => todo!(), @@ -101,7 +51,7 @@ fn expand_sig(sig: &syn::Signature) -> proc_macro2::TokenStream { syn::ReturnType::Default => output.extend(quote::quote! { ()}), syn::ReturnType::Type(_, ty) => { output.extend(quote::quote! { - <#ty as burn_cube::CubeVariable>::Variable + <#ty as burn_cube::RuntimeType>::ExpandType }); } } @@ -110,7 +60,7 @@ fn expand_sig(sig: &syn::Signature) -> proc_macro2::TokenStream { let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span()); quote::quote! { - pub fn #ident(context: &mut burn_cube::CodegenContext<'_>, #inputs) -> #output + pub fn #ident(context: &mut burn_cube::CubeContext, #inputs) -> #output } .into() } diff --git a/crates/burn-cube-macros/src/statement.rs b/crates/burn-cube-macros/src/statement.rs new file mode 100644 index 0000000000..993daab9dc --- /dev/null +++ b/crates/burn-cube-macros/src/statement.rs @@ -0,0 +1,66 @@ +use proc_macro2::TokenStream; + +pub fn parse_statement(statement: &syn::Stmt) -> TokenStream { + match statement { + syn::Stmt::Local(local) => parse_local(local), + syn::Stmt::Item(_) => todo!(), + syn::Stmt::Expr(expr, _) => parse_expr(expr), + syn::Stmt::Macro(_) => todo!(), + } +} + +fn parse_local(local: &syn::Local) -> TokenStream { + let init = local + .init + .as_ref() + .expect("Can't use let without an initialization."); + let ident = match &local.pat { + syn::Pat::Ident(ident) => &ident.ident, + _ => panic!("Only ident declaration is supported."), + }; + let init = parse_expr(&init.expr); + + quote::quote! { + let #ident = #init; + } +} + +fn parse_expr(expr: &syn::Expr) -> TokenStream { + match expr { + syn::Expr::Binary(op) => parse_binary(op), + syn::Expr::Path(path) => parse_path(path), + _ => panic!("Unsupported {:?}", expr), + } +} + +fn parse_path(path: &syn::ExprPath) -> TokenStream { + let ident = path + .path + .get_ident() + .expect("Only ident path are supported."); + + quote::quote! { + #ident + } +} + +fn parse_binary(binary: &syn::ExprBinary) -> TokenStream { + let lhs = &binary.left; + let rhs = &binary.right; + + match binary.op { + syn::BinOp::Add(_) => quote::quote! { + burn_cube::add::expand(context, #lhs, #rhs) + }, + syn::BinOp::Sub(_) => quote::quote! { + burn_cube::sub::expand(context, #lhs, #rhs) + }, + syn::BinOp::Mul(_) => quote::quote! { + burn_cube::mul::expand(context, #lhs, #rhs) + }, + syn::BinOp::Div(_) => quote::quote! { + burn_cube::div::expand(context, #lhs, #rhs) + }, + _ => todo!("{:?}", binary.op), + } +} diff --git a/crates/burn-cube/Cargo.toml b/crates/burn-cube/Cargo.toml index 977b877ff8..f78beebd48 100644 --- a/crates/burn-cube/Cargo.toml +++ b/crates/burn-cube/Cargo.toml @@ -17,5 +17,6 @@ std = [] [dependencies] burn-jit = { path = "../burn-jit", version = "0.13.0", default-features = false, features= ["autotune"] } burn-cube-macros = { path = "../burn-cube-macros", version = "0.13.0" } +derive-new = { workspace = true } log = { workspace = true } diff --git a/crates/burn-cube/src/branch.rs b/crates/burn-cube/src/branch.rs new file mode 100644 index 0000000000..1455eaae05 --- /dev/null +++ b/crates/burn-cube/src/branch.rs @@ -0,0 +1,60 @@ +use crate::{CubeContext, ExpandElement, UInt}; +use burn_jit::gpu::{self, Variable}; + +pub fn range(start: S, end: E, _unroll: bool, func: F) +where + S: Into, + E: Into, + F: Fn(UInt), +{ + let start: UInt = start.into(); + let end: UInt = end.into(); + + for i in start.val..end.val { + func(UInt::new(i, 1)) + } +} + +pub mod for_each { + use burn_jit::gpu::{Branch, Elem, Item}; + + use super::*; + + pub fn expand( + context: &mut CubeContext, + start: ExpandElement, + end: ExpandElement, + unroll: bool, + func: F, + ) where + F: Fn(&mut CubeContext, Variable), + { + if unroll { + let start = match start.as_ref() { + Variable::ConstantScalar(val, _) => *val as usize, + _ => panic!("Only constant start can be unrolled."), + }; + let end = match end.as_ref() { + Variable::ConstantScalar(val, _) => *val as usize, + _ => panic!("Only constant end can be unrolled."), + }; + + for i in start..end { + func(context, i.into()) + } + } else { + let mut child = context.child(); + let index_ty = Item::Scalar(Elem::UInt); + let i = child.scope.create_local_undeclared(index_ty); + + func(&mut child, i); + + context.scope.register(Branch::RangeLoop(gpu::RangeLoop { + i, + start: *start, + end: *end, + scope: child.scope, + })); + } + } +} diff --git a/crates/burn-cube/src/context.rs b/crates/burn-cube/src/context.rs new file mode 100644 index 0000000000..39208bae53 --- /dev/null +++ b/crates/burn-cube/src/context.rs @@ -0,0 +1,70 @@ +use crate::ExpandElement; +use burn_jit::gpu::{Item, Scope, Variable}; +use std::{collections::HashMap, sync::Arc}; + +#[derive(Default, Clone)] +pub struct VariablePool { + map: std::rc::Rc>>>>, +} + +impl VariablePool { + pub fn reuse(&self, item: Item) -> Option { + let map = self.map.borrow(); + + let variables = match map.get(&item) { + Some(val) => val, + None => return None, + }; + + for variable in variables.iter() { + if Arc::strong_count(variable) == 1 { + return Some(Arc::clone(variable)); + } + } + + None + } + pub fn insert(&mut self, var: ExpandElement) { + let mut map = self.map.borrow_mut(); + let item = var.item(); + + if let Some(variables) = map.get_mut(&item) { + variables.push(var.clone()); + } else { + map.insert(var.item(), vec![var.clone()]); + } + } +} + +pub struct CubeContext { + pub scope: Scope, + pub pool: VariablePool, +} + +impl CubeContext { + pub fn root() -> CubeContext { + Self { + pool: Default::default(), + scope: Scope::root(), + } + } + pub fn child(&mut self) -> CubeContext { + let scope = self.scope.child(); + + Self { + scope, + pool: self.pool.clone(), + } + } + + pub fn create_local(&mut self, item: Item) -> ExpandElement { + if let Some(var) = self.pool.reuse(item) { + return var; + } + + let new = Arc::new(self.scope.create_local(item)); + self.pool.insert(new.clone()); + + new + } +} diff --git a/crates/burn-cube/src/element.rs b/crates/burn-cube/src/element.rs new file mode 100644 index 0000000000..2326074785 --- /dev/null +++ b/crates/burn-cube/src/element.rs @@ -0,0 +1,48 @@ +use burn_jit::gpu::Variable; +use std::sync::Arc; + +pub trait RuntimeType { + type ExpandType: Clone; +} + +pub type ExpandElement = Arc; + +#[derive(new, Clone)] +pub struct Float { + pub val: f32, + pub vectorization: u8, +} + +#[derive(new, Clone)] +pub struct Int { + pub val: u32, + pub vectorization: u8, +} + +#[derive(new, Clone)] +pub struct UInt { + pub val: u32, + pub vectorization: u8, +} + +#[derive(new, Clone)] +pub struct Bool { + pub val: bool, + pub vectorization: u8, +} + +impl RuntimeType for Float { + type ExpandType = Arc; +} + +impl RuntimeType for Int { + type ExpandType = Arc; +} + +impl RuntimeType for UInt { + type ExpandType = Arc; +} + +impl RuntimeType for Bool { + type ExpandType = Arc; +} diff --git a/crates/burn-cube/src/lib.rs b/crates/burn-cube/src/lib.rs index 4705ba94e0..78e37f19ea 100644 --- a/crates/burn-cube/src/lib.rs +++ b/crates/burn-cube/src/lib.rs @@ -1,96 +1,16 @@ extern crate alloc; -pub use burn_cube_macros::cube; - -use alloc::sync::Arc; -use burn_jit::gpu::{self, Item, Scope, Variable}; -use std::collections::HashMap; - -pub struct CodegenContext<'a> { - pub scope: &'a mut Scope, - pub pool: HashMap>, -} - -impl<'a> CodegenContext<'a> { - pub fn crate_float(&mut self, item: Item) -> FloatVariable { - let var = self.create_local(item); - FloatVariable { var } - } - pub fn create_local(&mut self, item: Item) -> Arc { - for variable in self.pool.get(&item).iter() { - if Arc::strong_count(variable) == 1 { - return Arc::clone(variable); - } - } - - let new = Arc::new(self.scope.create_local(item)); - self.pool.insert(item, new.clone()); - - new - } -} - -#[derive(Copy, Clone)] -pub struct Float { - pub kind: u32, - pub val: f32, - pub vectorization: u8, -} - -impl core::ops::Add for Float { - type Output = Self; - - fn add(self, _rhs: Self) -> Self::Output { - panic!("Only used for types"); - } -} -#[derive(Clone)] -pub struct FloatVariable { - var: Arc, -} +#[macro_use] +extern crate derive_new; -impl CubeVariable for Float { - type Variable = FloatVariable; -} +mod branch; +mod context; +mod element; +mod operation; -pub trait CubeVariable { - type Variable: Clone; -} +pub use branch::*; +pub use context::*; +pub use element::*; +pub use operation::*; -pub fn float_add_expand( - context: &mut CodegenContext<'_>, - lhs: FloatVariable, - rhs: FloatVariable, -) -> FloatVariable { - let item = lhs.var.item(); - let out = context.create_local(item); - let out = FloatVariable { var: out }; - - let op = gpu::Operator::Add(gpu::BinaryOperator { - lhs: *lhs.var, - rhs: *rhs.var, - out: *out.var, - }); - - context.scope.register(op); - - out -} - -pub fn float_assign_expand( - context: &mut CodegenContext<'_>, - input: FloatVariable, -) -> FloatVariable { - let item = input.var.item(); - let out = context.create_local(item); - let out = FloatVariable { var: out }; - - let op = gpu::Operator::Assign(gpu::UnaryOperator { - input: *input.var, - out: *out.var, - }); - - context.scope.register(op); - - out -} +pub use burn_cube_macros::cube; diff --git a/crates/burn-cube/src/operation.rs b/crates/burn-cube/src/operation.rs new file mode 100644 index 0000000000..3ddcec0919 --- /dev/null +++ b/crates/burn-cube/src/operation.rs @@ -0,0 +1,173 @@ +use crate::{CubeContext, ExpandElement, Float, Int, UInt}; +use burn_jit::gpu::{self, Variable}; + +pub mod add { + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + binary_expand(context, lhs, rhs, gpu::Operator::Add) + } + + impl core::ops::Add for Float { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Float::new(self.val + rhs.val, 1) + } + } + + impl core::ops::Add for Int { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Int::new(self.val + rhs.val, 1) + } + } + + impl core::ops::Add for UInt { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + UInt::new(self.val + rhs.val, 1) + } + } +} + +pub mod sub { + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + binary_expand(context, lhs, rhs, gpu::Operator::Sub) + } + + impl core::ops::Sub for Float { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Float::new(self.val - rhs.val, 1) + } + } + + impl core::ops::Sub for Int { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Int::new(self.val - rhs.val, 1) + } + } + + impl core::ops::Sub for UInt { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + UInt::new(self.val - rhs.val, 1) + } + } +} + +pub mod mul { + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + binary_expand(context, lhs, rhs, gpu::Operator::Mul) + } + + impl core::ops::Mul for Float { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + Float::new(self.val * rhs.val, 1) + } + } + + impl core::ops::Mul for Int { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + Int::new(self.val * rhs.val, 1) + } + } + + impl core::ops::Mul for UInt { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + UInt::new(self.val * rhs.val, 1) + } + } +} + +pub mod div { + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + binary_expand(context, lhs, rhs, gpu::Operator::Div) + } + + impl core::ops::Div for Float { + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + Float::new(self.val / rhs.val, 1) + } + } + + impl core::ops::Div for Int { + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + Int::new(self.val / rhs.val, 1) + } + } + + impl core::ops::Div for UInt { + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + UInt::new(self.val / rhs.val, 1) + } + } +} + +fn binary_expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + func: F, +) -> ExpandElement +where + F: Fn(gpu::BinaryOperator) -> gpu::Operator, +{ + let lhs: Variable = *lhs; + let rhs: Variable = *rhs; + + let item = lhs.item(); + let out = context.create_local(item); + let out_var = *out; + + let op = func(gpu::BinaryOperator { + lhs, + rhs, + out: out_var, + }); + + context.scope.register(op); + + out +} diff --git a/crates/burn-cube/tests/cube.rs b/crates/burn-cube/tests/cube.rs index 5de36de541..1bbc58641c 100644 --- a/crates/burn-cube/tests/cube.rs +++ b/crates/burn-cube/tests/cube.rs @@ -1,27 +1,19 @@ -use burn_cube::{cube, CodegenContext, Float}; -use burn_jit::gpu::{Elem, Item, Scope}; +use burn_cube::{cube, CubeContext, Float}; +use burn_jit::gpu::{Elem, Item}; #[cube] pub fn kernel(lhs: Float, rhs: Float) -> Float { - let mut output = lhs; - - for i in 0..10 { - output = lhs + rhs - } - - output + let out = lhs.clone() + rhs; + let out = lhs - out; + out } #[test] fn test_simple_add() { - let mut scope = Scope::root(); - let mut context = CodegenContext { - scope: &mut scope, - pool: Default::default(), - }; + let mut context = CubeContext::root(); - let lhs = context.crate_float(Item::Vec4(Elem::Float)); - let rhs = context.crate_float(Item::Vec4(Elem::Float)); + let lhs = context.create_local(Item::Vec4(Elem::Float)); + let rhs = context.create_local(Item::Vec4(Elem::Float)); kernel_expand(&mut context, lhs, rhs); } diff --git a/crates/burn-fusion/src/tensor.rs b/crates/burn-fusion/src/tensor.rs index 5b430158df..8efd2221ff 100644 --- a/crates/burn-fusion/src/tensor.rs +++ b/crates/burn-fusion/src/tensor.rs @@ -153,8 +153,7 @@ pub enum TensorStatus { /// /// 1. Status::NotInit /// 2. Status::ReadOnly -/// 3. Status::ReadOnly -/// 4. Status::ReadWrite +/// 3. Status::ReadWrite #[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] pub struct TensorDescription { /// The [tensor id](TensorId). diff --git a/crates/burn-jit/src/codegen/dialect/gpu/scope.rs b/crates/burn-jit/src/codegen/dialect/gpu/scope.rs index d3d4220845..1bfada540a 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/scope.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/scope.rs @@ -92,7 +92,7 @@ impl Scope { /// Create a new local variable, but doesn't perform the declaration. /// /// Useful for _for loops_ and other algorithms that require the control over initialization. - pub(crate) fn create_local_undeclared(&mut self, item: Item) -> Variable { + pub fn create_local_undeclared(&mut self, item: Item) -> Variable { let index = self.new_local_index(); let local = Variable::Local(index, item, self.depth); self.undeclared += 1; From 8e4d39e7a3211c54b500c3f9374f822867ae2e6f Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 10 Apr 2024 17:27:24 -0400 Subject: [PATCH 03/54] WIP --- crates/burn-cube-macros/src/lib.rs | 3 +- crates/burn-cube-macros/src/statement.rs | 136 ++++++++++++++++++++++- crates/burn-cube/src/branch.rs | 86 +++++++------- crates/burn-cube/src/context.rs | 39 +++++-- crates/burn-cube/src/element.rs | 43 +++++-- crates/burn-cube/src/operation.rs | 19 +++- crates/burn-cube/tests/cube.rs | 21 +++- 7 files changed, 272 insertions(+), 75 deletions(-) diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index 2c8a83cbb3..39ca5d3e48 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -1,7 +1,6 @@ mod statement; -use statement::parse_statement; - use proc_macro::TokenStream; +use statement::parse_statement; /// Derive macro for the module. #[proc_macro_attribute] diff --git a/crates/burn-cube-macros/src/statement.rs b/crates/burn-cube-macros/src/statement.rs index 993daab9dc..8697018aa5 100644 --- a/crates/burn-cube-macros/src/statement.rs +++ b/crates/burn-cube-macros/src/statement.rs @@ -4,7 +4,16 @@ pub fn parse_statement(statement: &syn::Stmt) -> TokenStream { match statement { syn::Stmt::Local(local) => parse_local(local), syn::Stmt::Item(_) => todo!(), - syn::Stmt::Expr(expr, _) => parse_expr(expr), + syn::Stmt::Expr(expr, semi) => { + if let Some(_semi) = semi { + let expr = parse_expr(expr); + quote::quote! { + #expr; + } + } else { + parse_expr(expr) + } + } syn::Stmt::Macro(_) => todo!(), } } @@ -15,13 +24,15 @@ fn parse_local(local: &syn::Local) -> TokenStream { .as_ref() .expect("Can't use let without an initialization."); let ident = match &local.pat { - syn::Pat::Ident(ident) => &ident.ident, + syn::Pat::Ident(ident) => ident, _ => panic!("Only ident declaration is supported."), }; let init = parse_expr(&init.expr); + let let_tok = local.let_token; + quote::quote! { - let #ident = #init; + #let_tok #ident = #init; } } @@ -29,10 +40,125 @@ fn parse_expr(expr: &syn::Expr) -> TokenStream { match expr { syn::Expr::Binary(op) => parse_binary(op), syn::Expr::Path(path) => parse_path(path), + syn::Expr::Call(call) => parse_call(call), + syn::Expr::Lit(lit) => quote::quote! { #lit.into() }, + syn::Expr::Closure(closure) => parse_closure(closure), + syn::Expr::Block(block) => parse_expr_block(block), + syn::Expr::Assign(assign) => parse_assign(assign), + syn::Expr::ForLoop(for_loop) => parse_for_loop(for_loop), + syn::Expr::MethodCall(call) => parse_expr_method_call(call), _ => panic!("Unsupported {:?}", expr), } } +fn parse_expr_method_call(call: &syn::ExprMethodCall) -> TokenStream { + quote::quote!( #call ) +} + +fn parse_for_loop(for_loop: &syn::ExprForLoop) -> TokenStream { + let i = &for_loop.pat; + let block = parse_block(&for_loop.body); + + match for_loop.expr.as_ref() { + syn::Expr::Call(call) => { + let func_name = match call.func.as_ref() { + syn::Expr::Path(path) => path.path.get_ident().unwrap(), + _ => todo!("Only path call supported"), + }; + if &func_name.to_string() == "range" { + let mut args = quote::quote! { + context, + }; + for argument in call.args.iter() { + let arg = parse_expr(argument); + args.extend(quote::quote! { #arg, }); + } + + return quote::quote! { + range_expand(#args |context, #i| #block); + }; + } + } + _ => todo!("Only call is supported {for_loop:?}"), + } + + todo!(); +} + +fn parse_assign(assign: &syn::ExprAssign) -> TokenStream { + let lhs = parse_expr(&assign.left); + let rhs = parse_expr(&assign.right); + + quote::quote! { + { + // The clone is necessary when mutating a variable that is of a parent scope. + let _assign_lhs = #lhs.clone(); + // This is necessary is the rhs is an expression that need a mutable reference on the + // context. + let _assign_rhs = #rhs; + #lhs = burn_cube::assign::expand(context, _assign_lhs, _assign_rhs) + } + } +} + +fn parse_block(block: &syn::Block) -> TokenStream { + let mut statements = quote::quote!(); + + for statement in block.stmts.iter() { + statements.extend(parse_statement(statement)); + } + + quote::quote! { + { + #statements + } + } +} + +fn parse_expr_block(block: &syn::ExprBlock) -> TokenStream { + parse_block(&block.block) +} + +fn parse_closure(closure: &syn::ExprClosure) -> TokenStream { + let mut inputs = quote::quote! {}; + for input in closure.inputs.iter() { + let ident = match input { + syn::Pat::Ident(ident) => &ident.ident, + _ => panic!("Unsupported {:?}", input), + }; + inputs.extend(quote::quote! { + #ident, + }); + } + + let body = parse_expr(closure.body.as_ref()); + + quote::quote! { + |context, #inputs| #body + } +} + +fn parse_call(call: &syn::ExprCall) -> TokenStream { + let func_name = match call.func.as_ref() { + syn::Expr::Path(path) => path.path.get_ident().unwrap(), + _ => todo!("Only path call supported"), + }; + let mut args = quote::quote! { + context, + }; + let func_name_expand = + syn::Ident::new(format!("{func_name}_expand").as_str(), func_name.span()); + + for argument in call.args.iter() { + let arg = parse_expr(argument); + args.extend(quote::quote! { #arg, }); + } + + quote::quote! { + #func_name_expand(#args) + } +} + fn parse_path(path: &syn::ExprPath) -> TokenStream { let ident = path .path @@ -45,8 +171,8 @@ fn parse_path(path: &syn::ExprPath) -> TokenStream { } fn parse_binary(binary: &syn::ExprBinary) -> TokenStream { - let lhs = &binary.left; - let rhs = &binary.right; + let lhs = parse_expr(&binary.left); + let rhs = parse_expr(&binary.right); match binary.op { syn::BinOp::Add(_) => quote::quote! { diff --git a/crates/burn-cube/src/branch.rs b/crates/burn-cube/src/branch.rs index 1455eaae05..0c14571421 100644 --- a/crates/burn-cube/src/branch.rs +++ b/crates/burn-cube/src/branch.rs @@ -1,60 +1,56 @@ +use std::ops::Deref; + use crate::{CubeContext, ExpandElement, UInt}; -use burn_jit::gpu::{self, Variable}; +use burn_jit::gpu::{self, Branch, Elem, Item, Variable}; -pub fn range(start: S, end: E, _unroll: bool, func: F) +pub fn range(start: S, end: E, _unroll: bool) -> core::ops::Range where S: Into, E: Into, - F: Fn(UInt), { let start: UInt = start.into(); let end: UInt = end.into(); - for i in start.val..end.val { - func(UInt::new(i, 1)) + core::ops::Range { + start: start.val as usize, + end: end.val as usize, } } -pub mod for_each { - use burn_jit::gpu::{Branch, Elem, Item}; - - use super::*; - - pub fn expand( - context: &mut CubeContext, - start: ExpandElement, - end: ExpandElement, - unroll: bool, - func: F, - ) where - F: Fn(&mut CubeContext, Variable), - { - if unroll { - let start = match start.as_ref() { - Variable::ConstantScalar(val, _) => *val as usize, - _ => panic!("Only constant start can be unrolled."), - }; - let end = match end.as_ref() { - Variable::ConstantScalar(val, _) => *val as usize, - _ => panic!("Only constant end can be unrolled."), - }; - - for i in start..end { - func(context, i.into()) - } - } else { - let mut child = context.child(); - let index_ty = Item::Scalar(Elem::UInt); - let i = child.scope.create_local_undeclared(index_ty); - - func(&mut child, i); - - context.scope.register(Branch::RangeLoop(gpu::RangeLoop { - i, - start: *start, - end: *end, - scope: child.scope, - })); +pub fn range_expand( + context: &mut CubeContext, + start: ExpandElement, + end: ExpandElement, + unroll: bool, + mut func: F, +) where + F: FnMut(&mut CubeContext, Variable), +{ + if unroll { + let start = match start.deref() { + Variable::ConstantScalar(val, _) => *val as usize, + _ => panic!("Only constant start can be unrolled."), + }; + let end = match end.deref() { + Variable::ConstantScalar(val, _) => *val as usize, + _ => panic!("Only constant end can be unrolled."), + }; + + for i in start..end { + func(context, i.into()) } + } else { + let mut child = context.child(); + let index_ty = Item::Scalar(Elem::UInt); + let i = child.scope.borrow_mut().create_local_undeclared(index_ty); + + func(&mut child, i); + + context.register(Branch::RangeLoop(gpu::RangeLoop { + i, + start: *start, + end: *end, + scope: child.into_scope(), + })); } } diff --git a/crates/burn-cube/src/context.rs b/crates/burn-cube/src/context.rs index 39208bae53..76ec7e5f19 100644 --- a/crates/burn-cube/src/context.rs +++ b/crates/burn-cube/src/context.rs @@ -1,10 +1,12 @@ use crate::ExpandElement; -use burn_jit::gpu::{Item, Scope, Variable}; -use std::{collections::HashMap, sync::Arc}; +use alloc::rc::Rc; +use burn_jit::gpu::{self, Item, Scope}; +use core::cell::RefCell; +use std::collections::HashMap; #[derive(Default, Clone)] pub struct VariablePool { - map: std::rc::Rc>>>>, + map: Rc>>>, } impl VariablePool { @@ -17,8 +19,8 @@ impl VariablePool { }; for variable in variables.iter() { - if Arc::strong_count(variable) == 1 { - return Some(Arc::clone(variable)); + if Rc::strong_count(&variable.inner) == 1 { + return Some(variable.clone()); } } @@ -37,32 +39,49 @@ impl VariablePool { } pub struct CubeContext { - pub scope: Scope, + pub root: Rc>, + pub scope: Rc>, pub pool: VariablePool, } impl CubeContext { pub fn root() -> CubeContext { + let root = Rc::new(RefCell::new(Scope::root())); + let scope = root.clone(); + Self { pool: Default::default(), - scope: Scope::root(), + scope, + root, } } + pub fn register>(&mut self, op: O) { + self.scope.borrow_mut().register(op) + } pub fn child(&mut self) -> CubeContext { - let scope = self.scope.child(); + let scope = self.scope.borrow_mut().child(); Self { - scope, + scope: Rc::new(RefCell::new(scope)), + root: self.root.clone(), pool: self.pool.clone(), } } + pub fn into_scope(self) -> Scope { + core::mem::drop(self.root); + + Rc::into_inner(self.scope) + .expect("Only one reference") + .into_inner() + } + pub fn create_local(&mut self, item: Item) -> ExpandElement { if let Some(var) = self.pool.reuse(item) { return var; } - let new = Arc::new(self.scope.create_local(item)); + let new = ExpandElement::new(Rc::new(self.root.borrow_mut().create_local(item))); self.pool.insert(new.clone()); new diff --git a/crates/burn-cube/src/element.rs b/crates/burn-cube/src/element.rs index 2326074785..300f041180 100644 --- a/crates/burn-cube/src/element.rs +++ b/crates/burn-cube/src/element.rs @@ -1,11 +1,34 @@ -use burn_jit::gpu::Variable; -use std::sync::Arc; +use alloc::rc::Rc; +use burn_jit::gpu::{Item, Variable}; pub trait RuntimeType { type ExpandType: Clone; } -pub type ExpandElement = Arc; +#[derive(new, Clone)] +pub struct ExpandElement { + pub(crate) inner: Rc, +} + +impl ExpandElement { + pub fn item(&self) -> Item { + self.inner.item() + } +} + +impl From for ExpandElement { + fn from(value: u32) -> Self { + ExpandElement::new(Rc::new(Variable::from(value))) + } +} + +impl core::ops::Deref for ExpandElement { + type Target = Variable; + + fn deref(&self) -> &Self::Target { + self.inner.as_ref() + } +} #[derive(new, Clone)] pub struct Float { @@ -32,17 +55,23 @@ pub struct Bool { } impl RuntimeType for Float { - type ExpandType = Arc; + type ExpandType = ExpandElement; } impl RuntimeType for Int { - type ExpandType = Arc; + type ExpandType = ExpandElement; } impl RuntimeType for UInt { - type ExpandType = Arc; + type ExpandType = ExpandElement; } impl RuntimeType for Bool { - type ExpandType = Arc; + type ExpandType = ExpandElement; +} + +impl From for UInt { + fn from(value: u32) -> Self { + UInt::new(value, 1) + } } diff --git a/crates/burn-cube/src/operation.rs b/crates/burn-cube/src/operation.rs index 3ddcec0919..837e637743 100644 --- a/crates/burn-cube/src/operation.rs +++ b/crates/burn-cube/src/operation.rs @@ -145,6 +145,23 @@ pub mod div { } } +pub mod assign { + use super::*; + + pub fn expand( + context: &mut CubeContext, + input: ExpandElement, + output: ExpandElement, + ) -> ExpandElement { + let input = *input; + let out = *output; + + context.register(gpu::Operator::Assign(gpu::UnaryOperator { input, out })); + + output + } +} + fn binary_expand( context: &mut CubeContext, lhs: ExpandElement, @@ -167,7 +184,7 @@ where out: out_var, }); - context.scope.register(op); + context.register(op); out } diff --git a/crates/burn-cube/tests/cube.rs b/crates/burn-cube/tests/cube.rs index 1bbc58641c..7cd68c1649 100644 --- a/crates/burn-cube/tests/cube.rs +++ b/crates/burn-cube/tests/cube.rs @@ -1,19 +1,30 @@ -use burn_cube::{cube, CubeContext, Float}; +use burn_cube::{cube, range, range_expand, CubeContext, Float, UInt}; use burn_jit::gpu::{Elem, Item}; #[cube] -pub fn kernel(lhs: Float, rhs: Float) -> Float { - let out = lhs.clone() + rhs; - let out = lhs - out; +pub fn kernel(lhs: Float, rhs: Float, end: UInt) -> Float { + let mut out = lhs.clone() + rhs.clone(); + + for i in range(0, end, false) { + let temp = out.clone() * rhs.clone(); + out = kernel_inner(out.clone(), temp); + } + out } +#[cube] +pub fn kernel_inner(lhs: Float, rhs: Float) -> Float { + lhs + rhs +} + #[test] fn test_simple_add() { let mut context = CubeContext::root(); let lhs = context.create_local(Item::Vec4(Elem::Float)); let rhs = context.create_local(Item::Vec4(Elem::Float)); + let end = context.create_local(Item::Scalar(Elem::UInt)); - kernel_expand(&mut context, lhs, rhs); + kernel_expand(&mut context, lhs, rhs, end); } From f15919fd36f9ecb1f516d5f0af0a5b1b668d7358 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Fri, 12 Apr 2024 16:02:36 -0400 Subject: [PATCH 04/54] WIP --- crates/burn-cube-macros/src/lib.rs | 3 +- crates/burn-cube-macros/src/statement.rs | 54 ++++++++++++++++++++++-- crates/burn-cube/src/branch.rs | 9 ++-- crates/burn-cube/src/element.rs | 43 +++++++++++++++++++ crates/burn-cube/src/operation.rs | 51 +++++++++++++++++++++- crates/burn-cube/tests/cube.rs | 23 ++++------ 6 files changed, 159 insertions(+), 24 deletions(-) diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index 39ca5d3e48..82851fe008 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -17,13 +17,14 @@ pub fn cube(_attr: TokenStream, tokens: TokenStream) -> TokenStream { let code = quote::quote! { #func + #[allow(unused_mut)] #signature { #body } } .into(); - // panic!("{code}"); + panic!("{code}"); code } diff --git a/crates/burn-cube-macros/src/statement.rs b/crates/burn-cube-macros/src/statement.rs index 8697018aa5..be5146a725 100644 --- a/crates/burn-cube-macros/src/statement.rs +++ b/crates/burn-cube-macros/src/statement.rs @@ -47,10 +47,24 @@ fn parse_expr(expr: &syn::Expr) -> TokenStream { syn::Expr::Assign(assign) => parse_assign(assign), syn::Expr::ForLoop(for_loop) => parse_for_loop(for_loop), syn::Expr::MethodCall(call) => parse_expr_method_call(call), + syn::Expr::Index(index) => parse_expr_index(index), _ => panic!("Unsupported {:?}", expr), } } +fn parse_expr_index(index: &syn::ExprIndex) -> TokenStream { + let array = parse_expr(&index.expr); + let index = parse_expr(&index.index); + + quote::quote! { + { + let _array = #array; + let _index = #index.clone(); + burn_cube::index::expand(context, _array, _index) + } + } +} + fn parse_expr_method_call(call: &syn::ExprMethodCall) -> TokenStream { quote::quote!( #call ) } @@ -86,6 +100,22 @@ fn parse_for_loop(for_loop: &syn::ExprForLoop) -> TokenStream { } fn parse_assign(assign: &syn::ExprAssign) -> TokenStream { + if let syn::Expr::Index(index) = assign.left.as_ref() { + let array = parse_expr(&index.expr); + let index = parse_expr(&index.index); + let value = parse_expr(&assign.right); + + return quote::quote! { + { + // The clone is necessary when mutating a variable that is of a parent scope. + let _array = #array.clone(); + let _index = #index.clone(); + let _value = #value; + burn_cube::index_assign::expand(context, _array, _index, _value) + } + }; + }; + let lhs = parse_expr(&assign.left); let rhs = parse_expr(&assign.right); @@ -176,16 +206,32 @@ fn parse_binary(binary: &syn::ExprBinary) -> TokenStream { match binary.op { syn::BinOp::Add(_) => quote::quote! { - burn_cube::add::expand(context, #lhs, #rhs) + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::add::expand(context, _lhs, _rhs) + } }, syn::BinOp::Sub(_) => quote::quote! { - burn_cube::sub::expand(context, #lhs, #rhs) + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::sub::expand(context, _lhs, _rhs) + } }, syn::BinOp::Mul(_) => quote::quote! { - burn_cube::mul::expand(context, #lhs, #rhs) + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::mul::expand(context, _lhs, _rhs) + } }, syn::BinOp::Div(_) => quote::quote! { - burn_cube::div::expand(context, #lhs, #rhs) + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::div::expand(context, _lhs, _rhs) + } }, _ => todo!("{:?}", binary.op), } diff --git a/crates/burn-cube/src/branch.rs b/crates/burn-cube/src/branch.rs index 0c14571421..1221a34517 100644 --- a/crates/burn-cube/src/branch.rs +++ b/crates/burn-cube/src/branch.rs @@ -1,4 +1,4 @@ -use std::ops::Deref; +use std::{ops::Deref, rc::Rc}; use crate::{CubeContext, ExpandElement, UInt}; use burn_jit::gpu::{self, Branch, Elem, Item, Variable}; @@ -24,7 +24,7 @@ pub fn range_expand( unroll: bool, mut func: F, ) where - F: FnMut(&mut CubeContext, Variable), + F: FnMut(&mut CubeContext, ExpandElement), { if unroll { let start = match start.deref() { @@ -43,11 +43,12 @@ pub fn range_expand( let mut child = context.child(); let index_ty = Item::Scalar(Elem::UInt); let i = child.scope.borrow_mut().create_local_undeclared(index_ty); + let i = ExpandElement::new(Rc::new(i)); - func(&mut child, i); + func(&mut child, i.clone()); context.register(Branch::RangeLoop(gpu::RangeLoop { - i, + i: *i, start: *start, end: *end, scope: child.into_scope(), diff --git a/crates/burn-cube/src/element.rs b/crates/burn-cube/src/element.rs index 300f041180..9b59218be7 100644 --- a/crates/burn-cube/src/element.rs +++ b/crates/burn-cube/src/element.rs @@ -22,6 +22,18 @@ impl From for ExpandElement { } } +impl From for ExpandElement { + fn from(value: usize) -> Self { + ExpandElement::new(Rc::new(Variable::from(value))) + } +} + +impl From for ExpandElement { + fn from(value: bool) -> Self { + ExpandElement::new(Rc::new(Variable::from(value))) + } +} + impl core::ops::Deref for ExpandElement { type Target = Variable; @@ -54,10 +66,31 @@ pub struct Bool { pub vectorization: u8, } +#[derive(new, Clone)] +pub struct Array { + pub vals: Vec, +} + impl RuntimeType for Float { type ExpandType = ExpandElement; } +impl RuntimeType for Array { + type ExpandType = ExpandElement; +} + +impl RuntimeType for Array { + type ExpandType = ExpandElement; +} + +impl RuntimeType for Array { + type ExpandType = ExpandElement; +} + +impl RuntimeType for Array { + type ExpandType = ExpandElement; +} + impl RuntimeType for Int { type ExpandType = ExpandElement; } @@ -70,8 +103,18 @@ impl RuntimeType for Bool { type ExpandType = ExpandElement; } +impl RuntimeType for bool { + type ExpandType = bool; +} + impl From for UInt { fn from(value: u32) -> Self { UInt::new(value, 1) } } + +impl From for UInt { + fn from(value: usize) -> Self { + UInt::new(value as u32, 1) + } +} diff --git a/crates/burn-cube/src/operation.rs b/crates/burn-cube/src/operation.rs index 837e637743..31fae41b5f 100644 --- a/crates/burn-cube/src/operation.rs +++ b/crates/burn-cube/src/operation.rs @@ -1,4 +1,4 @@ -use crate::{CubeContext, ExpandElement, Float, Int, UInt}; +use crate::{Array, CubeContext, ExpandElement, Float, Int, UInt}; use burn_jit::gpu::{self, Variable}; pub mod add { @@ -162,6 +162,55 @@ pub mod assign { } } +pub mod index { + use crate::RuntimeType; + + use super::*; + + pub fn expand( + context: &mut CubeContext, + array: ExpandElement, + index: ExpandElement, + ) -> ExpandElement { + binary_expand(context, array, index, gpu::Operator::Index) + } + + impl> core::ops::Index for Array { + type Output = E; + + fn index(&self, index: I) -> &Self::Output { + let index = index.into().val; + &self.vals[index as usize] + } + } +} + +pub mod index_assign { + use crate::RuntimeType; + + use super::*; + + pub fn expand( + context: &mut CubeContext, + array: ExpandElement, + index: ExpandElement, + value: ExpandElement, + ) { + context.register(gpu::Operator::IndexAssign(gpu::BinaryOperator { + lhs: *index, + rhs: *value, + out: *array, + })) + } + + impl> core::ops::IndexMut for Array { + fn index_mut(&mut self, index: I) -> &mut Self::Output { + let index = index.into().val; + &mut self.vals[index as usize] + } + } +} + fn binary_expand( context: &mut CubeContext, lhs: ExpandElement, diff --git a/crates/burn-cube/tests/cube.rs b/crates/burn-cube/tests/cube.rs index 7cd68c1649..6e2f71b93f 100644 --- a/crates/burn-cube/tests/cube.rs +++ b/crates/burn-cube/tests/cube.rs @@ -1,22 +1,17 @@ -use burn_cube::{cube, range, range_expand, CubeContext, Float, UInt}; +use burn_cube::{cube, range, range_expand, Array, CubeContext, Float, UInt}; use burn_jit::gpu::{Elem, Item}; #[cube] -pub fn kernel(lhs: Float, rhs: Float, end: UInt) -> Float { - let mut out = lhs.clone() + rhs.clone(); - - for i in range(0, end, false) { - let temp = out.clone() * rhs.clone(); - out = kernel_inner(out.clone(), temp); +pub fn kernel(mut lhs: Array, rhs: Float, end: UInt, unroll: bool) { + for i in range(0usize, end, unroll) { + lhs[i] = rhs.clone() + lhs[i].clone(); } - - out } -#[cube] -pub fn kernel_inner(lhs: Float, rhs: Float) -> Float { - lhs + rhs -} +// #[cube] +// pub fn kernel_inner(lhs: Float, rhs: Float) -> Float { +// lhs + rhs +// } #[test] fn test_simple_add() { @@ -26,5 +21,5 @@ fn test_simple_add() { let rhs = context.create_local(Item::Vec4(Elem::Float)); let end = context.create_local(Item::Scalar(Elem::UInt)); - kernel_expand(&mut context, lhs, rhs, end); + kernel_expand(&mut context, lhs, rhs, end, false); } From 3c64001ea61835bed1a10d975a3874d0a46961f0 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Tue, 16 Apr 2024 11:31:18 -0400 Subject: [PATCH 05/54] Push --- crates/burn-cube-macros/src/statement.rs | 25 ++++++---- crates/burn-cube/src/context.rs | 7 +++ crates/burn-cube/src/element.rs | 10 ++-- crates/burn-cube/tests/cube.rs | 63 +++++++++++++++++++++--- 4 files changed, 82 insertions(+), 23 deletions(-) diff --git a/crates/burn-cube-macros/src/statement.rs b/crates/burn-cube-macros/src/statement.rs index be5146a725..1f71d330fc 100644 --- a/crates/burn-cube-macros/src/statement.rs +++ b/crates/burn-cube-macros/src/statement.rs @@ -59,7 +59,7 @@ fn parse_expr_index(index: &syn::ExprIndex) -> TokenStream { quote::quote! { { let _array = #array; - let _index = #index.clone(); + let _index = #index; burn_cube::index::expand(context, _array, _index) } } @@ -107,9 +107,8 @@ fn parse_assign(assign: &syn::ExprAssign) -> TokenStream { return quote::quote! { { - // The clone is necessary when mutating a variable that is of a parent scope. - let _array = #array.clone(); - let _index = #index.clone(); + let _array = #array; + let _index = #index; let _value = #value; burn_cube::index_assign::expand(context, _array, _index, _value) } @@ -121,10 +120,7 @@ fn parse_assign(assign: &syn::ExprAssign) -> TokenStream { quote::quote! { { - // The clone is necessary when mutating a variable that is of a parent scope. - let _assign_lhs = #lhs.clone(); - // This is necessary is the rhs is an expression that need a mutable reference on the - // context. + let _assign_lhs = #lhs; let _assign_rhs = #rhs; #lhs = burn_cube::assign::expand(context, _assign_lhs, _assign_rhs) } @@ -195,8 +191,17 @@ fn parse_path(path: &syn::ExprPath) -> TokenStream { .get_ident() .expect("Only ident path are supported."); - quote::quote! { - #ident + // TODO: Check in the following statements if the indent is overriden, or reused. + let will_be_used_again = true; + + if will_be_used_again { + quote::quote! { + #ident.clone() + } + } else { + quote::quote! { + #ident + } } } diff --git a/crates/burn-cube/src/context.rs b/crates/burn-cube/src/context.rs index 76ec7e5f19..1878c98664 100644 --- a/crates/burn-cube/src/context.rs +++ b/crates/burn-cube/src/context.rs @@ -20,10 +20,17 @@ impl VariablePool { for variable in variables.iter() { if Rc::strong_count(&variable.inner) == 1 { + println!("Reuse var {:?}", variable.inner); return Some(variable.clone()); } } + println!("New var"); + for variable in variables { + let count = Rc::strong_count(&variable.inner); + println!("{:?} => {}", variable.inner, count); + } + None } pub fn insert(&mut self, var: ExpandElement) { diff --git a/crates/burn-cube/src/element.rs b/crates/burn-cube/src/element.rs index 9b59218be7..2a480b62ee 100644 --- a/crates/burn-cube/src/element.rs +++ b/crates/burn-cube/src/element.rs @@ -5,7 +5,7 @@ pub trait RuntimeType { type ExpandType: Clone; } -#[derive(new, Clone)] +#[derive(new, Clone, Debug)] pub struct ExpandElement { pub(crate) inner: Rc, } @@ -42,25 +42,25 @@ impl core::ops::Deref for ExpandElement { } } -#[derive(new, Clone)] +#[derive(new, Clone, Copy)] pub struct Float { pub val: f32, pub vectorization: u8, } -#[derive(new, Clone)] +#[derive(new, Clone, Copy)] pub struct Int { pub val: u32, pub vectorization: u8, } -#[derive(new, Clone)] +#[derive(new, Clone, Copy)] pub struct UInt { pub val: u32, pub vectorization: u8, } -#[derive(new, Clone)] +#[derive(new, Clone, Copy)] pub struct Bool { pub val: bool, pub vectorization: u8, diff --git a/crates/burn-cube/tests/cube.rs b/crates/burn-cube/tests/cube.rs index 6e2f71b93f..6f119a2338 100644 --- a/crates/burn-cube/tests/cube.rs +++ b/crates/burn-cube/tests/cube.rs @@ -1,25 +1,72 @@ use burn_cube::{cube, range, range_expand, Array, CubeContext, Float, UInt}; use burn_jit::gpu::{Elem, Item}; -#[cube] +// #[cube] pub fn kernel(mut lhs: Array, rhs: Float, end: UInt, unroll: bool) { + let tmp1 = rhs * rhs; + let tmp2 = tmp1 + rhs; + for i in range(0usize, end, unroll) { - lhs[i] = rhs.clone() + lhs[i].clone(); + lhs[i] = tmp2 + lhs[i]; } } -// #[cube] -// pub fn kernel_inner(lhs: Float, rhs: Float) -> Float { -// lhs + rhs -// } +#[allow(unused_mut)] +pub fn kernel_expand( + context: &mut burn_cube::CubeContext, + mut lhs: as burn_cube::RuntimeType>::ExpandType, + rhs: ::ExpandType, + end: ::ExpandType, + unroll: ::ExpandType, +) -> () { + let tmp1 = { + let _lhs = rhs.clone(); + let _rhs = rhs.clone(); + burn_cube::mul::expand(context, _lhs, _rhs) + }; + let tmp2 = { + let _lhs = tmp1; + let _rhs = rhs; + burn_cube::add::expand(context, _lhs, _rhs) + }; + range_expand( + context, + 0usize.into(), + end.clone(), + unroll.clone(), + |context, i| { + { + let _array = lhs.clone(); + let _index = i.clone(); + let _value = { + let _lhs = tmp2.clone(); + let _rhs = { + let _array = lhs.clone(); + let _index = i.clone(); + burn_cube::index::expand(context, _array, _index) + }; + burn_cube::add::expand(context, _lhs, _rhs) + }; + burn_cube::index_assign::expand(context, _array, _index, _value) + }; + }, + ); +} #[test] fn test_simple_add() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Vec4(Elem::Float)); - let rhs = context.create_local(Item::Vec4(Elem::Float)); + let lhs = context.create_local(Item::Scalar(Elem::Float)); + let rhs = context.create_local(Item::Scalar(Elem::Float)); let end = context.create_local(Item::Scalar(Elem::UInt)); kernel_expand(&mut context, lhs, rhs, end, false); + let scope = context.into_scope(); + + for op in scope.operations.iter() { + println!("{op:?}"); + } + + panic!("nop"); } From 8c82ea82c5a5a554f08e7339cb1502e9b464859f Mon Sep 17 00:00:00 2001 From: nathaniel Date: Tue, 16 Apr 2024 12:01:07 -0400 Subject: [PATCH 06/54] Wip --- crates/burn-cube-macros/src/lib.rs | 59 +++++++++++++++++- crates/burn-cube-macros/src/statement.rs | 78 ++++++++++++------------ 2 files changed, 96 insertions(+), 41 deletions(-) diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index 82851fe008..f399f9835c 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -1,4 +1,6 @@ mod statement; +use std::collections::HashMap; + use proc_macro::TokenStream; use statement::parse_statement; @@ -6,11 +8,65 @@ use statement::parse_statement; #[proc_macro_attribute] pub fn cube(_attr: TokenStream, tokens: TokenStream) -> TokenStream { let func: syn::ItemFn = syn::parse(tokens).unwrap(); + let mut variables = VariableAnalyses::create(&func); + + codegen_cube(&func, &mut variables) +} + +#[derive(Hash, PartialEq, Eq, Debug, Clone)] +struct VariableKey { + name: String, +} + +impl From<&syn::Ident> for VariableKey { + fn from(value: &syn::Ident) -> Self { + VariableKey { + name: value.to_string(), + } + } +} + +struct VariableAnalysis { + num_used: usize, + loop_level_declared: usize, +} + +impl VariableAnalysis { + pub fn should_clone(&self, loop_level: usize) -> bool { + if self.num_used == 1 && self.loop_level_declared >= loop_level { + return false; + } + + true + } +} + +struct VariableAnalyses { + analyses: HashMap, +} + +impl VariableAnalyses { + pub fn should_clone(&self, ident: &syn::Ident, loop_level: usize) -> bool { + let key: VariableKey = ident.into(); + if let Some(var) = self.analyses.get(&key) { + return var.should_clone(loop_level); + } + + false + } + pub fn create(func: &syn::ItemFn) -> Self { + Self { + analyses: Default::default(), + } + } +} + +fn codegen_cube(func: &syn::ItemFn, variables: &mut VariableAnalyses) -> TokenStream { let signature = expand_sig(&func.sig); let mut body = quote::quote! {}; for statement in func.block.stmts.iter() { - let tokens = parse_statement(statement); + let tokens = parse_statement(statement, 0); body.extend(tokens); } @@ -24,7 +80,6 @@ pub fn cube(_attr: TokenStream, tokens: TokenStream) -> TokenStream { } .into(); - panic!("{code}"); code } diff --git a/crates/burn-cube-macros/src/statement.rs b/crates/burn-cube-macros/src/statement.rs index 1f71d330fc..3b05579a6e 100644 --- a/crates/burn-cube-macros/src/statement.rs +++ b/crates/burn-cube-macros/src/statement.rs @@ -1,24 +1,24 @@ use proc_macro2::TokenStream; -pub fn parse_statement(statement: &syn::Stmt) -> TokenStream { +pub fn parse_statement(statement: &syn::Stmt, loop_level: usize) -> TokenStream { match statement { - syn::Stmt::Local(local) => parse_local(local), + syn::Stmt::Local(local) => parse_local(local, loop_level), syn::Stmt::Item(_) => todo!(), syn::Stmt::Expr(expr, semi) => { if let Some(_semi) = semi { - let expr = parse_expr(expr); + let expr = parse_expr(expr, loop_level); quote::quote! { #expr; } } else { - parse_expr(expr) + parse_expr(expr, loop_level) } } syn::Stmt::Macro(_) => todo!(), } } -fn parse_local(local: &syn::Local) -> TokenStream { +fn parse_local(local: &syn::Local, loop_level: usize) -> TokenStream { let init = local .init .as_ref() @@ -27,7 +27,7 @@ fn parse_local(local: &syn::Local) -> TokenStream { syn::Pat::Ident(ident) => ident, _ => panic!("Only ident declaration is supported."), }; - let init = parse_expr(&init.expr); + let init = parse_expr(&init.expr, loop_level); let let_tok = local.let_token; @@ -36,25 +36,25 @@ fn parse_local(local: &syn::Local) -> TokenStream { } } -fn parse_expr(expr: &syn::Expr) -> TokenStream { +fn parse_expr(expr: &syn::Expr, loop_level: usize) -> TokenStream { match expr { - syn::Expr::Binary(op) => parse_binary(op), - syn::Expr::Path(path) => parse_path(path), - syn::Expr::Call(call) => parse_call(call), + syn::Expr::Binary(op) => parse_binary(op, loop_level), + syn::Expr::Path(path) => parse_path(path, loop_level), + syn::Expr::Call(call) => parse_call(call, loop_level), syn::Expr::Lit(lit) => quote::quote! { #lit.into() }, - syn::Expr::Closure(closure) => parse_closure(closure), - syn::Expr::Block(block) => parse_expr_block(block), - syn::Expr::Assign(assign) => parse_assign(assign), - syn::Expr::ForLoop(for_loop) => parse_for_loop(for_loop), + syn::Expr::Closure(closure) => parse_closure(closure, loop_level), + syn::Expr::Block(block) => parse_expr_block(block, loop_level), + syn::Expr::Assign(assign) => parse_assign(assign, loop_level), + syn::Expr::ForLoop(for_loop) => parse_for_loop(for_loop, loop_level), syn::Expr::MethodCall(call) => parse_expr_method_call(call), - syn::Expr::Index(index) => parse_expr_index(index), + syn::Expr::Index(index) => parse_expr_index(index, loop_level), _ => panic!("Unsupported {:?}", expr), } } -fn parse_expr_index(index: &syn::ExprIndex) -> TokenStream { - let array = parse_expr(&index.expr); - let index = parse_expr(&index.index); +fn parse_expr_index(index: &syn::ExprIndex, loop_level: usize) -> TokenStream { + let array = parse_expr(&index.expr, loop_level); + let index = parse_expr(&index.index, loop_level); quote::quote! { { @@ -69,9 +69,9 @@ fn parse_expr_method_call(call: &syn::ExprMethodCall) -> TokenStream { quote::quote!( #call ) } -fn parse_for_loop(for_loop: &syn::ExprForLoop) -> TokenStream { +fn parse_for_loop(for_loop: &syn::ExprForLoop, loop_level: usize) -> TokenStream { let i = &for_loop.pat; - let block = parse_block(&for_loop.body); + let block = parse_block(&for_loop.body, loop_level + 1); match for_loop.expr.as_ref() { syn::Expr::Call(call) => { @@ -84,7 +84,7 @@ fn parse_for_loop(for_loop: &syn::ExprForLoop) -> TokenStream { context, }; for argument in call.args.iter() { - let arg = parse_expr(argument); + let arg = parse_expr(argument, loop_level); args.extend(quote::quote! { #arg, }); } @@ -99,11 +99,11 @@ fn parse_for_loop(for_loop: &syn::ExprForLoop) -> TokenStream { todo!(); } -fn parse_assign(assign: &syn::ExprAssign) -> TokenStream { +fn parse_assign(assign: &syn::ExprAssign, loop_level: usize) -> TokenStream { if let syn::Expr::Index(index) = assign.left.as_ref() { - let array = parse_expr(&index.expr); - let index = parse_expr(&index.index); - let value = parse_expr(&assign.right); + let array = parse_expr(&index.expr, loop_level); + let index = parse_expr(&index.index, loop_level); + let value = parse_expr(&assign.right, loop_level); return quote::quote! { { @@ -115,8 +115,8 @@ fn parse_assign(assign: &syn::ExprAssign) -> TokenStream { }; }; - let lhs = parse_expr(&assign.left); - let rhs = parse_expr(&assign.right); + let lhs = parse_expr(&assign.left, loop_level); + let rhs = parse_expr(&assign.right, loop_level); quote::quote! { { @@ -127,11 +127,11 @@ fn parse_assign(assign: &syn::ExprAssign) -> TokenStream { } } -fn parse_block(block: &syn::Block) -> TokenStream { +fn parse_block(block: &syn::Block, loop_level: usize) -> TokenStream { let mut statements = quote::quote!(); for statement in block.stmts.iter() { - statements.extend(parse_statement(statement)); + statements.extend(parse_statement(statement, loop_level)); } quote::quote! { @@ -141,11 +141,11 @@ fn parse_block(block: &syn::Block) -> TokenStream { } } -fn parse_expr_block(block: &syn::ExprBlock) -> TokenStream { - parse_block(&block.block) +fn parse_expr_block(block: &syn::ExprBlock, loop_level: usize) -> TokenStream { + parse_block(&block.block, loop_level) } -fn parse_closure(closure: &syn::ExprClosure) -> TokenStream { +fn parse_closure(closure: &syn::ExprClosure, loop_level: usize) -> TokenStream { let mut inputs = quote::quote! {}; for input in closure.inputs.iter() { let ident = match input { @@ -157,14 +157,14 @@ fn parse_closure(closure: &syn::ExprClosure) -> TokenStream { }); } - let body = parse_expr(closure.body.as_ref()); + let body = parse_expr(closure.body.as_ref(), loop_level); quote::quote! { |context, #inputs| #body } } -fn parse_call(call: &syn::ExprCall) -> TokenStream { +fn parse_call(call: &syn::ExprCall, loop_level: usize) -> TokenStream { let func_name = match call.func.as_ref() { syn::Expr::Path(path) => path.path.get_ident().unwrap(), _ => todo!("Only path call supported"), @@ -176,7 +176,7 @@ fn parse_call(call: &syn::ExprCall) -> TokenStream { syn::Ident::new(format!("{func_name}_expand").as_str(), func_name.span()); for argument in call.args.iter() { - let arg = parse_expr(argument); + let arg = parse_expr(argument, loop_level); args.extend(quote::quote! { #arg, }); } @@ -185,7 +185,7 @@ fn parse_call(call: &syn::ExprCall) -> TokenStream { } } -fn parse_path(path: &syn::ExprPath) -> TokenStream { +fn parse_path(path: &syn::ExprPath, loop_level: usize) -> TokenStream { let ident = path .path .get_ident() @@ -205,9 +205,9 @@ fn parse_path(path: &syn::ExprPath) -> TokenStream { } } -fn parse_binary(binary: &syn::ExprBinary) -> TokenStream { - let lhs = parse_expr(&binary.left); - let rhs = parse_expr(&binary.right); +fn parse_binary(binary: &syn::ExprBinary, loop_level: usize) -> TokenStream { + let lhs = parse_expr(&binary.left, loop_level); + let rhs = parse_expr(&binary.right, loop_level); match binary.op { syn::BinOp::Add(_) => quote::quote! { From 8ed521787c97b853407c17564a199e57d4da936e Mon Sep 17 00:00:00 2001 From: louisfd Date: Wed, 17 Apr 2024 10:37:51 -0400 Subject: [PATCH 07/54] little refactor --- crates/burn-cube-macros/src/lib.rs | 11 +-- crates/burn-cube-macros/src/statement.rs | 103 ++++++++++++----------- 2 files changed, 57 insertions(+), 57 deletions(-) diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index f399f9835c..b5a48f074b 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -2,7 +2,7 @@ mod statement; use std::collections::HashMap; use proc_macro::TokenStream; -use statement::parse_statement; +use statement::codegen_statement; /// Derive macro for the module. #[proc_macro_attribute] @@ -33,11 +33,7 @@ struct VariableAnalysis { impl VariableAnalysis { pub fn should_clone(&self, loop_level: usize) -> bool { - if self.num_used == 1 && self.loop_level_declared >= loop_level { - return false; - } - - true + self.num_used > 1 || self.loop_level_declared < loop_level } } @@ -54,6 +50,7 @@ impl VariableAnalyses { false } + pub fn create(func: &syn::ItemFn) -> Self { Self { analyses: Default::default(), @@ -66,7 +63,7 @@ fn codegen_cube(func: &syn::ItemFn, variables: &mut VariableAnalyses) -> TokenSt let mut body = quote::quote! {}; for statement in func.block.stmts.iter() { - let tokens = parse_statement(statement, 0); + let tokens = codegen_statement(statement, 0); body.extend(tokens); } diff --git a/crates/burn-cube-macros/src/statement.rs b/crates/burn-cube-macros/src/statement.rs index 3b05579a6e..b97b6c651b 100644 --- a/crates/burn-cube-macros/src/statement.rs +++ b/crates/burn-cube-macros/src/statement.rs @@ -1,24 +1,23 @@ use proc_macro2::TokenStream; -pub fn parse_statement(statement: &syn::Stmt, loop_level: usize) -> TokenStream { +pub fn codegen_statement(statement: &syn::Stmt, loop_level: usize) -> TokenStream { match statement { - syn::Stmt::Local(local) => parse_local(local, loop_level), + syn::Stmt::Local(local) => codegen_local(local, loop_level), syn::Stmt::Item(_) => todo!(), syn::Stmt::Expr(expr, semi) => { - if let Some(_semi) = semi { - let expr = parse_expr(expr, loop_level); - quote::quote! { + let expr = codegen_expr(expr, loop_level); + match semi { + Some(_semi) => quote::quote!( #expr; - } - } else { - parse_expr(expr, loop_level) + ), + None => expr, } } syn::Stmt::Macro(_) => todo!(), } } -fn parse_local(local: &syn::Local, loop_level: usize) -> TokenStream { +fn codegen_local(local: &syn::Local, loop_level: usize) -> TokenStream { let init = local .init .as_ref() @@ -27,7 +26,7 @@ fn parse_local(local: &syn::Local, loop_level: usize) -> TokenStream { syn::Pat::Ident(ident) => ident, _ => panic!("Only ident declaration is supported."), }; - let init = parse_expr(&init.expr, loop_level); + let init = codegen_expr(&init.expr, loop_level); let let_tok = local.let_token; @@ -36,25 +35,25 @@ fn parse_local(local: &syn::Local, loop_level: usize) -> TokenStream { } } -fn parse_expr(expr: &syn::Expr, loop_level: usize) -> TokenStream { +fn codegen_expr(expr: &syn::Expr, loop_level: usize) -> TokenStream { match expr { - syn::Expr::Binary(op) => parse_binary(op, loop_level), - syn::Expr::Path(path) => parse_path(path, loop_level), - syn::Expr::Call(call) => parse_call(call, loop_level), + syn::Expr::Binary(op) => codegen_binary(op, loop_level), + syn::Expr::Path(path) => codegen_path(path, loop_level), + syn::Expr::Call(call) => codegen_call(call, loop_level), syn::Expr::Lit(lit) => quote::quote! { #lit.into() }, - syn::Expr::Closure(closure) => parse_closure(closure, loop_level), - syn::Expr::Block(block) => parse_expr_block(block, loop_level), - syn::Expr::Assign(assign) => parse_assign(assign, loop_level), - syn::Expr::ForLoop(for_loop) => parse_for_loop(for_loop, loop_level), - syn::Expr::MethodCall(call) => parse_expr_method_call(call), - syn::Expr::Index(index) => parse_expr_index(index, loop_level), + syn::Expr::Closure(closure) => codegen_closure(closure, loop_level), + syn::Expr::Block(block) => codegen_expr_block(block, loop_level), + syn::Expr::Assign(assign) => codegen_assign(assign, loop_level), + syn::Expr::ForLoop(for_loop) => codegen_for_loop(for_loop, loop_level), + syn::Expr::MethodCall(call) => codegen_expr_method_call(call), + syn::Expr::Index(index) => codegen_expr_index(index, loop_level), _ => panic!("Unsupported {:?}", expr), } } -fn parse_expr_index(index: &syn::ExprIndex, loop_level: usize) -> TokenStream { - let array = parse_expr(&index.expr, loop_level); - let index = parse_expr(&index.index, loop_level); +fn codegen_expr_index(index: &syn::ExprIndex, loop_level: usize) -> TokenStream { + let array = codegen_expr(&index.expr, loop_level); + let index = codegen_expr(&index.index, loop_level); quote::quote! { { @@ -65,13 +64,13 @@ fn parse_expr_index(index: &syn::ExprIndex, loop_level: usize) -> TokenStream { } } -fn parse_expr_method_call(call: &syn::ExprMethodCall) -> TokenStream { +fn codegen_expr_method_call(call: &syn::ExprMethodCall) -> TokenStream { quote::quote!( #call ) } -fn parse_for_loop(for_loop: &syn::ExprForLoop, loop_level: usize) -> TokenStream { +fn codegen_for_loop(for_loop: &syn::ExprForLoop, loop_level: usize) -> TokenStream { let i = &for_loop.pat; - let block = parse_block(&for_loop.body, loop_level + 1); + let block = codegen_block(&for_loop.body, loop_level + 1); match for_loop.expr.as_ref() { syn::Expr::Call(call) => { @@ -79,31 +78,33 @@ fn parse_for_loop(for_loop: &syn::ExprForLoop, loop_level: usize) -> TokenStream syn::Expr::Path(path) => path.path.get_ident().unwrap(), _ => todo!("Only path call supported"), }; + if &func_name.to_string() == "range" { let mut args = quote::quote! { context, }; + for argument in call.args.iter() { - let arg = parse_expr(argument, loop_level); + let arg = codegen_expr(argument, loop_level); args.extend(quote::quote! { #arg, }); } - return quote::quote! { + quote::quote! { range_expand(#args |context, #i| #block); - }; + } + } else { + todo!("Only range is supported") } } _ => todo!("Only call is supported {for_loop:?}"), } - - todo!(); } -fn parse_assign(assign: &syn::ExprAssign, loop_level: usize) -> TokenStream { +fn codegen_assign(assign: &syn::ExprAssign, loop_level: usize) -> TokenStream { if let syn::Expr::Index(index) = assign.left.as_ref() { - let array = parse_expr(&index.expr, loop_level); - let index = parse_expr(&index.index, loop_level); - let value = parse_expr(&assign.right, loop_level); + let array = codegen_expr(&index.expr, loop_level); + let index = codegen_expr(&index.index, loop_level); + let value = codegen_expr(&assign.right, loop_level); return quote::quote! { { @@ -115,8 +116,8 @@ fn parse_assign(assign: &syn::ExprAssign, loop_level: usize) -> TokenStream { }; }; - let lhs = parse_expr(&assign.left, loop_level); - let rhs = parse_expr(&assign.right, loop_level); + let lhs = codegen_expr(&assign.left, loop_level); + let rhs = codegen_expr(&assign.right, loop_level); quote::quote! { { @@ -127,11 +128,11 @@ fn parse_assign(assign: &syn::ExprAssign, loop_level: usize) -> TokenStream { } } -fn parse_block(block: &syn::Block, loop_level: usize) -> TokenStream { +fn codegen_block(block: &syn::Block, loop_level: usize) -> TokenStream { let mut statements = quote::quote!(); for statement in block.stmts.iter() { - statements.extend(parse_statement(statement, loop_level)); + statements.extend(codegen_statement(statement, loop_level)); } quote::quote! { @@ -141,11 +142,11 @@ fn parse_block(block: &syn::Block, loop_level: usize) -> TokenStream { } } -fn parse_expr_block(block: &syn::ExprBlock, loop_level: usize) -> TokenStream { - parse_block(&block.block, loop_level) +fn codegen_expr_block(block: &syn::ExprBlock, loop_level: usize) -> TokenStream { + codegen_block(&block.block, loop_level) } -fn parse_closure(closure: &syn::ExprClosure, loop_level: usize) -> TokenStream { +fn codegen_closure(closure: &syn::ExprClosure, loop_level: usize) -> TokenStream { let mut inputs = quote::quote! {}; for input in closure.inputs.iter() { let ident = match input { @@ -157,26 +158,28 @@ fn parse_closure(closure: &syn::ExprClosure, loop_level: usize) -> TokenStream { }); } - let body = parse_expr(closure.body.as_ref(), loop_level); + let body = codegen_expr(closure.body.as_ref(), loop_level); quote::quote! { |context, #inputs| #body } } -fn parse_call(call: &syn::ExprCall, loop_level: usize) -> TokenStream { +fn codegen_call(call: &syn::ExprCall, loop_level: usize) -> TokenStream { let func_name = match call.func.as_ref() { syn::Expr::Path(path) => path.path.get_ident().unwrap(), _ => todo!("Only path call supported"), }; + let mut args = quote::quote! { context, }; + let func_name_expand = syn::Ident::new(format!("{func_name}_expand").as_str(), func_name.span()); for argument in call.args.iter() { - let arg = parse_expr(argument, loop_level); + let arg = codegen_expr(argument, loop_level); args.extend(quote::quote! { #arg, }); } @@ -185,13 +188,13 @@ fn parse_call(call: &syn::ExprCall, loop_level: usize) -> TokenStream { } } -fn parse_path(path: &syn::ExprPath, loop_level: usize) -> TokenStream { +fn codegen_path(path: &syn::ExprPath, loop_level: usize) -> TokenStream { let ident = path .path .get_ident() .expect("Only ident path are supported."); - // TODO: Check in the following statements if the indent is overriden, or reused. + // TODO: Check in the following statements if the ident is overriden, or reused. let will_be_used_again = true; if will_be_used_again { @@ -205,9 +208,9 @@ fn parse_path(path: &syn::ExprPath, loop_level: usize) -> TokenStream { } } -fn parse_binary(binary: &syn::ExprBinary, loop_level: usize) -> TokenStream { - let lhs = parse_expr(&binary.left, loop_level); - let rhs = parse_expr(&binary.right, loop_level); +fn codegen_binary(binary: &syn::ExprBinary, loop_level: usize) -> TokenStream { + let lhs = codegen_expr(&binary.left, loop_level); + let rhs = codegen_expr(&binary.right, loop_level); match binary.op { syn::BinOp::Add(_) => quote::quote! { From dbbfdb7f50be6c385a6566abe9a384472a42b8c4 Mon Sep 17 00:00:00 2001 From: louisfd Date: Thu, 18 Apr 2024 16:17:47 -0400 Subject: [PATCH 08/54] wip --- crates/burn-cube-macros/src/analysis.rs | 126 +++++++++++++++++++++ crates/burn-cube/tests/cube.rs | 48 +------- crates/burn-jit/src/codegen/compilation.rs | 4 +- 3 files changed, 132 insertions(+), 46 deletions(-) create mode 100644 crates/burn-cube-macros/src/analysis.rs diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs new file mode 100644 index 0000000000..b01d378eba --- /dev/null +++ b/crates/burn-cube-macros/src/analysis.rs @@ -0,0 +1,126 @@ +use std::collections::HashMap; + +use syn::Stmt; + +use crate::VariableKey; + +#[derive(Debug)] +pub(crate) struct VariableAnalysis { + num_used: usize, + loop_level_declared: usize, +} + +impl VariableAnalysis { + pub fn should_clone(&self, loop_level: usize) -> bool { + self.num_used > 1 || self.loop_level_declared < loop_level + } +} + +#[derive(Debug)] +pub(crate) struct VariableAnalyses { + pub analyses: HashMap, +} + +impl VariableAnalyses { + pub fn should_clone(&self, ident: &syn::Ident, loop_level: usize) -> bool { + let key: VariableKey = ident.into(); + if let Some(var) = self.analyses.get(&key) { + return var.should_clone(loop_level); + } + + false + } + + pub fn create(func: &syn::ItemFn) -> Self { + let analyses = analyze(func); + + Self { analyses } + } +} + +pub(crate) fn analyze(func: &syn::ItemFn) -> HashMap { + // Build the vector of (Id, depth), using recursion + let mut declarations = Vec::new(); + let mut var_uses = Vec::new(); + list_occurrences(&func.block.stmts, 0, &mut declarations, &mut var_uses); + + // Run through the vec and build hashmap, without recursion + let mut analyses = HashMap::::new(); + for declaration in declarations.into_iter() { + let id = declaration.0; + let new_analysis = match analyses.remove(&id) { + Some(_) => { + panic!("Multiple variables with the same identifier is not supported") + } + None => VariableAnalysis { + num_used: 0, + loop_level_declared: declaration.1, + }, + }; + + analyses.insert(id, new_analysis); + } + + for id in var_uses.into_iter() { + let prev_analysis = analyses + .remove(&id) + .expect("Variable {id} should be declared before it's used"); + let new_analysis = VariableAnalysis { + num_used: prev_analysis.num_used + 1, + loop_level_declared: prev_analysis.loop_level_declared, + }; + analyses.insert(id, new_analysis); + } + + analyses +} + +fn list_occurrences( + stmts: &Vec, + depth: usize, + declarations: &mut Vec<(VariableKey, usize)>, + uses: &mut Vec, +) { + for stmt in stmts { + match stmt { + // Declaration + syn::Stmt::Local(local) => match &local.pat { + syn::Pat::Ident(pat_ident) => { + let id = &pat_ident.ident; + declarations.push((id.into(), depth)); + } + _ => todo!(), + }, + syn::Stmt::Expr(expr, _) => occ_expr(expr, depth, declarations, uses), + _ => todo!(), + } + } +} + +fn occ_expr( + expr: &syn::Expr, + depth: usize, + declarations: &mut Vec<(VariableKey, usize)>, + uses: &mut Vec, +) { + match expr { + syn::Expr::ForLoop(expr) => { + // Declaration + if let syn::Pat::Ident(pat_ident) = &*expr.pat { + let id = &pat_ident.ident; + declarations.push((id.into(), depth)); + } + + list_occurrences(&expr.body.stmts, depth + 1, declarations, uses); + } + syn::Expr::Assign(expr) => { + occ_expr(&expr.right, depth, declarations, uses); + } + syn::Expr::Index(expr) => panic!("{expr:?}"), + syn::Expr::Path(expr) => panic!("{expr:?}"), + syn::Expr::Binary(expr) => { + + }, + _ => todo!(), + } +} diff --git a/crates/burn-cube/tests/cube.rs b/crates/burn-cube/tests/cube.rs index 6f119a2338..8a12c4e815 100644 --- a/crates/burn-cube/tests/cube.rs +++ b/crates/burn-cube/tests/cube.rs @@ -1,7 +1,7 @@ use burn_cube::{cube, range, range_expand, Array, CubeContext, Float, UInt}; use burn_jit::gpu::{Elem, Item}; -// #[cube] +#[cube] pub fn kernel(mut lhs: Array, rhs: Float, end: UInt, unroll: bool) { let tmp1 = rhs * rhs; let tmp2 = tmp1 + rhs; @@ -11,48 +11,6 @@ pub fn kernel(mut lhs: Array, rhs: Float, end: UInt, unroll: bool) { } } -#[allow(unused_mut)] -pub fn kernel_expand( - context: &mut burn_cube::CubeContext, - mut lhs: as burn_cube::RuntimeType>::ExpandType, - rhs: ::ExpandType, - end: ::ExpandType, - unroll: ::ExpandType, -) -> () { - let tmp1 = { - let _lhs = rhs.clone(); - let _rhs = rhs.clone(); - burn_cube::mul::expand(context, _lhs, _rhs) - }; - let tmp2 = { - let _lhs = tmp1; - let _rhs = rhs; - burn_cube::add::expand(context, _lhs, _rhs) - }; - range_expand( - context, - 0usize.into(), - end.clone(), - unroll.clone(), - |context, i| { - { - let _array = lhs.clone(); - let _index = i.clone(); - let _value = { - let _lhs = tmp2.clone(); - let _rhs = { - let _array = lhs.clone(); - let _index = i.clone(); - burn_cube::index::expand(context, _array, _index) - }; - burn_cube::add::expand(context, _lhs, _rhs) - }; - burn_cube::index_assign::expand(context, _array, _index, _value) - }; - }, - ); -} - #[test] fn test_simple_add() { let mut context = CubeContext::root(); @@ -61,11 +19,11 @@ fn test_simple_add() { let rhs = context.create_local(Item::Scalar(Elem::Float)); let end = context.create_local(Item::Scalar(Elem::UInt)); - kernel_expand(&mut context, lhs, rhs, end, false); + // kernel_expand(&mut context, lhs, rhs, end, false); let scope = context.into_scope(); for op in scope.operations.iter() { - println!("{op:?}"); + // println!("{op:?}"); } panic!("nop"); diff --git a/crates/burn-jit/src/codegen/compilation.rs b/crates/burn-jit/src/codegen/compilation.rs index 50136a88df..fdd20b5cfe 100644 --- a/crates/burn-jit/src/codegen/compilation.rs +++ b/crates/burn-jit/src/codegen/compilation.rs @@ -193,7 +193,9 @@ impl CompilationSettings { if chosen.is_some() { break; } - if desc.shape == desc_input.shape && input.item() == output.item() { + if desc.shape == desc_input.shape + && input.item() == output.item() + { chosen = Some(index); } } From 797d99715e4778561530acbcab7f19de5f2d6b62 Mon Sep 17 00:00:00 2001 From: louisfd Date: Fri, 19 Apr 2024 16:23:14 -0400 Subject: [PATCH 09/54] the right number of clones --- crates/burn-cube-macros/src/analysis.rs | 85 ++++++++++++--- crates/burn-cube-macros/src/lib.rs | 40 +------ crates/burn-cube-macros/src/statement.rs | 128 ++++++++++++++++------- crates/burn-cube/tests/cube.rs | 4 +- 4 files changed, 165 insertions(+), 92 deletions(-) diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs index b01d378eba..fdbac773bc 100644 --- a/crates/burn-cube-macros/src/analysis.rs +++ b/crates/burn-cube-macros/src/analysis.rs @@ -11,8 +11,13 @@ pub(crate) struct VariableAnalysis { } impl VariableAnalysis { - pub fn should_clone(&self, loop_level: usize) -> bool { - self.num_used > 1 || self.loop_level_declared < loop_level + pub fn should_clone(&mut self, loop_level: usize) -> bool { + if self.num_used > 1 { + self.num_used -= 1; + true + } else { + self.loop_level_declared < loop_level + } } } @@ -22,13 +27,16 @@ pub(crate) struct VariableAnalyses { } impl VariableAnalyses { - pub fn should_clone(&self, ident: &syn::Ident, loop_level: usize) -> bool { + pub fn should_clone(&mut self, ident: &syn::Ident, loop_level: usize) -> bool { let key: VariableKey = ident.into(); - if let Some(var) = self.analyses.get(&key) { - return var.should_clone(loop_level); + match self.analyses.remove(&key) { + Some(mut var) => { + let should_clone = var.should_clone(loop_level); + self.analyses.insert(key, var); + return should_clone; + } + None => panic!("Ident {ident} not part of analysis"), } - - false } pub fn create(func: &syn::ItemFn) -> Self { @@ -41,8 +49,11 @@ impl VariableAnalyses { pub(crate) fn analyze(func: &syn::ItemFn) -> HashMap { // Build the vector of (Id, depth), using recursion let mut declarations = Vec::new(); + list_declarations_in_signature(&func.sig, &mut declarations); + let mut var_uses = Vec::new(); list_occurrences(&func.block.stmts, 0, &mut declarations, &mut var_uses); + // panic!("{var_uses:?}"); // Run through the vec and build hashmap, without recursion let mut analyses = HashMap::::new(); @@ -62,19 +73,42 @@ pub(crate) fn analyze(func: &syn::ItemFn) -> HashMap, +) { + for input in &sig.inputs { + match input { + syn::FnArg::Typed(pat) => { + let ident = &*pat.pat; + match ident { + syn::Pat::Ident(pat_ident) => { + let id = &pat_ident.ident; + declarations.push((id.into(), 0)); + } + _ => todo!(), + } + } + _ => todo!(), + } + } +} + fn list_occurrences( stmts: &Vec, depth: usize, @@ -88,6 +122,10 @@ fn list_occurrences( syn::Pat::Ident(pat_ident) => { let id = &pat_ident.ident; declarations.push((id.into(), depth)); + + if let Some(local_init) = &local.init { + occ_expr(&local_init.expr, depth, declarations, uses) + } } _ => todo!(), }, @@ -105,22 +143,37 @@ fn occ_expr( ) { match expr { syn::Expr::ForLoop(expr) => { - // Declaration + let depth = depth + 1; + + // Declaration of iterator if let syn::Pat::Ident(pat_ident) = &*expr.pat { let id = &pat_ident.ident; declarations.push((id.into(), depth)); } - list_occurrences(&expr.body.stmts, depth + 1, declarations, uses); + list_occurrences(&expr.body.stmts, depth, declarations, uses); } syn::Expr::Assign(expr) => { + occ_expr(&expr.left, depth, declarations, uses); occ_expr(&expr.right, depth, declarations, uses); } - syn::Expr::Index(expr) => panic!("{expr:?}"), - syn::Expr::Path(expr) => panic!("{expr:?}"), + syn::Expr::Index(expr) => { + occ_expr(&expr.expr, depth, declarations, uses); + occ_expr(&expr.index, depth, declarations, uses); + } + syn::Expr::Path(expr) => { + let ident = expr + .path + .get_ident() + .expect("Only ident path are supported."); + + // Use + uses.push(ident.into()); + } syn::Expr::Binary(expr) => { - - }, + occ_expr(&expr.left, depth, declarations, uses); + occ_expr(&expr.right, depth, declarations, uses); + } _ => todo!(), } } diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index b5a48f074b..29e51858a6 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -1,6 +1,7 @@ +mod analysis; mod statement; -use std::collections::HashMap; +use analysis::VariableAnalyses; use proc_macro::TokenStream; use statement::codegen_statement; @@ -26,44 +27,12 @@ impl From<&syn::Ident> for VariableKey { } } -struct VariableAnalysis { - num_used: usize, - loop_level_declared: usize, -} - -impl VariableAnalysis { - pub fn should_clone(&self, loop_level: usize) -> bool { - self.num_used > 1 || self.loop_level_declared < loop_level - } -} - -struct VariableAnalyses { - analyses: HashMap, -} - -impl VariableAnalyses { - pub fn should_clone(&self, ident: &syn::Ident, loop_level: usize) -> bool { - let key: VariableKey = ident.into(); - if let Some(var) = self.analyses.get(&key) { - return var.should_clone(loop_level); - } - - false - } - - pub fn create(func: &syn::ItemFn) -> Self { - Self { - analyses: Default::default(), - } - } -} - -fn codegen_cube(func: &syn::ItemFn, variables: &mut VariableAnalyses) -> TokenStream { +fn codegen_cube(func: &syn::ItemFn, variable_analyses: &mut VariableAnalyses) -> TokenStream { let signature = expand_sig(&func.sig); let mut body = quote::quote! {}; for statement in func.block.stmts.iter() { - let tokens = codegen_statement(statement, 0); + let tokens = codegen_statement(statement, 0, variable_analyses); body.extend(tokens); } @@ -76,6 +45,7 @@ fn codegen_cube(func: &syn::ItemFn, variables: &mut VariableAnalyses) -> TokenSt } } .into(); + // panic!("{code}"); code } diff --git a/crates/burn-cube-macros/src/statement.rs b/crates/burn-cube-macros/src/statement.rs index b97b6c651b..808972813c 100644 --- a/crates/burn-cube-macros/src/statement.rs +++ b/crates/burn-cube-macros/src/statement.rs @@ -1,11 +1,17 @@ use proc_macro2::TokenStream; -pub fn codegen_statement(statement: &syn::Stmt, loop_level: usize) -> TokenStream { +use crate::analysis::VariableAnalyses; + +pub fn codegen_statement( + statement: &syn::Stmt, + loop_level: usize, + variable_analyses: &mut VariableAnalyses, +) -> TokenStream { match statement { - syn::Stmt::Local(local) => codegen_local(local, loop_level), + syn::Stmt::Local(local) => codegen_local(local, loop_level, variable_analyses), syn::Stmt::Item(_) => todo!(), syn::Stmt::Expr(expr, semi) => { - let expr = codegen_expr(expr, loop_level); + let expr = codegen_expr(expr, loop_level, variable_analyses); match semi { Some(_semi) => quote::quote!( #expr; @@ -17,7 +23,11 @@ pub fn codegen_statement(statement: &syn::Stmt, loop_level: usize) -> TokenStrea } } -fn codegen_local(local: &syn::Local, loop_level: usize) -> TokenStream { +fn codegen_local( + local: &syn::Local, + loop_level: usize, + variable_analyses: &mut VariableAnalyses, +) -> TokenStream { let init = local .init .as_ref() @@ -26,7 +36,7 @@ fn codegen_local(local: &syn::Local, loop_level: usize) -> TokenStream { syn::Pat::Ident(ident) => ident, _ => panic!("Only ident declaration is supported."), }; - let init = codegen_expr(&init.expr, loop_level); + let init = codegen_expr(&init.expr, loop_level, variable_analyses); let let_tok = local.let_token; @@ -35,25 +45,33 @@ fn codegen_local(local: &syn::Local, loop_level: usize) -> TokenStream { } } -fn codegen_expr(expr: &syn::Expr, loop_level: usize) -> TokenStream { +fn codegen_expr( + expr: &syn::Expr, + loop_level: usize, + variable_analyses: &mut VariableAnalyses, +) -> TokenStream { match expr { - syn::Expr::Binary(op) => codegen_binary(op, loop_level), - syn::Expr::Path(path) => codegen_path(path, loop_level), - syn::Expr::Call(call) => codegen_call(call, loop_level), + syn::Expr::Binary(op) => codegen_binary(op, loop_level, variable_analyses), + syn::Expr::Path(path) => codegen_path(path, loop_level, variable_analyses), + syn::Expr::Call(call) => codegen_call(call, loop_level, variable_analyses), syn::Expr::Lit(lit) => quote::quote! { #lit.into() }, - syn::Expr::Closure(closure) => codegen_closure(closure, loop_level), - syn::Expr::Block(block) => codegen_expr_block(block, loop_level), - syn::Expr::Assign(assign) => codegen_assign(assign, loop_level), - syn::Expr::ForLoop(for_loop) => codegen_for_loop(for_loop, loop_level), + syn::Expr::Closure(closure) => codegen_closure(closure, loop_level, variable_analyses), + syn::Expr::Block(block) => codegen_expr_block(block, loop_level, variable_analyses), + syn::Expr::Assign(assign) => codegen_assign(assign, loop_level, variable_analyses), + syn::Expr::ForLoop(for_loop) => codegen_for_loop(for_loop, loop_level, variable_analyses), syn::Expr::MethodCall(call) => codegen_expr_method_call(call), - syn::Expr::Index(index) => codegen_expr_index(index, loop_level), + syn::Expr::Index(index) => codegen_expr_index(index, loop_level, variable_analyses), _ => panic!("Unsupported {:?}", expr), } } -fn codegen_expr_index(index: &syn::ExprIndex, loop_level: usize) -> TokenStream { - let array = codegen_expr(&index.expr, loop_level); - let index = codegen_expr(&index.index, loop_level); +fn codegen_expr_index( + index: &syn::ExprIndex, + loop_level: usize, + variable_analyses: &mut VariableAnalyses, +) -> TokenStream { + let array = codegen_expr(&index.expr, loop_level, variable_analyses); + let index = codegen_expr(&index.index, loop_level, variable_analyses); quote::quote! { { @@ -68,9 +86,13 @@ fn codegen_expr_method_call(call: &syn::ExprMethodCall) -> TokenStream { quote::quote!( #call ) } -fn codegen_for_loop(for_loop: &syn::ExprForLoop, loop_level: usize) -> TokenStream { +fn codegen_for_loop( + for_loop: &syn::ExprForLoop, + loop_level: usize, + variable_analyses: &mut VariableAnalyses, +) -> TokenStream { let i = &for_loop.pat; - let block = codegen_block(&for_loop.body, loop_level + 1); + let block = codegen_block(&for_loop.body, loop_level + 1, variable_analyses); match for_loop.expr.as_ref() { syn::Expr::Call(call) => { @@ -85,7 +107,7 @@ fn codegen_for_loop(for_loop: &syn::ExprForLoop, loop_level: usize) -> TokenStre }; for argument in call.args.iter() { - let arg = codegen_expr(argument, loop_level); + let arg = codegen_expr(argument, loop_level, variable_analyses); args.extend(quote::quote! { #arg, }); } @@ -100,11 +122,15 @@ fn codegen_for_loop(for_loop: &syn::ExprForLoop, loop_level: usize) -> TokenStre } } -fn codegen_assign(assign: &syn::ExprAssign, loop_level: usize) -> TokenStream { +fn codegen_assign( + assign: &syn::ExprAssign, + loop_level: usize, + variable_analyses: &mut VariableAnalyses, +) -> TokenStream { if let syn::Expr::Index(index) = assign.left.as_ref() { - let array = codegen_expr(&index.expr, loop_level); - let index = codegen_expr(&index.index, loop_level); - let value = codegen_expr(&assign.right, loop_level); + let array = codegen_expr(&index.expr, loop_level, variable_analyses); + let index = codegen_expr(&index.index, loop_level, variable_analyses); + let value = codegen_expr(&assign.right, loop_level, variable_analyses); return quote::quote! { { @@ -116,8 +142,8 @@ fn codegen_assign(assign: &syn::ExprAssign, loop_level: usize) -> TokenStream { }; }; - let lhs = codegen_expr(&assign.left, loop_level); - let rhs = codegen_expr(&assign.right, loop_level); + let lhs = codegen_expr(&assign.left, loop_level, variable_analyses); + let rhs = codegen_expr(&assign.right, loop_level, variable_analyses); quote::quote! { { @@ -128,11 +154,15 @@ fn codegen_assign(assign: &syn::ExprAssign, loop_level: usize) -> TokenStream { } } -fn codegen_block(block: &syn::Block, loop_level: usize) -> TokenStream { +fn codegen_block( + block: &syn::Block, + loop_level: usize, + variable_analyses: &mut VariableAnalyses, +) -> TokenStream { let mut statements = quote::quote!(); for statement in block.stmts.iter() { - statements.extend(codegen_statement(statement, loop_level)); + statements.extend(codegen_statement(statement, loop_level, variable_analyses)); } quote::quote! { @@ -142,11 +172,19 @@ fn codegen_block(block: &syn::Block, loop_level: usize) -> TokenStream { } } -fn codegen_expr_block(block: &syn::ExprBlock, loop_level: usize) -> TokenStream { - codegen_block(&block.block, loop_level) +fn codegen_expr_block( + block: &syn::ExprBlock, + loop_level: usize, + variable_analyses: &mut VariableAnalyses, +) -> TokenStream { + codegen_block(&block.block, loop_level, variable_analyses) } -fn codegen_closure(closure: &syn::ExprClosure, loop_level: usize) -> TokenStream { +fn codegen_closure( + closure: &syn::ExprClosure, + loop_level: usize, + variable_analyses: &mut VariableAnalyses, +) -> TokenStream { let mut inputs = quote::quote! {}; for input in closure.inputs.iter() { let ident = match input { @@ -158,14 +196,18 @@ fn codegen_closure(closure: &syn::ExprClosure, loop_level: usize) -> TokenStream }); } - let body = codegen_expr(closure.body.as_ref(), loop_level); + let body = codegen_expr(closure.body.as_ref(), loop_level, variable_analyses); quote::quote! { |context, #inputs| #body } } -fn codegen_call(call: &syn::ExprCall, loop_level: usize) -> TokenStream { +fn codegen_call( + call: &syn::ExprCall, + loop_level: usize, + variable_analyses: &mut VariableAnalyses, +) -> TokenStream { let func_name = match call.func.as_ref() { syn::Expr::Path(path) => path.path.get_ident().unwrap(), _ => todo!("Only path call supported"), @@ -179,7 +221,7 @@ fn codegen_call(call: &syn::ExprCall, loop_level: usize) -> TokenStream { syn::Ident::new(format!("{func_name}_expand").as_str(), func_name.span()); for argument in call.args.iter() { - let arg = codegen_expr(argument, loop_level); + let arg = codegen_expr(argument, loop_level, variable_analyses); args.extend(quote::quote! { #arg, }); } @@ -188,14 +230,18 @@ fn codegen_call(call: &syn::ExprCall, loop_level: usize) -> TokenStream { } } -fn codegen_path(path: &syn::ExprPath, loop_level: usize) -> TokenStream { +fn codegen_path( + path: &syn::ExprPath, + loop_level: usize, + variable_analyses: &mut VariableAnalyses, +) -> TokenStream { let ident = path .path .get_ident() .expect("Only ident path are supported."); // TODO: Check in the following statements if the ident is overriden, or reused. - let will_be_used_again = true; + let will_be_used_again = variable_analyses.should_clone(ident, loop_level); if will_be_used_again { quote::quote! { @@ -208,9 +254,13 @@ fn codegen_path(path: &syn::ExprPath, loop_level: usize) -> TokenStream { } } -fn codegen_binary(binary: &syn::ExprBinary, loop_level: usize) -> TokenStream { - let lhs = codegen_expr(&binary.left, loop_level); - let rhs = codegen_expr(&binary.right, loop_level); +fn codegen_binary( + binary: &syn::ExprBinary, + loop_level: usize, + variable_analyses: &mut VariableAnalyses, +) -> TokenStream { + let lhs = codegen_expr(&binary.left, loop_level, variable_analyses); + let rhs = codegen_expr(&binary.right, loop_level, variable_analyses); match binary.op { syn::BinOp::Add(_) => quote::quote! { diff --git a/crates/burn-cube/tests/cube.rs b/crates/burn-cube/tests/cube.rs index 8a12c4e815..7760a61900 100644 --- a/crates/burn-cube/tests/cube.rs +++ b/crates/burn-cube/tests/cube.rs @@ -19,11 +19,11 @@ fn test_simple_add() { let rhs = context.create_local(Item::Scalar(Elem::Float)); let end = context.create_local(Item::Scalar(Elem::UInt)); - // kernel_expand(&mut context, lhs, rhs, end, false); + kernel_expand(&mut context, lhs, rhs, end, false); let scope = context.into_scope(); for op in scope.operations.iter() { - // println!("{op:?}"); + println!("{op:?}"); } panic!("nop"); From 2617e6bd68bfe191c83fc2c62e811f51bad2f07c Mon Sep 17 00:00:00 2001 From: louisfd Date: Fri, 19 Apr 2024 18:23:50 -0400 Subject: [PATCH 10/54] add comments --- crates/burn-cube-macros/src/analysis.rs | 33 +++++++++------------ crates/burn-cube-macros/src/lib.rs | 5 ++-- crates/burn-cube/src/context.rs | 27 ++++++++++++++---- crates/burn-cube/src/element.rs | 15 ++++++++++ crates/burn-cube/tests/cube.rs | 38 ++++++++++++++++++++++++- 5 files changed, 91 insertions(+), 27 deletions(-) diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs index fdbac773bc..04e1d3f73b 100644 --- a/crates/burn-cube-macros/src/analysis.rs +++ b/crates/burn-cube-macros/src/analysis.rs @@ -49,11 +49,10 @@ impl VariableAnalyses { pub(crate) fn analyze(func: &syn::ItemFn) -> HashMap { // Build the vector of (Id, depth), using recursion let mut declarations = Vec::new(); - list_declarations_in_signature(&func.sig, &mut declarations); + signature_declarations(&func.sig, &mut declarations); let mut var_uses = Vec::new(); - list_occurrences(&func.block.stmts, 0, &mut declarations, &mut var_uses); - // panic!("{var_uses:?}"); + stmts_occurrences(&func.block.stmts, 0, &mut declarations, &mut var_uses); // Run through the vec and build hashmap, without recursion let mut analyses = HashMap::::new(); @@ -83,15 +82,11 @@ pub(crate) fn analyze(func: &syn::ItemFn) -> HashMap, -) { +fn signature_declarations(sig: &syn::Signature, declarations: &mut Vec<(VariableKey, usize)>) { for input in &sig.inputs { match input { syn::FnArg::Typed(pat) => { @@ -109,7 +104,7 @@ fn list_declarations_in_signature( } } -fn list_occurrences( +fn stmts_occurrences( stmts: &Vec, depth: usize, declarations: &mut Vec<(VariableKey, usize)>, @@ -124,18 +119,18 @@ fn list_occurrences( declarations.push((id.into(), depth)); if let Some(local_init) = &local.init { - occ_expr(&local_init.expr, depth, declarations, uses) + expr_occurrences(&local_init.expr, depth, declarations, uses) } } _ => todo!(), }, - syn::Stmt::Expr(expr, _) => occ_expr(expr, depth, declarations, uses), + syn::Stmt::Expr(expr, _) => expr_occurrences(expr, depth, declarations, uses), _ => todo!(), } } } -fn occ_expr( +fn expr_occurrences( expr: &syn::Expr, depth: usize, declarations: &mut Vec<(VariableKey, usize)>, @@ -151,15 +146,15 @@ fn occ_expr( declarations.push((id.into(), depth)); } - list_occurrences(&expr.body.stmts, depth, declarations, uses); + stmts_occurrences(&expr.body.stmts, depth, declarations, uses); } syn::Expr::Assign(expr) => { - occ_expr(&expr.left, depth, declarations, uses); - occ_expr(&expr.right, depth, declarations, uses); + expr_occurrences(&expr.left, depth, declarations, uses); + expr_occurrences(&expr.right, depth, declarations, uses); } syn::Expr::Index(expr) => { - occ_expr(&expr.expr, depth, declarations, uses); - occ_expr(&expr.index, depth, declarations, uses); + expr_occurrences(&expr.expr, depth, declarations, uses); + expr_occurrences(&expr.index, depth, declarations, uses); } syn::Expr::Path(expr) => { let ident = expr @@ -171,8 +166,8 @@ fn occ_expr( uses.push(ident.into()); } syn::Expr::Binary(expr) => { - occ_expr(&expr.left, depth, declarations, uses); - occ_expr(&expr.right, depth, declarations, uses); + expr_occurrences(&expr.left, depth, declarations, uses); + expr_occurrences(&expr.right, depth, declarations, uses); } _ => todo!(), } diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index 29e51858a6..3be328789a 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -9,9 +9,9 @@ use statement::codegen_statement; #[proc_macro_attribute] pub fn cube(_attr: TokenStream, tokens: TokenStream) -> TokenStream { let func: syn::ItemFn = syn::parse(tokens).unwrap(); - let mut variables = VariableAnalyses::create(&func); + let mut variable_analyses = VariableAnalyses::create(&func); - codegen_cube(&func, &mut variables) + codegen_cube(&func, &mut variable_analyses) } #[derive(Hash, PartialEq, Eq, Debug, Clone)] @@ -27,6 +27,7 @@ impl From<&syn::Ident> for VariableKey { } } +/// Generate the expanded version of a function marked with the cube macro fn codegen_cube(func: &syn::ItemFn, variable_analyses: &mut VariableAnalyses) -> TokenStream { let signature = expand_sig(&func.sig); let mut body = quote::quote! {}; diff --git a/crates/burn-cube/src/context.rs b/crates/burn-cube/src/context.rs index 1878c98664..813d23bdfd 100644 --- a/crates/burn-cube/src/context.rs +++ b/crates/burn-cube/src/context.rs @@ -10,14 +10,18 @@ pub struct VariablePool { } impl VariablePool { + /// Returns an old, not used anymore variable, if there exists one. pub fn reuse(&self, item: Item) -> Option { let map = self.map.borrow(); + // Filter for candidate variables of the same Item let variables = match map.get(&item) { Some(val) => val, None => return None, }; + // Among the candidates, take a variable if it's only referenced by the map + // Arbitrarily takes the first it finds for variable in variables.iter() { if Rc::strong_count(&variable.inner) == 1 { println!("Reuse var {:?}", variable.inner); @@ -25,14 +29,17 @@ impl VariablePool { } } - println!("New var"); - for variable in variables { - let count = Rc::strong_count(&variable.inner); - println!("{:?} => {}", variable.inner, count); - } + // println!("New var"); + // for variable in variables { + // let count = Rc::strong_count(&variable.inner); + // println!("{:?} => {}", variable.inner, count); + // } + // If no candidate was found, a new var will be needed None } + + /// Insert a new variable in the map, which is classified by Item pub fn insert(&mut self, var: ExpandElement) { let mut map = self.map.borrow_mut(); let item = var.item(); @@ -52,6 +59,9 @@ pub struct CubeContext { } impl CubeContext { + /// Create a new cube context, with a root scope + /// A root scope is at the root of a compute shader + /// Therefore there is one cube context per shader pub fn root() -> CubeContext { let root = Rc::new(RefCell::new(Scope::root())); let scope = root.clone(); @@ -62,9 +72,11 @@ impl CubeContext { root, } } + pub fn register>(&mut self, op: O) { self.scope.borrow_mut().register(op) } + pub fn child(&mut self) -> CubeContext { let scope = self.scope.borrow_mut().child(); @@ -83,11 +95,16 @@ impl CubeContext { .into_inner() } + /// When a new variable is required, we check if we can reuse an old one + /// Otherwise we create a new one. pub fn create_local(&mut self, item: Item) -> ExpandElement { + // Reuse an old variable if possible if let Some(var) = self.pool.reuse(item) { return var; } + // Create a new variable at the root scope + // Insert it in the variable pool for potential reuse let new = ExpandElement::new(Rc::new(self.root.borrow_mut().create_local(item))); self.pool.insert(new.clone()); diff --git a/crates/burn-cube/src/element.rs b/crates/burn-cube/src/element.rs index 2a480b62ee..ec64fe42df 100644 --- a/crates/burn-cube/src/element.rs +++ b/crates/burn-cube/src/element.rs @@ -1,16 +1,31 @@ use alloc::rc::Rc; use burn_jit::gpu::{Item, Variable}; +/// Types used in a cube function must implement this trait +/// +/// Variables whose values will be known at runtime must +/// have ExpandElement as associated type +/// Variables whose values will be known at compile time +/// must have the primitive type as associated type +/// +/// Note: Cube functions should be written using RuntimeTypes, +/// so that the code generated uses the associated ExpandType. +/// This allows Cube code to not necessitate cloning, which is cumbersome +/// in algorithmic code. The necessary cloning will automatically appear in +/// the generated code. pub trait RuntimeType { type ExpandType: Clone; } #[derive(new, Clone, Debug)] +/// Reference to a JIT variable +/// It's the expand element that is actually kept in the variable pool pub struct ExpandElement { pub(crate) inner: Rc, } impl ExpandElement { + /// Returns the Item of the variable pub fn item(&self) -> Item { self.inner.item() } diff --git a/crates/burn-cube/tests/cube.rs b/crates/burn-cube/tests/cube.rs index 7760a61900..cedf2aa1f3 100644 --- a/crates/burn-cube/tests/cube.rs +++ b/crates/burn-cube/tests/cube.rs @@ -1,7 +1,7 @@ use burn_cube::{cube, range, range_expand, Array, CubeContext, Float, UInt}; use burn_jit::gpu::{Elem, Item}; -#[cube] +// #[cube] pub fn kernel(mut lhs: Array, rhs: Float, end: UInt, unroll: bool) { let tmp1 = rhs * rhs; let tmp2 = tmp1 + rhs; @@ -11,6 +11,42 @@ pub fn kernel(mut lhs: Array, rhs: Float, end: UInt, unroll: bool) { } } +#[allow(unused_mut)] +pub fn kernel_expand( + context: &mut burn_cube::CubeContext, + mut lhs: as burn_cube::RuntimeType>::ExpandType, + rhs: ::ExpandType, + end: ::ExpandType, + unroll: ::ExpandType, +) -> () { + let tmp1 = { + let _lhs = rhs.clone(); + let _rhs = rhs.clone(); + burn_cube::mul::expand(context, _lhs, _rhs) + }; + let tmp2 = { + let _lhs = tmp1; + let _rhs = rhs; + burn_cube::add::expand(context, _lhs, _rhs) + }; + range_expand(context, 0usize.into(), end, unroll, |context, i| { + { + let _array = lhs.clone(); + let _index = i.clone(); + let _value = { + let _lhs = tmp2.clone(); + let _rhs = { + let _array = lhs.clone(); + let _index = i; + burn_cube::index::expand(context, _array, _index) + }; + burn_cube::add::expand(context, _lhs, _rhs) + }; + burn_cube::index_assign::expand(context, _array, _index, _value) + }; + }); +} + #[test] fn test_simple_add() { let mut context = CubeContext::root(); From 16b454acbb0e94215226e56edf271c551141e5a2 Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 22 Apr 2024 09:30:37 -0400 Subject: [PATCH 11/54] wip --- crates/burn-cube/src/context.rs | 2 +- crates/burn-cube/src/element.rs | 17 +++++++--- crates/burn-cube/tests/cube.rs | 32 ++++++++++++++++++- .../src/codegen/dialect/gpu/macros.rs | 2 ++ crates/burn-jit/src/codegen/mod.rs | 5 ++- crates/burn-jit/src/lib.rs | 2 +- 6 files changed, 49 insertions(+), 11 deletions(-) diff --git a/crates/burn-cube/src/context.rs b/crates/burn-cube/src/context.rs index 813d23bdfd..3effc70c49 100644 --- a/crates/burn-cube/src/context.rs +++ b/crates/burn-cube/src/context.rs @@ -24,7 +24,7 @@ impl VariablePool { // Arbitrarily takes the first it finds for variable in variables.iter() { if Rc::strong_count(&variable.inner) == 1 { - println!("Reuse var {:?}", variable.inner); + // println!("Reuse var {:?}", variable.inner); return Some(variable.clone()); } } diff --git a/crates/burn-cube/src/element.rs b/crates/burn-cube/src/element.rs index ec64fe42df..62aee734de 100644 --- a/crates/burn-cube/src/element.rs +++ b/crates/burn-cube/src/element.rs @@ -7,12 +7,12 @@ use burn_jit::gpu::{Item, Variable}; /// have ExpandElement as associated type /// Variables whose values will be known at compile time /// must have the primitive type as associated type -/// -/// Note: Cube functions should be written using RuntimeTypes, -/// so that the code generated uses the associated ExpandType. +/// +/// Note: Cube functions should be written using RuntimeTypes, +/// so that the code generated uses the associated ExpandType. /// This allows Cube code to not necessitate cloning, which is cumbersome -/// in algorithmic code. The necessary cloning will automatically appear in -/// the generated code. +/// in algorithmic code. The necessary cloning will automatically appear in +/// the generated code. pub trait RuntimeType { type ExpandType: Clone; } @@ -133,3 +133,10 @@ impl From for UInt { UInt::new(value as u32, 1) } } + +impl From for Variable { + fn from(value: ExpandElement) -> Self { + // Is it ok to do that? + (*value.inner).clone() + } +} diff --git a/crates/burn-cube/tests/cube.rs b/crates/burn-cube/tests/cube.rs index cedf2aa1f3..ff0e42b3bf 100644 --- a/crates/burn-cube/tests/cube.rs +++ b/crates/burn-cube/tests/cube.rs @@ -1,5 +1,6 @@ use burn_cube::{cube, range, range_expand, Array, CubeContext, Float, UInt}; -use burn_jit::gpu::{Elem, Item}; +use burn_jit::gpu; +use burn_jit::gpu::{Elem, Item, Variable}; // #[cube] pub fn kernel(mut lhs: Array, rhs: Float, end: UInt, unroll: bool) { @@ -64,3 +65,32 @@ fn test_simple_add() { panic!("nop"); } + +#[test] +fn gpu_macro_test() { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Float); + + let lhs = context.create_local(item); + let rhs = context.create_local(item); + let end = context.create_local(Item::Scalar(Elem::UInt)); + let out = context.create_local(item); + let mut scope = context.into_scope(); + + // Kernel + let tmp1 = scope.create_local(item); + let tmp2 = scope.create_local(item); + let rhs: Variable = rhs.into(); + gpu!(scope, tmp1 = rhs * rhs); + gpu!(scope, tmp2 = tmp1 + rhs); + + for i in range(0usize, end, unroll) { + lhs[i] = tmp2 + lhs[i]; + } + + for op in scope.operations.iter() { + println!("{op:?}"); + } + + panic!("nop"); +} diff --git a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs index 4feef5d0e3..d4d3f1745b 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs @@ -1,5 +1,7 @@ use super::Variable; +#[macro_export(local_inner_macros)] +/// Macro for generating JIT intermediate representation, in a concise way macro_rules! gpu { // out = lhs + rhs ($scope:expr, $out:ident = $lhs:ident + $rhs:expr) => { diff --git a/crates/burn-jit/src/codegen/mod.rs b/crates/burn-jit/src/codegen/mod.rs index 964b2eca03..cd3a3bef6c 100644 --- a/crates/burn-jit/src/codegen/mod.rs +++ b/crates/burn-jit/src/codegen/mod.rs @@ -1,11 +1,10 @@ mod compilation; pub(crate) mod compiler; -pub(crate) mod dialect; +/// Contains Intermediate Representation +pub mod dialect; mod kernel; pub(crate) use compilation::*; pub(crate) use compiler::*; pub(crate) use kernel::*; - -pub use dialect::*; diff --git a/crates/burn-jit/src/lib.rs b/crates/burn-jit/src/lib.rs index e178eabaed..f8095f7a6b 100644 --- a/crates/burn-jit/src/lib.rs +++ b/crates/burn-jit/src/lib.rs @@ -15,7 +15,7 @@ pub mod kernel; /// Tensor module. pub mod tensor; -pub(crate) mod codegen; +pub mod codegen; pub(crate) mod tune; mod element; From d75641275009647c2a11f9f1c37330c84b07a4dc Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 29 Apr 2024 10:01:18 -0400 Subject: [PATCH 12/54] for loop tests --- Cargo.lock | 6 +- crates/burn-cube-macros/src/statement.rs | 5 +- crates/burn-cube/Cargo.toml | 4 +- crates/burn-cube/src/element.rs | 4 + crates/burn-cube/tests/cube.rs | 96 ---------------------- crates/burn-cube/tests/for_loop_dynamic.rs | 57 +++++++++++++ crates/burn-cube/tests/for_loop_static.rs | 73 ++++++++++++++++ crates/burn-cube/tests/gcd.rs | 66 +++++++++++++++ 8 files changed, 207 insertions(+), 104 deletions(-) delete mode 100644 crates/burn-cube/tests/cube.rs create mode 100644 crates/burn-cube/tests/for_loop_dynamic.rs create mode 100644 crates/burn-cube/tests/for_loop_static.rs create mode 100644 crates/burn-cube/tests/gcd.rs diff --git a/Cargo.lock b/Cargo.lock index 3ebbd1aa84..b5920520b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -455,7 +455,7 @@ dependencies = [ [[package]] name = "burn-cube" -version = "0.13.0" +version = "0.14.0" dependencies = [ "burn-cube-macros", "burn-jit", @@ -465,12 +465,12 @@ dependencies = [ [[package]] name = "burn-cube-macros" -version = "0.13.0" +version = "0.14.0" dependencies = [ "derive-new", "proc-macro2", "quote", - "syn 2.0.55", + "syn 2.0.60", ] [[package]] diff --git a/crates/burn-cube-macros/src/statement.rs b/crates/burn-cube-macros/src/statement.rs index 808972813c..44fc188bd1 100644 --- a/crates/burn-cube-macros/src/statement.rs +++ b/crates/burn-cube-macros/src/statement.rs @@ -240,16 +240,15 @@ fn codegen_path( .get_ident() .expect("Only ident path are supported."); - // TODO: Check in the following statements if the ident is overriden, or reused. let will_be_used_again = variable_analyses.should_clone(ident, loop_level); if will_be_used_again { quote::quote! { - #ident.clone() + #ident.clone().into() } } else { quote::quote! { - #ident + #ident.into() } } } diff --git a/crates/burn-cube/Cargo.toml b/crates/burn-cube/Cargo.toml index f78beebd48..c8b47dba1f 100644 --- a/crates/burn-cube/Cargo.toml +++ b/crates/burn-cube/Cargo.toml @@ -15,8 +15,8 @@ default = [] std = [] [dependencies] -burn-jit = { path = "../burn-jit", version = "0.13.0", default-features = false, features= ["autotune"] } -burn-cube-macros = { path = "../burn-cube-macros", version = "0.13.0" } +burn-jit = { path = "../burn-jit", version = "0.14.0", default-features = false, features= ["autotune"] } +burn-cube-macros = { path = "../burn-cube-macros", version = "0.14.0" } derive-new = { workspace = true } log = { workspace = true } diff --git a/crates/burn-cube/src/element.rs b/crates/burn-cube/src/element.rs index 62aee734de..783921e66b 100644 --- a/crates/burn-cube/src/element.rs +++ b/crates/burn-cube/src/element.rs @@ -122,6 +122,10 @@ impl RuntimeType for bool { type ExpandType = bool; } +impl RuntimeType for u32 { + type ExpandType = u32; +} + impl From for UInt { fn from(value: u32) -> Self { UInt::new(value, 1) diff --git a/crates/burn-cube/tests/cube.rs b/crates/burn-cube/tests/cube.rs deleted file mode 100644 index ff0e42b3bf..0000000000 --- a/crates/burn-cube/tests/cube.rs +++ /dev/null @@ -1,96 +0,0 @@ -use burn_cube::{cube, range, range_expand, Array, CubeContext, Float, UInt}; -use burn_jit::gpu; -use burn_jit::gpu::{Elem, Item, Variable}; - -// #[cube] -pub fn kernel(mut lhs: Array, rhs: Float, end: UInt, unroll: bool) { - let tmp1 = rhs * rhs; - let tmp2 = tmp1 + rhs; - - for i in range(0usize, end, unroll) { - lhs[i] = tmp2 + lhs[i]; - } -} - -#[allow(unused_mut)] -pub fn kernel_expand( - context: &mut burn_cube::CubeContext, - mut lhs: as burn_cube::RuntimeType>::ExpandType, - rhs: ::ExpandType, - end: ::ExpandType, - unroll: ::ExpandType, -) -> () { - let tmp1 = { - let _lhs = rhs.clone(); - let _rhs = rhs.clone(); - burn_cube::mul::expand(context, _lhs, _rhs) - }; - let tmp2 = { - let _lhs = tmp1; - let _rhs = rhs; - burn_cube::add::expand(context, _lhs, _rhs) - }; - range_expand(context, 0usize.into(), end, unroll, |context, i| { - { - let _array = lhs.clone(); - let _index = i.clone(); - let _value = { - let _lhs = tmp2.clone(); - let _rhs = { - let _array = lhs.clone(); - let _index = i; - burn_cube::index::expand(context, _array, _index) - }; - burn_cube::add::expand(context, _lhs, _rhs) - }; - burn_cube::index_assign::expand(context, _array, _index, _value) - }; - }); -} - -#[test] -fn test_simple_add() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::Scalar(Elem::Float)); - let rhs = context.create_local(Item::Scalar(Elem::Float)); - let end = context.create_local(Item::Scalar(Elem::UInt)); - - kernel_expand(&mut context, lhs, rhs, end, false); - let scope = context.into_scope(); - - for op in scope.operations.iter() { - println!("{op:?}"); - } - - panic!("nop"); -} - -#[test] -fn gpu_macro_test() { - let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Float); - - let lhs = context.create_local(item); - let rhs = context.create_local(item); - let end = context.create_local(Item::Scalar(Elem::UInt)); - let out = context.create_local(item); - let mut scope = context.into_scope(); - - // Kernel - let tmp1 = scope.create_local(item); - let tmp2 = scope.create_local(item); - let rhs: Variable = rhs.into(); - gpu!(scope, tmp1 = rhs * rhs); - gpu!(scope, tmp2 = tmp1 + rhs); - - for i in range(0usize, end, unroll) { - lhs[i] = tmp2 + lhs[i]; - } - - for op in scope.operations.iter() { - println!("{op:?}"); - } - - panic!("nop"); -} diff --git a/crates/burn-cube/tests/for_loop_dynamic.rs b/crates/burn-cube/tests/for_loop_dynamic.rs new file mode 100644 index 0000000000..23e2dce791 --- /dev/null +++ b/crates/burn-cube/tests/for_loop_dynamic.rs @@ -0,0 +1,57 @@ +use burn_cube::{cube, range, range_expand, Array, CubeContext, Float, UInt}; +use burn_jit::gpu; +use burn_jit::gpu::FloatKind::F32; +use burn_jit::gpu::{Elem, Item, Variable}; + +#[cube] +pub fn kernel(mut lhs: Array, rhs: Float, end: UInt, unroll: bool) { + let tmp1 = rhs * rhs; + let tmp2 = tmp1 + rhs; + + for i in range(0u32, end, unroll) { + lhs[i] = tmp2 + lhs[i]; + } +} + +#[test] +fn test_for_loop() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); + let rhs = context.create_local(Item::Scalar(Elem::Float(F32))); + let end = context.create_local(Item::Scalar(Elem::UInt)); + + kernel_expand(&mut context, lhs, rhs, end, false); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); +} + +fn gpu_macro_ref() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Float(F32)); + + let lhs = context.create_local(item); + let rhs = context.create_local(item); + let lhs: Variable = lhs.into(); + let rhs: Variable = rhs.into(); + let end = context.create_local(Item::Scalar(Elem::UInt)); + let mut scope = context.into_scope(); + + // Kernel + let tmp1 = scope.create_local(item); + let tmp2 = scope.create_local(item); + gpu!(scope, tmp1 = rhs * rhs); + gpu!(scope, tmp2 = tmp1 + rhs); + + gpu!( + &mut scope, + range(0usize, end).for_each(|i, scope| { + gpu!(scope, rhs = lhs[i]); + gpu!(scope, tmp1 = tmp2 + rhs); + gpu!(scope, lhs[i] = tmp1); + }) + ); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/for_loop_static.rs b/crates/burn-cube/tests/for_loop_static.rs new file mode 100644 index 0000000000..e859461d2a --- /dev/null +++ b/crates/burn-cube/tests/for_loop_static.rs @@ -0,0 +1,73 @@ +use burn_cube::{cube, range, range_expand, Array, CubeContext, Float}; +use burn_jit::gpu; +use burn_jit::gpu::FloatKind::F32; +use burn_jit::gpu::{Elem, Item, Variable}; + +#[cube] +pub fn kernel(mut lhs: Array, rhs: Float, end: u32, unroll: bool) { + let tmp1 = rhs * rhs; + let tmp2 = tmp1 + rhs; + + for i in range(0u32, end, unroll) { + lhs[i] = tmp2 + lhs[i]; + } +} + +#[test] +fn test_for_loop_with_unroll() { + let mut context = CubeContext::root(); + let unroll = true; + + let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); + let rhs = context.create_local(Item::Scalar(Elem::Float(F32))); + let end = 4u32; + + kernel_expand(&mut context, lhs, rhs, end, unroll); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref(unroll)); +} + +#[test] +fn test_for_loop_no_unroll() { + let mut context = CubeContext::root(); + let unroll = false; + + let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); + let rhs = context.create_local(Item::Scalar(Elem::Float(F32))); + let end = 4u32; + + kernel_expand(&mut context, lhs, rhs, end, unroll); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref(unroll)); +} + +fn gpu_macro_ref(unroll: bool) -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Float(F32)); + + let lhs = context.create_local(item); + let rhs = context.create_local(item); + let lhs: Variable = lhs.into(); + let rhs: Variable = rhs.into(); + let end = 4u32; + let mut scope = context.into_scope(); + + // Kernel + let tmp1 = scope.create_local(item); + let tmp2 = scope.create_local(item); + gpu!(scope, tmp1 = rhs * rhs); + gpu!(scope, tmp2 = tmp1 + rhs); + + gpu!( + &mut scope, + range(0u32, end, unroll).for_each(|i, scope| { + gpu!(scope, rhs = lhs[i]); + gpu!(scope, tmp1 = tmp2 + rhs); + gpu!(scope, lhs[i] = tmp1); + }) + ); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/gcd.rs b/crates/burn-cube/tests/gcd.rs new file mode 100644 index 0000000000..62a48fa119 --- /dev/null +++ b/crates/burn-cube/tests/gcd.rs @@ -0,0 +1,66 @@ +// use burn_cube::{cube, range, range_expand, Array, CubeContext, Float, UInt}; +// use burn_jit::gpu; +// use burn_jit::gpu::FloatKind::F32; +// use burn_jit::gpu::{Elem, Item, Variable}; + +// #[cube] +// pub fn gcd(lhs: Float, rhs: Float) { +// let tmp1 = rhs * rhs; +// let tmp2 = tmp1 + rhs; + +// for i in range(0usize, end, unroll) { +// lhs[i] = tmp2 + lhs[i]; +// } +// } + +// #[test] +// fn cube_function_test() { +// let mut context = CubeContext::root(); +// // +// let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); +// let rhs = context.create_local(Item::Scalar(Elem::Float(F32))); +// let end = context.create_local(Item::Scalar(Elem::UInt)); + +// kernel_expand(&mut context, lhs, rhs, end, false); +// let scope = context.into_scope(); + +// let mut ops = String::new(); +// for op in scope.operations.iter() { +// ops.push_str(&format!("{op:?}")); +// } + +// assert_eq!(ops, gpu_macro_ref()); +// } + +// fn gpu_macro_ref() -> String { +// let mut context = CubeContext::root(); +// let item = Item::Scalar(Elem::Float(F32)); + +// let lhs = context.create_local(item); +// let rhs = context.create_local(item); +// let lhs: Variable = lhs.into(); +// let rhs: Variable = rhs.into(); +// let end = context.create_local(Item::Scalar(Elem::UInt)); +// let mut scope = context.into_scope(); + +// // Kernel +// let tmp1 = scope.create_local(item); +// let tmp2 = scope.create_local(item); +// gpu!(scope, tmp1 = rhs * rhs); +// gpu!(scope, tmp2 = tmp1 + rhs); + +// gpu!( +// &mut scope, +// range(0usize, end).for_each(|i, scope| { +// gpu!(scope, rhs = lhs[i]); +// gpu!(scope, tmp1 = tmp2 + rhs); +// gpu!(scope, lhs[i] = tmp1); +// }) +// ); + +// let mut ops = String::new(); +// for op in scope.operations.iter() { +// ops.push_str(&format!("{op:?}")); +// } +// ops +// } From faea9d7151ea31f03578b2f752568d59076dee0b Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 29 Apr 2024 17:01:51 -0400 Subject: [PATCH 13/54] wip --- crates/burn-cube-macros/src/analysis.rs | 44 ++++--- .../src/{statement.rs => codegen.rs} | 61 +++++++-- crates/burn-cube-macros/src/lib.rs | 5 +- crates/burn-cube/src/element.rs | 16 +++ crates/burn-cube/src/operation/assignation.rs | 68 ++++++++++ crates/burn-cube/src/operation/base.rs | 29 ++++ .../src/{operation.rs => operation/binary.rs} | 96 +++----------- crates/burn-cube/src/operation/cmp.rs | 39 ++++++ crates/burn-cube/src/operation/mod.rs | 8 ++ crates/burn-cube/tests/for_loop_dynamic.rs | 1 + crates/burn-cube/tests/gcd.rs | 124 ++++++++---------- crates/burn-cube/tests/literal.rs | 32 +++++ 12 files changed, 355 insertions(+), 168 deletions(-) rename crates/burn-cube-macros/src/{statement.rs => codegen.rs} (80%) create mode 100644 crates/burn-cube/src/operation/assignation.rs create mode 100644 crates/burn-cube/src/operation/base.rs rename crates/burn-cube/src/{operation.rs => operation/binary.rs} (62%) create mode 100644 crates/burn-cube/src/operation/cmp.rs create mode 100644 crates/burn-cube/src/operation/mod.rs create mode 100644 crates/burn-cube/tests/literal.rs diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs index 04e1d3f73b..0462c45d4d 100644 --- a/crates/burn-cube-macros/src/analysis.rs +++ b/crates/burn-cube-macros/src/analysis.rs @@ -96,10 +96,10 @@ fn signature_declarations(sig: &syn::Signature, declarations: &mut Vec<(Variable let id = &pat_ident.ident; declarations.push((id.into(), 0)); } - _ => todo!(), + _ => todo!("Analysis: unsupported ident {ident:?}"), } } - _ => todo!(), + _ => todo!("Analysis: unsupported input {input:?}"), } } } @@ -113,19 +113,22 @@ fn stmts_occurrences( for stmt in stmts { match stmt { // Declaration - syn::Stmt::Local(local) => match &local.pat { - syn::Pat::Ident(pat_ident) => { - let id = &pat_ident.ident; - declarations.push((id.into(), depth)); - - if let Some(local_init) = &local.init { - expr_occurrences(&local_init.expr, depth, declarations, uses) - } + syn::Stmt::Local(local) => { + let id = match &local.pat { + syn::Pat::Ident(pat_ident) => &pat_ident.ident, + syn::Pat::Type(pat_type) => match &*pat_type.pat { + syn::Pat::Ident(pat_ident) => &pat_ident.ident, + _ => todo!("Analysis: unsupported typed path {:?}", pat_type.pat), + }, + _ => todo!("Analysis: unsupported path {:?}", local.pat), + }; + declarations.push((id.into(), depth)); + if let Some(local_init) = &local.init { + expr_occurrences(&local_init.expr, depth, declarations, uses) } - _ => todo!(), - }, + } syn::Stmt::Expr(expr, _) => expr_occurrences(expr, depth, declarations, uses), - _ => todo!(), + _ => todo!("Analysis: unsupported stmt {stmt:?}"), } } } @@ -148,6 +151,12 @@ fn expr_occurrences( stmts_occurrences(&expr.body.stmts, depth, declarations, uses); } + syn::Expr::While(expr) => { + let depth = depth + 1; + + expr_occurrences(&expr.cond, depth, declarations, uses); + stmts_occurrences(&expr.body.stmts, depth, declarations, uses); + } syn::Expr::Assign(expr) => { expr_occurrences(&expr.left, depth, declarations, uses); expr_occurrences(&expr.right, depth, declarations, uses); @@ -160,7 +169,7 @@ fn expr_occurrences( let ident = expr .path .get_ident() - .expect("Only ident path are supported."); + .expect("Analysis: only ident path are supported."); // Use uses.push(ident.into()); @@ -169,6 +178,11 @@ fn expr_occurrences( expr_occurrences(&expr.left, depth, declarations, uses); expr_occurrences(&expr.right, depth, declarations, uses); } - _ => todo!(), + syn::Expr::MethodCall(expr) => { + if expr.args.is_empty() { + panic!("Analysis: method call with args is unsupported") + } + } + _ => todo!("Analysis: unsupported expr {expr:?}"), } } diff --git a/crates/burn-cube-macros/src/statement.rs b/crates/burn-cube-macros/src/codegen.rs similarity index 80% rename from crates/burn-cube-macros/src/statement.rs rename to crates/burn-cube-macros/src/codegen.rs index 44fc188bd1..bf11e0c100 100644 --- a/crates/burn-cube-macros/src/statement.rs +++ b/crates/burn-cube-macros/src/codegen.rs @@ -34,7 +34,11 @@ fn codegen_local( .expect("Can't use let without an initialization."); let ident = match &local.pat { syn::Pat::Ident(ident) => ident, - _ => panic!("Only ident declaration is supported."), + syn::Pat::Type(pat_type) => match &*pat_type.pat { + syn::Pat::Ident(pat_ident) => pat_ident, + _ => todo!("Codegen: Unsupported typed path {:?}", pat_type.pat), + }, + _ => todo!("Codegen: Declaration {:?} is unsupported.", local.pat), }; let init = codegen_expr(&init.expr, loop_level, variable_analyses); @@ -54,17 +58,24 @@ fn codegen_expr( syn::Expr::Binary(op) => codegen_binary(op, loop_level, variable_analyses), syn::Expr::Path(path) => codegen_path(path, loop_level, variable_analyses), syn::Expr::Call(call) => codegen_call(call, loop_level, variable_analyses), - syn::Expr::Lit(lit) => quote::quote! { #lit.into() }, + syn::Expr::Lit(lit) => codegen_lit(lit), syn::Expr::Closure(closure) => codegen_closure(closure, loop_level, variable_analyses), syn::Expr::Block(block) => codegen_expr_block(block, loop_level, variable_analyses), syn::Expr::Assign(assign) => codegen_assign(assign, loop_level, variable_analyses), syn::Expr::ForLoop(for_loop) => codegen_for_loop(for_loop, loop_level, variable_analyses), + syn::Expr::While(while_loop) => { + codegen_while_loop(while_loop, loop_level, variable_analyses) + } syn::Expr::MethodCall(call) => codegen_expr_method_call(call), syn::Expr::Index(index) => codegen_expr_index(index, loop_level, variable_analyses), _ => panic!("Unsupported {:?}", expr), } } +fn codegen_lit(lit: &syn::ExprLit) -> TokenStream { + quote::quote! { #lit.into() } +} + fn codegen_expr_index( index: &syn::ExprIndex, loop_level: usize, @@ -98,7 +109,7 @@ fn codegen_for_loop( syn::Expr::Call(call) => { let func_name = match call.func.as_ref() { syn::Expr::Path(path) => path.path.get_ident().unwrap(), - _ => todo!("Only path call supported"), + _ => todo!("Codegen: Only path call supported"), }; if &func_name.to_string() == "range" { @@ -115,10 +126,28 @@ fn codegen_for_loop( range_expand(#args |context, #i| #block); } } else { - todo!("Only range is supported") + todo!("Codegen: Only range is supported") } } - _ => todo!("Only call is supported {for_loop:?}"), + _ => todo!("Codegen: Only call is supported {for_loop:?}"), + } +} + +fn codegen_while_loop( + while_loop: &syn::ExprWhile, + loop_level: usize, + variable_analyses: &mut VariableAnalyses, +) -> TokenStream { + let block = codegen_block(&while_loop.body, loop_level + 1, variable_analyses); + + let cond = match while_loop.cond.as_ref() { + syn::Expr::Binary(expr) => codegen_binary(expr, loop_level, variable_analyses), + syn::Expr::Lit(expr) => codegen_lit(expr), + _ => todo!("{while_loop:?} cond not supported"), + }; + + quote::quote! { + loop_expand(#cond |context| #block); } } @@ -189,7 +218,7 @@ fn codegen_closure( for input in closure.inputs.iter() { let ident = match input { syn::Pat::Ident(ident) => &ident.ident, - _ => panic!("Unsupported {:?}", input), + _ => panic!("Codegen: Unsupported {:?}", input), }; inputs.extend(quote::quote! { #ident, @@ -210,7 +239,7 @@ fn codegen_call( ) -> TokenStream { let func_name = match call.func.as_ref() { syn::Expr::Path(path) => path.path.get_ident().unwrap(), - _ => todo!("Only path call supported"), + _ => todo!("Codegen: Only path call supported"), }; let mut args = quote::quote! { @@ -238,7 +267,7 @@ fn codegen_path( let ident = path .path .get_ident() - .expect("Only ident path are supported."); + .expect("Codegen: Only ident path are supported."); let will_be_used_again = variable_analyses.should_clone(ident, loop_level); @@ -290,6 +319,20 @@ fn codegen_binary( burn_cube::div::expand(context, _lhs, _rhs) } }, - _ => todo!("{:?}", binary.op), + syn::BinOp::Rem(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::rem::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::Ne(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::ne::expand(context, _lhs, _rhs) + } + }, + _ => todo!("Codegen: unsupported op {:?}", binary.op), } } diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index 3be328789a..8979cd5716 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -1,9 +1,9 @@ mod analysis; -mod statement; +mod codegen; use analysis::VariableAnalyses; use proc_macro::TokenStream; -use statement::codegen_statement; +use codegen::codegen_statement; /// Derive macro for the module. #[proc_macro_attribute] @@ -32,6 +32,7 @@ fn codegen_cube(func: &syn::ItemFn, variable_analyses: &mut VariableAnalyses) -> let signature = expand_sig(&func.sig); let mut body = quote::quote! {}; + // panic!("WG"); for statement in func.block.stmts.iter() { let tokens = codegen_statement(statement, 0, variable_analyses); body.extend(tokens); diff --git a/crates/burn-cube/src/element.rs b/crates/burn-cube/src/element.rs index 783921e66b..51a79083d6 100644 --- a/crates/burn-cube/src/element.rs +++ b/crates/burn-cube/src/element.rs @@ -49,6 +49,12 @@ impl From for ExpandElement { } } +impl From for ExpandElement { + fn from(value: f32) -> Self { + ExpandElement::new(Rc::new(Variable::from(value))) + } +} + impl core::ops::Deref for ExpandElement { type Target = Variable; @@ -132,6 +138,16 @@ impl From for UInt { } } +impl RuntimeType for f32 { + type ExpandType = f32; +} + +impl From for Float { + fn from(value: f32) -> Self { + Float::new(value, 1) + } +} + impl From for UInt { fn from(value: usize) -> Self { UInt::new(value as u32, 1) diff --git a/crates/burn-cube/src/operation/assignation.rs b/crates/burn-cube/src/operation/assignation.rs new file mode 100644 index 0000000000..3ac02e7e79 --- /dev/null +++ b/crates/burn-cube/src/operation/assignation.rs @@ -0,0 +1,68 @@ +use crate::{Array, CubeContext, ExpandElement, UInt}; +use burn_jit::gpu::{self}; + +pub mod assign { + use super::*; + + pub fn expand( + context: &mut CubeContext, + input: ExpandElement, + output: ExpandElement, + ) -> ExpandElement { + let input = *input; + let out = *output; + + context.register(gpu::Operator::Assign(gpu::UnaryOperator { input, out })); + + output + } +} + +pub mod index_assign { + use crate::RuntimeType; + + use super::*; + + pub fn expand( + context: &mut CubeContext, + array: ExpandElement, + index: ExpandElement, + value: ExpandElement, + ) { + context.register(gpu::Operator::IndexAssign(gpu::BinaryOperator { + lhs: *index, + rhs: *value, + out: *array, + })) + } + + impl> core::ops::IndexMut for Array { + fn index_mut(&mut self, index: I) -> &mut Self::Output { + let index = index.into().val; + &mut self.vals[index as usize] + } + } +} + +pub mod index { + use crate::{operation::base::binary_expand, RuntimeType}; + + use super::*; + + pub fn expand( + context: &mut CubeContext, + array: ExpandElement, + index: ExpandElement, + ) -> ExpandElement { + binary_expand(context, array, index, gpu::Operator::Index) + } + + impl> core::ops::Index for Array { + type Output = E; + + fn index(&self, index: I) -> &Self::Output { + let index = index.into().val; + &self.vals[index as usize] + } + } +} diff --git a/crates/burn-cube/src/operation/base.rs b/crates/burn-cube/src/operation/base.rs new file mode 100644 index 0000000000..59bd48e0d0 --- /dev/null +++ b/crates/burn-cube/src/operation/base.rs @@ -0,0 +1,29 @@ +use crate::{CubeContext, ExpandElement}; +use burn_jit::gpu::{self, Variable}; + +pub(crate) fn binary_expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + func: F, +) -> ExpandElement +where + F: Fn(gpu::BinaryOperator) -> gpu::Operator, +{ + let lhs: Variable = *lhs; + let rhs: Variable = *rhs; + + let item = lhs.item(); + let out = context.create_local(item); + let out_var = *out; + + let op = func(gpu::BinaryOperator { + lhs, + rhs, + out: out_var, + }); + + context.register(op); + + out +} diff --git a/crates/burn-cube/src/operation.rs b/crates/burn-cube/src/operation/binary.rs similarity index 62% rename from crates/burn-cube/src/operation.rs rename to crates/burn-cube/src/operation/binary.rs index 31fae41b5f..6c5c325da2 100644 --- a/crates/burn-cube/src/operation.rs +++ b/crates/burn-cube/src/operation/binary.rs @@ -1,5 +1,6 @@ -use crate::{Array, CubeContext, ExpandElement, Float, Int, UInt}; -use burn_jit::gpu::{self, Variable}; +use crate::operation::base::binary_expand; +use crate::{CubeContext, ExpandElement, Float, Int, UInt}; +use burn_jit::gpu::{self}; pub mod add { use super::*; @@ -145,95 +146,38 @@ pub mod div { } } -pub mod assign { +pub mod rem { use super::*; pub fn expand( context: &mut CubeContext, - input: ExpandElement, - output: ExpandElement, + lhs: ExpandElement, + rhs: ExpandElement, ) -> ExpandElement { - let input = *input; - let out = *output; - - context.register(gpu::Operator::Assign(gpu::UnaryOperator { input, out })); - - output + binary_expand(context, lhs, rhs, gpu::Operator::Modulo) } -} - -pub mod index { - use crate::RuntimeType; - use super::*; + impl core::ops::Rem for Float { + type Output = Self; - pub fn expand( - context: &mut CubeContext, - array: ExpandElement, - index: ExpandElement, - ) -> ExpandElement { - binary_expand(context, array, index, gpu::Operator::Index) + fn rem(self, rhs: Self) -> Self::Output { + Float::new(self.val % rhs.val, 1) + } } - impl> core::ops::Index for Array { - type Output = E; + impl core::ops::Rem for Int { + type Output = Self; - fn index(&self, index: I) -> &Self::Output { - let index = index.into().val; - &self.vals[index as usize] + fn rem(self, rhs: Self) -> Self::Output { + Int::new(self.val % rhs.val, 1) } } -} -pub mod index_assign { - use crate::RuntimeType; - - use super::*; + impl core::ops::Rem for UInt { + type Output = Self; - pub fn expand( - context: &mut CubeContext, - array: ExpandElement, - index: ExpandElement, - value: ExpandElement, - ) { - context.register(gpu::Operator::IndexAssign(gpu::BinaryOperator { - lhs: *index, - rhs: *value, - out: *array, - })) - } - - impl> core::ops::IndexMut for Array { - fn index_mut(&mut self, index: I) -> &mut Self::Output { - let index = index.into().val; - &mut self.vals[index as usize] + fn rem(self, rhs: Self) -> Self::Output { + UInt::new(self.val % rhs.val, 1) } } } - -fn binary_expand( - context: &mut CubeContext, - lhs: ExpandElement, - rhs: ExpandElement, - func: F, -) -> ExpandElement -where - F: Fn(gpu::BinaryOperator) -> gpu::Operator, -{ - let lhs: Variable = *lhs; - let rhs: Variable = *rhs; - - let item = lhs.item(); - let out = context.create_local(item); - let out_var = *out; - - let op = func(gpu::BinaryOperator { - lhs, - rhs, - out: out_var, - }); - - context.register(op); - - out -} diff --git a/crates/burn-cube/src/operation/cmp.rs b/crates/burn-cube/src/operation/cmp.rs new file mode 100644 index 0000000000..ad966c61ec --- /dev/null +++ b/crates/burn-cube/src/operation/cmp.rs @@ -0,0 +1,39 @@ +use crate::operation::base::binary_expand; +use crate::{CubeContext, ExpandElement, Float, Int, UInt}; +use burn_jit::gpu::{self}; + +impl core::cmp::PartialEq for Float { + fn eq(&self, other: &Self) -> bool { + self.val == other.val && self.vectorization == other.vectorization + } +} + +impl core::cmp::PartialEq for Int { + fn eq(&self, other: &Self) -> bool { + self.val == other.val && self.vectorization == other.vectorization + } +} + +impl core::cmp::PartialEq for UInt { + fn eq(&self, other: &Self) -> bool { + self.val == other.val && self.vectorization == other.vectorization + } +} + +impl core::cmp::Eq for Float {} + +impl core::cmp::Eq for Int {} + +impl core::cmp::Eq for UInt {} + +pub mod ne { + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + binary_expand(context, lhs, rhs, gpu::Operator::NotEqual) + } +} diff --git a/crates/burn-cube/src/operation/mod.rs b/crates/burn-cube/src/operation/mod.rs new file mode 100644 index 0000000000..fc4d8c0abe --- /dev/null +++ b/crates/burn-cube/src/operation/mod.rs @@ -0,0 +1,8 @@ +mod base; +mod assignation; +mod binary; +mod cmp; + +pub use assignation::*; +pub use binary::*; +pub use cmp::*; diff --git a/crates/burn-cube/tests/for_loop_dynamic.rs b/crates/burn-cube/tests/for_loop_dynamic.rs index 23e2dce791..4dd8634ffb 100644 --- a/crates/burn-cube/tests/for_loop_dynamic.rs +++ b/crates/burn-cube/tests/for_loop_dynamic.rs @@ -20,6 +20,7 @@ fn test_for_loop() { let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); let rhs = context.create_local(Item::Scalar(Elem::Float(F32))); let end = context.create_local(Item::Scalar(Elem::UInt)); + // let end =Variable::ConstantScalar(5.0, Item::Scalar(Elem::UInt)); kernel_expand(&mut context, lhs, rhs, end, false); let scope = context.into_scope(); diff --git a/crates/burn-cube/tests/gcd.rs b/crates/burn-cube/tests/gcd.rs index 62a48fa119..202e110657 100644 --- a/crates/burn-cube/tests/gcd.rs +++ b/crates/burn-cube/tests/gcd.rs @@ -1,66 +1,58 @@ -// use burn_cube::{cube, range, range_expand, Array, CubeContext, Float, UInt}; -// use burn_jit::gpu; -// use burn_jit::gpu::FloatKind::F32; -// use burn_jit::gpu::{Elem, Item, Variable}; - -// #[cube] -// pub fn gcd(lhs: Float, rhs: Float) { -// let tmp1 = rhs * rhs; -// let tmp2 = tmp1 + rhs; - -// for i in range(0usize, end, unroll) { -// lhs[i] = tmp2 + lhs[i]; -// } -// } - -// #[test] -// fn cube_function_test() { -// let mut context = CubeContext::root(); -// // -// let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); -// let rhs = context.create_local(Item::Scalar(Elem::Float(F32))); -// let end = context.create_local(Item::Scalar(Elem::UInt)); - -// kernel_expand(&mut context, lhs, rhs, end, false); -// let scope = context.into_scope(); - -// let mut ops = String::new(); -// for op in scope.operations.iter() { -// ops.push_str(&format!("{op:?}")); -// } - -// assert_eq!(ops, gpu_macro_ref()); -// } - -// fn gpu_macro_ref() -> String { -// let mut context = CubeContext::root(); -// let item = Item::Scalar(Elem::Float(F32)); - -// let lhs = context.create_local(item); -// let rhs = context.create_local(item); -// let lhs: Variable = lhs.into(); -// let rhs: Variable = rhs.into(); -// let end = context.create_local(Item::Scalar(Elem::UInt)); -// let mut scope = context.into_scope(); - -// // Kernel -// let tmp1 = scope.create_local(item); -// let tmp2 = scope.create_local(item); -// gpu!(scope, tmp1 = rhs * rhs); -// gpu!(scope, tmp2 = tmp1 + rhs); - -// gpu!( -// &mut scope, -// range(0usize, end).for_each(|i, scope| { -// gpu!(scope, rhs = lhs[i]); -// gpu!(scope, tmp1 = tmp2 + rhs); -// gpu!(scope, lhs[i] = tmp1); -// }) -// ); - -// let mut ops = String::new(); -// for op in scope.operations.iter() { -// ops.push_str(&format!("{op:?}")); -// } -// ops -// } +use burn_cube::{cube, CubeContext, Int}; +use burn_jit::gpu; +use burn_jit::gpu::Branch; +use burn_jit::gpu::IntKind::I32; +use burn_jit::gpu::{Elem, Item, Variable}; + +#[cube] +pub fn gcd(lhs: Int, rhs: Int) { + while rhs != 0u32 { + let tmp = rhs; + rhs = lhs % rhs; + lhs = tmp + } + // use lhs as output +} + +#[test] +fn cube_function_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); + let rhs = context.create_local(Item::Scalar(Elem::Int(I32))); + + gcd_expand(&mut context, lhs, rhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); +} + +fn gpu_macro_ref() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Int(I32)); + + let lhs = context.create_local(item); + let rhs = context.create_local(item); + let lhs: Variable = lhs.into(); + let rhs: Variable = rhs.into(); + let mut scope = context.into_scope(); + + // Kernel + let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let tmp = scope.create_local(Item::Scalar(Elem::Int(I32))); + gpu!( + &mut scope, + loop(|scope| { + gpu!(scope, cond = rhs != 0); + gpu!(scope, if(cond).then(|scope|{ + scope.register(Branch::Break); + })); + + gpu!(scope, tmp = rhs); + gpu!(scope, rhs = lhs % rhs); + gpu!(scope, lhs = tmp); + }) + ); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/literal.rs b/crates/burn-cube/tests/literal.rs new file mode 100644 index 0000000000..a24ae9af37 --- /dev/null +++ b/crates/burn-cube/tests/literal.rs @@ -0,0 +1,32 @@ +use burn_cube::{cube, CubeContext, Float}; +use burn_jit::gpu::FloatKind::F32; +use burn_jit::gpu::{Elem, Item, Variable}; + +#[cube] +pub fn literal(lhs: Float) { + let rhs: Float = 5.9f32.into(); +} + +#[test] +fn cube_literal_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); + + literal_expand(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); +} + +fn gpu_macro_ref() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Float(F32)); + + let lhs = context.create_local(item); + let lhs: Variable = lhs.into(); + let mut scope = context.into_scope(); + scope.create_with_value(5.9, item); + + format!("{:?}", scope.operations) +} From 709522e4b9b1dc8194ad97da8d2e86df70e11207 Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 29 Apr 2024 19:20:33 -0400 Subject: [PATCH 14/54] refactor identify leaves and deletables --- .../src/runtime/memory_management.rs | 62 ++++++++++++------- .../src/tests/memory_management.rs | 22 +++++++ 2 files changed, 61 insertions(+), 23 deletions(-) diff --git a/crates/burn-autodiff/src/runtime/memory_management.rs b/crates/burn-autodiff/src/runtime/memory_management.rs index 134dd461d8..69f0e0d679 100644 --- a/crates/burn-autodiff/src/runtime/memory_management.rs +++ b/crates/burn-autodiff/src/runtime/memory_management.rs @@ -161,35 +161,51 @@ impl GraphMemoryManagement { } } - fn is_referenced(&self, node_id: NodeID) -> bool { - match self.nodes.get_key_value(&node_id) { - Some((key, _value)) => Arc::strong_count(key) > 1, - None => panic!("Node should be in the nodes map"), - } - } - fn identify_leaves_and_deletables( &self, - node_id: NodeID, + leaf_id: NodeID, new_leaves: &mut HashSet, to_delete: &mut Vec, ) { - let current_status = self - .statuses - .get(&node_id) - .expect("Node should have status"); - - match current_status { - NodeMemoryStatus::Useful => { - new_leaves.insert(node_id); - } - _ => { - let parents = self.nodes.get(&node_id).cloned().unwrap_or(vec![]); - for parent in parents { - self.identify_leaves_and_deletables(parent, new_leaves, to_delete) + let mut visited = HashSet::new(); + let mut to_visit = Vec::new(); + + to_visit.push(leaf_id); + + while let Some(node_id) = to_visit.pop() { + visited.insert(node_id); + + match self + .statuses + .get(&node_id) + .expect("Node should have status") + { + NodeMemoryStatus::Useful => { + new_leaves.insert(node_id); } - to_delete.push(node_id); - } + _ => { + to_delete.push(node_id); + + for parent in self + .nodes + .get(&node_id) + .cloned() + .unwrap_or(vec![]) + .into_iter() + { + if !visited.contains(&parent) { + to_visit.push(parent); + } + } + } + }; + } + } + + fn is_referenced(&self, node_id: NodeID) -> bool { + match self.nodes.get_key_value(&node_id) { + Some((key, _value)) => Arc::strong_count(key) > 1, + None => panic!("Node should be in the nodes map"), } } } diff --git a/crates/burn-autodiff/src/tests/memory_management.rs b/crates/burn-autodiff/src/tests/memory_management.rs index 716afd5e9d..4f3495c75b 100644 --- a/crates/burn-autodiff/src/tests/memory_management.rs +++ b/crates/burn-autodiff/src/tests/memory_management.rs @@ -238,4 +238,26 @@ mod tests { assert!(tensor_2.grad(&grads).is_some()); assert!(tensor_3.grad(&grads).is_none()); } + + #[test] + #[should_panic] + fn test_mm_deletables_propagate_well() { + let data = Data::from([[1.0, 2.0], [3.0, 4.0]]); + let device = Default::default(); + + let tensor_0 = + Tensor::::from_data(data.clone(), &device).require_grad(); + let tensor_1 = + Tensor::::from_data(data.clone(), &device).require_grad(); + + let tensor_2 = tensor_0 * tensor_1; + let tensor_3 = tensor_2.clone().exp(); + let tensor_4 = tensor_3.clone().log(); + + let grads = tensor_2.backward(); + + // We are testing that after backward on tensor_2, not only the leaf tensor_4 is deleted, but + // the intermediate tensor_3 as well + let grads = tensor_3.backward(); + } } From 7a255256553c5a77170c6a976ef9a90a7145a888 Mon Sep 17 00:00:00 2001 From: louisfd Date: Tue, 30 Apr 2024 20:06:34 -0400 Subject: [PATCH 15/54] wip --- crates/burn-cube-macros/src/analysis.rs | 39 +++++--- crates/burn-cube-macros/src/codegen.rs | 54 +++++++---- crates/burn-cube-macros/src/lib.rs | 13 ++- crates/burn-cube-macros/src/prelude.rs | 53 +++++++++++ crates/burn-cube/src/element.rs | 18 +++- crates/burn-cube/tests/gcd.rs | 116 ++++++++++++------------ crates/burn-cube/tests/literal.rs | 11 ++- 7 files changed, 204 insertions(+), 100 deletions(-) create mode 100644 crates/burn-cube-macros/src/prelude.rs diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs index 0462c45d4d..2f03c02187 100644 --- a/crates/burn-cube-macros/src/analysis.rs +++ b/crates/burn-cube-macros/src/analysis.rs @@ -22,17 +22,18 @@ impl VariableAnalysis { } #[derive(Debug)] -pub(crate) struct VariableAnalyses { - pub analyses: HashMap, +pub(crate) struct CodeAnalysis { + pub needed_functions: Vec, + pub variable_analyses: HashMap, } -impl VariableAnalyses { +impl CodeAnalysis { pub fn should_clone(&mut self, ident: &syn::Ident, loop_level: usize) -> bool { let key: VariableKey = ident.into(); - match self.analyses.remove(&key) { + match self.variable_analyses.remove(&key) { Some(mut var) => { let should_clone = var.should_clone(loop_level); - self.analyses.insert(key, var); + self.variable_analyses.insert(key, var); return should_clone; } None => panic!("Ident {ident} not part of analysis"), @@ -41,8 +42,12 @@ impl VariableAnalyses { pub fn create(func: &syn::ItemFn) -> Self { let analyses = analyze(func); + // TODO WE WANT TO KEEP TRACK OF USED FUNCTIONS AS WELL + // AT THIS POINT LET'S MAKE A BUILDER PATTERN INSTEAD - Self { analyses } + Self { + variable_analyses: analyses, + } } } @@ -115,14 +120,17 @@ fn stmts_occurrences( // Declaration syn::Stmt::Local(local) => { let id = match &local.pat { - syn::Pat::Ident(pat_ident) => &pat_ident.ident, - syn::Pat::Type(pat_type) => match &*pat_type.pat { + syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident), + syn::Pat::Type(pat_type) => Some(match &*pat_type.pat { syn::Pat::Ident(pat_ident) => &pat_ident.ident, _ => todo!("Analysis: unsupported typed path {:?}", pat_type.pat), - }, + }), + syn::Pat::Wild(_) => None, _ => todo!("Analysis: unsupported path {:?}", local.pat), }; - declarations.push((id.into(), depth)); + if let Some(id) = id { + declarations.push((id.into(), depth)); + } if let Some(local_init) = &local.init { expr_occurrences(&local_init.expr, depth, declarations, uses) } @@ -178,9 +186,16 @@ fn expr_occurrences( expr_occurrences(&expr.left, depth, declarations, uses); expr_occurrences(&expr.right, depth, declarations, uses); } + syn::Expr::Lit(_) => {} + syn::Expr::Call(expr) => { + for arg in expr.args.iter() { + expr_occurrences(arg, depth, declarations, uses); + } + } syn::Expr::MethodCall(expr) => { - if expr.args.is_empty() { - panic!("Analysis: method call with args is unsupported") + expr_occurrences(&expr.receiver, depth, declarations, uses); + for arg in expr.args.iter() { + expr_occurrences(arg, depth, declarations, uses); } } _ => todo!("Analysis: unsupported expr {expr:?}"), diff --git a/crates/burn-cube-macros/src/codegen.rs b/crates/burn-cube-macros/src/codegen.rs index bf11e0c100..95f7138f7a 100644 --- a/crates/burn-cube-macros/src/codegen.rs +++ b/crates/burn-cube-macros/src/codegen.rs @@ -1,11 +1,11 @@ use proc_macro2::TokenStream; -use crate::analysis::VariableAnalyses; +use crate::analysis::CodeAnalysis; pub fn codegen_statement( statement: &syn::Stmt, loop_level: usize, - variable_analyses: &mut VariableAnalyses, + variable_analyses: &mut CodeAnalysis, ) -> TokenStream { match statement { syn::Stmt::Local(local) => codegen_local(local, loop_level, variable_analyses), @@ -26,24 +26,32 @@ pub fn codegen_statement( fn codegen_local( local: &syn::Local, loop_level: usize, - variable_analyses: &mut VariableAnalyses, + variable_analyses: &mut CodeAnalysis, ) -> TokenStream { let init = local .init .as_ref() .expect("Can't use let without an initialization."); + + let init = codegen_expr(&init.expr, loop_level, variable_analyses); + + let let_tok = local.let_token; + + if let syn::Pat::Wild(_) = &local.pat { + return quote::quote! { + #let_tok _ = #init; + }; + } + let ident = match &local.pat { syn::Pat::Ident(ident) => ident, syn::Pat::Type(pat_type) => match &*pat_type.pat { syn::Pat::Ident(pat_ident) => pat_ident, _ => todo!("Codegen: Unsupported typed path {:?}", pat_type.pat), }, + syn::Pat::Wild(_) => unreachable!(), _ => todo!("Codegen: Declaration {:?} is unsupported.", local.pat), }; - let init = codegen_expr(&init.expr, loop_level, variable_analyses); - - let let_tok = local.let_token; - quote::quote! { #let_tok #ident = #init; } @@ -52,7 +60,7 @@ fn codegen_local( fn codegen_expr( expr: &syn::Expr, loop_level: usize, - variable_analyses: &mut VariableAnalyses, + variable_analyses: &mut CodeAnalysis, ) -> TokenStream { match expr { syn::Expr::Binary(op) => codegen_binary(op, loop_level, variable_analyses), @@ -79,7 +87,7 @@ fn codegen_lit(lit: &syn::ExprLit) -> TokenStream { fn codegen_expr_index( index: &syn::ExprIndex, loop_level: usize, - variable_analyses: &mut VariableAnalyses, + variable_analyses: &mut CodeAnalysis, ) -> TokenStream { let array = codegen_expr(&index.expr, loop_level, variable_analyses); let index = codegen_expr(&index.index, loop_level, variable_analyses); @@ -100,7 +108,7 @@ fn codegen_expr_method_call(call: &syn::ExprMethodCall) -> TokenStream { fn codegen_for_loop( for_loop: &syn::ExprForLoop, loop_level: usize, - variable_analyses: &mut VariableAnalyses, + variable_analyses: &mut CodeAnalysis, ) -> TokenStream { let i = &for_loop.pat; let block = codegen_block(&for_loop.body, loop_level + 1, variable_analyses); @@ -108,7 +116,10 @@ fn codegen_for_loop( match for_loop.expr.as_ref() { syn::Expr::Call(call) => { let func_name = match call.func.as_ref() { - syn::Expr::Path(path) => path.path.get_ident().unwrap(), + syn::Expr::Path(path) => path + .path + .get_ident() + .expect("Codegen: func in for loop should have ident"), _ => todo!("Codegen: Only path call supported"), }; @@ -136,7 +147,7 @@ fn codegen_for_loop( fn codegen_while_loop( while_loop: &syn::ExprWhile, loop_level: usize, - variable_analyses: &mut VariableAnalyses, + variable_analyses: &mut CodeAnalysis, ) -> TokenStream { let block = codegen_block(&while_loop.body, loop_level + 1, variable_analyses); @@ -154,7 +165,7 @@ fn codegen_while_loop( fn codegen_assign( assign: &syn::ExprAssign, loop_level: usize, - variable_analyses: &mut VariableAnalyses, + variable_analyses: &mut CodeAnalysis, ) -> TokenStream { if let syn::Expr::Index(index) = assign.left.as_ref() { let array = codegen_expr(&index.expr, loop_level, variable_analyses); @@ -186,7 +197,7 @@ fn codegen_assign( fn codegen_block( block: &syn::Block, loop_level: usize, - variable_analyses: &mut VariableAnalyses, + variable_analyses: &mut CodeAnalysis, ) -> TokenStream { let mut statements = quote::quote!(); @@ -204,7 +215,7 @@ fn codegen_block( fn codegen_expr_block( block: &syn::ExprBlock, loop_level: usize, - variable_analyses: &mut VariableAnalyses, + variable_analyses: &mut CodeAnalysis, ) -> TokenStream { codegen_block(&block.block, loop_level, variable_analyses) } @@ -212,7 +223,7 @@ fn codegen_expr_block( fn codegen_closure( closure: &syn::ExprClosure, loop_level: usize, - variable_analyses: &mut VariableAnalyses, + variable_analyses: &mut CodeAnalysis, ) -> TokenStream { let mut inputs = quote::quote! {}; for input in closure.inputs.iter() { @@ -235,10 +246,13 @@ fn codegen_closure( fn codegen_call( call: &syn::ExprCall, loop_level: usize, - variable_analyses: &mut VariableAnalyses, + variable_analyses: &mut CodeAnalysis, ) -> TokenStream { let func_name = match call.func.as_ref() { - syn::Expr::Path(path) => path.path.get_ident().unwrap(), + syn::Expr::Path(path) => path + .path + .get_ident() + .expect("Codegen: func called path should have ident"), _ => todo!("Codegen: Only path call supported"), }; @@ -262,7 +276,7 @@ fn codegen_call( fn codegen_path( path: &syn::ExprPath, loop_level: usize, - variable_analyses: &mut VariableAnalyses, + variable_analyses: &mut CodeAnalysis, ) -> TokenStream { let ident = path .path @@ -285,7 +299,7 @@ fn codegen_path( fn codegen_binary( binary: &syn::ExprBinary, loop_level: usize, - variable_analyses: &mut VariableAnalyses, + variable_analyses: &mut CodeAnalysis, ) -> TokenStream { let lhs = codegen_expr(&binary.left, loop_level, variable_analyses); let rhs = codegen_expr(&binary.right, loop_level, variable_analyses); diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index 8979cd5716..10ed4f0488 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -1,15 +1,17 @@ mod analysis; mod codegen; +mod prelude; -use analysis::VariableAnalyses; -use proc_macro::TokenStream; +use analysis::CodeAnalysis; use codegen::codegen_statement; +use prelude::get_prelude; +use proc_macro::TokenStream; /// Derive macro for the module. #[proc_macro_attribute] pub fn cube(_attr: TokenStream, tokens: TokenStream) -> TokenStream { let func: syn::ItemFn = syn::parse(tokens).unwrap(); - let mut variable_analyses = VariableAnalyses::create(&func); + let mut variable_analyses = CodeAnalysis::create(&func); codegen_cube(&func, &mut variable_analyses) } @@ -28,7 +30,8 @@ impl From<&syn::Ident> for VariableKey { } /// Generate the expanded version of a function marked with the cube macro -fn codegen_cube(func: &syn::ItemFn, variable_analyses: &mut VariableAnalyses) -> TokenStream { +fn codegen_cube(func: &syn::ItemFn, variable_analyses: &mut CodeAnalysis) -> TokenStream { + let prelude = get_prelude(); let signature = expand_sig(&func.sig); let mut body = quote::quote! {}; @@ -39,6 +42,8 @@ fn codegen_cube(func: &syn::ItemFn, variable_analyses: &mut VariableAnalyses) -> } let code = quote::quote! { + #prelude + #func #[allow(unused_mut)] diff --git a/crates/burn-cube-macros/src/prelude.rs b/crates/burn-cube-macros/src/prelude.rs new file mode 100644 index 0000000000..1a07c42f2a --- /dev/null +++ b/crates/burn-cube-macros/src/prelude.rs @@ -0,0 +1,53 @@ +// TODO: in analysis, identify needed functions from prelude and add only those + +pub(crate) fn get_prelude(needed_functions: Vec<&str>) -> proc_macro2::TokenStream { + let mut prelude = proc_macro2::TokenStream::new(); + + for func_name in needed_functions { + let func_code = match func_name { + "float_new" => Some(codegen_float_new()), + "int_new" => Some(codegen_int_new()), + _ => None, + }; + + if func_code.is_some() { + prelude.extend(func_code); + } + } + + prelude +} + +fn codegen_float_new() -> proc_macro2::TokenStream { + quote::quote! { + pub fn float_new(val: f32) -> Float { + Float { + val, + vectorization: 1, + } + } + pub fn float_new_expand( + context: &mut CubeContext, + val: f32, + ) -> ::ExpandType { + val.into() + } + } +} + +fn codegen_int_new() -> proc_macro2::TokenStream { + quote::quote! { + pub fn int_new(val: i32) -> Int { + Int { + val, + vectorization: 1, + } + } + pub fn int_new_expand( + context: &mut CubeContext, + val: i32, + ) -> ::ExpandType { + val.into() + } + } +} diff --git a/crates/burn-cube/src/element.rs b/crates/burn-cube/src/element.rs index 51a79083d6..01eecf16f5 100644 --- a/crates/burn-cube/src/element.rs +++ b/crates/burn-cube/src/element.rs @@ -55,6 +55,12 @@ impl From for ExpandElement { } } +impl From for ExpandElement { + fn from(value: i32) -> Self { + ExpandElement::new(Rc::new(Variable::from(value))) + } +} + impl core::ops::Deref for ExpandElement { type Target = Variable; @@ -71,7 +77,7 @@ pub struct Float { #[derive(new, Clone, Copy)] pub struct Int { - pub val: u32, + pub val: i32, pub vectorization: u8, } @@ -154,6 +160,16 @@ impl From for UInt { } } +impl RuntimeType for i32 { + type ExpandType = i32; +} + +impl From for Int { + fn from(value: i32) -> Self { + Int::new(value, 1) + } +} + impl From for Variable { fn from(value: ExpandElement) -> Self { // Is it ok to do that? diff --git a/crates/burn-cube/tests/gcd.rs b/crates/burn-cube/tests/gcd.rs index 202e110657..240f93b67e 100644 --- a/crates/burn-cube/tests/gcd.rs +++ b/crates/burn-cube/tests/gcd.rs @@ -1,58 +1,58 @@ -use burn_cube::{cube, CubeContext, Int}; -use burn_jit::gpu; -use burn_jit::gpu::Branch; -use burn_jit::gpu::IntKind::I32; -use burn_jit::gpu::{Elem, Item, Variable}; - -#[cube] -pub fn gcd(lhs: Int, rhs: Int) { - while rhs != 0u32 { - let tmp = rhs; - rhs = lhs % rhs; - lhs = tmp - } - // use lhs as output -} - -#[test] -fn cube_function_test() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); - let rhs = context.create_local(Item::Scalar(Elem::Int(I32))); - - gcd_expand(&mut context, lhs, rhs); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); -} - -fn gpu_macro_ref() -> String { - let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Int(I32)); - - let lhs = context.create_local(item); - let rhs = context.create_local(item); - let lhs: Variable = lhs.into(); - let rhs: Variable = rhs.into(); - let mut scope = context.into_scope(); - - // Kernel - let cond = scope.create_local(Item::Scalar(Elem::Bool)); - let tmp = scope.create_local(Item::Scalar(Elem::Int(I32))); - gpu!( - &mut scope, - loop(|scope| { - gpu!(scope, cond = rhs != 0); - gpu!(scope, if(cond).then(|scope|{ - scope.register(Branch::Break); - })); - - gpu!(scope, tmp = rhs); - gpu!(scope, rhs = lhs % rhs); - gpu!(scope, lhs = tmp); - }) - ); - - format!("{:?}", scope.operations) -} +// use burn_cube::cube; +// use burn_jit::gpu; +// use burn_jit::gpu::Branch; +// use burn_jit::gpu::IntKind::I32; +// use burn_jit::gpu::{Elem, Item, Variable}; + +// #[cube] +// pub fn gcd(lhs: Int, rhs: Int) { +// while rhs != int_new(0) { +// let tmp = rhs; +// rhs = lhs % rhs; +// lhs = tmp +// } +// // TODO: use lhs as output +// } + +// #[test] +// fn cube_function_test() { +// let mut context = CubeContext::root(); + +// let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); +// let rhs = context.create_local(Item::Scalar(Elem::Int(I32))); + +// gcd_expand(&mut context, lhs, rhs); +// let scope = context.into_scope(); + +// assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); +// } + +// fn gpu_macro_ref() -> String { +// let mut context = CubeContext::root(); +// let item = Item::Scalar(Elem::Int(I32)); + +// let lhs = context.create_local(item); +// let rhs = context.create_local(item); +// let lhs: Variable = lhs.into(); +// let rhs: Variable = rhs.into(); +// let mut scope = context.into_scope(); + +// // Kernel +// let cond = scope.create_local(Item::Scalar(Elem::Bool)); +// let tmp = scope.create_local(Item::Scalar(Elem::Int(I32))); +// gpu!( +// &mut scope, +// loop(|scope| { +// gpu!(scope, cond = rhs != 0); +// gpu!(scope, if(cond).then(|scope|{ +// scope.register(Branch::Break); +// })); + +// gpu!(scope, tmp = rhs); +// gpu!(scope, rhs = lhs % rhs); +// gpu!(scope, lhs = tmp); +// }) +// ); + +// format!("{:?}", scope.operations) +// } diff --git a/crates/burn-cube/tests/literal.rs b/crates/burn-cube/tests/literal.rs index a24ae9af37..1293d0c943 100644 --- a/crates/burn-cube/tests/literal.rs +++ b/crates/burn-cube/tests/literal.rs @@ -1,10 +1,11 @@ use burn_cube::{cube, CubeContext, Float}; +use burn_jit::gpu; use burn_jit::gpu::FloatKind::F32; -use burn_jit::gpu::{Elem, Item, Variable}; +use burn_jit::gpu::{Elem, Item}; #[cube] pub fn literal(lhs: Float) { - let rhs: Float = 5.9f32.into(); + let _ = lhs + float_new(5.9); } #[test] @@ -22,11 +23,11 @@ fn cube_literal_test() { fn gpu_macro_ref() -> String { let mut context = CubeContext::root(); let item = Item::Scalar(Elem::Float(F32)); - let lhs = context.create_local(item); - let lhs: Variable = lhs.into(); + let mut scope = context.into_scope(); - scope.create_with_value(5.9, item); + let out = scope.create_local(item); + gpu!(scope, out = lhs + 5.9f32); format!("{:?}", scope.operations) } From 99fe8d8e0b2e20e33f95ce46886fb0c615556423 Mon Sep 17 00:00:00 2001 From: louisfd Date: Thu, 2 May 2024 10:33:06 -0400 Subject: [PATCH 16/54] prelude working --- crates/burn-cube-macros/src/analysis.rs | 282 ++++++++++++------------ crates/burn-cube-macros/src/lib.rs | 6 +- crates/burn-cube-macros/src/prelude.rs | 11 +- 3 files changed, 157 insertions(+), 142 deletions(-) diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs index 2f03c02187..a66cc507cf 100644 --- a/crates/burn-cube-macros/src/analysis.rs +++ b/crates/burn-cube-macros/src/analysis.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use syn::Stmt; @@ -23,10 +23,17 @@ impl VariableAnalysis { #[derive(Debug)] pub(crate) struct CodeAnalysis { - pub needed_functions: Vec, + pub needed_functions: HashSet, pub variable_analyses: HashMap, } +#[derive(Debug, Default)] +pub(crate) struct CodeAnalysisBuilder { + declarations: Vec<(VariableKey, usize)>, + var_uses: Vec, + function_calls: HashSet, +} + impl CodeAnalysis { pub fn should_clone(&mut self, ident: &syn::Ident, loop_level: usize) -> bool { let key: VariableKey = ident.into(); @@ -40,164 +47,165 @@ impl CodeAnalysis { } } - pub fn create(func: &syn::ItemFn) -> Self { - let analyses = analyze(func); - // TODO WE WANT TO KEEP TRACK OF USED FUNCTIONS AS WELL - // AT THIS POINT LET'S MAKE A BUILDER PATTERN INSTEAD - - Self { - variable_analyses: analyses, - } + pub fn create(func: &syn::ItemFn) -> CodeAnalysis { + let code_analysis_builder = CodeAnalysisBuilder::default(); + code_analysis_builder.analyze(func) } } -pub(crate) fn analyze(func: &syn::ItemFn) -> HashMap { - // Build the vector of (Id, depth), using recursion - let mut declarations = Vec::new(); - signature_declarations(&func.sig, &mut declarations); - - let mut var_uses = Vec::new(); - stmts_occurrences(&func.block.stmts, 0, &mut declarations, &mut var_uses); - - // Run through the vec and build hashmap, without recursion - let mut analyses = HashMap::::new(); - for declaration in declarations.into_iter() { - let id = declaration.0; - let new_analysis = match analyses.remove(&id) { - Some(_) => { - panic!("Multiple variables with the same identifier is not supported") - } - None => VariableAnalysis { - num_used: 0, - loop_level_declared: declaration.1, - }, - }; +impl CodeAnalysisBuilder { + fn analyze(mut self, func: &syn::ItemFn) -> CodeAnalysis { + // Build the vector of (Id, depth), using recursion + self.signature_declarations(&func.sig); + self.stmts_occurrences(&func.block.stmts, 0); - analyses.insert(id, new_analysis); + CodeAnalysis { + variable_analyses: self.to_map(), + needed_functions: self.function_calls, + } } - for id in var_uses.into_iter() { - let prev_analysis = analyses.remove(&id).expect(&format!( - "Variable {:?} should be declared before it's used", - id - )); - let new_analysis = VariableAnalysis { - num_used: prev_analysis.num_used + 1, - loop_level_declared: prev_analysis.loop_level_declared, - }; - analyses.insert(id, new_analysis); - } + fn to_map(&self) -> HashMap { + // Run through the vec and build hashmap, without recursion + let mut variable_analyses = HashMap::::new(); + for declaration in self.declarations.iter() { + let id = declaration.0.clone(); + let new_analysis = match variable_analyses.remove(&id) { + Some(_) => { + panic!("Multiple variables with the same identifier is not supported") + } + None => VariableAnalysis { + num_used: 0, + loop_level_declared: declaration.1, + }, + }; - analyses -} + variable_analyses.insert(id, new_analysis); + } -fn signature_declarations(sig: &syn::Signature, declarations: &mut Vec<(VariableKey, usize)>) { - for input in &sig.inputs { - match input { - syn::FnArg::Typed(pat) => { - let ident = &*pat.pat; - match ident { - syn::Pat::Ident(pat_ident) => { - let id = &pat_ident.ident; - declarations.push((id.into(), 0)); + for id in self.var_uses.iter() { + let prev_analysis = variable_analyses.remove(&id).expect(&format!( + "Variable {:?} should be declared before it's used", + id + )); + let new_analysis = VariableAnalysis { + num_used: prev_analysis.num_used + 1, + loop_level_declared: prev_analysis.loop_level_declared, + }; + variable_analyses.insert(id.clone(), new_analysis); + } + + variable_analyses + } + + fn signature_declarations(&mut self, sig: &syn::Signature) { + for input in &sig.inputs { + match input { + syn::FnArg::Typed(pat) => { + let ident = &*pat.pat; + match ident { + syn::Pat::Ident(pat_ident) => { + let id = &pat_ident.ident; + self.declarations.push((id.into(), 0)); + } + _ => todo!("Analysis: unsupported ident {ident:?}"), } - _ => todo!("Analysis: unsupported ident {ident:?}"), } + _ => todo!("Analysis: unsupported input {input:?}"), } - _ => todo!("Analysis: unsupported input {input:?}"), } } -} -fn stmts_occurrences( - stmts: &Vec, - depth: usize, - declarations: &mut Vec<(VariableKey, usize)>, - uses: &mut Vec, -) { - for stmt in stmts { - match stmt { - // Declaration - syn::Stmt::Local(local) => { - let id = match &local.pat { - syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident), - syn::Pat::Type(pat_type) => Some(match &*pat_type.pat { - syn::Pat::Ident(pat_ident) => &pat_ident.ident, - _ => todo!("Analysis: unsupported typed path {:?}", pat_type.pat), - }), - syn::Pat::Wild(_) => None, - _ => todo!("Analysis: unsupported path {:?}", local.pat), - }; - if let Some(id) = id { - declarations.push((id.into(), depth)); - } - if let Some(local_init) = &local.init { - expr_occurrences(&local_init.expr, depth, declarations, uses) + fn stmts_occurrences(&mut self, stmts: &Vec, depth: usize) { + for stmt in stmts { + match stmt { + // Declaration + syn::Stmt::Local(local) => { + let id = match &local.pat { + syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident), + syn::Pat::Type(pat_type) => Some(match &*pat_type.pat { + syn::Pat::Ident(pat_ident) => &pat_ident.ident, + _ => todo!("Analysis: unsupported typed path {:?}", pat_type.pat), + }), + syn::Pat::Wild(_) => None, + _ => todo!("Analysis: unsupported path {:?}", local.pat), + }; + if let Some(id) = id { + self.declarations.push((id.into(), depth)); + } + if let Some(local_init) = &local.init { + self.expr_occurrences(&local_init.expr, depth) + } } + syn::Stmt::Expr(expr, _) => self.expr_occurrences(expr, depth), + _ => todo!("Analysis: unsupported stmt {stmt:?}"), } - syn::Stmt::Expr(expr, _) => expr_occurrences(expr, depth, declarations, uses), - _ => todo!("Analysis: unsupported stmt {stmt:?}"), } } -} -fn expr_occurrences( - expr: &syn::Expr, - depth: usize, - declarations: &mut Vec<(VariableKey, usize)>, - uses: &mut Vec, -) { - match expr { - syn::Expr::ForLoop(expr) => { - let depth = depth + 1; - - // Declaration of iterator - if let syn::Pat::Ident(pat_ident) = &*expr.pat { - let id = &pat_ident.ident; - declarations.push((id.into(), depth)); - } + fn expr_occurrences(&mut self, expr: &syn::Expr, depth: usize) { + match expr { + syn::Expr::ForLoop(expr) => { + let depth = depth + 1; - stmts_occurrences(&expr.body.stmts, depth, declarations, uses); - } - syn::Expr::While(expr) => { - let depth = depth + 1; + // Declaration of iterator + if let syn::Pat::Ident(pat_ident) = &*expr.pat { + let id = &pat_ident.ident; + self.declarations.push((id.into(), depth)); + } - expr_occurrences(&expr.cond, depth, declarations, uses); - stmts_occurrences(&expr.body.stmts, depth, declarations, uses); - } - syn::Expr::Assign(expr) => { - expr_occurrences(&expr.left, depth, declarations, uses); - expr_occurrences(&expr.right, depth, declarations, uses); - } - syn::Expr::Index(expr) => { - expr_occurrences(&expr.expr, depth, declarations, uses); - expr_occurrences(&expr.index, depth, declarations, uses); - } - syn::Expr::Path(expr) => { - let ident = expr - .path - .get_ident() - .expect("Analysis: only ident path are supported."); - - // Use - uses.push(ident.into()); - } - syn::Expr::Binary(expr) => { - expr_occurrences(&expr.left, depth, declarations, uses); - expr_occurrences(&expr.right, depth, declarations, uses); - } - syn::Expr::Lit(_) => {} - syn::Expr::Call(expr) => { - for arg in expr.args.iter() { - expr_occurrences(arg, depth, declarations, uses); + self.stmts_occurrences(&expr.body.stmts, depth); } - } - syn::Expr::MethodCall(expr) => { - expr_occurrences(&expr.receiver, depth, declarations, uses); - for arg in expr.args.iter() { - expr_occurrences(arg, depth, declarations, uses); + syn::Expr::While(expr) => { + let depth = depth + 1; + + self.expr_occurrences(&expr.cond, depth); + self.stmts_occurrences(&expr.body.stmts, depth); + } + syn::Expr::Assign(expr) => { + self.expr_occurrences(&expr.left, depth); + self.expr_occurrences(&expr.right, depth); + } + syn::Expr::Index(expr) => { + self.expr_occurrences(&expr.expr, depth); + self.expr_occurrences(&expr.index, depth); + } + syn::Expr::Path(expr) => { + let ident = expr + .path + .get_ident() + .expect("Analysis: only ident path are supported."); + + // Use + self.var_uses.push(ident.into()); + } + syn::Expr::Binary(expr) => { + self.expr_occurrences(&expr.left, depth); + self.expr_occurrences(&expr.right, depth); + } + syn::Expr::Lit(_) => {} + syn::Expr::Call(expr) => { + match &*expr.func { + syn::Expr::Path(expr_path) => { + let ident = expr_path + .path + .get_ident() + .expect("Analysis: only ident supported for function call"); + self.function_calls.insert(ident.into()); + } + _ => todo!("Analysis: unsupported func expr {:?}", expr.func), + } + for arg in expr.args.iter() { + self.expr_occurrences(arg, depth); + } + } + syn::Expr::MethodCall(expr) => { + self.expr_occurrences(&expr.receiver, depth); + for arg in expr.args.iter() { + self.expr_occurrences(arg, depth); + } } + _ => todo!("Analysis: unsupported expr {expr:?}"), } - _ => todo!("Analysis: unsupported expr {expr:?}"), } } diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index 10ed4f0488..5f9f175092 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -30,14 +30,14 @@ impl From<&syn::Ident> for VariableKey { } /// Generate the expanded version of a function marked with the cube macro -fn codegen_cube(func: &syn::ItemFn, variable_analyses: &mut CodeAnalysis) -> TokenStream { - let prelude = get_prelude(); +fn codegen_cube(func: &syn::ItemFn, code_analysis: &mut CodeAnalysis) -> TokenStream { + let prelude = get_prelude(&code_analysis.needed_functions); let signature = expand_sig(&func.sig); let mut body = quote::quote! {}; // panic!("WG"); for statement in func.block.stmts.iter() { - let tokens = codegen_statement(statement, 0, variable_analyses); + let tokens = codegen_statement(statement, 0, code_analysis); body.extend(tokens); } diff --git a/crates/burn-cube-macros/src/prelude.rs b/crates/burn-cube-macros/src/prelude.rs index 1a07c42f2a..b359b226b3 100644 --- a/crates/burn-cube-macros/src/prelude.rs +++ b/crates/burn-cube-macros/src/prelude.rs @@ -1,9 +1,16 @@ // TODO: in analysis, identify needed functions from prelude and add only those -pub(crate) fn get_prelude(needed_functions: Vec<&str>) -> proc_macro2::TokenStream { +use std::collections::HashSet; + +use crate::VariableKey; + +pub(crate) fn get_prelude(needed_functions: &HashSet) -> proc_macro2::TokenStream { let mut prelude = proc_macro2::TokenStream::new(); - for func_name in needed_functions { + for func_name in needed_functions + .iter() + .map(|variable| variable.name.as_str()) + { let func_code = match func_name { "float_new" => Some(codegen_float_new()), "int_new" => Some(codegen_int_new()), From ad83d0290516194dcc6bad527b2b09210f112b2b Mon Sep 17 00:00:00 2001 From: louisfd Date: Thu, 2 May 2024 10:35:57 -0400 Subject: [PATCH 17/54] fix for loop --- .../tests/{for_loop_static.rs => for_loop.rs} | 8 +-- crates/burn-cube/tests/for_loop_dynamic.rs | 58 ------------------- 2 files changed, 4 insertions(+), 62 deletions(-) rename crates/burn-cube/tests/{for_loop_static.rs => for_loop.rs} (93%) delete mode 100644 crates/burn-cube/tests/for_loop_dynamic.rs diff --git a/crates/burn-cube/tests/for_loop_static.rs b/crates/burn-cube/tests/for_loop.rs similarity index 93% rename from crates/burn-cube/tests/for_loop_static.rs rename to crates/burn-cube/tests/for_loop.rs index e859461d2a..13194d7ed0 100644 --- a/crates/burn-cube/tests/for_loop_static.rs +++ b/crates/burn-cube/tests/for_loop.rs @@ -1,10 +1,10 @@ -use burn_cube::{cube, range, range_expand, Array, CubeContext, Float}; +use burn_cube::{cube, range, range_expand, Array, CubeContext, Float, UInt}; use burn_jit::gpu; use burn_jit::gpu::FloatKind::F32; use burn_jit::gpu::{Elem, Item, Variable}; #[cube] -pub fn kernel(mut lhs: Array, rhs: Float, end: u32, unroll: bool) { +pub fn kernel(mut lhs: Array, rhs: Float, end: UInt, unroll: bool) { let tmp1 = rhs * rhs; let tmp2 = tmp1 + rhs; @@ -20,7 +20,7 @@ fn test_for_loop_with_unroll() { let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); let rhs = context.create_local(Item::Scalar(Elem::Float(F32))); - let end = 4u32; + let end = 4u32.into(); kernel_expand(&mut context, lhs, rhs, end, unroll); let scope = context.into_scope(); @@ -35,7 +35,7 @@ fn test_for_loop_no_unroll() { let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); let rhs = context.create_local(Item::Scalar(Elem::Float(F32))); - let end = 4u32; + let end = 4u32.into(); kernel_expand(&mut context, lhs, rhs, end, unroll); let scope = context.into_scope(); diff --git a/crates/burn-cube/tests/for_loop_dynamic.rs b/crates/burn-cube/tests/for_loop_dynamic.rs deleted file mode 100644 index 4dd8634ffb..0000000000 --- a/crates/burn-cube/tests/for_loop_dynamic.rs +++ /dev/null @@ -1,58 +0,0 @@ -use burn_cube::{cube, range, range_expand, Array, CubeContext, Float, UInt}; -use burn_jit::gpu; -use burn_jit::gpu::FloatKind::F32; -use burn_jit::gpu::{Elem, Item, Variable}; - -#[cube] -pub fn kernel(mut lhs: Array, rhs: Float, end: UInt, unroll: bool) { - let tmp1 = rhs * rhs; - let tmp2 = tmp1 + rhs; - - for i in range(0u32, end, unroll) { - lhs[i] = tmp2 + lhs[i]; - } -} - -#[test] -fn test_for_loop() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); - let rhs = context.create_local(Item::Scalar(Elem::Float(F32))); - let end = context.create_local(Item::Scalar(Elem::UInt)); - // let end =Variable::ConstantScalar(5.0, Item::Scalar(Elem::UInt)); - - kernel_expand(&mut context, lhs, rhs, end, false); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); -} - -fn gpu_macro_ref() -> String { - let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Float(F32)); - - let lhs = context.create_local(item); - let rhs = context.create_local(item); - let lhs: Variable = lhs.into(); - let rhs: Variable = rhs.into(); - let end = context.create_local(Item::Scalar(Elem::UInt)); - let mut scope = context.into_scope(); - - // Kernel - let tmp1 = scope.create_local(item); - let tmp2 = scope.create_local(item); - gpu!(scope, tmp1 = rhs * rhs); - gpu!(scope, tmp2 = tmp1 + rhs); - - gpu!( - &mut scope, - range(0usize, end).for_each(|i, scope| { - gpu!(scope, rhs = lhs[i]); - gpu!(scope, tmp1 = tmp2 + rhs); - gpu!(scope, lhs[i] = tmp1); - }) - ); - - format!("{:?}", scope.operations) -} From 9547998a720ecc6b14b57de4f241269984b08283 Mon Sep 17 00:00:00 2001 From: louisfd Date: Thu, 2 May 2024 14:37:32 -0400 Subject: [PATCH 18/54] wip --- crates/burn-cube-macros/src/analysis.rs | 10 ++++ crates/burn-cube-macros/src/codegen.rs | 13 +++-- crates/burn-cube-macros/src/lib.rs | 1 - crates/burn-cube/src/branch.rs | 17 ++++++ crates/burn-cube/tests/for_loop.rs | 42 +++++++++++++++ crates/burn-cube/tests/gcd.rs | 71 +++++++++++++++++++++++-- crates/burn-cube/tests/if.rs | 40 ++++++++++++++ 7 files changed, 186 insertions(+), 8 deletions(-) create mode 100644 crates/burn-cube/tests/if.rs diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs index a66cc507cf..63bdae290c 100644 --- a/crates/burn-cube-macros/src/analysis.rs +++ b/crates/burn-cube-macros/src/analysis.rs @@ -162,6 +162,16 @@ impl CodeAnalysisBuilder { self.expr_occurrences(&expr.cond, depth); self.stmts_occurrences(&expr.body.stmts, depth); } + syn::Expr::If(expr) => { + // Not sure if should update depth + + if expr.else_branch.is_some() { + todo!("Analysis: else branch not supported"); + } + + self.expr_occurrences(&expr.cond, depth); + self.stmts_occurrences(&expr.then_branch.stmts, depth); + } syn::Expr::Assign(expr) => { self.expr_occurrences(&expr.left, depth); self.expr_occurrences(&expr.right, depth); diff --git a/crates/burn-cube-macros/src/codegen.rs b/crates/burn-cube-macros/src/codegen.rs index 95f7138f7a..e589d30217 100644 --- a/crates/burn-cube-macros/src/codegen.rs +++ b/crates/burn-cube-macros/src/codegen.rs @@ -1,4 +1,5 @@ use proc_macro2::TokenStream; +use syn::token::If; use crate::analysis::CodeAnalysis; @@ -9,7 +10,6 @@ pub fn codegen_statement( ) -> TokenStream { match statement { syn::Stmt::Local(local) => codegen_local(local, loop_level, variable_analyses), - syn::Stmt::Item(_) => todo!(), syn::Stmt::Expr(expr, semi) => { let expr = codegen_expr(expr, loop_level, variable_analyses); match semi { @@ -19,7 +19,7 @@ pub fn codegen_statement( None => expr, } } - syn::Stmt::Macro(_) => todo!(), + _ => todo!("Codegen: statement {statement:?} not supported"), } } @@ -74,9 +74,10 @@ fn codegen_expr( syn::Expr::While(while_loop) => { codegen_while_loop(while_loop, loop_level, variable_analyses) } + syn::Expr::If(expr_if) => codegen_if(expr_if, variable_analyses), syn::Expr::MethodCall(call) => codegen_expr_method_call(call), syn::Expr::Index(index) => codegen_expr_index(index, loop_level, variable_analyses), - _ => panic!("Unsupported {:?}", expr), + _ => panic!("Codegen: Unsupported {:?}", expr), } } @@ -144,6 +145,10 @@ fn codegen_for_loop( } } +fn codegen_if(expr_if: &syn::ExprIf, variable_analyses: &mut CodeAnalysis) { + todo!("RENDU ICITTE") +} + fn codegen_while_loop( while_loop: &syn::ExprWhile, loop_level: usize, @@ -158,7 +163,7 @@ fn codegen_while_loop( }; quote::quote! { - loop_expand(#cond |context| #block); + loop_expand(context, |context| #cond, |context| #block); } } diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index 5f9f175092..0b51f26c60 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -35,7 +35,6 @@ fn codegen_cube(func: &syn::ItemFn, code_analysis: &mut CodeAnalysis) -> TokenSt let signature = expand_sig(&func.sig); let mut body = quote::quote! {}; - // panic!("WG"); for statement in func.block.stmts.iter() { let tokens = codegen_statement(statement, 0, code_analysis); body.extend(tokens); diff --git a/crates/burn-cube/src/branch.rs b/crates/burn-cube/src/branch.rs index 1221a34517..55f92d6065 100644 --- a/crates/burn-cube/src/branch.rs +++ b/crates/burn-cube/src/branch.rs @@ -55,3 +55,20 @@ pub fn range_expand( })); } } + +// pub fn loop_expand(context: &mut CubeContext, mut cond_fn: F, mut block: F) +// where +// F: FnMut(&mut CubeContext), +// { +// let mut child = context.child(); +// // let cond: Variable = cond_fn... + +// child.register(Branch::If(gpu::If { +// cond, scope: child.into_scope() +// })); + +// block(&mut child); +// context.register(Branch::Loop(gpu::Loop { +// scope: child.into_scope(), +// })) +// } diff --git a/crates/burn-cube/tests/for_loop.rs b/crates/burn-cube/tests/for_loop.rs index 13194d7ed0..038cf377d9 100644 --- a/crates/burn-cube/tests/for_loop.rs +++ b/crates/burn-cube/tests/for_loop.rs @@ -13,6 +13,48 @@ pub fn kernel(mut lhs: Array, rhs: Float, end: UInt, unroll: bool) { } } +// #[allow(unused_mut)] +// pub fn kernel_expand( +// context: &mut burn_cube::CubeContext, +// mut lhs: as burn_cube::RuntimeType>::ExpandType, +// rhs: ::ExpandType, +// end: ::ExpandType, +// unroll: ::ExpandType, +// ) -> () { +// let tmp1 = { +// let _lhs = rhs.clone().into(); +// let _rhs = rhs.clone().into(); +// burn_cube::mul::expand(context, _lhs, _rhs) +// }; +// let tmp2 = { +// let _lhs = tmp1.into(); +// let _rhs = rhs.into(); +// burn_cube::add::expand(context, _lhs, _rhs) +// }; +// range_expand( +// context, +// 0u32.into(), +// end.into(), +// unroll.into(), +// |context, i| { +// { +// let _array = lhs.clone().into(); +// let _index = i.clone().into(); +// let _value = { +// let _lhs = tmp2.clone().into(); +// let _rhs = { +// let _array = lhs.clone().into(); +// let _index = i.into(); +// burn_cube::index::expand(context, _array, _index) +// }; +// burn_cube::add::expand(context, _lhs, _rhs) +// }; +// burn_cube::index_assign::expand(context, _array, _index, _value) +// }; +// }, +// ); +// } + #[test] fn test_for_loop_with_unroll() { let mut context = CubeContext::root(); diff --git a/crates/burn-cube/tests/gcd.rs b/crates/burn-cube/tests/gcd.rs index 240f93b67e..71a0787233 100644 --- a/crates/burn-cube/tests/gcd.rs +++ b/crates/burn-cube/tests/gcd.rs @@ -1,4 +1,4 @@ -// use burn_cube::cube; +// use burn_cube::{cube, loop_expand, CubeContext, Int}; // use burn_jit::gpu; // use burn_jit::gpu::Branch; // use burn_jit::gpu::IntKind::I32; @@ -8,12 +8,77 @@ // pub fn gcd(lhs: Int, rhs: Int) { // while rhs != int_new(0) { // let tmp = rhs; -// rhs = lhs % rhs; -// lhs = tmp +// // rhs = lhs % rhs; +// // lhs = tmp; // } // // TODO: use lhs as output // } +// // pub fn int_new(val: i32) -> Int { +// // Int { +// // val, +// // vectorization: 1, +// // } +// // } +// // pub fn int_new_expand( +// // context: &mut CubeContext, +// // val: i32, +// // ) -> ::ExpandType { +// // val.into() +// // } +// // pub fn gcd(lhs: Int, rhs: Int) { +// // while rhs != int_new(0) { +// // let tmp = rhs; +// // } +// // } +// // #[allow(unused_mut)] +// // pub fn gcd_expand( +// // context: &mut burn_cube::CubeContext, +// // lhs: ::ExpandType, +// // rhs: ::ExpandType, +// // ) -> () { +// // loop_expand( +// // context, +// // |context| { +// // let _lhs = rhs.into(); +// // let _rhs = int_new_expand(context, 0.into()); +// // burn_cube::ne::expand(context, _lhs, _rhs) +// // }, +// // |context| { +// // let tmp = rhs.clone().into(); +// // }, +// // ); +// // } + +// // pub fn int_new(val: i32) -> Int { +// // Int { +// // val, +// // vectorization: 1, +// // } +// // } +// // pub fn int_new_expand( +// // context: &mut CubeContext, +// // val: i32, +// // ) -> ::ExpandType { +// // val.into() +// // } +// // pub fn gcd(lhs: Int, rhs: Int) { +// // while rhs != int_new(0) {} +// // } +// // #[allow(unused_mut)] +// // pub fn gcd_expand( +// // context: &mut burn_cube::CubeContext, +// // lhs: ::ExpandType, +// // rhs: ::ExpandType, +// // ) -> () { +// // let _cond = { +// // let _lhs = rhs.into(); +// // let _rhs = int_new_expand(context, 0.into()); +// // burn_cube::ne::expand(context, _lhs, _rhs) +// // }; +// // loop_expand(context, _cond, |context| {}); +// // } + // #[test] // fn cube_function_test() { // let mut context = CubeContext::root(); diff --git a/crates/burn-cube/tests/if.rs b/crates/burn-cube/tests/if.rs new file mode 100644 index 0000000000..f412bf2e47 --- /dev/null +++ b/crates/burn-cube/tests/if.rs @@ -0,0 +1,40 @@ +use burn_cube::{cube, CubeContext, Float}; +use burn_jit::gpu; +use burn_jit::gpu::FloatKind::F32; +use burn_jit::gpu::{Elem, Item}; + +#[cube] +pub fn if_greater(lhs: Float) { + if lhs > float_new(0.) { + let _ = lhs; + } +} + +#[test] +fn cube_if_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); + + if_greater_expand(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); +} + +fn gpu_macro_ref() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Float(F32)); + let lhs = context.create_local(item); + + let mut scope = context.into_scope(); + let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let out = scope.create_local(item); + gpu!(scope, cond = lhs > 0f32); + + gpu!(&mut scope, if(cond).then(|scope|{ + gpu!(scope, out = lhs); + })); + + format!("{:?}", scope.operations) +} From bcd3c866477f833162ff0c3f15e65e7506dab6c9 Mon Sep 17 00:00:00 2001 From: louisfd Date: Fri, 3 May 2024 10:51:56 -0400 Subject: [PATCH 19/54] if --- crates/burn-cube-macros/src/analysis.rs | 2 - crates/burn-cube-macros/src/codegen.rs | 50 ++++++++++++++++++++----- crates/burn-cube-macros/src/prelude.rs | 3 +- crates/burn-cube/src/branch.rs | 14 +++++++ crates/burn-cube/src/operation/base.rs | 34 ++++++++++++++++- crates/burn-cube/src/operation/cmp.rs | 47 ++++++++++++++++++++++- crates/burn-cube/tests/if.rs | 12 +++--- 7 files changed, 138 insertions(+), 24 deletions(-) diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs index 63bdae290c..9c26137945 100644 --- a/crates/burn-cube-macros/src/analysis.rs +++ b/crates/burn-cube-macros/src/analysis.rs @@ -163,8 +163,6 @@ impl CodeAnalysisBuilder { self.stmts_occurrences(&expr.body.stmts, depth); } syn::Expr::If(expr) => { - // Not sure if should update depth - if expr.else_branch.is_some() { todo!("Analysis: else branch not supported"); } diff --git a/crates/burn-cube-macros/src/codegen.rs b/crates/burn-cube-macros/src/codegen.rs index e589d30217..5d2a9aca3e 100644 --- a/crates/burn-cube-macros/src/codegen.rs +++ b/crates/burn-cube-macros/src/codegen.rs @@ -74,7 +74,7 @@ fn codegen_expr( syn::Expr::While(while_loop) => { codegen_while_loop(while_loop, loop_level, variable_analyses) } - syn::Expr::If(expr_if) => codegen_if(expr_if, variable_analyses), + syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level, variable_analyses), syn::Expr::MethodCall(call) => codegen_expr_method_call(call), syn::Expr::Index(index) => codegen_expr_index(index, loop_level, variable_analyses), _ => panic!("Codegen: Unsupported {:?}", expr), @@ -145,8 +145,35 @@ fn codegen_for_loop( } } -fn codegen_if(expr_if: &syn::ExprIf, variable_analyses: &mut CodeAnalysis) { - todo!("RENDU ICITTE") +fn codegen_cond( + cond: &syn::Expr, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + match cond { + syn::Expr::Binary(expr) => codegen_binary(expr, loop_level, variable_analyses), + syn::Expr::Lit(expr) => codegen_lit(expr), + _ => todo!("{cond:?} cond not supported"), + } +} + +fn codegen_if( + expr_if: &syn::ExprIf, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + if expr_if.else_branch.is_some() { + todo!("Codegen: else branch not supported"); + } + + let cond = codegen_cond(&expr_if.cond, loop_level, variable_analyses); + + let block = codegen_block(&expr_if.then_branch, loop_level, variable_analyses); + + quote::quote! { + let _cond = #cond; + if_expand(context, _cond, |context| #block); + } } fn codegen_while_loop( @@ -156,11 +183,7 @@ fn codegen_while_loop( ) -> TokenStream { let block = codegen_block(&while_loop.body, loop_level + 1, variable_analyses); - let cond = match while_loop.cond.as_ref() { - syn::Expr::Binary(expr) => codegen_binary(expr, loop_level, variable_analyses), - syn::Expr::Lit(expr) => codegen_lit(expr), - _ => todo!("{while_loop:?} cond not supported"), - }; + let cond = codegen_cond(&while_loop.cond, loop_level, variable_analyses); quote::quote! { loop_expand(context, |context| #cond, |context| #block); @@ -292,11 +315,11 @@ fn codegen_path( if will_be_used_again { quote::quote! { - #ident.clone().into() + #ident.clone() } } else { quote::quote! { - #ident.into() + #ident } } } @@ -352,6 +375,13 @@ fn codegen_binary( burn_cube::ne::expand(context, _lhs, _rhs) } }, + syn::BinOp::Gt(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::gt::expand(context, _lhs, _rhs) + } + }, _ => todo!("Codegen: unsupported op {:?}", binary.op), } } diff --git a/crates/burn-cube-macros/src/prelude.rs b/crates/burn-cube-macros/src/prelude.rs index b359b226b3..d979385494 100644 --- a/crates/burn-cube-macros/src/prelude.rs +++ b/crates/burn-cube-macros/src/prelude.rs @@ -1,5 +1,3 @@ -// TODO: in analysis, identify needed functions from prelude and add only those - use std::collections::HashSet; use crate::VariableKey; @@ -37,6 +35,7 @@ fn codegen_float_new() -> proc_macro2::TokenStream { context: &mut CubeContext, val: f32, ) -> ::ExpandType { + // TODO: 0. becomes 0..into() val.into() } } diff --git a/crates/burn-cube/src/branch.rs b/crates/burn-cube/src/branch.rs index 55f92d6065..d6a8a735ee 100644 --- a/crates/burn-cube/src/branch.rs +++ b/crates/burn-cube/src/branch.rs @@ -56,6 +56,20 @@ pub fn range_expand( } } +pub fn if_expand(context: &mut CubeContext, cond: ExpandElement, mut block: IF) +where + IF: FnMut(&mut CubeContext), +{ + let mut child = context.child(); + + block(&mut child); + + context.register(Branch::If(gpu::If { + cond: *cond, + scope: child.into_scope(), + })); +} + // pub fn loop_expand(context: &mut CubeContext, mut cond_fn: F, mut block: F) // where // F: FnMut(&mut CubeContext), diff --git a/crates/burn-cube/src/operation/base.rs b/crates/burn-cube/src/operation/base.rs index 59bd48e0d0..b71050fc0b 100644 --- a/crates/burn-cube/src/operation/base.rs +++ b/crates/burn-cube/src/operation/base.rs @@ -1,5 +1,5 @@ use crate::{CubeContext, ExpandElement}; -use burn_jit::gpu::{self, Variable}; +use burn_jit::gpu::{self, Elem, Variable}; pub(crate) fn binary_expand( context: &mut CubeContext, @@ -27,3 +27,35 @@ where out } + +pub(crate) fn cmp_expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + func: F, +) -> ExpandElement +where + F: Fn(gpu::BinaryOperator) -> gpu::Operator, +{ + let lhs: Variable = *lhs; + let rhs: Variable = *rhs; + + let out_item = match lhs.item() { + gpu::Item::Vec4(_) => gpu::Item::Vec4(Elem::Bool), + gpu::Item::Vec3(_) => gpu::Item::Vec3(Elem::Bool), + gpu::Item::Vec2(_) => gpu::Item::Vec2(Elem::Bool), + gpu::Item::Scalar(_) => gpu::Item::Scalar(Elem::Bool), + }; + let out = context.create_local(out_item); + let out_var = *out; + + let op = func(gpu::BinaryOperator { + lhs, + rhs, + out: out_var, + }); + + context.register(op); + + out +} diff --git a/crates/burn-cube/src/operation/cmp.rs b/crates/burn-cube/src/operation/cmp.rs index ad966c61ec..e2afdc7d41 100644 --- a/crates/burn-cube/src/operation/cmp.rs +++ b/crates/burn-cube/src/operation/cmp.rs @@ -1,4 +1,4 @@ -use crate::operation::base::binary_expand; +use crate::operation::base::cmp_expand; use crate::{CubeContext, ExpandElement, Float, Int, UInt}; use burn_jit::gpu::{self}; @@ -26,7 +26,50 @@ impl core::cmp::Eq for Int {} impl core::cmp::Eq for UInt {} +impl core::cmp::PartialOrd for Float { + fn partial_cmp(&self, other: &Self) -> Option { + match self.val.partial_cmp(&other.val) { + Some(core::cmp::Ordering::Equal) => {} + ord => return ord, + } + self.vectorization.partial_cmp(&other.vectorization) + } +} + +impl core::cmp::PartialOrd for Int { + fn partial_cmp(&self, other: &Self) -> Option { + match self.val.partial_cmp(&other.val) { + Some(core::cmp::Ordering::Equal) => {} + ord => return ord, + } + self.vectorization.partial_cmp(&other.vectorization) + } +} + +impl core::cmp::PartialOrd for UInt { + fn partial_cmp(&self, other: &Self) -> Option { + match self.val.partial_cmp(&other.val) { + Some(core::cmp::Ordering::Equal) => {} + ord => return ord, + } + self.vectorization.partial_cmp(&other.vectorization) + } +} + pub mod ne { + + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + cmp_expand(context, lhs, rhs, gpu::Operator::NotEqual) + } +} + +pub mod gt { use super::*; pub fn expand( @@ -34,6 +77,6 @@ pub mod ne { lhs: ExpandElement, rhs: ExpandElement, ) -> ExpandElement { - binary_expand(context, lhs, rhs, gpu::Operator::NotEqual) + cmp_expand(context, lhs, rhs, gpu::Operator::Greater) } } diff --git a/crates/burn-cube/tests/if.rs b/crates/burn-cube/tests/if.rs index f412bf2e47..3d9681224b 100644 --- a/crates/burn-cube/tests/if.rs +++ b/crates/burn-cube/tests/if.rs @@ -1,11 +1,11 @@ -use burn_cube::{cube, CubeContext, Float}; +use burn_cube::{cube, if_expand, CubeContext, Float}; use burn_jit::gpu; use burn_jit::gpu::FloatKind::F32; -use burn_jit::gpu::{Elem, Item}; +use burn_jit::gpu::{Elem, Item, Variable}; #[cube] pub fn if_greater(lhs: Float) { - if lhs > float_new(0.) { + if lhs > float_new(0.0) { let _ = lhs; } } @@ -29,12 +29,10 @@ fn gpu_macro_ref() -> String { let mut scope = context.into_scope(); let cond = scope.create_local(Item::Scalar(Elem::Bool)); - let out = scope.create_local(item); + let lhs: Variable = lhs.into(); gpu!(scope, cond = lhs > 0f32); - gpu!(&mut scope, if(cond).then(|scope|{ - gpu!(scope, out = lhs); - })); + gpu!(&mut scope, if(cond).then(|_scope| {})); format!("{:?}", scope.operations) } From 9cfafd9330e1ae5d9d8d5fb8e1ab1021096f3920 Mon Sep 17 00:00:00 2001 From: louisfd Date: Fri, 3 May 2024 13:08:21 -0400 Subject: [PATCH 20/54] if using variable inside --- crates/burn-cube-macros/src/analysis.rs | 1 + crates/burn-cube-macros/src/codegen.rs | 2 +- crates/burn-cube/tests/if.rs | 9 ++++++--- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs index 9c26137945..a1541ab7ae 100644 --- a/crates/burn-cube-macros/src/analysis.rs +++ b/crates/burn-cube-macros/src/analysis.rs @@ -166,6 +166,7 @@ impl CodeAnalysisBuilder { if expr.else_branch.is_some() { todo!("Analysis: else branch not supported"); } + let depth = depth + 1; self.expr_occurrences(&expr.cond, depth); self.stmts_occurrences(&expr.then_branch.stmts, depth); diff --git a/crates/burn-cube-macros/src/codegen.rs b/crates/burn-cube-macros/src/codegen.rs index 5d2a9aca3e..e2a0cd155b 100644 --- a/crates/burn-cube-macros/src/codegen.rs +++ b/crates/burn-cube-macros/src/codegen.rs @@ -168,7 +168,7 @@ fn codegen_if( let cond = codegen_cond(&expr_if.cond, loop_level, variable_analyses); - let block = codegen_block(&expr_if.then_branch, loop_level, variable_analyses); + let block = codegen_block(&expr_if.then_branch, loop_level + 1, variable_analyses); quote::quote! { let _cond = #cond; diff --git a/crates/burn-cube/tests/if.rs b/crates/burn-cube/tests/if.rs index 3d9681224b..48050beaec 100644 --- a/crates/burn-cube/tests/if.rs +++ b/crates/burn-cube/tests/if.rs @@ -6,7 +6,7 @@ use burn_jit::gpu::{Elem, Item, Variable}; #[cube] pub fn if_greater(lhs: Float) { if lhs > float_new(0.0) { - let _ = lhs; + let _ = lhs + float_new(4.0); } } @@ -30,9 +30,12 @@ fn gpu_macro_ref() -> String { let mut scope = context.into_scope(); let cond = scope.create_local(Item::Scalar(Elem::Bool)); let lhs: Variable = lhs.into(); - gpu!(scope, cond = lhs > 0f32); + let y = scope.create_local(item); - gpu!(&mut scope, if(cond).then(|_scope| {})); + gpu!(scope, cond = lhs > 0f32); + gpu!(&mut scope, if(cond).then(|scope| { + gpu!(scope, y = lhs + 4.0); + })); format!("{:?}", scope.operations) } From 931045128fbf1587b9c89dd41d3414cc404849f9 Mon Sep 17 00:00:00 2001 From: louisfd Date: Fri, 3 May 2024 13:43:53 -0400 Subject: [PATCH 21/54] while loop --- crates/burn-cube-macros/src/codegen.rs | 4 +-- crates/burn-cube/src/branch.rs | 29 +++++++-------- crates/burn-cube/tests/while.rs | 49 ++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 17 deletions(-) create mode 100644 crates/burn-cube/tests/while.rs diff --git a/crates/burn-cube-macros/src/codegen.rs b/crates/burn-cube-macros/src/codegen.rs index e2a0cd155b..460e678e48 100644 --- a/crates/burn-cube-macros/src/codegen.rs +++ b/crates/burn-cube-macros/src/codegen.rs @@ -167,7 +167,6 @@ fn codegen_if( } let cond = codegen_cond(&expr_if.cond, loop_level, variable_analyses); - let block = codegen_block(&expr_if.then_branch, loop_level + 1, variable_analyses); quote::quote! { @@ -181,10 +180,9 @@ fn codegen_while_loop( loop_level: usize, variable_analyses: &mut CodeAnalysis, ) -> TokenStream { + let cond = codegen_cond(&while_loop.cond, loop_level + 1, variable_analyses); let block = codegen_block(&while_loop.body, loop_level + 1, variable_analyses); - let cond = codegen_cond(&while_loop.cond, loop_level, variable_analyses); - quote::quote! { loop_expand(context, |context| #cond, |context| #block); } diff --git a/crates/burn-cube/src/branch.rs b/crates/burn-cube/src/branch.rs index d6a8a735ee..93c7785c3c 100644 --- a/crates/burn-cube/src/branch.rs +++ b/crates/burn-cube/src/branch.rs @@ -70,19 +70,20 @@ where })); } -// pub fn loop_expand(context: &mut CubeContext, mut cond_fn: F, mut block: F) -// where -// F: FnMut(&mut CubeContext), -// { -// let mut child = context.child(); -// // let cond: Variable = cond_fn... +pub fn loop_expand(context: &mut CubeContext, mut cond_fn: FC, mut block: FB) +where + FC: FnMut(&mut CubeContext) -> ExpandElement, + FB: FnMut(&mut CubeContext), +{ + let mut inside_loop = context.child(); + let cond: ExpandElement = cond_fn(&mut inside_loop); -// child.register(Branch::If(gpu::If { -// cond, scope: child.into_scope() -// })); + if_expand(&mut inside_loop, cond, |context| { + context.register(Branch::Break) + }); -// block(&mut child); -// context.register(Branch::Loop(gpu::Loop { -// scope: child.into_scope(), -// })) -// } + block(&mut inside_loop); + context.register(Branch::Loop(gpu::Loop { + scope: inside_loop.into_scope(), + })); +} diff --git a/crates/burn-cube/tests/while.rs b/crates/burn-cube/tests/while.rs new file mode 100644 index 0000000000..adfaf4a7e6 --- /dev/null +++ b/crates/burn-cube/tests/while.rs @@ -0,0 +1,49 @@ +use burn_cube::{cube, loop_expand, CubeContext, Int}; +use burn_jit::gpu; +use burn_jit::gpu::Branch; +use burn_jit::gpu::IntKind::I32; +use burn_jit::gpu::{Elem, Item, Variable}; + +#[cube] +pub fn while_not(lhs: Int) { + while lhs != int_new(0) { + let _ = lhs - int_new(1); + } +} + +#[test] +fn cube_while_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); + + while_not_expand(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); +} + +fn gpu_macro_ref() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Int(I32)); + let lhs = context.create_local(item); + + let mut scope = context.into_scope(); + let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let lhs: Variable = lhs.into(); + let rhs = scope.create_local(item); + + gpu!( + &mut scope, + loop(|scope| { + gpu!(scope, cond = lhs != 0); + gpu!(scope, if(cond).then(|scope|{ + scope.register(Branch::Break); + })); + + gpu!(scope, rhs = lhs - 1i32); + }) + ); + + format!("{:?}", scope.operations) +} From 08f0f6e32fbc4a87afbc508e8382266d741c8822 Mon Sep 17 00:00:00 2001 From: louisfd Date: Fri, 3 May 2024 14:46:53 -0400 Subject: [PATCH 22/54] loop and break --- crates/burn-cube-macros/src/analysis.rs | 6 ++ crates/burn-cube-macros/src/codegen.rs | 23 ++++- crates/burn-cube-macros/src/lib.rs | 25 +++-- crates/burn-cube-macros/src/prelude.rs | 4 +- crates/burn-cube/src/branch.rs | 22 ++++- crates/burn-cube/tests/for_loop.rs | 48 +-------- crates/burn-cube/tests/gcd.rs | 123 ------------------------ crates/burn-cube/tests/if.rs | 2 +- crates/burn-cube/tests/literal.rs | 2 +- crates/burn-cube/tests/loop.rs | 71 ++++++++++++++ crates/burn-cube/tests/while.rs | 98 +++++++++---------- 11 files changed, 192 insertions(+), 232 deletions(-) delete mode 100644 crates/burn-cube/tests/gcd.rs create mode 100644 crates/burn-cube/tests/loop.rs diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs index a1541ab7ae..98132de77d 100644 --- a/crates/burn-cube-macros/src/analysis.rs +++ b/crates/burn-cube-macros/src/analysis.rs @@ -162,6 +162,11 @@ impl CodeAnalysisBuilder { self.expr_occurrences(&expr.cond, depth); self.stmts_occurrences(&expr.body.stmts, depth); } + syn::Expr::Loop(expr) => { + let depth = depth + 1; + + self.stmts_occurrences(&expr.body.stmts, depth); + } syn::Expr::If(expr) => { if expr.else_branch.is_some() { todo!("Analysis: else branch not supported"); @@ -214,6 +219,7 @@ impl CodeAnalysisBuilder { self.expr_occurrences(arg, depth); } } + syn::Expr::Break(_) => {} _ => todo!("Analysis: unsupported expr {expr:?}"), } } diff --git a/crates/burn-cube-macros/src/codegen.rs b/crates/burn-cube-macros/src/codegen.rs index 460e678e48..ed9f6ae8c6 100644 --- a/crates/burn-cube-macros/src/codegen.rs +++ b/crates/burn-cube-macros/src/codegen.rs @@ -1,5 +1,4 @@ use proc_macro2::TokenStream; -use syn::token::If; use crate::analysis::CodeAnalysis; @@ -74,6 +73,8 @@ fn codegen_expr( syn::Expr::While(while_loop) => { codegen_while_loop(while_loop, loop_level, variable_analyses) } + syn::Expr::Loop(loop_expr) => codegen_loop(loop_expr, loop_level, variable_analyses), + syn::Expr::Break(_) => codegen_break(), syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level, variable_analyses), syn::Expr::MethodCall(call) => codegen_expr_method_call(call), syn::Expr::Index(index) => codegen_expr_index(index, loop_level, variable_analyses), @@ -157,6 +158,12 @@ fn codegen_cond( } } +fn codegen_break() -> TokenStream { + quote::quote! { + break_expand(context); + } +} + fn codegen_if( expr_if: &syn::ExprIf, loop_level: usize, @@ -175,6 +182,18 @@ fn codegen_if( } } +fn codegen_loop( + loop_expr: &syn::ExprLoop, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let block = codegen_block(&loop_expr.body, loop_level + 1, variable_analyses); + + quote::quote! { + loop_expand(context, |context| #block); + } +} + fn codegen_while_loop( while_loop: &syn::ExprWhile, loop_level: usize, @@ -184,7 +203,7 @@ fn codegen_while_loop( let block = codegen_block(&while_loop.body, loop_level + 1, variable_analyses); quote::quote! { - loop_expand(context, |context| #cond, |context| #block); + while_loop_expand(context, |context| #cond, |context| #block); } } diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index 0b51f26c60..5aface3824 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -32,6 +32,7 @@ impl From<&syn::Ident> for VariableKey { /// Generate the expanded version of a function marked with the cube macro fn codegen_cube(func: &syn::ItemFn, code_analysis: &mut CodeAnalysis) -> TokenStream { let prelude = get_prelude(&code_analysis.needed_functions); + let mod_name = get_name(&func.sig); let signature = expand_sig(&func.sig); let mut body = quote::quote! {}; @@ -41,13 +42,16 @@ fn codegen_cube(func: &syn::ItemFn, code_analysis: &mut CodeAnalysis) -> TokenSt } let code = quote::quote! { - #prelude + mod #mod_name { + #prelude - #func + #[allow(dead_code)] + #func - #[allow(unused_mut)] - #signature { - #body + #[allow(unused_mut)] + #signature { + #body + } } } .into(); @@ -56,6 +60,15 @@ fn codegen_cube(func: &syn::ItemFn, code_analysis: &mut CodeAnalysis) -> TokenSt code } +fn get_name(sig: &syn::Signature) -> proc_macro2::TokenStream { + let ident = &sig.ident; + + quote::quote! { + #ident + } + .into() +} + fn expand_sig(sig: &syn::Signature) -> proc_macro2::TokenStream { let mut inputs = quote::quote!(); @@ -85,7 +98,7 @@ fn expand_sig(sig: &syn::Signature) -> proc_macro2::TokenStream { } let ident = &sig.ident; - let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span()); + let ident = syn::Ident::new("expand", ident.span()); quote::quote! { pub fn #ident(context: &mut burn_cube::CubeContext, #inputs) -> #output diff --git a/crates/burn-cube-macros/src/prelude.rs b/crates/burn-cube-macros/src/prelude.rs index d979385494..bfda43bd5e 100644 --- a/crates/burn-cube-macros/src/prelude.rs +++ b/crates/burn-cube-macros/src/prelude.rs @@ -3,7 +3,9 @@ use std::collections::HashSet; use crate::VariableKey; pub(crate) fn get_prelude(needed_functions: &HashSet) -> proc_macro2::TokenStream { - let mut prelude = proc_macro2::TokenStream::new(); + let mut prelude = quote::quote! { + use super::*; + }; for func_name in needed_functions .iter() diff --git a/crates/burn-cube/src/branch.rs b/crates/burn-cube/src/branch.rs index 93c7785c3c..db14e1381f 100644 --- a/crates/burn-cube/src/branch.rs +++ b/crates/burn-cube/src/branch.rs @@ -70,7 +70,23 @@ where })); } -pub fn loop_expand(context: &mut CubeContext, mut cond_fn: FC, mut block: FB) +pub fn break_expand(context: &mut CubeContext) { + context.register(Branch::Break); +} + +pub fn loop_expand(context: &mut CubeContext, mut block: FB) +where + FB: FnMut(&mut CubeContext), +{ + let mut inside_loop = context.child(); + + block(&mut inside_loop); + context.register(Branch::Loop(gpu::Loop { + scope: inside_loop.into_scope(), + })); +} + +pub fn while_loop_expand(context: &mut CubeContext, mut cond_fn: FC, mut block: FB) where FC: FnMut(&mut CubeContext) -> ExpandElement, FB: FnMut(&mut CubeContext), @@ -78,9 +94,7 @@ where let mut inside_loop = context.child(); let cond: ExpandElement = cond_fn(&mut inside_loop); - if_expand(&mut inside_loop, cond, |context| { - context.register(Branch::Break) - }); + if_expand(&mut inside_loop, cond, |context| break_expand(context)); block(&mut inside_loop); context.register(Branch::Loop(gpu::Loop { diff --git a/crates/burn-cube/tests/for_loop.rs b/crates/burn-cube/tests/for_loop.rs index 038cf377d9..743b9ebb00 100644 --- a/crates/burn-cube/tests/for_loop.rs +++ b/crates/burn-cube/tests/for_loop.rs @@ -4,7 +4,7 @@ use burn_jit::gpu::FloatKind::F32; use burn_jit::gpu::{Elem, Item, Variable}; #[cube] -pub fn kernel(mut lhs: Array, rhs: Float, end: UInt, unroll: bool) { +pub fn for_loop(mut lhs: Array, rhs: Float, end: UInt, unroll: bool) { let tmp1 = rhs * rhs; let tmp2 = tmp1 + rhs; @@ -13,48 +13,6 @@ pub fn kernel(mut lhs: Array, rhs: Float, end: UInt, unroll: bool) { } } -// #[allow(unused_mut)] -// pub fn kernel_expand( -// context: &mut burn_cube::CubeContext, -// mut lhs: as burn_cube::RuntimeType>::ExpandType, -// rhs: ::ExpandType, -// end: ::ExpandType, -// unroll: ::ExpandType, -// ) -> () { -// let tmp1 = { -// let _lhs = rhs.clone().into(); -// let _rhs = rhs.clone().into(); -// burn_cube::mul::expand(context, _lhs, _rhs) -// }; -// let tmp2 = { -// let _lhs = tmp1.into(); -// let _rhs = rhs.into(); -// burn_cube::add::expand(context, _lhs, _rhs) -// }; -// range_expand( -// context, -// 0u32.into(), -// end.into(), -// unroll.into(), -// |context, i| { -// { -// let _array = lhs.clone().into(); -// let _index = i.clone().into(); -// let _value = { -// let _lhs = tmp2.clone().into(); -// let _rhs = { -// let _array = lhs.clone().into(); -// let _index = i.into(); -// burn_cube::index::expand(context, _array, _index) -// }; -// burn_cube::add::expand(context, _lhs, _rhs) -// }; -// burn_cube::index_assign::expand(context, _array, _index, _value) -// }; -// }, -// ); -// } - #[test] fn test_for_loop_with_unroll() { let mut context = CubeContext::root(); @@ -64,7 +22,7 @@ fn test_for_loop_with_unroll() { let rhs = context.create_local(Item::Scalar(Elem::Float(F32))); let end = 4u32.into(); - kernel_expand(&mut context, lhs, rhs, end, unroll); + for_loop::expand(&mut context, lhs, rhs, end, unroll); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref(unroll)); @@ -79,7 +37,7 @@ fn test_for_loop_no_unroll() { let rhs = context.create_local(Item::Scalar(Elem::Float(F32))); let end = 4u32.into(); - kernel_expand(&mut context, lhs, rhs, end, unroll); + for_loop::expand(&mut context, lhs, rhs, end, unroll); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref(unroll)); diff --git a/crates/burn-cube/tests/gcd.rs b/crates/burn-cube/tests/gcd.rs deleted file mode 100644 index 71a0787233..0000000000 --- a/crates/burn-cube/tests/gcd.rs +++ /dev/null @@ -1,123 +0,0 @@ -// use burn_cube::{cube, loop_expand, CubeContext, Int}; -// use burn_jit::gpu; -// use burn_jit::gpu::Branch; -// use burn_jit::gpu::IntKind::I32; -// use burn_jit::gpu::{Elem, Item, Variable}; - -// #[cube] -// pub fn gcd(lhs: Int, rhs: Int) { -// while rhs != int_new(0) { -// let tmp = rhs; -// // rhs = lhs % rhs; -// // lhs = tmp; -// } -// // TODO: use lhs as output -// } - -// // pub fn int_new(val: i32) -> Int { -// // Int { -// // val, -// // vectorization: 1, -// // } -// // } -// // pub fn int_new_expand( -// // context: &mut CubeContext, -// // val: i32, -// // ) -> ::ExpandType { -// // val.into() -// // } -// // pub fn gcd(lhs: Int, rhs: Int) { -// // while rhs != int_new(0) { -// // let tmp = rhs; -// // } -// // } -// // #[allow(unused_mut)] -// // pub fn gcd_expand( -// // context: &mut burn_cube::CubeContext, -// // lhs: ::ExpandType, -// // rhs: ::ExpandType, -// // ) -> () { -// // loop_expand( -// // context, -// // |context| { -// // let _lhs = rhs.into(); -// // let _rhs = int_new_expand(context, 0.into()); -// // burn_cube::ne::expand(context, _lhs, _rhs) -// // }, -// // |context| { -// // let tmp = rhs.clone().into(); -// // }, -// // ); -// // } - -// // pub fn int_new(val: i32) -> Int { -// // Int { -// // val, -// // vectorization: 1, -// // } -// // } -// // pub fn int_new_expand( -// // context: &mut CubeContext, -// // val: i32, -// // ) -> ::ExpandType { -// // val.into() -// // } -// // pub fn gcd(lhs: Int, rhs: Int) { -// // while rhs != int_new(0) {} -// // } -// // #[allow(unused_mut)] -// // pub fn gcd_expand( -// // context: &mut burn_cube::CubeContext, -// // lhs: ::ExpandType, -// // rhs: ::ExpandType, -// // ) -> () { -// // let _cond = { -// // let _lhs = rhs.into(); -// // let _rhs = int_new_expand(context, 0.into()); -// // burn_cube::ne::expand(context, _lhs, _rhs) -// // }; -// // loop_expand(context, _cond, |context| {}); -// // } - -// #[test] -// fn cube_function_test() { -// let mut context = CubeContext::root(); - -// let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); -// let rhs = context.create_local(Item::Scalar(Elem::Int(I32))); - -// gcd_expand(&mut context, lhs, rhs); -// let scope = context.into_scope(); - -// assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); -// } - -// fn gpu_macro_ref() -> String { -// let mut context = CubeContext::root(); -// let item = Item::Scalar(Elem::Int(I32)); - -// let lhs = context.create_local(item); -// let rhs = context.create_local(item); -// let lhs: Variable = lhs.into(); -// let rhs: Variable = rhs.into(); -// let mut scope = context.into_scope(); - -// // Kernel -// let cond = scope.create_local(Item::Scalar(Elem::Bool)); -// let tmp = scope.create_local(Item::Scalar(Elem::Int(I32))); -// gpu!( -// &mut scope, -// loop(|scope| { -// gpu!(scope, cond = rhs != 0); -// gpu!(scope, if(cond).then(|scope|{ -// scope.register(Branch::Break); -// })); - -// gpu!(scope, tmp = rhs); -// gpu!(scope, rhs = lhs % rhs); -// gpu!(scope, lhs = tmp); -// }) -// ); - -// format!("{:?}", scope.operations) -// } diff --git a/crates/burn-cube/tests/if.rs b/crates/burn-cube/tests/if.rs index 48050beaec..5f8321da45 100644 --- a/crates/burn-cube/tests/if.rs +++ b/crates/burn-cube/tests/if.rs @@ -16,7 +16,7 @@ fn cube_if_test() { let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); - if_greater_expand(&mut context, lhs); + if_greater::expand(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); diff --git a/crates/burn-cube/tests/literal.rs b/crates/burn-cube/tests/literal.rs index 1293d0c943..642eb2c4b7 100644 --- a/crates/burn-cube/tests/literal.rs +++ b/crates/burn-cube/tests/literal.rs @@ -14,7 +14,7 @@ fn cube_literal_test() { let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); - literal_expand(&mut context, lhs); + literal::expand(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); diff --git a/crates/burn-cube/tests/loop.rs b/crates/burn-cube/tests/loop.rs new file mode 100644 index 0000000000..fb7830cf81 --- /dev/null +++ b/crates/burn-cube/tests/loop.rs @@ -0,0 +1,71 @@ +use burn_cube::{break_expand, cube, if_expand, loop_expand, while_loop_expand, CubeContext, Int}; +use burn_jit::gpu; +use burn_jit::gpu::Branch; +use burn_jit::gpu::IntKind::I32; +use burn_jit::gpu::{Elem, Item, Variable}; + +#[cube] +pub fn while_not(lhs: Int) { + while lhs != int_new(0) { + let _ = lhs - int_new(1); + } +} + +#[cube] +pub fn manual_loop_break(lhs: Int) { + loop { + if lhs != int_new(0) { + break; + } + let _ = lhs - int_new(1); + } +} + +#[test] +fn cube_while_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); + + while_not::expand(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); +} + +#[test] +fn cube_loop_break_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); + + manual_loop_break::expand(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); +} + +fn gpu_macro_ref() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Int(I32)); + let lhs = context.create_local(item); + + let mut scope = context.into_scope(); + let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let lhs: Variable = lhs.into(); + let rhs = scope.create_local(item); + + gpu!( + &mut scope, + loop(|scope| { + gpu!(scope, cond = lhs != 0); + gpu!(scope, if(cond).then(|scope|{ + scope.register(Branch::Break); + })); + + gpu!(scope, rhs = lhs - 1i32); + }) + ); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/while.rs b/crates/burn-cube/tests/while.rs index adfaf4a7e6..2a9ae9d92a 100644 --- a/crates/burn-cube/tests/while.rs +++ b/crates/burn-cube/tests/while.rs @@ -1,49 +1,49 @@ -use burn_cube::{cube, loop_expand, CubeContext, Int}; -use burn_jit::gpu; -use burn_jit::gpu::Branch; -use burn_jit::gpu::IntKind::I32; -use burn_jit::gpu::{Elem, Item, Variable}; - -#[cube] -pub fn while_not(lhs: Int) { - while lhs != int_new(0) { - let _ = lhs - int_new(1); - } -} - -#[test] -fn cube_while_test() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); - - while_not_expand(&mut context, lhs); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); -} - -fn gpu_macro_ref() -> String { - let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Int(I32)); - let lhs = context.create_local(item); - - let mut scope = context.into_scope(); - let cond = scope.create_local(Item::Scalar(Elem::Bool)); - let lhs: Variable = lhs.into(); - let rhs = scope.create_local(item); - - gpu!( - &mut scope, - loop(|scope| { - gpu!(scope, cond = lhs != 0); - gpu!(scope, if(cond).then(|scope|{ - scope.register(Branch::Break); - })); - - gpu!(scope, rhs = lhs - 1i32); - }) - ); - - format!("{:?}", scope.operations) -} +// use burn_cube::{cube, while_loop_expand, CubeContext, Int}; +// use burn_jit::gpu; +// use burn_jit::gpu::Branch; +// use burn_jit::gpu::IntKind::I32; +// use burn_jit::gpu::{Elem, Item, Variable}; + +// #[cube] +// pub fn while_not(lhs: Int) { +// while lhs != int_new(0) { +// let _ = lhs - int_new(1); +// } +// } + +// #[test] +// fn cube_while_test() { +// let mut context = CubeContext::root(); + +// let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); + +// while_not_expand(&mut context, lhs); +// let scope = context.into_scope(); + +// assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); +// } + +// fn gpu_macro_ref() -> String { +// let mut context = CubeContext::root(); +// let item = Item::Scalar(Elem::Int(I32)); +// let lhs = context.create_local(item); + +// let mut scope = context.into_scope(); +// let cond = scope.create_local(Item::Scalar(Elem::Bool)); +// let lhs: Variable = lhs.into(); +// let rhs = scope.create_local(item); + +// gpu!( +// &mut scope, +// loop(|scope| { +// gpu!(scope, cond = lhs != 0); +// gpu!(scope, if(cond).then(|scope|{ +// scope.register(Branch::Break); +// })); + +// gpu!(scope, rhs = lhs - 1i32); +// }) +// ); + +// format!("{:?}", scope.operations) +// } From fb65985f4fbf6a19cd57824aecd6deeecb6781dc Mon Sep 17 00:00:00 2001 From: louisfd Date: Fri, 3 May 2024 15:55:03 -0400 Subject: [PATCH 23/54] assign add --- crates/burn-cube-macros/src/analysis.rs | 10 ++- crates/burn-cube-macros/src/codegen.rs | 40 ++++++++--- crates/burn-cube/src/branch.rs | 24 ++++++- crates/burn-cube/src/operation/assignation.rs | 32 +++++++++ crates/burn-cube/src/operation/base.rs | 23 +++++++ crates/burn-cube/src/operation/cmp.rs | 24 +++++++ crates/burn-cube/src/operation/mod.rs | 2 +- crates/burn-cube/tests/if_else.rs | 45 +++++++++++++ crates/burn-cube/tests/reuse.rs | 66 +++++++++++++++++++ crates/burn-cube/tests/while.rs | 49 -------------- 10 files changed, 253 insertions(+), 62 deletions(-) create mode 100644 crates/burn-cube/tests/if_else.rs create mode 100644 crates/burn-cube/tests/reuse.rs delete mode 100644 crates/burn-cube/tests/while.rs diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs index 98132de77d..581b216e6e 100644 --- a/crates/burn-cube-macros/src/analysis.rs +++ b/crates/burn-cube-macros/src/analysis.rs @@ -168,13 +168,17 @@ impl CodeAnalysisBuilder { self.stmts_occurrences(&expr.body.stmts, depth); } syn::Expr::If(expr) => { - if expr.else_branch.is_some() { - todo!("Analysis: else branch not supported"); - } let depth = depth + 1; self.expr_occurrences(&expr.cond, depth); self.stmts_occurrences(&expr.then_branch.stmts, depth); + if let Some((_, expr)) = &expr.else_branch { + if let syn::Expr::Block(expr_block) = &**expr { + self.stmts_occurrences(&expr_block.block.stmts, depth); + } else { + todo!("Analysis: Only block else expr is supported") + } + } } syn::Expr::Assign(expr) => { self.expr_occurrences(&expr.left, depth); diff --git a/crates/burn-cube-macros/src/codegen.rs b/crates/burn-cube-macros/src/codegen.rs index ed9f6ae8c6..5145a278a6 100644 --- a/crates/burn-cube-macros/src/codegen.rs +++ b/crates/burn-cube-macros/src/codegen.rs @@ -169,16 +169,26 @@ fn codegen_if( loop_level: usize, variable_analyses: &mut CodeAnalysis, ) -> TokenStream { - if expr_if.else_branch.is_some() { - todo!("Codegen: else branch not supported"); - } - let cond = codegen_cond(&expr_if.cond, loop_level, variable_analyses); - let block = codegen_block(&expr_if.then_branch, loop_level + 1, variable_analyses); - quote::quote! { - let _cond = #cond; - if_expand(context, _cond, |context| #block); + let then_block = codegen_block(&expr_if.then_branch, loop_level + 1, variable_analyses); + + if let Some((_, expr)) = &expr_if.else_branch { + if let syn::Expr::Block(expr_block) = &**expr { + let else_block = codegen_block(&expr_block.block, loop_level + 1, variable_analyses); + + quote::quote! { + let _cond = #cond; + if_else_expand(context, _cond, |context| #then_block, |context| #else_block); + } + } else { + todo!("Analysis: Only block else expr is supported") + } + } else { + quote::quote! { + let _cond = #cond; + if_expand(context, _cond, |context| #then_block); + } } } @@ -399,6 +409,20 @@ fn codegen_binary( burn_cube::gt::expand(context, _lhs, _rhs) } }, + syn::BinOp::Lt(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::lt::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::AddAssign(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::add_assign_op::expand(context, _lhs, _rhs) + } + }, _ => todo!("Codegen: unsupported op {:?}", binary.op), } } diff --git a/crates/burn-cube/src/branch.rs b/crates/burn-cube/src/branch.rs index db14e1381f..e4a5ea97d0 100644 --- a/crates/burn-cube/src/branch.rs +++ b/crates/burn-cube/src/branch.rs @@ -70,6 +70,28 @@ where })); } +pub fn if_else_expand( + context: &mut CubeContext, + cond: ExpandElement, + mut then_block: IF, + mut else_block: EL, +) where + IF: FnMut(&mut CubeContext), + EL: FnMut(&mut CubeContext), +{ + let mut then_child = context.child(); + then_block(&mut then_child); + + let mut else_child = context.child(); + else_block(&mut else_child); + + context.register(Branch::IfElse(gpu::IfElse { + cond: *cond, + scope_if: then_child.into_scope(), + scope_else: else_child.into_scope(), + })); +} + pub fn break_expand(context: &mut CubeContext) { context.register(Branch::Break); } @@ -92,8 +114,8 @@ where FB: FnMut(&mut CubeContext), { let mut inside_loop = context.child(); - let cond: ExpandElement = cond_fn(&mut inside_loop); + let cond: ExpandElement = cond_fn(&mut inside_loop); if_expand(&mut inside_loop, cond, |context| break_expand(context)); block(&mut inside_loop); diff --git a/crates/burn-cube/src/operation/assignation.rs b/crates/burn-cube/src/operation/assignation.rs index 3ac02e7e79..f317185333 100644 --- a/crates/burn-cube/src/operation/assignation.rs +++ b/crates/burn-cube/src/operation/assignation.rs @@ -66,3 +66,35 @@ pub mod index { } } } + +pub mod add_assign_op { + use crate::{operation::base::assign_op_expand, Float, Int}; + + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + assign_op_expand(context, lhs, rhs, gpu::Operator::Add) + } + + impl core::ops::AddAssign for Float { + fn add_assign(&mut self, rhs: Self) { + self.val += rhs.val + } + } + + impl core::ops::AddAssign for Int { + fn add_assign(&mut self, rhs: Self) { + self.val += rhs.val + } + } + + impl core::ops::AddAssign for UInt { + fn add_assign(&mut self, rhs: Self) { + self.val += rhs.val + } + } +} diff --git a/crates/burn-cube/src/operation/base.rs b/crates/burn-cube/src/operation/base.rs index b71050fc0b..96a921d73f 100644 --- a/crates/burn-cube/src/operation/base.rs +++ b/crates/burn-cube/src/operation/base.rs @@ -59,3 +59,26 @@ where out } + +pub(crate) fn assign_op_expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + func: F, +) -> ExpandElement +where + F: Fn(gpu::BinaryOperator) -> gpu::Operator, +{ + let lhs_var: Variable = *lhs; + let rhs: Variable = *rhs; + + let op = func(gpu::BinaryOperator { + lhs: lhs_var, + rhs, + out: lhs_var, + }); + + context.register(op); + + lhs +} diff --git a/crates/burn-cube/src/operation/cmp.rs b/crates/burn-cube/src/operation/cmp.rs index e2afdc7d41..a422c93b27 100644 --- a/crates/burn-cube/src/operation/cmp.rs +++ b/crates/burn-cube/src/operation/cmp.rs @@ -80,3 +80,27 @@ pub mod gt { cmp_expand(context, lhs, rhs, gpu::Operator::Greater) } } + +pub mod lt { + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + cmp_expand(context, lhs, rhs, gpu::Operator::Lower) + } +} + +pub mod add_assign { + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + cmp_expand(context, lhs, rhs, gpu::Operator::Add) + } +} diff --git a/crates/burn-cube/src/operation/mod.rs b/crates/burn-cube/src/operation/mod.rs index fc4d8c0abe..db2470087f 100644 --- a/crates/burn-cube/src/operation/mod.rs +++ b/crates/burn-cube/src/operation/mod.rs @@ -1,5 +1,5 @@ -mod base; mod assignation; +mod base; mod binary; mod cmp; diff --git a/crates/burn-cube/tests/if_else.rs b/crates/burn-cube/tests/if_else.rs new file mode 100644 index 0000000000..b378c3cddd --- /dev/null +++ b/crates/burn-cube/tests/if_else.rs @@ -0,0 +1,45 @@ +use burn_cube::{cube, if_else_expand, CubeContext, Float}; +use burn_jit::gpu; +use burn_jit::gpu::FloatKind::F32; +use burn_jit::gpu::{Elem, Item, Variable}; + +#[cube] +pub fn if_else(lhs: Float) { + if lhs < float_new(0.0) { + let _ = lhs + float_new(4.0); + } else { + let _ = lhs - float_new(5.0); + } +} + +#[test] +fn cube_if_else_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); + + if_else::expand(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); +} + +fn gpu_macro_ref() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Float(F32)); + let lhs = context.create_local(item); + + let mut scope = context.into_scope(); + let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let lhs: Variable = lhs.into(); + let y = scope.create_local(item); + + gpu!(scope, cond = lhs < 0f32); + gpu!(&mut scope, if(cond).then(|scope| { + gpu!(scope, y = lhs + 4.0); + }).else(|scope|{ + gpu!(scope, y = lhs - 5.0); + })); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/reuse.rs b/crates/burn-cube/tests/reuse.rs new file mode 100644 index 0000000000..c0c7f4b37e --- /dev/null +++ b/crates/burn-cube/tests/reuse.rs @@ -0,0 +1,66 @@ +use burn_cube::{cube, while_loop_expand, CubeContext, Int}; +use burn_jit::gpu; +use burn_jit::gpu::IntKind::I32; +use burn_jit::gpu::{Branch, Elem, Item, Variable}; + +// #[cube] +// pub fn reuse(lhs: Int) { +// while lhs < int_new(10) { +// lhs = lhs + int_new(1); +// } +// } + +#[cube] +pub fn reuse_incr(mut lhs: Int) { + while lhs < int_new(10) { + lhs += int_new(1); + } +} + +// #[test] +// fn cube_reuse_test() { +// let mut context = CubeContext::root(); + +// let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); + +// reuse::expand(&mut context, lhs); +// let scope = context.into_scope(); + +// assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); +// } + +#[test] +fn cube_reuse_incr_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); + + reuse_incr::expand(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); +} + +fn gpu_macro_ref() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Int(I32)); + let lhs = context.create_local(item); + + let mut scope = context.into_scope(); + let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let lhs: Variable = lhs.into(); + + gpu!( + &mut scope, + loop(|scope| { + gpu!(scope, cond = lhs < 10); + gpu!(scope, if(cond).then(|scope|{ + scope.register(Branch::Break); + })); + + gpu!(scope, lhs = lhs + 1); + }) + ); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/while.rs b/crates/burn-cube/tests/while.rs deleted file mode 100644 index 2a9ae9d92a..0000000000 --- a/crates/burn-cube/tests/while.rs +++ /dev/null @@ -1,49 +0,0 @@ -// use burn_cube::{cube, while_loop_expand, CubeContext, Int}; -// use burn_jit::gpu; -// use burn_jit::gpu::Branch; -// use burn_jit::gpu::IntKind::I32; -// use burn_jit::gpu::{Elem, Item, Variable}; - -// #[cube] -// pub fn while_not(lhs: Int) { -// while lhs != int_new(0) { -// let _ = lhs - int_new(1); -// } -// } - -// #[test] -// fn cube_while_test() { -// let mut context = CubeContext::root(); - -// let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); - -// while_not_expand(&mut context, lhs); -// let scope = context.into_scope(); - -// assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); -// } - -// fn gpu_macro_ref() -> String { -// let mut context = CubeContext::root(); -// let item = Item::Scalar(Elem::Int(I32)); -// let lhs = context.create_local(item); - -// let mut scope = context.into_scope(); -// let cond = scope.create_local(Item::Scalar(Elem::Bool)); -// let lhs: Variable = lhs.into(); -// let rhs = scope.create_local(item); - -// gpu!( -// &mut scope, -// loop(|scope| { -// gpu!(scope, cond = lhs != 0); -// gpu!(scope, if(cond).then(|scope|{ -// scope.register(Branch::Break); -// })); - -// gpu!(scope, rhs = lhs - 1i32); -// }) -// ); - -// format!("{:?}", scope.operations) -// } From 8fd3541fbb9c6cdc12e0a1910eefa452f29b2bc5 Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 6 May 2024 08:55:16 -0400 Subject: [PATCH 24/54] wip --- crates/burn-cube/tests/reuse.rs | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/crates/burn-cube/tests/reuse.rs b/crates/burn-cube/tests/reuse.rs index c0c7f4b37e..e1bea2e415 100644 --- a/crates/burn-cube/tests/reuse.rs +++ b/crates/burn-cube/tests/reuse.rs @@ -3,12 +3,21 @@ use burn_jit::gpu; use burn_jit::gpu::IntKind::I32; use burn_jit::gpu::{Branch, Elem, Item, Variable}; -// #[cube] -// pub fn reuse(lhs: Int) { -// while lhs < int_new(10) { -// lhs = lhs + int_new(1); -// } -// } +#[cube] +pub fn reuse_not_in_rhs(mut lhs: Int) { + while lhs < int_new(10) { + lhs = int_new(1); + } +} +// TODO: remove clone in lhs + +#[cube] +pub fn reuse_in_rhs(mut lhs: Int) { + while lhs < int_new(10) { + lhs = lhs + int_new(1); + } +} +// TODO: allow to be borrowed immutably in rhs while being mutable in lhs :o #[cube] pub fn reuse_incr(mut lhs: Int) { @@ -17,18 +26,6 @@ pub fn reuse_incr(mut lhs: Int) { } } -// #[test] -// fn cube_reuse_test() { -// let mut context = CubeContext::root(); - -// let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); - -// reuse::expand(&mut context, lhs); -// let scope = context.into_scope(); - -// assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); -// } - #[test] fn cube_reuse_incr_test() { let mut context = CubeContext::root(); From a14b6532fb4873f4da2d8521b411a72995cf63a2 Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 6 May 2024 13:17:39 -0400 Subject: [PATCH 25/54] variable reuse --- crates/burn-cube-macros/src/analysis.rs | 4 +- crates/burn-cube-macros/src/codegen.rs | 59 ++++++++++------- crates/burn-cube/src/operation/assignation.rs | 8 +-- crates/burn-cube/tests/reuse.rs | 64 ++++++++++++++----- 4 files changed, 90 insertions(+), 45 deletions(-) diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs index 581b216e6e..cb7bf1fdd7 100644 --- a/crates/burn-cube-macros/src/analysis.rs +++ b/crates/burn-cube-macros/src/analysis.rs @@ -72,7 +72,7 @@ impl CodeAnalysisBuilder { let id = declaration.0.clone(); let new_analysis = match variable_analyses.remove(&id) { Some(_) => { - panic!("Multiple variables with the same identifier is not supported") + panic!("Analysis: Multiple variables with the same identifier is not supported") } None => VariableAnalysis { num_used: 0, @@ -85,7 +85,7 @@ impl CodeAnalysisBuilder { for id in self.var_uses.iter() { let prev_analysis = variable_analyses.remove(&id).expect(&format!( - "Variable {:?} should be declared before it's used", + "Analyis: Variable {:?} should be declared before it's used", id )); let new_analysis = VariableAnalysis { diff --git a/crates/burn-cube-macros/src/codegen.rs b/crates/burn-cube-macros/src/codegen.rs index 5145a278a6..613db4b4d5 100644 --- a/crates/burn-cube-macros/src/codegen.rs +++ b/crates/burn-cube-macros/src/codegen.rs @@ -63,7 +63,7 @@ fn codegen_expr( ) -> TokenStream { match expr { syn::Expr::Binary(op) => codegen_binary(op, loop_level, variable_analyses), - syn::Expr::Path(path) => codegen_path(path, loop_level, variable_analyses), + syn::Expr::Path(path) => codegen_path_rhs(path, loop_level, variable_analyses), syn::Expr::Call(call) => codegen_call(call, loop_level, variable_analyses), syn::Expr::Lit(lit) => codegen_lit(lit), syn::Expr::Closure(closure) => codegen_closure(closure, loop_level, variable_analyses), @@ -222,30 +222,34 @@ fn codegen_assign( loop_level: usize, variable_analyses: &mut CodeAnalysis, ) -> TokenStream { - if let syn::Expr::Index(index) = assign.left.as_ref() { - let array = codegen_expr(&index.expr, loop_level, variable_analyses); - let index = codegen_expr(&index.index, loop_level, variable_analyses); - let value = codegen_expr(&assign.right, loop_level, variable_analyses); + match assign.left.as_ref() { + syn::Expr::Index(index) => { + let array = codegen_expr(&index.expr, loop_level, variable_analyses); + let index = codegen_expr(&index.index, loop_level, variable_analyses); + let value = codegen_expr(&assign.right, loop_level, variable_analyses); - return quote::quote! { - { - let _array = #array; - let _index = #index; - let _value = #value; - burn_cube::index_assign::expand(context, _array, _index, _value) + quote::quote! { + { + let _array = #array; + let _index = #index; + let _value = #value; + burn_cube::index_assign::expand(context, _array, _index, _value) + } } - }; - }; - - let lhs = codegen_expr(&assign.left, loop_level, variable_analyses); - let rhs = codegen_expr(&assign.right, loop_level, variable_analyses); + } + syn::Expr::Path(_) => { + let lhs = codegen_expr(&assign.left, loop_level, variable_analyses); + let rhs = codegen_expr(&assign.right, loop_level, variable_analyses); - quote::quote! { - { - let _assign_lhs = #lhs; - let _assign_rhs = #rhs; - #lhs = burn_cube::assign::expand(context, _assign_lhs, _assign_rhs) + quote::quote! { + { + let _assign_lhs = #lhs; + let _assign_rhs = #rhs; + burn_cube::assign::expand(context, _assign_rhs, _assign_lhs) + } + } } + _ => todo!("Assign of expr {:?} unsupported", assign.left), } } @@ -328,7 +332,7 @@ fn codegen_call( } } -fn codegen_path( +fn codegen_path_rhs( path: &syn::ExprPath, loop_level: usize, variable_analyses: &mut CodeAnalysis, @@ -351,6 +355,17 @@ fn codegen_path( } } +fn codegen_path_lhs(path: &syn::ExprPath) -> TokenStream { + let ident = path + .path + .get_ident() + .expect("Codegen: Only ident path are supported."); + + quote::quote! { + #ident + } +} + fn codegen_binary( binary: &syn::ExprBinary, loop_level: usize, diff --git a/crates/burn-cube/src/operation/assignation.rs b/crates/burn-cube/src/operation/assignation.rs index f317185333..ffefcb3d09 100644 --- a/crates/burn-cube/src/operation/assignation.rs +++ b/crates/burn-cube/src/operation/assignation.rs @@ -4,17 +4,13 @@ use burn_jit::gpu::{self}; pub mod assign { use super::*; - pub fn expand( - context: &mut CubeContext, - input: ExpandElement, - output: ExpandElement, - ) -> ExpandElement { + pub fn expand(context: &mut CubeContext, input: ExpandElement, output: ExpandElement) { let input = *input; let out = *output; context.register(gpu::Operator::Assign(gpu::UnaryOperator { input, out })); - output + // output } } diff --git a/crates/burn-cube/tests/reuse.rs b/crates/burn-cube/tests/reuse.rs index e1bea2e415..ada076a811 100644 --- a/crates/burn-cube/tests/reuse.rs +++ b/crates/burn-cube/tests/reuse.rs @@ -3,27 +3,35 @@ use burn_jit::gpu; use burn_jit::gpu::IntKind::I32; use burn_jit::gpu::{Branch, Elem, Item, Variable}; +// TODO +// a += b is more efficient than a = a + b +// because the latter does not assume that a is the same in lhs and rhs +// It could be detected and optimized + #[cube] -pub fn reuse_not_in_rhs(mut lhs: Int) { - while lhs < int_new(10) { - lhs = int_new(1); +pub fn reuse(mut x: Int) { + while x < int_new(10) { + x = x + int_new(1); } } -// TODO: remove clone in lhs #[cube] -pub fn reuse_in_rhs(mut lhs: Int) { - while lhs < int_new(10) { - lhs = lhs + int_new(1); +pub fn reuse_incr(mut x: Int) { + while x < int_new(10) { + x += int_new(1); } } -// TODO: allow to be borrowed immutably in rhs while being mutable in lhs :o -#[cube] -pub fn reuse_incr(mut lhs: Int) { - while lhs < int_new(10) { - lhs += int_new(1); - } +#[test] +fn cube_reuse_assign_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); + + reuse::expand(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_assign()); } #[test] @@ -35,10 +43,36 @@ fn cube_reuse_incr_test() { reuse_incr::expand(&mut context, lhs); let scope = context.into_scope(); - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_incr()); +} + +fn gpu_macro_ref_assign() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Int(I32)); + let lhs = context.create_local(item); + + let mut scope = context.into_scope(); + let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let lhs: Variable = lhs.into(); + let tmp = scope.create_local(item); + + gpu!( + &mut scope, + loop(|scope| { + gpu!(scope, cond = lhs < 10); + gpu!(scope, if(cond).then(|scope|{ + scope.register(Branch::Break); + })); + + gpu!(scope, tmp = lhs + 1); + gpu!(scope, lhs = tmp); + }) + ); + + format!("{:?}", scope.operations) } -fn gpu_macro_ref() -> String { +fn gpu_macro_ref_incr() -> String { let mut context = CubeContext::root(); let item = Item::Scalar(Elem::Int(I32)); let lhs = context.create_local(item); From fab9201e57b60d58f2c27a3e753f995ee21dad90 Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 6 May 2024 18:31:14 -0400 Subject: [PATCH 26/54] cast elem --- crates/burn-cube-macros/src/codegen.rs | 27 +- crates/burn-cube-macros/src/prelude.rs | 145 +++++++++++ crates/burn-cube/src/operation/binary.rs | 44 ++++ crates/burn-cube/tests/cast_elem.rs | 306 +++++++++++++++++++++++ crates/burn-cube/tests/reuse.rs | 26 +- 5 files changed, 524 insertions(+), 24 deletions(-) create mode 100644 crates/burn-cube/tests/cast_elem.rs diff --git a/crates/burn-cube-macros/src/codegen.rs b/crates/burn-cube-macros/src/codegen.rs index 613db4b4d5..8ffe39a1d1 100644 --- a/crates/burn-cube-macros/src/codegen.rs +++ b/crates/burn-cube-macros/src/codegen.rs @@ -355,17 +355,6 @@ fn codegen_path_rhs( } } -fn codegen_path_lhs(path: &syn::ExprPath) -> TokenStream { - let ident = path - .path - .get_ident() - .expect("Codegen: Only ident path are supported."); - - quote::quote! { - #ident - } -} - fn codegen_binary( binary: &syn::ExprBinary, loop_level: usize, @@ -438,6 +427,22 @@ fn codegen_binary( burn_cube::add_assign_op::expand(context, _lhs, _rhs) } }, + syn::BinOp::BitAnd(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::and::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::And(_) => unimplemented!("Logical and (&&) not overridable in Rust due to its short circuiting nature. Use bitwise instead (&). "), + syn::BinOp::BitOr(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::or::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::Or(_) => unimplemented!("Logical or (||) not overridable in Rust due to its short circuiting nature. Use bitwise instead (|). "), _ => todo!("Codegen: unsupported op {:?}", binary.op), } } diff --git a/crates/burn-cube-macros/src/prelude.rs b/crates/burn-cube-macros/src/prelude.rs index bfda43bd5e..ded3fb5e11 100644 --- a/crates/burn-cube-macros/src/prelude.rs +++ b/crates/burn-cube-macros/src/prelude.rs @@ -14,6 +14,12 @@ pub(crate) fn get_prelude(needed_functions: &HashSet) -> proc_macro let func_code = match func_name { "float_new" => Some(codegen_float_new()), "int_new" => Some(codegen_int_new()), + "uint_new" => Some(codegen_uint_new()), + "bool_new" => Some(codegen_bool_new()), + "to_int" => Some(codegen_to_int()), + "to_float" => Some(codegen_to_float()), + "to_uint" => Some(codegen_to_uint()), + "to_bool" => Some(codegen_to_bool()), _ => None, }; @@ -59,3 +65,142 @@ fn codegen_int_new() -> proc_macro2::TokenStream { } } } + +fn codegen_uint_new() -> proc_macro2::TokenStream { + quote::quote! { + pub fn uint_new(val: u32) -> UInt { + UInt { + val, + vectorization: 1, + } + } + pub fn uint_new_expand( + context: &mut CubeContext, + val: u32, + ) -> ::ExpandType { + val.into() + } + } +} + +fn codegen_bool_new() -> proc_macro2::TokenStream { + quote::quote! { + pub fn bool_new(val: bool) -> Bool{ + Bool { + val, + vectorization: 1, + } + } + pub fn bool_new_expand( + context: &mut CubeContext, + val: bool, + ) -> ::ExpandType { + val.into() + } + } +} + +fn codegen_to_int() -> proc_macro2::TokenStream { + quote::quote! { + pub fn to_int(input: R) -> Int { + Int { + val: 0, + vectorization: 1, + } + } + pub fn to_int_expand( + context: &mut CubeContext, + val: burn_cube::ExpandElement, + ) -> ::ExpandType { + let elem = Elem::Int(I32); + let new_var = context.create_local(match val.item() { + Item::Vec4(_) => Item::Vec4(elem), + Item::Vec3(_) => Item::Vec3(elem), + Item::Vec2(_) => Item::Vec2(elem), + Item::Scalar(_) => Item::Scalar(elem), + }); + burn_cube::assign::expand(context, val.into(), new_var.clone()); + new_var + } + } +} + +fn codegen_to_float() -> proc_macro2::TokenStream { + quote::quote! { + pub fn to_float(input: R) -> Float { + Float { + val: 0., + vectorization: 1, + } + // TODO: make val and vectorization accessible through trait + // Float { + // val: input.val as f32, + // vectorization: input.vectorization, + // } + } + pub fn to_float_expand( + context: &mut CubeContext, + val: burn_cube::ExpandElement, + ) -> ::ExpandType { + let elem = Elem::Float(F32); + let new_var = context.create_local(match val.item() { + Item::Vec4(_) => Item::Vec4(elem), + Item::Vec3(_) => Item::Vec3(elem), + Item::Vec2(_) => Item::Vec2(elem), + Item::Scalar(_) => Item::Scalar(elem), + }); + burn_cube::assign::expand(context, val.into(), new_var.clone()); + new_var + } + } +} + +fn codegen_to_uint() -> proc_macro2::TokenStream { + quote::quote! { + pub fn to_uint(input: R) -> UInt { + UInt { + val: 0, + vectorization: 1, + } + } + pub fn to_uint_expand( + context: &mut CubeContext, + val: burn_cube::ExpandElement, + ) -> ::ExpandType { + let elem = Elem::UInt; + let new_var = context.create_local(match val.item() { + Item::Vec4(_) => Item::Vec4(elem), + Item::Vec3(_) => Item::Vec3(elem), + Item::Vec2(_) => Item::Vec2(elem), + Item::Scalar(_) => Item::Scalar(elem), + }); + burn_cube::assign::expand(context, val.into(), new_var.clone()); + new_var + } + } +} + +fn codegen_to_bool() -> proc_macro2::TokenStream { + quote::quote! { + pub fn to_bool(input: R) -> Bool { + Bool { + val: true, + vectorization: 1, + } + } + pub fn to_bool_expand( + context: &mut CubeContext, + val: burn_cube::ExpandElement, + ) -> ::ExpandType { + let elem = Elem::Bool; + let new_var = context.create_local(match val.item() { + Item::Vec4(_) => Item::Vec4(elem), + Item::Vec3(_) => Item::Vec3(elem), + Item::Vec2(_) => Item::Vec2(elem), + Item::Scalar(_) => Item::Scalar(elem), + }); + burn_cube::assign::expand(context, val.into(), new_var.clone()); + new_var + } + } +} diff --git a/crates/burn-cube/src/operation/binary.rs b/crates/burn-cube/src/operation/binary.rs index 6c5c325da2..08b79961e7 100644 --- a/crates/burn-cube/src/operation/binary.rs +++ b/crates/burn-cube/src/operation/binary.rs @@ -181,3 +181,47 @@ pub mod rem { } } } + +pub mod and { + use crate::Bool; + + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + binary_expand(context, lhs, rhs, gpu::Operator::And) + } + + impl core::ops::BitAnd for Bool { + type Output = Bool; + + fn bitand(self, rhs: Self) -> Self::Output { + Bool::new(self.val && rhs.val, 1) + } + } +} + +pub mod or { + use crate::Bool; + + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + binary_expand(context, lhs, rhs, gpu::Operator::Or) + } + + impl core::ops::BitOr for Bool { + type Output = Bool; + + fn bitor(self, rhs: Self) -> Self::Output { + Bool::new(self.val || rhs.val, 1) + } + } +} diff --git a/crates/burn-cube/tests/cast_elem.rs b/crates/burn-cube/tests/cast_elem.rs new file mode 100644 index 0000000000..075e57e701 --- /dev/null +++ b/crates/burn-cube/tests/cast_elem.rs @@ -0,0 +1,306 @@ +use burn_cube::{cube, Bool, CubeContext, Float, Int, UInt}; +use burn_jit::gpu; +use burn_jit::gpu::FloatKind::F32; +use burn_jit::gpu::IntKind::I32; +use burn_jit::gpu::{Elem, Item, Variable}; + +macro_rules! cast_test { + ($name:ident, $module:ident, $from:expr, $to:expr) => { + #[test] + fn $name() { + let mut context = CubeContext::root(); + + let x = context.create_local($from); + + $module::expand(&mut context, x); + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + gpu_macro_ref_cast($from, $to) + ); + } + }; + + ($name:ident, $module:ident, $ty:expr) => { + #[test] + fn $name() { + let mut context = CubeContext::root(); + + let x = context.create_local($ty); + + $module::expand(&mut context, x); + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + gpu_macro_ref_identity($ty) + ); + } + }; +} + +// From float +#[cube] +pub fn float_to_float(x: Float) { + let y = x + float_new(2.0); + let _ = to_float(y) + float_new(34.0); +} + +#[cube] +pub fn float_to_int(x: Float) { + let y = x + float_new(2.0); + let _ = to_int(y) + int_new(34); +} + +#[cube] +pub fn float_to_uint(x: Float) { + let y = x + float_new(2.0); + let _ = to_uint(y) + uint_new(34u32); +} + +#[cube] +pub fn float_to_bool(x: Float) { + let y = x + float_new(2.0); + let _ = to_bool(y) | bool_new(true); +} + +cast_test!( + cube_float_to_int_test, + float_to_int, + Item::Scalar(Elem::Float(F32)), + Item::Scalar(Elem::Int(I32)) +); + +cast_test!( + cube_float_to_uint_test, + float_to_uint, + Item::Scalar(Elem::Float(F32)), + Item::Scalar(Elem::UInt) +); + +cast_test!( + cube_float_to_float_test, + float_to_float, + Item::Scalar(Elem::Float(F32)) +); + +cast_test!( + cube_float_to_bool_test, + float_to_bool, + Item::Scalar(Elem::Float(F32)), + Item::Scalar(Elem::Bool) +); + +// From int +#[cube] +pub fn int_to_float(x: Int) { + let y = x + int_new(2); + let _ = to_float(y) + float_new(34.0); +} + +#[cube] +pub fn int_to_int(x: Int) { + let y = x + int_new(2); + let _ = to_int(y) + int_new(34); +} + +#[cube] +pub fn int_to_uint(x: Int) { + let y = x + int_new(2); + let _ = to_uint(y) + uint_new(34u32); +} + +#[cube] +pub fn int_to_bool(x: Int) { + let y = x + int_new(2); + let _ = to_bool(y) | bool_new(true); +} + +cast_test!( + cube_int_to_float_test, + int_to_float, + Item::Scalar(Elem::Int(I32)), + Item::Scalar(Elem::Float(F32)) +); + +cast_test!( + cube_int_to_int_test, + int_to_int, + Item::Scalar(Elem::Int(I32)) +); + +cast_test!( + cube_int_to_uint_test, + int_to_uint, + Item::Scalar(Elem::Int(I32)), + Item::Scalar(Elem::UInt) +); + +cast_test!( + cube_int_to_bool_test, + int_to_bool, + Item::Scalar(Elem::Int(I32)), + Item::Scalar(Elem::Bool) +); + +// From uint +#[cube] +pub fn uint_to_float(x: UInt) { + let y = x + uint_new(2u32); + let _ = to_float(y) + float_new(34.0); +} + +#[cube] +pub fn uint_to_int(x: UInt) { + let y = x + uint_new(2u32); + let _ = to_int(y) + int_new(34); +} + +#[cube] +pub fn uint_to_uint(x: UInt) { + let y = x + uint_new(2u32); + let _ = to_uint(y) + uint_new(34u32); +} + +#[cube] +pub fn uint_to_bool(x: UInt) { + let y = x + uint_new(2u32); + let _ = to_bool(y) | bool_new(true); +} + +cast_test!( + cube_uint_to_float_test, + uint_to_float, + Item::Scalar(Elem::UInt), + Item::Scalar(Elem::Float(F32)) +); + +cast_test!( + cube_uint_to_int_test, + uint_to_int, + Item::Scalar(Elem::UInt), + Item::Scalar(Elem::Int(I32)) +); + +cast_test!( + cube_uint_to_uint_test, + uint_to_uint, + Item::Scalar(Elem::UInt) +); + +cast_test!( + cube_uint_to_bool_test, + uint_to_bool, + Item::Scalar(Elem::UInt), + Item::Scalar(Elem::Bool) +); + +// From bool +#[cube] +pub fn bool_to_float(x: Bool) { + let y = x & bool_new(false); + let _ = to_float(y) + float_new(34.0); +} + +#[cube] +pub fn bool_to_int(x: Bool) { + let y = x & bool_new(false); + let _ = to_int(y) + int_new(34); +} + +#[cube] +pub fn bool_to_uint(x: Bool) { + let y = x & bool_new(false); + let _ = to_uint(y) + uint_new(34u32); +} + +#[cube] +pub fn bool_to_bool(x: Bool) { + let y = x & bool_new(false); + let _ = to_bool(y) | bool_new(true); +} + +cast_test!( + cube_bool_to_float_test, + bool_to_float, + Item::Scalar(Elem::Bool), + Item::Scalar(Elem::Float(F32)) +); + +cast_test!( + cube_bool_to_int_test, + bool_to_int, + Item::Scalar(Elem::Bool), + Item::Scalar(Elem::Int(I32)) +); + +cast_test!( + cube_bool_to_uint_test, + bool_to_uint, + Item::Scalar(Elem::Bool), + Item::Scalar(Elem::UInt) +); + +cast_test!( + cube_bool_to_bool_test, + bool_to_bool, + Item::Scalar(Elem::Bool) +); + +fn gpu_macro_ref_cast(from_item: Item, to_item: Item) -> String { + let mut context = CubeContext::root(); + let x = context.create_local(from_item); + + let mut scope = context.into_scope(); + let x: Variable = x.into(); + let y = scope.create_local(from_item); + let y_casted = scope.create_local(to_item); + let z = scope.create_local(to_item); + + match from_item.elem() { + Elem::Float(_) => gpu!(scope, y = x + 2f32), + Elem::Int(_) => gpu!(scope, y = x + 2i32), + Elem::UInt => gpu!(scope, y = x + 2u32), + Elem::Bool => gpu!(scope, y = x && false), + } + + gpu!(scope, y_casted = cast(y)); + + match to_item.elem() { + Elem::Float(_) => gpu!(scope, z = y_casted + 34f32), + Elem::Int(_) => gpu!(scope, z = y_casted + 34i32), + Elem::UInt => gpu!(scope, z = y_casted + 34u32), + Elem::Bool => gpu!(scope, z = y_casted || true), + } + + format!("{:?}", scope.operations) +} + +fn gpu_macro_ref_identity(item: Item) -> String { + // When staying with the same type variables are automatically reused in cube + let mut context = CubeContext::root(); + let x = context.create_local(item); + + let mut scope = context.into_scope(); + let x: Variable = x.into(); + let y = scope.create_local(item); + + match item.elem() { + Elem::Float(_) => gpu!(scope, y = x + 2f32), + Elem::Int(_) => gpu!(scope, y = x + 2i32), + Elem::UInt => gpu!(scope, y = x + 2u32), + Elem::Bool => gpu!(scope, y = x && false), + } + + gpu!(scope, x = cast(y)); + + match item.elem() { + Elem::Float(_) => gpu!(scope, y = x + 34f32), + Elem::Int(_) => gpu!(scope, y = x + 34i32), + Elem::UInt => gpu!(scope, y = x + 34u32), + Elem::Bool => gpu!(scope, y = x || true), + } + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/reuse.rs b/crates/burn-cube/tests/reuse.rs index ada076a811..e0a441c8c6 100644 --- a/crates/burn-cube/tests/reuse.rs +++ b/crates/burn-cube/tests/reuse.rs @@ -26,9 +26,9 @@ pub fn reuse_incr(mut x: Int) { fn cube_reuse_assign_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); + let x = context.create_local(Item::Scalar(Elem::Int(I32))); - reuse::expand(&mut context, lhs); + reuse::expand(&mut context, x); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_assign()); @@ -38,9 +38,9 @@ fn cube_reuse_assign_test() { fn cube_reuse_incr_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); + let x = context.create_local(Item::Scalar(Elem::Int(I32))); - reuse_incr::expand(&mut context, lhs); + reuse_incr::expand(&mut context, x); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_incr()); @@ -49,23 +49,23 @@ fn cube_reuse_incr_test() { fn gpu_macro_ref_assign() -> String { let mut context = CubeContext::root(); let item = Item::Scalar(Elem::Int(I32)); - let lhs = context.create_local(item); + let x = context.create_local(item); let mut scope = context.into_scope(); let cond = scope.create_local(Item::Scalar(Elem::Bool)); - let lhs: Variable = lhs.into(); + let x: Variable = x.into(); let tmp = scope.create_local(item); gpu!( &mut scope, loop(|scope| { - gpu!(scope, cond = lhs < 10); + gpu!(scope, cond = x < 10); gpu!(scope, if(cond).then(|scope|{ scope.register(Branch::Break); })); - gpu!(scope, tmp = lhs + 1); - gpu!(scope, lhs = tmp); + gpu!(scope, tmp = x + 1); + gpu!(scope, x = tmp); }) ); @@ -75,21 +75,21 @@ fn gpu_macro_ref_assign() -> String { fn gpu_macro_ref_incr() -> String { let mut context = CubeContext::root(); let item = Item::Scalar(Elem::Int(I32)); - let lhs = context.create_local(item); + let x = context.create_local(item); let mut scope = context.into_scope(); let cond = scope.create_local(Item::Scalar(Elem::Bool)); - let lhs: Variable = lhs.into(); + let x: Variable = x.into(); gpu!( &mut scope, loop(|scope| { - gpu!(scope, cond = lhs < 10); + gpu!(scope, cond = x < 10); gpu!(scope, if(cond).then(|scope|{ scope.register(Branch::Break); })); - gpu!(scope, lhs = lhs + 1); + gpu!(scope, x = x + 1); }) ); From b9bac61c2aec07c1809d536cba6b6eac755ff590 Mon Sep 17 00:00:00 2001 From: louisfd Date: Wed, 8 May 2024 07:44:32 -0400 Subject: [PATCH 27/54] wip cast --- crates/burn-cube-macros/src/analysis.rs | 26 +++++-- crates/burn-cube-macros/src/codegen.rs | 32 +++++++-- crates/burn-cube-macros/src/lib.rs | 37 ++++++++-- crates/burn-cube-macros/src/prelude.rs | 38 +++++----- crates/burn-cube/src/context.rs | 6 -- crates/burn-cube/src/element.rs | 61 ++++++++++++++-- crates/burn-cube/src/operation/assignation.rs | 4 +- crates/burn-cube/src/operation/binary.rs | 22 ++++-- crates/burn-cube/src/operation/cmp.rs | 8 +-- crates/burn-cube/tests/cast_elem.rs | 70 +++++++++---------- crates/burn-cube/tests/cast_kind.rs | 43 ++++++++++++ crates/burn-cube/tests/for_loop.rs | 8 +-- crates/burn-cube/tests/if.rs | 12 ++-- crates/burn-cube/tests/if_else.rs | 16 ++--- crates/burn-cube/tests/literal.rs | 12 ++-- .../src/codegen/dialect/gpu/macros.rs | 6 ++ 16 files changed, 280 insertions(+), 121 deletions(-) create mode 100644 crates/burn-cube/tests/cast_kind.rs diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs index cb7bf1fdd7..1f6cd9c41a 100644 --- a/crates/burn-cube-macros/src/analysis.rs +++ b/crates/burn-cube-macros/src/analysis.rs @@ -1,6 +1,6 @@ use std::collections::{HashMap, HashSet}; -use syn::Stmt; +use syn::{PathArguments, Stmt}; use crate::VariableKey; @@ -205,11 +205,25 @@ impl CodeAnalysisBuilder { syn::Expr::Call(expr) => { match &*expr.func { syn::Expr::Path(expr_path) => { - let ident = expr_path - .path - .get_ident() - .expect("Analysis: only ident supported for function call"); - self.function_calls.insert(ident.into()); + if let Some(first_segment) = expr_path.path.segments.first() { + // Extract the identifier of the path segment + let ident = &first_segment.ident; + self.function_calls.insert(ident.into()); + + // Check if the path segment has generic arguments + if let PathArguments::AngleBracketed(arguments) = + &first_segment.arguments + { + // Extract the generic arguments + for arg in &arguments.args { + match arg { + syn::GenericArgument::Type(_) + | syn::GenericArgument::Constraint(_) => {} + _ => todo!("Analysis: Generic {:?} not supported", arg), + } + } + } + } } _ => todo!("Analysis: unsupported func expr {:?}", expr.func), } diff --git a/crates/burn-cube-macros/src/codegen.rs b/crates/burn-cube-macros/src/codegen.rs index 8ffe39a1d1..b5b4d3e2da 100644 --- a/crates/burn-cube-macros/src/codegen.rs +++ b/crates/burn-cube-macros/src/codegen.rs @@ -1,4 +1,5 @@ use proc_macro2::TokenStream; +use syn::PathArguments; use crate::analysis::CodeAnalysis; @@ -307,18 +308,35 @@ fn codegen_call( loop_level: usize, variable_analyses: &mut CodeAnalysis, ) -> TokenStream { - let func_name = match call.func.as_ref() { - syn::Expr::Path(path) => path - .path - .get_ident() - .expect("Codegen: func called path should have ident"), - _ => todo!("Codegen: Only path call supported"), + let (func_name, generics) = match call.func.as_ref() { + syn::Expr::Path(expr_path) => { + if let Some(first_segment) = expr_path.path.segments.first() { + // Extract the identifier of the path segment + let ident = &first_segment.ident; + let generics = + if let PathArguments::AngleBracketed(arguments) = &first_segment.arguments { + Some(arguments) + } else { + None + }; + + (ident, generics) + } else { + panic!("Codegen: func call must have an ident"); + } + } + _ => todo!("Codegen: func call {:?} not supported", call.func), }; let mut args = quote::quote! { context, }; + let generics = match generics { + Some(generics) => quote::quote! { #generics }, + None => quote::quote! {}, + }; + let func_name_expand = syn::Ident::new(format!("{func_name}_expand").as_str(), func_name.span()); @@ -328,7 +346,7 @@ fn codegen_call( } quote::quote! { - #func_name_expand(#args) + #func_name_expand #generics (#args) } } diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index 5aface3824..b91f45935c 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -6,14 +6,42 @@ use analysis::CodeAnalysis; use codegen::codegen_statement; use prelude::get_prelude; use proc_macro::TokenStream; +use quote::ToTokens; +use syn::{parse_macro_input, punctuated::Punctuated, Meta}; /// Derive macro for the module. #[proc_macro_attribute] -pub fn cube(_attr: TokenStream, tokens: TokenStream) -> TokenStream { +pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream { + let args = parse_macro_input!(attr with Punctuated::::parse_terminated); + + let mut panic_mode = false; + if let Some(arg) = args.first() { + match arg { + Meta::Path(path) => { + if let Some(ident) = path.get_ident().map(|id| id.to_string()) { + match ident.as_str() { + "panic" => { + panic_mode = true; + } + _ => panic!("Attribute {ident} is not supported"), + } + } else { + panic!("Only ident attribute supported"); + } + } + Meta::List(_) => panic!("No List attribute supported"), + Meta::NameValue(_) => panic!("No NameValue attribute supported"), + } + } + let func: syn::ItemFn = syn::parse(tokens).unwrap(); let mut variable_analyses = CodeAnalysis::create(&func); - codegen_cube(&func, &mut variable_analyses) + let code = codegen_cube(&func, &mut variable_analyses); + match panic_mode { + true => panic!("{code}"), + false => code, + } } #[derive(Hash, PartialEq, Eq, Debug, Clone)] @@ -55,7 +83,6 @@ fn codegen_cube(func: &syn::ItemFn, code_analysis: &mut CodeAnalysis) -> TokenSt } } .into(); - // panic!("{code}"); code } @@ -100,8 +127,10 @@ fn expand_sig(sig: &syn::Signature) -> proc_macro2::TokenStream { let ident = &sig.ident; let ident = syn::Ident::new("expand", ident.span()); + let generics = sig.generics.clone().into_token_stream(); + quote::quote! { - pub fn #ident(context: &mut burn_cube::CubeContext, #inputs) -> #output + pub fn #ident #generics (context: &mut burn_cube::CubeContext, #inputs) -> #output } .into() } diff --git a/crates/burn-cube-macros/src/prelude.rs b/crates/burn-cube-macros/src/prelude.rs index ded3fb5e11..01cc702be3 100644 --- a/crates/burn-cube-macros/src/prelude.rs +++ b/crates/burn-cube-macros/src/prelude.rs @@ -33,18 +33,19 @@ pub(crate) fn get_prelude(needed_functions: &HashSet) -> proc_macro fn codegen_float_new() -> proc_macro2::TokenStream { quote::quote! { - pub fn float_new(val: f32) -> Float { - Float { - val, - vectorization: 1, - } + use std::{rc::Rc}; + use burn_cube::ExpandElement; + use burn_jit::gpu::Variable; + pub fn float_new(val: f32) -> Float { + Float::new(val, 1) } - pub fn float_new_expand( + pub fn float_new_expand( context: &mut CubeContext, val: f32, - ) -> ::ExpandType { - // TODO: 0. becomes 0..into() - val.into() + ) -> as burn_cube::RuntimeType>::ExpandType { + let elem = F::to_elem(); + let new_var = Variable::ConstantScalar(val as f64, elem); + ExpandElement::new(Rc::new(new_var)) } } } @@ -126,23 +127,18 @@ fn codegen_to_int() -> proc_macro2::TokenStream { } fn codegen_to_float() -> proc_macro2::TokenStream { + // R: type we come from + // F: kind of float we want as output quote::quote! { - pub fn to_float(input: R) -> Float { - Float { - val: 0., - vectorization: 1, - } + pub fn to_float(input: R) -> Float { // TODO: make val and vectorization accessible through trait - // Float { - // val: input.val as f32, - // vectorization: input.vectorization, - // } + Float::new(0., 1) } - pub fn to_float_expand( + pub fn to_float_expand( context: &mut CubeContext, val: burn_cube::ExpandElement, - ) -> ::ExpandType { - let elem = Elem::Float(F32); + ) -> as burn_cube::RuntimeType>::ExpandType { + let elem = F::to_elem(); let new_var = context.create_local(match val.item() { Item::Vec4(_) => Item::Vec4(elem), Item::Vec3(_) => Item::Vec3(elem), diff --git a/crates/burn-cube/src/context.rs b/crates/burn-cube/src/context.rs index 3effc70c49..00afd24387 100644 --- a/crates/burn-cube/src/context.rs +++ b/crates/burn-cube/src/context.rs @@ -29,12 +29,6 @@ impl VariablePool { } } - // println!("New var"); - // for variable in variables { - // let count = Rc::strong_count(&variable.inner); - // println!("{:?} => {}", variable.inner, count); - // } - // If no candidate was found, a new var will be needed None } diff --git a/crates/burn-cube/src/element.rs b/crates/burn-cube/src/element.rs index 01eecf16f5..25191bf674 100644 --- a/crates/burn-cube/src/element.rs +++ b/crates/burn-cube/src/element.rs @@ -1,5 +1,7 @@ +use std::marker::PhantomData; + use alloc::rc::Rc; -use burn_jit::gpu::{Item, Variable}; +use burn_jit::gpu::{Elem, Item, Variable}; /// Types used in a cube function must implement this trait /// @@ -69,12 +71,59 @@ impl core::ops::Deref for ExpandElement { } } -#[derive(new, Clone, Copy)] -pub struct Float { +// Why _ suffixes? Just to avoid clashing with JIT float kind types +// TODO refactor +pub trait FloatKind_: Clone + Copy { + fn to_elem() -> Elem; +} +#[derive(Clone, Copy)] +pub struct F32_; +#[derive(Clone, Copy)] +pub struct BF16_; +#[derive(Clone, Copy)] +pub struct F32_; +#[derive(Clone, Copy)] +pub struct F64_; +impl FloatKind_ for F32_ { + fn to_elem() -> Elem { + Elem::Float(burn_jit::gpu::FloatKind::F32) + } +} +impl FloatKind_ for BF16_ { + fn to_elem() -> Elem { + Elem::Float(burn_jit::gpu::FloatKind::BF16) + } +} +impl FloatKind_ for F32_ { + fn to_elem() -> Elem { + Elem::Float(burn_jit::gpu::FloatKind::F32) + } +} +impl FloatKind_ for F64_ { + fn to_elem() -> Elem { + Elem::Float(burn_jit::gpu::FloatKind::F64) + } +} + +#[derive(Clone, Copy)] +pub struct Float { pub val: f32, pub vectorization: u8, + pub _type: PhantomData, } +impl Float { + pub fn new(val: f32, vectorization: u8) -> Self { + Self { + val, + vectorization, + _type: PhantomData, + } + } +} + +pub type Float_ = Float; + #[derive(new, Clone, Copy)] pub struct Int { pub val: i32, @@ -98,11 +147,11 @@ pub struct Array { pub vals: Vec, } -impl RuntimeType for Float { +impl RuntimeType for Float { type ExpandType = ExpandElement; } -impl RuntimeType for Array { +impl RuntimeType for Array> { type ExpandType = ExpandElement; } @@ -148,7 +197,7 @@ impl RuntimeType for f32 { type ExpandType = f32; } -impl From for Float { +impl From for Float { fn from(value: f32) -> Self { Float::new(value, 1) } diff --git a/crates/burn-cube/src/operation/assignation.rs b/crates/burn-cube/src/operation/assignation.rs index ffefcb3d09..a99de50168 100644 --- a/crates/burn-cube/src/operation/assignation.rs +++ b/crates/burn-cube/src/operation/assignation.rs @@ -64,7 +64,7 @@ pub mod index { } pub mod add_assign_op { - use crate::{operation::base::assign_op_expand, Float, Int}; + use crate::{operation::base::assign_op_expand, Float, FloatKind_, Int}; use super::*; @@ -76,7 +76,7 @@ pub mod add_assign_op { assign_op_expand(context, lhs, rhs, gpu::Operator::Add) } - impl core::ops::AddAssign for Float { + impl core::ops::AddAssign for Float { fn add_assign(&mut self, rhs: Self) { self.val += rhs.val } diff --git a/crates/burn-cube/src/operation/binary.rs b/crates/burn-cube/src/operation/binary.rs index 08b79961e7..75c2a0bdc8 100644 --- a/crates/burn-cube/src/operation/binary.rs +++ b/crates/burn-cube/src/operation/binary.rs @@ -3,6 +3,8 @@ use crate::{CubeContext, ExpandElement, Float, Int, UInt}; use burn_jit::gpu::{self}; pub mod add { + use crate::FloatKind_; + use super::*; pub fn expand( @@ -13,11 +15,11 @@ pub mod add { binary_expand(context, lhs, rhs, gpu::Operator::Add) } - impl core::ops::Add for Float { + impl core::ops::Add for Float { type Output = Self; fn add(self, rhs: Self) -> Self::Output { - Float::new(self.val + rhs.val, 1) + Float::::new(self.val + rhs.val, 1) } } @@ -39,6 +41,8 @@ pub mod add { } pub mod sub { + use crate::FloatKind_; + use super::*; pub fn expand( @@ -49,7 +53,7 @@ pub mod sub { binary_expand(context, lhs, rhs, gpu::Operator::Sub) } - impl core::ops::Sub for Float { + impl core::ops::Sub for Float { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { @@ -75,6 +79,8 @@ pub mod sub { } pub mod mul { + use crate::FloatKind_; + use super::*; pub fn expand( @@ -85,7 +91,7 @@ pub mod mul { binary_expand(context, lhs, rhs, gpu::Operator::Mul) } - impl core::ops::Mul for Float { + impl core::ops::Mul for Float { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { @@ -111,6 +117,8 @@ pub mod mul { } pub mod div { + use crate::FloatKind_; + use super::*; pub fn expand( @@ -121,7 +129,7 @@ pub mod div { binary_expand(context, lhs, rhs, gpu::Operator::Div) } - impl core::ops::Div for Float { + impl core::ops::Div for Float { type Output = Self; fn div(self, rhs: Self) -> Self::Output { @@ -147,6 +155,8 @@ pub mod div { } pub mod rem { + use crate::FloatKind_; + use super::*; pub fn expand( @@ -157,7 +167,7 @@ pub mod rem { binary_expand(context, lhs, rhs, gpu::Operator::Modulo) } - impl core::ops::Rem for Float { + impl core::ops::Rem for Float { type Output = Self; fn rem(self, rhs: Self) -> Self::Output { diff --git a/crates/burn-cube/src/operation/cmp.rs b/crates/burn-cube/src/operation/cmp.rs index a422c93b27..3f4b85d6ce 100644 --- a/crates/burn-cube/src/operation/cmp.rs +++ b/crates/burn-cube/src/operation/cmp.rs @@ -1,8 +1,8 @@ use crate::operation::base::cmp_expand; -use crate::{CubeContext, ExpandElement, Float, Int, UInt}; +use crate::{CubeContext, ExpandElement, Float, FloatKind_, Int, UInt}; use burn_jit::gpu::{self}; -impl core::cmp::PartialEq for Float { +impl core::cmp::PartialEq for Float { fn eq(&self, other: &Self) -> bool { self.val == other.val && self.vectorization == other.vectorization } @@ -20,13 +20,13 @@ impl core::cmp::PartialEq for UInt { } } -impl core::cmp::Eq for Float {} +impl core::cmp::Eq for Float {} impl core::cmp::Eq for Int {} impl core::cmp::Eq for UInt {} -impl core::cmp::PartialOrd for Float { +impl core::cmp::PartialOrd for Float { fn partial_cmp(&self, other: &Self) -> Option { match self.val.partial_cmp(&other.val) { Some(core::cmp::Ordering::Equal) => {} diff --git a/crates/burn-cube/tests/cast_elem.rs b/crates/burn-cube/tests/cast_elem.rs index 075e57e701..2a4ff29b4f 100644 --- a/crates/burn-cube/tests/cast_elem.rs +++ b/crates/burn-cube/tests/cast_elem.rs @@ -1,4 +1,4 @@ -use burn_cube::{cube, Bool, CubeContext, Float, Int, UInt}; +use burn_cube::{cube, Bool, CubeContext, Float, FloatKind_, Int, UInt}; use burn_jit::gpu; use burn_jit::gpu::FloatKind::F32; use burn_jit::gpu::IntKind::I32; @@ -12,7 +12,7 @@ macro_rules! cast_test { let x = context.create_local($from); - $module::expand(&mut context, x); + $module::expand::(&mut context, x); let scope = context.into_scope(); assert_eq!( @@ -29,7 +29,7 @@ macro_rules! cast_test { let x = context.create_local($ty); - $module::expand(&mut context, x); + $module::expand::(&mut context, x); let scope = context.into_scope(); assert_eq!( @@ -42,29 +42,35 @@ macro_rules! cast_test { // From float #[cube] -pub fn float_to_float(x: Float) { - let y = x + float_new(2.0); - let _ = to_float(y) + float_new(34.0); +pub fn float_to_float(x: Float) { + let y = x + float_new::(2.0); + let _ = to_float::, F>(y) + float_new::(34.0); } #[cube] -pub fn float_to_int(x: Float) { - let y = x + float_new(2.0); +pub fn float_to_int(x: Float) { + let y = x + float_new::(2.0); let _ = to_int(y) + int_new(34); } #[cube] -pub fn float_to_uint(x: Float) { - let y = x + float_new(2.0); +pub fn float_to_uint(x: Float) { + let y = x + float_new::(2.0); let _ = to_uint(y) + uint_new(34u32); } #[cube] -pub fn float_to_bool(x: Float) { - let y = x + float_new(2.0); +pub fn float_to_bool(x: Float) { + let y = x + float_new::(2.0); let _ = to_bool(y) | bool_new(true); } +cast_test!( + cube_float_to_float_test, + float_to_float, + Item::Scalar(Elem::Float(F32)) +); + cast_test!( cube_float_to_int_test, float_to_int, @@ -79,12 +85,6 @@ cast_test!( Item::Scalar(Elem::UInt) ); -cast_test!( - cube_float_to_float_test, - float_to_float, - Item::Scalar(Elem::Float(F32)) -); - cast_test!( cube_float_to_bool_test, float_to_bool, @@ -92,27 +92,27 @@ cast_test!( Item::Scalar(Elem::Bool) ); -// From int +// // From int #[cube] -pub fn int_to_float(x: Int) { +pub fn int_to_float(x: Int) { let y = x + int_new(2); - let _ = to_float(y) + float_new(34.0); + let _ = to_float::(y) + float_new::(34.0); } #[cube] -pub fn int_to_int(x: Int) { +pub fn int_to_int(x: Int) { let y = x + int_new(2); let _ = to_int(y) + int_new(34); } #[cube] -pub fn int_to_uint(x: Int) { +pub fn int_to_uint(x: Int) { let y = x + int_new(2); let _ = to_uint(y) + uint_new(34u32); } #[cube] -pub fn int_to_bool(x: Int) { +pub fn int_to_bool(x: Int) { let y = x + int_new(2); let _ = to_bool(y) | bool_new(true); } @@ -144,27 +144,27 @@ cast_test!( Item::Scalar(Elem::Bool) ); -// From uint +// // From uint #[cube] -pub fn uint_to_float(x: UInt) { +pub fn uint_to_float(x: UInt) { let y = x + uint_new(2u32); - let _ = to_float(y) + float_new(34.0); + let _ = to_float::(y) + float_new::(34.0); } #[cube] -pub fn uint_to_int(x: UInt) { +pub fn uint_to_int(x: UInt) { let y = x + uint_new(2u32); let _ = to_int(y) + int_new(34); } #[cube] -pub fn uint_to_uint(x: UInt) { +pub fn uint_to_uint(x: UInt) { let y = x + uint_new(2u32); let _ = to_uint(y) + uint_new(34u32); } #[cube] -pub fn uint_to_bool(x: UInt) { +pub fn uint_to_bool(x: UInt) { let y = x + uint_new(2u32); let _ = to_bool(y) | bool_new(true); } @@ -198,25 +198,25 @@ cast_test!( // From bool #[cube] -pub fn bool_to_float(x: Bool) { +pub fn bool_to_float(x: Bool) { let y = x & bool_new(false); - let _ = to_float(y) + float_new(34.0); + let _ = to_float::(y) + float_new::(34.0); } #[cube] -pub fn bool_to_int(x: Bool) { +pub fn bool_to_int(x: Bool) { let y = x & bool_new(false); let _ = to_int(y) + int_new(34); } #[cube] -pub fn bool_to_uint(x: Bool) { +pub fn bool_to_uint(x: Bool) { let y = x & bool_new(false); let _ = to_uint(y) + uint_new(34u32); } #[cube] -pub fn bool_to_bool(x: Bool) { +pub fn bool_to_bool(x: Bool) { let y = x & bool_new(false); let _ = to_bool(y) | bool_new(true); } diff --git a/crates/burn-cube/tests/cast_kind.rs b/crates/burn-cube/tests/cast_kind.rs new file mode 100644 index 0000000000..20586c39af --- /dev/null +++ b/crates/burn-cube/tests/cast_kind.rs @@ -0,0 +1,43 @@ +use burn_cube::{cube, CubeContext, Float, FloatKind_, F32_, F64_}; +use burn_jit::gpu; +use burn_jit::gpu::FloatKind; +use burn_jit::gpu::{Elem, Item}; + +#[cube] +pub fn cast_kind(input: Float) { + let x = input + float_new::(5.9f32); + let y = to_float::, F2>(x); + let _ = y + float_new::(2.3f32); +} + +#[test] +fn cube_cast_kind_test() { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Float(FloatKind::F64)); + + let input = context.create_local(item); + + // F16 not testable with the gpu macro, but should work the same + cast_kind::expand::(&mut context, input); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); +} + +fn gpu_macro_ref() -> String { + let mut context = CubeContext::root(); + let float_64 = Item::Scalar(Elem::Float(FloatKind::F64)); + let float_32 = Item::Scalar(Elem::Float(FloatKind::F32)); + let input = context.create_local(float_64); + + let mut scope = context.into_scope(); + let x = scope.create_local(float_64); + let y = scope.create_local(float_32); + let z = scope.create_local(float_32); + + gpu!(scope, x = input + 5.9f32 as f64); + gpu!(scope, y = cast(x)); + gpu!(scope, z = y + 2.9f32); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/for_loop.rs b/crates/burn-cube/tests/for_loop.rs index 743b9ebb00..087648d06e 100644 --- a/crates/burn-cube/tests/for_loop.rs +++ b/crates/burn-cube/tests/for_loop.rs @@ -1,10 +1,10 @@ -use burn_cube::{cube, range, range_expand, Array, CubeContext, Float, UInt}; +use burn_cube::{cube, range, range_expand, Array, CubeContext, Float, FloatKind_, UInt, F32_}; use burn_jit::gpu; use burn_jit::gpu::FloatKind::F32; use burn_jit::gpu::{Elem, Item, Variable}; #[cube] -pub fn for_loop(mut lhs: Array, rhs: Float, end: UInt, unroll: bool) { +pub fn for_loop(mut lhs: Array>, rhs: Float, end: UInt, unroll: bool) { let tmp1 = rhs * rhs; let tmp2 = tmp1 + rhs; @@ -22,7 +22,7 @@ fn test_for_loop_with_unroll() { let rhs = context.create_local(Item::Scalar(Elem::Float(F32))); let end = 4u32.into(); - for_loop::expand(&mut context, lhs, rhs, end, unroll); + for_loop::expand::(&mut context, lhs, rhs, end, unroll); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref(unroll)); @@ -37,7 +37,7 @@ fn test_for_loop_no_unroll() { let rhs = context.create_local(Item::Scalar(Elem::Float(F32))); let end = 4u32.into(); - for_loop::expand(&mut context, lhs, rhs, end, unroll); + for_loop::expand::(&mut context, lhs, rhs, end, unroll); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref(unroll)); diff --git a/crates/burn-cube/tests/if.rs b/crates/burn-cube/tests/if.rs index 5f8321da45..08942d2aa3 100644 --- a/crates/burn-cube/tests/if.rs +++ b/crates/burn-cube/tests/if.rs @@ -1,12 +1,12 @@ -use burn_cube::{cube, if_expand, CubeContext, Float}; +use burn_cube::{cube, if_expand, CubeContext, Float, FloatKind_, F32_}; use burn_jit::gpu; use burn_jit::gpu::FloatKind::F32; use burn_jit::gpu::{Elem, Item, Variable}; #[cube] -pub fn if_greater(lhs: Float) { - if lhs > float_new(0.0) { - let _ = lhs + float_new(4.0); +pub fn if_greater(lhs: Float) { + if lhs > float_new::(0.0) { + let _ = lhs + float_new::(4.0); } } @@ -16,7 +16,7 @@ fn cube_if_test() { let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); - if_greater::expand(&mut context, lhs); + if_greater::expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); @@ -34,7 +34,7 @@ fn gpu_macro_ref() -> String { gpu!(scope, cond = lhs > 0f32); gpu!(&mut scope, if(cond).then(|scope| { - gpu!(scope, y = lhs + 4.0); + gpu!(scope, y = lhs + 4.0f32); })); format!("{:?}", scope.operations) diff --git a/crates/burn-cube/tests/if_else.rs b/crates/burn-cube/tests/if_else.rs index b378c3cddd..a61c7a005a 100644 --- a/crates/burn-cube/tests/if_else.rs +++ b/crates/burn-cube/tests/if_else.rs @@ -1,14 +1,14 @@ -use burn_cube::{cube, if_else_expand, CubeContext, Float}; +use burn_cube::{cube, if_else_expand, CubeContext, Float, FloatKind_, F32_}; use burn_jit::gpu; use burn_jit::gpu::FloatKind::F32; use burn_jit::gpu::{Elem, Item, Variable}; #[cube] -pub fn if_else(lhs: Float) { - if lhs < float_new(0.0) { - let _ = lhs + float_new(4.0); +pub fn if_else(lhs: Float) { + if lhs < float_new::(0.0) { + let _ = lhs + float_new::(4.0); } else { - let _ = lhs - float_new(5.0); + let _ = lhs - float_new::(5.0); } } @@ -18,7 +18,7 @@ fn cube_if_else_test() { let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); - if_else::expand(&mut context, lhs); + if_else::expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); @@ -36,9 +36,9 @@ fn gpu_macro_ref() -> String { gpu!(scope, cond = lhs < 0f32); gpu!(&mut scope, if(cond).then(|scope| { - gpu!(scope, y = lhs + 4.0); + gpu!(scope, y = lhs + 4.0f32); }).else(|scope|{ - gpu!(scope, y = lhs - 5.0); + gpu!(scope, y = lhs - 5.0f32); })); format!("{:?}", scope.operations) diff --git a/crates/burn-cube/tests/literal.rs b/crates/burn-cube/tests/literal.rs index 642eb2c4b7..d424cca67b 100644 --- a/crates/burn-cube/tests/literal.rs +++ b/crates/burn-cube/tests/literal.rs @@ -1,18 +1,18 @@ -use burn_cube::{cube, CubeContext, Float}; +use burn_cube::{cube, CubeContext, Float, F32_}; use burn_jit::gpu; -use burn_jit::gpu::FloatKind::F32; +use burn_jit::gpu::FloatKind; use burn_jit::gpu::{Elem, Item}; #[cube] -pub fn literal(lhs: Float) { - let _ = lhs + float_new(5.9); +pub fn literal(lhs: Float) { + let _ = lhs + float_new::(5.9); } #[test] fn cube_literal_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); + let lhs = context.create_local(Item::Scalar(Elem::Float(FloatKind::F32))); literal::expand(&mut context, lhs); let scope = context.into_scope(); @@ -22,7 +22,7 @@ fn cube_literal_test() { fn gpu_macro_ref() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Float(F32)); + let item = Item::Scalar(Elem::Float(FloatKind::F32)); let lhs = context.create_local(item); let mut scope = context.into_scope(); diff --git a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs index 991235b3ab..af3a1938bb 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs @@ -414,6 +414,12 @@ impl From for Variable { } } +impl From for Variable { + fn from(value: f64) -> Self { + Self::ConstantScalar(value as f64, super::Elem::Float(super::FloatKind::F64)) + } +} + impl From for Variable { fn from(value: u32) -> Self { Self::ConstantScalar(value as f64, super::Elem::UInt) From 6aa7ccc543aafb71c4b7cfebc26b948d93da6e47 Mon Sep 17 00:00:00 2001 From: louisfd Date: Wed, 8 May 2024 07:46:45 -0400 Subject: [PATCH 28/54] cast kind float --- crates/burn-cube/src/element.rs | 6 +++--- crates/burn-cube/tests/cast_kind.rs | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/burn-cube/src/element.rs b/crates/burn-cube/src/element.rs index 25191bf674..f9bb27afa6 100644 --- a/crates/burn-cube/src/element.rs +++ b/crates/burn-cube/src/element.rs @@ -77,16 +77,16 @@ pub trait FloatKind_: Clone + Copy { fn to_elem() -> Elem; } #[derive(Clone, Copy)] -pub struct F32_; +pub struct F16_; #[derive(Clone, Copy)] pub struct BF16_; #[derive(Clone, Copy)] pub struct F32_; #[derive(Clone, Copy)] pub struct F64_; -impl FloatKind_ for F32_ { +impl FloatKind_ for F16_ { fn to_elem() -> Elem { - Elem::Float(burn_jit::gpu::FloatKind::F32) + Elem::Float(burn_jit::gpu::FloatKind::F16) } } impl FloatKind_ for BF16_ { diff --git a/crates/burn-cube/tests/cast_kind.rs b/crates/burn-cube/tests/cast_kind.rs index 20586c39af..10152096c4 100644 --- a/crates/burn-cube/tests/cast_kind.rs +++ b/crates/burn-cube/tests/cast_kind.rs @@ -4,7 +4,7 @@ use burn_jit::gpu::FloatKind; use burn_jit::gpu::{Elem, Item}; #[cube] -pub fn cast_kind(input: Float) { +pub fn cast_float_kind(input: Float) { let x = input + float_new::(5.9f32); let y = to_float::, F2>(x); let _ = y + float_new::(2.3f32); @@ -18,7 +18,7 @@ fn cube_cast_kind_test() { let input = context.create_local(item); // F16 not testable with the gpu macro, but should work the same - cast_kind::expand::(&mut context, input); + cast_float_kind::expand::(&mut context, input); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); @@ -37,7 +37,7 @@ fn gpu_macro_ref() -> String { gpu!(scope, x = input + 5.9f32 as f64); gpu!(scope, y = cast(x)); - gpu!(scope, z = y + 2.9f32); + gpu!(scope, z = y + 2.3f32); format!("{:?}", scope.operations) } From 16e38ac206d56cb2699e21645a0c4ee318e9c7b8 Mon Sep 17 00:00:00 2001 From: louisfd Date: Thu, 9 May 2024 10:39:58 -0400 Subject: [PATCH 29/54] refactor elements --- crates/burn-cube-macros/src/prelude.rs | 20 +- crates/burn-cube/src/element.rs | 227 ------- crates/burn-cube/src/element/array.rs | 22 + crates/burn-cube/src/element/base.rs | 47 ++ crates/burn-cube/src/element/bool.rs | 11 + crates/burn-cube/src/element/float.rs | 84 +++ crates/burn-cube/src/element/int.rs | 17 + crates/burn-cube/src/element/mod.rs | 14 + crates/burn-cube/src/element/primitive.rs | 51 ++ crates/burn-cube/src/element/uint.rs | 23 + crates/burn-cube/src/operation/assignation.rs | 31 +- crates/burn-cube/src/operation/binary.rs | 186 +++--- crates/burn-cube/src/operation/cmp.rs | 73 +-- crates/burn-cube/tests/cast_elem.rs | 612 +++++++++--------- crates/burn-cube/tests/cast_kind.rs | 86 +-- crates/burn-cube/tests/for_loop.rs | 21 +- crates/burn-cube/tests/if.rs | 13 +- crates/burn-cube/tests/if_else.rs | 13 +- crates/burn-cube/tests/literal.rs | 15 +- crates/burn-cube/tests/loop.rs | 112 ++-- crates/burn-cube/tests/reuse.rs | 194 +++--- 21 files changed, 931 insertions(+), 941 deletions(-) delete mode 100644 crates/burn-cube/src/element.rs create mode 100644 crates/burn-cube/src/element/array.rs create mode 100644 crates/burn-cube/src/element/base.rs create mode 100644 crates/burn-cube/src/element/bool.rs create mode 100644 crates/burn-cube/src/element/float.rs create mode 100644 crates/burn-cube/src/element/int.rs create mode 100644 crates/burn-cube/src/element/mod.rs create mode 100644 crates/burn-cube/src/element/primitive.rs create mode 100644 crates/burn-cube/src/element/uint.rs diff --git a/crates/burn-cube-macros/src/prelude.rs b/crates/burn-cube-macros/src/prelude.rs index 01cc702be3..0b70689246 100644 --- a/crates/burn-cube-macros/src/prelude.rs +++ b/crates/burn-cube-macros/src/prelude.rs @@ -36,14 +36,14 @@ fn codegen_float_new() -> proc_macro2::TokenStream { use std::{rc::Rc}; use burn_cube::ExpandElement; use burn_jit::gpu::Variable; - pub fn float_new(val: f32) -> Float { - Float::new(val, 1) + pub fn float_new(val: f32) -> F { + F::new(val, 1) } - pub fn float_new_expand( + pub fn float_new_expand( context: &mut CubeContext, val: f32, - ) -> as burn_cube::RuntimeType>::ExpandType { - let elem = F::to_elem(); + ) -> ::ExpandType { + let elem = Elem::Float(F::into_kind()); let new_var = Variable::ConstantScalar(val as f64, elem); ExpandElement::new(Rc::new(new_var)) } @@ -130,15 +130,15 @@ fn codegen_to_float() -> proc_macro2::TokenStream { // R: type we come from // F: kind of float we want as output quote::quote! { - pub fn to_float(input: R) -> Float { + pub fn to_float(input: R) -> F { // TODO: make val and vectorization accessible through trait - Float::new(0., 1) + F::new(0., 1) } - pub fn to_float_expand( + pub fn to_float_expand( context: &mut CubeContext, val: burn_cube::ExpandElement, - ) -> as burn_cube::RuntimeType>::ExpandType { - let elem = F::to_elem(); + ) -> burn_cube::ExpandElement { + let elem = Elem::Float(F::into_kind()); let new_var = context.create_local(match val.item() { Item::Vec4(_) => Item::Vec4(elem), Item::Vec3(_) => Item::Vec3(elem), diff --git a/crates/burn-cube/src/element.rs b/crates/burn-cube/src/element.rs deleted file mode 100644 index f9bb27afa6..0000000000 --- a/crates/burn-cube/src/element.rs +++ /dev/null @@ -1,227 +0,0 @@ -use std::marker::PhantomData; - -use alloc::rc::Rc; -use burn_jit::gpu::{Elem, Item, Variable}; - -/// Types used in a cube function must implement this trait -/// -/// Variables whose values will be known at runtime must -/// have ExpandElement as associated type -/// Variables whose values will be known at compile time -/// must have the primitive type as associated type -/// -/// Note: Cube functions should be written using RuntimeTypes, -/// so that the code generated uses the associated ExpandType. -/// This allows Cube code to not necessitate cloning, which is cumbersome -/// in algorithmic code. The necessary cloning will automatically appear in -/// the generated code. -pub trait RuntimeType { - type ExpandType: Clone; -} - -#[derive(new, Clone, Debug)] -/// Reference to a JIT variable -/// It's the expand element that is actually kept in the variable pool -pub struct ExpandElement { - pub(crate) inner: Rc, -} - -impl ExpandElement { - /// Returns the Item of the variable - pub fn item(&self) -> Item { - self.inner.item() - } -} - -impl From for ExpandElement { - fn from(value: u32) -> Self { - ExpandElement::new(Rc::new(Variable::from(value))) - } -} - -impl From for ExpandElement { - fn from(value: usize) -> Self { - ExpandElement::new(Rc::new(Variable::from(value))) - } -} - -impl From for ExpandElement { - fn from(value: bool) -> Self { - ExpandElement::new(Rc::new(Variable::from(value))) - } -} - -impl From for ExpandElement { - fn from(value: f32) -> Self { - ExpandElement::new(Rc::new(Variable::from(value))) - } -} - -impl From for ExpandElement { - fn from(value: i32) -> Self { - ExpandElement::new(Rc::new(Variable::from(value))) - } -} - -impl core::ops::Deref for ExpandElement { - type Target = Variable; - - fn deref(&self) -> &Self::Target { - self.inner.as_ref() - } -} - -// Why _ suffixes? Just to avoid clashing with JIT float kind types -// TODO refactor -pub trait FloatKind_: Clone + Copy { - fn to_elem() -> Elem; -} -#[derive(Clone, Copy)] -pub struct F16_; -#[derive(Clone, Copy)] -pub struct BF16_; -#[derive(Clone, Copy)] -pub struct F32_; -#[derive(Clone, Copy)] -pub struct F64_; -impl FloatKind_ for F16_ { - fn to_elem() -> Elem { - Elem::Float(burn_jit::gpu::FloatKind::F16) - } -} -impl FloatKind_ for BF16_ { - fn to_elem() -> Elem { - Elem::Float(burn_jit::gpu::FloatKind::BF16) - } -} -impl FloatKind_ for F32_ { - fn to_elem() -> Elem { - Elem::Float(burn_jit::gpu::FloatKind::F32) - } -} -impl FloatKind_ for F64_ { - fn to_elem() -> Elem { - Elem::Float(burn_jit::gpu::FloatKind::F64) - } -} - -#[derive(Clone, Copy)] -pub struct Float { - pub val: f32, - pub vectorization: u8, - pub _type: PhantomData, -} - -impl Float { - pub fn new(val: f32, vectorization: u8) -> Self { - Self { - val, - vectorization, - _type: PhantomData, - } - } -} - -pub type Float_ = Float; - -#[derive(new, Clone, Copy)] -pub struct Int { - pub val: i32, - pub vectorization: u8, -} - -#[derive(new, Clone, Copy)] -pub struct UInt { - pub val: u32, - pub vectorization: u8, -} - -#[derive(new, Clone, Copy)] -pub struct Bool { - pub val: bool, - pub vectorization: u8, -} - -#[derive(new, Clone)] -pub struct Array { - pub vals: Vec, -} - -impl RuntimeType for Float { - type ExpandType = ExpandElement; -} - -impl RuntimeType for Array> { - type ExpandType = ExpandElement; -} - -impl RuntimeType for Array { - type ExpandType = ExpandElement; -} - -impl RuntimeType for Array { - type ExpandType = ExpandElement; -} - -impl RuntimeType for Array { - type ExpandType = ExpandElement; -} - -impl RuntimeType for Int { - type ExpandType = ExpandElement; -} - -impl RuntimeType for UInt { - type ExpandType = ExpandElement; -} - -impl RuntimeType for Bool { - type ExpandType = ExpandElement; -} - -impl RuntimeType for bool { - type ExpandType = bool; -} - -impl RuntimeType for u32 { - type ExpandType = u32; -} - -impl From for UInt { - fn from(value: u32) -> Self { - UInt::new(value, 1) - } -} - -impl RuntimeType for f32 { - type ExpandType = f32; -} - -impl From for Float { - fn from(value: f32) -> Self { - Float::new(value, 1) - } -} - -impl From for UInt { - fn from(value: usize) -> Self { - UInt::new(value as u32, 1) - } -} - -impl RuntimeType for i32 { - type ExpandType = i32; -} - -impl From for Int { - fn from(value: i32) -> Self { - Int::new(value, 1) - } -} - -impl From for Variable { - fn from(value: ExpandElement) -> Self { - // Is it ok to do that? - (*value.inner).clone() - } -} diff --git a/crates/burn-cube/src/element/array.rs b/crates/burn-cube/src/element/array.rs new file mode 100644 index 0000000000..c7300b6cc9 --- /dev/null +++ b/crates/burn-cube/src/element/array.rs @@ -0,0 +1,22 @@ +use crate::{Bool, ExpandElement, Float, Int, RuntimeType, UInt}; + +#[derive(new, Clone)] +pub struct Array { + pub vals: Vec, +} + +impl RuntimeType for Array { + type ExpandType = ExpandElement; +} + +impl RuntimeType for Array { + type ExpandType = ExpandElement; +} + +impl RuntimeType for Array { + type ExpandType = ExpandElement; +} + +impl RuntimeType for Array { + type ExpandType = ExpandElement; +} diff --git a/crates/burn-cube/src/element/base.rs b/crates/burn-cube/src/element/base.rs new file mode 100644 index 0000000000..aca38847de --- /dev/null +++ b/crates/burn-cube/src/element/base.rs @@ -0,0 +1,47 @@ +use alloc::rc::Rc; +use burn_jit::gpu::{Item, Variable}; + +/// Types used in a cube function must implement this trait +/// +/// Variables whose values will be known at runtime must +/// have ExpandElement as associated type +/// Variables whose values will be known at compile time +/// must have the primitive type as associated type +/// +/// Note: Cube functions should be written using RuntimeTypes, +/// so that the code generated uses the associated ExpandType. +/// This allows Cube code to not necessitate cloning, which is cumbersome +/// in algorithmic code. The necessary cloning will automatically appear in +/// the generated code. +pub trait RuntimeType { + type ExpandType: Clone; +} + +#[derive(new, Clone, Debug)] +/// Reference to a JIT variable +/// It's the expand element that is actually kept in the variable pool +pub struct ExpandElement { + pub(crate) inner: Rc, +} + +impl ExpandElement { + /// Returns the Item of the variable + pub fn item(&self) -> Item { + self.inner.item() + } +} + +impl core::ops::Deref for ExpandElement { + type Target = Variable; + + fn deref(&self) -> &Self::Target { + self.inner.as_ref() + } +} + +impl From for Variable { + fn from(value: ExpandElement) -> Self { + // Is it ok to do that? + (*value.inner).clone() + } +} diff --git a/crates/burn-cube/src/element/bool.rs b/crates/burn-cube/src/element/bool.rs new file mode 100644 index 0000000000..dc3d0c01c3 --- /dev/null +++ b/crates/burn-cube/src/element/bool.rs @@ -0,0 +1,11 @@ +use crate::{ExpandElement, RuntimeType}; + +#[derive(new, Clone, Copy)] +pub struct Bool { + pub val: bool, + pub vectorization: u8, +} + +impl RuntimeType for Bool { + type ExpandType = ExpandElement; +} diff --git a/crates/burn-cube/src/element/float.rs b/crates/burn-cube/src/element/float.rs new file mode 100644 index 0000000000..3674310605 --- /dev/null +++ b/crates/burn-cube/src/element/float.rs @@ -0,0 +1,84 @@ +use crate::{ExpandElement, RuntimeType}; + +pub trait Float: + Clone + + Copy + + RuntimeType + + std::cmp::PartialOrd + + std::ops::Add + + std::ops::Mul + + std::ops::Sub +{ + fn into_kind() -> burn_jit::gpu::FloatKind; + fn new(val: f32, vectorization: usize) -> Self; +} + +#[derive(Clone, Copy)] +pub struct F16 { + pub val: f32, + pub vectorization: usize, +} +#[derive(Clone, Copy)] +pub struct BF16 { + pub val: f32, + pub vectorization: usize, +} +#[derive(Clone, Copy)] +pub struct F32 { + pub val: f32, + pub vectorization: usize, +} +#[derive(Clone, Copy)] +pub struct F64 { + pub val: f32, + pub vectorization: usize, +} + +impl RuntimeType for F16 { + type ExpandType = ExpandElement; +} + +impl RuntimeType for BF16 { + type ExpandType = ExpandElement; +} + +impl RuntimeType for F32 { + type ExpandType = ExpandElement; +} + +impl RuntimeType for F64 { + type ExpandType = ExpandElement; +} + +impl Float for F16 { + fn into_kind() -> burn_jit::gpu::FloatKind { + burn_jit::gpu::FloatKind::F16 + } + fn new(val: f32, vectorization: usize) -> Self { + Self { val, vectorization } + } +} +impl Float for BF16 { + fn into_kind() -> burn_jit::gpu::FloatKind { + burn_jit::gpu::FloatKind::BF16 + } + fn new(val: f32, vectorization: usize) -> Self { + Self { val, vectorization } + } +} +impl Float for F32 { + fn into_kind() -> burn_jit::gpu::FloatKind { + burn_jit::gpu::FloatKind::F32 + } + fn new(val: f32, vectorization: usize) -> Self { + Self { val, vectorization } + } +} +impl Float for F64 { + fn into_kind() -> burn_jit::gpu::FloatKind { + burn_jit::gpu::FloatKind::F64 + } + fn new(val: f32, vectorization: usize) -> Self { + Self { val, vectorization } + } +} diff --git a/crates/burn-cube/src/element/int.rs b/crates/burn-cube/src/element/int.rs new file mode 100644 index 0000000000..3437165ba8 --- /dev/null +++ b/crates/burn-cube/src/element/int.rs @@ -0,0 +1,17 @@ +use crate::{ExpandElement, RuntimeType}; + +#[derive(new, Clone, Copy)] +pub struct Int { + pub val: i32, + pub vectorization: u8, +} + +impl RuntimeType for Int { + type ExpandType = ExpandElement; +} + +impl From for Int { + fn from(value: i32) -> Self { + Int::new(value, 1) + } +} diff --git a/crates/burn-cube/src/element/mod.rs b/crates/burn-cube/src/element/mod.rs new file mode 100644 index 0000000000..81b0c98a61 --- /dev/null +++ b/crates/burn-cube/src/element/mod.rs @@ -0,0 +1,14 @@ +mod array; +mod base; +mod bool; +mod float; +mod int; +mod primitive; +mod uint; + +pub use array::*; +pub use base::*; +pub use bool::*; +pub use float::*; +pub use int::*; +pub use uint::*; diff --git a/crates/burn-cube/src/element/primitive.rs b/crates/burn-cube/src/element/primitive.rs new file mode 100644 index 0000000000..ef2eebf10c --- /dev/null +++ b/crates/burn-cube/src/element/primitive.rs @@ -0,0 +1,51 @@ +use std::rc::Rc; + +use burn_jit::gpu::Variable; + +use crate::{ExpandElement, RuntimeType}; + +impl RuntimeType for bool { + type ExpandType = bool; +} + +impl RuntimeType for u32 { + type ExpandType = u32; +} + +impl RuntimeType for f32 { + type ExpandType = f32; +} + +impl RuntimeType for i32 { + type ExpandType = i32; +} + +impl From for ExpandElement { + fn from(value: u32) -> Self { + ExpandElement::new(Rc::new(Variable::from(value))) + } +} + +impl From for ExpandElement { + fn from(value: usize) -> Self { + ExpandElement::new(Rc::new(Variable::from(value))) + } +} + +impl From for ExpandElement { + fn from(value: bool) -> Self { + ExpandElement::new(Rc::new(Variable::from(value))) + } +} + +impl From for ExpandElement { + fn from(value: f32) -> Self { + ExpandElement::new(Rc::new(Variable::from(value))) + } +} + +impl From for ExpandElement { + fn from(value: i32) -> Self { + ExpandElement::new(Rc::new(Variable::from(value))) + } +} diff --git a/crates/burn-cube/src/element/uint.rs b/crates/burn-cube/src/element/uint.rs new file mode 100644 index 0000000000..917274bd68 --- /dev/null +++ b/crates/burn-cube/src/element/uint.rs @@ -0,0 +1,23 @@ +use crate::{ExpandElement, RuntimeType}; + +#[derive(new, Clone, Copy)] +pub struct UInt { + pub val: u32, + pub vectorization: u8, +} + +impl RuntimeType for UInt { + type ExpandType = ExpandElement; +} + +impl From for UInt { + fn from(value: u32) -> Self { + UInt::new(value, 1) + } +} + +impl From for UInt { + fn from(value: usize) -> Self { + UInt::new(value as u32, 1) + } +} diff --git a/crates/burn-cube/src/operation/assignation.rs b/crates/burn-cube/src/operation/assignation.rs index a99de50168..fd6ea3897f 100644 --- a/crates/burn-cube/src/operation/assignation.rs +++ b/crates/burn-cube/src/operation/assignation.rs @@ -64,7 +64,7 @@ pub mod index { } pub mod add_assign_op { - use crate::{operation::base::assign_op_expand, Float, FloatKind_, Int}; + use crate::{operation::base::assign_op_expand, Int, BF16, F16, F32, F64}; use super::*; @@ -76,21 +76,20 @@ pub mod add_assign_op { assign_op_expand(context, lhs, rhs, gpu::Operator::Add) } - impl core::ops::AddAssign for Float { - fn add_assign(&mut self, rhs: Self) { - self.val += rhs.val - } - } - - impl core::ops::AddAssign for Int { - fn add_assign(&mut self, rhs: Self) { - self.val += rhs.val - } + macro_rules! impl_add_assign { + ($type:ty) => { + impl core::ops::AddAssign for $type { + fn add_assign(&mut self, rhs: Self) { + self.val += rhs.val + } + } + }; } - impl core::ops::AddAssign for UInt { - fn add_assign(&mut self, rhs: Self) { - self.val += rhs.val - } - } + impl_add_assign!(F16); + impl_add_assign!(BF16); + impl_add_assign!(F32); + impl_add_assign!(F64); + impl_add_assign!(Int); + impl_add_assign!(UInt); } diff --git a/crates/burn-cube/src/operation/binary.rs b/crates/burn-cube/src/operation/binary.rs index 75c2a0bdc8..f3e2f15c14 100644 --- a/crates/burn-cube/src/operation/binary.rs +++ b/crates/burn-cube/src/operation/binary.rs @@ -1,9 +1,8 @@ use crate::operation::base::binary_expand; -use crate::{CubeContext, ExpandElement, Float, Int, UInt}; +use crate::{CubeContext, ExpandElement, Float, Int, UInt, BF16, F16, F32, F64}; use burn_jit::gpu::{self}; pub mod add { - use crate::FloatKind_; use super::*; @@ -15,34 +14,27 @@ pub mod add { binary_expand(context, lhs, rhs, gpu::Operator::Add) } - impl core::ops::Add for Float { - type Output = Self; + macro_rules! impl_add { + ($type:ty) => { + impl core::ops::Add for $type { + type Output = Self; - fn add(self, rhs: Self) -> Self::Output { - Float::::new(self.val + rhs.val, 1) - } - } - - impl core::ops::Add for Int { - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - Int::new(self.val + rhs.val, 1) - } + fn add(self, rhs: Self) -> Self::Output { + <$type>::new(self.val + rhs.val, 1) + } + } + }; } - impl core::ops::Add for UInt { - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - UInt::new(self.val + rhs.val, 1) - } - } + impl_add!(F16); + impl_add!(BF16); + impl_add!(F32); + impl_add!(F64); + impl_add!(Int); + impl_add!(UInt); } pub mod sub { - use crate::FloatKind_; - use super::*; pub fn expand( @@ -53,33 +45,28 @@ pub mod sub { binary_expand(context, lhs, rhs, gpu::Operator::Sub) } - impl core::ops::Sub for Float { - type Output = Self; + macro_rules! impl_sub { + ($type:ty) => { + impl core::ops::Sub for $type { + type Output = Self; - fn sub(self, rhs: Self) -> Self::Output { - Float::new(self.val - rhs.val, 1) - } + fn sub(self, rhs: Self) -> Self::Output { + <$type>::new(self.val - rhs.val, 1) + } + } + }; } - impl core::ops::Sub for Int { - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - Int::new(self.val - rhs.val, 1) - } - } - - impl core::ops::Sub for UInt { - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - UInt::new(self.val - rhs.val, 1) - } - } + impl_sub!(F16); + impl_sub!(BF16); + impl_sub!(F32); + impl_sub!(F64); + impl_sub!(Int); + impl_sub!(UInt); } pub mod mul { - use crate::FloatKind_; + use crate::Float; use super::*; @@ -91,33 +78,28 @@ pub mod mul { binary_expand(context, lhs, rhs, gpu::Operator::Mul) } - impl core::ops::Mul for Float { - type Output = Self; + macro_rules! impl_mul { + ($type:ty) => { + impl core::ops::Mul for $type { + type Output = Self; - fn mul(self, rhs: Self) -> Self::Output { - Float::new(self.val * rhs.val, 1) - } + fn mul(self, rhs: Self) -> Self::Output { + <$type>::new(self.val * rhs.val, 1) + } + } + }; } - impl core::ops::Mul for Int { - type Output = Self; - - fn mul(self, rhs: Self) -> Self::Output { - Int::new(self.val * rhs.val, 1) - } - } - - impl core::ops::Mul for UInt { - type Output = Self; - - fn mul(self, rhs: Self) -> Self::Output { - UInt::new(self.val * rhs.val, 1) - } - } + impl_mul!(F16); + impl_mul!(BF16); + impl_mul!(F32); + impl_mul!(F64); + impl_mul!(Int); + impl_mul!(UInt); } pub mod div { - use crate::FloatKind_; + use crate::Float; use super::*; @@ -129,33 +111,28 @@ pub mod div { binary_expand(context, lhs, rhs, gpu::Operator::Div) } - impl core::ops::Div for Float { - type Output = Self; - - fn div(self, rhs: Self) -> Self::Output { - Float::new(self.val / rhs.val, 1) - } - } - - impl core::ops::Div for Int { - type Output = Self; + macro_rules! impl_div { + ($type:ty) => { + impl core::ops::Div for $type { + type Output = Self; - fn div(self, rhs: Self) -> Self::Output { - Int::new(self.val / rhs.val, 1) - } + fn div(self, rhs: Self) -> Self::Output { + <$type>::new(self.val / rhs.val, 1) + } + } + }; } - impl core::ops::Div for UInt { - type Output = Self; - - fn div(self, rhs: Self) -> Self::Output { - UInt::new(self.val / rhs.val, 1) - } - } + impl_div!(F16); + impl_div!(BF16); + impl_div!(F32); + impl_div!(F64); + impl_div!(Int); + impl_div!(UInt); } pub mod rem { - use crate::FloatKind_; + use crate::Float; use super::*; @@ -167,29 +144,24 @@ pub mod rem { binary_expand(context, lhs, rhs, gpu::Operator::Modulo) } - impl core::ops::Rem for Float { - type Output = Self; - - fn rem(self, rhs: Self) -> Self::Output { - Float::new(self.val % rhs.val, 1) - } - } - - impl core::ops::Rem for Int { - type Output = Self; + macro_rules! impl_rem { + ($type:ty) => { + impl core::ops::Rem for $type { + type Output = Self; - fn rem(self, rhs: Self) -> Self::Output { - Int::new(self.val % rhs.val, 1) - } + fn rem(self, rhs: Self) -> Self::Output { + <$type>::new(self.val % rhs.val, 1) + } + } + }; } - impl core::ops::Rem for UInt { - type Output = Self; - - fn rem(self, rhs: Self) -> Self::Output { - UInt::new(self.val % rhs.val, 1) - } - } + impl_rem!(F16); + impl_rem!(BF16); + impl_rem!(F32); + impl_rem!(F64); + impl_rem!(Int); + impl_rem!(UInt); } pub mod and { diff --git a/crates/burn-cube/src/operation/cmp.rs b/crates/burn-cube/src/operation/cmp.rs index 3f4b85d6ce..af71f0ca10 100644 --- a/crates/burn-cube/src/operation/cmp.rs +++ b/crates/burn-cube/src/operation/cmp.rs @@ -1,60 +1,33 @@ use crate::operation::base::cmp_expand; -use crate::{CubeContext, ExpandElement, Float, FloatKind_, Int, UInt}; +use crate::{CubeContext, ExpandElement, Int, UInt, BF16, F16, F32, F64}; use burn_jit::gpu::{self}; -impl core::cmp::PartialEq for Float { - fn eq(&self, other: &Self) -> bool { - self.val == other.val && self.vectorization == other.vectorization - } -} - -impl core::cmp::PartialEq for Int { - fn eq(&self, other: &Self) -> bool { - self.val == other.val && self.vectorization == other.vectorization - } -} - -impl core::cmp::PartialEq for UInt { - fn eq(&self, other: &Self) -> bool { - self.val == other.val && self.vectorization == other.vectorization - } -} - -impl core::cmp::Eq for Float {} - -impl core::cmp::Eq for Int {} - -impl core::cmp::Eq for UInt {} - -impl core::cmp::PartialOrd for Float { - fn partial_cmp(&self, other: &Self) -> Option { - match self.val.partial_cmp(&other.val) { - Some(core::cmp::Ordering::Equal) => {} - ord => return ord, +macro_rules! impl_cmp { + ($type:ty) => { + impl core::cmp::PartialEq for $type { + fn eq(&self, other: &Self) -> bool { + self.val == other.val && self.vectorization == other.vectorization + } } - self.vectorization.partial_cmp(&other.vectorization) - } -} - -impl core::cmp::PartialOrd for Int { - fn partial_cmp(&self, other: &Self) -> Option { - match self.val.partial_cmp(&other.val) { - Some(core::cmp::Ordering::Equal) => {} - ord => return ord, + impl core::cmp::Eq for $type {} + impl core::cmp::PartialOrd for $type { + fn partial_cmp(&self, other: &Self) -> Option { + match self.val.partial_cmp(&other.val) { + Some(core::cmp::Ordering::Equal) => {} + ord => return ord, + } + self.vectorization.partial_cmp(&other.vectorization) + } } - self.vectorization.partial_cmp(&other.vectorization) - } + }; } -impl core::cmp::PartialOrd for UInt { - fn partial_cmp(&self, other: &Self) -> Option { - match self.val.partial_cmp(&other.val) { - Some(core::cmp::Ordering::Equal) => {} - ord => return ord, - } - self.vectorization.partial_cmp(&other.vectorization) - } -} +impl_cmp!(F16); +impl_cmp!(BF16); +impl_cmp!(F32); +impl_cmp!(F64); +impl_cmp!(Int); +impl_cmp!(UInt); pub mod ne { diff --git a/crates/burn-cube/tests/cast_elem.rs b/crates/burn-cube/tests/cast_elem.rs index 2a4ff29b4f..fbd11d3f2c 100644 --- a/crates/burn-cube/tests/cast_elem.rs +++ b/crates/burn-cube/tests/cast_elem.rs @@ -1,306 +1,306 @@ -use burn_cube::{cube, Bool, CubeContext, Float, FloatKind_, Int, UInt}; -use burn_jit::gpu; -use burn_jit::gpu::FloatKind::F32; -use burn_jit::gpu::IntKind::I32; -use burn_jit::gpu::{Elem, Item, Variable}; - -macro_rules! cast_test { - ($name:ident, $module:ident, $from:expr, $to:expr) => { - #[test] - fn $name() { - let mut context = CubeContext::root(); - - let x = context.create_local($from); - - $module::expand::(&mut context, x); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - gpu_macro_ref_cast($from, $to) - ); - } - }; - - ($name:ident, $module:ident, $ty:expr) => { - #[test] - fn $name() { - let mut context = CubeContext::root(); - - let x = context.create_local($ty); - - $module::expand::(&mut context, x); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - gpu_macro_ref_identity($ty) - ); - } - }; -} - -// From float -#[cube] -pub fn float_to_float(x: Float) { - let y = x + float_new::(2.0); - let _ = to_float::, F>(y) + float_new::(34.0); -} - -#[cube] -pub fn float_to_int(x: Float) { - let y = x + float_new::(2.0); - let _ = to_int(y) + int_new(34); -} - -#[cube] -pub fn float_to_uint(x: Float) { - let y = x + float_new::(2.0); - let _ = to_uint(y) + uint_new(34u32); -} - -#[cube] -pub fn float_to_bool(x: Float) { - let y = x + float_new::(2.0); - let _ = to_bool(y) | bool_new(true); -} - -cast_test!( - cube_float_to_float_test, - float_to_float, - Item::Scalar(Elem::Float(F32)) -); - -cast_test!( - cube_float_to_int_test, - float_to_int, - Item::Scalar(Elem::Float(F32)), - Item::Scalar(Elem::Int(I32)) -); - -cast_test!( - cube_float_to_uint_test, - float_to_uint, - Item::Scalar(Elem::Float(F32)), - Item::Scalar(Elem::UInt) -); - -cast_test!( - cube_float_to_bool_test, - float_to_bool, - Item::Scalar(Elem::Float(F32)), - Item::Scalar(Elem::Bool) -); - -// // From int -#[cube] -pub fn int_to_float(x: Int) { - let y = x + int_new(2); - let _ = to_float::(y) + float_new::(34.0); -} - -#[cube] -pub fn int_to_int(x: Int) { - let y = x + int_new(2); - let _ = to_int(y) + int_new(34); -} - -#[cube] -pub fn int_to_uint(x: Int) { - let y = x + int_new(2); - let _ = to_uint(y) + uint_new(34u32); -} - -#[cube] -pub fn int_to_bool(x: Int) { - let y = x + int_new(2); - let _ = to_bool(y) | bool_new(true); -} - -cast_test!( - cube_int_to_float_test, - int_to_float, - Item::Scalar(Elem::Int(I32)), - Item::Scalar(Elem::Float(F32)) -); - -cast_test!( - cube_int_to_int_test, - int_to_int, - Item::Scalar(Elem::Int(I32)) -); - -cast_test!( - cube_int_to_uint_test, - int_to_uint, - Item::Scalar(Elem::Int(I32)), - Item::Scalar(Elem::UInt) -); - -cast_test!( - cube_int_to_bool_test, - int_to_bool, - Item::Scalar(Elem::Int(I32)), - Item::Scalar(Elem::Bool) -); - -// // From uint -#[cube] -pub fn uint_to_float(x: UInt) { - let y = x + uint_new(2u32); - let _ = to_float::(y) + float_new::(34.0); -} - -#[cube] -pub fn uint_to_int(x: UInt) { - let y = x + uint_new(2u32); - let _ = to_int(y) + int_new(34); -} - -#[cube] -pub fn uint_to_uint(x: UInt) { - let y = x + uint_new(2u32); - let _ = to_uint(y) + uint_new(34u32); -} - -#[cube] -pub fn uint_to_bool(x: UInt) { - let y = x + uint_new(2u32); - let _ = to_bool(y) | bool_new(true); -} - -cast_test!( - cube_uint_to_float_test, - uint_to_float, - Item::Scalar(Elem::UInt), - Item::Scalar(Elem::Float(F32)) -); - -cast_test!( - cube_uint_to_int_test, - uint_to_int, - Item::Scalar(Elem::UInt), - Item::Scalar(Elem::Int(I32)) -); - -cast_test!( - cube_uint_to_uint_test, - uint_to_uint, - Item::Scalar(Elem::UInt) -); - -cast_test!( - cube_uint_to_bool_test, - uint_to_bool, - Item::Scalar(Elem::UInt), - Item::Scalar(Elem::Bool) -); - -// From bool -#[cube] -pub fn bool_to_float(x: Bool) { - let y = x & bool_new(false); - let _ = to_float::(y) + float_new::(34.0); -} - -#[cube] -pub fn bool_to_int(x: Bool) { - let y = x & bool_new(false); - let _ = to_int(y) + int_new(34); -} - -#[cube] -pub fn bool_to_uint(x: Bool) { - let y = x & bool_new(false); - let _ = to_uint(y) + uint_new(34u32); -} - -#[cube] -pub fn bool_to_bool(x: Bool) { - let y = x & bool_new(false); - let _ = to_bool(y) | bool_new(true); -} - -cast_test!( - cube_bool_to_float_test, - bool_to_float, - Item::Scalar(Elem::Bool), - Item::Scalar(Elem::Float(F32)) -); - -cast_test!( - cube_bool_to_int_test, - bool_to_int, - Item::Scalar(Elem::Bool), - Item::Scalar(Elem::Int(I32)) -); - -cast_test!( - cube_bool_to_uint_test, - bool_to_uint, - Item::Scalar(Elem::Bool), - Item::Scalar(Elem::UInt) -); - -cast_test!( - cube_bool_to_bool_test, - bool_to_bool, - Item::Scalar(Elem::Bool) -); - -fn gpu_macro_ref_cast(from_item: Item, to_item: Item) -> String { - let mut context = CubeContext::root(); - let x = context.create_local(from_item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - let y = scope.create_local(from_item); - let y_casted = scope.create_local(to_item); - let z = scope.create_local(to_item); - - match from_item.elem() { - Elem::Float(_) => gpu!(scope, y = x + 2f32), - Elem::Int(_) => gpu!(scope, y = x + 2i32), - Elem::UInt => gpu!(scope, y = x + 2u32), - Elem::Bool => gpu!(scope, y = x && false), - } - - gpu!(scope, y_casted = cast(y)); - - match to_item.elem() { - Elem::Float(_) => gpu!(scope, z = y_casted + 34f32), - Elem::Int(_) => gpu!(scope, z = y_casted + 34i32), - Elem::UInt => gpu!(scope, z = y_casted + 34u32), - Elem::Bool => gpu!(scope, z = y_casted || true), - } - - format!("{:?}", scope.operations) -} - -fn gpu_macro_ref_identity(item: Item) -> String { - // When staying with the same type variables are automatically reused in cube - let mut context = CubeContext::root(); - let x = context.create_local(item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - let y = scope.create_local(item); - - match item.elem() { - Elem::Float(_) => gpu!(scope, y = x + 2f32), - Elem::Int(_) => gpu!(scope, y = x + 2i32), - Elem::UInt => gpu!(scope, y = x + 2u32), - Elem::Bool => gpu!(scope, y = x && false), - } - - gpu!(scope, x = cast(y)); - - match item.elem() { - Elem::Float(_) => gpu!(scope, y = x + 34f32), - Elem::Int(_) => gpu!(scope, y = x + 34i32), - Elem::UInt => gpu!(scope, y = x + 34u32), - Elem::Bool => gpu!(scope, y = x || true), - } - - format!("{:?}", scope.operations) -} +// use burn_cube::{cube, Bool, CubeContext, FloatX, Float, Int, UInt}; +// use burn_jit::gpu; +// use burn_jit::gpu::FloatKind::F32; +// use burn_jit::gpu::IntKind::I32; +// use burn_jit::gpu::{Elem, Item, Variable}; + +// macro_rules! cast_test { +// ($name:ident, $module:ident, $from:expr, $to:expr) => { +// #[test] +// fn $name() { +// let mut context = CubeContext::root(); + +// let x = context.create_local($from); + +// $module::expand::(&mut context, x); +// let scope = context.into_scope(); + +// assert_eq!( +// format!("{:?}", scope.operations), +// gpu_macro_ref_cast($from, $to) +// ); +// } +// }; + +// ($name:ident, $module:ident, $ty:expr) => { +// #[test] +// fn $name() { +// let mut context = CubeContext::root(); + +// let x = context.create_local($ty); + +// $module::expand::(&mut context, x); +// let scope = context.into_scope(); + +// assert_eq!( +// format!("{:?}", scope.operations), +// gpu_macro_ref_identity($ty) +// ); +// } +// }; +// } + +// // From float +// #[cube] +// pub fn float_to_float(x: FloatX) { +// let y = x + float_new::(2.0); +// let _ = to_float::, F>(y) + float_new::(34.0); +// } + +// #[cube] +// pub fn float_to_int(x: FloatX) { +// let y = x + float_new::(2.0); +// let _ = to_int(y) + int_new(34); +// } + +// #[cube] +// pub fn float_to_uint(x: FloatX) { +// let y = x + float_new::(2.0); +// let _ = to_uint(y) + uint_new(34u32); +// } + +// #[cube] +// pub fn float_to_bool(x: FloatX) { +// let y = x + float_new::(2.0); +// let _ = to_bool(y) | bool_new(true); +// } + +// cast_test!( +// cube_float_to_float_test, +// float_to_float, +// Item::Scalar(Elem::Float(F32)) +// ); + +// cast_test!( +// cube_float_to_int_test, +// float_to_int, +// Item::Scalar(Elem::Float(F32)), +// Item::Scalar(Elem::Int(I32)) +// ); + +// cast_test!( +// cube_float_to_uint_test, +// float_to_uint, +// Item::Scalar(Elem::Float(F32)), +// Item::Scalar(Elem::UInt) +// ); + +// cast_test!( +// cube_float_to_bool_test, +// float_to_bool, +// Item::Scalar(Elem::Float(F32)), +// Item::Scalar(Elem::Bool) +// ); + +// // // From int +// #[cube] +// pub fn int_to_float(x: Int) { +// let y = x + int_new(2); +// let _ = to_float::(y) + float_new::(34.0); +// } + +// #[cube] +// pub fn int_to_int(x: Int) { +// let y = x + int_new(2); +// let _ = to_int(y) + int_new(34); +// } + +// #[cube] +// pub fn int_to_uint(x: Int) { +// let y = x + int_new(2); +// let _ = to_uint(y) + uint_new(34u32); +// } + +// #[cube] +// pub fn int_to_bool(x: Int) { +// let y = x + int_new(2); +// let _ = to_bool(y) | bool_new(true); +// } + +// cast_test!( +// cube_int_to_float_test, +// int_to_float, +// Item::Scalar(Elem::Int(I32)), +// Item::Scalar(Elem::Float(F32)) +// ); + +// cast_test!( +// cube_int_to_int_test, +// int_to_int, +// Item::Scalar(Elem::Int(I32)) +// ); + +// cast_test!( +// cube_int_to_uint_test, +// int_to_uint, +// Item::Scalar(Elem::Int(I32)), +// Item::Scalar(Elem::UInt) +// ); + +// cast_test!( +// cube_int_to_bool_test, +// int_to_bool, +// Item::Scalar(Elem::Int(I32)), +// Item::Scalar(Elem::Bool) +// ); + +// // // From uint +// #[cube] +// pub fn uint_to_float(x: UInt) { +// let y = x + uint_new(2u32); +// let _ = to_float::(y) + float_new::(34.0); +// } + +// #[cube] +// pub fn uint_to_int(x: UInt) { +// let y = x + uint_new(2u32); +// let _ = to_int(y) + int_new(34); +// } + +// #[cube] +// pub fn uint_to_uint(x: UInt) { +// let y = x + uint_new(2u32); +// let _ = to_uint(y) + uint_new(34u32); +// } + +// #[cube] +// pub fn uint_to_bool(x: UInt) { +// let y = x + uint_new(2u32); +// let _ = to_bool(y) | bool_new(true); +// } + +// cast_test!( +// cube_uint_to_float_test, +// uint_to_float, +// Item::Scalar(Elem::UInt), +// Item::Scalar(Elem::Float(F32)) +// ); + +// cast_test!( +// cube_uint_to_int_test, +// uint_to_int, +// Item::Scalar(Elem::UInt), +// Item::Scalar(Elem::Int(I32)) +// ); + +// cast_test!( +// cube_uint_to_uint_test, +// uint_to_uint, +// Item::Scalar(Elem::UInt) +// ); + +// cast_test!( +// cube_uint_to_bool_test, +// uint_to_bool, +// Item::Scalar(Elem::UInt), +// Item::Scalar(Elem::Bool) +// ); + +// // From bool +// #[cube] +// pub fn bool_to_float(x: Bool) { +// let y = x & bool_new(false); +// let _ = to_float::(y) + float_new::(34.0); +// } + +// #[cube] +// pub fn bool_to_int(x: Bool) { +// let y = x & bool_new(false); +// let _ = to_int(y) + int_new(34); +// } + +// #[cube] +// pub fn bool_to_uint(x: Bool) { +// let y = x & bool_new(false); +// let _ = to_uint(y) + uint_new(34u32); +// } + +// #[cube] +// pub fn bool_to_bool(x: Bool) { +// let y = x & bool_new(false); +// let _ = to_bool(y) | bool_new(true); +// } + +// cast_test!( +// cube_bool_to_float_test, +// bool_to_float, +// Item::Scalar(Elem::Bool), +// Item::Scalar(Elem::Float(F32)) +// ); + +// cast_test!( +// cube_bool_to_int_test, +// bool_to_int, +// Item::Scalar(Elem::Bool), +// Item::Scalar(Elem::Int(I32)) +// ); + +// cast_test!( +// cube_bool_to_uint_test, +// bool_to_uint, +// Item::Scalar(Elem::Bool), +// Item::Scalar(Elem::UInt) +// ); + +// cast_test!( +// cube_bool_to_bool_test, +// bool_to_bool, +// Item::Scalar(Elem::Bool) +// ); + +// fn gpu_macro_ref_cast(from_item: Item, to_item: Item) -> String { +// let mut context = CubeContext::root(); +// let x = context.create_local(from_item); + +// let mut scope = context.into_scope(); +// let x: Variable = x.into(); +// let y = scope.create_local(from_item); +// let y_casted = scope.create_local(to_item); +// let z = scope.create_local(to_item); + +// match from_item.elem() { +// Elem::Float(_) => gpu!(scope, y = x + 2f32), +// Elem::Int(_) => gpu!(scope, y = x + 2i32), +// Elem::UInt => gpu!(scope, y = x + 2u32), +// Elem::Bool => gpu!(scope, y = x && false), +// } + +// gpu!(scope, y_casted = cast(y)); + +// match to_item.elem() { +// Elem::Float(_) => gpu!(scope, z = y_casted + 34f32), +// Elem::Int(_) => gpu!(scope, z = y_casted + 34i32), +// Elem::UInt => gpu!(scope, z = y_casted + 34u32), +// Elem::Bool => gpu!(scope, z = y_casted || true), +// } + +// format!("{:?}", scope.operations) +// } + +// fn gpu_macro_ref_identity(item: Item) -> String { +// // When staying with the same type variables are automatically reused in cube +// let mut context = CubeContext::root(); +// let x = context.create_local(item); + +// let mut scope = context.into_scope(); +// let x: Variable = x.into(); +// let y = scope.create_local(item); + +// match item.elem() { +// Elem::Float(_) => gpu!(scope, y = x + 2f32), +// Elem::Int(_) => gpu!(scope, y = x + 2i32), +// Elem::UInt => gpu!(scope, y = x + 2u32), +// Elem::Bool => gpu!(scope, y = x && false), +// } + +// gpu!(scope, x = cast(y)); + +// match item.elem() { +// Elem::Float(_) => gpu!(scope, y = x + 34f32), +// Elem::Int(_) => gpu!(scope, y = x + 34i32), +// Elem::UInt => gpu!(scope, y = x + 34u32), +// Elem::Bool => gpu!(scope, y = x || true), +// } + +// format!("{:?}", scope.operations) +// } diff --git a/crates/burn-cube/tests/cast_kind.rs b/crates/burn-cube/tests/cast_kind.rs index 10152096c4..634211834c 100644 --- a/crates/burn-cube/tests/cast_kind.rs +++ b/crates/burn-cube/tests/cast_kind.rs @@ -1,43 +1,43 @@ -use burn_cube::{cube, CubeContext, Float, FloatKind_, F32_, F64_}; -use burn_jit::gpu; -use burn_jit::gpu::FloatKind; -use burn_jit::gpu::{Elem, Item}; - -#[cube] -pub fn cast_float_kind(input: Float) { - let x = input + float_new::(5.9f32); - let y = to_float::, F2>(x); - let _ = y + float_new::(2.3f32); -} - -#[test] -fn cube_cast_kind_test() { - let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Float(FloatKind::F64)); - - let input = context.create_local(item); - - // F16 not testable with the gpu macro, but should work the same - cast_float_kind::expand::(&mut context, input); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); -} - -fn gpu_macro_ref() -> String { - let mut context = CubeContext::root(); - let float_64 = Item::Scalar(Elem::Float(FloatKind::F64)); - let float_32 = Item::Scalar(Elem::Float(FloatKind::F32)); - let input = context.create_local(float_64); - - let mut scope = context.into_scope(); - let x = scope.create_local(float_64); - let y = scope.create_local(float_32); - let z = scope.create_local(float_32); - - gpu!(scope, x = input + 5.9f32 as f64); - gpu!(scope, y = cast(x)); - gpu!(scope, z = y + 2.3f32); - - format!("{:?}", scope.operations) -} +// use burn_cube::{cube, CubeContext, FloatX, Float, F32_, F64_}; +// use burn_jit::gpu; +// use burn_jit::gpu::FloatKind; +// use burn_jit::gpu::{Elem, Item}; + +// #[cube] +// pub fn cast_float_kind(input: FloatX) { +// let x = input + float_new::(5.9f32); +// let y = to_float::, F2>(x); +// let _ = y + float_new::(2.3f32); +// } + +// #[test] +// fn cube_cast_kind_test() { +// let mut context = CubeContext::root(); +// let item = Item::Scalar(Elem::Float(FloatKind::F64)); + +// let input = context.create_local(item); + +// // F16 not testable with the gpu macro, but should work the same +// cast_float_kind::expand::(&mut context, input); +// let scope = context.into_scope(); + +// assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); +// } + +// fn gpu_macro_ref() -> String { +// let mut context = CubeContext::root(); +// let float_64 = Item::Scalar(Elem::Float(FloatKind::F64)); +// let float_32 = Item::Scalar(Elem::Float(FloatKind::F32)); +// let input = context.create_local(float_64); + +// let mut scope = context.into_scope(); +// let x = scope.create_local(float_64); +// let y = scope.create_local(float_32); +// let z = scope.create_local(float_32); + +// gpu!(scope, x = input + 5.9f32 as f64); +// gpu!(scope, y = cast(x)); +// gpu!(scope, z = y + 2.3f32); + +// format!("{:?}", scope.operations) +// } diff --git a/crates/burn-cube/tests/for_loop.rs b/crates/burn-cube/tests/for_loop.rs index 087648d06e..59e83151b5 100644 --- a/crates/burn-cube/tests/for_loop.rs +++ b/crates/burn-cube/tests/for_loop.rs @@ -1,10 +1,11 @@ -use burn_cube::{cube, range, range_expand, Array, CubeContext, Float, FloatKind_, UInt, F32_}; +use burn_cube::{cube, range, range_expand, Array, CubeContext, Float, UInt, F32}; use burn_jit::gpu; -use burn_jit::gpu::FloatKind::F32; use burn_jit::gpu::{Elem, Item, Variable}; +type ElemType = F32; + #[cube] -pub fn for_loop(mut lhs: Array>, rhs: Float, end: UInt, unroll: bool) { +pub fn for_loop(mut lhs: Array, rhs: F, end: UInt, unroll: bool) { let tmp1 = rhs * rhs; let tmp2 = tmp1 + rhs; @@ -18,11 +19,11 @@ fn test_for_loop_with_unroll() { let mut context = CubeContext::root(); let unroll = true; - let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); - let rhs = context.create_local(Item::Scalar(Elem::Float(F32))); + let lhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); + let rhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); let end = 4u32.into(); - for_loop::expand::(&mut context, lhs, rhs, end, unroll); + for_loop::expand::(&mut context, lhs, rhs, end, unroll); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref(unroll)); @@ -33,11 +34,11 @@ fn test_for_loop_no_unroll() { let mut context = CubeContext::root(); let unroll = false; - let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); - let rhs = context.create_local(Item::Scalar(Elem::Float(F32))); + let lhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); + let rhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); let end = 4u32.into(); - for_loop::expand::(&mut context, lhs, rhs, end, unroll); + for_loop::expand::(&mut context, lhs, rhs, end, unroll); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref(unroll)); @@ -45,7 +46,7 @@ fn test_for_loop_no_unroll() { fn gpu_macro_ref(unroll: bool) -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Float(F32)); + let item = Item::Scalar(Elem::Float(ElemType::into_kind())); let lhs = context.create_local(item); let rhs = context.create_local(item); diff --git a/crates/burn-cube/tests/if.rs b/crates/burn-cube/tests/if.rs index 08942d2aa3..ab28c3e6d8 100644 --- a/crates/burn-cube/tests/if.rs +++ b/crates/burn-cube/tests/if.rs @@ -1,10 +1,11 @@ -use burn_cube::{cube, if_expand, CubeContext, Float, FloatKind_, F32_}; +use burn_cube::{cube, if_expand, CubeContext, Float, F32}; use burn_jit::gpu; -use burn_jit::gpu::FloatKind::F32; use burn_jit::gpu::{Elem, Item, Variable}; +type ElemType = F32; + #[cube] -pub fn if_greater(lhs: Float) { +pub fn if_greater(lhs: F) { if lhs > float_new::(0.0) { let _ = lhs + float_new::(4.0); } @@ -14,9 +15,9 @@ pub fn if_greater(lhs: Float) { fn cube_if_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); + let lhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); - if_greater::expand::(&mut context, lhs); + if_greater::expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); @@ -24,7 +25,7 @@ fn cube_if_test() { fn gpu_macro_ref() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Float(F32)); + let item = Item::Scalar(Elem::Float(ElemType::into_kind())); let lhs = context.create_local(item); let mut scope = context.into_scope(); diff --git a/crates/burn-cube/tests/if_else.rs b/crates/burn-cube/tests/if_else.rs index a61c7a005a..f02587cfb3 100644 --- a/crates/burn-cube/tests/if_else.rs +++ b/crates/burn-cube/tests/if_else.rs @@ -1,10 +1,11 @@ -use burn_cube::{cube, if_else_expand, CubeContext, Float, FloatKind_, F32_}; +use burn_cube::{cube, if_else_expand, CubeContext, Float, F32}; use burn_jit::gpu; -use burn_jit::gpu::FloatKind::F32; use burn_jit::gpu::{Elem, Item, Variable}; +type ElemType = F32; + #[cube] -pub fn if_else(lhs: Float) { +pub fn if_else(lhs: F) { if lhs < float_new::(0.0) { let _ = lhs + float_new::(4.0); } else { @@ -16,9 +17,9 @@ pub fn if_else(lhs: Float) { fn cube_if_else_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(Elem::Float(F32))); + let lhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); - if_else::expand::(&mut context, lhs); + if_else::expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); @@ -26,7 +27,7 @@ fn cube_if_else_test() { fn gpu_macro_ref() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Float(F32)); + let item = Item::Scalar(Elem::Float(ElemType::into_kind())); let lhs = context.create_local(item); let mut scope = context.into_scope(); diff --git a/crates/burn-cube/tests/literal.rs b/crates/burn-cube/tests/literal.rs index d424cca67b..0ced8ca5e2 100644 --- a/crates/burn-cube/tests/literal.rs +++ b/crates/burn-cube/tests/literal.rs @@ -1,20 +1,21 @@ -use burn_cube::{cube, CubeContext, Float, F32_}; +use burn_cube::{cube, CubeContext, Float, F32}; use burn_jit::gpu; -use burn_jit::gpu::FloatKind; use burn_jit::gpu::{Elem, Item}; +type ElemType = F32; + #[cube] -pub fn literal(lhs: Float) { - let _ = lhs + float_new::(5.9); +pub fn literal(lhs: F) { + let _ = lhs + float_new::(5.9); } #[test] fn cube_literal_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(Elem::Float(FloatKind::F32))); + let lhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); - literal::expand(&mut context, lhs); + literal::expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); @@ -22,7 +23,7 @@ fn cube_literal_test() { fn gpu_macro_ref() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Float(FloatKind::F32)); + let item = Item::Scalar(Elem::Float(ElemType::into_kind())); let lhs = context.create_local(item); let mut scope = context.into_scope(); diff --git a/crates/burn-cube/tests/loop.rs b/crates/burn-cube/tests/loop.rs index fb7830cf81..0513a7f9eb 100644 --- a/crates/burn-cube/tests/loop.rs +++ b/crates/burn-cube/tests/loop.rs @@ -1,71 +1,71 @@ -use burn_cube::{break_expand, cube, if_expand, loop_expand, while_loop_expand, CubeContext, Int}; -use burn_jit::gpu; -use burn_jit::gpu::Branch; -use burn_jit::gpu::IntKind::I32; -use burn_jit::gpu::{Elem, Item, Variable}; +// use burn_cube::{break_expand, cube, if_expand, loop_expand, while_loop_expand, CubeContext, Int}; +// use burn_jit::gpu; +// use burn_jit::gpu::Branch; +// use burn_jit::gpu::IntKind::I32; +// use burn_jit::gpu::{Elem, Item, Variable}; -#[cube] -pub fn while_not(lhs: Int) { - while lhs != int_new(0) { - let _ = lhs - int_new(1); - } -} +// #[cube] +// pub fn while_not(lhs: Int) { +// while lhs != int_new(0) { +// let _ = lhs - int_new(1); +// } +// } -#[cube] -pub fn manual_loop_break(lhs: Int) { - loop { - if lhs != int_new(0) { - break; - } - let _ = lhs - int_new(1); - } -} +// #[cube] +// pub fn manual_loop_break(lhs: Int) { +// loop { +// if lhs != int_new(0) { +// break; +// } +// let _ = lhs - int_new(1); +// } +// } -#[test] -fn cube_while_test() { - let mut context = CubeContext::root(); +// #[test] +// fn cube_while_test() { +// let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); +// let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); - while_not::expand(&mut context, lhs); - let scope = context.into_scope(); +// while_not::expand(&mut context, lhs); +// let scope = context.into_scope(); - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); -} +// assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); +// } -#[test] -fn cube_loop_break_test() { - let mut context = CubeContext::root(); +// #[test] +// fn cube_loop_break_test() { +// let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); +// let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); - manual_loop_break::expand(&mut context, lhs); - let scope = context.into_scope(); +// manual_loop_break::expand(&mut context, lhs); +// let scope = context.into_scope(); - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); -} +// assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); +// } -fn gpu_macro_ref() -> String { - let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Int(I32)); - let lhs = context.create_local(item); +// fn gpu_macro_ref() -> String { +// let mut context = CubeContext::root(); +// let item = Item::Scalar(Elem::Int(I32)); +// let lhs = context.create_local(item); - let mut scope = context.into_scope(); - let cond = scope.create_local(Item::Scalar(Elem::Bool)); - let lhs: Variable = lhs.into(); - let rhs = scope.create_local(item); +// let mut scope = context.into_scope(); +// let cond = scope.create_local(Item::Scalar(Elem::Bool)); +// let lhs: Variable = lhs.into(); +// let rhs = scope.create_local(item); - gpu!( - &mut scope, - loop(|scope| { - gpu!(scope, cond = lhs != 0); - gpu!(scope, if(cond).then(|scope|{ - scope.register(Branch::Break); - })); +// gpu!( +// &mut scope, +// loop(|scope| { +// gpu!(scope, cond = lhs != 0); +// gpu!(scope, if(cond).then(|scope|{ +// scope.register(Branch::Break); +// })); - gpu!(scope, rhs = lhs - 1i32); - }) - ); +// gpu!(scope, rhs = lhs - 1i32); +// }) +// ); - format!("{:?}", scope.operations) -} +// format!("{:?}", scope.operations) +// } diff --git a/crates/burn-cube/tests/reuse.rs b/crates/burn-cube/tests/reuse.rs index e0a441c8c6..f9b245242a 100644 --- a/crates/burn-cube/tests/reuse.rs +++ b/crates/burn-cube/tests/reuse.rs @@ -1,97 +1,97 @@ -use burn_cube::{cube, while_loop_expand, CubeContext, Int}; -use burn_jit::gpu; -use burn_jit::gpu::IntKind::I32; -use burn_jit::gpu::{Branch, Elem, Item, Variable}; - -// TODO -// a += b is more efficient than a = a + b -// because the latter does not assume that a is the same in lhs and rhs -// It could be detected and optimized - -#[cube] -pub fn reuse(mut x: Int) { - while x < int_new(10) { - x = x + int_new(1); - } -} - -#[cube] -pub fn reuse_incr(mut x: Int) { - while x < int_new(10) { - x += int_new(1); - } -} - -#[test] -fn cube_reuse_assign_test() { - let mut context = CubeContext::root(); - - let x = context.create_local(Item::Scalar(Elem::Int(I32))); - - reuse::expand(&mut context, x); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_assign()); -} - -#[test] -fn cube_reuse_incr_test() { - let mut context = CubeContext::root(); - - let x = context.create_local(Item::Scalar(Elem::Int(I32))); - - reuse_incr::expand(&mut context, x); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_incr()); -} - -fn gpu_macro_ref_assign() -> String { - let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Int(I32)); - let x = context.create_local(item); - - let mut scope = context.into_scope(); - let cond = scope.create_local(Item::Scalar(Elem::Bool)); - let x: Variable = x.into(); - let tmp = scope.create_local(item); - - gpu!( - &mut scope, - loop(|scope| { - gpu!(scope, cond = x < 10); - gpu!(scope, if(cond).then(|scope|{ - scope.register(Branch::Break); - })); - - gpu!(scope, tmp = x + 1); - gpu!(scope, x = tmp); - }) - ); - - format!("{:?}", scope.operations) -} - -fn gpu_macro_ref_incr() -> String { - let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Int(I32)); - let x = context.create_local(item); - - let mut scope = context.into_scope(); - let cond = scope.create_local(Item::Scalar(Elem::Bool)); - let x: Variable = x.into(); - - gpu!( - &mut scope, - loop(|scope| { - gpu!(scope, cond = x < 10); - gpu!(scope, if(cond).then(|scope|{ - scope.register(Branch::Break); - })); - - gpu!(scope, x = x + 1); - }) - ); - - format!("{:?}", scope.operations) -} +// use burn_cube::{cube, while_loop_expand, CubeContext, Int}; +// use burn_jit::gpu; +// use burn_jit::gpu::IntKind::I32; +// use burn_jit::gpu::{Branch, Elem, Item, Variable}; + +// // TODO +// // a += b is more efficient than a = a + b +// // because the latter does not assume that a is the same in lhs and rhs +// // It could be detected and optimized + +// #[cube] +// pub fn reuse(mut x: Int) { +// while x < int_new(10) { +// x = x + int_new(1); +// } +// } + +// #[cube] +// pub fn reuse_incr(mut x: Int) { +// while x < int_new(10) { +// x += int_new(1); +// } +// } + +// #[test] +// fn cube_reuse_assign_test() { +// let mut context = CubeContext::root(); + +// let x = context.create_local(Item::Scalar(Elem::Int(I32))); + +// reuse::expand(&mut context, x); +// let scope = context.into_scope(); + +// assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_assign()); +// } + +// #[test] +// fn cube_reuse_incr_test() { +// let mut context = CubeContext::root(); + +// let x = context.create_local(Item::Scalar(Elem::Int(I32))); + +// reuse_incr::expand(&mut context, x); +// let scope = context.into_scope(); + +// assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_incr()); +// } + +// fn gpu_macro_ref_assign() -> String { +// let mut context = CubeContext::root(); +// let item = Item::Scalar(Elem::Int(I32)); +// let x = context.create_local(item); + +// let mut scope = context.into_scope(); +// let cond = scope.create_local(Item::Scalar(Elem::Bool)); +// let x: Variable = x.into(); +// let tmp = scope.create_local(item); + +// gpu!( +// &mut scope, +// loop(|scope| { +// gpu!(scope, cond = x < 10); +// gpu!(scope, if(cond).then(|scope|{ +// scope.register(Branch::Break); +// })); + +// gpu!(scope, tmp = x + 1); +// gpu!(scope, x = tmp); +// }) +// ); + +// format!("{:?}", scope.operations) +// } + +// fn gpu_macro_ref_incr() -> String { +// let mut context = CubeContext::root(); +// let item = Item::Scalar(Elem::Int(I32)); +// let x = context.create_local(item); + +// let mut scope = context.into_scope(); +// let cond = scope.create_local(Item::Scalar(Elem::Bool)); +// let x: Variable = x.into(); + +// gpu!( +// &mut scope, +// loop(|scope| { +// gpu!(scope, cond = x < 10); +// gpu!(scope, if(cond).then(|scope|{ +// scope.register(Branch::Break); +// })); + +// gpu!(scope, x = x + 1); +// }) +// ); + +// format!("{:?}", scope.operations) +// } From 4f7c56f3f30ad7c28511da02cfeee3f95b86011d Mon Sep 17 00:00:00 2001 From: louisfd Date: Thu, 9 May 2024 11:34:04 -0400 Subject: [PATCH 30/54] make tests work --- crates/burn-cube-macros/src/prelude.rs | 33 +- crates/burn-cube/src/element/array.rs | 16 +- crates/burn-cube/src/element/base.rs | 1 - crates/burn-cube/src/element/float.rs | 96 +-- crates/burn-cube/src/element/int.rs | 53 +- crates/burn-cube/src/operation/assignation.rs | 5 +- crates/burn-cube/src/operation/binary.rs | 23 +- crates/burn-cube/src/operation/cmp.rs | 5 +- crates/burn-cube/tests/cast_elem.rs | 610 +++++++++--------- crates/burn-cube/tests/loop.rs | 145 +++-- crates/burn-cube/tests/reuse.rs | 195 +++--- 11 files changed, 579 insertions(+), 603 deletions(-) diff --git a/crates/burn-cube-macros/src/prelude.rs b/crates/burn-cube-macros/src/prelude.rs index 0b70689246..6a560c267a 100644 --- a/crates/burn-cube-macros/src/prelude.rs +++ b/crates/burn-cube-macros/src/prelude.rs @@ -33,35 +33,29 @@ pub(crate) fn get_prelude(needed_functions: &HashSet) -> proc_macro fn codegen_float_new() -> proc_macro2::TokenStream { quote::quote! { - use std::{rc::Rc}; - use burn_cube::ExpandElement; - use burn_jit::gpu::Variable; pub fn float_new(val: f32) -> F { F::new(val, 1) } + pub fn float_new_expand( context: &mut CubeContext, val: f32, ) -> ::ExpandType { - let elem = Elem::Float(F::into_kind()); - let new_var = Variable::ConstantScalar(val as f64, elem); - ExpandElement::new(Rc::new(new_var)) + val.into() } } } fn codegen_int_new() -> proc_macro2::TokenStream { quote::quote! { - pub fn int_new(val: i32) -> Int { - Int { - val, - vectorization: 1, - } + pub fn int_new(val: i32) -> I { + I::new(val, 1) } - pub fn int_new_expand( + + pub fn int_new_expand( context: &mut CubeContext, val: i32, - ) -> ::ExpandType { + ) -> ::ExpandType { val.into() } } @@ -103,17 +97,14 @@ fn codegen_bool_new() -> proc_macro2::TokenStream { fn codegen_to_int() -> proc_macro2::TokenStream { quote::quote! { - pub fn to_int(input: R) -> Int { - Int { - val: 0, - vectorization: 1, - } + pub fn to_int(input: R) -> I { + I::new(0, 1) } - pub fn to_int_expand( + pub fn to_int_expand( context: &mut CubeContext, val: burn_cube::ExpandElement, - ) -> ::ExpandType { - let elem = Elem::Int(I32); + ) -> ::ExpandType { + let elem = Elem::Int(I::into_kind()); let new_var = context.create_local(match val.item() { Item::Vec4(_) => Item::Vec4(elem), Item::Vec3(_) => Item::Vec3(elem), diff --git a/crates/burn-cube/src/element/array.rs b/crates/burn-cube/src/element/array.rs index c7300b6cc9..edab275011 100644 --- a/crates/burn-cube/src/element/array.rs +++ b/crates/burn-cube/src/element/array.rs @@ -1,22 +1,10 @@ -use crate::{Bool, ExpandElement, Float, Int, RuntimeType, UInt}; +use crate::{ExpandElement, RuntimeType}; #[derive(new, Clone)] pub struct Array { pub vals: Vec, } -impl RuntimeType for Array { - type ExpandType = ExpandElement; -} - -impl RuntimeType for Array { - type ExpandType = ExpandElement; -} - -impl RuntimeType for Array { - type ExpandType = ExpandElement; -} - -impl RuntimeType for Array { +impl RuntimeType for Array { type ExpandType = ExpandElement; } diff --git a/crates/burn-cube/src/element/base.rs b/crates/burn-cube/src/element/base.rs index aca38847de..a14cd24ea8 100644 --- a/crates/burn-cube/src/element/base.rs +++ b/crates/burn-cube/src/element/base.rs @@ -41,7 +41,6 @@ impl core::ops::Deref for ExpandElement { impl From for Variable { fn from(value: ExpandElement) -> Self { - // Is it ok to do that? (*value.inner).clone() } } diff --git a/crates/burn-cube/src/element/float.rs b/crates/burn-cube/src/element/float.rs index 3674310605..f7d9a46ea1 100644 --- a/crates/burn-cube/src/element/float.rs +++ b/crates/burn-cube/src/element/float.rs @@ -6,79 +6,43 @@ pub trait Float: + RuntimeType + std::cmp::PartialOrd + std::ops::Add - + std::ops::Mul + std::ops::Sub + + std::ops::Mul + + std::ops::Div { fn into_kind() -> burn_jit::gpu::FloatKind; fn new(val: f32, vectorization: usize) -> Self; } -#[derive(Clone, Copy)] -pub struct F16 { - pub val: f32, - pub vectorization: usize, -} -#[derive(Clone, Copy)] -pub struct BF16 { - pub val: f32, - pub vectorization: usize, -} -#[derive(Clone, Copy)] -pub struct F32 { - pub val: f32, - pub vectorization: usize, -} -#[derive(Clone, Copy)] -pub struct F64 { - pub val: f32, - pub vectorization: usize, -} +macro_rules! impl_float { + ($type:ident) => { + #[derive(Clone, Copy)] + pub struct $type { + pub val: f32, + pub vectorization: usize, + } -impl RuntimeType for F16 { - type ExpandType = ExpandElement; -} + impl RuntimeType for $type { + type ExpandType = ExpandElement; + } -impl RuntimeType for BF16 { - type ExpandType = ExpandElement; + impl Float for $type { + fn into_kind() -> burn_jit::gpu::FloatKind { + burn_jit::gpu::FloatKind::$type + } + fn new(val: f32, vectorization: usize) -> Self { + Self { val, vectorization } + } + } + impl From for $type { + fn from(value: f32) -> Self { + $type::new(value, 1) + } + } + }; } -impl RuntimeType for F32 { - type ExpandType = ExpandElement; -} - -impl RuntimeType for F64 { - type ExpandType = ExpandElement; -} - -impl Float for F16 { - fn into_kind() -> burn_jit::gpu::FloatKind { - burn_jit::gpu::FloatKind::F16 - } - fn new(val: f32, vectorization: usize) -> Self { - Self { val, vectorization } - } -} -impl Float for BF16 { - fn into_kind() -> burn_jit::gpu::FloatKind { - burn_jit::gpu::FloatKind::BF16 - } - fn new(val: f32, vectorization: usize) -> Self { - Self { val, vectorization } - } -} -impl Float for F32 { - fn into_kind() -> burn_jit::gpu::FloatKind { - burn_jit::gpu::FloatKind::F32 - } - fn new(val: f32, vectorization: usize) -> Self { - Self { val, vectorization } - } -} -impl Float for F64 { - fn into_kind() -> burn_jit::gpu::FloatKind { - burn_jit::gpu::FloatKind::F64 - } - fn new(val: f32, vectorization: usize) -> Self { - Self { val, vectorization } - } -} +impl_float!(F16); +impl_float!(BF16); +impl_float!(F32); +impl_float!(F64); diff --git a/crates/burn-cube/src/element/int.rs b/crates/burn-cube/src/element/int.rs index 3437165ba8..673c9d456c 100644 --- a/crates/burn-cube/src/element/int.rs +++ b/crates/burn-cube/src/element/int.rs @@ -1,17 +1,48 @@ use crate::{ExpandElement, RuntimeType}; -#[derive(new, Clone, Copy)] -pub struct Int { - pub val: i32, - pub vectorization: u8, +pub trait Int: + Clone + + Copy + + RuntimeType + + std::cmp::PartialOrd + + std::ops::Add + + std::ops::Sub + + std::ops::Mul + + std::ops::Div + + std::ops::AddAssign +{ + fn into_kind() -> burn_jit::gpu::IntKind; + fn new(val: i32, vectorization: usize) -> Self; } -impl RuntimeType for Int { - type ExpandType = ExpandElement; -} +macro_rules! impl_int { + ($type:ident) => { + #[derive(Clone, Copy)] + pub struct $type { + pub val: i32, + pub vectorization: usize, + } + + impl RuntimeType for $type { + type ExpandType = ExpandElement; + } -impl From for Int { - fn from(value: i32) -> Self { - Int::new(value, 1) - } + impl Int for $type { + fn into_kind() -> burn_jit::gpu::IntKind { + burn_jit::gpu::IntKind::$type + } + fn new(val: i32, vectorization: usize) -> Self { + Self { val, vectorization } + } + } + + impl From for $type { + fn from(value: i32) -> Self { + $type::new(value, 1) + } + } + }; } + +impl_int!(I32); +impl_int!(I64); diff --git a/crates/burn-cube/src/operation/assignation.rs b/crates/burn-cube/src/operation/assignation.rs index fd6ea3897f..ff5eba0ad3 100644 --- a/crates/burn-cube/src/operation/assignation.rs +++ b/crates/burn-cube/src/operation/assignation.rs @@ -64,7 +64,7 @@ pub mod index { } pub mod add_assign_op { - use crate::{operation::base::assign_op_expand, Int, BF16, F16, F32, F64}; + use crate::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64}; use super::*; @@ -90,6 +90,7 @@ pub mod add_assign_op { impl_add_assign!(BF16); impl_add_assign!(F32); impl_add_assign!(F64); - impl_add_assign!(Int); + impl_add_assign!(I32); + impl_add_assign!(I64); impl_add_assign!(UInt); } diff --git a/crates/burn-cube/src/operation/binary.rs b/crates/burn-cube/src/operation/binary.rs index f3e2f15c14..5ae1adecc8 100644 --- a/crates/burn-cube/src/operation/binary.rs +++ b/crates/burn-cube/src/operation/binary.rs @@ -1,5 +1,5 @@ use crate::operation::base::binary_expand; -use crate::{CubeContext, ExpandElement, Float, Int, UInt, BF16, F16, F32, F64}; +use crate::{CubeContext, ExpandElement, Float, Int, UInt, BF16, F16, F32, F64, I32, I64}; use burn_jit::gpu::{self}; pub mod add { @@ -30,7 +30,8 @@ pub mod add { impl_add!(BF16); impl_add!(F32); impl_add!(F64); - impl_add!(Int); + impl_add!(I32); + impl_add!(I64); impl_add!(UInt); } @@ -61,7 +62,8 @@ pub mod sub { impl_sub!(BF16); impl_sub!(F32); impl_sub!(F64); - impl_sub!(Int); + impl_sub!(I32); + impl_sub!(I64); impl_sub!(UInt); } @@ -94,7 +96,8 @@ pub mod mul { impl_mul!(BF16); impl_mul!(F32); impl_mul!(F64); - impl_mul!(Int); + impl_mul!(I32); + impl_mul!(I64); impl_mul!(UInt); } @@ -127,13 +130,12 @@ pub mod div { impl_div!(BF16); impl_div!(F32); impl_div!(F64); - impl_div!(Int); + impl_div!(I32); + impl_div!(I64); impl_div!(UInt); } pub mod rem { - use crate::Float; - use super::*; pub fn expand( @@ -156,11 +158,8 @@ pub mod rem { }; } - impl_rem!(F16); - impl_rem!(BF16); - impl_rem!(F32); - impl_rem!(F64); - impl_rem!(Int); + impl_rem!(I32); + impl_rem!(I64); impl_rem!(UInt); } diff --git a/crates/burn-cube/src/operation/cmp.rs b/crates/burn-cube/src/operation/cmp.rs index af71f0ca10..169f21bcca 100644 --- a/crates/burn-cube/src/operation/cmp.rs +++ b/crates/burn-cube/src/operation/cmp.rs @@ -1,5 +1,5 @@ use crate::operation::base::cmp_expand; -use crate::{CubeContext, ExpandElement, Int, UInt, BF16, F16, F32, F64}; +use crate::{CubeContext, ExpandElement, UInt, BF16, F16, F32, F64, I32, I64}; use burn_jit::gpu::{self}; macro_rules! impl_cmp { @@ -26,7 +26,8 @@ impl_cmp!(F16); impl_cmp!(BF16); impl_cmp!(F32); impl_cmp!(F64); -impl_cmp!(Int); +impl_cmp!(I32); +impl_cmp!(I64); impl_cmp!(UInt); pub mod ne { diff --git a/crates/burn-cube/tests/cast_elem.rs b/crates/burn-cube/tests/cast_elem.rs index fbd11d3f2c..52a8feaa7c 100644 --- a/crates/burn-cube/tests/cast_elem.rs +++ b/crates/burn-cube/tests/cast_elem.rs @@ -1,306 +1,304 @@ -// use burn_cube::{cube, Bool, CubeContext, FloatX, Float, Int, UInt}; -// use burn_jit::gpu; -// use burn_jit::gpu::FloatKind::F32; -// use burn_jit::gpu::IntKind::I32; -// use burn_jit::gpu::{Elem, Item, Variable}; - -// macro_rules! cast_test { -// ($name:ident, $module:ident, $from:expr, $to:expr) => { -// #[test] -// fn $name() { -// let mut context = CubeContext::root(); - -// let x = context.create_local($from); - -// $module::expand::(&mut context, x); -// let scope = context.into_scope(); - -// assert_eq!( -// format!("{:?}", scope.operations), -// gpu_macro_ref_cast($from, $to) -// ); -// } -// }; - -// ($name:ident, $module:ident, $ty:expr) => { -// #[test] -// fn $name() { -// let mut context = CubeContext::root(); - -// let x = context.create_local($ty); - -// $module::expand::(&mut context, x); -// let scope = context.into_scope(); - -// assert_eq!( -// format!("{:?}", scope.operations), -// gpu_macro_ref_identity($ty) -// ); -// } -// }; -// } - -// // From float -// #[cube] -// pub fn float_to_float(x: FloatX) { -// let y = x + float_new::(2.0); -// let _ = to_float::, F>(y) + float_new::(34.0); -// } - -// #[cube] -// pub fn float_to_int(x: FloatX) { -// let y = x + float_new::(2.0); -// let _ = to_int(y) + int_new(34); -// } - -// #[cube] -// pub fn float_to_uint(x: FloatX) { -// let y = x + float_new::(2.0); -// let _ = to_uint(y) + uint_new(34u32); -// } - -// #[cube] -// pub fn float_to_bool(x: FloatX) { -// let y = x + float_new::(2.0); -// let _ = to_bool(y) | bool_new(true); -// } - -// cast_test!( -// cube_float_to_float_test, -// float_to_float, -// Item::Scalar(Elem::Float(F32)) -// ); - -// cast_test!( -// cube_float_to_int_test, -// float_to_int, -// Item::Scalar(Elem::Float(F32)), -// Item::Scalar(Elem::Int(I32)) -// ); - -// cast_test!( -// cube_float_to_uint_test, -// float_to_uint, -// Item::Scalar(Elem::Float(F32)), -// Item::Scalar(Elem::UInt) -// ); - -// cast_test!( -// cube_float_to_bool_test, -// float_to_bool, -// Item::Scalar(Elem::Float(F32)), -// Item::Scalar(Elem::Bool) -// ); - -// // // From int -// #[cube] -// pub fn int_to_float(x: Int) { -// let y = x + int_new(2); -// let _ = to_float::(y) + float_new::(34.0); -// } - -// #[cube] -// pub fn int_to_int(x: Int) { -// let y = x + int_new(2); -// let _ = to_int(y) + int_new(34); -// } - -// #[cube] -// pub fn int_to_uint(x: Int) { -// let y = x + int_new(2); -// let _ = to_uint(y) + uint_new(34u32); -// } - -// #[cube] -// pub fn int_to_bool(x: Int) { -// let y = x + int_new(2); -// let _ = to_bool(y) | bool_new(true); -// } - -// cast_test!( -// cube_int_to_float_test, -// int_to_float, -// Item::Scalar(Elem::Int(I32)), -// Item::Scalar(Elem::Float(F32)) -// ); - -// cast_test!( -// cube_int_to_int_test, -// int_to_int, -// Item::Scalar(Elem::Int(I32)) -// ); - -// cast_test!( -// cube_int_to_uint_test, -// int_to_uint, -// Item::Scalar(Elem::Int(I32)), -// Item::Scalar(Elem::UInt) -// ); - -// cast_test!( -// cube_int_to_bool_test, -// int_to_bool, -// Item::Scalar(Elem::Int(I32)), -// Item::Scalar(Elem::Bool) -// ); - -// // // From uint -// #[cube] -// pub fn uint_to_float(x: UInt) { -// let y = x + uint_new(2u32); -// let _ = to_float::(y) + float_new::(34.0); -// } - -// #[cube] -// pub fn uint_to_int(x: UInt) { -// let y = x + uint_new(2u32); -// let _ = to_int(y) + int_new(34); -// } - -// #[cube] -// pub fn uint_to_uint(x: UInt) { -// let y = x + uint_new(2u32); -// let _ = to_uint(y) + uint_new(34u32); -// } - -// #[cube] -// pub fn uint_to_bool(x: UInt) { -// let y = x + uint_new(2u32); -// let _ = to_bool(y) | bool_new(true); -// } - -// cast_test!( -// cube_uint_to_float_test, -// uint_to_float, -// Item::Scalar(Elem::UInt), -// Item::Scalar(Elem::Float(F32)) -// ); - -// cast_test!( -// cube_uint_to_int_test, -// uint_to_int, -// Item::Scalar(Elem::UInt), -// Item::Scalar(Elem::Int(I32)) -// ); - -// cast_test!( -// cube_uint_to_uint_test, -// uint_to_uint, -// Item::Scalar(Elem::UInt) -// ); - -// cast_test!( -// cube_uint_to_bool_test, -// uint_to_bool, -// Item::Scalar(Elem::UInt), -// Item::Scalar(Elem::Bool) -// ); - -// // From bool -// #[cube] -// pub fn bool_to_float(x: Bool) { -// let y = x & bool_new(false); -// let _ = to_float::(y) + float_new::(34.0); -// } - -// #[cube] -// pub fn bool_to_int(x: Bool) { -// let y = x & bool_new(false); -// let _ = to_int(y) + int_new(34); -// } - -// #[cube] -// pub fn bool_to_uint(x: Bool) { -// let y = x & bool_new(false); -// let _ = to_uint(y) + uint_new(34u32); -// } - -// #[cube] -// pub fn bool_to_bool(x: Bool) { -// let y = x & bool_new(false); -// let _ = to_bool(y) | bool_new(true); -// } - -// cast_test!( -// cube_bool_to_float_test, -// bool_to_float, -// Item::Scalar(Elem::Bool), -// Item::Scalar(Elem::Float(F32)) -// ); - -// cast_test!( -// cube_bool_to_int_test, -// bool_to_int, -// Item::Scalar(Elem::Bool), -// Item::Scalar(Elem::Int(I32)) -// ); - -// cast_test!( -// cube_bool_to_uint_test, -// bool_to_uint, -// Item::Scalar(Elem::Bool), -// Item::Scalar(Elem::UInt) -// ); - -// cast_test!( -// cube_bool_to_bool_test, -// bool_to_bool, -// Item::Scalar(Elem::Bool) -// ); - -// fn gpu_macro_ref_cast(from_item: Item, to_item: Item) -> String { -// let mut context = CubeContext::root(); -// let x = context.create_local(from_item); - -// let mut scope = context.into_scope(); -// let x: Variable = x.into(); -// let y = scope.create_local(from_item); -// let y_casted = scope.create_local(to_item); -// let z = scope.create_local(to_item); - -// match from_item.elem() { -// Elem::Float(_) => gpu!(scope, y = x + 2f32), -// Elem::Int(_) => gpu!(scope, y = x + 2i32), -// Elem::UInt => gpu!(scope, y = x + 2u32), -// Elem::Bool => gpu!(scope, y = x && false), -// } - -// gpu!(scope, y_casted = cast(y)); - -// match to_item.elem() { -// Elem::Float(_) => gpu!(scope, z = y_casted + 34f32), -// Elem::Int(_) => gpu!(scope, z = y_casted + 34i32), -// Elem::UInt => gpu!(scope, z = y_casted + 34u32), -// Elem::Bool => gpu!(scope, z = y_casted || true), -// } - -// format!("{:?}", scope.operations) -// } - -// fn gpu_macro_ref_identity(item: Item) -> String { -// // When staying with the same type variables are automatically reused in cube -// let mut context = CubeContext::root(); -// let x = context.create_local(item); - -// let mut scope = context.into_scope(); -// let x: Variable = x.into(); -// let y = scope.create_local(item); - -// match item.elem() { -// Elem::Float(_) => gpu!(scope, y = x + 2f32), -// Elem::Int(_) => gpu!(scope, y = x + 2i32), -// Elem::UInt => gpu!(scope, y = x + 2u32), -// Elem::Bool => gpu!(scope, y = x && false), -// } - -// gpu!(scope, x = cast(y)); - -// match item.elem() { -// Elem::Float(_) => gpu!(scope, y = x + 34f32), -// Elem::Int(_) => gpu!(scope, y = x + 34i32), -// Elem::UInt => gpu!(scope, y = x + 34u32), -// Elem::Bool => gpu!(scope, y = x || true), -// } - -// format!("{:?}", scope.operations) -// } +use burn_cube::{cube, Bool, CubeContext, Float, Int, UInt, F32, I32}; +use burn_jit::gpu; +use burn_jit::gpu::{Elem, Item, Variable}; + +macro_rules! cast_test { + ($name:ident, $module:ident, $from:expr, $to:expr) => { + #[test] + fn $name() { + let mut context = CubeContext::root(); + + let x = context.create_local($from); + + $module::expand(&mut context, x); + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + gpu_macro_ref_cast($from, $to) + ); + } + }; + + ($name:ident, $module:ident, $ty:expr) => { + #[test] + fn $name() { + let mut context = CubeContext::root(); + + let x = context.create_local($ty); + + $module::expand(&mut context, x); + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + gpu_macro_ref_identity($ty) + ); + } + }; +} + +// From float +#[cube] +pub fn float_to_float(x: F32) { + let y = x + float_new::(2.0); + let _ = to_float::(y) + float_new::(34.0); +} + +#[cube] +pub fn float_to_int(x: F32) { + let y = x + float_new::(2.0); + let _ = to_int::(y) + int_new::(34); +} + +#[cube] +pub fn float_to_uint(x: F32) { + let y = x + float_new::(2.0); + let _ = to_uint(y) + uint_new(34u32); +} + +#[cube] +pub fn float_to_bool(x: F32) { + let y = x + float_new::(2.0); + let _ = to_bool(y) | bool_new(true); +} + +cast_test!( + cube_float_to_float_test, + float_to_float, + Item::Scalar(Elem::Float(F32::into_kind())) +); + +cast_test!( + cube_float_to_int_test, + float_to_int, + Item::Scalar(Elem::Float(F32::into_kind())), + Item::Scalar(Elem::Int(I32::into_kind())) +); + +cast_test!( + cube_float_to_uint_test, + float_to_uint, + Item::Scalar(Elem::Float(F32::into_kind())), + Item::Scalar(Elem::UInt) +); + +cast_test!( + cube_float_to_bool_test, + float_to_bool, + Item::Scalar(Elem::Float(F32::into_kind())), + Item::Scalar(Elem::Bool) +); + +// // From int +#[cube] +pub fn int_to_float(x: I32) { + let y = x + int_new::(2); + let _ = to_float::(y) + float_new::(34.0); +} + +#[cube] +pub fn int_to_int(x: I32) { + let y = x + int_new::(2); + let _ = to_int::(y) + int_new::(34); +} + +#[cube] +pub fn int_to_uint(x: I32) { + let y = x + int_new::(2); + let _ = to_uint(y) + uint_new(34u32); +} + +#[cube] +pub fn int_to_bool(x: I32) { + let y = x + int_new::(2); + let _ = to_bool(y) | bool_new(true); +} + +cast_test!( + cube_int_to_float_test, + int_to_float, + Item::Scalar(Elem::Int(I32::into_kind())), + Item::Scalar(Elem::Float(F32::into_kind())) +); + +cast_test!( + cube_int_to_int_test, + int_to_int, + Item::Scalar(Elem::Int(I32::into_kind())) +); + +cast_test!( + cube_int_to_uint_test, + int_to_uint, + Item::Scalar(Elem::Int(I32::into_kind())), + Item::Scalar(Elem::UInt) +); + +cast_test!( + cube_int_to_bool_test, + int_to_bool, + Item::Scalar(Elem::Int(I32::into_kind())), + Item::Scalar(Elem::Bool) +); + +// // From uint +#[cube] +pub fn uint_to_float(x: UInt) { + let y = x + uint_new(2u32); + let _ = to_float::(y) + float_new::(34.0); +} + +#[cube] +pub fn uint_to_int(x: UInt) { + let y = x + uint_new(2u32); + let _ = to_int::(y) + int_new::(34); +} + +#[cube] +pub fn uint_to_uint(x: UInt) { + let y = x + uint_new(2u32); + let _ = to_uint(y) + uint_new(34u32); +} + +#[cube] +pub fn uint_to_bool(x: UInt) { + let y = x + uint_new(2u32); + let _ = to_bool(y) | bool_new(true); +} + +cast_test!( + cube_uint_to_float_test, + uint_to_float, + Item::Scalar(Elem::UInt), + Item::Scalar(Elem::Float(F32::into_kind())) +); + +cast_test!( + cube_uint_to_int_test, + uint_to_int, + Item::Scalar(Elem::UInt), + Item::Scalar(Elem::Int(I32::into_kind())) +); + +cast_test!( + cube_uint_to_uint_test, + uint_to_uint, + Item::Scalar(Elem::UInt) +); + +cast_test!( + cube_uint_to_bool_test, + uint_to_bool, + Item::Scalar(Elem::UInt), + Item::Scalar(Elem::Bool) +); + +// From bool +#[cube] +pub fn bool_to_float(x: Bool) { + let y = x & bool_new(false); + let _ = to_float::(y) + float_new::(34.0); +} + +#[cube] +pub fn bool_to_int(x: Bool) { + let y = x & bool_new(false); + let _ = to_int::(y) + int_new::(34); +} + +#[cube] +pub fn bool_to_uint(x: Bool) { + let y = x & bool_new(false); + let _ = to_uint(y) + uint_new(34u32); +} + +#[cube] +pub fn bool_to_bool(x: Bool) { + let y = x & bool_new(false); + let _ = to_bool(y) | bool_new(true); +} + +cast_test!( + cube_bool_to_float_test, + bool_to_float, + Item::Scalar(Elem::Bool), + Item::Scalar(Elem::Float(F32::into_kind())) +); + +cast_test!( + cube_bool_to_int_test, + bool_to_int, + Item::Scalar(Elem::Bool), + Item::Scalar(Elem::Int(I32::into_kind())) +); + +cast_test!( + cube_bool_to_uint_test, + bool_to_uint, + Item::Scalar(Elem::Bool), + Item::Scalar(Elem::UInt) +); + +cast_test!( + cube_bool_to_bool_test, + bool_to_bool, + Item::Scalar(Elem::Bool) +); + +fn gpu_macro_ref_cast(from_item: Item, to_item: Item) -> String { + let mut context = CubeContext::root(); + let x = context.create_local(from_item); + + let mut scope = context.into_scope(); + let x: Variable = x.into(); + let y = scope.create_local(from_item); + let y_casted = scope.create_local(to_item); + let z = scope.create_local(to_item); + + match from_item.elem() { + Elem::Float(_) => gpu!(scope, y = x + 2f32), + Elem::Int(_) => gpu!(scope, y = x + 2i32), + Elem::UInt => gpu!(scope, y = x + 2u32), + Elem::Bool => gpu!(scope, y = x && false), + } + + gpu!(scope, y_casted = cast(y)); + + match to_item.elem() { + Elem::Float(_) => gpu!(scope, z = y_casted + 34f32), + Elem::Int(_) => gpu!(scope, z = y_casted + 34i32), + Elem::UInt => gpu!(scope, z = y_casted + 34u32), + Elem::Bool => gpu!(scope, z = y_casted || true), + } + + format!("{:?}", scope.operations) +} + +fn gpu_macro_ref_identity(item: Item) -> String { + // When staying with the same type variables are automatically reused in cube + let mut context = CubeContext::root(); + let x = context.create_local(item); + + let mut scope = context.into_scope(); + let x: Variable = x.into(); + let y = scope.create_local(item); + + match item.elem() { + Elem::Float(_) => gpu!(scope, y = x + 2f32), + Elem::Int(_) => gpu!(scope, y = x + 2i32), + Elem::UInt => gpu!(scope, y = x + 2u32), + Elem::Bool => gpu!(scope, y = x && false), + } + + gpu!(scope, x = cast(y)); + + match item.elem() { + Elem::Float(_) => gpu!(scope, y = x + 34f32), + Elem::Int(_) => gpu!(scope, y = x + 34i32), + Elem::UInt => gpu!(scope, y = x + 34u32), + Elem::Bool => gpu!(scope, y = x || true), + } + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/loop.rs b/crates/burn-cube/tests/loop.rs index 0513a7f9eb..0c1fd8b788 100644 --- a/crates/burn-cube/tests/loop.rs +++ b/crates/burn-cube/tests/loop.rs @@ -1,71 +1,74 @@ -// use burn_cube::{break_expand, cube, if_expand, loop_expand, while_loop_expand, CubeContext, Int}; -// use burn_jit::gpu; -// use burn_jit::gpu::Branch; -// use burn_jit::gpu::IntKind::I32; -// use burn_jit::gpu::{Elem, Item, Variable}; - -// #[cube] -// pub fn while_not(lhs: Int) { -// while lhs != int_new(0) { -// let _ = lhs - int_new(1); -// } -// } - -// #[cube] -// pub fn manual_loop_break(lhs: Int) { -// loop { -// if lhs != int_new(0) { -// break; -// } -// let _ = lhs - int_new(1); -// } -// } - -// #[test] -// fn cube_while_test() { -// let mut context = CubeContext::root(); - -// let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); - -// while_not::expand(&mut context, lhs); -// let scope = context.into_scope(); - -// assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); -// } - -// #[test] -// fn cube_loop_break_test() { -// let mut context = CubeContext::root(); - -// let lhs = context.create_local(Item::Scalar(Elem::Int(I32))); - -// manual_loop_break::expand(&mut context, lhs); -// let scope = context.into_scope(); - -// assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); -// } - -// fn gpu_macro_ref() -> String { -// let mut context = CubeContext::root(); -// let item = Item::Scalar(Elem::Int(I32)); -// let lhs = context.create_local(item); - -// let mut scope = context.into_scope(); -// let cond = scope.create_local(Item::Scalar(Elem::Bool)); -// let lhs: Variable = lhs.into(); -// let rhs = scope.create_local(item); - -// gpu!( -// &mut scope, -// loop(|scope| { -// gpu!(scope, cond = lhs != 0); -// gpu!(scope, if(cond).then(|scope|{ -// scope.register(Branch::Break); -// })); - -// gpu!(scope, rhs = lhs - 1i32); -// }) -// ); - -// format!("{:?}", scope.operations) -// } +use burn_cube::{ + break_expand, cube, if_expand, loop_expand, while_loop_expand, CubeContext, Int, I32, +}; +use burn_jit::gpu; +use burn_jit::gpu::Branch; +use burn_jit::gpu::{Elem, Item, Variable}; + +type ElemType = I32; + +#[cube] +pub fn while_not(lhs: I) { + while lhs != int_new::(0) { + let _ = lhs - int_new::(1); + } +} + +#[cube] +pub fn manual_loop_break(lhs: I) { + loop { + if lhs != int_new::(0) { + break; + } + let _ = lhs - int_new::(1); + } +} + +#[test] +fn cube_while_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(Elem::Int(ElemType::into_kind()))); + + while_not::expand::(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); +} + +#[test] +fn cube_loop_break_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(Elem::Int(ElemType::into_kind()))); + + manual_loop_break::expand::(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); +} + +fn gpu_macro_ref() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Int(ElemType::into_kind())); + let lhs = context.create_local(item); + + let mut scope = context.into_scope(); + let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let lhs: Variable = lhs.into(); + let rhs = scope.create_local(item); + + gpu!( + &mut scope, + loop(|scope| { + gpu!(scope, cond = lhs != 0); + gpu!(scope, if(cond).then(|scope|{ + scope.register(Branch::Break); + })); + + gpu!(scope, rhs = lhs - 1i32); + }) + ); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/reuse.rs b/crates/burn-cube/tests/reuse.rs index f9b245242a..08fb6e457a 100644 --- a/crates/burn-cube/tests/reuse.rs +++ b/crates/burn-cube/tests/reuse.rs @@ -1,97 +1,98 @@ -// use burn_cube::{cube, while_loop_expand, CubeContext, Int}; -// use burn_jit::gpu; -// use burn_jit::gpu::IntKind::I32; -// use burn_jit::gpu::{Branch, Elem, Item, Variable}; - -// // TODO -// // a += b is more efficient than a = a + b -// // because the latter does not assume that a is the same in lhs and rhs -// // It could be detected and optimized - -// #[cube] -// pub fn reuse(mut x: Int) { -// while x < int_new(10) { -// x = x + int_new(1); -// } -// } - -// #[cube] -// pub fn reuse_incr(mut x: Int) { -// while x < int_new(10) { -// x += int_new(1); -// } -// } - -// #[test] -// fn cube_reuse_assign_test() { -// let mut context = CubeContext::root(); - -// let x = context.create_local(Item::Scalar(Elem::Int(I32))); - -// reuse::expand(&mut context, x); -// let scope = context.into_scope(); - -// assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_assign()); -// } - -// #[test] -// fn cube_reuse_incr_test() { -// let mut context = CubeContext::root(); - -// let x = context.create_local(Item::Scalar(Elem::Int(I32))); - -// reuse_incr::expand(&mut context, x); -// let scope = context.into_scope(); - -// assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_incr()); -// } - -// fn gpu_macro_ref_assign() -> String { -// let mut context = CubeContext::root(); -// let item = Item::Scalar(Elem::Int(I32)); -// let x = context.create_local(item); - -// let mut scope = context.into_scope(); -// let cond = scope.create_local(Item::Scalar(Elem::Bool)); -// let x: Variable = x.into(); -// let tmp = scope.create_local(item); - -// gpu!( -// &mut scope, -// loop(|scope| { -// gpu!(scope, cond = x < 10); -// gpu!(scope, if(cond).then(|scope|{ -// scope.register(Branch::Break); -// })); - -// gpu!(scope, tmp = x + 1); -// gpu!(scope, x = tmp); -// }) -// ); - -// format!("{:?}", scope.operations) -// } - -// fn gpu_macro_ref_incr() -> String { -// let mut context = CubeContext::root(); -// let item = Item::Scalar(Elem::Int(I32)); -// let x = context.create_local(item); - -// let mut scope = context.into_scope(); -// let cond = scope.create_local(Item::Scalar(Elem::Bool)); -// let x: Variable = x.into(); - -// gpu!( -// &mut scope, -// loop(|scope| { -// gpu!(scope, cond = x < 10); -// gpu!(scope, if(cond).then(|scope|{ -// scope.register(Branch::Break); -// })); - -// gpu!(scope, x = x + 1); -// }) -// ); - -// format!("{:?}", scope.operations) -// } +use burn_cube::{cube, while_loop_expand, CubeContext, Int, I32}; +use burn_jit::gpu; +use burn_jit::gpu::{Branch, Elem, Item, Variable}; + +// TODO +// a += b is more efficient than a = a + b +// because the latter does not assume that a is the same in lhs and rhs +// It could be detected and optimized + +type ElemType = I32; + +#[cube] +pub fn reuse(mut x: I) { + while x < int_new::(10) { + x = x + int_new::(1); + } +} + +#[cube] +pub fn reuse_incr(mut x: I) { + while x < int_new::(10) { + x += int_new::(1); + } +} + +#[test] +fn cube_reuse_assign_test() { + let mut context = CubeContext::root(); + + let x = context.create_local(Item::Scalar(Elem::Int(ElemType::into_kind()))); + + reuse::expand::(&mut context, x); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_assign()); +} + +#[test] +fn cube_reuse_incr_test() { + let mut context = CubeContext::root(); + + let x = context.create_local(Item::Scalar(Elem::Int(ElemType::into_kind()))); + + reuse_incr::expand::(&mut context, x); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_incr()); +} + +fn gpu_macro_ref_assign() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Int(ElemType::into_kind())); + let x = context.create_local(item); + + let mut scope = context.into_scope(); + let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let x: Variable = x.into(); + let tmp = scope.create_local(item); + + gpu!( + &mut scope, + loop(|scope| { + gpu!(scope, cond = x < 10); + gpu!(scope, if(cond).then(|scope|{ + scope.register(Branch::Break); + })); + + gpu!(scope, tmp = x + 1); + gpu!(scope, x = tmp); + }) + ); + + format!("{:?}", scope.operations) +} + +fn gpu_macro_ref_incr() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Int(ElemType::into_kind())); + let x = context.create_local(item); + + let mut scope = context.into_scope(); + let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let x: Variable = x.into(); + + gpu!( + &mut scope, + loop(|scope| { + gpu!(scope, cond = x < 10); + gpu!(scope, if(cond).then(|scope|{ + scope.register(Branch::Break); + })); + + gpu!(scope, x = x + 1); + }) + ); + + format!("{:?}", scope.operations) +} From 63757367cb0c0625a565bd09db15af273bc84d3e Mon Sep 17 00:00:00 2001 From: louisfd Date: Thu, 9 May 2024 12:10:04 -0400 Subject: [PATCH 31/54] cast kind done --- crates/burn-cube-macros/src/prelude.rs | 8 +- crates/burn-cube/src/element/float.rs | 17 ++- crates/burn-cube/src/element/int.rs | 18 +-- crates/burn-cube/src/operation/binary.rs | 4 - crates/burn-cube/tests/cast_kind.rs | 122 ++++++++++++------ .../src/codegen/dialect/gpu/macros.rs | 6 + 6 files changed, 110 insertions(+), 65 deletions(-) diff --git a/crates/burn-cube-macros/src/prelude.rs b/crates/burn-cube-macros/src/prelude.rs index 6a560c267a..aaa0958927 100644 --- a/crates/burn-cube-macros/src/prelude.rs +++ b/crates/burn-cube-macros/src/prelude.rs @@ -33,15 +33,16 @@ pub(crate) fn get_prelude(needed_functions: &HashSet) -> proc_macro fn codegen_float_new() -> proc_macro2::TokenStream { quote::quote! { + // TODO ENCAPSULATE IMPORTS + pub fn float_new(val: f32) -> F { F::new(val, 1) } - pub fn float_new_expand( context: &mut CubeContext, val: f32, ) -> ::ExpandType { - val.into() + F::new_expand(val) } } } @@ -51,12 +52,11 @@ fn codegen_int_new() -> proc_macro2::TokenStream { pub fn int_new(val: i32) -> I { I::new(val, 1) } - pub fn int_new_expand( context: &mut CubeContext, val: i32, ) -> ::ExpandType { - val.into() + I::new_expand(val) } } } diff --git a/crates/burn-cube/src/element/float.rs b/crates/burn-cube/src/element/float.rs index f7d9a46ea1..c8efc03c06 100644 --- a/crates/burn-cube/src/element/float.rs +++ b/crates/burn-cube/src/element/float.rs @@ -1,4 +1,6 @@ use crate::{ExpandElement, RuntimeType}; +use burn_jit::gpu::{Elem, FloatKind, Variable}; +use std::rc::Rc; pub trait Float: Clone @@ -10,8 +12,9 @@ pub trait Float: + std::ops::Mul + std::ops::Div { - fn into_kind() -> burn_jit::gpu::FloatKind; + fn into_kind() -> FloatKind; fn new(val: f32, vectorization: usize) -> Self; + fn new_expand(val: f32) -> ExpandElement; } macro_rules! impl_float { @@ -27,16 +30,16 @@ macro_rules! impl_float { } impl Float for $type { - fn into_kind() -> burn_jit::gpu::FloatKind { - burn_jit::gpu::FloatKind::$type + fn into_kind() -> FloatKind { + FloatKind::$type } fn new(val: f32, vectorization: usize) -> Self { Self { val, vectorization } } - } - impl From for $type { - fn from(value: f32) -> Self { - $type::new(value, 1) + fn new_expand(val: f32) -> ExpandElement { + let elem = Elem::Float(Self::into_kind()); + let new_var = Variable::ConstantScalar(val as f64, elem); + ExpandElement::new(Rc::new(new_var)) } } }; diff --git a/crates/burn-cube/src/element/int.rs b/crates/burn-cube/src/element/int.rs index 673c9d456c..f880bc1742 100644 --- a/crates/burn-cube/src/element/int.rs +++ b/crates/burn-cube/src/element/int.rs @@ -1,4 +1,6 @@ use crate::{ExpandElement, RuntimeType}; +use burn_jit::gpu::{Elem, IntKind, Variable}; +use std::rc::Rc; pub trait Int: Clone @@ -11,8 +13,9 @@ pub trait Int: + std::ops::Div + std::ops::AddAssign { - fn into_kind() -> burn_jit::gpu::IntKind; + fn into_kind() -> IntKind; fn new(val: i32, vectorization: usize) -> Self; + fn new_expand(val: i32) -> ExpandElement; } macro_rules! impl_int { @@ -28,17 +31,16 @@ macro_rules! impl_int { } impl Int for $type { - fn into_kind() -> burn_jit::gpu::IntKind { - burn_jit::gpu::IntKind::$type + fn into_kind() -> IntKind { + IntKind::$type } fn new(val: i32, vectorization: usize) -> Self { Self { val, vectorization } } - } - - impl From for $type { - fn from(value: i32) -> Self { - $type::new(value, 1) + fn new_expand(val: i32) -> ExpandElement { + let elem = Elem::Int(Self::into_kind()); + let new_var = Variable::ConstantScalar(val as f64, elem); + ExpandElement::new(Rc::new(new_var)) } } }; diff --git a/crates/burn-cube/src/operation/binary.rs b/crates/burn-cube/src/operation/binary.rs index 5ae1adecc8..d8b370c653 100644 --- a/crates/burn-cube/src/operation/binary.rs +++ b/crates/burn-cube/src/operation/binary.rs @@ -68,8 +68,6 @@ pub mod sub { } pub mod mul { - use crate::Float; - use super::*; pub fn expand( @@ -102,8 +100,6 @@ pub mod mul { } pub mod div { - use crate::Float; - use super::*; pub fn expand( diff --git a/crates/burn-cube/tests/cast_kind.rs b/crates/burn-cube/tests/cast_kind.rs index 634211834c..2a7f1c37b4 100644 --- a/crates/burn-cube/tests/cast_kind.rs +++ b/crates/burn-cube/tests/cast_kind.rs @@ -1,43 +1,81 @@ -// use burn_cube::{cube, CubeContext, FloatX, Float, F32_, F64_}; -// use burn_jit::gpu; +use burn_cube::{cube, CubeContext, Float, Int, F32, F64, I32, I64}; +use burn_jit::gpu; // use burn_jit::gpu::FloatKind; -// use burn_jit::gpu::{Elem, Item}; - -// #[cube] -// pub fn cast_float_kind(input: FloatX) { -// let x = input + float_new::(5.9f32); -// let y = to_float::, F2>(x); -// let _ = y + float_new::(2.3f32); -// } - -// #[test] -// fn cube_cast_kind_test() { -// let mut context = CubeContext::root(); -// let item = Item::Scalar(Elem::Float(FloatKind::F64)); - -// let input = context.create_local(item); - -// // F16 not testable with the gpu macro, but should work the same -// cast_float_kind::expand::(&mut context, input); -// let scope = context.into_scope(); - -// assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); -// } - -// fn gpu_macro_ref() -> String { -// let mut context = CubeContext::root(); -// let float_64 = Item::Scalar(Elem::Float(FloatKind::F64)); -// let float_32 = Item::Scalar(Elem::Float(FloatKind::F32)); -// let input = context.create_local(float_64); - -// let mut scope = context.into_scope(); -// let x = scope.create_local(float_64); -// let y = scope.create_local(float_32); -// let z = scope.create_local(float_32); - -// gpu!(scope, x = input + 5.9f32 as f64); -// gpu!(scope, y = cast(x)); -// gpu!(scope, z = y + 2.3f32); - -// format!("{:?}", scope.operations) -// } +use burn_jit::gpu::{Elem, Item}; + +#[cube] +pub fn cast_float_kind(input: F1) { + let x = input + float_new::(5.9); + let y = to_float::(x); + let _ = y + float_new::(2.3); +} + +#[cube] +pub fn cast_int_kind(input: I1) { + let x = input + int_new::(5); + let y = to_int::(x); + let _ = y + int_new::(2); +} + +#[test] +fn cube_cast_float_kind_test() { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Float(F64::into_kind())); + + let input = context.create_local(item); + + // F16 not testable with the gpu macro, but should work the same + cast_float_kind::expand::(&mut context, input); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_float()); +} + +#[test] +fn cube_cast_int_kind_test() { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Int(I32::into_kind())); + + let input = context.create_local(item); + + cast_int_kind::expand::(&mut context, input); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_int()); +} + +fn gpu_macro_ref_float() -> String { + let mut context = CubeContext::root(); + let float_64 = Item::Scalar(Elem::Float(F64::into_kind())); + let float_32 = Item::Scalar(Elem::Float(F32::into_kind())); + let input = context.create_local(float_64); + + let mut scope = context.into_scope(); + let x = scope.create_local(float_64); + let y = scope.create_local(float_32); + let z = scope.create_local(float_32); + + gpu!(scope, x = input + 5.9f32 as f64); + gpu!(scope, y = cast(x)); + gpu!(scope, z = y + 2.3f32); + + format!("{:?}", scope.operations) +} + +fn gpu_macro_ref_int() -> String { + let mut context = CubeContext::root(); + let int_32 = Item::Scalar(Elem::Int(I32::into_kind())); + let int_64 = Item::Scalar(Elem::Int(I64::into_kind())); + let input = context.create_local(int_32); + + let mut scope = context.into_scope(); + let x = scope.create_local(int_32); + let y = scope.create_local(int_64); + let z = scope.create_local(int_64); + + gpu!(scope, x = input + 5i32); + gpu!(scope, y = cast(x)); + gpu!(scope, z = y + 2i64); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs index af3a1938bb..b2c1c1a457 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs @@ -408,6 +408,12 @@ impl From for Variable { } } +impl From for Variable { + fn from(value: i64) -> Self { + Self::ConstantScalar(value as f64, super::Elem::Int(super::IntKind::I64)) + } +} + impl From for Variable { fn from(value: f32) -> Self { Self::ConstantScalar(value as f64, super::Elem::Float(super::FloatKind::F32)) From 0c509759bad452cf33df465161b4c243dc35f817 Mon Sep 17 00:00:00 2001 From: louisfd Date: Thu, 9 May 2024 13:06:52 -0400 Subject: [PATCH 32/54] rename gpu macro --- crates/burn-cube-macros/src/prelude.rs | 2 - crates/burn-cube/tests/cast_elem.rs | 50 +-- crates/burn-cube/tests/cast_kind.rs | 27 +- crates/burn-cube/tests/for_loop.rs | 24 +- crates/burn-cube/tests/if.rs | 13 +- crates/burn-cube/tests/if_else.rs | 18 +- crates/burn-cube/tests/literal.rs | 9 +- crates/burn-cube/tests/loop.rs | 16 +- crates/burn-cube/tests/reuse.rs | 33 +- .../src/codegen/dialect/gpu/macros.rs | 140 ++++---- .../burn-jit/src/codegen/dialect/gpu/mod.rs | 2 +- .../codegen/dialect/gpu/procedure/assign.rs | 16 +- .../codegen/dialect/gpu/procedure/index.rs | 20 +- .../src/codegen/dialect/gpu/procedure/read.rs | 26 +- .../codegen/dialect/gpu/procedure/write.rs | 4 +- .../burn-jit/src/codegen/dialect/gpu/scope.rs | 6 +- .../src/codegen/dialect/gpu/variable.rs | 2 +- crates/burn-jit/src/kernel/cast/base.rs | 6 +- crates/burn-jit/src/kernel/cast/bool_cast.rs | 10 +- crates/burn-jit/src/kernel/contiguous.rs | 6 +- crates/burn-jit/src/kernel/conv/conv2d.rs | 176 +++++----- .../src/kernel/conv/conv_transpose2d.rs | 242 +++++++------- crates/burn-jit/src/kernel/index/flip.rs | 28 +- crates/burn-jit/src/kernel/index/gather.rs | 14 +- crates/burn-jit/src/kernel/index/repeat.rs | 22 +- crates/burn-jit/src/kernel/index/scatter.rs | 62 ++-- crates/burn-jit/src/kernel/index/select.rs | 30 +- .../src/kernel/index/select_assign.rs | 66 ++-- crates/burn-jit/src/kernel/index/slice.rs | 24 +- .../burn-jit/src/kernel/index/slice_assign.rs | 30 +- .../src/kernel/interpolate/bicubic.rs | 316 +++++++++--------- .../src/kernel/interpolate/bilinear.rs | 178 +++++----- .../src/kernel/interpolate/nearest.rs | 94 +++--- .../kernel/interpolate/nearest_backward.rs | 112 +++---- crates/burn-jit/src/kernel/mask/shader.rs | 18 +- crates/burn-jit/src/kernel/matmul/simple.rs | 68 ++-- .../src/kernel/matmul/tiling2d_shader/base.rs | 6 +- .../matmul/tiling2d_shader/computation.rs | 40 +-- .../tiling2d_shader/load_shared_memory.rs | 164 ++++----- .../tiling2d_shader/shader_information.rs | 96 +++--- .../matmul/tiling2d_shader/write_output.rs | 42 +-- .../pool/adaptive_avg_pool2d_backward.rs | 134 ++++---- .../src/kernel/pool/adaptive_pool2d_shader.rs | 122 +++---- crates/burn-jit/src/kernel/pool/avg_pool2d.rs | 16 +- .../src/kernel/pool/avg_pool2d_backward.rs | 210 ++++++------ crates/burn-jit/src/kernel/pool/max_pool2d.rs | 26 +- .../src/kernel/pool/max_pool2d_backward.rs | 164 ++++----- .../burn-jit/src/kernel/pool/pool2d_shader.rs | 108 +++--- crates/burn-jit/src/kernel/prng/base.rs | 60 ++-- crates/burn-jit/src/kernel/prng/bernoulli.rs | 18 +- crates/burn-jit/src/kernel/prng/normal.rs | 54 +-- crates/burn-jit/src/kernel/prng/uniform.rs | 26 +- .../burn-jit/src/kernel/reduce/argmax_dim.rs | 36 +- .../burn-jit/src/kernel/reduce/argmin_dim.rs | 36 +- crates/burn-jit/src/kernel/reduce/mean_dim.rs | 30 +- .../src/kernel/reduce/naive_reduce_shader.rs | 36 +- crates/burn-jit/src/kernel/reduce/prod_dim.rs | 22 +- .../src/kernel/reduce/shared_reduce_shader.rs | 68 ++-- crates/burn-jit/src/kernel/reduce/sum_dim.rs | 22 +- 59 files changed, 1725 insertions(+), 1721 deletions(-) diff --git a/crates/burn-cube-macros/src/prelude.rs b/crates/burn-cube-macros/src/prelude.rs index aaa0958927..f69196d762 100644 --- a/crates/burn-cube-macros/src/prelude.rs +++ b/crates/burn-cube-macros/src/prelude.rs @@ -33,8 +33,6 @@ pub(crate) fn get_prelude(needed_functions: &HashSet) -> proc_macro fn codegen_float_new() -> proc_macro2::TokenStream { quote::quote! { - // TODO ENCAPSULATE IMPORTS - pub fn float_new(val: f32) -> F { F::new(val, 1) } diff --git a/crates/burn-cube/tests/cast_elem.rs b/crates/burn-cube/tests/cast_elem.rs index 52a8feaa7c..123cabe191 100644 --- a/crates/burn-cube/tests/cast_elem.rs +++ b/crates/burn-cube/tests/cast_elem.rs @@ -1,6 +1,8 @@ use burn_cube::{cube, Bool, CubeContext, Float, Int, UInt, F32, I32}; -use burn_jit::gpu; -use burn_jit::gpu::{Elem, Item, Variable}; +use burn_jit::{ + cube_inline, + gpu::{Elem, Item, Variable}, +}; macro_rules! cast_test { ($name:ident, $module:ident, $from:expr, $to:expr) => { @@ -15,7 +17,7 @@ macro_rules! cast_test { assert_eq!( format!("{:?}", scope.operations), - gpu_macro_ref_cast($from, $to) + inline_macro_ref_cast($from, $to) ); } }; @@ -32,7 +34,7 @@ macro_rules! cast_test { assert_eq!( format!("{:?}", scope.operations), - gpu_macro_ref_identity($ty) + inline_macro_ref_identity($ty) ); } }; @@ -246,7 +248,7 @@ cast_test!( Item::Scalar(Elem::Bool) ); -fn gpu_macro_ref_cast(from_item: Item, to_item: Item) -> String { +fn inline_macro_ref_cast(from_item: Item, to_item: Item) -> String { let mut context = CubeContext::root(); let x = context.create_local(from_item); @@ -257,25 +259,25 @@ fn gpu_macro_ref_cast(from_item: Item, to_item: Item) -> String { let z = scope.create_local(to_item); match from_item.elem() { - Elem::Float(_) => gpu!(scope, y = x + 2f32), - Elem::Int(_) => gpu!(scope, y = x + 2i32), - Elem::UInt => gpu!(scope, y = x + 2u32), - Elem::Bool => gpu!(scope, y = x && false), + Elem::Float(_) => cube_inline!(scope, y = x + 2f32), + Elem::Int(_) => cube_inline!(scope, y = x + 2i32), + Elem::UInt => cube_inline!(scope, y = x + 2u32), + Elem::Bool => cube_inline!(scope, y = x && false), } - gpu!(scope, y_casted = cast(y)); + cube_inline!(scope, y_casted = cast(y)); match to_item.elem() { - Elem::Float(_) => gpu!(scope, z = y_casted + 34f32), - Elem::Int(_) => gpu!(scope, z = y_casted + 34i32), - Elem::UInt => gpu!(scope, z = y_casted + 34u32), - Elem::Bool => gpu!(scope, z = y_casted || true), + Elem::Float(_) => cube_inline!(scope, z = y_casted + 34f32), + Elem::Int(_) => cube_inline!(scope, z = y_casted + 34i32), + Elem::UInt => cube_inline!(scope, z = y_casted + 34u32), + Elem::Bool => cube_inline!(scope, z = y_casted || true), } format!("{:?}", scope.operations) } -fn gpu_macro_ref_identity(item: Item) -> String { +fn inline_macro_ref_identity(item: Item) -> String { // When staying with the same type variables are automatically reused in cube let mut context = CubeContext::root(); let x = context.create_local(item); @@ -285,19 +287,19 @@ fn gpu_macro_ref_identity(item: Item) -> String { let y = scope.create_local(item); match item.elem() { - Elem::Float(_) => gpu!(scope, y = x + 2f32), - Elem::Int(_) => gpu!(scope, y = x + 2i32), - Elem::UInt => gpu!(scope, y = x + 2u32), - Elem::Bool => gpu!(scope, y = x && false), + Elem::Float(_) => cube_inline!(scope, y = x + 2f32), + Elem::Int(_) => cube_inline!(scope, y = x + 2i32), + Elem::UInt => cube_inline!(scope, y = x + 2u32), + Elem::Bool => cube_inline!(scope, y = x && false), } - gpu!(scope, x = cast(y)); + cube_inline!(scope, x = cast(y)); match item.elem() { - Elem::Float(_) => gpu!(scope, y = x + 34f32), - Elem::Int(_) => gpu!(scope, y = x + 34i32), - Elem::UInt => gpu!(scope, y = x + 34u32), - Elem::Bool => gpu!(scope, y = x || true), + Elem::Float(_) => cube_inline!(scope, y = x + 34f32), + Elem::Int(_) => cube_inline!(scope, y = x + 34i32), + Elem::UInt => cube_inline!(scope, y = x + 34u32), + Elem::Bool => cube_inline!(scope, y = x || true), } format!("{:?}", scope.operations) diff --git a/crates/burn-cube/tests/cast_kind.rs b/crates/burn-cube/tests/cast_kind.rs index 2a7f1c37b4..e4793988bb 100644 --- a/crates/burn-cube/tests/cast_kind.rs +++ b/crates/burn-cube/tests/cast_kind.rs @@ -1,7 +1,8 @@ use burn_cube::{cube, CubeContext, Float, Int, F32, F64, I32, I64}; -use burn_jit::gpu; -// use burn_jit::gpu::FloatKind; -use burn_jit::gpu::{Elem, Item}; +use burn_jit::{ + cube_inline, + gpu::{Elem, Item}, +}; #[cube] pub fn cast_float_kind(input: F1) { @@ -28,7 +29,7 @@ fn cube_cast_float_kind_test() { cast_float_kind::expand::(&mut context, input); let scope = context.into_scope(); - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_float()); + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_float()); } #[test] @@ -41,10 +42,10 @@ fn cube_cast_int_kind_test() { cast_int_kind::expand::(&mut context, input); let scope = context.into_scope(); - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_int()); + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); } -fn gpu_macro_ref_float() -> String { +fn inline_macro_ref_float() -> String { let mut context = CubeContext::root(); let float_64 = Item::Scalar(Elem::Float(F64::into_kind())); let float_32 = Item::Scalar(Elem::Float(F32::into_kind())); @@ -55,14 +56,14 @@ fn gpu_macro_ref_float() -> String { let y = scope.create_local(float_32); let z = scope.create_local(float_32); - gpu!(scope, x = input + 5.9f32 as f64); - gpu!(scope, y = cast(x)); - gpu!(scope, z = y + 2.3f32); + cube_inline!(scope, x = input + 5.9f32 as f64); + cube_inline!(scope, y = cast(x)); + cube_inline!(scope, z = y + 2.3f32); format!("{:?}", scope.operations) } -fn gpu_macro_ref_int() -> String { +fn inline_macro_ref_int() -> String { let mut context = CubeContext::root(); let int_32 = Item::Scalar(Elem::Int(I32::into_kind())); let int_64 = Item::Scalar(Elem::Int(I64::into_kind())); @@ -73,9 +74,9 @@ fn gpu_macro_ref_int() -> String { let y = scope.create_local(int_64); let z = scope.create_local(int_64); - gpu!(scope, x = input + 5i32); - gpu!(scope, y = cast(x)); - gpu!(scope, z = y + 2i64); + cube_inline!(scope, x = input + 5i32); + cube_inline!(scope, y = cast(x)); + cube_inline!(scope, z = y + 2i64); format!("{:?}", scope.operations) } diff --git a/crates/burn-cube/tests/for_loop.rs b/crates/burn-cube/tests/for_loop.rs index 59e83151b5..d0cfe89bbc 100644 --- a/crates/burn-cube/tests/for_loop.rs +++ b/crates/burn-cube/tests/for_loop.rs @@ -1,6 +1,8 @@ use burn_cube::{cube, range, range_expand, Array, CubeContext, Float, UInt, F32}; -use burn_jit::gpu; -use burn_jit::gpu::{Elem, Item, Variable}; +use burn_jit::{ + cube_inline, + gpu::{Elem, Item, Variable}, +}; type ElemType = F32; @@ -26,7 +28,7 @@ fn test_for_loop_with_unroll() { for_loop::expand::(&mut context, lhs, rhs, end, unroll); let scope = context.into_scope(); - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref(unroll)); + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll)); } #[test] @@ -41,10 +43,10 @@ fn test_for_loop_no_unroll() { for_loop::expand::(&mut context, lhs, rhs, end, unroll); let scope = context.into_scope(); - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref(unroll)); + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll)); } -fn gpu_macro_ref(unroll: bool) -> String { +fn inline_macro_ref(unroll: bool) -> String { let mut context = CubeContext::root(); let item = Item::Scalar(Elem::Float(ElemType::into_kind())); @@ -58,15 +60,15 @@ fn gpu_macro_ref(unroll: bool) -> String { // Kernel let tmp1 = scope.create_local(item); let tmp2 = scope.create_local(item); - gpu!(scope, tmp1 = rhs * rhs); - gpu!(scope, tmp2 = tmp1 + rhs); + cube_inline!(scope, tmp1 = rhs * rhs); + cube_inline!(scope, tmp2 = tmp1 + rhs); - gpu!( + cube_inline!( &mut scope, range(0u32, end, unroll).for_each(|i, scope| { - gpu!(scope, rhs = lhs[i]); - gpu!(scope, tmp1 = tmp2 + rhs); - gpu!(scope, lhs[i] = tmp1); + cube_inline!(scope, rhs = lhs[i]); + cube_inline!(scope, tmp1 = tmp2 + rhs); + cube_inline!(scope, lhs[i] = tmp1); }) ); diff --git a/crates/burn-cube/tests/if.rs b/crates/burn-cube/tests/if.rs index ab28c3e6d8..e7f3d41987 100644 --- a/crates/burn-cube/tests/if.rs +++ b/crates/burn-cube/tests/if.rs @@ -1,6 +1,5 @@ use burn_cube::{cube, if_expand, CubeContext, Float, F32}; -use burn_jit::gpu; -use burn_jit::gpu::{Elem, Item, Variable}; +use burn_jit::{cube_inline, gpu::{Elem, Item, Variable}}; type ElemType = F32; @@ -20,10 +19,10 @@ fn cube_if_test() { if_greater::expand::(&mut context, lhs); let scope = context.into_scope(); - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); } -fn gpu_macro_ref() -> String { +fn inline_macro_ref() -> String { let mut context = CubeContext::root(); let item = Item::Scalar(Elem::Float(ElemType::into_kind())); let lhs = context.create_local(item); @@ -33,9 +32,9 @@ fn gpu_macro_ref() -> String { let lhs: Variable = lhs.into(); let y = scope.create_local(item); - gpu!(scope, cond = lhs > 0f32); - gpu!(&mut scope, if(cond).then(|scope| { - gpu!(scope, y = lhs + 4.0f32); + cube_inline!(scope, cond = lhs > 0f32); + cube_inline!(&mut scope, if(cond).then(|scope| { + cube_inline!(scope, y = lhs + 4.0f32); })); format!("{:?}", scope.operations) diff --git a/crates/burn-cube/tests/if_else.rs b/crates/burn-cube/tests/if_else.rs index f02587cfb3..605035967b 100644 --- a/crates/burn-cube/tests/if_else.rs +++ b/crates/burn-cube/tests/if_else.rs @@ -1,6 +1,8 @@ use burn_cube::{cube, if_else_expand, CubeContext, Float, F32}; -use burn_jit::gpu; -use burn_jit::gpu::{Elem, Item, Variable}; +use burn_jit::{ + cube_inline, + gpu::{Elem, Item, Variable}, +}; type ElemType = F32; @@ -22,10 +24,10 @@ fn cube_if_else_test() { if_else::expand::(&mut context, lhs); let scope = context.into_scope(); - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); } -fn gpu_macro_ref() -> String { +fn inline_macro_ref() -> String { let mut context = CubeContext::root(); let item = Item::Scalar(Elem::Float(ElemType::into_kind())); let lhs = context.create_local(item); @@ -35,11 +37,11 @@ fn gpu_macro_ref() -> String { let lhs: Variable = lhs.into(); let y = scope.create_local(item); - gpu!(scope, cond = lhs < 0f32); - gpu!(&mut scope, if(cond).then(|scope| { - gpu!(scope, y = lhs + 4.0f32); + cube_inline!(scope, cond = lhs < 0f32); + cube_inline!(&mut scope, if(cond).then(|scope| { + cube_inline!(scope, y = lhs + 4.0f32); }).else(|scope|{ - gpu!(scope, y = lhs - 5.0f32); + cube_inline!(scope, y = lhs - 5.0f32); })); format!("{:?}", scope.operations) diff --git a/crates/burn-cube/tests/literal.rs b/crates/burn-cube/tests/literal.rs index 0ced8ca5e2..0024de8de8 100644 --- a/crates/burn-cube/tests/literal.rs +++ b/crates/burn-cube/tests/literal.rs @@ -1,6 +1,5 @@ use burn_cube::{cube, CubeContext, Float, F32}; -use burn_jit::gpu; -use burn_jit::gpu::{Elem, Item}; +use burn_jit::{cube_inline, gpu::{Elem, Item}}; type ElemType = F32; @@ -18,17 +17,17 @@ fn cube_literal_test() { literal::expand::(&mut context, lhs); let scope = context.into_scope(); - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); } -fn gpu_macro_ref() -> String { +fn inline_macro_ref() -> String { let mut context = CubeContext::root(); let item = Item::Scalar(Elem::Float(ElemType::into_kind())); let lhs = context.create_local(item); let mut scope = context.into_scope(); let out = scope.create_local(item); - gpu!(scope, out = lhs + 5.9f32); + cube_inline!(scope, out = lhs + 5.9f32); format!("{:?}", scope.operations) } diff --git a/crates/burn-cube/tests/loop.rs b/crates/burn-cube/tests/loop.rs index 0c1fd8b788..8611078687 100644 --- a/crates/burn-cube/tests/loop.rs +++ b/crates/burn-cube/tests/loop.rs @@ -1,7 +1,7 @@ use burn_cube::{ break_expand, cube, if_expand, loop_expand, while_loop_expand, CubeContext, Int, I32, }; -use burn_jit::gpu; +use burn_jit::cube_inline; use burn_jit::gpu::Branch; use burn_jit::gpu::{Elem, Item, Variable}; @@ -33,7 +33,7 @@ fn cube_while_test() { while_not::expand::(&mut context, lhs); let scope = context.into_scope(); - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); } #[test] @@ -45,10 +45,10 @@ fn cube_loop_break_test() { manual_loop_break::expand::(&mut context, lhs); let scope = context.into_scope(); - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref()); + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); } -fn gpu_macro_ref() -> String { +fn inline_macro_ref() -> String { let mut context = CubeContext::root(); let item = Item::Scalar(Elem::Int(ElemType::into_kind())); let lhs = context.create_local(item); @@ -58,15 +58,15 @@ fn gpu_macro_ref() -> String { let lhs: Variable = lhs.into(); let rhs = scope.create_local(item); - gpu!( + cube_inline!( &mut scope, loop(|scope| { - gpu!(scope, cond = lhs != 0); - gpu!(scope, if(cond).then(|scope|{ + cube_inline!(scope, cond = lhs != 0); + cube_inline!(scope, if(cond).then(|scope|{ scope.register(Branch::Break); })); - gpu!(scope, rhs = lhs - 1i32); + cube_inline!(scope, rhs = lhs - 1i32); }) ); diff --git a/crates/burn-cube/tests/reuse.rs b/crates/burn-cube/tests/reuse.rs index 08fb6e457a..86af146a84 100644 --- a/crates/burn-cube/tests/reuse.rs +++ b/crates/burn-cube/tests/reuse.rs @@ -1,8 +1,9 @@ use burn_cube::{cube, while_loop_expand, CubeContext, Int, I32}; -use burn_jit::gpu; -use burn_jit::gpu::{Branch, Elem, Item, Variable}; +use burn_jit::{ + cube_inline, + gpu::{Branch, Elem, Item, Variable}, +}; -// TODO // a += b is more efficient than a = a + b // because the latter does not assume that a is the same in lhs and rhs // It could be detected and optimized @@ -32,7 +33,7 @@ fn cube_reuse_assign_test() { reuse::expand::(&mut context, x); let scope = context.into_scope(); - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_assign()); + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_assign()); } #[test] @@ -44,10 +45,10 @@ fn cube_reuse_incr_test() { reuse_incr::expand::(&mut context, x); let scope = context.into_scope(); - assert_eq!(format!("{:?}", scope.operations), gpu_macro_ref_incr()); + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_incr()); } -fn gpu_macro_ref_assign() -> String { +fn inline_macro_ref_assign() -> String { let mut context = CubeContext::root(); let item = Item::Scalar(Elem::Int(ElemType::into_kind())); let x = context.create_local(item); @@ -57,23 +58,23 @@ fn gpu_macro_ref_assign() -> String { let x: Variable = x.into(); let tmp = scope.create_local(item); - gpu!( + cube_inline!( &mut scope, loop(|scope| { - gpu!(scope, cond = x < 10); - gpu!(scope, if(cond).then(|scope|{ + cube_inline!(scope, cond = x < 10); + cube_inline!(scope, if(cond).then(|scope|{ scope.register(Branch::Break); })); - gpu!(scope, tmp = x + 1); - gpu!(scope, x = tmp); + cube_inline!(scope, tmp = x + 1); + cube_inline!(scope, x = tmp); }) ); format!("{:?}", scope.operations) } -fn gpu_macro_ref_incr() -> String { +fn inline_macro_ref_incr() -> String { let mut context = CubeContext::root(); let item = Item::Scalar(Elem::Int(ElemType::into_kind())); let x = context.create_local(item); @@ -82,15 +83,15 @@ fn gpu_macro_ref_incr() -> String { let cond = scope.create_local(Item::Scalar(Elem::Bool)); let x: Variable = x.into(); - gpu!( + cube_inline!( &mut scope, loop(|scope| { - gpu!(scope, cond = x < 10); - gpu!(scope, if(cond).then(|scope|{ + cube_inline!(scope, cond = x < 10); + cube_inline!(scope, if(cond).then(|scope|{ scope.register(Branch::Break); })); - gpu!(scope, x = x + 1); + cube_inline!(scope, x = x + 1); }) ); diff --git a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs index b2c1c1a457..348fb7772f 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs @@ -2,336 +2,336 @@ use super::Variable; #[macro_export(local_inner_macros)] /// Macro for generating JIT intermediate representation, in a concise way -macro_rules! gpu { +macro_rules! cube_inline { // out = lhs + rhs ($scope:expr, $out:ident = $lhs:ident + $rhs:expr) => { - gpu!($scope, $out = add($lhs, $rhs)) + cube_inline!($scope, $out = add($lhs, $rhs)) }; // out += input ($scope:expr, $out:ident += $input:ident) => { - gpu!($scope, $out = add($out, $input)) + cube_inline!($scope, $out = add($out, $input)) }; // out = add(lhs, rhs) ($scope:expr, $out:ident = add($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Add( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = lhs - rhs ($scope:expr, $out:ident = $lhs:ident - $rhs:expr) => { - gpu!($scope, $out = sub($lhs, $rhs)) + cube_inline!($scope, $out = sub($lhs, $rhs)) }; // out = sub(lhs, rhs) ($scope:expr, $out:ident = sub($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Sub( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = lhs * rhs ($scope:expr, $out:ident = $lhs:ident * $rhs:expr) => { - gpu!($scope, $out = mul($lhs, $rhs)) + cube_inline!($scope, $out = mul($lhs, $rhs)) }; // out *= input ($scope:expr, $out:ident *= $input:ident) => { - gpu!($scope, $out = mul($out, $input)) + cube_inline!($scope, $out = mul($out, $input)) }; // out = mul(lhs, rhs) ($scope:expr, $out:ident = mul($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Mul( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = lhs / rhs ($scope:expr, $out:ident = $lhs:ident / $rhs:expr) => { - gpu!($scope, $out = div($lhs, $rhs)) + cube_inline!($scope, $out = div($lhs, $rhs)) }; // out = div(lhs, rhs) ($scope:expr, $out:ident = div($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Div( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = lhs % rhs ($scope:expr, $out:ident = $lhs:ident % $rhs:expr) => { - gpu!($scope, $out = modulo($lhs, $rhs)) + cube_inline!($scope, $out = modulo($lhs, $rhs)) }; // out = modulo(lhs, rhs) ($scope:expr, $out:ident = modulo($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Modulo( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = powf(lhs, rhs) ($scope:expr, $out:ident = powf($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Powf( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = lhs && rhs ($scope:expr, $out:ident = $lhs:ident && $rhs:expr) => { - gpu!($scope, $out = and($lhs, $rhs)) + cube_inline!($scope, $out = and($lhs, $rhs)) }; // out = and(lhs, rhs) ($scope:expr, $out:ident = and($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::And( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = lhs || rhs ($scope:expr, $out:ident = $lhs:ident || $rhs:expr) => { - gpu!($scope, $out = or($lhs, $rhs)) + cube_inline!($scope, $out = or($lhs, $rhs)) }; // out = or(lhs, rhs) ($scope:expr, $out:ident = or($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Or( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = !input ($scope:expr, $out:ident = !$input:expr) => { - gpu!($scope, $out = not($input)) + cube_inline!($scope, $out = not($input)) }; // out = not(input) ($scope:expr, $out:ident = not($input:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Not( - gpu!(unary $input, $out) + cube_inline!(unary $input, $out) )); }; // out = lhs & rhs ($scope:expr, $out: ident = $lhs:ident & $rhs:ident) => { - gpu!($scope, $out = bitwise_and($lhs, $rhs)) + cube_inline!($scope, $out = bitwise_and($lhs, $rhs)) }; // out = bitwise_and(lhs, rhs) ($scope:expr, $out:ident = bitwise_and($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::BitwiseAnd( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = lhs ^ rhs ($scope:expr, $out: ident = $lhs:ident ^ $rhs:ident) => { - gpu!($scope, $out = bitwise_xor($lhs, $rhs)) + cube_inline!($scope, $out = bitwise_xor($lhs, $rhs)) }; // out = bitwise_xor(lhs, rhs) ($scope:expr, $out:ident = bitwise_xor($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::BitwiseXor( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = lhs << rhs ($scope:expr, $out: ident = $lhs:ident << $rhs:ident) => { - gpu!($scope, $out = shift_left($lhs, $rhs)) + cube_inline!($scope, $out = shift_left($lhs, $rhs)) }; // out = shift_left(lhs, rhs) ($scope:expr, $out:ident = shift_left($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::ShiftLeft( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = lhs >> rhs ($scope:expr, $out: ident = $lhs:ident >> $rhs:ident) => { - gpu!($scope, $out = shift_right($lhs, $rhs)) + cube_inline!($scope, $out = shift_right($lhs, $rhs)) }; // out = shift_right(lhs, rhs) ($scope:expr, $out:ident = shift_right($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::ShiftRight( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = lhs == rhs ($scope:expr, $out:ident = $lhs:ident == $rhs:expr) => { - gpu!($scope, $out = equal($lhs, $rhs)) + cube_inline!($scope, $out = equal($lhs, $rhs)) }; // out = equal(lhs, rhs) ($scope:expr, $out:ident = equal($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Equal( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = lhs != rhs ($scope:expr, $out:ident = $lhs:ident != $rhs:expr) => { - gpu!($scope, $out = not_equal($lhs, $rhs)) + cube_inline!($scope, $out = not_equal($lhs, $rhs)) }; // out = not_equal(lhs, rhs) ($scope:expr, $out:ident = not_equal($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::NotEqual( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = lhs > rhs ($scope:expr, $out:ident = $lhs:ident > $rhs:expr) => { - gpu!($scope, $out = greater($lhs, $rhs)) + cube_inline!($scope, $out = greater($lhs, $rhs)) }; // out = greater(lhs, rhs) ($scope:expr, $out:ident = greater($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Greater( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = lhs >= rhs ($scope:expr, $out:ident = $lhs:ident >= $rhs:expr) => { - gpu!($scope, $out = greater_equal($lhs, $rhs)) + cube_inline!($scope, $out = greater_equal($lhs, $rhs)) }; // out = greater_equal(lhs, rhs) ($scope:expr, $out:ident = greater_equal($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::GreaterEqual( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = lhs < rhs ($scope:expr, $out:ident = $lhs:ident < $rhs:expr) => { - gpu!($scope, $out = lower($lhs, $rhs)) + cube_inline!($scope, $out = lower($lhs, $rhs)) }; // out = lower(lhs, rhs) ($scope:expr, $out:ident = lower($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Lower( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = lhs <= rhs ($scope:expr, $out:ident = $lhs:ident <= $rhs:expr) => { - gpu!($scope, $out = lower_equal($lhs, $rhs)) + cube_inline!($scope, $out = lower_equal($lhs, $rhs)) }; // out = lower_equal(lhs, rhs) ($scope:expr, $out:ident = lower_equal($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::LowerEqual( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = max(lhs, rhs) ($scope:expr, $out:ident = max($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Max( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = min(lhs, rhs) ($scope:expr, $out:ident = min($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Min( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = lhs[rhs] ($scope:expr, $out:ident = $lhs:ident[$rhs:expr]) => { - gpu!($scope, $out = index($lhs, $rhs)) + cube_inline!($scope, $out = index($lhs, $rhs)) }; // out = index(lhs, rhs) ($scope:expr, $out:ident = index($lhs:expr, $rhs:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Index( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = unchecked(lhs[rhs]) ($scope:expr, $out:ident = unchecked($lhs:ident[$rhs:expr])) => { $scope.register($crate::codegen::dialect::gpu::Operator::UncheckedIndex( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out[lhs] = rhs ($scope:expr, $out:ident[$lhs:ident] = $rhs:expr) => { $scope.register($crate::codegen::dialect::gpu::Operator::IndexAssign( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // unchecked(out[lhs]) = rhs ($scope:expr, unchecked($out:ident[$lhs:ident]) = $rhs:expr) => { $scope.register($crate::codegen::dialect::gpu::Operator::UncheckedIndexAssign( - gpu!(binary $lhs, $rhs, $out) + cube_inline!(binary $lhs, $rhs, $out) )); }; // out = |input| ($scope:expr, $out:ident = |$input:ident|) => { - gpu!($scope, $out = abs($input)) + cube_inline!($scope, $out = abs($input)) }; // out = abs(input) ($scope:expr, $out:ident = abs($input:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Abs( - gpu!(unary $input, $out) + cube_inline!(unary $input, $out) )); }; // out = exp(input) ($scope:expr, $out:ident = exp($input:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Exp( - gpu!(unary $input, $out) + cube_inline!(unary $input, $out) )); }; // out = log(input) ($scope:expr, $out:ident = log($input:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Log( - gpu!(unary $input, $out) + cube_inline!(unary $input, $out) )); }; // out = log1p(input) ($scope:expr, $out:ident = log1p($input:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Log1p( - gpu!(unary $input, $out) + cube_inline!(unary $input, $out) )); }; // out = cos(input) ($scope:expr, $out:ident = cos($input:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Cos( - gpu!(unary $input, $out) + cube_inline!(unary $input, $out) )); }; // out = sin(input) ($scope:expr, $out:ident = sin($input:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Sin( - gpu!(unary $input, $out) + cube_inline!(unary $input, $out) )); }; // out = tanh(input) ($scope:expr, $out:ident = tanh($input:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Tanh( - gpu!(unary $input, $out) + cube_inline!(unary $input, $out) )); }; // out = sqrt(input) ($scope:expr, $out:ident = sqrt($input:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Sqrt( - gpu!(unary $input, $out) + cube_inline!(unary $input, $out) )); }; // out = floor(input) ($scope:expr, $out:ident = floor($input:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Floor( - gpu!(unary $input, $out) + cube_inline!(unary $input, $out) )); }; // out = ceil(input) ($scope:expr, $out:ident = ceil($input:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Ceil( - gpu!(unary $input, $out) + cube_inline!(unary $input, $out) )); }; // out = erf(input) ($scope:expr, $out:ident = erf($input:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Erf( - gpu!(unary $input, $out) + cube_inline!(unary $input, $out) )); }; // out = input ($scope:expr, $out:ident = $input:ident) => { $scope.register($crate::codegen::dialect::gpu::Operator::Assign( - gpu!(unary $input, $out) + cube_inline!(unary $input, $out) )); }; // out = vec4(a, b, c, d) ($scope:expr, $out:ident = vec4($a:ident,$b:ident,$c:ident,$d:ident)) => { let i = $scope.zero(Elem::UInt); - gpu!($scope, $out[i] = $a); - gpu!($scope, i = i + 1u32); - gpu!($scope, $out[i] = $b); - gpu!($scope, i = i + 1u32); - gpu!($scope, $out[i] = $c); - gpu!($scope, i = i + 1u32); - gpu!($scope, $out[i] = $d); + cube_inline!($scope, $out[i] = $a); + cube_inline!($scope, i = i + 1u32); + cube_inline!($scope, $out[i] = $b); + cube_inline!($scope, i = i + 1u32); + cube_inline!($scope, $out[i] = $c); + cube_inline!($scope, i = i + 1u32); + cube_inline!($scope, $out[i] = $d); }; // out = input ($scope:expr, $out:ident = $input:ident) => { - gpu!($scope, $out = cast($input)) + cube_inline!($scope, $out = cast($input)) }; // out = cast(input) ($scope:expr, $out:ident = cast($input:expr)) => { $scope.register($crate::codegen::dialect::gpu::Operator::Assign( - gpu!(unary $input, $out) + cube_inline!(unary $input, $out) )); }; // out = shape(tensor, dim) @@ -438,4 +438,4 @@ impl From for Variable { } } -pub(crate) use gpu; +pub(crate) use cube_inline; diff --git a/crates/burn-jit/src/codegen/dialect/gpu/mod.rs b/crates/burn-jit/src/codegen/dialect/gpu/mod.rs index c5fa96d929..e575305c79 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/mod.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/mod.rs @@ -18,4 +18,4 @@ pub use synchronization::*; pub use variable::*; pub use vectorization::*; -pub(crate) use macros::gpu; +pub(crate) use macros::cube_inline; diff --git a/crates/burn-jit/src/codegen/dialect/gpu/procedure/assign.rs b/crates/burn-jit/src/codegen/dialect/gpu/procedure/assign.rs index 0e3a3b705a..e1643a565d 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/procedure/assign.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/procedure/assign.rs @@ -1,4 +1,4 @@ -use crate::codegen::dialect::gpu::{macros::gpu, Item, Scope, Variable, Vectorization}; +use crate::codegen::dialect::gpu::{macros::cube_inline, Item, Scope, Variable, Vectorization}; use serde::{Deserialize, Serialize}; /// Assign value to a variable based on a given condition. @@ -23,7 +23,7 @@ impl ConditionalAssign { Item::Scalar(_) => var, _ => { let out = scope.create_local(var.item().elem()); - gpu!(scope, out = var[index]); + cube_inline!(scope, out = var[index]); out } }; @@ -31,14 +31,14 @@ impl ConditionalAssign { let mut assign_index = |index: usize| { let cond = index_var(scope, cond, index); - gpu!(scope, if (cond).then(|scope| { + cube_inline!(scope, if (cond).then(|scope| { let lhs = index_var(scope, lhs, index); let index: Variable = index.into(); - gpu!(scope, out[index] = lhs); + cube_inline!(scope, out[index] = lhs); }).else(|scope| { let rhs = index_var(scope, rhs, index); let index: Variable = index.into(); - gpu!(scope, out[index] = rhs); + cube_inline!(scope, out[index] = rhs); })); }; @@ -59,10 +59,10 @@ impl ConditionalAssign { assign_index(1); } Item::Scalar(_) => { - gpu!(scope, if (cond).then(|scope| { - gpu!(scope, out = lhs); + cube_inline!(scope, if (cond).then(|scope| { + cube_inline!(scope, out = lhs); }).else(|scope| { - gpu!(scope, out = rhs); + cube_inline!(scope, out = rhs); })); } }; diff --git a/crates/burn-jit/src/codegen/dialect/gpu/procedure/index.rs b/crates/burn-jit/src/codegen/dialect/gpu/procedure/index.rs index 95b0797e2b..659e781f73 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/procedure/index.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/procedure/index.rs @@ -1,4 +1,4 @@ -use crate::codegen::dialect::gpu::{macros::gpu, Item, Scope, Variable, Vectorization}; +use crate::codegen::dialect::gpu::{macros::cube_inline, Item, Scope, Variable, Vectorization}; use serde::{Deserialize, Serialize}; /// Perform a check bound on the index (lhs) of value (rhs) @@ -19,13 +19,13 @@ impl CheckedIndex { let array_len = scope.create_local(Item::Scalar(crate::gpu::Elem::UInt)); let inside_bound = scope.create_local(Item::Scalar(crate::gpu::Elem::Bool)); - gpu!(scope, array_len = len(lhs)); - gpu!(scope, inside_bound = rhs < array_len); + cube_inline!(scope, array_len = len(lhs)); + cube_inline!(scope, inside_bound = rhs < array_len); - gpu!(scope, if(inside_bound).then(|scope| { - gpu!(scope, out = unchecked(lhs[rhs])); + cube_inline!(scope, if(inside_bound).then(|scope| { + cube_inline!(scope, out = unchecked(lhs[rhs])); }).else(|scope| { - gpu!(scope, out = cast(0)); + cube_inline!(scope, out = cast(0)); })); } @@ -56,11 +56,11 @@ impl CheckedIndexAssign { let array_len = scope.create_local(Item::Scalar(crate::gpu::Elem::UInt)); let inside_bound = scope.create_local(Item::Scalar(crate::gpu::Elem::Bool)); - gpu!(scope, array_len = len(out)); - gpu!(scope, inside_bound = lhs < array_len); + cube_inline!(scope, array_len = len(out)); + cube_inline!(scope, inside_bound = lhs < array_len); - gpu!(scope, if(inside_bound).then(|scope| { - gpu!(scope, unchecked(out[lhs]) = rhs); + cube_inline!(scope, if(inside_bound).then(|scope| { + cube_inline!(scope, unchecked(out[lhs]) = rhs); })); } diff --git a/crates/burn-jit/src/codegen/dialect/gpu/procedure/read.rs b/crates/burn-jit/src/codegen/dialect/gpu/procedure/read.rs index 086e0db782..b5526192ba 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/procedure/read.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/procedure/read.rs @@ -1,4 +1,4 @@ -use super::super::{gpu, Elem, Item, Operator, Scope, Variable}; +use super::super::{cube_inline, Elem, Item, Operator, Scope, Variable}; use crate::codegen::dialect::gpu::{BinaryOperator, Vectorization}; use serde::{Deserialize, Serialize}; @@ -86,7 +86,7 @@ impl ReadGlobalWithLayout { let output = outputs[i]; let index = indexes[i]; - gpu!(scope, output = tensor[index]); + cube_inline!(scope, output = tensor[index]); } } @@ -141,36 +141,36 @@ impl IndexOffsetGlobalWithLayout { .into(); for index in self.indexes.iter() { - gpu!(scope, index = zero); + cube_inline!(scope, index = zero); } - gpu!( + cube_inline!( scope, range(self.dim_start, self.dim_end).for_each(|i, scope| { let stride_layout = scope.create_local(index_item_ty); let ogwl = scope.create_local(index_item_ty); - gpu!(scope, stride_layout = stride(layout, i)); - gpu!(scope, ogwl = offset_ref * vectorization_factor); - gpu!(scope, ogwl = ogwl / stride_layout); + cube_inline!(scope, stride_layout = stride(layout, i)); + cube_inline!(scope, ogwl = offset_ref * vectorization_factor); + cube_inline!(scope, ogwl = ogwl / stride_layout); for (tensor, index) in self.tensors.iter().zip(self.indexes.iter()) { let stride = scope.create_local(index_item_ty); let shape = scope.create_local(index_item_ty); let tmp = scope.create_local(index_item_ty); - gpu!(scope, stride = stride(tensor, i)); - gpu!(scope, shape = shape(tensor, i)); + cube_inline!(scope, stride = stride(tensor, i)); + cube_inline!(scope, shape = shape(tensor, i)); - gpu!(scope, tmp = ogwl % shape); - gpu!(scope, tmp = tmp * stride); - gpu!(scope, index = index + tmp); + cube_inline!(scope, tmp = ogwl % shape); + cube_inline!(scope, tmp = tmp * stride); + cube_inline!(scope, index = index + tmp); } }) ); for index in self.indexes { - gpu!(scope, index = index / vectorization_factor); + cube_inline!(scope, index = index / vectorization_factor); } } diff --git a/crates/burn-jit/src/codegen/dialect/gpu/procedure/write.rs b/crates/burn-jit/src/codegen/dialect/gpu/procedure/write.rs index aaad996163..9215ea54f8 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/procedure/write.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/procedure/write.rs @@ -1,4 +1,4 @@ -use crate::codegen::dialect::gpu::{macros::gpu, Scope, Variable, Vectorization}; +use crate::codegen::dialect::gpu::{macros::cube_inline, Scope, Variable, Vectorization}; use serde::{Deserialize, Serialize}; /// Write to a global array. @@ -16,7 +16,7 @@ impl WriteGlobal { let input = self.input; let position = Variable::Id; - gpu!(scope, output[position] = input); + cube_inline!(scope, output[position] = input); } pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { diff --git a/crates/burn-jit/src/codegen/dialect/gpu/scope.rs b/crates/burn-jit/src/codegen/dialect/gpu/scope.rs index 1bfada540a..02a125233c 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/scope.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/scope.rs @@ -1,7 +1,7 @@ use crate::JitElement; use super::{ - gpu, processing::ScopeProcessing, Elem, IndexOffsetGlobalWithLayout, Item, Operation, Operator, + cube_inline, processing::ScopeProcessing, Elem, IndexOffsetGlobalWithLayout, Item, Operation, Operator, Procedure, ReadGlobal, ReadGlobalWithLayout, UnaryOperator, Variable, Vectorization, WriteGlobal, }; @@ -64,7 +64,7 @@ impl Scope { pub(crate) fn zero>(&mut self, item: I) -> Variable { let local = self.create_local(item); let zero: Variable = 0u32.into(); - gpu!(self, local = zero); + cube_inline!(self, local = zero); local } @@ -76,7 +76,7 @@ impl Scope { ) -> Variable { let local = self.create_local(item); let value = Variable::ConstantScalar(value.to_f64().unwrap(), item.into().elem()); - gpu!(self, local = value); + cube_inline!(self, local = value); local } diff --git a/crates/burn-jit/src/codegen/dialect/gpu/variable.rs b/crates/burn-jit/src/codegen/dialect/gpu/variable.rs index b726867138..3f4330050a 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/variable.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/variable.rs @@ -97,7 +97,7 @@ impl Variable { } } -// Useful with the gpu! macro. +// Useful with the cube_inline macro. impl From<&Variable> for Variable { fn from(value: &Variable) -> Self { *value diff --git a/crates/burn-jit/src/kernel/cast/base.rs b/crates/burn-jit/src/kernel/cast/base.rs index dfb1c92dbe..28ba6ebf20 100644 --- a/crates/burn-jit/src/kernel/cast/base.rs +++ b/crates/burn-jit/src/kernel/cast/base.rs @@ -5,7 +5,7 @@ use crate::{ Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, - gpu::{gpu, ComputeShader, Scope, Variable, Visibility}, + gpu::{cube_inline, ComputeShader, Scope, Variable, Visibility}, kernel::GpuComputeShaderPhase, tensor::JitTensor, JitElement, Runtime, @@ -103,7 +103,7 @@ impl CastShader { let output = self.output; let value = scope.create_local(output.item()); - gpu!(scope, value = tensor[id]); - gpu!(scope, output[id] = value); + cube_inline!(scope, value = tensor[id]); + cube_inline!(scope, output[id] = value); } } diff --git a/crates/burn-jit/src/kernel/cast/bool_cast.rs b/crates/burn-jit/src/kernel/cast/bool_cast.rs index a1c3f7b47c..ea7162fceb 100644 --- a/crates/burn-jit/src/kernel/cast/bool_cast.rs +++ b/crates/burn-jit/src/kernel/cast/bool_cast.rs @@ -5,7 +5,7 @@ use crate::{ Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, - gpu::{gpu, ComputeShader, Elem, Item, Scope, Variable, Visibility}, + gpu::{cube_inline, ComputeShader, Elem, Item, Scope, Variable, Visibility}, kernel::GpuComputeShaderPhase, tensor::JitTensor, JitElement, Runtime, @@ -99,11 +99,11 @@ impl BoolCastShader { let output = self.output; let represents_true = scope.create_local(Elem::Bool); - gpu!(scope, represents_true = tensor[id]); - gpu!(scope, if(represents_true).then(|scope|{ - gpu!(scope, output[id] = 1); + cube_inline!(scope, represents_true = tensor[id]); + cube_inline!(scope, if(represents_true).then(|scope|{ + cube_inline!(scope, output[id] = 1); }).else(|scope|{ - gpu!(scope, output[id] = 0); + cube_inline!(scope, output[id] = 0); })); } } diff --git a/crates/burn-jit/src/kernel/contiguous.rs b/crates/burn-jit/src/kernel/contiguous.rs index 607b846ea2..db9efb3cf8 100644 --- a/crates/burn-jit/src/kernel/contiguous.rs +++ b/crates/burn-jit/src/kernel/contiguous.rs @@ -5,7 +5,7 @@ use crate::{ Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, - gpu::{gpu, ComputeShader, Elem, IndexOffsetGlobalWithLayout, Scope, Variable, Visibility}, + gpu::{cube_inline, ComputeShader, Elem, IndexOffsetGlobalWithLayout, Scope, Variable, Visibility}, tensor::JitTensor, JitElement, Runtime, }; @@ -110,7 +110,7 @@ impl IntoContiguousShader { .expand(scope); let value = scope.create_local(tensor.item()); - gpu!(scope, value = tensor[offset_input]); - gpu!(scope, output[id] = value); + cube_inline!(scope, value = tensor[offset_input]); + cube_inline!(scope, output[id] = value); } } diff --git a/crates/burn-jit/src/kernel/conv/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d.rs index 6c72c02b31..d201e4e027 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d.rs @@ -6,7 +6,7 @@ use std::marker::PhantomData; use crate::{ codegen::{ - dialect::gpu::{gpu, Elem, Scope, Variable, Visibility}, + dialect::gpu::{cube_inline, Elem, Scope, Variable, Visibility}, Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, @@ -51,14 +51,14 @@ impl Conv2dComputeShader { let input_shape_1 = scope.create_local(Elem::UInt); let input_shape_2 = scope.create_local(Elem::UInt); let input_shape_3 = scope.create_local(Elem::UInt); - gpu!(scope, input_stride_0 = stride(input, 0u32)); - gpu!(scope, input_stride_1 = stride(input, 1u32)); - gpu!(scope, input_stride_2 = stride(input, 2u32)); - gpu!(scope, input_stride_3 = stride(input, 3u32)); - gpu!(scope, input_shape_0 = shape(input, 0u32)); - gpu!(scope, input_shape_1 = shape(input, 1u32)); - gpu!(scope, input_shape_2 = shape(input, 2u32)); - gpu!(scope, input_shape_3 = shape(input, 3u32)); + cube_inline!(scope, input_stride_0 = stride(input, 0u32)); + cube_inline!(scope, input_stride_1 = stride(input, 1u32)); + cube_inline!(scope, input_stride_2 = stride(input, 2u32)); + cube_inline!(scope, input_stride_3 = stride(input, 3u32)); + cube_inline!(scope, input_shape_0 = shape(input, 0u32)); + cube_inline!(scope, input_shape_1 = shape(input, 1u32)); + cube_inline!(scope, input_shape_2 = shape(input, 2u32)); + cube_inline!(scope, input_shape_3 = shape(input, 3u32)); let output_stride_0 = scope.create_local(Elem::UInt); let output_stride_1 = scope.create_local(Elem::UInt); @@ -68,14 +68,14 @@ impl Conv2dComputeShader { let output_shape_1 = scope.create_local(Elem::UInt); let output_shape_2 = scope.create_local(Elem::UInt); let output_shape_3 = scope.create_local(Elem::UInt); - gpu!(scope, output_stride_0 = stride(output, 0u32)); - gpu!(scope, output_stride_1 = stride(output, 1u32)); - gpu!(scope, output_stride_2 = stride(output, 2u32)); - gpu!(scope, output_stride_3 = stride(output, 3u32)); - gpu!(scope, output_shape_0 = shape(output, 0u32)); - gpu!(scope, output_shape_1 = shape(output, 1u32)); - gpu!(scope, output_shape_2 = shape(output, 2u32)); - gpu!(scope, output_shape_3 = shape(output, 3u32)); + cube_inline!(scope, output_stride_0 = stride(output, 0u32)); + cube_inline!(scope, output_stride_1 = stride(output, 1u32)); + cube_inline!(scope, output_stride_2 = stride(output, 2u32)); + cube_inline!(scope, output_stride_3 = stride(output, 3u32)); + cube_inline!(scope, output_shape_0 = shape(output, 0u32)); + cube_inline!(scope, output_shape_1 = shape(output, 1u32)); + cube_inline!(scope, output_shape_2 = shape(output, 2u32)); + cube_inline!(scope, output_shape_3 = shape(output, 3u32)); let weight_stride_0 = scope.create_local(Elem::UInt); let weight_stride_1 = scope.create_local(Elem::UInt); @@ -85,14 +85,14 @@ impl Conv2dComputeShader { let in_channels = scope.create_local(Elem::UInt); let kernel_size_0 = scope.create_local(Elem::UInt); let kernel_size_1 = scope.create_local(Elem::UInt); - gpu!(scope, weight_stride_0 = stride(weight, 0u32)); - gpu!(scope, weight_stride_1 = stride(weight, 1u32)); - gpu!(scope, weight_stride_2 = stride(weight, 2u32)); - gpu!(scope, weight_stride_3 = stride(weight, 3u32)); - gpu!(scope, weight_shape_0 = shape(weight, 0u32)); - gpu!(scope, in_channels = shape(weight, 1u32)); - gpu!(scope, kernel_size_0 = shape(weight, 2u32)); - gpu!(scope, kernel_size_1 = shape(weight, 3u32)); + cube_inline!(scope, weight_stride_0 = stride(weight, 0u32)); + cube_inline!(scope, weight_stride_1 = stride(weight, 1u32)); + cube_inline!(scope, weight_stride_2 = stride(weight, 2u32)); + cube_inline!(scope, weight_stride_3 = stride(weight, 3u32)); + cube_inline!(scope, weight_shape_0 = shape(weight, 0u32)); + cube_inline!(scope, in_channels = shape(weight, 1u32)); + cube_inline!(scope, kernel_size_0 = shape(weight, 2u32)); + cube_inline!(scope, kernel_size_1 = shape(weight, 3u32)); let conv_stride_0 = Variable::GlobalScalar(0, Elem::UInt); let conv_stride_1 = Variable::GlobalScalar(1, Elem::UInt); @@ -111,26 +111,26 @@ impl Conv2dComputeShader { let ic_start = scope.create_local(Elem::UInt); let ic_end = scope.create_local(Elem::UInt); - gpu!(scope, b = id / output_stride_0); - gpu!(scope, b = b % output_shape_0); + cube_inline!(scope, b = id / output_stride_0); + cube_inline!(scope, b = b % output_shape_0); - gpu!(scope, oc = id / output_stride_1); - gpu!(scope, oc = oc % output_shape_1); + cube_inline!(scope, oc = id / output_stride_1); + cube_inline!(scope, oc = oc % output_shape_1); - gpu!(scope, oh = id / output_stride_2); - gpu!(scope, oh = oh % output_shape_2); + cube_inline!(scope, oh = id / output_stride_2); + cube_inline!(scope, oh = oh % output_shape_2); - gpu!(scope, ow = id / output_stride_3); - gpu!(scope, ow = ow % output_shape_3); + cube_inline!(scope, ow = id / output_stride_3); + cube_inline!(scope, ow = ow % output_shape_3); - gpu!(scope, g = weight_shape_0 + oc); - gpu!(scope, g = g % groups); + cube_inline!(scope, g = weight_shape_0 + oc); + cube_inline!(scope, g = g % groups); - gpu!(scope, ic_start = in_channels * g); - gpu!(scope, ic_end = ic_start + in_channels); + cube_inline!(scope, ic_start = in_channels * g); + cube_inline!(scope, ic_end = ic_start + in_channels); let sum = scope.create_local(output.item()); - gpu!(scope, sum = bias[oc]); + cube_inline!(scope, sum = bias[oc]); let ih_base = scope.create_local(Elem::UInt); let iw_base = scope.create_local(Elem::UInt); @@ -163,65 +163,65 @@ impl Conv2dComputeShader { let weight_value = scope.create_local(weight.item()); let value_product = scope.create_local(input.item()); - gpu!(scope, ih_base = oh * conv_stride_0); - gpu!(scope, iw_base = ow * conv_stride_1); + cube_inline!(scope, ih_base = oh * conv_stride_0); + cube_inline!(scope, iw_base = ow * conv_stride_1); - gpu!(scope, border_top = padding_0); - gpu!(scope, border_left = padding_1); - gpu!(scope, border_bottom = input_shape_2 + padding_0); - gpu!(scope, border_right = input_shape_3 + padding_1); + cube_inline!(scope, border_top = padding_0); + cube_inline!(scope, border_left = padding_1); + cube_inline!(scope, border_bottom = input_shape_2 + padding_0); + cube_inline!(scope, border_right = input_shape_3 + padding_1); - gpu!(scope, index_input_0 = b * input_stride_0); - gpu!(scope, index_weight_0 = oc * weight_stride_0); + cube_inline!(scope, index_input_0 = b * input_stride_0); + cube_inline!(scope, index_weight_0 = oc * weight_stride_0); - gpu!( + cube_inline!( scope, range(ic_start, ic_end).for_each(|ic, scope| { - gpu!(scope, index_input_1 = ic * input_stride_1); - gpu!(scope, index_weight_1 = ic - ic_start); - gpu!(scope, index_weight_1 *= weight_stride_1); + cube_inline!(scope, index_input_1 = ic * input_stride_1); + cube_inline!(scope, index_weight_1 = ic - ic_start); + cube_inline!(scope, index_weight_1 *= weight_stride_1); - gpu!( + cube_inline!( scope, range(0u32, kernel_size_0).for_each(|kh, scope| { - gpu!( + cube_inline!( scope, range(0u32, kernel_size_1).for_each(|kw, scope| { - gpu!(scope, ih = kh * dilation_0); - gpu!(scope, ih += ih_base); - gpu!(scope, iw = kw * dilation_1); - gpu!(scope, iw += iw_base); - - gpu!(scope, padding_accumulator = ih >= border_top); - gpu!(scope, padding = ih < border_bottom); - gpu!(scope, padding_accumulator = padding_accumulator && padding); - gpu!(scope, padding = iw >= border_left); - gpu!(scope, padding_accumulator = padding_accumulator && padding); - gpu!(scope, padding = iw < border_right); - gpu!(scope, padding_accumulator = padding_accumulator && padding); - - gpu!(scope, if(padding_accumulator).then(|scope|{ - gpu!(scope, ih_pad = ih - padding_0); - gpu!(scope, iw_pad = iw - padding_1); - - gpu!(scope, index_input_2 = ih_pad * input_stride_2); - gpu!(scope, index_input_3 = iw_pad * input_stride_3); - gpu!(scope, index_weight_2 = kh * weight_stride_2); - gpu!(scope, index_weight_3 = kw * weight_stride_3); - - gpu!(scope, index_input = index_input_0); - gpu!(scope, index_input += index_input_1); - gpu!(scope, index_input += index_input_2); - gpu!(scope, index_input += index_input_3); - gpu!(scope, index_weight = index_weight_0); - gpu!(scope, index_weight += index_weight_1); - gpu!(scope, index_weight += index_weight_2); - gpu!(scope, index_weight += index_weight_3); - - gpu!(scope, input_value = input[index_input]); - gpu!(scope, weight_value = weight[index_weight]); - gpu!(scope, value_product = input_value * weight_value); - gpu!(scope, sum += value_product); + cube_inline!(scope, ih = kh * dilation_0); + cube_inline!(scope, ih += ih_base); + cube_inline!(scope, iw = kw * dilation_1); + cube_inline!(scope, iw += iw_base); + + cube_inline!(scope, padding_accumulator = ih >= border_top); + cube_inline!(scope, padding = ih < border_bottom); + cube_inline!(scope, padding_accumulator = padding_accumulator && padding); + cube_inline!(scope, padding = iw >= border_left); + cube_inline!(scope, padding_accumulator = padding_accumulator && padding); + cube_inline!(scope, padding = iw < border_right); + cube_inline!(scope, padding_accumulator = padding_accumulator && padding); + + cube_inline!(scope, if(padding_accumulator).then(|scope|{ + cube_inline!(scope, ih_pad = ih - padding_0); + cube_inline!(scope, iw_pad = iw - padding_1); + + cube_inline!(scope, index_input_2 = ih_pad * input_stride_2); + cube_inline!(scope, index_input_3 = iw_pad * input_stride_3); + cube_inline!(scope, index_weight_2 = kh * weight_stride_2); + cube_inline!(scope, index_weight_3 = kw * weight_stride_3); + + cube_inline!(scope, index_input = index_input_0); + cube_inline!(scope, index_input += index_input_1); + cube_inline!(scope, index_input += index_input_2); + cube_inline!(scope, index_input += index_input_3); + cube_inline!(scope, index_weight = index_weight_0); + cube_inline!(scope, index_weight += index_weight_1); + cube_inline!(scope, index_weight += index_weight_2); + cube_inline!(scope, index_weight += index_weight_3); + + cube_inline!(scope, input_value = input[index_input]); + cube_inline!(scope, weight_value = weight[index_weight]); + cube_inline!(scope, value_product = input_value * weight_value); + cube_inline!(scope, sum += value_product); })); }) ); @@ -230,7 +230,7 @@ impl Conv2dComputeShader { }) ); - gpu!(scope, output[id] = sum); + cube_inline!(scope, output[id] = sum); } } diff --git a/crates/burn-jit/src/kernel/conv/conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/conv_transpose2d.rs index d4896cbc05..dc5add0b74 100644 --- a/crates/burn-jit/src/kernel/conv/conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv_transpose2d.rs @@ -6,7 +6,7 @@ use crate::{ OutputInfo, WorkgroupLaunch, }, element::JitElement, - gpu::{gpu, ComputeShader, Elem, IntKind, Scope, Variable, Visibility}, + gpu::{cube_inline, ComputeShader, Elem, IntKind, Scope, Variable, Visibility}, kernel::{self, GpuComputeShaderPhase}, ops::{ numeric::{empty_device, zeros_device}, @@ -48,14 +48,14 @@ impl Conv2dTransposeComputeShader { let input_shape_1 = scope.create_local(Elem::UInt); let input_shape_2 = scope.create_local(Elem::UInt); let input_shape_3 = scope.create_local(Elem::UInt); - gpu!(scope, input_stride_0 = stride(input, 0u32)); - gpu!(scope, input_stride_1 = stride(input, 1u32)); - gpu!(scope, input_stride_2 = stride(input, 2u32)); - gpu!(scope, input_stride_3 = stride(input, 3u32)); - gpu!(scope, input_shape_0 = shape(input, 0u32)); - gpu!(scope, input_shape_1 = shape(input, 1u32)); - gpu!(scope, input_shape_2 = shape(input, 2u32)); - gpu!(scope, input_shape_3 = shape(input, 3u32)); + cube_inline!(scope, input_stride_0 = stride(input, 0u32)); + cube_inline!(scope, input_stride_1 = stride(input, 1u32)); + cube_inline!(scope, input_stride_2 = stride(input, 2u32)); + cube_inline!(scope, input_stride_3 = stride(input, 3u32)); + cube_inline!(scope, input_shape_0 = shape(input, 0u32)); + cube_inline!(scope, input_shape_1 = shape(input, 1u32)); + cube_inline!(scope, input_shape_2 = shape(input, 2u32)); + cube_inline!(scope, input_shape_3 = shape(input, 3u32)); let output_stride_0 = scope.create_local(Elem::UInt); let output_stride_1 = scope.create_local(Elem::UInt); @@ -65,14 +65,14 @@ impl Conv2dTransposeComputeShader { let output_shape_1 = scope.create_local(Elem::UInt); let output_shape_2 = scope.create_local(Elem::UInt); let output_shape_3 = scope.create_local(Elem::UInt); - gpu!(scope, output_stride_0 = stride(output, 0u32)); - gpu!(scope, output_stride_1 = stride(output, 1u32)); - gpu!(scope, output_stride_2 = stride(output, 2u32)); - gpu!(scope, output_stride_3 = stride(output, 3u32)); - gpu!(scope, output_shape_0 = shape(output, 0u32)); - gpu!(scope, output_shape_1 = shape(output, 1u32)); - gpu!(scope, output_shape_2 = shape(output, 2u32)); - gpu!(scope, output_shape_3 = shape(output, 3u32)); + cube_inline!(scope, output_stride_0 = stride(output, 0u32)); + cube_inline!(scope, output_stride_1 = stride(output, 1u32)); + cube_inline!(scope, output_stride_2 = stride(output, 2u32)); + cube_inline!(scope, output_stride_3 = stride(output, 3u32)); + cube_inline!(scope, output_shape_0 = shape(output, 0u32)); + cube_inline!(scope, output_shape_1 = shape(output, 1u32)); + cube_inline!(scope, output_shape_2 = shape(output, 2u32)); + cube_inline!(scope, output_shape_3 = shape(output, 3u32)); let weight_stride_0 = scope.create_local(Elem::UInt); let weight_stride_1 = scope.create_local(Elem::UInt); @@ -82,14 +82,14 @@ impl Conv2dTransposeComputeShader { let weight_shape_1 = scope.create_local(Elem::UInt); let kernel_size_0 = scope.create_local(Elem::UInt); let kernel_size_1 = scope.create_local(Elem::UInt); - gpu!(scope, weight_stride_0 = stride(weight, 0u32)); - gpu!(scope, weight_stride_1 = stride(weight, 1u32)); - gpu!(scope, weight_stride_2 = stride(weight, 2u32)); - gpu!(scope, weight_stride_3 = stride(weight, 3u32)); - gpu!(scope, in_channels = shape(weight, 0u32)); - gpu!(scope, weight_shape_1 = shape(weight, 1u32)); - gpu!(scope, kernel_size_0 = shape(weight, 2u32)); - gpu!(scope, kernel_size_1 = shape(weight, 3u32)); + cube_inline!(scope, weight_stride_0 = stride(weight, 0u32)); + cube_inline!(scope, weight_stride_1 = stride(weight, 1u32)); + cube_inline!(scope, weight_stride_2 = stride(weight, 2u32)); + cube_inline!(scope, weight_stride_3 = stride(weight, 3u32)); + cube_inline!(scope, in_channels = shape(weight, 0u32)); + cube_inline!(scope, weight_shape_1 = shape(weight, 1u32)); + cube_inline!(scope, kernel_size_0 = shape(weight, 2u32)); + cube_inline!(scope, kernel_size_1 = shape(weight, 3u32)); let conv_stride_0 = Variable::GlobalScalar(0, Elem::UInt); let conv_stride_1 = Variable::GlobalScalar(1, Elem::UInt); @@ -101,8 +101,8 @@ impl Conv2dTransposeComputeShader { let stride_0_i = scope.create_local(Elem::Int(IntKind::I32)); let stride_1_i = scope.create_local(Elem::Int(IntKind::I32)); - gpu!(scope, stride_0_i = cast(conv_stride_0)); - gpu!(scope, stride_1_i = cast(conv_stride_1)); + cube_inline!(scope, stride_0_i = cast(conv_stride_0)); + cube_inline!(scope, stride_1_i = cast(conv_stride_1)); let oc_out = scope.create_local(Elem::UInt); let oc = scope.create_local(Elem::UInt); @@ -117,26 +117,26 @@ impl Conv2dTransposeComputeShader { let ic_end = scope.create_local(Elem::UInt); let ic_tmp = scope.create_local(Elem::UInt); - gpu!(scope, b = id / output_stride_0); - gpu!(scope, b = b % output_shape_0); + cube_inline!(scope, b = id / output_stride_0); + cube_inline!(scope, b = b % output_shape_0); - gpu!(scope, oc_out = id / output_stride_1); - gpu!(scope, oc_out = oc_out % output_shape_1); + cube_inline!(scope, oc_out = id / output_stride_1); + cube_inline!(scope, oc_out = oc_out % output_shape_1); - gpu!(scope, oh = id / output_stride_2); - gpu!(scope, oh = oh % output_shape_2); + cube_inline!(scope, oh = id / output_stride_2); + cube_inline!(scope, oh = oh % output_shape_2); - gpu!(scope, ow = id / output_stride_3); - gpu!(scope, ow = ow % output_shape_3); + cube_inline!(scope, ow = id / output_stride_3); + cube_inline!(scope, ow = ow % output_shape_3); - gpu!(scope, k = oc_out / weight_shape_1); - gpu!(scope, g = k % groups); - gpu!(scope, oc = weight_shape_1 * g); - gpu!(scope, oc = oc_out - oc); + cube_inline!(scope, k = oc_out / weight_shape_1); + cube_inline!(scope, g = k % groups); + cube_inline!(scope, oc = weight_shape_1 * g); + cube_inline!(scope, oc = oc_out - oc); - gpu!(scope, ic_tmp = in_channels / groups); - gpu!(scope, ic_start = g * ic_tmp); - gpu!(scope, ic_end = ic_start + ic_tmp); + cube_inline!(scope, ic_tmp = in_channels / groups); + cube_inline!(scope, ic_start = g * ic_tmp); + cube_inline!(scope, ic_end = ic_start + ic_tmp); let tmp_u = scope.create_local(Elem::UInt); let tmp_i = scope.create_local(Elem::Int(IntKind::I32)); @@ -153,37 +153,37 @@ impl Conv2dTransposeComputeShader { let ih_end = scope.create_local(Elem::UInt); let iw_end = scope.create_local(Elem::UInt); - gpu!(scope, kms_u = kernel_size_0 * dilation_0); - gpu!(scope, kms_0 = cast(kms_u)); - gpu!(scope, kms_0 = kms_0 - stride_0_i); - gpu!(scope, kms_u = kernel_size_1 * dilation_1); - gpu!(scope, kms_1 = cast(kms_u)); - gpu!(scope, kms_1 = kms_1 - stride_1_i); - - gpu!(scope, tmp_u = oh + padding_0); - gpu!(scope, tmp_i = cast(tmp_u)); - gpu!(scope, ih_start_tmp = tmp_i - kms_0); - gpu!(scope, ih_start_tmp = ih_start_tmp / stride_0_i); - gpu!(scope, tmp_u = ow + padding_1); - gpu!(scope, tmp_i = cast(tmp_u)); - gpu!(scope, iw_start_tmp = tmp_i - kms_1); - gpu!(scope, iw_start_tmp = iw_start_tmp / stride_1_i); - - gpu!(scope, tmp_i = max(ih_start_tmp, zero_i)); - gpu!(scope, ih_start = cast(tmp_i)); - gpu!(scope, tmp_i = kms_0 + ih_start_tmp); - gpu!(scope, tmp_i += one_i); - gpu!(scope, tmp_i = max(tmp_i, zero_i)); - gpu!(scope, tmp_u = cast(tmp_i)); - gpu!(scope, ih_end = min(tmp_u, input_shape_2)); - - gpu!(scope, tmp_i = max(iw_start_tmp, zero_i)); - gpu!(scope, iw_start = cast(tmp_i)); - gpu!(scope, tmp_i = kms_1 + iw_start_tmp); - gpu!(scope, tmp_i += one_i); - gpu!(scope, tmp_i = max(tmp_i, zero_i)); - gpu!(scope, tmp_u = cast(tmp_i)); - gpu!(scope, iw_end = min(tmp_u, input_shape_3)); + cube_inline!(scope, kms_u = kernel_size_0 * dilation_0); + cube_inline!(scope, kms_0 = cast(kms_u)); + cube_inline!(scope, kms_0 = kms_0 - stride_0_i); + cube_inline!(scope, kms_u = kernel_size_1 * dilation_1); + cube_inline!(scope, kms_1 = cast(kms_u)); + cube_inline!(scope, kms_1 = kms_1 - stride_1_i); + + cube_inline!(scope, tmp_u = oh + padding_0); + cube_inline!(scope, tmp_i = cast(tmp_u)); + cube_inline!(scope, ih_start_tmp = tmp_i - kms_0); + cube_inline!(scope, ih_start_tmp = ih_start_tmp / stride_0_i); + cube_inline!(scope, tmp_u = ow + padding_1); + cube_inline!(scope, tmp_i = cast(tmp_u)); + cube_inline!(scope, iw_start_tmp = tmp_i - kms_1); + cube_inline!(scope, iw_start_tmp = iw_start_tmp / stride_1_i); + + cube_inline!(scope, tmp_i = max(ih_start_tmp, zero_i)); + cube_inline!(scope, ih_start = cast(tmp_i)); + cube_inline!(scope, tmp_i = kms_0 + ih_start_tmp); + cube_inline!(scope, tmp_i += one_i); + cube_inline!(scope, tmp_i = max(tmp_i, zero_i)); + cube_inline!(scope, tmp_u = cast(tmp_i)); + cube_inline!(scope, ih_end = min(tmp_u, input_shape_2)); + + cube_inline!(scope, tmp_i = max(iw_start_tmp, zero_i)); + cube_inline!(scope, iw_start = cast(tmp_i)); + cube_inline!(scope, tmp_i = kms_1 + iw_start_tmp); + cube_inline!(scope, tmp_i += one_i); + cube_inline!(scope, tmp_i = max(tmp_i, zero_i)); + cube_inline!(scope, tmp_u = cast(tmp_i)); + cube_inline!(scope, iw_end = min(tmp_u, input_shape_3)); let index_input = scope.create_local(Elem::UInt); let index_weight = scope.create_local(Elem::UInt); @@ -197,13 +197,13 @@ impl Conv2dTransposeComputeShader { let index_weight_kh = scope.create_local(Elem::UInt); let index_weight_kw = scope.create_local(Elem::UInt); - gpu!(scope, index_input_b = b * input_stride_0); - gpu!(scope, index_weight_oc = oc * weight_stride_1); + cube_inline!(scope, index_input_b = b * input_stride_0); + cube_inline!(scope, index_weight_oc = oc * weight_stride_1); let prod = scope.create_local(output.item()); let prod_tmp = scope.create_local(output.item()); let sum = scope.create_local(output.item()); - gpu!(scope, sum = bias[oc_out]); + cube_inline!(scope, sum = bias[oc_out]); let kh = scope.create_local(Elem::UInt); let kw = scope.create_local(Elem::UInt); @@ -218,61 +218,61 @@ impl Conv2dTransposeComputeShader { let not_neg = scope.create_local(Elem::Bool); let cond = scope.create_local(Elem::Bool); - gpu!(scope, numerator_h_base = oh + padding_0); - gpu!(scope, numerator_w_base = ow + padding_1); + cube_inline!(scope, numerator_h_base = oh + padding_0); + cube_inline!(scope, numerator_w_base = ow + padding_1); - gpu!( + cube_inline!( scope, range(ic_start, ic_end).for_each(|ic, scope| { - gpu!(scope, index_input_ic = ic * input_stride_1); - gpu!(scope, index_weight_ic = ic * weight_stride_0); + cube_inline!(scope, index_input_ic = ic * input_stride_1); + cube_inline!(scope, index_weight_ic = ic * weight_stride_0); - gpu!( + cube_inline!( scope, range(ih_start, ih_end).for_each(|ih, scope| { - gpu!(scope, numerator_tmp = ih * conv_stride_0); - gpu!(scope, not_neg = numerator_h_base >= numerator_tmp); - gpu!(scope, numerator_h = numerator_h_base - numerator_tmp); + cube_inline!(scope, numerator_tmp = ih * conv_stride_0); + cube_inline!(scope, not_neg = numerator_h_base >= numerator_tmp); + cube_inline!(scope, numerator_h = numerator_h_base - numerator_tmp); - gpu!(scope, numerator_mod = numerator_h % dilation_0); - gpu!(scope, divisible = numerator_mod == zero); - gpu!(scope, cond = not_neg && divisible); + cube_inline!(scope, numerator_mod = numerator_h % dilation_0); + cube_inline!(scope, divisible = numerator_mod == zero); + cube_inline!(scope, cond = not_neg && divisible); - gpu!(scope, if(cond).then(|scope|{ - gpu!(scope, kh = numerator_h / dilation_0); - gpu!(scope, index_input_ih = ih * input_stride_2); - gpu!(scope, index_weight_kh = kh * weight_stride_2); + cube_inline!(scope, if(cond).then(|scope|{ + cube_inline!(scope, kh = numerator_h / dilation_0); + cube_inline!(scope, index_input_ih = ih * input_stride_2); + cube_inline!(scope, index_weight_kh = kh * weight_stride_2); - gpu!( + cube_inline!( scope, range(iw_start, iw_end).for_each(|iw, scope| { - gpu!(scope, numerator_tmp = iw * conv_stride_1); - gpu!(scope, not_neg = numerator_w_base >= numerator_tmp); - gpu!(scope, numerator_w = numerator_w_base - numerator_tmp); - - gpu!(scope, numerator_mod = numerator_w % dilation_1); - gpu!(scope, divisible = numerator_mod == zero); - gpu!(scope, cond = not_neg && divisible); - - gpu!(scope, if(cond).then(|scope|{ - gpu!(scope, kw = numerator_w / dilation_1); - gpu!(scope, index_input_iw = iw * input_stride_3); - gpu!(scope, index_weight_kw = kw * weight_stride_3); - - gpu!(scope, index_input = index_input_b); - gpu!(scope, index_input += index_input_ic); - gpu!(scope, index_input += index_input_ih); - gpu!(scope, index_input += index_input_iw); - - gpu!(scope, index_weight = index_weight_ic); - gpu!(scope, index_weight += index_weight_oc); - gpu!(scope, index_weight += index_weight_kh); - gpu!(scope, index_weight += index_weight_kw); - - gpu!(scope, prod = input[index_input]); - gpu!(scope, prod_tmp = weight[index_weight]); - gpu!(scope, prod *= prod_tmp); - gpu!(scope, sum += prod); + cube_inline!(scope, numerator_tmp = iw * conv_stride_1); + cube_inline!(scope, not_neg = numerator_w_base >= numerator_tmp); + cube_inline!(scope, numerator_w = numerator_w_base - numerator_tmp); + + cube_inline!(scope, numerator_mod = numerator_w % dilation_1); + cube_inline!(scope, divisible = numerator_mod == zero); + cube_inline!(scope, cond = not_neg && divisible); + + cube_inline!(scope, if(cond).then(|scope|{ + cube_inline!(scope, kw = numerator_w / dilation_1); + cube_inline!(scope, index_input_iw = iw * input_stride_3); + cube_inline!(scope, index_weight_kw = kw * weight_stride_3); + + cube_inline!(scope, index_input = index_input_b); + cube_inline!(scope, index_input += index_input_ic); + cube_inline!(scope, index_input += index_input_ih); + cube_inline!(scope, index_input += index_input_iw); + + cube_inline!(scope, index_weight = index_weight_ic); + cube_inline!(scope, index_weight += index_weight_oc); + cube_inline!(scope, index_weight += index_weight_kh); + cube_inline!(scope, index_weight += index_weight_kw); + + cube_inline!(scope, prod = input[index_input]); + cube_inline!(scope, prod_tmp = weight[index_weight]); + cube_inline!(scope, prod *= prod_tmp); + cube_inline!(scope, sum += prod); })); }) ); @@ -283,7 +283,7 @@ impl Conv2dTransposeComputeShader { }) ); - gpu!(scope, output[id] = sum); + cube_inline!(scope, output[id] = sum); } } diff --git a/crates/burn-jit/src/kernel/index/flip.rs b/crates/burn-jit/src/kernel/index/flip.rs index 8dd71a6f6c..9e19cd96f8 100644 --- a/crates/burn-jit/src/kernel/index/flip.rs +++ b/crates/burn-jit/src/kernel/index/flip.rs @@ -1,6 +1,6 @@ use crate::{ codegen::{ - dialect::gpu::{gpu, Elem, Scope, Variable, Visibility}, + dialect::gpu::{cube_inline, Elem, Scope, Variable, Visibility}, Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, @@ -42,29 +42,29 @@ impl FlipComputeShader { let flip_bool = scope.create_local(Elem::Bool); for i in 0..self.rank { - gpu!(scope, stride = stride(input, i)); - gpu!(scope, shape = shape(output, i)); - gpu!( + cube_inline!(scope, stride = stride(input, i)); + cube_inline!(scope, shape = shape(output, i)); + cube_inline!( scope, flip = cast(Variable::GlobalScalar(i as u16, Elem::UInt)) ); - gpu!(scope, flip_bool = flip == 1u32); + cube_inline!(scope, flip_bool = flip == 1u32); - gpu!(scope, offset_local = id / stride); - gpu!(scope, offset_local = offset_local % shape); + cube_inline!(scope, offset_local = id / stride); + cube_inline!(scope, offset_local = offset_local % shape); - gpu!(scope, if(flip_bool).then(|scope| { - gpu!(scope, offset_local = shape - offset_local); - gpu!(scope, offset_local = offset_local - 1u32); + cube_inline!(scope, if(flip_bool).then(|scope| { + cube_inline!(scope, offset_local = shape - offset_local); + cube_inline!(scope, offset_local = offset_local - 1u32); })); - gpu!(scope, offset_local = offset_local * stride); + cube_inline!(scope, offset_local = offset_local * stride); - gpu!(scope, offset_input += offset_local); + cube_inline!(scope, offset_input += offset_local); } let result = scope.create_local(input.item()); - gpu!(scope, result = input[offset_input]); - gpu!(scope, output[id] = result); + cube_inline!(scope, result = input[offset_input]); + cube_inline!(scope, output[id] = result); } } diff --git a/crates/burn-jit/src/kernel/index/gather.rs b/crates/burn-jit/src/kernel/index/gather.rs index 3622a70fcd..8bba9d0abb 100644 --- a/crates/burn-jit/src/kernel/index/gather.rs +++ b/crates/burn-jit/src/kernel/index/gather.rs @@ -1,4 +1,4 @@ -use crate::codegen::dialect::gpu::{gpu, Elem, Scope, Variable}; +use crate::codegen::dialect::gpu::{cube_inline, Elem, Scope, Variable}; use crate::codegen::Execution; use crate::gpu::{ComputeShader, IntKind}; use crate::{ @@ -43,9 +43,9 @@ impl GatherComputeShader { let offset = scope.create_local(Elem::UInt); // The offset of the `dim` dimension is obtained by the indices tensor. - gpu!(scope, offset = cast(self.indices)); - gpu!(scope, stride = stride(tensor, self.dim)); - gpu!(scope, offset = offset * stride); + cube_inline!(scope, offset = cast(self.indices)); + cube_inline!(scope, stride = stride(tensor, self.dim)); + cube_inline!(scope, offset = offset * stride); // We fetch the offset before the `dim` dimension. if self.dim > 0 { @@ -58,7 +58,7 @@ impl GatherComputeShader { dim_start: 0u32.into(), dim_end: self.dim.into(), }); - gpu!(scope, offset += offset_before); + cube_inline!(scope, offset += offset_before); } let offset_after = scope.create_local(Elem::UInt); @@ -70,9 +70,9 @@ impl GatherComputeShader { dim_start: (self.dim + 1).into(), dim_end: Variable::Rank, }); - gpu!(scope, offset += offset_after); + cube_inline!(scope, offset += offset_after); - gpu!(scope, output = tensor[offset]); + cube_inline!(scope, output = tensor[offset]); } } diff --git a/crates/burn-jit/src/kernel/index/repeat.rs b/crates/burn-jit/src/kernel/index/repeat.rs index 3c4dc0a52d..443f4f7bd0 100644 --- a/crates/burn-jit/src/kernel/index/repeat.rs +++ b/crates/burn-jit/src/kernel/index/repeat.rs @@ -1,6 +1,6 @@ use crate::{ codegen::{ - dialect::gpu::{gpu, Elem, Scope, Variable, Visibility}, + dialect::gpu::{cube_inline, Elem, Scope, Variable, Visibility}, Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, @@ -42,20 +42,20 @@ impl RepeatComputeShader { for i in 0..self.rank { if i != self.dim { - gpu!(scope, stride_input = stride(input, i)); - gpu!(scope, stride_output = stride(output, i)); - gpu!(scope, shape_output = shape(output, i)); - - gpu!(scope, offset_local = id / stride_output); - gpu!(scope, offset_local = offset_local % shape_output); - gpu!(scope, offset_local = offset_local * stride_input); - gpu!(scope, offset_input += offset_local); + cube_inline!(scope, stride_input = stride(input, i)); + cube_inline!(scope, stride_output = stride(output, i)); + cube_inline!(scope, shape_output = shape(output, i)); + + cube_inline!(scope, offset_local = id / stride_output); + cube_inline!(scope, offset_local = offset_local % shape_output); + cube_inline!(scope, offset_local = offset_local * stride_input); + cube_inline!(scope, offset_input += offset_local); } } let result = scope.create_local(input.item()); - gpu!(scope, result = input[offset_input]); - gpu!(scope, output[id] = result); + cube_inline!(scope, result = input[offset_input]); + cube_inline!(scope, output[id] = result); } } impl GpuComputeShaderPhase for RepeatEagerKernel { diff --git a/crates/burn-jit/src/kernel/index/scatter.rs b/crates/burn-jit/src/kernel/index/scatter.rs index 1312f1f393..7422aa075d 100644 --- a/crates/burn-jit/src/kernel/index/scatter.rs +++ b/crates/burn-jit/src/kernel/index/scatter.rs @@ -1,4 +1,4 @@ -use crate::codegen::dialect::gpu::{gpu, Branch, Elem, Scope, Variable}; +use crate::codegen::dialect::gpu::{cube_inline, Branch, Elem, Scope, Variable}; use crate::codegen::Execution; use crate::gpu::ComputeShader; use crate::kernel::{elemwise_workgroup, WORKGROUP_DEFAULT}; @@ -48,22 +48,22 @@ impl ScatterComputeShader { let stride_input = scope.create_local(Elem::UInt); let shape_value = scope.create_local(Elem::UInt); - gpu!(scope, stride_input = stride(input, self.dim)); - gpu!(scope, shape_value = shape(value, self.dim)); + cube_inline!(scope, stride_input = stride(input, self.dim)); + cube_inline!(scope, shape_value = shape(value, self.dim)); let id = Variable::Id; let offset_input = scope.zero(Elem::UInt); let offset_value = scope.zero(Elem::UInt); let num_elems = scope.create_local(Elem::UInt); - gpu!(scope, num_elems = cast(1usize)); - gpu!( + cube_inline!(scope, num_elems = cast(1usize)); + cube_inline!( scope, range(0u32, Variable::Rank).for_each(|i, scope| { let should_skip = scope.create_local(Elem::Bool); - gpu!(scope, should_skip = i == self.dim); + cube_inline!(scope, should_skip = i == self.dim); - gpu!(scope, if(should_skip).then(|_| { + cube_inline!(scope, if(should_skip).then(|_| { // Nothing to do. }).else(|scope| { let shape_input_loop = scope.create_local(Elem::UInt); @@ -75,30 +75,30 @@ impl ScatterComputeShader { let offset_tmp = scope.create_local(Elem::UInt); let stride_input_loop = scope.create_local(Elem::UInt); - gpu!(scope, stride_value_loop = stride(value, i)); - gpu!(scope, stride_input_loop = stride(input, i)); - gpu!(scope, stride_tmp = stride(indices, i)); + cube_inline!(scope, stride_value_loop = stride(value, i)); + cube_inline!(scope, stride_input_loop = stride(input, i)); + cube_inline!(scope, stride_tmp = stride(indices, i)); - gpu!(scope, shape_value_loop = shape(value, i)); - gpu!(scope, shape_input_loop = shape(input, i)); + cube_inline!(scope, shape_value_loop = shape(value, i)); + cube_inline!(scope, shape_input_loop = shape(input, i)); - gpu!(scope, num_blocks = id / stride_tmp); - gpu!(scope, num_blocks = num_blocks % shape_input_loop); + cube_inline!(scope, num_blocks = id / stride_tmp); + cube_inline!(scope, num_blocks = num_blocks % shape_input_loop); - gpu!(scope, offset_tmp = num_blocks * stride_input_loop); - gpu!(scope, offset_input += offset_tmp); + cube_inline!(scope, offset_tmp = num_blocks * stride_input_loop); + cube_inline!(scope, offset_input += offset_tmp); - gpu!(scope, offset_tmp = num_blocks * stride_value_loop); - gpu!(scope, offset_value += offset_tmp); + cube_inline!(scope, offset_tmp = num_blocks * stride_value_loop); + cube_inline!(scope, offset_value += offset_tmp); - gpu!(scope, num_elems = num_elems * shape_value_loop); + cube_inline!(scope, num_elems = num_elems * shape_value_loop); })); }) ); let should_stop = scope.create_local(Elem::Bool); - gpu!(scope, should_stop = id >= num_elems); - gpu!(scope, if (should_stop).then(|scope|{ + cube_inline!(scope, should_stop = id >= num_elems); + cube_inline!(scope, if (should_stop).then(|scope|{ scope.register(Branch::Return); })); @@ -109,21 +109,21 @@ impl ScatterComputeShader { let result_value = scope.create_local(value.item()); let result_indices = scope.create_local(Elem::UInt); - gpu!( + cube_inline!( scope, range(0u32, shape_value).for_each(|i, scope| { - gpu!(scope, index = stride_input * i); - gpu!(scope, index += offset_value); + cube_inline!(scope, index = stride_input * i); + cube_inline!(scope, index += offset_value); - gpu!(scope, result_value = value[index]); - gpu!(scope, result_indices = indices[index]); + cube_inline!(scope, result_value = value[index]); + cube_inline!(scope, result_indices = indices[index]); - gpu!(scope, index_input = stride_input * result_indices); - gpu!(scope, index_input += offset_input); + cube_inline!(scope, index_input = stride_input * result_indices); + cube_inline!(scope, index_input += offset_input); - gpu!(scope, result_input = input[index_input]); - gpu!(scope, result_input += result_value); - gpu!(scope, input[index_input] = result_input); + cube_inline!(scope, result_input = input[index_input]); + cube_inline!(scope, result_input += result_value); + cube_inline!(scope, input[index_input] = result_input); }) ); } diff --git a/crates/burn-jit/src/kernel/index/select.rs b/crates/burn-jit/src/kernel/index/select.rs index 403fe8b49d..4966e766d1 100644 --- a/crates/burn-jit/src/kernel/index/select.rs +++ b/crates/burn-jit/src/kernel/index/select.rs @@ -1,6 +1,6 @@ use crate::{ codegen::{ - dialect::gpu::{gpu, Elem, IntKind, Item, Scope, Variable, Visibility}, + dialect::gpu::{cube_inline, Elem, IntKind, Item, Scope, Variable, Visibility}, Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, @@ -35,38 +35,38 @@ impl SelectComputeShader { let id = Variable::Id; let offset_input = scope.zero(Elem::UInt); - gpu!( + cube_inline!( scope, range(0u32, Variable::Rank).for_each(|i, scope| { let stride_input = scope.create_local(Elem::UInt); let stride_output = scope.create_local(Elem::UInt); let shape_output = scope.create_local(Elem::UInt); - gpu!(scope, stride_input = stride(input, i)); - gpu!(scope, stride_output = stride(output, i)); - gpu!(scope, shape_output = shape(output, i)); + cube_inline!(scope, stride_input = stride(input, i)); + cube_inline!(scope, stride_output = stride(output, i)); + cube_inline!(scope, shape_output = shape(output, i)); let offset_local = scope.create_local(Elem::UInt); - gpu!(scope, offset_local = id / stride_output); - gpu!(scope, offset_local = offset_local % shape_output); + cube_inline!(scope, offset_local = id / stride_output); + cube_inline!(scope, offset_local = offset_local % shape_output); let dim_index = scope.create_local(Elem::Bool); - gpu!(scope, dim_index = i == self.dim); + cube_inline!(scope, dim_index = i == self.dim); - gpu!(scope, if(dim_index).then(|scope| { - gpu!(scope, offset_local = indices[offset_local]); - gpu!(scope, offset_local = offset_local * stride_input); + cube_inline!(scope, if(dim_index).then(|scope| { + cube_inline!(scope, offset_local = indices[offset_local]); + cube_inline!(scope, offset_local = offset_local * stride_input); }).else(|scope| { - gpu!(scope, offset_local = offset_local * stride_input); + cube_inline!(scope, offset_local = offset_local * stride_input); })); - gpu!(scope, offset_input += offset_local); + cube_inline!(scope, offset_input += offset_local); }) ); let value = scope.create_local(input.item()); - gpu!(scope, value = input[offset_input]); - gpu!(scope, output[id] = value); + cube_inline!(scope, value = input[offset_input]); + cube_inline!(scope, output[id] = value); } } diff --git a/crates/burn-jit/src/kernel/index/select_assign.rs b/crates/burn-jit/src/kernel/index/select_assign.rs index ca6be7ae34..0fc5e8d58b 100644 --- a/crates/burn-jit/src/kernel/index/select_assign.rs +++ b/crates/burn-jit/src/kernel/index/select_assign.rs @@ -1,6 +1,6 @@ use crate::{ codegen::{ - dialect::gpu::{gpu, Branch, Elem, IntKind, Item, Scope, Variable, Visibility}, + dialect::gpu::{cube_inline, Branch, Elem, IntKind, Item, Scope, Variable, Visibility}, Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, WorkgroupLaunch, }, @@ -41,65 +41,65 @@ impl SelectAssignComputeShader { let shape_value_dim = scope.create_local(Elem::UInt); let num_elems = scope.create_local(Elem::UInt); - gpu!(scope, num_elems = cast(1u32)); + cube_inline!(scope, num_elems = cast(1u32)); - gpu!( + cube_inline!( scope, range(0u32, Variable::Rank).for_each(|i, scope| { let shape_value = scope.create_local(Elem::UInt); let stride_tensor = scope.create_local(Elem::UInt); let stride_value = scope.create_local(Elem::UInt); - gpu!(scope, stride_tensor = stride(tensor, i)); - gpu!(scope, stride_value = stride(value, i)); - gpu!(scope, shape_value = shape(value, i)); + cube_inline!(scope, stride_tensor = stride(tensor, i)); + cube_inline!(scope, stride_value = stride(value, i)); + cube_inline!(scope, shape_value = shape(value, i)); let dim_index = scope.create_local(Elem::Bool); - gpu!(scope, dim_index = i == self.dim); + cube_inline!(scope, dim_index = i == self.dim); - gpu!(scope, if(dim_index).then(|scope| { - gpu!(scope, shape_value_dim = shape_value); - gpu!(scope, stride_tensor_dim = stride_tensor); - gpu!(scope, stride_value_dim = stride_value); + cube_inline!(scope, if(dim_index).then(|scope| { + cube_inline!(scope, shape_value_dim = shape_value); + cube_inline!(scope, stride_tensor_dim = stride_tensor); + cube_inline!(scope, stride_value_dim = stride_value); }).else(|scope| { let stride_tmp = scope.create_local(Elem::UInt); let shape_tensor = scope.create_local(Elem::UInt); - gpu!(scope, stride_tmp = stride(indices, i)); - gpu!(scope, shape_tensor = shape(tensor, i)); + cube_inline!(scope, stride_tmp = stride(indices, i)); + cube_inline!(scope, shape_tensor = shape(tensor, i)); - gpu!(scope, num_elems = num_elems * shape_tensor); + cube_inline!(scope, num_elems = num_elems * shape_tensor); let offset_local = scope.create_local(Elem::UInt); let offset_local_tensor = scope.create_local(Elem::UInt); let offset_local_value = scope.create_local(Elem::UInt); - gpu!(scope, offset_local = id / stride_tmp); + cube_inline!(scope, offset_local = id / stride_tmp); - gpu!(scope, offset_local_tensor = offset_local % shape_tensor); - gpu!( + cube_inline!(scope, offset_local_tensor = offset_local % shape_tensor); + cube_inline!( scope, offset_local_tensor = offset_local_tensor * stride_tensor ); - gpu!(scope, offset_tensor += offset_local_tensor); + cube_inline!(scope, offset_tensor += offset_local_tensor); - gpu!(scope, offset_local_value = offset_local % shape_value); - gpu!( + cube_inline!(scope, offset_local_value = offset_local % shape_value); + cube_inline!( scope, offset_local_value = offset_local_value * stride_value ); - gpu!(scope, offset_value += offset_local_value); + cube_inline!(scope, offset_value += offset_local_value); })); }) ); let should_stop = scope.create_local(Elem::Bool); - gpu!(scope, should_stop = id >= num_elems); - gpu!(scope, if(should_stop).then(|scope| { + cube_inline!(scope, should_stop = id >= num_elems); + cube_inline!(scope, if(should_stop).then(|scope| { scope.register(Branch::Return); })); - gpu!( + cube_inline!( scope, range(0u32, shape_value_dim).for_each(|i, scope| { let index = scope.create_local(Elem::UInt); @@ -110,19 +110,19 @@ impl SelectAssignComputeShader { let result_value = scope.create_local(value.item()); let result = scope.create_local(tensor.item()); - gpu!(scope, index = indices[i]); + cube_inline!(scope, index = indices[i]); - gpu!(scope, index_tensor = index * stride_tensor_dim); - gpu!(scope, index_tensor += offset_tensor); + cube_inline!(scope, index_tensor = index * stride_tensor_dim); + cube_inline!(scope, index_tensor += offset_tensor); - gpu!(scope, index_value = i * stride_value_dim); - gpu!(scope, index_value += offset_value); + cube_inline!(scope, index_value = i * stride_value_dim); + cube_inline!(scope, index_value += offset_value); - gpu!(scope, result_tensor = tensor[index_tensor]); - gpu!(scope, result_value = value[index_value]); - gpu!(scope, result = result_value + result_tensor); + cube_inline!(scope, result_tensor = tensor[index_tensor]); + cube_inline!(scope, result_value = value[index_value]); + cube_inline!(scope, result = result_value + result_tensor); - gpu!(scope, tensor[index_tensor] = result); + cube_inline!(scope, tensor[index_tensor] = result); }) ); } diff --git a/crates/burn-jit/src/kernel/index/slice.rs b/crates/burn-jit/src/kernel/index/slice.rs index a9023e30ac..2b33ec22b6 100644 --- a/crates/burn-jit/src/kernel/index/slice.rs +++ b/crates/burn-jit/src/kernel/index/slice.rs @@ -1,6 +1,6 @@ use crate::{ codegen::{ - dialect::gpu::{gpu, Elem, Scope, Variable, Visibility}, + dialect::gpu::{cube_inline, Elem, Scope, Variable, Visibility}, Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, @@ -42,25 +42,25 @@ impl SliceComputeShader { let range_start = scope.create_local(Elem::UInt); for i in 0..self.rank { - gpu!(scope, stride_input = stride(input, i)); - gpu!(scope, stride_output = stride(output, i)); - gpu!(scope, shape_output = shape(output, i)); - gpu!( + cube_inline!(scope, stride_input = stride(input, i)); + cube_inline!(scope, stride_output = stride(output, i)); + cube_inline!(scope, shape_output = shape(output, i)); + cube_inline!( scope, range_start = cast(Variable::GlobalScalar(i as u16, Elem::UInt)) ); - gpu!(scope, offset_local = id / stride_output); - gpu!(scope, offset_local = offset_local % shape_output); - gpu!(scope, offset_local = offset_local + range_start); - gpu!(scope, offset_local = offset_local * stride_input); + cube_inline!(scope, offset_local = id / stride_output); + cube_inline!(scope, offset_local = offset_local % shape_output); + cube_inline!(scope, offset_local = offset_local + range_start); + cube_inline!(scope, offset_local = offset_local * stride_input); - gpu!(scope, offset_input += offset_local); + cube_inline!(scope, offset_input += offset_local); } let result = scope.create_local(input.item()); - gpu!(scope, result = input[offset_input]); - gpu!(scope, output[id] = result); + cube_inline!(scope, result = input[offset_input]); + cube_inline!(scope, output[id] = result); } } diff --git a/crates/burn-jit/src/kernel/index/slice_assign.rs b/crates/burn-jit/src/kernel/index/slice_assign.rs index d316b05aa0..2263715fd9 100644 --- a/crates/burn-jit/src/kernel/index/slice_assign.rs +++ b/crates/burn-jit/src/kernel/index/slice_assign.rs @@ -1,6 +1,6 @@ use crate::{ codegen::{ - dialect::gpu::{gpu, Elem, Scope, Variable, Visibility}, + dialect::gpu::{cube_inline, Elem, Scope, Variable, Visibility}, Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, WorkgroupLaunch, }, @@ -46,32 +46,32 @@ impl SliceAssignComputeShader { let range_start = scope.create_local(Elem::UInt); for i in 0..self.rank { - gpu!(scope, stride_input = stride(input, i)); - gpu!(scope, stride_value = stride(value, i)); - gpu!(scope, shape_value = shape(value, i)); - gpu!(scope, shape_input = shape(input, i)); - gpu!( + cube_inline!(scope, stride_input = stride(input, i)); + cube_inline!(scope, stride_value = stride(value, i)); + cube_inline!(scope, shape_value = shape(value, i)); + cube_inline!(scope, shape_input = shape(input, i)); + cube_inline!( scope, range_start = cast(Variable::GlobalScalar(i as u16, Elem::UInt)) ); - gpu!(scope, offset_local = id / stride_value); - gpu!(scope, offset_local = offset_local % shape_value); + cube_inline!(scope, offset_local = id / stride_value); + cube_inline!(scope, offset_local = offset_local % shape_value); - gpu!(scope, offset_local_value = offset_local * stride_value); - gpu!(scope, offset_local_input = offset_local + range_start); - gpu!( + cube_inline!(scope, offset_local_value = offset_local * stride_value); + cube_inline!(scope, offset_local_input = offset_local + range_start); + cube_inline!( scope, offset_local_input = offset_local_input * stride_input ); - gpu!(scope, offset_value += offset_local_value); - gpu!(scope, offset_input += offset_local_input); + cube_inline!(scope, offset_value += offset_local_value); + cube_inline!(scope, offset_input += offset_local_input); } let result = scope.create_local(input.item()); - gpu!(scope, result = value[offset_value]); - gpu!(scope, input[offset_input] = result); + cube_inline!(scope, result = value[offset_value]); + cube_inline!(scope, input[offset_input] = result); } } diff --git a/crates/burn-jit/src/kernel/interpolate/bicubic.rs b/crates/burn-jit/src/kernel/interpolate/bicubic.rs index d00fbd1ab0..dd6dc981ad 100644 --- a/crates/burn-jit/src/kernel/interpolate/bicubic.rs +++ b/crates/burn-jit/src/kernel/interpolate/bicubic.rs @@ -5,7 +5,7 @@ use crate::{ Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, - gpu::{gpu, ComputeShader, Elem, Scope, Variable, Visibility}, + gpu::{cube_inline, ComputeShader, Elem, Scope, Variable, Visibility}, kernel::GpuComputeShaderPhase, tensor::JitTensor, JitElement, Runtime, @@ -48,40 +48,40 @@ impl InterpolateBicubicShader { let output_shape_2 = scope.create_local(Elem::UInt); let output_shape_3 = scope.create_local(Elem::UInt); - gpu!(scope, input_stride_0 = stride(input, 0u32)); - gpu!(scope, input_stride_1 = stride(input, 1u32)); - gpu!(scope, input_stride_2 = stride(input, 2u32)); - gpu!(scope, input_stride_3 = stride(input, 3u32)); + cube_inline!(scope, input_stride_0 = stride(input, 0u32)); + cube_inline!(scope, input_stride_1 = stride(input, 1u32)); + cube_inline!(scope, input_stride_2 = stride(input, 2u32)); + cube_inline!(scope, input_stride_3 = stride(input, 3u32)); - gpu!(scope, input_shape_2 = shape(input, 2u32)); - gpu!(scope, input_shape_3 = shape(input, 3u32)); + cube_inline!(scope, input_shape_2 = shape(input, 2u32)); + cube_inline!(scope, input_shape_3 = shape(input, 3u32)); - gpu!(scope, output_stride_0 = stride(output, 0u32)); - gpu!(scope, output_stride_1 = stride(output, 1u32)); - gpu!(scope, output_stride_2 = stride(output, 2u32)); - gpu!(scope, output_stride_3 = stride(output, 3u32)); + cube_inline!(scope, output_stride_0 = stride(output, 0u32)); + cube_inline!(scope, output_stride_1 = stride(output, 1u32)); + cube_inline!(scope, output_stride_2 = stride(output, 2u32)); + cube_inline!(scope, output_stride_3 = stride(output, 3u32)); - gpu!(scope, output_shape_0 = shape(output, 0u32)); - gpu!(scope, output_shape_1 = shape(output, 1u32)); - gpu!(scope, output_shape_2 = shape(output, 2u32)); - gpu!(scope, output_shape_3 = shape(output, 3u32)); + cube_inline!(scope, output_shape_0 = shape(output, 0u32)); + cube_inline!(scope, output_shape_1 = shape(output, 1u32)); + cube_inline!(scope, output_shape_2 = shape(output, 2u32)); + cube_inline!(scope, output_shape_3 = shape(output, 3u32)); let b = scope.create_local(Elem::UInt); let c = scope.create_local(Elem::UInt); let h = scope.create_local(Elem::UInt); let w = scope.create_local(Elem::UInt); - gpu!(scope, b = id / output_stride_0); - gpu!(scope, b = b % output_shape_0); + cube_inline!(scope, b = id / output_stride_0); + cube_inline!(scope, b = b % output_shape_0); - gpu!(scope, c = id / output_stride_1); - gpu!(scope, c = c % output_shape_1); + cube_inline!(scope, c = id / output_stride_1); + cube_inline!(scope, c = c % output_shape_1); - gpu!(scope, h = id / output_stride_2); - gpu!(scope, h = h % output_shape_2); + cube_inline!(scope, h = id / output_stride_2); + cube_inline!(scope, h = h % output_shape_2); - gpu!(scope, w = id / output_stride_3); - gpu!(scope, w = w % output_shape_3); + cube_inline!(scope, w = id / output_stride_3); + cube_inline!(scope, w = w % output_shape_3); let input_height = scope.create_local(Elem::UInt); let output_height = scope.create_local(Elem::UInt); @@ -101,28 +101,28 @@ impl InterpolateBicubicShader { let yw = scope.create_local(elem); let y_tmp = scope.create_local(Elem::UInt); - gpu!(scope, input_height = input_shape_2 - 1u32); - gpu!(scope, output_height = output_shape_2 - 1u32); - gpu!(scope, numerator = h * input_height); - gpu!(scope, numerator_float = cast(numerator)); - gpu!(scope, output_height_float = cast(output_height)); - gpu!(scope, frac = numerator_float / output_height_float); - gpu!(scope, y_in_float = floor(frac)); - gpu!(scope, y_in = cast(y_in_float)); - gpu!(scope, yw = frac - y_in_float); + cube_inline!(scope, input_height = input_shape_2 - 1u32); + cube_inline!(scope, output_height = output_shape_2 - 1u32); + cube_inline!(scope, numerator = h * input_height); + cube_inline!(scope, numerator_float = cast(numerator)); + cube_inline!(scope, output_height_float = cast(output_height)); + cube_inline!(scope, frac = numerator_float / output_height_float); + cube_inline!(scope, y_in_float = floor(frac)); + cube_inline!(scope, y_in = cast(y_in_float)); + cube_inline!(scope, yw = frac - y_in_float); let y0 = scope.zero(Elem::UInt); - gpu!(scope, not_zero = y_in != 0u32); - gpu!(scope, if(not_zero).then(|scope|{ - gpu!(scope, y0 = y_in - 1u32); + cube_inline!(scope, not_zero = y_in != 0u32); + cube_inline!(scope, if(not_zero).then(|scope|{ + cube_inline!(scope, y0 = y_in - 1u32); })); let y1 = y_in; - gpu!(scope, y_tmp = y_in + 1u32); + cube_inline!(scope, y_tmp = y_in + 1u32); let y2 = Self::min(scope, y_tmp, input_height); - gpu!(scope, y_tmp = y_in + 2u32); + cube_inline!(scope, y_tmp = y_in + 2u32); let y3 = Self::min(scope, y_tmp, input_height); let x_in_float = scope.create_local(elem); @@ -130,36 +130,36 @@ impl InterpolateBicubicShader { let xw = scope.create_local(elem); let x_tmp = scope.create_local(Elem::UInt); - gpu!(scope, input_width = input_shape_3 - 1u32); - gpu!(scope, output_width = output_shape_3 - 1u32); - gpu!(scope, numerator = w * input_width); - gpu!(scope, numerator_float = cast(numerator)); - gpu!(scope, output_width_float = cast(output_width)); - gpu!(scope, frac = numerator_float / output_width_float); - gpu!(scope, x_in_float = floor(frac)); - gpu!(scope, x_in = cast(x_in_float)); - gpu!(scope, xw = frac - x_in_float); + cube_inline!(scope, input_width = input_shape_3 - 1u32); + cube_inline!(scope, output_width = output_shape_3 - 1u32); + cube_inline!(scope, numerator = w * input_width); + cube_inline!(scope, numerator_float = cast(numerator)); + cube_inline!(scope, output_width_float = cast(output_width)); + cube_inline!(scope, frac = numerator_float / output_width_float); + cube_inline!(scope, x_in_float = floor(frac)); + cube_inline!(scope, x_in = cast(x_in_float)); + cube_inline!(scope, xw = frac - x_in_float); let x0 = scope.zero(Elem::UInt); - gpu!(scope, not_zero = x_in != 0u32); - gpu!(scope, if(not_zero).then(|scope|{ - gpu!(scope, x0 = x_in - 1u32); + cube_inline!(scope, not_zero = x_in != 0u32); + cube_inline!(scope, if(not_zero).then(|scope|{ + cube_inline!(scope, x0 = x_in - 1u32); })); - gpu!(scope, x_tmp = x_in - 1u32); + cube_inline!(scope, x_tmp = x_in - 1u32); let x1 = x_in; - gpu!(scope, x_tmp = x_in + 1u32); + cube_inline!(scope, x_tmp = x_in + 1u32); let x2 = Self::min(scope, x_tmp, input_width); - gpu!(scope, x_tmp = x_in + 2u32); + cube_inline!(scope, x_tmp = x_in + 2u32); let x3 = Self::min(scope, x_tmp, input_width); let index_base = scope.create_local(Elem::UInt); let index_tmp = scope.create_local(Elem::UInt); - gpu!(scope, index_base = b * input_stride_0); - gpu!(scope, index_tmp = c * input_stride_1); - gpu!(scope, index_base += index_tmp); + cube_inline!(scope, index_base = b * input_stride_0); + cube_inline!(scope, index_tmp = c * input_stride_1); + cube_inline!(scope, index_base += index_tmp); let y0_stride = scope.create_local(Elem::UInt); let y1_stride = scope.create_local(Elem::UInt); @@ -169,14 +169,14 @@ impl InterpolateBicubicShader { let x1_stride = scope.create_local(Elem::UInt); let x2_stride = scope.create_local(Elem::UInt); let x3_stride = scope.create_local(Elem::UInt); - gpu!(scope, y0_stride = y0 * input_stride_2); - gpu!(scope, y1_stride = y1 * input_stride_2); - gpu!(scope, y2_stride = y2 * input_stride_2); - gpu!(scope, y3_stride = y3 * input_stride_2); - gpu!(scope, x0_stride = x0 * input_stride_3); - gpu!(scope, x1_stride = x1 * input_stride_3); - gpu!(scope, x2_stride = x2 * input_stride_3); - gpu!(scope, x3_stride = x3 * input_stride_3); + cube_inline!(scope, y0_stride = y0 * input_stride_2); + cube_inline!(scope, y1_stride = y1 * input_stride_2); + cube_inline!(scope, y2_stride = y2 * input_stride_2); + cube_inline!(scope, y3_stride = y3 * input_stride_2); + cube_inline!(scope, x0_stride = x0 * input_stride_3); + cube_inline!(scope, x1_stride = x1 * input_stride_3); + cube_inline!(scope, x2_stride = x2 * input_stride_3); + cube_inline!(scope, x3_stride = x3 * input_stride_3); let index_0 = scope.create_local(Elem::UInt); let index_1 = scope.create_local(Elem::UInt); @@ -187,79 +187,79 @@ impl InterpolateBicubicShader { let inp_2 = scope.create_local(input.item()); let inp_3 = scope.create_local(input.item()); - gpu!(scope, index_0 = index_base); - gpu!(scope, index_0 += y0_stride); - gpu!(scope, index_0 += x0_stride); - gpu!(scope, inp_0 = input[index_0]); - gpu!(scope, index_1 = index_base); - gpu!(scope, index_1 += y0_stride); - gpu!(scope, index_1 += x1_stride); - gpu!(scope, inp_1 = input[index_1]); - gpu!(scope, index_2 = index_base); - gpu!(scope, index_2 += y0_stride); - gpu!(scope, index_2 += x2_stride); - gpu!(scope, inp_2 = input[index_2]); - gpu!(scope, index_3 = index_base); - gpu!(scope, index_3 += y0_stride); - gpu!(scope, index_3 += x3_stride); - gpu!(scope, inp_3 = input[index_3]); + cube_inline!(scope, index_0 = index_base); + cube_inline!(scope, index_0 += y0_stride); + cube_inline!(scope, index_0 += x0_stride); + cube_inline!(scope, inp_0 = input[index_0]); + cube_inline!(scope, index_1 = index_base); + cube_inline!(scope, index_1 += y0_stride); + cube_inline!(scope, index_1 += x1_stride); + cube_inline!(scope, inp_1 = input[index_1]); + cube_inline!(scope, index_2 = index_base); + cube_inline!(scope, index_2 += y0_stride); + cube_inline!(scope, index_2 += x2_stride); + cube_inline!(scope, inp_2 = input[index_2]); + cube_inline!(scope, index_3 = index_base); + cube_inline!(scope, index_3 += y0_stride); + cube_inline!(scope, index_3 += x3_stride); + cube_inline!(scope, inp_3 = input[index_3]); let coefficients0 = Self::cubic_interp1d(scope, inp_0, inp_1, inp_2, inp_3, xw); - gpu!(scope, index_0 = index_base); - gpu!(scope, index_0 += y1_stride); - gpu!(scope, index_0 += x0_stride); - gpu!(scope, inp_0 = input[index_0]); - gpu!(scope, index_1 = index_base); - gpu!(scope, index_1 += y1_stride); - gpu!(scope, index_1 += x1_stride); - gpu!(scope, inp_1 = input[index_1]); - gpu!(scope, index_2 = index_base); - gpu!(scope, index_2 += y1_stride); - gpu!(scope, index_2 += x2_stride); - gpu!(scope, inp_2 = input[index_2]); - gpu!(scope, index_3 = index_base); - gpu!(scope, index_3 += y1_stride); - gpu!(scope, index_3 += x3_stride); - gpu!(scope, inp_3 = input[index_3]); + cube_inline!(scope, index_0 = index_base); + cube_inline!(scope, index_0 += y1_stride); + cube_inline!(scope, index_0 += x0_stride); + cube_inline!(scope, inp_0 = input[index_0]); + cube_inline!(scope, index_1 = index_base); + cube_inline!(scope, index_1 += y1_stride); + cube_inline!(scope, index_1 += x1_stride); + cube_inline!(scope, inp_1 = input[index_1]); + cube_inline!(scope, index_2 = index_base); + cube_inline!(scope, index_2 += y1_stride); + cube_inline!(scope, index_2 += x2_stride); + cube_inline!(scope, inp_2 = input[index_2]); + cube_inline!(scope, index_3 = index_base); + cube_inline!(scope, index_3 += y1_stride); + cube_inline!(scope, index_3 += x3_stride); + cube_inline!(scope, inp_3 = input[index_3]); let coefficients1 = Self::cubic_interp1d(scope, inp_0, inp_1, inp_2, inp_3, xw); - gpu!(scope, index_0 = index_base); - gpu!(scope, index_0 += y2_stride); - gpu!(scope, index_0 += x0_stride); - gpu!(scope, inp_0 = input[index_0]); - gpu!(scope, index_1 = index_base); - gpu!(scope, index_1 += y2_stride); - gpu!(scope, index_1 += x1_stride); - gpu!(scope, inp_1 = input[index_1]); - gpu!(scope, index_2 = index_base); - gpu!(scope, index_2 += y2_stride); - gpu!(scope, index_2 += x2_stride); - gpu!(scope, inp_2 = input[index_2]); - gpu!(scope, index_3 = index_base); - gpu!(scope, index_3 += y2_stride); - gpu!(scope, index_3 += x3_stride); - gpu!(scope, inp_3 = input[index_3]); + cube_inline!(scope, index_0 = index_base); + cube_inline!(scope, index_0 += y2_stride); + cube_inline!(scope, index_0 += x0_stride); + cube_inline!(scope, inp_0 = input[index_0]); + cube_inline!(scope, index_1 = index_base); + cube_inline!(scope, index_1 += y2_stride); + cube_inline!(scope, index_1 += x1_stride); + cube_inline!(scope, inp_1 = input[index_1]); + cube_inline!(scope, index_2 = index_base); + cube_inline!(scope, index_2 += y2_stride); + cube_inline!(scope, index_2 += x2_stride); + cube_inline!(scope, inp_2 = input[index_2]); + cube_inline!(scope, index_3 = index_base); + cube_inline!(scope, index_3 += y2_stride); + cube_inline!(scope, index_3 += x3_stride); + cube_inline!(scope, inp_3 = input[index_3]); let coefficients2 = Self::cubic_interp1d(scope, inp_0, inp_1, inp_2, inp_3, xw); - gpu!(scope, index_0 = index_base); - gpu!(scope, index_0 += y3_stride); - gpu!(scope, index_0 += x0_stride); - gpu!(scope, inp_0 = input[index_0]); - gpu!(scope, index_1 = index_base); - gpu!(scope, index_1 += y3_stride); - gpu!(scope, index_1 += x1_stride); - gpu!(scope, inp_1 = input[index_1]); - gpu!(scope, index_2 = index_base); - gpu!(scope, index_2 += y3_stride); - gpu!(scope, index_2 += x2_stride); - gpu!(scope, inp_2 = input[index_2]); - gpu!(scope, index_3 = index_base); - gpu!(scope, index_3 += y3_stride); - gpu!(scope, index_3 += x3_stride); - gpu!(scope, inp_3 = input[index_3]); + cube_inline!(scope, index_0 = index_base); + cube_inline!(scope, index_0 += y3_stride); + cube_inline!(scope, index_0 += x0_stride); + cube_inline!(scope, inp_0 = input[index_0]); + cube_inline!(scope, index_1 = index_base); + cube_inline!(scope, index_1 += y3_stride); + cube_inline!(scope, index_1 += x1_stride); + cube_inline!(scope, inp_1 = input[index_1]); + cube_inline!(scope, index_2 = index_base); + cube_inline!(scope, index_2 += y3_stride); + cube_inline!(scope, index_2 += x2_stride); + cube_inline!(scope, inp_2 = input[index_2]); + cube_inline!(scope, index_3 = index_base); + cube_inline!(scope, index_3 += y3_stride); + cube_inline!(scope, index_3 += x3_stride); + cube_inline!(scope, inp_3 = input[index_3]); let coefficients3 = Self::cubic_interp1d(scope, inp_0, inp_1, inp_2, inp_3, xw); @@ -272,18 +272,18 @@ impl InterpolateBicubicShader { yw, ); - gpu!(scope, output[id] = val); + cube_inline!(scope, output[id] = val); } fn min(scope: &mut Scope, a: Variable, b: Variable) -> Variable { let cond = scope.create_local(Elem::Bool); let res = scope.create_local(a.item()); - gpu!(scope, cond = a < b); - gpu!(scope, if(cond).then(|scope|{ - gpu!(scope, res = a); + cube_inline!(scope, cond = a < b); + cube_inline!(scope, if(cond).then(|scope|{ + cube_inline!(scope, res = a); }).else(|scope|{ - gpu!(scope, res = b); + cube_inline!(scope, res = b); })); res @@ -305,24 +305,24 @@ impl InterpolateBicubicShader { let cubic = scope.create_local(item); let cubic_tmp = scope.create_local(item); - gpu!(scope, x = t + one); + cube_inline!(scope, x = t + one); let coeffs0 = Self::cubic_convolution2(scope, x, a); let coeffs1 = Self::cubic_convolution1(scope, t, a); - gpu!(scope, x = one - t); + cube_inline!(scope, x = one - t); let coeffs2 = Self::cubic_convolution1(scope, x, a); - gpu!(scope, x = two - t); + cube_inline!(scope, x = two - t); let coeffs3 = Self::cubic_convolution2(scope, x, a); - gpu!(scope, cubic = x0 * coeffs0); - gpu!(scope, cubic_tmp = x1 * coeffs1); - gpu!(scope, cubic += cubic_tmp); - gpu!(scope, cubic_tmp = x2 * coeffs2); - gpu!(scope, cubic += cubic_tmp); - gpu!(scope, cubic_tmp = x3 * coeffs3); - gpu!(scope, cubic += cubic_tmp); + cube_inline!(scope, cubic = x0 * coeffs0); + cube_inline!(scope, cubic_tmp = x1 * coeffs1); + cube_inline!(scope, cubic += cubic_tmp); + cube_inline!(scope, cubic_tmp = x2 * coeffs2); + cube_inline!(scope, cubic += cubic_tmp); + cube_inline!(scope, cubic_tmp = x3 * coeffs3); + cube_inline!(scope, cubic += cubic_tmp); cubic } @@ -335,13 +335,13 @@ impl InterpolateBicubicShader { let two = scope.create_with_value(2, item); let three = scope.create_with_value(3, item); - gpu!(scope, conv = a + two); - gpu!(scope, conv *= x); - gpu!(scope, tmp = a + three); - gpu!(scope, conv = conv - tmp); - gpu!(scope, conv *= x); - gpu!(scope, conv *= x); - gpu!(scope, conv += one); + cube_inline!(scope, conv = a + two); + cube_inline!(scope, conv *= x); + cube_inline!(scope, tmp = a + three); + cube_inline!(scope, conv = conv - tmp); + cube_inline!(scope, conv *= x); + cube_inline!(scope, conv *= x); + cube_inline!(scope, conv += one); conv } @@ -354,15 +354,15 @@ impl InterpolateBicubicShader { let five = scope.create_with_value(5, item); let eight = scope.create_with_value(8, item); - gpu!(scope, conv = a * x); - gpu!(scope, tmp = five * a); - gpu!(scope, conv = conv - tmp); - gpu!(scope, conv *= x); - gpu!(scope, tmp = eight * a); - gpu!(scope, conv += tmp); - gpu!(scope, conv *= x); - gpu!(scope, tmp = four * a); - gpu!(scope, conv = conv - tmp); + cube_inline!(scope, conv = a * x); + cube_inline!(scope, tmp = five * a); + cube_inline!(scope, conv = conv - tmp); + cube_inline!(scope, conv *= x); + cube_inline!(scope, tmp = eight * a); + cube_inline!(scope, conv += tmp); + cube_inline!(scope, conv *= x); + cube_inline!(scope, tmp = four * a); + cube_inline!(scope, conv = conv - tmp); conv } diff --git a/crates/burn-jit/src/kernel/interpolate/bilinear.rs b/crates/burn-jit/src/kernel/interpolate/bilinear.rs index b73fe8eba2..197c44166c 100644 --- a/crates/burn-jit/src/kernel/interpolate/bilinear.rs +++ b/crates/burn-jit/src/kernel/interpolate/bilinear.rs @@ -5,7 +5,7 @@ use crate::{ Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, - gpu::{gpu, ComputeShader, Elem, Scope, Variable, Visibility}, + gpu::{cube_inline, ComputeShader, Elem, Scope, Variable, Visibility}, kernel::GpuComputeShaderPhase, tensor::JitTensor, JitElement, Runtime, @@ -46,40 +46,40 @@ impl InterpolateBilinearShader { let output_shape_2 = scope.create_local(Elem::UInt); let output_shape_3 = scope.create_local(Elem::UInt); - gpu!(scope, input_stride_0 = stride(input, 0u32)); - gpu!(scope, input_stride_1 = stride(input, 1u32)); - gpu!(scope, input_stride_2 = stride(input, 2u32)); - gpu!(scope, input_stride_3 = stride(input, 3u32)); + cube_inline!(scope, input_stride_0 = stride(input, 0u32)); + cube_inline!(scope, input_stride_1 = stride(input, 1u32)); + cube_inline!(scope, input_stride_2 = stride(input, 2u32)); + cube_inline!(scope, input_stride_3 = stride(input, 3u32)); - gpu!(scope, input_shape_2 = shape(input, 2u32)); - gpu!(scope, input_shape_3 = shape(input, 3u32)); + cube_inline!(scope, input_shape_2 = shape(input, 2u32)); + cube_inline!(scope, input_shape_3 = shape(input, 3u32)); - gpu!(scope, output_stride_0 = stride(output, 0u32)); - gpu!(scope, output_stride_1 = stride(output, 1u32)); - gpu!(scope, output_stride_2 = stride(output, 2u32)); - gpu!(scope, output_stride_3 = stride(output, 3u32)); + cube_inline!(scope, output_stride_0 = stride(output, 0u32)); + cube_inline!(scope, output_stride_1 = stride(output, 1u32)); + cube_inline!(scope, output_stride_2 = stride(output, 2u32)); + cube_inline!(scope, output_stride_3 = stride(output, 3u32)); - gpu!(scope, output_shape_0 = shape(output, 0u32)); - gpu!(scope, output_shape_1 = shape(output, 1u32)); - gpu!(scope, output_shape_2 = shape(output, 2u32)); - gpu!(scope, output_shape_3 = shape(output, 3u32)); + cube_inline!(scope, output_shape_0 = shape(output, 0u32)); + cube_inline!(scope, output_shape_1 = shape(output, 1u32)); + cube_inline!(scope, output_shape_2 = shape(output, 2u32)); + cube_inline!(scope, output_shape_3 = shape(output, 3u32)); let b = scope.create_local(Elem::UInt); let c = scope.create_local(Elem::UInt); let h = scope.create_local(Elem::UInt); let w = scope.create_local(Elem::UInt); - gpu!(scope, b = id / output_stride_0); - gpu!(scope, b = b % output_shape_0); + cube_inline!(scope, b = id / output_stride_0); + cube_inline!(scope, b = b % output_shape_0); - gpu!(scope, c = id / output_stride_1); - gpu!(scope, c = c % output_shape_1); + cube_inline!(scope, c = id / output_stride_1); + cube_inline!(scope, c = c % output_shape_1); - gpu!(scope, h = id / output_stride_2); - gpu!(scope, h = h % output_shape_2); + cube_inline!(scope, h = id / output_stride_2); + cube_inline!(scope, h = h % output_shape_2); - gpu!(scope, w = id / output_stride_3); - gpu!(scope, w = w % output_shape_3); + cube_inline!(scope, w = id / output_stride_3); + cube_inline!(scope, w = w % output_shape_3); let factor_float = scope.create_local(input.item()); let numerator_float = scope.create_local(input.item()); @@ -102,33 +102,33 @@ impl InterpolateBilinearShader { let xw = scope.create_local(input.item()); let xw_ = scope.create_local(input.item()); - gpu!(scope, numerator_int = input_shape_2 - 1u32); - gpu!(scope, denominator_int = output_shape_2 - 1u32); - gpu!(scope, factor_float = cast(h)); - gpu!(scope, numerator_float = cast(numerator_int)); - gpu!(scope, denominator_float = cast(denominator_int)); - gpu!(scope, frac = factor_float * numerator_float); - gpu!(scope, frac = frac / denominator_float); - gpu!(scope, v0 = floor(frac)); - gpu!(scope, v1 = ceil(frac)); - gpu!(scope, yw = frac - v0); - gpu!(scope, yw_ = one - yw); - gpu!(scope, y0 = cast(v0)); - gpu!(scope, y1 = cast(v1)); - - gpu!(scope, numerator_int = input_shape_3 - 1u32); - gpu!(scope, denominator_int = output_shape_3 - 1u32); - gpu!(scope, factor_float = cast(w)); - gpu!(scope, numerator_float = cast(numerator_int)); - gpu!(scope, denominator_float = cast(denominator_int)); - gpu!(scope, frac = factor_float * numerator_float); - gpu!(scope, frac = frac / denominator_float); - gpu!(scope, v0 = floor(frac)); - gpu!(scope, v1 = ceil(frac)); - gpu!(scope, xw = frac - v0); - gpu!(scope, xw_ = one - xw); - gpu!(scope, x0 = cast(v0)); - gpu!(scope, x1 = cast(v1)); + cube_inline!(scope, numerator_int = input_shape_2 - 1u32); + cube_inline!(scope, denominator_int = output_shape_2 - 1u32); + cube_inline!(scope, factor_float = cast(h)); + cube_inline!(scope, numerator_float = cast(numerator_int)); + cube_inline!(scope, denominator_float = cast(denominator_int)); + cube_inline!(scope, frac = factor_float * numerator_float); + cube_inline!(scope, frac = frac / denominator_float); + cube_inline!(scope, v0 = floor(frac)); + cube_inline!(scope, v1 = ceil(frac)); + cube_inline!(scope, yw = frac - v0); + cube_inline!(scope, yw_ = one - yw); + cube_inline!(scope, y0 = cast(v0)); + cube_inline!(scope, y1 = cast(v1)); + + cube_inline!(scope, numerator_int = input_shape_3 - 1u32); + cube_inline!(scope, denominator_int = output_shape_3 - 1u32); + cube_inline!(scope, factor_float = cast(w)); + cube_inline!(scope, numerator_float = cast(numerator_int)); + cube_inline!(scope, denominator_float = cast(denominator_int)); + cube_inline!(scope, frac = factor_float * numerator_float); + cube_inline!(scope, frac = frac / denominator_float); + cube_inline!(scope, v0 = floor(frac)); + cube_inline!(scope, v1 = ceil(frac)); + cube_inline!(scope, xw = frac - v0); + cube_inline!(scope, xw_ = one - xw); + cube_inline!(scope, x0 = cast(v0)); + cube_inline!(scope, x1 = cast(v1)); let index_base = scope.create_local(Elem::UInt); let index_tmp = scope.create_local(Elem::UInt); @@ -142,47 +142,47 @@ impl InterpolateBilinearShader { let p_c = scope.create_local(input.item()); let p_d = scope.create_local(input.item()); - gpu!(scope, index_base = b * input_stride_0); - gpu!(scope, index_tmp = c * input_stride_1); - gpu!(scope, index_base += index_tmp); - gpu!(scope, y0_stride = y0 * input_stride_2); - gpu!(scope, y1_stride = y1 * input_stride_2); - gpu!(scope, x0_stride = x0 * input_stride_3); - gpu!(scope, x1_stride = x1 * input_stride_3); - - gpu!(scope, index = index_base); - gpu!(scope, index += y0_stride); - gpu!(scope, index += x0_stride); - gpu!(scope, p_a = input[index]); - gpu!(scope, p_a *= xw_); - gpu!(scope, p_a *= yw_); - - gpu!(scope, index = index_base); - gpu!(scope, index += y0_stride); - gpu!(scope, index += x1_stride); - gpu!(scope, p_b = input[index]); - gpu!(scope, p_b *= xw); - gpu!(scope, p_b *= yw_); - - gpu!(scope, index = index_base); - gpu!(scope, index += y1_stride); - gpu!(scope, index += x0_stride); - gpu!(scope, p_c = input[index]); - gpu!(scope, p_c *= xw_); - gpu!(scope, p_c *= yw); - - gpu!(scope, index = index_base); - gpu!(scope, index += y1_stride); - gpu!(scope, index += x1_stride); - gpu!(scope, p_d = input[index]); - gpu!(scope, p_d *= xw); - gpu!(scope, p_d *= yw); + cube_inline!(scope, index_base = b * input_stride_0); + cube_inline!(scope, index_tmp = c * input_stride_1); + cube_inline!(scope, index_base += index_tmp); + cube_inline!(scope, y0_stride = y0 * input_stride_2); + cube_inline!(scope, y1_stride = y1 * input_stride_2); + cube_inline!(scope, x0_stride = x0 * input_stride_3); + cube_inline!(scope, x1_stride = x1 * input_stride_3); + + cube_inline!(scope, index = index_base); + cube_inline!(scope, index += y0_stride); + cube_inline!(scope, index += x0_stride); + cube_inline!(scope, p_a = input[index]); + cube_inline!(scope, p_a *= xw_); + cube_inline!(scope, p_a *= yw_); + + cube_inline!(scope, index = index_base); + cube_inline!(scope, index += y0_stride); + cube_inline!(scope, index += x1_stride); + cube_inline!(scope, p_b = input[index]); + cube_inline!(scope, p_b *= xw); + cube_inline!(scope, p_b *= yw_); + + cube_inline!(scope, index = index_base); + cube_inline!(scope, index += y1_stride); + cube_inline!(scope, index += x0_stride); + cube_inline!(scope, p_c = input[index]); + cube_inline!(scope, p_c *= xw_); + cube_inline!(scope, p_c *= yw); + + cube_inline!(scope, index = index_base); + cube_inline!(scope, index += y1_stride); + cube_inline!(scope, index += x1_stride); + cube_inline!(scope, p_d = input[index]); + cube_inline!(scope, p_d *= xw); + cube_inline!(scope, p_d *= yw); let sum = scope.create_local(input.item()); - gpu!(scope, sum = p_a + p_b); - gpu!(scope, sum += p_c); - gpu!(scope, sum += p_d); - gpu!(scope, output[id] = sum); + cube_inline!(scope, sum = p_a + p_b); + cube_inline!(scope, sum += p_c); + cube_inline!(scope, sum += p_d); + cube_inline!(scope, output[id] = sum); } } diff --git a/crates/burn-jit/src/kernel/interpolate/nearest.rs b/crates/burn-jit/src/kernel/interpolate/nearest.rs index 776a58a2e0..1edc4b740f 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest.rs @@ -5,7 +5,7 @@ use crate::{ Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, - gpu::{gpu, ComputeShader, Elem, Scope, Variable, Visibility}, + gpu::{cube_inline, ComputeShader, Elem, Scope, Variable, Visibility}, kernel::GpuComputeShaderPhase, tensor::JitTensor, JitElement, Runtime, @@ -48,40 +48,40 @@ impl InterpolateNearestShader { let output_shape_2 = scope.create_local(Elem::UInt); let output_shape_3 = scope.create_local(Elem::UInt); - gpu!(scope, input_stride_0 = stride(input, 0u32)); - gpu!(scope, input_stride_1 = stride(input, 1u32)); - gpu!(scope, input_stride_2 = stride(input, 2u32)); - gpu!(scope, input_stride_3 = stride(input, 3u32)); + cube_inline!(scope, input_stride_0 = stride(input, 0u32)); + cube_inline!(scope, input_stride_1 = stride(input, 1u32)); + cube_inline!(scope, input_stride_2 = stride(input, 2u32)); + cube_inline!(scope, input_stride_3 = stride(input, 3u32)); - gpu!(scope, input_shape_2 = shape(input, 2u32)); - gpu!(scope, input_shape_3 = shape(input, 3u32)); + cube_inline!(scope, input_shape_2 = shape(input, 2u32)); + cube_inline!(scope, input_shape_3 = shape(input, 3u32)); - gpu!(scope, output_stride_0 = stride(output, 0u32)); - gpu!(scope, output_stride_1 = stride(output, 1u32)); - gpu!(scope, output_stride_2 = stride(output, 2u32)); - gpu!(scope, output_stride_3 = stride(output, 3u32)); + cube_inline!(scope, output_stride_0 = stride(output, 0u32)); + cube_inline!(scope, output_stride_1 = stride(output, 1u32)); + cube_inline!(scope, output_stride_2 = stride(output, 2u32)); + cube_inline!(scope, output_stride_3 = stride(output, 3u32)); - gpu!(scope, output_shape_0 = shape(output, 0u32)); - gpu!(scope, output_shape_1 = shape(output, 1u32)); - gpu!(scope, output_shape_2 = shape(output, 2u32)); - gpu!(scope, output_shape_3 = shape(output, 3u32)); + cube_inline!(scope, output_shape_0 = shape(output, 0u32)); + cube_inline!(scope, output_shape_1 = shape(output, 1u32)); + cube_inline!(scope, output_shape_2 = shape(output, 2u32)); + cube_inline!(scope, output_shape_3 = shape(output, 3u32)); let b = scope.create_local(Elem::UInt); let c = scope.create_local(Elem::UInt); let h = scope.create_local(Elem::UInt); let w = scope.create_local(Elem::UInt); - gpu!(scope, b = id / output_stride_0); - gpu!(scope, b = b % output_shape_0); + cube_inline!(scope, b = id / output_stride_0); + cube_inline!(scope, b = b % output_shape_0); - gpu!(scope, c = id / output_stride_1); - gpu!(scope, c = c % output_shape_1); + cube_inline!(scope, c = id / output_stride_1); + cube_inline!(scope, c = c % output_shape_1); - gpu!(scope, h = id / output_stride_2); - gpu!(scope, h = h % output_shape_2); + cube_inline!(scope, h = id / output_stride_2); + cube_inline!(scope, h = h % output_shape_2); - gpu!(scope, w = id / output_stride_3); - gpu!(scope, w = w % output_shape_3); + cube_inline!(scope, w = id / output_stride_3); + cube_inline!(scope, w = w % output_shape_3); let factor_float = scope.create_local(elem); let numerator_float = scope.create_local(elem); @@ -91,36 +91,36 @@ impl InterpolateNearestShader { let xu = scope.create_local(Elem::UInt); let yu = scope.create_local(Elem::UInt); - gpu!(scope, factor_float = cast(h)); - gpu!(scope, numerator_float = cast(input_shape_2)); - gpu!(scope, denominator_float = cast(output_shape_2)); - gpu!(scope, y = factor_float * numerator_float); - gpu!(scope, y = y / denominator_float); - gpu!(scope, y = floor(y)); - gpu!(scope, yu = cast(y)); - - gpu!(scope, factor_float = cast(w)); - gpu!(scope, numerator_float = cast(input_shape_3)); - gpu!(scope, denominator_float = cast(output_shape_3)); - gpu!(scope, x = factor_float * numerator_float); - gpu!(scope, x = x / denominator_float); - gpu!(scope, x = floor(x)); - gpu!(scope, xu = cast(x)); + cube_inline!(scope, factor_float = cast(h)); + cube_inline!(scope, numerator_float = cast(input_shape_2)); + cube_inline!(scope, denominator_float = cast(output_shape_2)); + cube_inline!(scope, y = factor_float * numerator_float); + cube_inline!(scope, y = y / denominator_float); + cube_inline!(scope, y = floor(y)); + cube_inline!(scope, yu = cast(y)); + + cube_inline!(scope, factor_float = cast(w)); + cube_inline!(scope, numerator_float = cast(input_shape_3)); + cube_inline!(scope, denominator_float = cast(output_shape_3)); + cube_inline!(scope, x = factor_float * numerator_float); + cube_inline!(scope, x = x / denominator_float); + cube_inline!(scope, x = floor(x)); + cube_inline!(scope, xu = cast(x)); let index = scope.create_local(Elem::UInt); let index_tmp = scope.create_local(Elem::UInt); let val = scope.create_local(output.item()); - gpu!(scope, index = b * input_stride_0); - gpu!(scope, index_tmp = c * input_stride_1); - gpu!(scope, index += index_tmp); - gpu!(scope, index_tmp = yu * input_stride_2); - gpu!(scope, index += index_tmp); - gpu!(scope, index_tmp = xu * input_stride_3); - gpu!(scope, index += index_tmp); + cube_inline!(scope, index = b * input_stride_0); + cube_inline!(scope, index_tmp = c * input_stride_1); + cube_inline!(scope, index += index_tmp); + cube_inline!(scope, index_tmp = yu * input_stride_2); + cube_inline!(scope, index += index_tmp); + cube_inline!(scope, index_tmp = xu * input_stride_3); + cube_inline!(scope, index += index_tmp); - gpu!(scope, val = input[index]); - gpu!(scope, output[id] = val); + cube_inline!(scope, val = input[index]); + cube_inline!(scope, output[id] = val); } } diff --git a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs index d52438f60c..046ba0594c 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs @@ -5,7 +5,7 @@ use crate::{ Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, - gpu::{gpu, ComputeShader, Elem, Scope, Variable, Visibility}, + gpu::{cube_inline, ComputeShader, Elem, Scope, Variable, Visibility}, kernel::GpuComputeShaderPhase, tensor::JitTensor, JitElement, Runtime, @@ -49,42 +49,42 @@ impl InterpolateNearestBackwardShader { let output_shape_2 = scope.create_local(Elem::UInt); let output_shape_3 = scope.create_local(Elem::UInt); - gpu!(scope, grad_stride_0 = stride(grad, 0u32)); - gpu!(scope, grad_stride_1 = stride(grad, 1u32)); - gpu!(scope, grad_stride_2 = stride(grad, 2u32)); - gpu!(scope, grad_stride_3 = stride(grad, 3u32)); + cube_inline!(scope, grad_stride_0 = stride(grad, 0u32)); + cube_inline!(scope, grad_stride_1 = stride(grad, 1u32)); + cube_inline!(scope, grad_stride_2 = stride(grad, 2u32)); + cube_inline!(scope, grad_stride_3 = stride(grad, 3u32)); - gpu!(scope, grad_shape_0 = shape(grad, 0u32)); - gpu!(scope, grad_shape_1 = shape(grad, 1u32)); - gpu!(scope, grad_shape_2 = shape(grad, 2u32)); - gpu!(scope, grad_shape_3 = shape(grad, 3u32)); + cube_inline!(scope, grad_shape_0 = shape(grad, 0u32)); + cube_inline!(scope, grad_shape_1 = shape(grad, 1u32)); + cube_inline!(scope, grad_shape_2 = shape(grad, 2u32)); + cube_inline!(scope, grad_shape_3 = shape(grad, 3u32)); - gpu!(scope, output_stride_0 = stride(output, 0u32)); - gpu!(scope, output_stride_1 = stride(output, 1u32)); - gpu!(scope, output_stride_2 = stride(output, 2u32)); - gpu!(scope, output_stride_3 = stride(output, 3u32)); + cube_inline!(scope, output_stride_0 = stride(output, 0u32)); + cube_inline!(scope, output_stride_1 = stride(output, 1u32)); + cube_inline!(scope, output_stride_2 = stride(output, 2u32)); + cube_inline!(scope, output_stride_3 = stride(output, 3u32)); - gpu!(scope, output_shape_0 = shape(output, 0u32)); - gpu!(scope, output_shape_1 = shape(output, 1u32)); - gpu!(scope, output_shape_2 = shape(output, 2u32)); - gpu!(scope, output_shape_3 = shape(output, 3u32)); + cube_inline!(scope, output_shape_0 = shape(output, 0u32)); + cube_inline!(scope, output_shape_1 = shape(output, 1u32)); + cube_inline!(scope, output_shape_2 = shape(output, 2u32)); + cube_inline!(scope, output_shape_3 = shape(output, 3u32)); let b = scope.create_local(Elem::UInt); let c = scope.create_local(Elem::UInt); let oh = scope.create_local(Elem::UInt); let ow = scope.create_local(Elem::UInt); - gpu!(scope, b = id / output_stride_0); - gpu!(scope, b = b % output_shape_0); + cube_inline!(scope, b = id / output_stride_0); + cube_inline!(scope, b = b % output_shape_0); - gpu!(scope, c = id / output_stride_1); - gpu!(scope, c = c % output_shape_1); + cube_inline!(scope, c = id / output_stride_1); + cube_inline!(scope, c = c % output_shape_1); - gpu!(scope, oh = id / output_stride_2); - gpu!(scope, oh = oh % output_shape_2); + cube_inline!(scope, oh = id / output_stride_2); + cube_inline!(scope, oh = oh % output_shape_2); - gpu!(scope, ow = id / output_stride_3); - gpu!(scope, ow = ow % output_shape_3); + cube_inline!(scope, ow = id / output_stride_3); + cube_inline!(scope, ow = ow % output_shape_3); let gh_start = Self::start_index(scope, oh, grad_shape_2, output_shape_2); let gh_end = Self::end_index(scope, oh, grad_shape_2, output_shape_2); @@ -99,34 +99,34 @@ impl InterpolateNearestBackwardShader { let index_grad_2 = scope.create_local(Elem::UInt); let index_grad_3 = scope.create_local(Elem::UInt); - gpu!(scope, index_grad_0 = b * grad_stride_0); - gpu!(scope, index_grad_1 = c * grad_stride_1); + cube_inline!(scope, index_grad_0 = b * grad_stride_0); + cube_inline!(scope, index_grad_1 = c * grad_stride_1); let sum = scope.zero(output.item()); - gpu!( + cube_inline!( scope, range(gh_start, gh_end).for_each(|gh, scope| { - gpu!( + cube_inline!( scope, range(gw_start, gw_end).for_each(|gw, scope| { - gpu!(scope, index_grad_2 = gh * grad_stride_2); - gpu!(scope, index_grad_3 = gw * grad_stride_3); + cube_inline!(scope, index_grad_2 = gh * grad_stride_2); + cube_inline!(scope, index_grad_3 = gw * grad_stride_3); - gpu!(scope, index_grad = index_grad_0); - gpu!(scope, index_grad += index_grad_1); - gpu!(scope, index_grad += index_grad_2); - gpu!(scope, index_grad += index_grad_3); + cube_inline!(scope, index_grad = index_grad_0); + cube_inline!(scope, index_grad += index_grad_1); + cube_inline!(scope, index_grad += index_grad_2); + cube_inline!(scope, index_grad += index_grad_3); - gpu!(scope, result = grad[index_grad]); + cube_inline!(scope, result = grad[index_grad]); - gpu!(scope, sum += result); + cube_inline!(scope, sum += result); }) ); }) ); - gpu!(scope, output[id] = sum); + cube_inline!(scope, output[id] = sum); } fn start_index( @@ -140,12 +140,12 @@ impl InterpolateNearestBackwardShader { let div = scope.create_local(elem); let index = scope.create_local(Elem::UInt); - gpu!(scope, index = input_index * output_size); - gpu!(scope, numerator_float = cast(index)); - gpu!(scope, div = cast(input_size)); - gpu!(scope, div = numerator_float / div); - gpu!(scope, div = ceil(div)); - gpu!(scope, index = cast(div)); + cube_inline!(scope, index = input_index * output_size); + cube_inline!(scope, numerator_float = cast(index)); + cube_inline!(scope, div = cast(input_size)); + cube_inline!(scope, div = numerator_float / div); + cube_inline!(scope, div = ceil(div)); + cube_inline!(scope, index = cast(div)); index } @@ -163,19 +163,19 @@ impl InterpolateNearestBackwardShader { let min = scope.create_local(Elem::Bool); let end_index = scope.create_local(Elem::UInt); - gpu!(scope, index = input_index + 1u32); - gpu!(scope, index *= output_size); - gpu!(scope, numerator_float = cast(index)); - gpu!(scope, div = cast(input_size)); - gpu!(scope, div = numerator_float / div); - gpu!(scope, div = ceil(div)); - gpu!(scope, index = cast(div)); - - gpu!(scope, min = output_size < index); - gpu!(scope, if(min).then(|scope|{ - gpu!(scope, end_index = output_size); + cube_inline!(scope, index = input_index + 1u32); + cube_inline!(scope, index *= output_size); + cube_inline!(scope, numerator_float = cast(index)); + cube_inline!(scope, div = cast(input_size)); + cube_inline!(scope, div = numerator_float / div); + cube_inline!(scope, div = ceil(div)); + cube_inline!(scope, index = cast(div)); + + cube_inline!(scope, min = output_size < index); + cube_inline!(scope, if(min).then(|scope|{ + cube_inline!(scope, end_index = output_size); }).else(|scope|{ - gpu!(scope, end_index = index); + cube_inline!(scope, end_index = index); })); end_index diff --git a/crates/burn-jit/src/kernel/mask/shader.rs b/crates/burn-jit/src/kernel/mask/shader.rs index 3615ed30c1..0cfce5767b 100644 --- a/crates/burn-jit/src/kernel/mask/shader.rs +++ b/crates/burn-jit/src/kernel/mask/shader.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use crate::{ codegen::{Compilation, CompilationInfo, CompilationSettings, InputInfo, OutputInfo}, gpu::{ - gpu, ComputeShader, Elem, IndexOffsetGlobalWithLayout, Item, Scope, Variable, Visibility, + cube_inline, ComputeShader, Elem, IndexOffsetGlobalWithLayout, Item, Scope, Variable, Visibility, }, kernel::GpuComputeShaderPhase, JitElement, Runtime, @@ -30,7 +30,7 @@ impl MaskStrategy for MaskFill { value: Variable, _index: Variable, ) -> Variable { - gpu!(scope, masked_value = value); + cube_inline!(scope, masked_value = value); masked_value } @@ -55,7 +55,7 @@ impl MaskStrategy for MaskWhere { value: Variable, index: Variable, ) -> Variable { - gpu!(scope, masked_value = value[index]); + cube_inline!(scope, masked_value = value[index]); masked_value } @@ -248,22 +248,22 @@ impl MaskShader { // Determine if index should be masked let value_in_mask = scope.create_local(mask.item()); - gpu!(scope, value_in_mask = mask[index_mask]); + cube_inline!(scope, value_in_mask = mask[index_mask]); let masked = scope.create_local(Elem::Bool); let zero = scope.zero(value_in_mask.item()); if self.reversed { - gpu!(scope, masked = value_in_mask == zero); + cube_inline!(scope, masked = value_in_mask == zero); } else { - gpu!(scope, masked = value_in_mask != zero); + cube_inline!(scope, masked = value_in_mask != zero); } // Assign a value at the index let used_value = scope.create_local(output.item()); - gpu!(scope, if(masked).then(|scope| { + cube_inline!(scope, if(masked).then(|scope| { M::mask(scope, used_value, value, index_input ); }).else(|scope| { - gpu!(scope, used_value = input[index_input]); + cube_inline!(scope, used_value = input[index_input]); })); - gpu!(scope, output[id] = used_value); + cube_inline!(scope, output[id] = used_value); } } diff --git a/crates/burn-jit/src/kernel/matmul/simple.rs b/crates/burn-jit/src/kernel/matmul/simple.rs index f3ba6843f5..97d597dd55 100644 --- a/crates/burn-jit/src/kernel/matmul/simple.rs +++ b/crates/burn-jit/src/kernel/matmul/simple.rs @@ -1,5 +1,5 @@ use crate::codegen::dialect::gpu::{ - gpu, BinaryOperator, Branch, Elem, IndexOffsetGlobalWithLayout, Scope, Variable, + cube_inline, BinaryOperator, Branch, Elem, IndexOffsetGlobalWithLayout, Scope, Variable, }; use crate::codegen::Execution; use crate::gpu::ComputeShader; @@ -50,17 +50,17 @@ impl MatmulComputeShader { let col = scope.create_local(Elem::UInt); // Row position. - gpu!(scope, tmp_index = local_idx / block_size); - gpu!(scope, row = block_size * Variable::WorkgroupIdX); - gpu!(scope, row = row + tmp_index); + cube_inline!(scope, tmp_index = local_idx / block_size); + cube_inline!(scope, row = block_size * Variable::WorkgroupIdX); + cube_inline!(scope, row = row + tmp_index); // Col position. - gpu!(scope, tmp_index = local_idx % block_size); - gpu!(scope, col = block_size * Variable::WorkgroupIdY); - gpu!(scope, col = col + tmp_index); + cube_inline!(scope, tmp_index = local_idx % block_size); + cube_inline!(scope, col = block_size * Variable::WorkgroupIdY); + cube_inline!(scope, col = col + tmp_index); // Batch position. - gpu!(scope, batch_dims = rank - 2u32); + cube_inline!(scope, batch_dims = rank - 2u32); // Define the matrix size. let n_rows = scope.create_local(Elem::UInt); @@ -68,24 +68,24 @@ impl MatmulComputeShader { let k = scope.create_local(Elem::UInt); // Number of rows. - gpu!(scope, n_rows = shape(out, batch_dims)); + cube_inline!(scope, n_rows = shape(out, batch_dims)); // Number of cols. - gpu!(scope, tmp_index = batch_dims + 1u32); - gpu!(scope, n_cols = shape(out, tmp_index)); + cube_inline!(scope, tmp_index = batch_dims + 1u32); + cube_inline!(scope, n_cols = shape(out, tmp_index)); // The dimension that is going to be squashed. - gpu!(scope, k = shape(lhs, tmp_index)); + cube_inline!(scope, k = shape(lhs, tmp_index)); // Check if there is some work to be done. let should_stop = scope.create_local(Elem::Bool); - gpu!(scope, should_stop = row >= n_rows); - gpu!(scope, if (should_stop).then(|scope| { + cube_inline!(scope, should_stop = row >= n_rows); + cube_inline!(scope, if (should_stop).then(|scope| { scope.register(Branch::Return); })); - gpu!(scope, should_stop = col >= n_cols); - gpu!(scope, if (should_stop).then(|scope| { + cube_inline!(scope, should_stop = col >= n_cols); + cube_inline!(scope, if (should_stop).then(|scope| { scope.register(Branch::Return); })); @@ -95,8 +95,8 @@ impl MatmulComputeShader { let offset_output = scope.create_local(Elem::UInt); // Batch offset for the output. - gpu!(scope, offset_output = n_rows * n_cols); - gpu!(scope, offset_output = offset_output * batch); + cube_inline!(scope, offset_output = n_rows * n_cols); + cube_inline!(scope, offset_output = offset_output * batch); // Batch offset for the lhs & rhs matrices. IndexOffsetGlobalWithLayout { @@ -114,10 +114,10 @@ impl MatmulComputeShader { // Initialize the sum to zero. let zero: Variable = 0f32.into(); - gpu!(scope, sum = zero); + cube_inline!(scope, sum = zero); // Loop over the k dimension. - gpu!( + cube_inline!( scope, range(0u32, k).for_each(|i, scope| { let lhs_index = scope.create_local(Elem::UInt); @@ -127,28 +127,28 @@ impl MatmulComputeShader { let rhs_value = scope.create_local(rhs.item()); let out_value = scope.create_local(out.item()); - gpu!(scope, lhs_index = row * k); - gpu!(scope, lhs_index = lhs_index + i); - gpu!(scope, lhs_index = lhs_index + offset_lhs); + cube_inline!(scope, lhs_index = row * k); + cube_inline!(scope, lhs_index = lhs_index + i); + cube_inline!(scope, lhs_index = lhs_index + offset_lhs); - gpu!(scope, rhs_index = i * n_cols); - gpu!(scope, rhs_index = rhs_index + col); - gpu!(scope, rhs_index = rhs_index + offset_rhs); + cube_inline!(scope, rhs_index = i * n_cols); + cube_inline!(scope, rhs_index = rhs_index + col); + cube_inline!(scope, rhs_index = rhs_index + offset_rhs); - gpu!(scope, lhs_value = lhs[lhs_index]); - gpu!(scope, rhs_value = rhs[rhs_index]); + cube_inline!(scope, lhs_value = lhs[lhs_index]); + cube_inline!(scope, rhs_value = rhs[rhs_index]); - gpu!(scope, out_value = lhs_value * rhs_value); - gpu!(scope, sum += out_value); + cube_inline!(scope, out_value = lhs_value * rhs_value); + cube_inline!(scope, sum += out_value); }) ); let out_index = scope.create_local(Elem::UInt); - gpu!(scope, out_index = row * n_cols); - gpu!(scope, out_index += col); - gpu!(scope, out_index += offset_output); - gpu!(scope, out[out_index] = sum); + cube_inline!(scope, out_index = row * n_cols); + cube_inline!(scope, out_index += col); + cube_inline!(scope, out_index += offset_output); + cube_inline!(scope, out[out_index] = sum); } } diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs index 06f0f1bf08..3810723c4b 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs @@ -1,4 +1,4 @@ -use crate::gpu::{gpu, BinaryOperator, Scope, Synchronization, Variable}; +use crate::gpu::{cube_inline, BinaryOperator, Scope, Synchronization, Variable}; use crate::kernel::matmul::tiling2d_shader::{ computation_loop, gather_shader_information, load_shared_memory, write_to_output, @@ -45,12 +45,12 @@ impl MatmulTiling2dShader { let shader_state = gather_shader_information(scope, &self); let block_size_k: Variable = self.config.block_size_k.into(); - gpu!( + cube_inline!( scope, range(0u32, shader_state.n_loops).for_each(|i, scope| { // From 0 to K with steps block_size_k let k = shader_state.k; - gpu!(scope, k = i * block_size_k); + cube_inline!(scope, k = i * block_size_k); load_shared_memory(scope, &self, &shader_state); diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs index 1c0675fc60..7c9465edfe 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs @@ -1,4 +1,4 @@ -use crate::gpu::{gpu, Elem, Scope, Variable}; +use crate::gpu::{cube_inline, Elem, Scope, Variable}; use super::{MatmulTiling2dShader, Tiling2dState}; @@ -31,7 +31,7 @@ pub fn computation_loop( let results_before = scope.create_local(elem); let results_after = scope.create_local(elem); - gpu!( + cube_inline!( scope, range( 0u32, @@ -40,40 +40,40 @@ pub fn computation_loop( ) .for_each(|dot_index, scope| { // Load a subcolumn of values from lhs - gpu!(scope, lhs_sm_position = thread_row / 4u32); - gpu!(scope, lhs_sm_position *= block_size_k); - gpu!(scope, lhs_sm_position += dot_index); - gpu!(scope, register_m = shared_lhs[lhs_sm_position]); + cube_inline!(scope, lhs_sm_position = thread_row / 4u32); + cube_inline!(scope, lhs_sm_position *= block_size_k); + cube_inline!(scope, lhs_sm_position += dot_index); + cube_inline!(scope, register_m = shared_lhs[lhs_sm_position]); // Load a subrow of values from rhs - gpu!(scope, rhs_sm_position = dot_index * block_size_n); - gpu!(scope, rhs_sm_position += thread_col); - gpu!(scope, rhs_sm_position = rhs_sm_position / 4u32); - gpu!(scope, register_n = shared_rhs[rhs_sm_position]); + cube_inline!(scope, rhs_sm_position = dot_index * block_size_n); + cube_inline!(scope, rhs_sm_position += thread_col); + cube_inline!(scope, rhs_sm_position = rhs_sm_position / 4u32); + cube_inline!(scope, register_n = shared_rhs[rhs_sm_position]); - gpu!( + cube_inline!( scope, range(0u32, shader.config.tile_size_m as u32, shader.config.unroll).for_each( |res_idx_m, scope| { - gpu!( + cube_inline!( scope, range(0u32, shader.config.tile_size_n as u32, shader.config.unroll) .for_each(|res_idx_n, scope| { - gpu!(scope, registered_m = register_m[res_idx_m]); - gpu!(scope, registered_n = register_n[res_idx_n]); + cube_inline!(scope, registered_m = register_m[res_idx_m]); + cube_inline!(scope, registered_n = register_n[res_idx_n]); - gpu!(scope, multiplied = registered_m * registered_n); + cube_inline!(scope, multiplied = registered_m * registered_n); - gpu!( + cube_inline!( scope, results_position = res_idx_m * shader.config.tile_size_n ); - gpu!(scope, results_position += res_idx_n); + cube_inline!(scope, results_position += res_idx_n); - gpu!(scope, results_before = results[results_position]); - gpu!(scope, results_after = results_before + multiplied); + cube_inline!(scope, results_before = results[results_position]); + cube_inline!(scope, results_after = results_before + multiplied); - gpu!(scope, results[results_position] = results_after); + cube_inline!(scope, results[results_position] = results_after); }) ); } diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs index 567566fe6b..c6c6240c18 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs @@ -1,4 +1,4 @@ -use crate::gpu::{gpu, Elem, Scope, Variable}; +use crate::gpu::{cube_inline, Elem, Scope, Variable}; use super::{MatmulTiling2dShader, Tiling2dState}; @@ -68,7 +68,7 @@ fn load_shared_memory_with_bound_check( // How close is the thread to the end of the matrix. // If < 4 then it is an edge case let remain = scope.create_local(Elem::UInt); - gpu!(scope, remain = dim - pos_in_dim); + cube_inline!(scope, remain = dim - pos_in_dim); let block_size_k: Variable = shader.config.block_size_k.into(); let block_size_n: Variable = shader.config.block_size_n.into(); @@ -96,88 +96,88 @@ fn load_shared_memory_with_bound_check( let val_3 = scope.create_local(elem); let zero: Variable = 0u32.into(); - gpu!( + cube_inline!( scope, range(0_u32, 4u32, shader.config.unroll).for_each(|j, scope| { - gpu!(scope, current = thread_idx_1 + j); + cube_inline!(scope, current = thread_idx_1 + j); - gpu!(scope, aligned_with_shared_memory = current < block_size_k); + cube_inline!(scope, aligned_with_shared_memory = current < block_size_k); // To avoid overwriting following row in shared memory - gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ + cube_inline!(scope, if(aligned_with_shared_memory).then(|scope|{ // Position in shared memory match input_identifier { InputIdentifier::Lhs => { - gpu!(scope, sm_position = thread_idx_2 / 4u32); - gpu!(scope, sm_position *= block_size_k); - gpu!(scope, sm_position += current); + cube_inline!(scope, sm_position = thread_idx_2 / 4u32); + cube_inline!(scope, sm_position *= block_size_k); + cube_inline!(scope, sm_position += current); }, InputIdentifier::Rhs => { - gpu!(scope, sm_position = current * block_size_n); - gpu!(scope, sm_position += thread_idx_2); - gpu!(scope, sm_position = sm_position / 4u32); + cube_inline!(scope, sm_position = current * block_size_n); + cube_inline!(scope, sm_position += thread_idx_2); + cube_inline!(scope, sm_position = sm_position / 4u32); } } // To pad with zeros if outside lhs - gpu!(scope, current_with_k = current + k); - gpu!(scope, within_input = current_with_k < dim_k); - gpu!(scope, remain_at_least_1 = remain >= 1u32); - gpu!(scope, read_condition = within_input && remain_at_least_1); - - gpu!(scope, if(read_condition).then(|scope| { - gpu!(scope, position_0 = k + current); - gpu!(scope, position_0 *= stride_1); - gpu!(scope, tmp = thread_idx_2 * stride_2); - gpu!(scope, position_0 += tmp); - gpu!(scope, position_0 += input_offset); - gpu!(scope, position_1 = position_0 + stride_2); - gpu!(scope, position_2 = position_1 + stride_2); - gpu!(scope, position_3 = position_2 + stride_2); - - gpu!(scope, remain_n = remain >= 4u32); - gpu!(scope, if(remain_n).then(|scope|{ - gpu!(scope, val_0 = input[position_0]); - gpu!(scope, val_1 = input[position_1]); - gpu!(scope, val_2 = input[position_2]); - gpu!(scope, val_3 = input[position_3]); + cube_inline!(scope, current_with_k = current + k); + cube_inline!(scope, within_input = current_with_k < dim_k); + cube_inline!(scope, remain_at_least_1 = remain >= 1u32); + cube_inline!(scope, read_condition = within_input && remain_at_least_1); + + cube_inline!(scope, if(read_condition).then(|scope| { + cube_inline!(scope, position_0 = k + current); + cube_inline!(scope, position_0 *= stride_1); + cube_inline!(scope, tmp = thread_idx_2 * stride_2); + cube_inline!(scope, position_0 += tmp); + cube_inline!(scope, position_0 += input_offset); + cube_inline!(scope, position_1 = position_0 + stride_2); + cube_inline!(scope, position_2 = position_1 + stride_2); + cube_inline!(scope, position_3 = position_2 + stride_2); + + cube_inline!(scope, remain_n = remain >= 4u32); + cube_inline!(scope, if(remain_n).then(|scope|{ + cube_inline!(scope, val_0 = input[position_0]); + cube_inline!(scope, val_1 = input[position_1]); + cube_inline!(scope, val_2 = input[position_2]); + cube_inline!(scope, val_3 = input[position_3]); }).else(|scope|{ - gpu!(scope, remain_n = remain == 3u32); - gpu!(scope, if(remain_n).then(|scope|{ - gpu!(scope, val_0 = input[position_0]); - gpu!(scope, val_1 = input[position_1]); - gpu!(scope, val_2 = input[position_2]); - gpu!(scope, val_3 = zero); + cube_inline!(scope, remain_n = remain == 3u32); + cube_inline!(scope, if(remain_n).then(|scope|{ + cube_inline!(scope, val_0 = input[position_0]); + cube_inline!(scope, val_1 = input[position_1]); + cube_inline!(scope, val_2 = input[position_2]); + cube_inline!(scope, val_3 = zero); }).else(|scope|{ - gpu!(scope, remain_n = remain == 2u32); - gpu!(scope, if(remain_n).then(|scope|{ - gpu!(scope, val_0 = input[position_0]); - gpu!(scope, val_1 = input[position_1]); - gpu!(scope, val_2 = zero); - gpu!(scope, val_3 = zero); + cube_inline!(scope, remain_n = remain == 2u32); + cube_inline!(scope, if(remain_n).then(|scope|{ + cube_inline!(scope, val_0 = input[position_0]); + cube_inline!(scope, val_1 = input[position_1]); + cube_inline!(scope, val_2 = zero); + cube_inline!(scope, val_3 = zero); }).else(|scope|{ - gpu!(scope, remain_n = remain == 1u32); - gpu!(scope, if(remain_n).then(|scope|{ - gpu!(scope, val_0 = input[position_0]); - gpu!(scope, val_1 = zero); - gpu!(scope, val_2 = zero); - gpu!(scope, val_3 = zero); + cube_inline!(scope, remain_n = remain == 1u32); + cube_inline!(scope, if(remain_n).then(|scope|{ + cube_inline!(scope, val_0 = input[position_0]); + cube_inline!(scope, val_1 = zero); + cube_inline!(scope, val_2 = zero); + cube_inline!(scope, val_3 = zero); })); })); })); })); - gpu!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3)); - gpu!(scope, shared_memory[sm_position] = val_vec4); + cube_inline!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3)); + cube_inline!(scope, shared_memory[sm_position] = val_vec4); }).else(|scope|{ - gpu!(scope, val_0 = zero); - gpu!(scope, val_vec4 = vec4(val_0, val_0, val_0, val_0)); - gpu!(scope, shared_memory[sm_position] = val_vec4); + cube_inline!(scope, val_0 = zero); + cube_inline!(scope, val_vec4 = vec4(val_0, val_0, val_0, val_0)); + cube_inline!(scope, shared_memory[sm_position] = val_vec4); })); })); }) @@ -233,45 +233,45 @@ fn load_shared_memory_no_bound_check( let val_3 = scope.create_local(elem); let val_vec4 = scope.create_local(shared_memory.item()); - gpu!( + cube_inline!( scope, range(0_u32, 4u32, shader.config.unroll).for_each(|j, scope| { - gpu!(scope, current = thread_idx_1 + j); + cube_inline!(scope, current = thread_idx_1 + j); - gpu!(scope, aligned_with_shared_memory = current < block_size_k); + cube_inline!(scope, aligned_with_shared_memory = current < block_size_k); // To avoid overwriting following row in shared memory - gpu!(scope, if(aligned_with_shared_memory).then(|scope|{ + cube_inline!(scope, if(aligned_with_shared_memory).then(|scope|{ match input_identifier { InputIdentifier::Lhs => { - gpu!(scope, sm_position = thread_idx_2 / 4u32); - gpu!(scope, sm_position *= block_size_k); - gpu!(scope, sm_position += current); + cube_inline!(scope, sm_position = thread_idx_2 / 4u32); + cube_inline!(scope, sm_position *= block_size_k); + cube_inline!(scope, sm_position += current); }, InputIdentifier::Rhs => { - gpu!(scope, sm_position = current * block_size_n); - gpu!(scope, sm_position += thread_idx_2); - gpu!(scope, sm_position = sm_position / 4u32); + cube_inline!(scope, sm_position = current * block_size_n); + cube_inline!(scope, sm_position += thread_idx_2); + cube_inline!(scope, sm_position = sm_position / 4u32); } } - gpu!(scope, position_0 = k + current); - gpu!(scope, position_0 *= stride_1); - gpu!(scope, tmp = thread_idx_2 * stride_2); - gpu!(scope, position_0 += tmp); - gpu!(scope, position_0 += input_offset); - gpu!(scope, position_1 = position_0 + stride_2); - gpu!(scope, position_2 = position_1 + stride_2); - gpu!(scope, position_3 = position_2 + stride_2); - - gpu!(scope, val_0 = input[position_0]); - gpu!(scope, val_1 = input[position_1]); - gpu!(scope, val_2 = input[position_2]); - gpu!(scope, val_3 = input[position_3]); - - gpu!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3)); - gpu!(scope, shared_memory[sm_position] = val_vec4); + cube_inline!(scope, position_0 = k + current); + cube_inline!(scope, position_0 *= stride_1); + cube_inline!(scope, tmp = thread_idx_2 * stride_2); + cube_inline!(scope, position_0 += tmp); + cube_inline!(scope, position_0 += input_offset); + cube_inline!(scope, position_1 = position_0 + stride_2); + cube_inline!(scope, position_2 = position_1 + stride_2); + cube_inline!(scope, position_3 = position_2 + stride_2); + + cube_inline!(scope, val_0 = input[position_0]); + cube_inline!(scope, val_1 = input[position_1]); + cube_inline!(scope, val_2 = input[position_2]); + cube_inline!(scope, val_3 = input[position_3]); + + cube_inline!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3)); + cube_inline!(scope, shared_memory[sm_position] = val_vec4); })); }) ); diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs index fca13cebed..3fe389cad6 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs @@ -1,4 +1,4 @@ -use crate::gpu::{gpu, Elem, Item, Scope, Variable}; +use crate::gpu::{cube_inline, Elem, Item, Scope, Variable}; use super::{MatmulTiling2dShader, Tiling2dState}; @@ -32,11 +32,11 @@ pub(crate) fn gather_shader_information( let dim_m = scope.create_local(Elem::UInt); let dim_k = scope.create_local(Elem::UInt); let dim_n = scope.create_local(Elem::UInt); - gpu!(scope, last_dim = rank - 1u32); - gpu!(scope, second_to_last_dim = rank - 2u32); - gpu!(scope, dim_m = shape(lhs, second_to_last_dim)); - gpu!(scope, dim_k = shape(lhs, last_dim)); - gpu!(scope, dim_n = shape(rhs, last_dim)); + cube_inline!(scope, last_dim = rank - 1u32); + cube_inline!(scope, second_to_last_dim = rank - 2u32); + cube_inline!(scope, dim_m = shape(lhs, second_to_last_dim)); + cube_inline!(scope, dim_k = shape(lhs, last_dim)); + cube_inline!(scope, dim_n = shape(rhs, last_dim)); // Strides let lhs_stride_row = scope.create_local(Elem::UInt); @@ -45,48 +45,48 @@ pub(crate) fn gather_shader_information( let rhs_stride_col = scope.create_local(Elem::UInt); let out_stride_row = scope.create_local(Elem::UInt); let out_stride_col = scope.create_local(Elem::UInt); - gpu!(scope, lhs_stride_row = stride(lhs, second_to_last_dim)); - gpu!(scope, lhs_stride_col = stride(lhs, last_dim)); - gpu!(scope, rhs_stride_row = stride(rhs, second_to_last_dim)); - gpu!(scope, rhs_stride_col = stride(rhs, last_dim)); - gpu!(scope, out_stride_row = stride(out, second_to_last_dim)); - gpu!(scope, out_stride_col = stride(out, last_dim)); + cube_inline!(scope, lhs_stride_row = stride(lhs, second_to_last_dim)); + cube_inline!(scope, lhs_stride_col = stride(lhs, last_dim)); + cube_inline!(scope, rhs_stride_row = stride(rhs, second_to_last_dim)); + cube_inline!(scope, rhs_stride_col = stride(rhs, last_dim)); + cube_inline!(scope, out_stride_row = stride(out, second_to_last_dim)); + cube_inline!(scope, out_stride_col = stride(out, last_dim)); // Workgroup offset let skip_row = scope.create_local(Elem::UInt); let skip_col = scope.create_local(Elem::UInt); let workgroup_id_x = Variable::WorkgroupIdX; let workgroup_id_y = Variable::WorkgroupIdY; - gpu!(scope, skip_row = workgroup_id_x); - gpu!(scope, skip_row *= block_size_m); - gpu!(scope, skip_col = workgroup_id_y); - gpu!(scope, skip_col *= block_size_n); + cube_inline!(scope, skip_row = workgroup_id_x); + cube_inline!(scope, skip_row *= block_size_m); + cube_inline!(scope, skip_col = workgroup_id_y); + cube_inline!(scope, skip_col *= block_size_n); // Position of the first element of the thread, relative to the block let thread_row = scope.create_local(Elem::UInt); let thread_col = scope.create_local(Elem::UInt); - gpu!(scope, thread_row = local_idx / n_threads_per_row); - gpu!(scope, thread_row *= tile_size_m); - gpu!(scope, thread_col = local_idx % n_threads_per_row); - gpu!(scope, thread_col *= tile_size_n); + cube_inline!(scope, thread_row = local_idx / n_threads_per_row); + cube_inline!(scope, thread_row *= tile_size_m); + cube_inline!(scope, thread_col = local_idx % n_threads_per_row); + cube_inline!(scope, thread_col *= tile_size_n); // Position of the first element of the thread, in absolute (in one batch) let row = scope.create_local(Elem::UInt); let col = scope.create_local(Elem::UInt); - gpu!(scope, row = skip_row + thread_row); - gpu!(scope, col = skip_col + thread_col); + cube_inline!(scope, row = skip_row + thread_row); + cube_inline!(scope, col = skip_col + thread_col); // Calculate offset. let offset_lhs = scope.create_local(Elem::UInt); let offset_rhs = scope.create_local(Elem::UInt); - gpu!(scope, offset_lhs = skip_row * lhs_stride_row); - gpu!(scope, offset_rhs = skip_col * rhs_stride_col); + cube_inline!(scope, offset_lhs = skip_row * lhs_stride_row); + cube_inline!(scope, offset_rhs = skip_col * rhs_stride_col); // Batch offset for the output. let offset_output = scope.create_local(Elem::UInt); let batch_dims = scope.create_local(Elem::UInt); - gpu!(scope, offset_output = dim_m * dim_n); - gpu!(scope, offset_output = offset_output * batch); + cube_inline!(scope, offset_output = dim_m * dim_n); + cube_inline!(scope, offset_output = offset_output * batch); // Batch offset for the lhs & rhs matrices. let stride_lhs = scope.create_local(Elem::UInt); @@ -97,24 +97,24 @@ pub(crate) fn gather_shader_information( let tmp = scope.create_local(Elem::UInt); let tmp_lhs = scope.create_local(Elem::UInt); let tmp_rhs = scope.create_local(Elem::UInt); - gpu!(scope, batch_dims = rank - 2u32); - gpu!( + cube_inline!(scope, batch_dims = rank - 2u32); + cube_inline!( scope, range(0u32, batch_dims).for_each(|b, scope| { - gpu!(scope, stride_lhs = stride(lhs, b)); - gpu!(scope, stride_rhs = stride(rhs, b)); - gpu!(scope, stride_output = stride(out, b)); - gpu!(scope, shape_lhs = shape(lhs, b)); - gpu!(scope, shape_rhs = shape(rhs, b)); - - gpu!(scope, tmp = offset_output / stride_output); - gpu!(scope, tmp_lhs = tmp % shape_lhs); - gpu!(scope, tmp_lhs = tmp_lhs * stride_lhs); - gpu!(scope, offset_lhs += tmp_lhs); - - gpu!(scope, tmp_rhs = tmp % shape_rhs); - gpu!(scope, tmp_rhs = tmp_rhs * stride_rhs); - gpu!(scope, offset_rhs += tmp_rhs); + cube_inline!(scope, stride_lhs = stride(lhs, b)); + cube_inline!(scope, stride_rhs = stride(rhs, b)); + cube_inline!(scope, stride_output = stride(out, b)); + cube_inline!(scope, shape_lhs = shape(lhs, b)); + cube_inline!(scope, shape_rhs = shape(rhs, b)); + + cube_inline!(scope, tmp = offset_output / stride_output); + cube_inline!(scope, tmp_lhs = tmp % shape_lhs); + cube_inline!(scope, tmp_lhs = tmp_lhs * stride_lhs); + cube_inline!(scope, offset_lhs += tmp_lhs); + + cube_inline!(scope, tmp_rhs = tmp % shape_rhs); + cube_inline!(scope, tmp_rhs = tmp_rhs * stride_rhs); + cube_inline!(scope, offset_rhs += tmp_rhs); }) ); @@ -140,13 +140,13 @@ pub(crate) fn gather_shader_information( let dim_k_float = scope.create_local(elem); let block_size_k_float = scope.create_local(elem); let n_loops_float = scope.create_local(elem); - gpu!(scope, dim_k_float = dim_k); - gpu!(scope, block_size_k_float = block_size_k); - gpu!(scope, n_loops_float = dim_k_float / block_size_k_float); - gpu!(scope, n_loops_float = ceil(n_loops_float)); - gpu!(scope, n_loops = n_loops_float); + cube_inline!(scope, dim_k_float = dim_k); + cube_inline!(scope, block_size_k_float = block_size_k); + cube_inline!(scope, n_loops_float = dim_k_float / block_size_k_float); + cube_inline!(scope, n_loops_float = ceil(n_loops_float)); + cube_inline!(scope, n_loops = n_loops_float); } else { - gpu!(scope, n_loops = dim_k / block_size_k); + cube_inline!(scope, n_loops = dim_k / block_size_k); } Tiling2dState { diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs index 3a75da6903..a756a487e0 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs @@ -1,4 +1,4 @@ -use crate::gpu::{gpu, Elem, Scope, Variable}; +use crate::gpu::{cube_inline, Elem, Scope, Variable}; use super::{MatmulTiling2dShader, Tiling2dState}; @@ -21,22 +21,22 @@ pub fn write_to_output( let within_output = scope.create_local(Elem::Bool); let within_output_tmp = scope.create_local(Elem::Bool); - gpu!( + cube_inline!( scope, range(0u32, shader.config.tile_size_m as u32, shader.config.unroll).for_each( |res_idx_m, scope| { - gpu!( + cube_inline!( scope, range(0u32, shader.config.tile_size_n as u32, shader.config.unroll) .for_each(|res_idx_n, scope| { - gpu!(scope, row_index = row + res_idx_m); - gpu!(scope, col_index = col + res_idx_n); + cube_inline!(scope, row_index = row + res_idx_m); + cube_inline!(scope, col_index = col + res_idx_n); - gpu!(scope, within_output = row_index < dim_m); - gpu!(scope, within_output_tmp = col_index < dim_n); - gpu!(scope, within_output = within_output && within_output_tmp); + cube_inline!(scope, within_output = row_index < dim_m); + cube_inline!(scope, within_output_tmp = col_index < dim_n); + cube_inline!(scope, within_output = within_output && within_output_tmp); - gpu!(scope, if(within_output).then(|scope|{ + cube_inline!(scope, if(within_output).then(|scope|{ write_inner( scope, shader, @@ -53,16 +53,16 @@ pub fn write_to_output( ) ); } else { - gpu!( + cube_inline!( scope, range(0u32, shader.config.tile_size_m as u32, shader.config.unroll).for_each( |res_idx_m, scope| { - gpu!( + cube_inline!( scope, range(0u32, shader.config.tile_size_n as u32, shader.config.unroll) .for_each(|res_idx_n, scope| { - gpu!(scope, row_index = row + res_idx_m); - gpu!(scope, col_index = col + res_idx_n); + cube_inline!(scope, row_index = row + res_idx_m); + cube_inline!(scope, col_index = col + res_idx_n); write_inner( scope, @@ -102,18 +102,18 @@ fn write_inner( let result = scope.create_local(elem); let output_position = scope.create_local(Elem::UInt); - gpu!( + cube_inline!( scope, results_position = res_idx_m * shader.config.tile_size_n ); - gpu!(scope, results_position += res_idx_n); + cube_inline!(scope, results_position += res_idx_n); - gpu!(scope, result = results[results_position]); + cube_inline!(scope, result = results[results_position]); - gpu!(scope, row_index *= out_stride_row); - gpu!(scope, col_index *= out_stride_col); - gpu!(scope, output_position = row_index + col_index); - gpu!(scope, output_position += offset_output); + cube_inline!(scope, row_index *= out_stride_row); + cube_inline!(scope, col_index *= out_stride_col); + cube_inline!(scope, output_position = row_index + col_index); + cube_inline!(scope, output_position += offset_output); - gpu!(scope, out[output_position] = result); + cube_inline!(scope, out[output_position] = result); } diff --git a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs index e09060bb7a..72b77694e3 100644 --- a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs @@ -6,7 +6,7 @@ use crate::{ OutputInfo, WorkgroupLaunch, }, element::JitElement, - gpu::{gpu, ComputeShader, Elem, Scope, Variable, Visibility}, + gpu::{cube_inline, ComputeShader, Elem, Scope, Variable, Visibility}, kernel::GpuComputeShaderPhase, tensor::JitTensor, Runtime, @@ -48,40 +48,40 @@ impl AdaptiveAvgPool2dBackwardComputeShader { let output_shape_2 = scope.create_local(Elem::UInt); let output_shape_3 = scope.create_local(Elem::UInt); - gpu!(scope, grad_stride_0 = stride(grad, 0u32)); - gpu!(scope, grad_stride_1 = stride(grad, 1u32)); - gpu!(scope, grad_stride_2 = stride(grad, 2u32)); - gpu!(scope, grad_stride_3 = stride(grad, 3u32)); + cube_inline!(scope, grad_stride_0 = stride(grad, 0u32)); + cube_inline!(scope, grad_stride_1 = stride(grad, 1u32)); + cube_inline!(scope, grad_stride_2 = stride(grad, 2u32)); + cube_inline!(scope, grad_stride_3 = stride(grad, 3u32)); - gpu!(scope, grad_shape_2 = shape(grad, 2u32)); - gpu!(scope, grad_shape_3 = shape(grad, 3u32)); + cube_inline!(scope, grad_shape_2 = shape(grad, 2u32)); + cube_inline!(scope, grad_shape_3 = shape(grad, 3u32)); - gpu!(scope, output_stride_0 = stride(output, 0u32)); - gpu!(scope, output_stride_1 = stride(output, 1u32)); - gpu!(scope, output_stride_2 = stride(output, 2u32)); - gpu!(scope, output_stride_3 = stride(output, 3u32)); + cube_inline!(scope, output_stride_0 = stride(output, 0u32)); + cube_inline!(scope, output_stride_1 = stride(output, 1u32)); + cube_inline!(scope, output_stride_2 = stride(output, 2u32)); + cube_inline!(scope, output_stride_3 = stride(output, 3u32)); - gpu!(scope, output_shape_0 = shape(output, 0u32)); - gpu!(scope, output_shape_1 = shape(output, 1u32)); - gpu!(scope, output_shape_2 = shape(output, 2u32)); - gpu!(scope, output_shape_3 = shape(output, 3u32)); + cube_inline!(scope, output_shape_0 = shape(output, 0u32)); + cube_inline!(scope, output_shape_1 = shape(output, 1u32)); + cube_inline!(scope, output_shape_2 = shape(output, 2u32)); + cube_inline!(scope, output_shape_3 = shape(output, 3u32)); let b = scope.create_local(Elem::UInt); let c = scope.create_local(Elem::UInt); let ih = scope.create_local(Elem::UInt); let iw = scope.create_local(Elem::UInt); - gpu!(scope, b = id / output_stride_0); - gpu!(scope, b = b % output_shape_0); + cube_inline!(scope, b = id / output_stride_0); + cube_inline!(scope, b = b % output_shape_0); - gpu!(scope, c = id / output_stride_1); - gpu!(scope, c = c % output_shape_1); + cube_inline!(scope, c = id / output_stride_1); + cube_inline!(scope, c = c % output_shape_1); - gpu!(scope, ih = id / output_stride_2); - gpu!(scope, ih = ih % output_shape_2); + cube_inline!(scope, ih = id / output_stride_2); + cube_inline!(scope, ih = ih % output_shape_2); - gpu!(scope, iw = id / output_stride_3); - gpu!(scope, iw = iw % output_shape_3); + cube_inline!(scope, iw = id / output_stride_3); + cube_inline!(scope, iw = iw % output_shape_3); let oh_start = Self::start_index(scope, ih, output_shape_2, grad_shape_2); let oh_end = Self::end_index(scope, ih, output_shape_2, grad_shape_2); @@ -103,46 +103,46 @@ impl AdaptiveAvgPool2dBackwardComputeShader { let index_base = scope.create_local(Elem::UInt); let index_tmp = scope.create_local(Elem::UInt); let index = scope.create_local(Elem::UInt); - gpu!(scope, index_base = b * grad_stride_0); - gpu!(scope, index_tmp = c * grad_stride_1); - gpu!(scope, index_base += index_tmp); + cube_inline!(scope, index_base = b * grad_stride_0); + cube_inline!(scope, index_tmp = c * grad_stride_1); + cube_inline!(scope, index_base += index_tmp); - gpu!( + cube_inline!( scope, range(oh_start, oh_end).for_each(|oh, scope| { let ih_start = Self::start_index(scope, oh, grad_shape_2, output_shape_2); let ih_end = Self::end_index(scope, oh, grad_shape_2, output_shape_2); - gpu!(scope, contributed_h = ih >= ih_start); - gpu!(scope, contributed_tmp = ih < ih_end); - gpu!(scope, contributed_h = contributed_h && contributed_tmp); + cube_inline!(scope, contributed_h = ih >= ih_start); + cube_inline!(scope, contributed_tmp = ih < ih_end); + cube_inline!(scope, contributed_h = contributed_h && contributed_tmp); - gpu!(scope, if(contributed_h).then(|scope|{ - gpu!( + cube_inline!(scope, if(contributed_h).then(|scope|{ + cube_inline!( scope, range(ow_start, ow_end).for_each(|ow, scope| { let iw_start = Self::start_index(scope, ow, grad_shape_3, output_shape_3); let iw_end = Self::end_index(scope, ow, grad_shape_3, output_shape_3); - gpu!(scope, contributed_w = iw >= iw_start); - gpu!(scope, contributed_tmp = iw < iw_end); - gpu!(scope, contributed_w = contributed_w && contributed_tmp); + cube_inline!(scope, contributed_w = iw >= iw_start); + cube_inline!(scope, contributed_tmp = iw < iw_end); + cube_inline!(scope, contributed_w = contributed_w && contributed_tmp); - gpu!(scope, if(contributed_w).then(|scope|{ - gpu!(scope, count = ih_end - ih_start); - gpu!(scope, count_tmp = iw_end - iw_start); - gpu!(scope, count *= count_tmp); - gpu!(scope, count_float = cast(count)); + cube_inline!(scope, if(contributed_w).then(|scope|{ + cube_inline!(scope, count = ih_end - ih_start); + cube_inline!(scope, count_tmp = iw_end - iw_start); + cube_inline!(scope, count *= count_tmp); + cube_inline!(scope, count_float = cast(count)); - gpu!(scope, index = index_base); - gpu!(scope, index_tmp = oh * grad_stride_2); - gpu!(scope, index += index_tmp); - gpu!(scope, index_tmp = ow * grad_stride_3); - gpu!(scope, index += index_tmp); + cube_inline!(scope, index = index_base); + cube_inline!(scope, index_tmp = oh * grad_stride_2); + cube_inline!(scope, index += index_tmp); + cube_inline!(scope, index_tmp = ow * grad_stride_3); + cube_inline!(scope, index += index_tmp); - gpu!(scope, the_grad = grad[index]); - gpu!(scope, avg = the_grad / count_float); - gpu!(scope, grad_acc += avg); + cube_inline!(scope, the_grad = grad[index]); + cube_inline!(scope, avg = the_grad / count_float); + cube_inline!(scope, grad_acc += avg); })); }) ); @@ -150,7 +150,7 @@ impl AdaptiveAvgPool2dBackwardComputeShader { }) ); - gpu!(scope, output[id] = grad_acc); + cube_inline!(scope, output[id] = grad_acc); } fn start_index( @@ -164,12 +164,12 @@ impl AdaptiveAvgPool2dBackwardComputeShader { let div = scope.create_local(elem); let index = scope.create_local(Elem::UInt); - gpu!(scope, index = output_size_index * input_size); - gpu!(scope, numerator_float = cast(index)); - gpu!(scope, div = cast(output_size)); - gpu!(scope, div = numerator_float / div); - gpu!(scope, div = floor(div)); - gpu!(scope, index = cast(div)); + cube_inline!(scope, index = output_size_index * input_size); + cube_inline!(scope, numerator_float = cast(index)); + cube_inline!(scope, div = cast(output_size)); + cube_inline!(scope, div = numerator_float / div); + cube_inline!(scope, div = floor(div)); + cube_inline!(scope, index = cast(div)); index } @@ -186,19 +186,19 @@ impl AdaptiveAvgPool2dBackwardComputeShader { let min = scope.create_local(Elem::Bool); let end_index = scope.create_local(Elem::UInt); - gpu!(scope, index = output_size_index + 1u32); - gpu!(scope, index *= input_size); - gpu!(scope, numerator_float = cast(index)); - gpu!(scope, div = cast(output_size)); - gpu!(scope, div = numerator_float / div); - gpu!(scope, div = ceil(div)); - gpu!(scope, index = cast(div)); - - gpu!(scope, min = input_size < index); - gpu!(scope, if(min).then(|scope|{ - gpu!(scope, end_index = input_size); + cube_inline!(scope, index = output_size_index + 1u32); + cube_inline!(scope, index *= input_size); + cube_inline!(scope, numerator_float = cast(index)); + cube_inline!(scope, div = cast(output_size)); + cube_inline!(scope, div = numerator_float / div); + cube_inline!(scope, div = ceil(div)); + cube_inline!(scope, index = cast(div)); + + cube_inline!(scope, min = input_size < index); + cube_inline!(scope, if(min).then(|scope|{ + cube_inline!(scope, end_index = input_size); }).else(|scope|{ - gpu!(scope, end_index = index); + cube_inline!(scope, end_index = index); })); end_index } diff --git a/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs b/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs index ec9669e534..e6b4cd66f1 100644 --- a/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs +++ b/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use crate::{ codegen::{Compilation, CompilationInfo, CompilationSettings, InputInfo, OutputInfo}, - gpu::{gpu, ComputeShader, Elem, Scope, Variable, Visibility}, + gpu::{cube_inline, ComputeShader, Elem, Scope, Variable, Visibility}, kernel::GpuComputeShaderPhase, JitElement, Runtime, }; @@ -40,42 +40,42 @@ impl AdaptivePool2dComputeShader { let output_shape_2 = scope.create_local(Elem::UInt); let output_shape_3 = scope.create_local(Elem::UInt); - gpu!(scope, input_stride_0 = stride(input, 0u32)); - gpu!(scope, input_stride_1 = stride(input, 1u32)); - gpu!(scope, input_stride_2 = stride(input, 2u32)); - gpu!(scope, input_stride_3 = stride(input, 3u32)); + cube_inline!(scope, input_stride_0 = stride(input, 0u32)); + cube_inline!(scope, input_stride_1 = stride(input, 1u32)); + cube_inline!(scope, input_stride_2 = stride(input, 2u32)); + cube_inline!(scope, input_stride_3 = stride(input, 3u32)); - gpu!(scope, input_shape_0 = shape(input, 2u32)); - gpu!(scope, input_shape_1 = shape(input, 3u32)); - gpu!(scope, input_shape_2 = shape(input, 2u32)); - gpu!(scope, input_shape_3 = shape(input, 3u32)); + cube_inline!(scope, input_shape_0 = shape(input, 2u32)); + cube_inline!(scope, input_shape_1 = shape(input, 3u32)); + cube_inline!(scope, input_shape_2 = shape(input, 2u32)); + cube_inline!(scope, input_shape_3 = shape(input, 3u32)); - gpu!(scope, output_stride_0 = stride(output, 0u32)); - gpu!(scope, output_stride_1 = stride(output, 1u32)); - gpu!(scope, output_stride_2 = stride(output, 2u32)); - gpu!(scope, output_stride_3 = stride(output, 3u32)); + cube_inline!(scope, output_stride_0 = stride(output, 0u32)); + cube_inline!(scope, output_stride_1 = stride(output, 1u32)); + cube_inline!(scope, output_stride_2 = stride(output, 2u32)); + cube_inline!(scope, output_stride_3 = stride(output, 3u32)); - gpu!(scope, output_shape_0 = shape(output, 0u32)); - gpu!(scope, output_shape_1 = shape(output, 1u32)); - gpu!(scope, output_shape_2 = shape(output, 2u32)); - gpu!(scope, output_shape_3 = shape(output, 3u32)); + cube_inline!(scope, output_shape_0 = shape(output, 0u32)); + cube_inline!(scope, output_shape_1 = shape(output, 1u32)); + cube_inline!(scope, output_shape_2 = shape(output, 2u32)); + cube_inline!(scope, output_shape_3 = shape(output, 3u32)); let b = scope.create_local(Elem::UInt); let c = scope.create_local(Elem::UInt); let oh = scope.create_local(Elem::UInt); let ow = scope.create_local(Elem::UInt); - gpu!(scope, b = id / output_stride_0); - gpu!(scope, b = b % output_shape_0); + cube_inline!(scope, b = id / output_stride_0); + cube_inline!(scope, b = b % output_shape_0); - gpu!(scope, c = id / output_stride_1); - gpu!(scope, c = c % output_shape_1); + cube_inline!(scope, c = id / output_stride_1); + cube_inline!(scope, c = c % output_shape_1); - gpu!(scope, oh = id / output_stride_2); - gpu!(scope, oh = oh % output_shape_2); + cube_inline!(scope, oh = id / output_stride_2); + cube_inline!(scope, oh = oh % output_shape_2); - gpu!(scope, ow = id / output_stride_3); - gpu!(scope, ow = ow % output_shape_3); + cube_inline!(scope, ow = id / output_stride_3); + cube_inline!(scope, ow = ow % output_shape_3); let ih_start = Self::start_index(scope, oh, output_shape_2, input_shape_2); let ih_end = Self::end_index(scope, oh, output_shape_2, input_shape_2); @@ -90,28 +90,28 @@ impl AdaptivePool2dComputeShader { let index_input_2 = scope.create_local(Elem::UInt); let index_input_3 = scope.create_local(Elem::UInt); - gpu!(scope, index_input_0 = b * input_stride_0); - gpu!(scope, index_input_1 = c * input_stride_1); + cube_inline!(scope, index_input_0 = b * input_stride_0); + cube_inline!(scope, index_input_1 = c * input_stride_1); let sum = scope.zero(output.item()); - gpu!( + cube_inline!( scope, range(ih_start, ih_end).for_each(|ih, scope| { - gpu!( + cube_inline!( scope, range(iw_start, iw_end).for_each(|iw, scope| { - gpu!(scope, index_input_2 = ih * input_stride_2); - gpu!(scope, index_input_3 = iw * input_stride_3); + cube_inline!(scope, index_input_2 = ih * input_stride_2); + cube_inline!(scope, index_input_3 = iw * input_stride_3); - gpu!(scope, index_input = index_input_0); - gpu!(scope, index_input += index_input_1); - gpu!(scope, index_input += index_input_2); - gpu!(scope, index_input += index_input_3); + cube_inline!(scope, index_input = index_input_0); + cube_inline!(scope, index_input += index_input_1); + cube_inline!(scope, index_input += index_input_2); + cube_inline!(scope, index_input += index_input_3); - gpu!(scope, result = input[index_input]); + cube_inline!(scope, result = input[index_input]); - gpu!(scope, sum += result); + cube_inline!(scope, sum += result); }) ); }) @@ -122,13 +122,13 @@ impl AdaptivePool2dComputeShader { let count_float = scope.create_local(output.item()); let avg = scope.create_local(output.item()); - gpu!(scope, count = ih_end - ih_start); - gpu!(scope, count_tmp = iw_end - iw_start); - gpu!(scope, count *= count_tmp); + cube_inline!(scope, count = ih_end - ih_start); + cube_inline!(scope, count_tmp = iw_end - iw_start); + cube_inline!(scope, count *= count_tmp); - gpu!(scope, count_float = cast(count)); - gpu!(scope, avg = sum / count_float); - gpu!(scope, output[id] = avg); + cube_inline!(scope, count_float = cast(count)); + cube_inline!(scope, avg = sum / count_float); + cube_inline!(scope, output[id] = avg); } fn start_index( @@ -142,12 +142,12 @@ impl AdaptivePool2dComputeShader { let div = scope.create_local(elem); let index = scope.create_local(Elem::UInt); - gpu!(scope, index = output_size_index * input_size); - gpu!(scope, numerator_float = cast(index)); - gpu!(scope, div = cast(output_size)); - gpu!(scope, div = numerator_float / div); - gpu!(scope, div = floor(div)); - gpu!(scope, index = cast(div)); + cube_inline!(scope, index = output_size_index * input_size); + cube_inline!(scope, numerator_float = cast(index)); + cube_inline!(scope, div = cast(output_size)); + cube_inline!(scope, div = numerator_float / div); + cube_inline!(scope, div = floor(div)); + cube_inline!(scope, index = cast(div)); index } @@ -164,19 +164,19 @@ impl AdaptivePool2dComputeShader { let min = scope.create_local(Elem::Bool); let end_index = scope.create_local(Elem::UInt); - gpu!(scope, index = output_size_index + 1u32); - gpu!(scope, index *= input_size); - gpu!(scope, numerator_float = cast(index)); - gpu!(scope, div = cast(output_size)); - gpu!(scope, div = numerator_float / div); - gpu!(scope, div = ceil(div)); - gpu!(scope, index = cast(div)); - - gpu!(scope, min = input_size < index); - gpu!(scope, if(min).then(|scope|{ - gpu!(scope, end_index = input_size); + cube_inline!(scope, index = output_size_index + 1u32); + cube_inline!(scope, index *= input_size); + cube_inline!(scope, numerator_float = cast(index)); + cube_inline!(scope, div = cast(output_size)); + cube_inline!(scope, div = numerator_float / div); + cube_inline!(scope, div = ceil(div)); + cube_inline!(scope, index = cast(div)); + + cube_inline!(scope, min = input_size < index); + cube_inline!(scope, if(min).then(|scope|{ + cube_inline!(scope, end_index = input_size); }).else(|scope|{ - gpu!(scope, end_index = index); + cube_inline!(scope, end_index = index); })); end_index } diff --git a/crates/burn-jit/src/kernel/pool/avg_pool2d.rs b/crates/burn-jit/src/kernel/pool/avg_pool2d.rs index 6ff6d5cc1e..7b7fb65600 100644 --- a/crates/burn-jit/src/kernel/pool/avg_pool2d.rs +++ b/crates/burn-jit/src/kernel/pool/avg_pool2d.rs @@ -1,7 +1,7 @@ use crate::{ codegen::{dialect::gpu::Variable, EagerHandle, Execution, WorkgroupLaunch}, element::JitElement, - gpu::{gpu, Elem, Item, Scope}, + gpu::{cube_inline, Elem, Item, Scope}, ops::numeric::empty_device, tensor::JitTensor, Runtime, @@ -25,10 +25,10 @@ impl PoolStrategy for AvgPool { let count = scope.create_local(Elem::UInt); if self.count_include_pad { let kernel_size: Variable = (self.kernel_size[0] * self.kernel_size[1]).into(); - gpu!(scope, count = kernel_size); + cube_inline!(scope, count = kernel_size); } else { let zero: Variable = 0u32.into(); - gpu!(scope, count = zero); + cube_inline!(scope, count = zero); } (sum, count) } @@ -43,9 +43,9 @@ impl PoolStrategy for AvgPool { let (sum, count) = accumulator; if !self.count_include_pad { let one: Variable = 1u32.into(); - gpu!(scope, count += one); + cube_inline!(scope, count += one); } - gpu!(scope, sum += result); + cube_inline!(scope, sum += result); (sum, count) } @@ -60,9 +60,9 @@ impl PoolStrategy for AvgPool { let (sum, count) = accumulator; let avg = scope.create_local(output.item()); let count_float = scope.create_local(output.item()); - gpu!(scope, count_float = cast(count)); - gpu!(scope, avg = sum / count_float); - gpu!(scope, output[id] = avg); + cube_inline!(scope, count_float = cast(count)); + cube_inline!(scope, avg = sum / count_float); + cube_inline!(scope, output[id] = avg); } fn with_indices() -> bool { diff --git a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs index 1b8cac774b..7fc266533e 100644 --- a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs @@ -1,6 +1,6 @@ use crate::{ codegen::{ - dialect::gpu::{gpu, Elem, IntKind, Scope, Variable, Visibility}, + dialect::gpu::{cube_inline, Elem, IntKind, Scope, Variable, Visibility}, Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, @@ -52,23 +52,23 @@ impl AvgPool2dBackwardComputeShader { let output_shape_2 = scope.create_local(Elem::UInt); let output_shape_3 = scope.create_local(Elem::UInt); - gpu!(scope, grad_stride_0 = stride(grad, 0u32)); - gpu!(scope, grad_stride_1 = stride(grad, 1u32)); - gpu!(scope, grad_stride_2 = stride(grad, 2u32)); - gpu!(scope, grad_stride_3 = stride(grad, 3u32)); + cube_inline!(scope, grad_stride_0 = stride(grad, 0u32)); + cube_inline!(scope, grad_stride_1 = stride(grad, 1u32)); + cube_inline!(scope, grad_stride_2 = stride(grad, 2u32)); + cube_inline!(scope, grad_stride_3 = stride(grad, 3u32)); - gpu!(scope, grad_shape_2 = shape(grad, 2u32)); - gpu!(scope, grad_shape_3 = shape(grad, 3u32)); + cube_inline!(scope, grad_shape_2 = shape(grad, 2u32)); + cube_inline!(scope, grad_shape_3 = shape(grad, 3u32)); - gpu!(scope, output_stride_0 = stride(output, 0u32)); - gpu!(scope, output_stride_1 = stride(output, 1u32)); - gpu!(scope, output_stride_2 = stride(output, 2u32)); - gpu!(scope, output_stride_3 = stride(output, 3u32)); + cube_inline!(scope, output_stride_0 = stride(output, 0u32)); + cube_inline!(scope, output_stride_1 = stride(output, 1u32)); + cube_inline!(scope, output_stride_2 = stride(output, 2u32)); + cube_inline!(scope, output_stride_3 = stride(output, 3u32)); - gpu!(scope, output_shape_0 = shape(output, 0u32)); - gpu!(scope, output_shape_1 = shape(output, 1u32)); - gpu!(scope, output_shape_2 = shape(output, 2u32)); - gpu!(scope, output_shape_3 = shape(output, 3u32)); + cube_inline!(scope, output_shape_0 = shape(output, 0u32)); + cube_inline!(scope, output_shape_1 = shape(output, 1u32)); + cube_inline!(scope, output_shape_2 = shape(output, 2u32)); + cube_inline!(scope, output_shape_3 = shape(output, 3u32)); let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt); let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt); @@ -81,24 +81,24 @@ impl AvgPool2dBackwardComputeShader { let ih = scope.create_local(Elem::UInt); let iw = scope.create_local(Elem::UInt); - gpu!(scope, b = id / output_stride_0); - gpu!(scope, b = b % output_shape_0); + cube_inline!(scope, b = id / output_stride_0); + cube_inline!(scope, b = b % output_shape_0); - gpu!(scope, c = id / output_stride_1); - gpu!(scope, c = c % output_shape_1); + cube_inline!(scope, c = id / output_stride_1); + cube_inline!(scope, c = c % output_shape_1); - gpu!(scope, ih = id / output_stride_2); - gpu!(scope, ih = ih % output_shape_2); + cube_inline!(scope, ih = id / output_stride_2); + cube_inline!(scope, ih = ih % output_shape_2); - gpu!(scope, iw = id / output_stride_3); - gpu!(scope, iw = iw % output_shape_3); + cube_inline!(scope, iw = id / output_stride_3); + cube_inline!(scope, iw = iw % output_shape_3); let index_current = scope.create_local(Elem::UInt); let index_current_tmp = scope.create_local(Elem::UInt); - gpu!(scope, index_current = ih * output_stride_2); - gpu!(scope, index_current_tmp = iw * output_stride_3); - gpu!(scope, index_current += index_current_tmp); + cube_inline!(scope, index_current = ih * output_stride_2); + cube_inline!(scope, index_current_tmp = iw * output_stride_3); + cube_inline!(scope, index_current += index_current_tmp); let index = scope.create_local(Elem::UInt); let index_tmp = scope.create_local(Elem::UInt); @@ -111,7 +111,7 @@ impl AvgPool2dBackwardComputeShader { let count_include_pad = self.count_include_pad; if count_include_pad { let kernel_size: Variable = (self.kernel_size[0] * self.kernel_size[1]).into(); - gpu!(scope, count = kernel_size); + cube_inline!(scope, count = kernel_size); } let (oh_start, oh_end, ow_start, ow_end) = self.loop_ranges( @@ -124,9 +124,9 @@ impl AvgPool2dBackwardComputeShader { output_stride_3, ); - gpu!(scope, index_base = b * grad_stride_0); - gpu!(scope, index_tmp = c * grad_stride_1); - gpu!(scope, index_base += index_tmp); + cube_inline!(scope, index_base = b * grad_stride_0); + cube_inline!(scope, index_tmp = c * grad_stride_1); + cube_inline!(scope, index_base += index_tmp); let border_bottom = scope.create_local(Elem::UInt); let border_right = scope.create_local(Elem::UInt); @@ -140,67 +140,67 @@ impl AvgPool2dBackwardComputeShader { let before_end = scope.create_local(Elem::Bool); let contributed_h = scope.create_local(Elem::Bool); let contributed_w = scope.create_local(Elem::Bool); - gpu!(scope, border_bottom = output_shape_2 + padding_0); - gpu!(scope, border_right = output_shape_3 + padding_1); - gpu!(scope, begin_h = ih + padding_0); - gpu!(scope, begin_w = iw + padding_1); + cube_inline!(scope, border_bottom = output_shape_2 + padding_0); + cube_inline!(scope, border_right = output_shape_3 + padding_1); + cube_inline!(scope, begin_h = ih + padding_0); + cube_inline!(scope, begin_w = iw + padding_1); let ih_diff = scope.create_local(Elem::UInt); let iw_diff = scope.create_local(Elem::UInt); let count_int = scope.create_local(Elem::UInt); - gpu!( + cube_inline!( scope, range(oh_start, oh_end).for_each(|oh, scope| { // Contributed h - gpu!(scope, ih_start = oh * pool_stride_0); - gpu!(scope, ih_end = ih_start + kernel_size_0); - gpu!(scope, ih_start = max(ih_start, padding_0)); - gpu!(scope, ih_end = min(ih_end, border_bottom)); - gpu!(scope, after_start = begin_h >= ih_start); - gpu!(scope, before_end = ih < ih_end); - gpu!(scope, contributed_h = after_start && before_end); + cube_inline!(scope, ih_start = oh * pool_stride_0); + cube_inline!(scope, ih_end = ih_start + kernel_size_0); + cube_inline!(scope, ih_start = max(ih_start, padding_0)); + cube_inline!(scope, ih_end = min(ih_end, border_bottom)); + cube_inline!(scope, after_start = begin_h >= ih_start); + cube_inline!(scope, before_end = ih < ih_end); + cube_inline!(scope, contributed_h = after_start && before_end); if !count_include_pad { - gpu!(scope, ih_diff = ih_end - ih_start); + cube_inline!(scope, ih_diff = ih_end - ih_start); } - gpu!(scope, if(contributed_h).then(|scope|{ - gpu!( + cube_inline!(scope, if(contributed_h).then(|scope|{ + cube_inline!( scope, range(ow_start, ow_end).for_each(|ow, scope| { - gpu!(scope, index = index_base); - gpu!(scope, index_tmp = oh * grad_stride_2); - gpu!(scope, index += index_tmp); - gpu!(scope, index_tmp = ow * grad_stride_3); - gpu!(scope, index += index_tmp); + cube_inline!(scope, index = index_base); + cube_inline!(scope, index_tmp = oh * grad_stride_2); + cube_inline!(scope, index += index_tmp); + cube_inline!(scope, index_tmp = ow * grad_stride_3); + cube_inline!(scope, index += index_tmp); // Contributed w - gpu!(scope, iw_start = ow * pool_stride_1); - gpu!(scope, iw_end = iw_start + kernel_size_1); - gpu!(scope, iw_start = max(iw_start, padding_1)); - gpu!(scope, iw_end = min(iw_end, border_right)); - gpu!(scope, after_start = begin_w >= iw_start); - gpu!(scope, before_end = iw < iw_end); - gpu!(scope, contributed_w = after_start && before_end); - - gpu!(scope, if(contributed_w).then(|scope|{ + cube_inline!(scope, iw_start = ow * pool_stride_1); + cube_inline!(scope, iw_end = iw_start + kernel_size_1); + cube_inline!(scope, iw_start = max(iw_start, padding_1)); + cube_inline!(scope, iw_end = min(iw_end, border_right)); + cube_inline!(scope, after_start = begin_w >= iw_start); + cube_inline!(scope, before_end = iw < iw_end); + cube_inline!(scope, contributed_w = after_start && before_end); + + cube_inline!(scope, if(contributed_w).then(|scope|{ if !count_include_pad { - gpu!(scope, iw_diff = iw_end - iw_start); - gpu!(scope, count_int = ih_diff * iw_diff); - gpu!(scope, count = cast(count_int)); + cube_inline!(scope, iw_diff = iw_end - iw_start); + cube_inline!(scope, count_int = ih_diff * iw_diff); + cube_inline!(scope, count = cast(count_int)); } - gpu!(scope, result = grad[index]); - gpu!(scope, result = result / count); - gpu!(scope, grad_accumulation += result); + cube_inline!(scope, result = grad[index]); + cube_inline!(scope, result = result / count); + cube_inline!(scope, grad_accumulation += result); })); })); })); }) ); - gpu!(scope, output[id] = grad_accumulation); + cube_inline!(scope, output[id] = grad_accumulation); } #[allow(clippy::too_many_arguments)] @@ -235,53 +235,53 @@ impl AvgPool2dBackwardComputeShader { let signed_kernel_size_0 = scope.create_local(Elem::Int(IntKind::I32)); let signed_kernel_size_1 = scope.create_local(Elem::Int(IntKind::I32)); - gpu!(scope, signed_pool_stride_0 = cast(pool_stride_0)); - gpu!(scope, signed_pool_stride_1 = cast(pool_stride_1)); - gpu!(scope, signed_dilation_0 = cast(dilation_0)); - gpu!(scope, signed_dilation_1 = cast(dilation_1)); - gpu!(scope, signed_padding_0 = cast(padding_0)); - gpu!(scope, signed_padding_1 = cast(padding_1)); + cube_inline!(scope, signed_pool_stride_0 = cast(pool_stride_0)); + cube_inline!(scope, signed_pool_stride_1 = cast(pool_stride_1)); + cube_inline!(scope, signed_dilation_0 = cast(dilation_0)); + cube_inline!(scope, signed_dilation_1 = cast(dilation_1)); + cube_inline!(scope, signed_padding_0 = cast(padding_0)); + cube_inline!(scope, signed_padding_1 = cast(padding_1)); - gpu!(scope, signed_kernel_size_0 = cast(kernel_size_0)); - gpu!(scope, signed_kernel_size_1 = cast(kernel_size_1)); + cube_inline!(scope, signed_kernel_size_0 = cast(kernel_size_0)); + cube_inline!(scope, signed_kernel_size_1 = cast(kernel_size_1)); - gpu!(scope, signed_ih = cast(ih)); - gpu!(scope, signed_iw = cast(iw)); + cube_inline!(scope, signed_ih = cast(ih)); + cube_inline!(scope, signed_iw = cast(iw)); let kms_0 = scope.create_local(Elem::Int(IntKind::I32)); let kms_1 = scope.create_local(Elem::Int(IntKind::I32)); - gpu!(scope, kms_0 = signed_dilation_0 * signed_kernel_size_0); - gpu!(scope, kms_0 = kms_0 - signed_pool_stride_0); + cube_inline!(scope, kms_0 = signed_dilation_0 * signed_kernel_size_0); + cube_inline!(scope, kms_0 = kms_0 - signed_pool_stride_0); - gpu!(scope, kms_1 = signed_dilation_1 * signed_kernel_size_1); - gpu!(scope, kms_1 = kms_1 - signed_pool_stride_1); + cube_inline!(scope, kms_1 = signed_dilation_1 * signed_kernel_size_1); + cube_inline!(scope, kms_1 = kms_1 - signed_pool_stride_1); let oh_start_tmp = scope.create_local(Elem::Int(IntKind::I32)); let ow_start_tmp = scope.create_local(Elem::Int(IntKind::I32)); - gpu!(scope, oh_start_tmp = signed_ih + signed_padding_0); - gpu!(scope, oh_start_tmp = oh_start_tmp - kms_0); - gpu!(scope, oh_start_tmp = oh_start_tmp / signed_pool_stride_0); + cube_inline!(scope, oh_start_tmp = signed_ih + signed_padding_0); + cube_inline!(scope, oh_start_tmp = oh_start_tmp - kms_0); + cube_inline!(scope, oh_start_tmp = oh_start_tmp / signed_pool_stride_0); - gpu!(scope, ow_start_tmp = signed_iw + signed_padding_1); - gpu!(scope, ow_start_tmp = ow_start_tmp - kms_1); - gpu!(scope, ow_start_tmp = ow_start_tmp / signed_pool_stride_1); + cube_inline!(scope, ow_start_tmp = signed_iw + signed_padding_1); + cube_inline!(scope, ow_start_tmp = ow_start_tmp - kms_1); + cube_inline!(scope, ow_start_tmp = ow_start_tmp / signed_pool_stride_1); - gpu!(scope, oh_start_tmp = max(oh_start_tmp, 0i32)); - gpu!(scope, ow_start_tmp = max(ow_start_tmp, 0i32)); + cube_inline!(scope, oh_start_tmp = max(oh_start_tmp, 0i32)); + cube_inline!(scope, ow_start_tmp = max(ow_start_tmp, 0i32)); let oh_start = scope.create_local(Elem::UInt); let ow_start = scope.create_local(Elem::UInt); - gpu!(scope, oh_start = cast(oh_start_tmp)); - gpu!(scope, ow_start = cast(ow_start_tmp)); + cube_inline!(scope, oh_start = cast(oh_start_tmp)); + cube_inline!(scope, ow_start = cast(ow_start_tmp)); let oh_end_tmp = scope.create_local(Elem::Int(IntKind::I32)); let ow_end_tmp = scope.create_local(Elem::Int(IntKind::I32)); - gpu!(scope, oh_end_tmp = max(kms_0, 0i32)); - gpu!(scope, ow_end_tmp = max(kms_1, 0i32)); + cube_inline!(scope, oh_end_tmp = max(kms_0, 0i32)); + cube_inline!(scope, ow_end_tmp = max(kms_1, 0i32)); let oh_end = scope.create_local(Elem::UInt); let ow_end = scope.create_local(Elem::UInt); @@ -289,27 +289,27 @@ impl AvgPool2dBackwardComputeShader { let oh_end_limit = scope.create_local(Elem::UInt); let ow_end_limit = scope.create_local(Elem::UInt); - gpu!(scope, oh_end = cast(oh_end_tmp)); - gpu!(scope, ow_end = cast(ow_end_tmp)); + cube_inline!(scope, oh_end = cast(oh_end_tmp)); + cube_inline!(scope, ow_end = cast(ow_end_tmp)); - gpu!(scope, oh_end = oh_end + oh_start); - gpu!(scope, oh_end_limit = grad_shape_2 - 1u32); + cube_inline!(scope, oh_end = oh_end + oh_start); + cube_inline!(scope, oh_end_limit = grad_shape_2 - 1u32); - gpu!(scope, ow_end = ow_end + ow_start); - gpu!(scope, ow_end_limit = grad_shape_3 - 1u32); + cube_inline!(scope, ow_end = ow_end + ow_start); + cube_inline!(scope, ow_end_limit = grad_shape_3 - 1u32); - gpu!(scope, oh_end = min(oh_end, oh_end_limit)); - gpu!(scope, ow_end = min(ow_end, ow_end_limit)); + cube_inline!(scope, oh_end = min(oh_end, oh_end_limit)); + cube_inline!(scope, ow_end = min(ow_end, ow_end_limit)); let index_current = scope.create_local(Elem::UInt); let index_current_tmp = scope.create_local(Elem::UInt); - gpu!(scope, index_current = ih * output_stride_2); - gpu!(scope, index_current_tmp = iw * output_stride_3); - gpu!(scope, index_current += index_current_tmp); + cube_inline!(scope, index_current = ih * output_stride_2); + cube_inline!(scope, index_current_tmp = iw * output_stride_3); + cube_inline!(scope, index_current += index_current_tmp); - gpu!(scope, oh_end = oh_end + 1u32); - gpu!(scope, ow_end = ow_end + 1u32); + cube_inline!(scope, oh_end = oh_end + 1u32); + cube_inline!(scope, ow_end = ow_end + 1u32); (oh_start, oh_end, ow_start, ow_end) } diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d.rs b/crates/burn-jit/src/kernel/pool/max_pool2d.rs index b1d871071d..115982ff85 100644 --- a/crates/burn-jit/src/kernel/pool/max_pool2d.rs +++ b/crates/burn-jit/src/kernel/pool/max_pool2d.rs @@ -3,7 +3,7 @@ use std::{fmt::Debug, marker::PhantomData}; use crate::{ codegen::{dialect::gpu::Variable, EagerHandle, Execution, WorkgroupLaunch}, element::JitElement, - gpu::{gpu, Elem, Item, Scope}, + gpu::{cube_inline, Elem, Item, Scope}, ops::numeric::empty_device, tensor::JitTensor, Runtime, @@ -24,7 +24,7 @@ impl PoolStrategy for MaxPool { let max_val = scope.create_local(item); let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), item.elem()); - gpu!(scope, max_val = max_initial); + cube_inline!(scope, max_val = max_initial); max_val } @@ -36,9 +36,9 @@ impl PoolStrategy for MaxPool { _idx: Variable, ) -> Self::Accumulator { let is_max = scope.create_local(Elem::Bool); - gpu!(scope, is_max = result > accumulator); - gpu!(scope, if(is_max).then(|scope|{ - gpu!(scope, accumulator = result); + cube_inline!(scope, is_max = result > accumulator); + cube_inline!(scope, if(is_max).then(|scope|{ + cube_inline!(scope, accumulator = result); })); accumulator } @@ -51,7 +51,7 @@ impl PoolStrategy for MaxPool { _indices: Option, accumulator: Self::Accumulator, ) { - gpu!(scope, output[id] = accumulator); + cube_inline!(scope, output[id] = accumulator); } fn with_indices() -> bool { @@ -71,7 +71,7 @@ impl PoolStrategy for MaxPoolWithIndices { let max_val = scope.create_local(item); let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), item.elem()); - gpu!(scope, max_val = max_initial); + cube_inline!(scope, max_val = max_initial); let max_index = scope.create_local(Elem::UInt); (max_val, max_index) } @@ -84,10 +84,10 @@ impl PoolStrategy for MaxPoolWithIndices { idx: Variable, ) -> Self::Accumulator { let is_max = scope.create_local(Elem::Bool); - gpu!(scope, is_max = result > max_val); - gpu!(scope, if(is_max).then(|scope|{ - gpu!(scope, max_val = result); - gpu!(scope, max_index = idx); + cube_inline!(scope, is_max = result > max_val); + cube_inline!(scope, if(is_max).then(|scope|{ + cube_inline!(scope, max_val = result); + cube_inline!(scope, max_index = idx); })); (max_val, max_index) } @@ -101,8 +101,8 @@ impl PoolStrategy for MaxPoolWithIndices { (max_val, max_index): Self::Accumulator, ) { let indices = indices.unwrap(); - gpu!(scope, output[id] = max_val); - gpu!(scope, indices[id] = max_index); + cube_inline!(scope, output[id] = max_val); + cube_inline!(scope, indices[id] = max_index); } fn with_indices() -> bool { diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs index 9b86c5e56f..a6c2b641a2 100644 --- a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs @@ -1,6 +1,6 @@ use crate::{ codegen::{ - dialect::gpu::{gpu, Elem, IntKind, Item, Scope, Variable, Visibility}, + dialect::gpu::{cube_inline, Elem, IntKind, Item, Scope, Variable, Visibility}, Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, @@ -52,47 +52,47 @@ impl MaxPool2dBackwardComputeShader { let output_shape_2 = scope.create_local(Elem::UInt); let output_shape_3 = scope.create_local(Elem::UInt); - gpu!(scope, grad_stride_0 = stride(grad, 0u32)); - gpu!(scope, grad_stride_1 = stride(grad, 1u32)); - gpu!(scope, grad_stride_2 = stride(grad, 2u32)); - gpu!(scope, grad_stride_3 = stride(grad, 3u32)); + cube_inline!(scope, grad_stride_0 = stride(grad, 0u32)); + cube_inline!(scope, grad_stride_1 = stride(grad, 1u32)); + cube_inline!(scope, grad_stride_2 = stride(grad, 2u32)); + cube_inline!(scope, grad_stride_3 = stride(grad, 3u32)); - gpu!(scope, grad_shape_2 = shape(grad, 2u32)); - gpu!(scope, grad_shape_3 = shape(grad, 3u32)); + cube_inline!(scope, grad_shape_2 = shape(grad, 2u32)); + cube_inline!(scope, grad_shape_3 = shape(grad, 3u32)); - gpu!(scope, output_stride_0 = stride(output, 0u32)); - gpu!(scope, output_stride_1 = stride(output, 1u32)); - gpu!(scope, output_stride_2 = stride(output, 2u32)); - gpu!(scope, output_stride_3 = stride(output, 3u32)); + cube_inline!(scope, output_stride_0 = stride(output, 0u32)); + cube_inline!(scope, output_stride_1 = stride(output, 1u32)); + cube_inline!(scope, output_stride_2 = stride(output, 2u32)); + cube_inline!(scope, output_stride_3 = stride(output, 3u32)); - gpu!(scope, output_shape_0 = shape(output, 0u32)); - gpu!(scope, output_shape_1 = shape(output, 1u32)); - gpu!(scope, output_shape_2 = shape(output, 2u32)); - gpu!(scope, output_shape_3 = shape(output, 3u32)); + cube_inline!(scope, output_shape_0 = shape(output, 0u32)); + cube_inline!(scope, output_shape_1 = shape(output, 1u32)); + cube_inline!(scope, output_shape_2 = shape(output, 2u32)); + cube_inline!(scope, output_shape_3 = shape(output, 3u32)); let b = scope.create_local(Elem::UInt); let c = scope.create_local(Elem::UInt); let ih = scope.create_local(Elem::UInt); let iw = scope.create_local(Elem::UInt); - gpu!(scope, b = id / output_stride_0); - gpu!(scope, b = b % output_shape_0); + cube_inline!(scope, b = id / output_stride_0); + cube_inline!(scope, b = b % output_shape_0); - gpu!(scope, c = id / output_stride_1); - gpu!(scope, c = c % output_shape_1); + cube_inline!(scope, c = id / output_stride_1); + cube_inline!(scope, c = c % output_shape_1); - gpu!(scope, ih = id / output_stride_2); - gpu!(scope, ih = ih % output_shape_2); + cube_inline!(scope, ih = id / output_stride_2); + cube_inline!(scope, ih = ih % output_shape_2); - gpu!(scope, iw = id / output_stride_3); - gpu!(scope, iw = iw % output_shape_3); + cube_inline!(scope, iw = id / output_stride_3); + cube_inline!(scope, iw = iw % output_shape_3); let index_current = scope.create_local(Elem::UInt); let index_current_tmp = scope.create_local(Elem::UInt); - gpu!(scope, index_current = ih * output_stride_2); - gpu!(scope, index_current_tmp = iw * output_stride_3); - gpu!(scope, index_current += index_current_tmp); + cube_inline!(scope, index_current = ih * output_stride_2); + cube_inline!(scope, index_current_tmp = iw * output_stride_3); + cube_inline!(scope, index_current += index_current_tmp); let index_select = scope.create_local(Elem::Int(IntKind::I32)); @@ -116,37 +116,37 @@ impl MaxPool2dBackwardComputeShader { output_stride_3, ); - gpu!(scope, index_base = b * grad_stride_0); - gpu!(scope, index_tmp = c * grad_stride_1); - gpu!(scope, index_base += index_tmp); + cube_inline!(scope, index_base = b * grad_stride_0); + cube_inline!(scope, index_tmp = c * grad_stride_1); + cube_inline!(scope, index_base += index_tmp); - gpu!( + cube_inline!( scope, range(oh_start, oh_end).for_each(|oh, scope| { - gpu!( + cube_inline!( scope, range(ow_start, ow_end).for_each(|ow, scope| { - gpu!(scope, index = index_base); - gpu!(scope, index_tmp = oh * grad_stride_2); - gpu!(scope, index += index_tmp); - gpu!(scope, index_tmp = ow * grad_stride_3); - gpu!(scope, index += index_tmp); + cube_inline!(scope, index = index_base); + cube_inline!(scope, index_tmp = oh * grad_stride_2); + cube_inline!(scope, index += index_tmp); + cube_inline!(scope, index_tmp = ow * grad_stride_3); + cube_inline!(scope, index += index_tmp); - gpu!(scope, index_select = indices[index]); - gpu!(scope, index_max = cast(index_select)); + cube_inline!(scope, index_select = indices[index]); + cube_inline!(scope, index_max = cast(index_select)); - gpu!(scope, is_max = index_max == index_current); + cube_inline!(scope, is_max = index_max == index_current); - gpu!(scope, if(is_max).then(|scope|{ - gpu!(scope, result = grad[index]); - gpu!(scope, grad_accumulation += result); + cube_inline!(scope, if(is_max).then(|scope|{ + cube_inline!(scope, result = grad[index]); + cube_inline!(scope, grad_accumulation += result); })); }) ); }) ); - gpu!(scope, output[id] = grad_accumulation); + cube_inline!(scope, output[id] = grad_accumulation); } #[allow(clippy::too_many_arguments)] @@ -181,53 +181,53 @@ impl MaxPool2dBackwardComputeShader { let signed_kernel_size_0 = scope.create_local(Elem::Int(IntKind::I32)); let signed_kernel_size_1 = scope.create_local(Elem::Int(IntKind::I32)); - gpu!(scope, signed_pool_stride_0 = cast(pool_stride_0)); - gpu!(scope, signed_pool_stride_1 = cast(pool_stride_1)); - gpu!(scope, signed_dilation_0 = cast(dilation_0)); - gpu!(scope, signed_dilation_1 = cast(dilation_1)); - gpu!(scope, signed_padding_0 = cast(padding_0)); - gpu!(scope, signed_padding_1 = cast(padding_1)); + cube_inline!(scope, signed_pool_stride_0 = cast(pool_stride_0)); + cube_inline!(scope, signed_pool_stride_1 = cast(pool_stride_1)); + cube_inline!(scope, signed_dilation_0 = cast(dilation_0)); + cube_inline!(scope, signed_dilation_1 = cast(dilation_1)); + cube_inline!(scope, signed_padding_0 = cast(padding_0)); + cube_inline!(scope, signed_padding_1 = cast(padding_1)); - gpu!(scope, signed_kernel_size_0 = cast(kernel_size_0)); - gpu!(scope, signed_kernel_size_1 = cast(kernel_size_1)); + cube_inline!(scope, signed_kernel_size_0 = cast(kernel_size_0)); + cube_inline!(scope, signed_kernel_size_1 = cast(kernel_size_1)); - gpu!(scope, signed_ih = cast(ih)); - gpu!(scope, signed_iw = cast(iw)); + cube_inline!(scope, signed_ih = cast(ih)); + cube_inline!(scope, signed_iw = cast(iw)); let kms_0 = scope.create_local(Elem::Int(IntKind::I32)); let kms_1 = scope.create_local(Elem::Int(IntKind::I32)); - gpu!(scope, kms_0 = signed_dilation_0 * signed_kernel_size_0); - gpu!(scope, kms_0 = kms_0 - signed_pool_stride_0); + cube_inline!(scope, kms_0 = signed_dilation_0 * signed_kernel_size_0); + cube_inline!(scope, kms_0 = kms_0 - signed_pool_stride_0); - gpu!(scope, kms_1 = signed_dilation_1 * signed_kernel_size_1); - gpu!(scope, kms_1 = kms_1 - signed_pool_stride_1); + cube_inline!(scope, kms_1 = signed_dilation_1 * signed_kernel_size_1); + cube_inline!(scope, kms_1 = kms_1 - signed_pool_stride_1); let oh_start_tmp = scope.create_local(Elem::Int(IntKind::I32)); let ow_start_tmp = scope.create_local(Elem::Int(IntKind::I32)); - gpu!(scope, oh_start_tmp = signed_ih + signed_padding_0); - gpu!(scope, oh_start_tmp = oh_start_tmp - kms_0); - gpu!(scope, oh_start_tmp = oh_start_tmp / signed_pool_stride_0); + cube_inline!(scope, oh_start_tmp = signed_ih + signed_padding_0); + cube_inline!(scope, oh_start_tmp = oh_start_tmp - kms_0); + cube_inline!(scope, oh_start_tmp = oh_start_tmp / signed_pool_stride_0); - gpu!(scope, ow_start_tmp = signed_iw + signed_padding_1); - gpu!(scope, ow_start_tmp = ow_start_tmp - kms_1); - gpu!(scope, ow_start_tmp = ow_start_tmp / signed_pool_stride_1); + cube_inline!(scope, ow_start_tmp = signed_iw + signed_padding_1); + cube_inline!(scope, ow_start_tmp = ow_start_tmp - kms_1); + cube_inline!(scope, ow_start_tmp = ow_start_tmp / signed_pool_stride_1); - gpu!(scope, oh_start_tmp = max(oh_start_tmp, 0i32)); - gpu!(scope, ow_start_tmp = max(ow_start_tmp, 0i32)); + cube_inline!(scope, oh_start_tmp = max(oh_start_tmp, 0i32)); + cube_inline!(scope, ow_start_tmp = max(ow_start_tmp, 0i32)); let oh_start = scope.create_local(Elem::UInt); let ow_start = scope.create_local(Elem::UInt); - gpu!(scope, oh_start = cast(oh_start_tmp)); - gpu!(scope, ow_start = cast(ow_start_tmp)); + cube_inline!(scope, oh_start = cast(oh_start_tmp)); + cube_inline!(scope, ow_start = cast(ow_start_tmp)); let oh_end_tmp = scope.create_local(Elem::Int(IntKind::I32)); let ow_end_tmp = scope.create_local(Elem::Int(IntKind::I32)); - gpu!(scope, oh_end_tmp = max(kms_0, 0i32)); - gpu!(scope, ow_end_tmp = max(kms_1, 0i32)); + cube_inline!(scope, oh_end_tmp = max(kms_0, 0i32)); + cube_inline!(scope, ow_end_tmp = max(kms_1, 0i32)); let oh_end = scope.create_local(Elem::UInt); let ow_end = scope.create_local(Elem::UInt); @@ -235,27 +235,27 @@ impl MaxPool2dBackwardComputeShader { let oh_end_limit = scope.create_local(Elem::UInt); let ow_end_limit = scope.create_local(Elem::UInt); - gpu!(scope, oh_end = cast(oh_end_tmp)); - gpu!(scope, ow_end = cast(ow_end_tmp)); + cube_inline!(scope, oh_end = cast(oh_end_tmp)); + cube_inline!(scope, ow_end = cast(ow_end_tmp)); - gpu!(scope, oh_end = oh_end + oh_start); - gpu!(scope, oh_end_limit = grad_shape_2 - 1u32); + cube_inline!(scope, oh_end = oh_end + oh_start); + cube_inline!(scope, oh_end_limit = grad_shape_2 - 1u32); - gpu!(scope, ow_end = ow_end + ow_start); - gpu!(scope, ow_end_limit = grad_shape_3 - 1u32); + cube_inline!(scope, ow_end = ow_end + ow_start); + cube_inline!(scope, ow_end_limit = grad_shape_3 - 1u32); - gpu!(scope, oh_end = min(oh_end, oh_end_limit)); - gpu!(scope, ow_end = min(ow_end, ow_end_limit)); + cube_inline!(scope, oh_end = min(oh_end, oh_end_limit)); + cube_inline!(scope, ow_end = min(ow_end, ow_end_limit)); let index_current = scope.create_local(Elem::UInt); let index_current_tmp = scope.create_local(Elem::UInt); - gpu!(scope, index_current = ih * output_stride_2); - gpu!(scope, index_current_tmp = iw * output_stride_3); - gpu!(scope, index_current += index_current_tmp); + cube_inline!(scope, index_current = ih * output_stride_2); + cube_inline!(scope, index_current_tmp = iw * output_stride_3); + cube_inline!(scope, index_current += index_current_tmp); - gpu!(scope, oh_end = oh_end + 1u32); - gpu!(scope, ow_end = ow_end + 1u32); + cube_inline!(scope, oh_end = oh_end + 1u32); + cube_inline!(scope, ow_end = ow_end + 1u32); (oh_start, oh_end, ow_start, ow_end) } diff --git a/crates/burn-jit/src/kernel/pool/pool2d_shader.rs b/crates/burn-jit/src/kernel/pool/pool2d_shader.rs index 64101b7146..e4c511df70 100644 --- a/crates/burn-jit/src/kernel/pool/pool2d_shader.rs +++ b/crates/burn-jit/src/kernel/pool/pool2d_shader.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use crate::{ codegen::{Compilation, CompilationInfo, CompilationSettings, InputInfo, OutputInfo}, - gpu::{gpu, ComputeShader, Elem, IntKind, Item, Scope, Variable, Visibility}, + gpu::{cube_inline, ComputeShader, Elem, IntKind, Item, Scope, Variable, Visibility}, kernel::GpuComputeShaderPhase, JitElement, Runtime, }; @@ -45,25 +45,25 @@ impl Pool2dComputeShader { let output_shape_2 = scope.create_local(Elem::UInt); let output_shape_3 = scope.create_local(Elem::UInt); - gpu!(scope, input_stride_0 = stride(input, 0u32)); - gpu!(scope, input_stride_1 = stride(input, 1u32)); - gpu!(scope, input_stride_2 = stride(input, 2u32)); - gpu!(scope, input_stride_3 = stride(input, 3u32)); + cube_inline!(scope, input_stride_0 = stride(input, 0u32)); + cube_inline!(scope, input_stride_1 = stride(input, 1u32)); + cube_inline!(scope, input_stride_2 = stride(input, 2u32)); + cube_inline!(scope, input_stride_3 = stride(input, 3u32)); - gpu!(scope, input_shape_0 = shape(input, 2u32)); - gpu!(scope, input_shape_1 = shape(input, 3u32)); - gpu!(scope, input_shape_2 = shape(input, 2u32)); - gpu!(scope, input_shape_3 = shape(input, 3u32)); + cube_inline!(scope, input_shape_0 = shape(input, 2u32)); + cube_inline!(scope, input_shape_1 = shape(input, 3u32)); + cube_inline!(scope, input_shape_2 = shape(input, 2u32)); + cube_inline!(scope, input_shape_3 = shape(input, 3u32)); - gpu!(scope, output_stride_0 = stride(output, 0u32)); - gpu!(scope, output_stride_1 = stride(output, 1u32)); - gpu!(scope, output_stride_2 = stride(output, 2u32)); - gpu!(scope, output_stride_3 = stride(output, 3u32)); + cube_inline!(scope, output_stride_0 = stride(output, 0u32)); + cube_inline!(scope, output_stride_1 = stride(output, 1u32)); + cube_inline!(scope, output_stride_2 = stride(output, 2u32)); + cube_inline!(scope, output_stride_3 = stride(output, 3u32)); - gpu!(scope, output_shape_0 = shape(output, 0u32)); - gpu!(scope, output_shape_1 = shape(output, 1u32)); - gpu!(scope, output_shape_2 = shape(output, 2u32)); - gpu!(scope, output_shape_3 = shape(output, 3u32)); + cube_inline!(scope, output_shape_0 = shape(output, 0u32)); + cube_inline!(scope, output_shape_1 = shape(output, 1u32)); + cube_inline!(scope, output_shape_2 = shape(output, 2u32)); + cube_inline!(scope, output_shape_3 = shape(output, 3u32)); let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt); let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt); @@ -77,17 +77,17 @@ impl Pool2dComputeShader { let oh = scope.create_local(Elem::UInt); let ow = scope.create_local(Elem::UInt); - gpu!(scope, b = id / output_stride_0); - gpu!(scope, b = b % output_shape_0); + cube_inline!(scope, b = id / output_stride_0); + cube_inline!(scope, b = b % output_shape_0); - gpu!(scope, c = id / output_stride_1); - gpu!(scope, c = c % output_shape_1); + cube_inline!(scope, c = id / output_stride_1); + cube_inline!(scope, c = c % output_shape_1); - gpu!(scope, oh = id / output_stride_2); - gpu!(scope, oh = oh % output_shape_2); + cube_inline!(scope, oh = id / output_stride_2); + cube_inline!(scope, oh = oh % output_shape_2); - gpu!(scope, ow = id / output_stride_3); - gpu!(scope, ow = ow % output_shape_3); + cube_inline!(scope, ow = id / output_stride_3); + cube_inline!(scope, ow = ow % output_shape_3); let ih = scope.create_local(Elem::UInt); let iw = scope.create_local(Elem::UInt); @@ -110,48 +110,48 @@ impl Pool2dComputeShader { let border_bottom = scope.create_local(Elem::UInt); let border_right = scope.create_local(Elem::UInt); - gpu!(scope, border_bottom = input_shape_2 + padding_0); - gpu!(scope, border_right = input_shape_3 + padding_1); + cube_inline!(scope, border_bottom = input_shape_2 + padding_0); + cube_inline!(scope, border_right = input_shape_3 + padding_1); - gpu!(scope, index_input_0 = b * input_stride_0); - gpu!(scope, index_input_1 = c * input_stride_1); + cube_inline!(scope, index_input_0 = b * input_stride_0); + cube_inline!(scope, index_input_1 = c * input_stride_1); let accumulator = self.pool_strategy.initialize(scope, input.item()); (0..self.kernel_size[0]).for_each(|kh| { - gpu!(scope, ih = oh * pool_stride_0); - gpu!(scope, dilated = kh * dilation_0); - gpu!(scope, ih += dilated); + cube_inline!(scope, ih = oh * pool_stride_0); + cube_inline!(scope, dilated = kh * dilation_0); + cube_inline!(scope, ih += dilated); - gpu!(scope, within_padding_h = ih >= padding_0); - gpu!(scope, tmp_padding = ih < border_bottom); - gpu!(scope, within_padding_h = within_padding_h && tmp_padding); + cube_inline!(scope, within_padding_h = ih >= padding_0); + cube_inline!(scope, tmp_padding = ih < border_bottom); + cube_inline!(scope, within_padding_h = within_padding_h && tmp_padding); - gpu!(scope, if (within_padding_h).then(|scope| { + cube_inline!(scope, if (within_padding_h).then(|scope| { (0..self.kernel_size[1]).for_each(|kw| { - gpu!(scope, iw = ow * pool_stride_1); - gpu!(scope, dilated = kw * dilation_1); - gpu!(scope, iw += dilated); + cube_inline!(scope, iw = ow * pool_stride_1); + cube_inline!(scope, dilated = kw * dilation_1); + cube_inline!(scope, iw += dilated); - gpu!(scope, within_padding_w = iw >= padding_1); - gpu!(scope, tmp_padding = iw < border_right); - gpu!(scope, within_padding_w = within_padding_w && tmp_padding); + cube_inline!(scope, within_padding_w = iw >= padding_1); + cube_inline!(scope, tmp_padding = iw < border_right); + cube_inline!(scope, within_padding_w = within_padding_w && tmp_padding); - gpu!(scope, if (within_padding_w).then(|scope| { - gpu!(scope, ih_pad = ih - padding_0); - gpu!(scope, iw_pad = iw - padding_1); + cube_inline!(scope, if (within_padding_w).then(|scope| { + cube_inline!(scope, ih_pad = ih - padding_0); + cube_inline!(scope, iw_pad = iw - padding_1); - gpu!(scope, index_input_2 = ih_pad * input_stride_2); - gpu!(scope, idx = index_input_2); - gpu!(scope, idx += iw_pad); - gpu!(scope, index_input_3 = iw_pad * input_stride_3); + cube_inline!(scope, index_input_2 = ih_pad * input_stride_2); + cube_inline!(scope, idx = index_input_2); + cube_inline!(scope, idx += iw_pad); + cube_inline!(scope, index_input_3 = iw_pad * input_stride_3); - gpu!(scope, index_input = index_input_0); - gpu!(scope, index_input += index_input_1); - gpu!(scope, index_input += index_input_2); - gpu!(scope, index_input += index_input_3); + cube_inline!(scope, index_input = index_input_0); + cube_inline!(scope, index_input += index_input_1); + cube_inline!(scope, index_input += index_input_2); + cube_inline!(scope, index_input += index_input_3); - gpu!(scope, result = input[index_input]); + cube_inline!(scope, result = input[index_input]); self.pool_strategy.process_result(scope, accumulator, result, idx); })); diff --git a/crates/burn-jit/src/kernel/prng/base.rs b/crates/burn-jit/src/kernel/prng/base.rs index 0a32f74129..08687e67f8 100644 --- a/crates/burn-jit/src/kernel/prng/base.rs +++ b/crates/burn-jit/src/kernel/prng/base.rs @@ -6,7 +6,7 @@ use crate::{ OutputInfo, WorkgroupLaunch, }, compute::WorkGroup, - gpu::{gpu, ComputeShader, Elem, Scope, Variable}, + gpu::{cube_inline, ComputeShader, Elem, Scope, Variable}, kernel::{GpuComputeShaderPhase, WORKGROUP_DEFAULT}, tensor::JitTensor, JitElement, Runtime, SEED, @@ -175,41 +175,41 @@ impl, E: JitElement> PrngShader { let local_index = Variable::LocalInvocationIndex; let n_invocations = scope.create_local(Elem::UInt); - gpu!(scope, n_invocations = workgroup_size_x); - gpu!(scope, n_invocations *= workgroup_size_y); + cube_inline!(scope, n_invocations = workgroup_size_x); + cube_inline!(scope, n_invocations *= workgroup_size_y); let workgroup_offset = scope.create_local(Elem::UInt); - gpu!(scope, workgroup_offset = workgroup_id_x * num_workgroups_y); - gpu!(scope, workgroup_offset += workgroup_id_y); - gpu!(scope, workgroup_offset *= n_invocations); + cube_inline!(scope, workgroup_offset = workgroup_id_x * num_workgroups_y); + cube_inline!(scope, workgroup_offset += workgroup_id_y); + cube_inline!(scope, workgroup_offset *= n_invocations); let write_index_base = scope.create_local(Elem::UInt); - gpu!(scope, write_index_base = workgroup_offset); - gpu!(scope, write_index_base *= n_values_per_thread); - gpu!(scope, write_index_base += local_index); + cube_inline!(scope, write_index_base = workgroup_offset); + cube_inline!(scope, write_index_base *= n_values_per_thread); + cube_inline!(scope, write_index_base += local_index); // Set state with unique seeds let thread_seed = scope.create_local(Elem::UInt); - gpu!(scope, thread_seed = cast(1000000007)); + cube_inline!(scope, thread_seed = cast(1000000007)); let thread_seed_index = scope.create_local(Elem::UInt); - gpu!(scope, thread_seed_index = workgroup_offset + local_index); - gpu!(scope, thread_seed *= thread_seed_index); + cube_inline!(scope, thread_seed_index = workgroup_offset + local_index); + cube_inline!(scope, thread_seed *= thread_seed_index); let state_0 = scope.create_local(Elem::UInt); - gpu!(scope, state_0 = thread_seed); - gpu!(scope, state_0 += seed_0); + cube_inline!(scope, state_0 = thread_seed); + cube_inline!(scope, state_0 += seed_0); let state_1 = scope.create_local(Elem::UInt); - gpu!(scope, state_1 = thread_seed); - gpu!(scope, state_1 += seed_1); + cube_inline!(scope, state_1 = thread_seed); + cube_inline!(scope, state_1 += seed_1); let state_2 = scope.create_local(Elem::UInt); - gpu!(scope, state_2 = thread_seed); - gpu!(scope, state_2 += seed_2); + cube_inline!(scope, state_2 = thread_seed); + cube_inline!(scope, state_2 += seed_2); let state_3 = scope.create_local(Elem::UInt); - gpu!(scope, state_3 = thread_seed); - gpu!(scope, state_3 += seed_3); + cube_inline!(scope, state_3 = thread_seed); + cube_inline!(scope, state_3 += seed_3); // Creation of n_values_per_thread values, specific to the distribution P::inner_loop( @@ -269,25 +269,25 @@ fn taus_step( m: Variable, ) { let b = scope.create_local(Elem::UInt); - gpu!(scope, b = z << s1); - gpu!(scope, b = b ^ z); - gpu!(scope, b = b >> s2); - gpu!(scope, z = z & m); - gpu!(scope, z = z << s3); - gpu!(scope, z = z ^ b); + cube_inline!(scope, b = z << s1); + cube_inline!(scope, b = b ^ z); + cube_inline!(scope, b = b >> s2); + cube_inline!(scope, z = z & m); + cube_inline!(scope, z = z << s3); + cube_inline!(scope, z = z ^ b); } pub(crate) fn lcg_step(scope: &mut Scope, z: Variable) { let a: Variable = 1664525u32.into(); let b: Variable = 1013904223u32.into(); - gpu!(scope, z *= a); - gpu!(scope, z += b); + cube_inline!(scope, z *= a); + cube_inline!(scope, z += b); } pub(crate) fn cast_uint_to_float(scope: &mut Scope, int_random: Variable, float_random: Variable) { let tmp: Variable = 2.328_306_4e-10.into(); - gpu!(scope, float_random = cast(int_random)); - gpu!(scope, float_random *= tmp); + cube_inline!(scope, float_random = cast(int_random)); + cube_inline!(scope, float_random *= tmp); } #[allow(missing_docs)] diff --git a/crates/burn-jit/src/kernel/prng/bernoulli.rs b/crates/burn-jit/src/kernel/prng/bernoulli.rs index df21278efa..9049b04503 100644 --- a/crates/burn-jit/src/kernel/prng/bernoulli.rs +++ b/crates/burn-jit/src/kernel/prng/bernoulli.rs @@ -1,7 +1,7 @@ use burn_tensor::Shape; use crate::{ - gpu::{gpu, Elem, Scope, Variable}, + gpu::{cube_inline, Elem, Scope, Variable}, kernel::prng::{cast_uint_to_float, lcg_step, taus_step_0, taus_step_1, taus_step_2}, tensor::JitTensor, JitElement, Runtime, @@ -31,7 +31,7 @@ impl Prng for Bernoulli { output: Variable, ) { let prob = args[0]; - gpu!( + cube_inline!( scope, range(0u32, n_values_per_thread).for_each(|i, scope| { taus_step_0(scope, state_0); @@ -40,20 +40,20 @@ impl Prng for Bernoulli { lcg_step(scope, state_3); let int_random = scope.create_local(Elem::UInt); - gpu!(scope, int_random = state_0 ^ state_1); - gpu!(scope, int_random = int_random ^ state_2); - gpu!(scope, int_random = int_random ^ state_3); + cube_inline!(scope, int_random = state_0 ^ state_1); + cube_inline!(scope, int_random = int_random ^ state_2); + cube_inline!(scope, int_random = int_random ^ state_3); let float_random = scope.create_local(E::gpu_elem()); cast_uint_to_float(scope, int_random, float_random); let bernoulli = scope.create_local(Elem::Bool); - gpu!(scope, bernoulli = float_random < prob); + cube_inline!(scope, bernoulli = float_random < prob); let write_index = scope.create_local(Elem::UInt); - gpu!(scope, write_index = i * n_invocations); - gpu!(scope, write_index += write_index_base); - gpu!(scope, output[write_index] = bernoulli); + cube_inline!(scope, write_index = i * n_invocations); + cube_inline!(scope, write_index += write_index_base); + cube_inline!(scope, output[write_index] = bernoulli); }) ); } diff --git a/crates/burn-jit/src/kernel/prng/normal.rs b/crates/burn-jit/src/kernel/prng/normal.rs index b8f62fb1d1..5b175fd84f 100644 --- a/crates/burn-jit/src/kernel/prng/normal.rs +++ b/crates/burn-jit/src/kernel/prng/normal.rs @@ -3,7 +3,7 @@ use std::f32::consts::PI; use burn_tensor::Shape; use crate::{ - gpu::{gpu, Elem, Scope, Variable}, + gpu::{cube_inline, Elem, Scope, Variable}, kernel::prng::{cast_uint_to_float, lcg_step, taus_step_0, taus_step_1, taus_step_2}, tensor::JitTensor, JitElement, Runtime, @@ -41,7 +41,7 @@ impl Prng for Normal { let t_neg = scope.create_with_value(-2.0, item); let two: Variable = 2u32.into(); - gpu!( + cube_inline!( scope, range(0u32, n_values_per_thread / 2).for_each(|i, scope| { let int_random = scope.create_local(Elem::UInt); @@ -52,9 +52,9 @@ impl Prng for Normal { taus_step_2(scope, state_2); lcg_step(scope, state_3); - gpu!(scope, int_random = state_0 ^ state_1); - gpu!(scope, int_random = int_random ^ state_2); - gpu!(scope, int_random = int_random ^ state_3); + cube_inline!(scope, int_random = state_0 ^ state_1); + cube_inline!(scope, int_random = int_random ^ state_2); + cube_inline!(scope, int_random = int_random ^ state_3); let unit_0 = scope.create_local(elem); cast_uint_to_float(scope, int_random, unit_0); @@ -65,44 +65,44 @@ impl Prng for Normal { taus_step_2(scope, state_2); lcg_step(scope, state_3); - gpu!(scope, int_random = state_0 ^ state_1); - gpu!(scope, int_random = int_random ^ state_2); - gpu!(scope, int_random = int_random ^ state_3); + cube_inline!(scope, int_random = state_0 ^ state_1); + cube_inline!(scope, int_random = int_random ^ state_2); + cube_inline!(scope, int_random = int_random ^ state_3); let unit_1 = scope.create_local(elem); cast_uint_to_float(scope, int_random, unit_1); // Box-Muller transform let coeff = scope.create_local(item); - gpu!(scope, coeff = log(unit_0)); - gpu!(scope, coeff *= t_neg); - gpu!(scope, coeff = sqrt(coeff)); - gpu!(scope, coeff *= std); + cube_inline!(scope, coeff = log(unit_0)); + cube_inline!(scope, coeff *= t_neg); + cube_inline!(scope, coeff = sqrt(coeff)); + cube_inline!(scope, coeff *= std); let trigo_arg = scope.create_local(item); - gpu!(scope, trigo_arg = two_pi * unit_1); + cube_inline!(scope, trigo_arg = two_pi * unit_1); let normal_0 = scope.create_local(item); let normal_1 = scope.create_local(item); - gpu!(scope, normal_0 = cos(trigo_arg)); - gpu!(scope, normal_0 *= coeff); - gpu!(scope, normal_0 += mean); - gpu!(scope, normal_1 = sin(trigo_arg)); - gpu!(scope, normal_1 *= coeff); - gpu!(scope, normal_1 += mean); + cube_inline!(scope, normal_0 = cos(trigo_arg)); + cube_inline!(scope, normal_0 *= coeff); + cube_inline!(scope, normal_0 += mean); + cube_inline!(scope, normal_1 = sin(trigo_arg)); + cube_inline!(scope, normal_1 *= coeff); + cube_inline!(scope, normal_1 += mean); // Write to output let write_index_0 = scope.create_local(Elem::UInt); let write_index_1 = scope.create_local(Elem::UInt); let iteration_offset = scope.create_local(Elem::UInt); - gpu!(scope, write_index_0 = write_index_base); - gpu!(scope, iteration_offset = two * i); - gpu!(scope, iteration_offset *= n_invocations); - gpu!(scope, write_index_0 += iteration_offset); - gpu!(scope, write_index_1 = write_index_0 + n_invocations); - - gpu!(scope, output[write_index_0] = normal_0); - gpu!(scope, output[write_index_1] = normal_1); + cube_inline!(scope, write_index_0 = write_index_base); + cube_inline!(scope, iteration_offset = two * i); + cube_inline!(scope, iteration_offset *= n_invocations); + cube_inline!(scope, write_index_0 += iteration_offset); + cube_inline!(scope, write_index_1 = write_index_0 + n_invocations); + + cube_inline!(scope, output[write_index_0] = normal_0); + cube_inline!(scope, output[write_index_1] = normal_1); }) ); } diff --git a/crates/burn-jit/src/kernel/prng/uniform.rs b/crates/burn-jit/src/kernel/prng/uniform.rs index 06d206cf1e..e29e6e5ac9 100644 --- a/crates/burn-jit/src/kernel/prng/uniform.rs +++ b/crates/burn-jit/src/kernel/prng/uniform.rs @@ -1,7 +1,7 @@ use burn_tensor::Shape; use crate::{ - gpu::{gpu, Elem, Scope, Variable}, + gpu::{cube_inline, Elem, Scope, Variable}, kernel::prng::{cast_uint_to_float, lcg_step, taus_step_0, taus_step_1, taus_step_2}, tensor::JitTensor, JitElement, Runtime, @@ -36,9 +36,9 @@ impl Prng for Uniform { let lower_bound = args[0]; let upper_bound = args[1]; let scale = scope.create_local(item); - gpu!(scope, scale = upper_bound - lower_bound); + cube_inline!(scope, scale = upper_bound - lower_bound); - gpu!( + cube_inline!( scope, range(0u32, n_values_per_thread).for_each(|i, scope| { taus_step_0(scope, state_0); @@ -47,25 +47,25 @@ impl Prng for Uniform { lcg_step(scope, state_3); let int_random = scope.create_local(Elem::UInt); - gpu!(scope, int_random = state_0 ^ state_1); - gpu!(scope, int_random = int_random ^ state_2); - gpu!(scope, int_random = int_random ^ state_3); + cube_inline!(scope, int_random = state_0 ^ state_1); + cube_inline!(scope, int_random = int_random ^ state_2); + cube_inline!(scope, int_random = int_random ^ state_3); let float_random = scope.create_local(elem); let float_scale = scope.create_local(elem); cast_uint_to_float(scope, int_random, float_random); - gpu!(scope, float_scale = cast(scale)); + cube_inline!(scope, float_scale = cast(scale)); let uniform_float = scope.create_local(elem); let uniform = scope.create_local(item); - gpu!(scope, uniform_float = float_random * float_scale); - gpu!(scope, uniform = cast(uniform_float)); - gpu!(scope, uniform += lower_bound); + cube_inline!(scope, uniform_float = float_random * float_scale); + cube_inline!(scope, uniform = cast(uniform_float)); + cube_inline!(scope, uniform += lower_bound); let write_index = scope.create_local(Elem::UInt); - gpu!(scope, write_index = i * n_invocations); - gpu!(scope, write_index += write_index_base); - gpu!(scope, output[write_index] = uniform); + cube_inline!(scope, write_index = i * n_invocations); + cube_inline!(scope, write_index += write_index_base); + cube_inline!(scope, output[write_index] = uniform); }) ); } diff --git a/crates/burn-jit/src/kernel/reduce/argmax_dim.rs b/crates/burn-jit/src/kernel/reduce/argmax_dim.rs index c207ed5b22..86ec2eaf5c 100644 --- a/crates/burn-jit/src/kernel/reduce/argmax_dim.rs +++ b/crates/burn-jit/src/kernel/reduce/argmax_dim.rs @@ -1,5 +1,5 @@ use crate::{ - codegen::dialect::gpu::{gpu, Elem, Item, Scope, Variable}, + codegen::dialect::gpu::{cube_inline, Elem, Item, Scope, Variable}, JitElement, }; @@ -19,7 +19,7 @@ impl ReduceDimAlgorithm for ArgMax { let max = scope.create_local(input_item); let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), input_item.elem()); - gpu!(scope, max = max_initial); + cube_inline!(scope, max = max_initial); (max, index) } @@ -31,10 +31,10 @@ impl ReduceDimAlgorithm for ArgMax { i: Variable, ) { let condition = scope.create_local(Elem::Bool); - gpu!(scope, condition = value > max); - gpu!(scope, if(condition).then(|scope| { - gpu!(scope, max = value); - gpu!(scope, index = i); + cube_inline!(scope, condition = value > max); + cube_inline!(scope, if(condition).then(|scope| { + cube_inline!(scope, max = value); + cube_inline!(scope, index = i); })); } @@ -45,7 +45,7 @@ impl ReduceDimAlgorithm for ArgMax { _shape_reduce_dim: Variable, ) { let id = Variable::Id; - gpu!(scope, output[id] = index); + cube_inline!(scope, output[id] = index); } fn initialize_shared( @@ -58,7 +58,7 @@ impl ReduceDimAlgorithm for ArgMax { let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size); let max = Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), input_item.elem()); - gpu!(scope, value_shared_memory[write_position] = max); + cube_inline!(scope, value_shared_memory[write_position] = max); (value_shared_memory, index_shared_memory) } @@ -70,13 +70,13 @@ impl ReduceDimAlgorithm for ArgMax { ) { let (value_shared_memory, index_shared_memory) = shared_memory; let current_value = scope.create_local(value.item()); - gpu!(scope, current_value = value_shared_memory[write_position]); + cube_inline!(scope, current_value = value_shared_memory[write_position]); let condition = scope.create_local(Elem::Bool); - gpu!(scope, condition = value > current_value); - gpu!(scope, if(condition).then(|scope| { - gpu!(scope, value_shared_memory[write_position] = value); - gpu!(scope, index_shared_memory[write_position] = index); + cube_inline!(scope, condition = value > current_value); + cube_inline!(scope, if(condition).then(|scope| { + cube_inline!(scope, value_shared_memory[write_position] = value); + cube_inline!(scope, index_shared_memory[write_position] = index); })); } @@ -87,7 +87,7 @@ impl ReduceDimAlgorithm for ArgMax { i: Variable, ) -> Self::Accumulator { let value = scope.create_local(input.item()); - gpu!(scope, value = input[read_position]); + cube_inline!(scope, value = input[read_position]); (value, i) } @@ -98,9 +98,9 @@ impl ReduceDimAlgorithm for ArgMax { ) -> Self::Accumulator { let (value_shared_memory, index_shared_memory) = shared_memory; let value = scope.create_local(value_shared_memory.item()); - gpu!(scope, value = value_shared_memory[read_position]); + cube_inline!(scope, value = value_shared_memory[read_position]); let index = scope.create_local(index_shared_memory.item()); - gpu!(scope, index = index_shared_memory[read_position]); + cube_inline!(scope, index = index_shared_memory[read_position]); (value, index) } @@ -113,7 +113,7 @@ impl ReduceDimAlgorithm for ArgMax { ) { let (_, index_shared_memory) = shared_memory; let final_value = scope.create_local(output.item()); - gpu!(scope, final_value = index_shared_memory[0]); - gpu!(scope, output[write_position] = final_value); + cube_inline!(scope, final_value = index_shared_memory[0]); + cube_inline!(scope, output[write_position] = final_value); } } diff --git a/crates/burn-jit/src/kernel/reduce/argmin_dim.rs b/crates/burn-jit/src/kernel/reduce/argmin_dim.rs index 139dc4b960..c87d4a3b75 100644 --- a/crates/burn-jit/src/kernel/reduce/argmin_dim.rs +++ b/crates/burn-jit/src/kernel/reduce/argmin_dim.rs @@ -1,5 +1,5 @@ use crate::{ - codegen::dialect::gpu::{gpu, Elem, Item, Scope, Variable}, + codegen::dialect::gpu::{cube_inline, Elem, Item, Scope, Variable}, JitElement, }; @@ -19,7 +19,7 @@ impl ReduceDimAlgorithm for ArgMin { let min = scope.create_local(input_item); let min_initial = Variable::ConstantScalar(E::maximum_value().to_f64().unwrap(), input_item.elem()); - gpu!(scope, min = min_initial); + cube_inline!(scope, min = min_initial); (min, index) } @@ -31,10 +31,10 @@ impl ReduceDimAlgorithm for ArgMin { i: Variable, ) { let condition = scope.create_local(Elem::Bool); - gpu!(scope, condition = value < min); - gpu!(scope, if(condition).then(|scope| { - gpu!(scope, min = value); - gpu!(scope, index = i); + cube_inline!(scope, condition = value < min); + cube_inline!(scope, if(condition).then(|scope| { + cube_inline!(scope, min = value); + cube_inline!(scope, index = i); })); } @@ -45,7 +45,7 @@ impl ReduceDimAlgorithm for ArgMin { _shape_reduce_dim: Variable, ) { let id = Variable::Id; - gpu!(scope, output[id] = index); + cube_inline!(scope, output[id] = index); } fn initialize_shared( @@ -58,7 +58,7 @@ impl ReduceDimAlgorithm for ArgMin { let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size); let min = Variable::ConstantScalar(E::maximum_value().to_f64().unwrap(), input_item.elem()); - gpu!(scope, value_shared_memory[write_position] = min); + cube_inline!(scope, value_shared_memory[write_position] = min); (value_shared_memory, index_shared_memory) } @@ -70,13 +70,13 @@ impl ReduceDimAlgorithm for ArgMin { ) { let (value_shared_memory, index_shared_memory) = shared_memory; let current_value = scope.create_local(value.item()); - gpu!(scope, current_value = value_shared_memory[write_position]); + cube_inline!(scope, current_value = value_shared_memory[write_position]); let condition = scope.create_local(Elem::Bool); - gpu!(scope, condition = value < current_value); - gpu!(scope, if(condition).then(|scope| { - gpu!(scope, value_shared_memory[write_position] = value); - gpu!(scope, index_shared_memory[write_position] = index); + cube_inline!(scope, condition = value < current_value); + cube_inline!(scope, if(condition).then(|scope| { + cube_inline!(scope, value_shared_memory[write_position] = value); + cube_inline!(scope, index_shared_memory[write_position] = index); })); } @@ -87,7 +87,7 @@ impl ReduceDimAlgorithm for ArgMin { i: Variable, ) -> Self::Accumulator { let value = scope.create_local(input.item()); - gpu!(scope, value = input[read_position]); + cube_inline!(scope, value = input[read_position]); (value, i) } @@ -98,9 +98,9 @@ impl ReduceDimAlgorithm for ArgMin { ) -> Self::Accumulator { let (value_shared_memory, index_shared_memory) = shared_memory; let value = scope.create_local(value_shared_memory.item()); - gpu!(scope, value = value_shared_memory[read_position]); + cube_inline!(scope, value = value_shared_memory[read_position]); let index = scope.create_local(index_shared_memory.item()); - gpu!(scope, index = index_shared_memory[read_position]); + cube_inline!(scope, index = index_shared_memory[read_position]); (value, index) } @@ -113,7 +113,7 @@ impl ReduceDimAlgorithm for ArgMin { ) { let (_, index_shared_memory) = shared_memory; let final_value = scope.create_local(output.item()); - gpu!(scope, final_value = index_shared_memory[0]); - gpu!(scope, output[write_position] = final_value); + cube_inline!(scope, final_value = index_shared_memory[0]); + cube_inline!(scope, output[write_position] = final_value); } } diff --git a/crates/burn-jit/src/kernel/reduce/mean_dim.rs b/crates/burn-jit/src/kernel/reduce/mean_dim.rs index e547bab8a8..8d4887936f 100644 --- a/crates/burn-jit/src/kernel/reduce/mean_dim.rs +++ b/crates/burn-jit/src/kernel/reduce/mean_dim.rs @@ -1,5 +1,5 @@ use crate::{ - codegen::dialect::gpu::{gpu, Item, Scope, Variable}, + codegen::dialect::gpu::{cube_inline, Item, Scope, Variable}, JitElement, }; @@ -15,7 +15,7 @@ impl ReduceDimAlgorithm for MeanDim { } fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) { - gpu!(scope, accumulator += value); + cube_inline!(scope, accumulator += value); } fn assign_naive( @@ -26,9 +26,9 @@ impl ReduceDimAlgorithm for MeanDim { ) { let id = Variable::Id; let denominator = scope.create_local(accumulator.item()); - gpu!(scope, denominator = cast(shape_reduce_dim)); - gpu!(scope, accumulator = accumulator / denominator); - gpu!(scope, output[id] = accumulator); + cube_inline!(scope, denominator = cast(shape_reduce_dim)); + cube_inline!(scope, accumulator = accumulator / denominator); + cube_inline!(scope, output[id] = accumulator); } fn initialize_shared( @@ -39,7 +39,7 @@ impl ReduceDimAlgorithm for MeanDim { ) -> Self::Accumulator { let shared_memory = scope.create_shared(input_item, shared_memory_size); let neutral_element = scope.zero(shared_memory.item()); - gpu!(scope, shared_memory[write_position] = neutral_element); + cube_inline!(scope, shared_memory[write_position] = neutral_element); shared_memory } @@ -51,9 +51,9 @@ impl ReduceDimAlgorithm for MeanDim { ) { let current_value = scope.create_local(value.item()); let computed = scope.create_local(value.item()); - gpu!(scope, current_value = shared_memory[write_position]); - gpu!(scope, computed = current_value + value); - gpu!(scope, shared_memory[write_position] = computed); + cube_inline!(scope, current_value = shared_memory[write_position]); + cube_inline!(scope, computed = current_value + value); + cube_inline!(scope, shared_memory[write_position] = computed); } fn read_from_input( @@ -63,7 +63,7 @@ impl ReduceDimAlgorithm for MeanDim { _i: Variable, ) -> Self::Accumulator { let value = scope.create_local(input.item()); - gpu!(scope, value = input[read_position]); + cube_inline!(scope, value = input[read_position]); value } @@ -73,7 +73,7 @@ impl ReduceDimAlgorithm for MeanDim { read_position: Variable, ) -> Variable { let read_value = scope.create_local(shared_memory.item()); - gpu!(scope, read_value = shared_memory[read_position]); + cube_inline!(scope, read_value = shared_memory[read_position]); read_value } @@ -85,11 +85,11 @@ impl ReduceDimAlgorithm for MeanDim { shape_reduce_dim: Variable, ) { let final_value = scope.create_local(output.item()); - gpu!(scope, final_value = shared_memory[0]); + cube_inline!(scope, final_value = shared_memory[0]); let denominator = scope.create_local(output.item()); - gpu!(scope, denominator = cast(shape_reduce_dim)); - gpu!(scope, final_value = final_value / denominator); - gpu!(scope, output[write_position] = final_value); + cube_inline!(scope, denominator = cast(shape_reduce_dim)); + cube_inline!(scope, final_value = final_value / denominator); + cube_inline!(scope, output[write_position] = final_value); } } diff --git a/crates/burn-jit/src/kernel/reduce/naive_reduce_shader.rs b/crates/burn-jit/src/kernel/reduce/naive_reduce_shader.rs index a28c016ea8..cc68dce061 100644 --- a/crates/burn-jit/src/kernel/reduce/naive_reduce_shader.rs +++ b/crates/burn-jit/src/kernel/reduce/naive_reduce_shader.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use crate::{ codegen::{ - dialect::gpu::{gpu, Elem, Scope, Variable, Visibility}, + dialect::gpu::{cube_inline, Elem, Scope, Variable, Visibility}, Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, @@ -92,45 +92,45 @@ impl> NaiveReduceDimComputeShader ReduceDimAlgorithm for ProdDim { } fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) { - gpu!(scope, accumulator *= value); + cube_inline!(scope, accumulator *= value); } fn assign_naive( @@ -25,7 +25,7 @@ impl ReduceDimAlgorithm for ProdDim { _shape_reduce_dim: Variable, ) { let id = Variable::Id; - gpu!(scope, output[id] = accumulator); + cube_inline!(scope, output[id] = accumulator); } fn initialize_shared( @@ -36,7 +36,7 @@ impl ReduceDimAlgorithm for ProdDim { ) -> Self::Accumulator { let shared_memory = scope.create_shared(input_item, shared_memory_size); let neutral_element = scope.create_with_value(1, shared_memory.item()); - gpu!(scope, shared_memory[write_position] = neutral_element); + cube_inline!(scope, shared_memory[write_position] = neutral_element); shared_memory } @@ -48,9 +48,9 @@ impl ReduceDimAlgorithm for ProdDim { ) { let current_value = scope.create_local(value.item()); let computed = scope.create_local(value.item()); - gpu!(scope, current_value = shared_memory[write_position]); - gpu!(scope, computed = current_value * value); - gpu!(scope, shared_memory[write_position] = computed); + cube_inline!(scope, current_value = shared_memory[write_position]); + cube_inline!(scope, computed = current_value * value); + cube_inline!(scope, shared_memory[write_position] = computed); } fn read_from_input( @@ -60,7 +60,7 @@ impl ReduceDimAlgorithm for ProdDim { _i: Variable, ) -> Self::Accumulator { let value = scope.create_local(input.item()); - gpu!(scope, value = input[read_position]); + cube_inline!(scope, value = input[read_position]); value } @@ -70,7 +70,7 @@ impl ReduceDimAlgorithm for ProdDim { read_position: Variable, ) -> Self::Accumulator { let read_value = scope.create_local(shared_memory.item()); - gpu!(scope, read_value = shared_memory[read_position]); + cube_inline!(scope, read_value = shared_memory[read_position]); read_value } @@ -82,7 +82,7 @@ impl ReduceDimAlgorithm for ProdDim { _shape_reduce_dim: Variable, ) { let final_value = scope.create_local(output.item()); - gpu!(scope, final_value = shared_memory[0]); - gpu!(scope, output[write_position] = final_value); + cube_inline!(scope, final_value = shared_memory[0]); + cube_inline!(scope, output[write_position] = final_value); } } diff --git a/crates/burn-jit/src/kernel/reduce/shared_reduce_shader.rs b/crates/burn-jit/src/kernel/reduce/shared_reduce_shader.rs index 567f5549bd..3d9af3b90f 100644 --- a/crates/burn-jit/src/kernel/reduce/shared_reduce_shader.rs +++ b/crates/burn-jit/src/kernel/reduce/shared_reduce_shader.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use crate::{ codegen::{ dialect::gpu::{ - gpu, Branch, Elem, Scope, Synchronization, Variable, Visibility, WorkgroupSize, + cube_inline, Branch, Elem, Scope, Synchronization, Variable, Visibility, WorkgroupSize, }, Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, @@ -124,41 +124,41 @@ impl> SharedReduceDimComputeShader> SharedReduceDimComputeShader> SharedReduceDimComputeShader> SharedReduceDimComputeShader ReduceDimAlgorithm for SumDim { } fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) { - gpu!(scope, accumulator += value); + cube_inline!(scope, accumulator += value); } fn assign_naive( @@ -25,7 +25,7 @@ impl ReduceDimAlgorithm for SumDim { _shape_reduce_dim: Variable, ) { let id = Variable::Id; - gpu!(scope, output[id] = accumulator); + cube_inline!(scope, output[id] = accumulator); } fn initialize_shared( @@ -36,7 +36,7 @@ impl ReduceDimAlgorithm for SumDim { ) -> Self::Accumulator { let shared_memory = scope.create_shared(input_item, shared_memory_size); let neutral_element = scope.zero(shared_memory.item()); - gpu!(scope, shared_memory[write_position] = neutral_element); + cube_inline!(scope, shared_memory[write_position] = neutral_element); shared_memory } @@ -48,9 +48,9 @@ impl ReduceDimAlgorithm for SumDim { ) { let current_value = scope.create_local(value.item()); let computed = scope.create_local(value.item()); - gpu!(scope, current_value = shared_memory[write_position]); - gpu!(scope, computed = current_value + value); - gpu!(scope, shared_memory[write_position] = computed); + cube_inline!(scope, current_value = shared_memory[write_position]); + cube_inline!(scope, computed = current_value + value); + cube_inline!(scope, shared_memory[write_position] = computed); } fn read_from_input( @@ -60,7 +60,7 @@ impl ReduceDimAlgorithm for SumDim { _i: Variable, ) -> Self::Accumulator { let value = scope.create_local(input.item()); - gpu!(scope, value = input[read_position]); + cube_inline!(scope, value = input[read_position]); value } @@ -70,7 +70,7 @@ impl ReduceDimAlgorithm for SumDim { read_position: Variable, ) -> Self::Accumulator { let read_value = scope.create_local(shared_memory.item()); - gpu!(scope, read_value = shared_memory[read_position]); + cube_inline!(scope, read_value = shared_memory[read_position]); read_value } @@ -82,7 +82,7 @@ impl ReduceDimAlgorithm for SumDim { _shape_reduce_dim: Variable, ) { let final_value = scope.create_local(output.item()); - gpu!(scope, final_value = shared_memory[0]); - gpu!(scope, output[write_position] = final_value); + cube_inline!(scope, final_value = shared_memory[0]); + cube_inline!(scope, output[write_position] = final_value); } } From da92893331a43acfe56d2e88390f4e494eeb6059 Mon Sep 17 00:00:00 2001 From: louisfd Date: Fri, 10 May 2024 14:23:01 -0400 Subject: [PATCH 33/54] refactor --- crates/burn-cube-macros/src/analysis.rs | 9 +- crates/burn-cube-macros/src/codegen.rs | 104 ++++++---- crates/burn-cube-macros/src/lib.rs | 32 ++- crates/burn-cube-macros/src/prelude.rs | 191 ------------------ crates/burn-cube/src/element/array.rs | 4 +- crates/burn-cube/src/element/base.rs | 6 +- crates/burn-cube/src/element/bool.rs | 4 +- crates/burn-cube/src/element/float.rs | 24 ++- crates/burn-cube/src/element/int.rs | 25 ++- crates/burn-cube/src/element/mod.rs | 2 + crates/burn-cube/src/element/numeric.rs | 10 + crates/burn-cube/src/element/primitive.rs | 10 +- crates/burn-cube/src/element/uint.rs | 19 +- crates/burn-cube/src/elemtype.rs | 111 ++++++++++ crates/burn-cube/src/lib.rs | 6 +- crates/burn-cube/src/operation/assignation.rs | 10 +- crates/burn-cube/src/operation/binary.rs | 117 +++++++---- crates/burn-cube/tests/cast_elem.rs | 70 +++---- crates/burn-cube/tests/cast_kind.rs | 34 +++- crates/burn-cube/tests/for_loop.rs | 6 +- crates/burn-cube/tests/function_call.rs | 68 +++++++ crates/burn-cube/tests/generic_kernel.rs | 58 ++++++ crates/burn-cube/tests/if.rs | 13 +- crates/burn-cube/tests/if_else.rs | 12 +- crates/burn-cube/tests/literal.rs | 14 +- crates/burn-cube/tests/loop.rs | 16 +- crates/burn-cube/tests/reuse.rs | 14 +- crates/burn-jit/src/lib.rs | 1 + 28 files changed, 576 insertions(+), 414 deletions(-) delete mode 100644 crates/burn-cube-macros/src/prelude.rs create mode 100644 crates/burn-cube/src/element/numeric.rs create mode 100644 crates/burn-cube/src/elemtype.rs create mode 100644 crates/burn-cube/tests/function_call.rs create mode 100644 crates/burn-cube/tests/generic_kernel.rs diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs index 1f6cd9c41a..856d11dd24 100644 --- a/crates/burn-cube-macros/src/analysis.rs +++ b/crates/burn-cube-macros/src/analysis.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use syn::{PathArguments, Stmt}; @@ -23,7 +23,6 @@ impl VariableAnalysis { #[derive(Debug)] pub(crate) struct CodeAnalysis { - pub needed_functions: HashSet, pub variable_analyses: HashMap, } @@ -31,7 +30,6 @@ pub(crate) struct CodeAnalysis { pub(crate) struct CodeAnalysisBuilder { declarations: Vec<(VariableKey, usize)>, var_uses: Vec, - function_calls: HashSet, } impl CodeAnalysis { @@ -61,7 +59,6 @@ impl CodeAnalysisBuilder { CodeAnalysis { variable_analyses: self.to_map(), - needed_functions: self.function_calls, } } @@ -206,10 +203,6 @@ impl CodeAnalysisBuilder { match &*expr.func { syn::Expr::Path(expr_path) => { if let Some(first_segment) = expr_path.path.segments.first() { - // Extract the identifier of the path segment - let ident = &first_segment.ident; - self.function_calls.insert(ident.into()); - // Check if the path segment has generic arguments if let PathArguments::AngleBracketed(arguments) = &first_segment.arguments diff --git a/crates/burn-cube-macros/src/codegen.rs b/crates/burn-cube-macros/src/codegen.rs index b5b4d3e2da..74c38871d0 100644 --- a/crates/burn-cube-macros/src/codegen.rs +++ b/crates/burn-cube-macros/src/codegen.rs @@ -1,5 +1,6 @@ use proc_macro2::TokenStream; -use syn::PathArguments; +use quote::{quote_spanned, ToTokens}; +use syn::{Lit, PathArguments}; use crate::analysis::CodeAnalysis; @@ -28,32 +29,31 @@ fn codegen_local( loop_level: usize, variable_analyses: &mut CodeAnalysis, ) -> TokenStream { - let init = local - .init - .as_ref() - .expect("Can't use let without an initialization."); - - let init = codegen_expr(&init.expr, loop_level, variable_analyses); - let let_tok = local.let_token; - if let syn::Pat::Wild(_) = &local.pat { - return quote::quote! { - #let_tok _ = #init; - }; - } - let ident = match &local.pat { - syn::Pat::Ident(ident) => ident, + syn::Pat::Ident(ident) => ident.to_token_stream(), syn::Pat::Type(pat_type) => match &*pat_type.pat { - syn::Pat::Ident(pat_ident) => pat_ident, + syn::Pat::Ident(pat_ident) => pat_ident.to_token_stream(), _ => todo!("Codegen: Unsupported typed path {:?}", pat_type.pat), }, - syn::Pat::Wild(_) => unreachable!(), + syn::Pat::Wild(wild) => wild.underscore_token.to_token_stream(), _ => todo!("Codegen: Declaration {:?} is unsupported.", local.pat), }; - quote::quote! { - #let_tok #ident = #init; + + match local.init.as_ref() { + Some(init) => { + let init = codegen_expr(&init.expr, loop_level, variable_analyses); + + quote::quote! { + #let_tok #ident = #init; + } + } + None => { + quote::quote! { + #let_tok #ident; + } + } } } @@ -84,7 +84,17 @@ fn codegen_expr( } fn codegen_lit(lit: &syn::ExprLit) -> TokenStream { - quote::quote! { #lit.into() } + match lit.lit { + // We treat floats differently to avoid getting 4..into() for instance + Lit::Float(_) => { + let lit_str = lit.lit.to_token_stream().to_string(); + let float_lit = lit_str.parse::().unwrap(); + quote::quote! { #float_lit.into() } + } + _ => { + quote::quote! { #lit.into() } + } + } } fn codegen_expr_index( @@ -308,45 +318,57 @@ fn codegen_call( loop_level: usize, variable_analyses: &mut CodeAnalysis, ) -> TokenStream { - let (func_name, generics) = match call.func.as_ref() { + // Possibilities: + // a() + // a::() + // T::a + let (mut idents, generics) = match call.func.as_ref() { syn::Expr::Path(expr_path) => { - if let Some(first_segment) = expr_path.path.segments.first() { - // Extract the identifier of the path segment - let ident = &first_segment.ident; - let generics = - if let PathArguments::AngleBracketed(arguments) = &first_segment.arguments { - Some(arguments) - } else { - None - }; - - (ident, generics) - } else { - panic!("Codegen: func call must have an ident"); + let mut idents = Vec::new(); + let mut generics = None; + for (index, segment) in expr_path.path.segments.iter().enumerate() { + idents.push(&segment.ident); + + if index == expr_path.path.segments.len() - 1 { + if let PathArguments::AngleBracketed(arguments) = &segment.arguments { + generics = Some(arguments) + } + } } + (idents, generics) } _ => todo!("Codegen: func call {:?} not supported", call.func), }; - let mut args = quote::quote! { - context, - }; + let func_name = idents + .pop() + .expect("Codegen: Func should have at least one ident"); + + let mut previous_tokens = TokenStream::new(); + for ident in idents.iter() { + previous_tokens.extend(quote_spanned! {ident.span() => #ident :: }); + } + + let func_name_expand = syn::Ident::new( + format!("{func_name}_expand").as_str(), + proc_macro2::Span::call_site(), + ); let generics = match generics { Some(generics) => quote::quote! { #generics }, None => quote::quote! {}, }; - let func_name_expand = - syn::Ident::new(format!("{func_name}_expand").as_str(), func_name.span()); - + let mut args = quote::quote! { + context, + }; for argument in call.args.iter() { let arg = codegen_expr(argument, loop_level, variable_analyses); args.extend(quote::quote! { #arg, }); } quote::quote! { - #func_name_expand #generics (#args) + #previous_tokens #func_name_expand #generics (#args) } } diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index b91f45935c..19265f68f8 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -1,10 +1,8 @@ mod analysis; mod codegen; -mod prelude; use analysis::CodeAnalysis; use codegen::codegen_statement; -use prelude::get_prelude; use proc_macro::TokenStream; use quote::ToTokens; use syn::{parse_macro_input, punctuated::Punctuated, Meta}; @@ -59,8 +57,8 @@ impl From<&syn::Ident> for VariableKey { /// Generate the expanded version of a function marked with the cube macro fn codegen_cube(func: &syn::ItemFn, code_analysis: &mut CodeAnalysis) -> TokenStream { - let prelude = get_prelude(&code_analysis.needed_functions); - let mod_name = get_name(&func.sig); + // let prelude = get_prelude(&code_analysis.needed_functions); + // let mod_name = get_name(&func.sig); let signature = expand_sig(&func.sig); let mut body = quote::quote! {}; @@ -70,8 +68,8 @@ fn codegen_cube(func: &syn::ItemFn, code_analysis: &mut CodeAnalysis) -> TokenSt } let code = quote::quote! { - mod #mod_name { - #prelude + // mod #mod_name { + // #prelude #[allow(dead_code)] #func @@ -80,21 +78,21 @@ fn codegen_cube(func: &syn::ItemFn, code_analysis: &mut CodeAnalysis) -> TokenSt #signature { #body } - } + // } } .into(); code } -fn get_name(sig: &syn::Signature) -> proc_macro2::TokenStream { - let ident = &sig.ident; +// fn get_name(sig: &syn::Signature) -> proc_macro2::TokenStream { +// let ident = &sig.ident; - quote::quote! { - #ident - } - .into() -} +// quote::quote! { +// #ident +// } +// .into() +// } fn expand_sig(sig: &syn::Signature) -> proc_macro2::TokenStream { let mut inputs = quote::quote!(); @@ -106,7 +104,7 @@ fn expand_sig(sig: &syn::Signature) -> proc_macro2::TokenStream { let ident = pat.pat.clone(); inputs.extend(quote::quote! { - #ident: <#ty as burn_cube::RuntimeType>::ExpandType, + #ident: <#ty as burn_cube::CubeType>::ExpandType, }); } _ => todo!(), @@ -119,13 +117,13 @@ fn expand_sig(sig: &syn::Signature) -> proc_macro2::TokenStream { syn::ReturnType::Default => output.extend(quote::quote! { ()}), syn::ReturnType::Type(_, ty) => { output.extend(quote::quote! { - <#ty as burn_cube::RuntimeType>::ExpandType + <#ty as burn_cube::CubeType>::ExpandType }); } } let ident = &sig.ident; - let ident = syn::Ident::new("expand", ident.span()); + let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span()); let generics = sig.generics.clone().into_token_stream(); diff --git a/crates/burn-cube-macros/src/prelude.rs b/crates/burn-cube-macros/src/prelude.rs deleted file mode 100644 index f69196d762..0000000000 --- a/crates/burn-cube-macros/src/prelude.rs +++ /dev/null @@ -1,191 +0,0 @@ -use std::collections::HashSet; - -use crate::VariableKey; - -pub(crate) fn get_prelude(needed_functions: &HashSet) -> proc_macro2::TokenStream { - let mut prelude = quote::quote! { - use super::*; - }; - - for func_name in needed_functions - .iter() - .map(|variable| variable.name.as_str()) - { - let func_code = match func_name { - "float_new" => Some(codegen_float_new()), - "int_new" => Some(codegen_int_new()), - "uint_new" => Some(codegen_uint_new()), - "bool_new" => Some(codegen_bool_new()), - "to_int" => Some(codegen_to_int()), - "to_float" => Some(codegen_to_float()), - "to_uint" => Some(codegen_to_uint()), - "to_bool" => Some(codegen_to_bool()), - _ => None, - }; - - if func_code.is_some() { - prelude.extend(func_code); - } - } - - prelude -} - -fn codegen_float_new() -> proc_macro2::TokenStream { - quote::quote! { - pub fn float_new(val: f32) -> F { - F::new(val, 1) - } - pub fn float_new_expand( - context: &mut CubeContext, - val: f32, - ) -> ::ExpandType { - F::new_expand(val) - } - } -} - -fn codegen_int_new() -> proc_macro2::TokenStream { - quote::quote! { - pub fn int_new(val: i32) -> I { - I::new(val, 1) - } - pub fn int_new_expand( - context: &mut CubeContext, - val: i32, - ) -> ::ExpandType { - I::new_expand(val) - } - } -} - -fn codegen_uint_new() -> proc_macro2::TokenStream { - quote::quote! { - pub fn uint_new(val: u32) -> UInt { - UInt { - val, - vectorization: 1, - } - } - pub fn uint_new_expand( - context: &mut CubeContext, - val: u32, - ) -> ::ExpandType { - val.into() - } - } -} - -fn codegen_bool_new() -> proc_macro2::TokenStream { - quote::quote! { - pub fn bool_new(val: bool) -> Bool{ - Bool { - val, - vectorization: 1, - } - } - pub fn bool_new_expand( - context: &mut CubeContext, - val: bool, - ) -> ::ExpandType { - val.into() - } - } -} - -fn codegen_to_int() -> proc_macro2::TokenStream { - quote::quote! { - pub fn to_int(input: R) -> I { - I::new(0, 1) - } - pub fn to_int_expand( - context: &mut CubeContext, - val: burn_cube::ExpandElement, - ) -> ::ExpandType { - let elem = Elem::Int(I::into_kind()); - let new_var = context.create_local(match val.item() { - Item::Vec4(_) => Item::Vec4(elem), - Item::Vec3(_) => Item::Vec3(elem), - Item::Vec2(_) => Item::Vec2(elem), - Item::Scalar(_) => Item::Scalar(elem), - }); - burn_cube::assign::expand(context, val.into(), new_var.clone()); - new_var - } - } -} - -fn codegen_to_float() -> proc_macro2::TokenStream { - // R: type we come from - // F: kind of float we want as output - quote::quote! { - pub fn to_float(input: R) -> F { - // TODO: make val and vectorization accessible through trait - F::new(0., 1) - } - pub fn to_float_expand( - context: &mut CubeContext, - val: burn_cube::ExpandElement, - ) -> burn_cube::ExpandElement { - let elem = Elem::Float(F::into_kind()); - let new_var = context.create_local(match val.item() { - Item::Vec4(_) => Item::Vec4(elem), - Item::Vec3(_) => Item::Vec3(elem), - Item::Vec2(_) => Item::Vec2(elem), - Item::Scalar(_) => Item::Scalar(elem), - }); - burn_cube::assign::expand(context, val.into(), new_var.clone()); - new_var - } - } -} - -fn codegen_to_uint() -> proc_macro2::TokenStream { - quote::quote! { - pub fn to_uint(input: R) -> UInt { - UInt { - val: 0, - vectorization: 1, - } - } - pub fn to_uint_expand( - context: &mut CubeContext, - val: burn_cube::ExpandElement, - ) -> ::ExpandType { - let elem = Elem::UInt; - let new_var = context.create_local(match val.item() { - Item::Vec4(_) => Item::Vec4(elem), - Item::Vec3(_) => Item::Vec3(elem), - Item::Vec2(_) => Item::Vec2(elem), - Item::Scalar(_) => Item::Scalar(elem), - }); - burn_cube::assign::expand(context, val.into(), new_var.clone()); - new_var - } - } -} - -fn codegen_to_bool() -> proc_macro2::TokenStream { - quote::quote! { - pub fn to_bool(input: R) -> Bool { - Bool { - val: true, - vectorization: 1, - } - } - pub fn to_bool_expand( - context: &mut CubeContext, - val: burn_cube::ExpandElement, - ) -> ::ExpandType { - let elem = Elem::Bool; - let new_var = context.create_local(match val.item() { - Item::Vec4(_) => Item::Vec4(elem), - Item::Vec3(_) => Item::Vec3(elem), - Item::Vec2(_) => Item::Vec2(elem), - Item::Scalar(_) => Item::Scalar(elem), - }); - burn_cube::assign::expand(context, val.into(), new_var.clone()); - new_var - } - } -} diff --git a/crates/burn-cube/src/element/array.rs b/crates/burn-cube/src/element/array.rs index edab275011..234f6e0c45 100644 --- a/crates/burn-cube/src/element/array.rs +++ b/crates/burn-cube/src/element/array.rs @@ -1,10 +1,10 @@ -use crate::{ExpandElement, RuntimeType}; +use crate::{ExpandElement, CubeType}; #[derive(new, Clone)] pub struct Array { pub vals: Vec, } -impl RuntimeType for Array { +impl CubeType for Array { type ExpandType = ExpandElement; } diff --git a/crates/burn-cube/src/element/base.rs b/crates/burn-cube/src/element/base.rs index a14cd24ea8..0b8ecb8c64 100644 --- a/crates/burn-cube/src/element/base.rs +++ b/crates/burn-cube/src/element/base.rs @@ -4,16 +4,16 @@ use burn_jit::gpu::{Item, Variable}; /// Types used in a cube function must implement this trait /// /// Variables whose values will be known at runtime must -/// have ExpandElement as associated type +/// have ExpandElement as associated type (using RuntimeType) /// Variables whose values will be known at compile time /// must have the primitive type as associated type /// -/// Note: Cube functions should be written using RuntimeTypes, +/// Note: Cube functions should be written using CubeTypes, /// so that the code generated uses the associated ExpandType. /// This allows Cube code to not necessitate cloning, which is cumbersome /// in algorithmic code. The necessary cloning will automatically appear in /// the generated code. -pub trait RuntimeType { +pub trait CubeType { type ExpandType: Clone; } diff --git a/crates/burn-cube/src/element/bool.rs b/crates/burn-cube/src/element/bool.rs index dc3d0c01c3..c2f99c4c21 100644 --- a/crates/burn-cube/src/element/bool.rs +++ b/crates/burn-cube/src/element/bool.rs @@ -1,4 +1,4 @@ -use crate::{ExpandElement, RuntimeType}; +use crate::{CubeType, ExpandElement}; #[derive(new, Clone, Copy)] pub struct Bool { @@ -6,6 +6,6 @@ pub struct Bool { pub vectorization: u8, } -impl RuntimeType for Bool { +impl CubeType for Bool { type ExpandType = ExpandElement; } diff --git a/crates/burn-cube/src/element/float.rs b/crates/burn-cube/src/element/float.rs index c8efc03c06..bfab89d20d 100644 --- a/crates/burn-cube/src/element/float.rs +++ b/crates/burn-cube/src/element/float.rs @@ -1,31 +1,29 @@ -use crate::{ExpandElement, RuntimeType}; +use crate::{CubeContext, CubeType, ExpandElement, Numeric}; use burn_jit::gpu::{Elem, FloatKind, Variable}; use std::rc::Rc; pub trait Float: Clone + Copy - + RuntimeType + std::cmp::PartialOrd + std::ops::Add + std::ops::Sub + std::ops::Mul + std::ops::Div + + Numeric { fn into_kind() -> FloatKind; - fn new(val: f32, vectorization: usize) -> Self; - fn new_expand(val: f32) -> ExpandElement; } macro_rules! impl_float { ($type:ident) => { #[derive(Clone, Copy)] pub struct $type { - pub val: f32, + pub val: f64, pub vectorization: usize, } - impl RuntimeType for $type { + impl CubeType for $type { type ExpandType = ExpandElement; } @@ -33,12 +31,18 @@ macro_rules! impl_float { fn into_kind() -> FloatKind { FloatKind::$type } - fn new(val: f32, vectorization: usize) -> Self { - Self { val, vectorization } + } + + impl Numeric for $type { + fn new(val: f64) -> Self { + Self { + val, + vectorization: 1, + } } - fn new_expand(val: f32) -> ExpandElement { + fn new_expand(_context: &mut CubeContext, val: f64) -> ExpandElement { let elem = Elem::Float(Self::into_kind()); - let new_var = Variable::ConstantScalar(val as f64, elem); + let new_var = Variable::ConstantScalar(val, elem); ExpandElement::new(Rc::new(new_var)) } } diff --git a/crates/burn-cube/src/element/int.rs b/crates/burn-cube/src/element/int.rs index f880bc1742..2450dc3702 100644 --- a/crates/burn-cube/src/element/int.rs +++ b/crates/burn-cube/src/element/int.rs @@ -1,32 +1,29 @@ -use crate::{ExpandElement, RuntimeType}; +use crate::{CubeContext, CubeType, ExpandElement, Numeric}; use burn_jit::gpu::{Elem, IntKind, Variable}; use std::rc::Rc; pub trait Int: Clone + Copy - + RuntimeType + std::cmp::PartialOrd - + std::ops::Add + std::ops::Sub + std::ops::Mul + std::ops::Div + std::ops::AddAssign + + Numeric { fn into_kind() -> IntKind; - fn new(val: i32, vectorization: usize) -> Self; - fn new_expand(val: i32) -> ExpandElement; } macro_rules! impl_int { ($type:ident) => { #[derive(Clone, Copy)] pub struct $type { - pub val: i32, + pub val: f64, pub vectorization: usize, } - impl RuntimeType for $type { + impl CubeType for $type { type ExpandType = ExpandElement; } @@ -34,12 +31,18 @@ macro_rules! impl_int { fn into_kind() -> IntKind { IntKind::$type } - fn new(val: i32, vectorization: usize) -> Self { - Self { val, vectorization } + } + + impl Numeric for $type { + fn new(val: f64) -> Self { + Self { + val, + vectorization: 1, + } } - fn new_expand(val: i32) -> ExpandElement { + fn new_expand(_context: &mut CubeContext, val: f64) -> ExpandElement { let elem = Elem::Int(Self::into_kind()); - let new_var = Variable::ConstantScalar(val as f64, elem); + let new_var = Variable::ConstantScalar(val, elem); ExpandElement::new(Rc::new(new_var)) } } diff --git a/crates/burn-cube/src/element/mod.rs b/crates/burn-cube/src/element/mod.rs index 81b0c98a61..9ff6b110b1 100644 --- a/crates/burn-cube/src/element/mod.rs +++ b/crates/burn-cube/src/element/mod.rs @@ -3,6 +3,7 @@ mod base; mod bool; mod float; mod int; +mod numeric; mod primitive; mod uint; @@ -11,4 +12,5 @@ pub use base::*; pub use bool::*; pub use float::*; pub use int::*; +pub use numeric::*; pub use uint::*; diff --git a/crates/burn-cube/src/element/numeric.rs b/crates/burn-cube/src/element/numeric.rs new file mode 100644 index 0000000000..0494b2cfc1 --- /dev/null +++ b/crates/burn-cube/src/element/numeric.rs @@ -0,0 +1,10 @@ +// use crate::{BF16, F16, F32, F64, I32, I64}; + +use crate::{CubeContext, CubeType, ExpandElement}; + +pub trait Numeric: + Clone + Copy + CubeType + std::ops::Add +{ + fn new(val: f64) -> Self; + fn new_expand(context: &mut CubeContext, val: f64) -> ExpandElement; +} diff --git a/crates/burn-cube/src/element/primitive.rs b/crates/burn-cube/src/element/primitive.rs index ef2eebf10c..9cb34b6cbb 100644 --- a/crates/burn-cube/src/element/primitive.rs +++ b/crates/burn-cube/src/element/primitive.rs @@ -2,21 +2,21 @@ use std::rc::Rc; use burn_jit::gpu::Variable; -use crate::{ExpandElement, RuntimeType}; +use crate::{ExpandElement, CubeType}; -impl RuntimeType for bool { +impl CubeType for bool { type ExpandType = bool; } -impl RuntimeType for u32 { +impl CubeType for u32 { type ExpandType = u32; } -impl RuntimeType for f32 { +impl CubeType for f32 { type ExpandType = f32; } -impl RuntimeType for i32 { +impl CubeType for i32 { type ExpandType = i32; } diff --git a/crates/burn-cube/src/element/uint.rs b/crates/burn-cube/src/element/uint.rs index 917274bd68..1e255c0828 100644 --- a/crates/burn-cube/src/element/uint.rs +++ b/crates/burn-cube/src/element/uint.rs @@ -1,23 +1,32 @@ -use crate::{ExpandElement, RuntimeType}; +use crate::{ExpandElement, CubeType}; -#[derive(new, Clone, Copy)] +#[derive(Clone, Copy)] pub struct UInt { pub val: u32, pub vectorization: u8, } -impl RuntimeType for UInt { +impl UInt { + pub fn new(val: u32) -> Self { + Self { + val, + vectorization: 1, + } + } +} + +impl CubeType for UInt { type ExpandType = ExpandElement; } impl From for UInt { fn from(value: u32) -> Self { - UInt::new(value, 1) + UInt::new(value) } } impl From for UInt { fn from(value: usize) -> Self { - UInt::new(value as u32, 1) + UInt::new(value as u32) } } diff --git a/crates/burn-cube/src/elemtype.rs b/crates/burn-cube/src/elemtype.rs new file mode 100644 index 0000000000..19a6c3b3d7 --- /dev/null +++ b/crates/burn-cube/src/elemtype.rs @@ -0,0 +1,111 @@ +use burn_jit::gpu::{Elem, Item}; + +use crate::{assign, Bool, CubeContext, CubeType, ExpandElement, Float, Int, Numeric, UInt}; + +// pub fn new(val: f) -> T { +// T::new(val) +// } + +// pub fn new_expand(_context: &mut CubeContext, val: f32) -> ::ExpandType { +// T::new_expand(_context, val) +// } + +pub fn uint_new(val: u32) -> UInt { + UInt { + val, + vectorization: 1, + } +} +pub fn uint_new_expand(_context: &mut CubeContext, val: u32) -> ::ExpandType { + val.into() +} + +pub fn bool_new(val: bool) -> Bool { + Bool { + val, + vectorization: 1, + } +} +pub fn bool_new_expand(_context: &mut CubeContext, val: bool) -> ::ExpandType { + val.into() +} + +pub fn to_int(_input: R) -> I { + I::new(0.) +} +pub fn to_int_expand( + context: &mut CubeContext, + val: ExpandElement, +) -> ::ExpandType { + let elem = Elem::Int(I::into_kind()); + let new_var = context.create_local(match val.item() { + Item::Vec4(_) => Item::Vec4(elem), + Item::Vec3(_) => Item::Vec3(elem), + Item::Vec2(_) => Item::Vec2(elem), + Item::Scalar(_) => Item::Scalar(elem), + }); + assign::expand(context, val.into(), new_var.clone()); + new_var +} + +pub fn to_float(_input: R) -> F { + // TODO: make val accessible through trait + F::new(0.) +} + +pub fn to_float_expand( + context: &mut CubeContext, + val: ExpandElement, +) -> ExpandElement { + let elem = Elem::Float(F::into_kind()); + let new_var = context.create_local(match val.item() { + Item::Vec4(_) => Item::Vec4(elem), + Item::Vec3(_) => Item::Vec3(elem), + Item::Vec2(_) => Item::Vec2(elem), + Item::Scalar(_) => Item::Scalar(elem), + }); + assign::expand(context, val.into(), new_var.clone()); + new_var +} + +pub fn to_uint(_input: R) -> UInt { + UInt { + val: 0, + vectorization: 1, + } +} +pub fn to_uint_expand( + context: &mut CubeContext, + val: ExpandElement, +) -> ::ExpandType { + let elem = Elem::UInt; + let new_var = context.create_local(match val.item() { + Item::Vec4(_) => Item::Vec4(elem), + Item::Vec3(_) => Item::Vec3(elem), + Item::Vec2(_) => Item::Vec2(elem), + Item::Scalar(_) => Item::Scalar(elem), + }); + assign::expand(context, val.into(), new_var.clone()); + new_var +} + +pub fn to_bool(_input: R) -> Bool { + Bool { + val: true, + vectorization: 1, + } +} +pub fn to_bool_expand( + context: &mut CubeContext, + val: ExpandElement, +) -> ::ExpandType { + let elem = Elem::Bool; + let new_var = context.create_local(match val.item() { + Item::Vec4(_) => Item::Vec4(elem), + Item::Vec3(_) => Item::Vec3(elem), + Item::Vec2(_) => Item::Vec2(elem), + Item::Scalar(_) => Item::Scalar(elem), + }); + assign::expand(context, val.into(), new_var.clone()); + new_var +} diff --git a/crates/burn-cube/src/lib.rs b/crates/burn-cube/src/lib.rs index 78e37f19ea..1a2043d6c3 100644 --- a/crates/burn-cube/src/lib.rs +++ b/crates/burn-cube/src/lib.rs @@ -3,12 +3,14 @@ extern crate alloc; #[macro_use] extern crate derive_new; -mod branch; +// For use with * +pub mod branch; +pub mod elemtype; + mod context; mod element; mod operation; -pub use branch::*; pub use context::*; pub use element::*; pub use operation::*; diff --git a/crates/burn-cube/src/operation/assignation.rs b/crates/burn-cube/src/operation/assignation.rs index ff5eba0ad3..4c768354c4 100644 --- a/crates/burn-cube/src/operation/assignation.rs +++ b/crates/burn-cube/src/operation/assignation.rs @@ -9,13 +9,11 @@ pub mod assign { let out = *output; context.register(gpu::Operator::Assign(gpu::UnaryOperator { input, out })); - - // output } } pub mod index_assign { - use crate::RuntimeType; + use crate::CubeType; use super::*; @@ -32,7 +30,7 @@ pub mod index_assign { })) } - impl> core::ops::IndexMut for Array { + impl> core::ops::IndexMut for Array { fn index_mut(&mut self, index: I) -> &mut Self::Output { let index = index.into().val; &mut self.vals[index as usize] @@ -41,7 +39,7 @@ pub mod index_assign { } pub mod index { - use crate::{operation::base::binary_expand, RuntimeType}; + use crate::{operation::base::binary_expand, CubeType}; use super::*; @@ -53,7 +51,7 @@ pub mod index { binary_expand(context, array, index, gpu::Operator::Index) } - impl> core::ops::Index for Array { + impl> core::ops::Index for Array { type Output = E; fn index(&self, index: I) -> &Self::Output { diff --git a/crates/burn-cube/src/operation/binary.rs b/crates/burn-cube/src/operation/binary.rs index d8b370c653..21f281eb20 100644 --- a/crates/burn-cube/src/operation/binary.rs +++ b/crates/burn-cube/src/operation/binary.rs @@ -1,5 +1,5 @@ use crate::operation::base::binary_expand; -use crate::{CubeContext, ExpandElement, Float, Int, UInt, BF16, F16, F32, F64, I32, I64}; +use crate::{CubeContext, ExpandElement, Float, Int, Numeric, UInt, BF16, F16, F32, F64, I32, I64}; use burn_jit::gpu::{self}; pub mod add { @@ -15,23 +15,33 @@ pub mod add { } macro_rules! impl_add { + ($type:ty, $trait:ty) => { + impl core::ops::Add for $type { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + <$type as Numeric>::new(self.val + rhs.val) + } + } + }; + ($type:ty) => { impl core::ops::Add for $type { type Output = Self; fn add(self, rhs: Self) -> Self::Output { - <$type>::new(self.val + rhs.val, 1) + <$type>::new(self.val + rhs.val) } } }; } - impl_add!(F16); - impl_add!(BF16); - impl_add!(F32); - impl_add!(F64); - impl_add!(I32); - impl_add!(I64); + impl_add!(F16, Float); + impl_add!(BF16, Float); + impl_add!(F32, Float); + impl_add!(F64, Float); + impl_add!(I32, Int); + impl_add!(I64, Int); impl_add!(UInt); } @@ -47,23 +57,33 @@ pub mod sub { } macro_rules! impl_sub { + ($type:ty, $trait:ty) => { + impl core::ops::Sub for $type { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + <$type as Numeric>::new(self.val - rhs.val) + } + } + }; + ($type:ty) => { impl core::ops::Sub for $type { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { - <$type>::new(self.val - rhs.val, 1) + <$type>::new(self.val - rhs.val) } } }; } - impl_sub!(F16); - impl_sub!(BF16); - impl_sub!(F32); - impl_sub!(F64); - impl_sub!(I32); - impl_sub!(I64); + impl_sub!(F16, Float); + impl_sub!(BF16, Float); + impl_sub!(F32, Float); + impl_sub!(F64, Float); + impl_sub!(I32, Int); + impl_sub!(I64, Int); impl_sub!(UInt); } @@ -79,23 +99,33 @@ pub mod mul { } macro_rules! impl_mul { + ($type:ty, $trait:ty) => { + impl core::ops::Mul for $type { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + <$type as Numeric>::new(self.val * rhs.val) + } + } + }; + ($type:ty) => { impl core::ops::Mul for $type { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { - <$type>::new(self.val * rhs.val, 1) + <$type>::new(self.val - rhs.val) } } }; } - impl_mul!(F16); - impl_mul!(BF16); - impl_mul!(F32); - impl_mul!(F64); - impl_mul!(I32); - impl_mul!(I64); + impl_mul!(F16, Float); + impl_mul!(BF16, Float); + impl_mul!(F32, Float); + impl_mul!(F64, Float); + impl_mul!(I32, Int); + impl_mul!(I64, Int); impl_mul!(UInt); } @@ -111,24 +141,33 @@ pub mod div { } macro_rules! impl_div { + ($type:ty, $trait:ty) => { + impl core::ops::Div for $type { + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + <$type as Numeric>::new(self.val / rhs.val) + } + } + }; + ($type:ty) => { impl core::ops::Div for $type { type Output = Self; fn div(self, rhs: Self) -> Self::Output { - <$type>::new(self.val / rhs.val, 1) + <$type>::new(self.val / rhs.val) } } }; } - impl_div!(F16); - impl_div!(BF16); - impl_div!(F32); - impl_div!(F64); - impl_div!(I32); - impl_div!(I64); - impl_div!(UInt); + impl_div!(F16, Float); + impl_div!(BF16, Float); + impl_div!(F32, Float); + impl_div!(F64, Float); + impl_div!(I32, Int); + impl_div!(I64, Int); } pub mod rem { @@ -143,19 +182,29 @@ pub mod rem { } macro_rules! impl_rem { - ($type:ty) => { + ($type:ty, $trait:ty) => { impl core::ops::Rem for $type { type Output = Self; fn rem(self, rhs: Self) -> Self::Output { - <$type>::new(self.val % rhs.val, 1) + <$type as Numeric>::new(self.val % rhs.val) + } + } + }; + + ($type:ty) => { + impl core::ops::Div for $type { + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + <$type>::new(self.val / rhs.val) } } }; } - impl_rem!(I32); - impl_rem!(I64); + impl_rem!(I32, Int); + impl_rem!(I64, Int); impl_rem!(UInt); } diff --git a/crates/burn-cube/tests/cast_elem.rs b/crates/burn-cube/tests/cast_elem.rs index 123cabe191..ce3cc6f95c 100644 --- a/crates/burn-cube/tests/cast_elem.rs +++ b/crates/burn-cube/tests/cast_elem.rs @@ -1,4 +1,4 @@ -use burn_cube::{cube, Bool, CubeContext, Float, Int, UInt, F32, I32}; +use burn_cube::{cube, elemtype::*, Bool, CubeContext, Float, Int, Numeric, UInt, F32, I32}; use burn_jit::{ cube_inline, gpu::{Elem, Item, Variable}, @@ -12,7 +12,7 @@ macro_rules! cast_test { let x = context.create_local($from); - $module::expand(&mut context, x); + $module(&mut context, x); let scope = context.into_scope(); assert_eq!( @@ -29,7 +29,7 @@ macro_rules! cast_test { let x = context.create_local($ty); - $module::expand(&mut context, x); + $module(&mut context, x); let scope = context.into_scope(); assert_eq!( @@ -43,51 +43,51 @@ macro_rules! cast_test { // From float #[cube] pub fn float_to_float(x: F32) { - let y = x + float_new::(2.0); - let _ = to_float::(y) + float_new::(34.0); + let y = x + F32::new(2.0); + let _ = to_float::(y) + F32::new(34.0); } #[cube] pub fn float_to_int(x: F32) { - let y = x + float_new::(2.0); - let _ = to_int::(y) + int_new::(34); + let y = x + F32::new(2.0); + let _ = to_int::(y) + I32::new(34.); } #[cube] pub fn float_to_uint(x: F32) { - let y = x + float_new::(2.0); + let y = x + F32::new(2.0); let _ = to_uint(y) + uint_new(34u32); } #[cube] pub fn float_to_bool(x: F32) { - let y = x + float_new::(2.0); + let y = x + F32::new(2.0); let _ = to_bool(y) | bool_new(true); } cast_test!( cube_float_to_float_test, - float_to_float, + float_to_float_expand, Item::Scalar(Elem::Float(F32::into_kind())) ); cast_test!( cube_float_to_int_test, - float_to_int, + float_to_int_expand, Item::Scalar(Elem::Float(F32::into_kind())), Item::Scalar(Elem::Int(I32::into_kind())) ); cast_test!( cube_float_to_uint_test, - float_to_uint, + float_to_uint_expand, Item::Scalar(Elem::Float(F32::into_kind())), Item::Scalar(Elem::UInt) ); cast_test!( cube_float_to_bool_test, - float_to_bool, + float_to_bool_expand, Item::Scalar(Elem::Float(F32::into_kind())), Item::Scalar(Elem::Bool) ); @@ -95,51 +95,51 @@ cast_test!( // // From int #[cube] pub fn int_to_float(x: I32) { - let y = x + int_new::(2); - let _ = to_float::(y) + float_new::(34.0); + let y = x + I32::new(2.); + let _ = to_float::(y) + F32::new(34.0); } #[cube] pub fn int_to_int(x: I32) { - let y = x + int_new::(2); - let _ = to_int::(y) + int_new::(34); + let y = x + I32::new(2.); + let _ = to_int::(y) + I32::new(34.); } #[cube] pub fn int_to_uint(x: I32) { - let y = x + int_new::(2); + let y = x + I32::new(2.); let _ = to_uint(y) + uint_new(34u32); } #[cube] pub fn int_to_bool(x: I32) { - let y = x + int_new::(2); + let y = x + I32::new(2.); let _ = to_bool(y) | bool_new(true); } cast_test!( cube_int_to_float_test, - int_to_float, + int_to_float_expand, Item::Scalar(Elem::Int(I32::into_kind())), Item::Scalar(Elem::Float(F32::into_kind())) ); cast_test!( cube_int_to_int_test, - int_to_int, + int_to_int_expand, Item::Scalar(Elem::Int(I32::into_kind())) ); cast_test!( cube_int_to_uint_test, - int_to_uint, + int_to_uint_expand, Item::Scalar(Elem::Int(I32::into_kind())), Item::Scalar(Elem::UInt) ); cast_test!( cube_int_to_bool_test, - int_to_bool, + int_to_bool_expand, Item::Scalar(Elem::Int(I32::into_kind())), Item::Scalar(Elem::Bool) ); @@ -148,13 +148,13 @@ cast_test!( #[cube] pub fn uint_to_float(x: UInt) { let y = x + uint_new(2u32); - let _ = to_float::(y) + float_new::(34.0); + let _ = to_float::(y) + F32::new(34.0); } #[cube] pub fn uint_to_int(x: UInt) { let y = x + uint_new(2u32); - let _ = to_int::(y) + int_new::(34); + let _ = to_int::(y) + I32::new(34.); } #[cube] @@ -171,27 +171,27 @@ pub fn uint_to_bool(x: UInt) { cast_test!( cube_uint_to_float_test, - uint_to_float, + uint_to_float_expand, Item::Scalar(Elem::UInt), Item::Scalar(Elem::Float(F32::into_kind())) ); cast_test!( cube_uint_to_int_test, - uint_to_int, + uint_to_int_expand, Item::Scalar(Elem::UInt), Item::Scalar(Elem::Int(I32::into_kind())) ); cast_test!( cube_uint_to_uint_test, - uint_to_uint, + uint_to_uint_expand, Item::Scalar(Elem::UInt) ); cast_test!( cube_uint_to_bool_test, - uint_to_bool, + uint_to_bool_expand, Item::Scalar(Elem::UInt), Item::Scalar(Elem::Bool) ); @@ -200,13 +200,13 @@ cast_test!( #[cube] pub fn bool_to_float(x: Bool) { let y = x & bool_new(false); - let _ = to_float::(y) + float_new::(34.0); + let _ = to_float::(y) + F32::new(34.0); } #[cube] pub fn bool_to_int(x: Bool) { let y = x & bool_new(false); - let _ = to_int::(y) + int_new::(34); + let _ = to_int::(y) + I32::new(34.); } #[cube] @@ -223,28 +223,28 @@ pub fn bool_to_bool(x: Bool) { cast_test!( cube_bool_to_float_test, - bool_to_float, + bool_to_float_expand, Item::Scalar(Elem::Bool), Item::Scalar(Elem::Float(F32::into_kind())) ); cast_test!( cube_bool_to_int_test, - bool_to_int, + bool_to_int_expand, Item::Scalar(Elem::Bool), Item::Scalar(Elem::Int(I32::into_kind())) ); cast_test!( cube_bool_to_uint_test, - bool_to_uint, + bool_to_uint_expand, Item::Scalar(Elem::Bool), Item::Scalar(Elem::UInt) ); cast_test!( cube_bool_to_bool_test, - bool_to_bool, + bool_to_bool_expand, Item::Scalar(Elem::Bool) ); diff --git a/crates/burn-cube/tests/cast_kind.rs b/crates/burn-cube/tests/cast_kind.rs index e4793988bb..4a777ac3bb 100644 --- a/crates/burn-cube/tests/cast_kind.rs +++ b/crates/burn-cube/tests/cast_kind.rs @@ -1,4 +1,4 @@ -use burn_cube::{cube, CubeContext, Float, Int, F32, F64, I32, I64}; +use burn_cube::{cube, elemtype::*, CubeContext, Float, Int, Numeric, F32, F64, I32, I64}; use burn_jit::{ cube_inline, gpu::{Elem, Item}, @@ -6,16 +6,23 @@ use burn_jit::{ #[cube] pub fn cast_float_kind(input: F1) { - let x = input + float_new::(5.9); + let x = input + F1::new(5.9); let y = to_float::(x); - let _ = y + float_new::(2.3); + let _ = y + F2::new(2.3); } #[cube] pub fn cast_int_kind(input: I1) { - let x = input + int_new::(5); + let x = input + I1::new(5.); let y = to_int::(x); - let _ = y + int_new::(2); + let _ = y + I2::new(2.); +} + +#[cube] +pub fn cast_numeric_to_kind(input: T) { + let x = input + T::new(5.); + let y = to_int::(x); + let _ = y + I2::new(2.); } #[test] @@ -26,7 +33,7 @@ fn cube_cast_float_kind_test() { let input = context.create_local(item); // F16 not testable with the gpu macro, but should work the same - cast_float_kind::expand::(&mut context, input); + cast_float_kind_expand::(&mut context, input); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_float()); @@ -39,7 +46,20 @@ fn cube_cast_int_kind_test() { let input = context.create_local(item); - cast_int_kind::expand::(&mut context, input); + cast_int_kind_expand::(&mut context, input); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); +} + +#[test] +fn cube_cast_numeric_kind_test() { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Int(I32::into_kind())); + + let input = context.create_local(item); + + cast_numeric_to_kind_expand::(&mut context, input); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); diff --git a/crates/burn-cube/tests/for_loop.rs b/crates/burn-cube/tests/for_loop.rs index d0cfe89bbc..0e533c4c38 100644 --- a/crates/burn-cube/tests/for_loop.rs +++ b/crates/burn-cube/tests/for_loop.rs @@ -1,4 +1,4 @@ -use burn_cube::{cube, range, range_expand, Array, CubeContext, Float, UInt, F32}; +use burn_cube::{branch::*, cube, Array, CubeContext, Float, UInt, F32}; use burn_jit::{ cube_inline, gpu::{Elem, Item, Variable}, @@ -25,7 +25,7 @@ fn test_for_loop_with_unroll() { let rhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); let end = 4u32.into(); - for_loop::expand::(&mut context, lhs, rhs, end, unroll); + for_loop_expand::(&mut context, lhs, rhs, end, unroll); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll)); @@ -40,7 +40,7 @@ fn test_for_loop_no_unroll() { let rhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); let end = 4u32.into(); - for_loop::expand::(&mut context, lhs, rhs, end, unroll); + for_loop_expand::(&mut context, lhs, rhs, end, unroll); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll)); diff --git a/crates/burn-cube/tests/function_call.rs b/crates/burn-cube/tests/function_call.rs new file mode 100644 index 0000000000..a442cc5d8e --- /dev/null +++ b/crates/burn-cube/tests/function_call.rs @@ -0,0 +1,68 @@ +use burn_cube::{cube, elemtype::*, CubeContext, UInt}; +use burn_jit::gpu::{Elem, Item}; + +#[cube] +pub fn caller_no_arg(x: UInt) { + let _ = x + callee_no_arg(); +} + +#[cube] +pub fn callee_no_arg() -> UInt { + uint_new(8u32) +} + +#[cube] +pub fn no_call_no_arg(x: UInt) { + let _ = x + uint_new(8u32); +} + +#[cube] +pub fn caller_with_arg(x: UInt) { + let _ = x + callee_with_arg(x); +} + +#[cube] +pub fn callee_with_arg(x: UInt) -> UInt { + x * uint_new(8u32) +} + +#[cube] +pub fn no_call_with_arg(x: UInt) { + let _ = x + x * uint_new(8u32); +} + +#[test] +fn cube_call_equivalent_to_no_call_no_arg_test() { + let mut caller_context = CubeContext::root(); + let x = caller_context.create_local(Item::Scalar(Elem::UInt)); + caller_no_arg_expand(&mut caller_context, x); + let caller_scope = caller_context.into_scope(); + + let mut no_call_context = CubeContext::root(); + let x = no_call_context.create_local(Item::Scalar(Elem::UInt)); + no_call_no_arg_expand(&mut no_call_context, x); + let no_call_scope = no_call_context.into_scope(); + + assert_eq!( + format!("{:?}", caller_scope.operations), + format!("{:?}", no_call_scope.operations) + ); +} + +#[test] +fn cube_call_equivalent_to_no_call_with_arg_test() { + let mut caller_context = CubeContext::root(); + let x = caller_context.create_local(Item::Scalar(Elem::UInt)); + caller_with_arg_expand(&mut caller_context, x); + let caller_scope = caller_context.into_scope(); + + let mut no_call_context = CubeContext::root(); + let x = no_call_context.create_local(Item::Scalar(Elem::UInt)); + no_call_with_arg_expand(&mut no_call_context, x); + let no_call_scope = no_call_context.into_scope(); + + assert_eq!( + format!("{:?}", caller_scope.operations), + format!("{:?}", no_call_scope.operations) + ); +} diff --git a/crates/burn-cube/tests/generic_kernel.rs b/crates/burn-cube/tests/generic_kernel.rs new file mode 100644 index 0000000000..f0d6b6228c --- /dev/null +++ b/crates/burn-cube/tests/generic_kernel.rs @@ -0,0 +1,58 @@ +use burn_cube::{cube, CubeContext, Float, Int, Numeric, F32, I32}; +use burn_jit::{ + cube_inline, + gpu::{Elem, Item}, +}; + +#[cube] +pub fn generic_kernel(lhs: T) { + let _ = lhs + T::new(5.); +} + +#[test] +fn cube_generic_float_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(Elem::Float(F32::into_kind()))); + + generic_kernel_expand::(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_float()); +} + +#[test] +fn cube_generic_int_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(Elem::Int(I32::into_kind()))); + + generic_kernel_expand::(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); +} + +fn inline_macro_ref_float() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Float(F32::into_kind())); + let lhs = context.create_local(item); + + let mut scope = context.into_scope(); + let out = scope.create_local(item); + cube_inline!(scope, out = lhs + 5.0f32); + + format!("{:?}", scope.operations) +} + +fn inline_macro_ref_int() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(Elem::Int(I32::into_kind())); + let lhs = context.create_local(item); + + let mut scope = context.into_scope(); + let out = scope.create_local(item); + cube_inline!(scope, out = lhs + 5); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/if.rs b/crates/burn-cube/tests/if.rs index e7f3d41987..2dbbb468d2 100644 --- a/crates/burn-cube/tests/if.rs +++ b/crates/burn-cube/tests/if.rs @@ -1,12 +1,15 @@ -use burn_cube::{cube, if_expand, CubeContext, Float, F32}; -use burn_jit::{cube_inline, gpu::{Elem, Item, Variable}}; +use burn_cube::{branch::*, cube, elemtype::*, CubeContext, Float, F32}; +use burn_jit::{ + cube_inline, + gpu::{Elem, Item, Variable}, +}; type ElemType = F32; #[cube] pub fn if_greater(lhs: F) { - if lhs > float_new::(0.0) { - let _ = lhs + float_new::(4.0); + if lhs > F::new(0.0) { + let _ = lhs + F::new(4.0); } } @@ -16,7 +19,7 @@ fn cube_if_test() { let lhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); - if_greater::expand::(&mut context, lhs); + if_greater_expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); diff --git a/crates/burn-cube/tests/if_else.rs b/crates/burn-cube/tests/if_else.rs index 605035967b..b8c07151a0 100644 --- a/crates/burn-cube/tests/if_else.rs +++ b/crates/burn-cube/tests/if_else.rs @@ -1,4 +1,4 @@ -use burn_cube::{cube, if_else_expand, CubeContext, Float, F32}; +use burn_cube::{branch::*, cube, elemtype::*, CubeContext, Float, F32}; use burn_jit::{ cube_inline, gpu::{Elem, Item, Variable}, @@ -7,11 +7,11 @@ use burn_jit::{ type ElemType = F32; #[cube] -pub fn if_else(lhs: F) { - if lhs < float_new::(0.0) { - let _ = lhs + float_new::(4.0); +pub fn if_then_else(lhs: F) { + if lhs < F::new(0.0) { + let _ = lhs + F::new(4.0); } else { - let _ = lhs - float_new::(5.0); + let _ = lhs - F::new(5.0); } } @@ -21,7 +21,7 @@ fn cube_if_else_test() { let lhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); - if_else::expand::(&mut context, lhs); + if_then_else_expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); diff --git a/crates/burn-cube/tests/literal.rs b/crates/burn-cube/tests/literal.rs index 0024de8de8..f861a75883 100644 --- a/crates/burn-cube/tests/literal.rs +++ b/crates/burn-cube/tests/literal.rs @@ -1,11 +1,14 @@ -use burn_cube::{cube, CubeContext, Float, F32}; -use burn_jit::{cube_inline, gpu::{Elem, Item}}; +use burn_cube::{cube, elemtype::*, CubeContext, Float, F32}; +use burn_jit::{ + cube_inline, + gpu::{Elem, Item}, +}; type ElemType = F32; #[cube] pub fn literal(lhs: F) { - let _ = lhs + float_new::(5.9); + let _ = lhs + F::new(5.); } #[test] @@ -13,8 +16,9 @@ fn cube_literal_test() { let mut context = CubeContext::root(); let lhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); + // let lhs = context.create_local(ElemType::into_item()); - literal::expand::(&mut context, lhs); + literal_expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); @@ -27,7 +31,7 @@ fn inline_macro_ref() -> String { let mut scope = context.into_scope(); let out = scope.create_local(item); - cube_inline!(scope, out = lhs + 5.9f32); + cube_inline!(scope, out = lhs + 5.0f32); format!("{:?}", scope.operations) } diff --git a/crates/burn-cube/tests/loop.rs b/crates/burn-cube/tests/loop.rs index 8611078687..13c31dd695 100644 --- a/crates/burn-cube/tests/loop.rs +++ b/crates/burn-cube/tests/loop.rs @@ -1,6 +1,4 @@ -use burn_cube::{ - break_expand, cube, if_expand, loop_expand, while_loop_expand, CubeContext, Int, I32, -}; +use burn_cube::{branch::*, cube, elemtype::*, CubeContext, Int, I32}; use burn_jit::cube_inline; use burn_jit::gpu::Branch; use burn_jit::gpu::{Elem, Item, Variable}; @@ -9,18 +7,18 @@ type ElemType = I32; #[cube] pub fn while_not(lhs: I) { - while lhs != int_new::(0) { - let _ = lhs - int_new::(1); + while lhs != I::new(0.) { + let _ = lhs - I::new(1.); } } #[cube] pub fn manual_loop_break(lhs: I) { loop { - if lhs != int_new::(0) { + if lhs != I::new(0.) { break; } - let _ = lhs - int_new::(1); + let _ = lhs - I::new(1.); } } @@ -30,7 +28,7 @@ fn cube_while_test() { let lhs = context.create_local(Item::Scalar(Elem::Int(ElemType::into_kind()))); - while_not::expand::(&mut context, lhs); + while_not_expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); @@ -42,7 +40,7 @@ fn cube_loop_break_test() { let lhs = context.create_local(Item::Scalar(Elem::Int(ElemType::into_kind()))); - manual_loop_break::expand::(&mut context, lhs); + manual_loop_break_expand::(&mut context, lhs); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); diff --git a/crates/burn-cube/tests/reuse.rs b/crates/burn-cube/tests/reuse.rs index 86af146a84..4f1894dd52 100644 --- a/crates/burn-cube/tests/reuse.rs +++ b/crates/burn-cube/tests/reuse.rs @@ -1,4 +1,4 @@ -use burn_cube::{cube, while_loop_expand, CubeContext, Int, I32}; +use burn_cube::{branch::*, cube, elemtype::*, CubeContext, Int, I32}; use burn_jit::{ cube_inline, gpu::{Branch, Elem, Item, Variable}, @@ -12,15 +12,15 @@ type ElemType = I32; #[cube] pub fn reuse(mut x: I) { - while x < int_new::(10) { - x = x + int_new::(1); + while x < I::new(10.) { + x = x + I::new(1.); } } #[cube] pub fn reuse_incr(mut x: I) { - while x < int_new::(10) { - x += int_new::(1); + while x < I::new(10.) { + x += I::new(1.); } } @@ -30,7 +30,7 @@ fn cube_reuse_assign_test() { let x = context.create_local(Item::Scalar(Elem::Int(ElemType::into_kind()))); - reuse::expand::(&mut context, x); + reuse_expand::(&mut context, x); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_assign()); @@ -42,7 +42,7 @@ fn cube_reuse_incr_test() { let x = context.create_local(Item::Scalar(Elem::Int(ElemType::into_kind()))); - reuse_incr::expand::(&mut context, x); + reuse_incr_expand::(&mut context, x); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_incr()); diff --git a/crates/burn-jit/src/lib.rs b/crates/burn-jit/src/lib.rs index 985df1f2df..0ea613f315 100644 --- a/crates/burn-jit/src/lib.rs +++ b/crates/burn-jit/src/lib.rs @@ -15,6 +15,7 @@ pub mod kernel; /// Tensor module. pub mod tensor; +/// Useful in Cube, should be moved over there pub mod codegen; pub(crate) mod tune; From 2047e470fdfc9d133d8952e063a4aca5fa09fe47 Mon Sep 17 00:00:00 2001 From: louisfd Date: Fri, 10 May 2024 16:14:21 -0400 Subject: [PATCH 34/54] T::new --- crates/burn-cube/src/element/float.rs | 28 +++++++---- crates/burn-cube/src/element/int.rs | 28 +++++++---- crates/burn-cube/src/element/numeric.rs | 5 +- crates/burn-cube/src/elemtype.rs | 46 +++--------------- crates/burn-cube/src/operation/binary.rs | 12 ++--- crates/burn-cube/tests/cast_elem.rs | 60 ++++++++++++------------ crates/burn-cube/tests/cast_kind.rs | 31 ++++++------ crates/burn-cube/tests/for_loop.rs | 12 ++--- crates/burn-cube/tests/generic_kernel.rs | 15 +++--- crates/burn-cube/tests/if.rs | 10 ++-- crates/burn-cube/tests/if_else.rs | 12 ++--- crates/burn-cube/tests/literal.rs | 13 ++--- crates/burn-cube/tests/loop.rs | 16 +++---- crates/burn-cube/tests/reuse.rs | 18 +++---- 14 files changed, 143 insertions(+), 163 deletions(-) diff --git a/crates/burn-cube/src/element/float.rs b/crates/burn-cube/src/element/float.rs index bfab89d20d..8d9da01f2f 100644 --- a/crates/burn-cube/src/element/float.rs +++ b/crates/burn-cube/src/element/float.rs @@ -12,7 +12,9 @@ pub trait Float: + std::ops::Div + Numeric { - fn into_kind() -> FloatKind; + fn into_elem() -> Elem; + fn from(val: f64) -> Self; + fn from_expand(context: &mut CubeContext, val: f64) -> ExpandElement; } macro_rules! impl_float { @@ -28,22 +30,30 @@ macro_rules! impl_float { } impl Float for $type { - fn into_kind() -> FloatKind { - FloatKind::$type + fn into_elem() -> Elem { + Elem::Float(FloatKind::$type) + } + fn from(val: f64) -> Self { + Self { + val, + vectorization: 1, + } + } + fn from_expand(_context: &mut CubeContext, val: f64) -> ExpandElement { + let new_var = Variable::ConstantScalar(val, Self::into_elem()); + ExpandElement::new(Rc::new(new_var)) } } impl Numeric for $type { - fn new(val: f64) -> Self { + fn new(val: i64) -> Self { Self { - val, + val: val as f64, vectorization: 1, } } - fn new_expand(_context: &mut CubeContext, val: f64) -> ExpandElement { - let elem = Elem::Float(Self::into_kind()); - let new_var = Variable::ConstantScalar(val, elem); - ExpandElement::new(Rc::new(new_var)) + fn new_expand(context: &mut CubeContext, val: i64) -> ExpandElement { + ::from_expand(context, val as f64) } } }; diff --git a/crates/burn-cube/src/element/int.rs b/crates/burn-cube/src/element/int.rs index 2450dc3702..37fee7b7b1 100644 --- a/crates/burn-cube/src/element/int.rs +++ b/crates/burn-cube/src/element/int.rs @@ -12,14 +12,16 @@ pub trait Int: + std::ops::AddAssign + Numeric { - fn into_kind() -> IntKind; + fn into_elem() -> Elem; + fn from(val: i64) -> Self; + fn from_expand(context: &mut CubeContext, val: i64) -> ExpandElement; } macro_rules! impl_int { ($type:ident) => { #[derive(Clone, Copy)] pub struct $type { - pub val: f64, + pub val: i64, pub vectorization: usize, } @@ -28,22 +30,30 @@ macro_rules! impl_int { } impl Int for $type { - fn into_kind() -> IntKind { - IntKind::$type + fn into_elem() -> Elem { + Elem::Int(IntKind::$type) + } + fn from(val: i64) -> Self { + Self { + val, + vectorization: 1, + } + } + fn from_expand(_context: &mut CubeContext, val: i64) -> ExpandElement { + let new_var = Variable::ConstantScalar(val as f64, Self::into_elem()); + ExpandElement::new(Rc::new(new_var)) } } impl Numeric for $type { - fn new(val: f64) -> Self { + fn new(val: i64) -> Self { Self { val, vectorization: 1, } } - fn new_expand(_context: &mut CubeContext, val: f64) -> ExpandElement { - let elem = Elem::Int(Self::into_kind()); - let new_var = Variable::ConstantScalar(val, elem); - ExpandElement::new(Rc::new(new_var)) + fn new_expand(context: &mut CubeContext, val: i64) -> ExpandElement { + ::from_expand(context, val) } } }; diff --git a/crates/burn-cube/src/element/numeric.rs b/crates/burn-cube/src/element/numeric.rs index 0494b2cfc1..954d3ba16c 100644 --- a/crates/burn-cube/src/element/numeric.rs +++ b/crates/burn-cube/src/element/numeric.rs @@ -5,6 +5,7 @@ use crate::{CubeContext, CubeType, ExpandElement}; pub trait Numeric: Clone + Copy + CubeType + std::ops::Add { - fn new(val: f64) -> Self; - fn new_expand(context: &mut CubeContext, val: f64) -> ExpandElement; + // If we use numeric then constants are necessarily ints + fn new(val: i64) -> Self; + fn new_expand(context: &mut CubeContext, val: i64) -> ExpandElement; } diff --git a/crates/burn-cube/src/elemtype.rs b/crates/burn-cube/src/elemtype.rs index 19a6c3b3d7..79b1e18082 100644 --- a/crates/burn-cube/src/elemtype.rs +++ b/crates/burn-cube/src/elemtype.rs @@ -1,14 +1,6 @@ use burn_jit::gpu::{Elem, Item}; -use crate::{assign, Bool, CubeContext, CubeType, ExpandElement, Float, Int, Numeric, UInt}; - -// pub fn new(val: f) -> T { -// T::new(val) -// } - -// pub fn new_expand(_context: &mut CubeContext, val: f32) -> ::ExpandType { -// T::new_expand(_context, val) -// } +use crate::{assign, Bool, CubeContext, CubeType, ExpandElement, Float, Int, UInt}; pub fn uint_new(val: u32) -> UInt { UInt { @@ -31,39 +23,27 @@ pub fn bool_new_expand(_context: &mut CubeContext, val: bool) -> (_input: R) -> I { - I::new(0.) + I::new(0) } pub fn to_int_expand( context: &mut CubeContext, val: ExpandElement, ) -> ::ExpandType { - let elem = Elem::Int(I::into_kind()); - let new_var = context.create_local(match val.item() { - Item::Vec4(_) => Item::Vec4(elem), - Item::Vec3(_) => Item::Vec3(elem), - Item::Vec2(_) => Item::Vec2(elem), - Item::Scalar(_) => Item::Scalar(elem), - }); + let new_var = context.create_local(Item::Scalar(I::into_elem())); assign::expand(context, val.into(), new_var.clone()); new_var } pub fn to_float(_input: R) -> F { // TODO: make val accessible through trait - F::new(0.) + F::new(0) } pub fn to_float_expand( context: &mut CubeContext, val: ExpandElement, ) -> ExpandElement { - let elem = Elem::Float(F::into_kind()); - let new_var = context.create_local(match val.item() { - Item::Vec4(_) => Item::Vec4(elem), - Item::Vec3(_) => Item::Vec3(elem), - Item::Vec2(_) => Item::Vec2(elem), - Item::Scalar(_) => Item::Scalar(elem), - }); + let new_var = context.create_local(Item::Scalar(F::into_elem())); assign::expand(context, val.into(), new_var.clone()); new_var } @@ -78,13 +58,7 @@ pub fn to_uint_expand( context: &mut CubeContext, val: ExpandElement, ) -> ::ExpandType { - let elem = Elem::UInt; - let new_var = context.create_local(match val.item() { - Item::Vec4(_) => Item::Vec4(elem), - Item::Vec3(_) => Item::Vec3(elem), - Item::Vec2(_) => Item::Vec2(elem), - Item::Scalar(_) => Item::Scalar(elem), - }); + let new_var = context.create_local(Item::Scalar(Elem::UInt)); assign::expand(context, val.into(), new_var.clone()); new_var } @@ -99,13 +73,7 @@ pub fn to_bool_expand( context: &mut CubeContext, val: ExpandElement, ) -> ::ExpandType { - let elem = Elem::Bool; - let new_var = context.create_local(match val.item() { - Item::Vec4(_) => Item::Vec4(elem), - Item::Vec3(_) => Item::Vec3(elem), - Item::Vec2(_) => Item::Vec2(elem), - Item::Scalar(_) => Item::Scalar(elem), - }); + let new_var = context.create_local(Item::Scalar(Elem::Bool)); assign::expand(context, val.into(), new_var.clone()); new_var } diff --git a/crates/burn-cube/src/operation/binary.rs b/crates/burn-cube/src/operation/binary.rs index 21f281eb20..46955ae06d 100644 --- a/crates/burn-cube/src/operation/binary.rs +++ b/crates/burn-cube/src/operation/binary.rs @@ -1,5 +1,5 @@ use crate::operation::base::binary_expand; -use crate::{CubeContext, ExpandElement, Float, Int, Numeric, UInt, BF16, F16, F32, F64, I32, I64}; +use crate::{CubeContext, ExpandElement, Float, Int, UInt, BF16, F16, F32, F64, I32, I64}; use burn_jit::gpu::{self}; pub mod add { @@ -20,7 +20,7 @@ pub mod add { type Output = Self; fn add(self, rhs: Self) -> Self::Output { - <$type as Numeric>::new(self.val + rhs.val) + <$type as $trait>::from(self.val + rhs.val) } } }; @@ -62,7 +62,7 @@ pub mod sub { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { - <$type as Numeric>::new(self.val - rhs.val) + <$type as $trait>::from(self.val - rhs.val) } } }; @@ -104,7 +104,7 @@ pub mod mul { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { - <$type as Numeric>::new(self.val * rhs.val) + <$type as $trait>::from(self.val * rhs.val) } } }; @@ -146,7 +146,7 @@ pub mod div { type Output = Self; fn div(self, rhs: Self) -> Self::Output { - <$type as Numeric>::new(self.val / rhs.val) + <$type as $trait>::from(self.val / rhs.val) } } }; @@ -187,7 +187,7 @@ pub mod rem { type Output = Self; fn rem(self, rhs: Self) -> Self::Output { - <$type as Numeric>::new(self.val % rhs.val) + <$type as $trait>::from(self.val % rhs.val) } } }; diff --git a/crates/burn-cube/tests/cast_elem.rs b/crates/burn-cube/tests/cast_elem.rs index ce3cc6f95c..09b6f3ab45 100644 --- a/crates/burn-cube/tests/cast_elem.rs +++ b/crates/burn-cube/tests/cast_elem.rs @@ -43,104 +43,104 @@ macro_rules! cast_test { // From float #[cube] pub fn float_to_float(x: F32) { - let y = x + F32::new(2.0); - let _ = to_float::(y) + F32::new(34.0); + let y = x + F32::new(2); + let _ = to_float::(y) + F32::new(34); } #[cube] pub fn float_to_int(x: F32) { - let y = x + F32::new(2.0); - let _ = to_int::(y) + I32::new(34.); + let y = x + F32::new(2); + let _ = to_int::(y) + I32::new(34); } #[cube] pub fn float_to_uint(x: F32) { - let y = x + F32::new(2.0); + let y = x + F32::new(2); let _ = to_uint(y) + uint_new(34u32); } #[cube] pub fn float_to_bool(x: F32) { - let y = x + F32::new(2.0); + let y = x + F32::new(2); let _ = to_bool(y) | bool_new(true); } cast_test!( cube_float_to_float_test, float_to_float_expand, - Item::Scalar(Elem::Float(F32::into_kind())) + Item::Scalar(F32::into_elem()) ); cast_test!( cube_float_to_int_test, float_to_int_expand, - Item::Scalar(Elem::Float(F32::into_kind())), - Item::Scalar(Elem::Int(I32::into_kind())) + Item::Scalar(F32::into_elem()), + Item::Scalar(I32::into_elem()) ); cast_test!( cube_float_to_uint_test, float_to_uint_expand, - Item::Scalar(Elem::Float(F32::into_kind())), + Item::Scalar(F32::into_elem()), Item::Scalar(Elem::UInt) ); cast_test!( cube_float_to_bool_test, float_to_bool_expand, - Item::Scalar(Elem::Float(F32::into_kind())), + Item::Scalar(F32::into_elem()), Item::Scalar(Elem::Bool) ); // // From int #[cube] pub fn int_to_float(x: I32) { - let y = x + I32::new(2.); - let _ = to_float::(y) + F32::new(34.0); + let y = x + I32::new(2); + let _ = to_float::(y) + F32::new(34); } #[cube] pub fn int_to_int(x: I32) { - let y = x + I32::new(2.); - let _ = to_int::(y) + I32::new(34.); + let y = x + I32::new(2); + let _ = to_int::(y) + I32::new(34); } #[cube] pub fn int_to_uint(x: I32) { - let y = x + I32::new(2.); + let y = x + I32::new(2); let _ = to_uint(y) + uint_new(34u32); } #[cube] pub fn int_to_bool(x: I32) { - let y = x + I32::new(2.); + let y = x + I32::new(2); let _ = to_bool(y) | bool_new(true); } cast_test!( cube_int_to_float_test, int_to_float_expand, - Item::Scalar(Elem::Int(I32::into_kind())), - Item::Scalar(Elem::Float(F32::into_kind())) + Item::Scalar(I32::into_elem()), + Item::Scalar(F32::into_elem()) ); cast_test!( cube_int_to_int_test, int_to_int_expand, - Item::Scalar(Elem::Int(I32::into_kind())) + Item::Scalar(I32::into_elem()) ); cast_test!( cube_int_to_uint_test, int_to_uint_expand, - Item::Scalar(Elem::Int(I32::into_kind())), + Item::Scalar(I32::into_elem()), Item::Scalar(Elem::UInt) ); cast_test!( cube_int_to_bool_test, int_to_bool_expand, - Item::Scalar(Elem::Int(I32::into_kind())), + Item::Scalar(I32::into_elem()), Item::Scalar(Elem::Bool) ); @@ -148,13 +148,13 @@ cast_test!( #[cube] pub fn uint_to_float(x: UInt) { let y = x + uint_new(2u32); - let _ = to_float::(y) + F32::new(34.0); + let _ = to_float::(y) + F32::new(34); } #[cube] pub fn uint_to_int(x: UInt) { let y = x + uint_new(2u32); - let _ = to_int::(y) + I32::new(34.); + let _ = to_int::(y) + I32::new(34); } #[cube] @@ -173,14 +173,14 @@ cast_test!( cube_uint_to_float_test, uint_to_float_expand, Item::Scalar(Elem::UInt), - Item::Scalar(Elem::Float(F32::into_kind())) + Item::Scalar(F32::into_elem()) ); cast_test!( cube_uint_to_int_test, uint_to_int_expand, Item::Scalar(Elem::UInt), - Item::Scalar(Elem::Int(I32::into_kind())) + Item::Scalar(I32::into_elem()) ); cast_test!( @@ -200,13 +200,13 @@ cast_test!( #[cube] pub fn bool_to_float(x: Bool) { let y = x & bool_new(false); - let _ = to_float::(y) + F32::new(34.0); + let _ = to_float::(y) + F32::new(34); } #[cube] pub fn bool_to_int(x: Bool) { let y = x & bool_new(false); - let _ = to_int::(y) + I32::new(34.); + let _ = to_int::(y) + I32::new(34); } #[cube] @@ -225,14 +225,14 @@ cast_test!( cube_bool_to_float_test, bool_to_float_expand, Item::Scalar(Elem::Bool), - Item::Scalar(Elem::Float(F32::into_kind())) + Item::Scalar(F32::into_elem()) ); cast_test!( cube_bool_to_int_test, bool_to_int_expand, Item::Scalar(Elem::Bool), - Item::Scalar(Elem::Int(I32::into_kind())) + Item::Scalar(I32::into_elem()) ); cast_test!( diff --git a/crates/burn-cube/tests/cast_kind.rs b/crates/burn-cube/tests/cast_kind.rs index 4a777ac3bb..8eec42e74d 100644 --- a/crates/burn-cube/tests/cast_kind.rs +++ b/crates/burn-cube/tests/cast_kind.rs @@ -1,34 +1,31 @@ use burn_cube::{cube, elemtype::*, CubeContext, Float, Int, Numeric, F32, F64, I32, I64}; -use burn_jit::{ - cube_inline, - gpu::{Elem, Item}, -}; +use burn_jit::{cube_inline, gpu::Item}; #[cube] pub fn cast_float_kind(input: F1) { - let x = input + F1::new(5.9); + let x = input + F1::from(5.9); let y = to_float::(x); - let _ = y + F2::new(2.3); + let _ = y + F2::from(2.3); } #[cube] pub fn cast_int_kind(input: I1) { - let x = input + I1::new(5.); + let x = input + I1::from(5); let y = to_int::(x); - let _ = y + I2::new(2.); + let _ = y + I2::from(2); } #[cube] pub fn cast_numeric_to_kind(input: T) { - let x = input + T::new(5.); + let x = input + T::new(5); let y = to_int::(x); - let _ = y + I2::new(2.); + let _ = y + I2::new(2); } #[test] fn cube_cast_float_kind_test() { let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Float(F64::into_kind())); + let item = Item::Scalar(F64::into_elem()); let input = context.create_local(item); @@ -42,7 +39,7 @@ fn cube_cast_float_kind_test() { #[test] fn cube_cast_int_kind_test() { let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Int(I32::into_kind())); + let item = Item::Scalar(I32::into_elem()); let input = context.create_local(item); @@ -55,7 +52,7 @@ fn cube_cast_int_kind_test() { #[test] fn cube_cast_numeric_kind_test() { let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Int(I32::into_kind())); + let item = Item::Scalar(I32::into_elem()); let input = context.create_local(item); @@ -67,8 +64,8 @@ fn cube_cast_numeric_kind_test() { fn inline_macro_ref_float() -> String { let mut context = CubeContext::root(); - let float_64 = Item::Scalar(Elem::Float(F64::into_kind())); - let float_32 = Item::Scalar(Elem::Float(F32::into_kind())); + let float_64 = Item::Scalar(F64::into_elem()); + let float_32 = Item::Scalar(F32::into_elem()); let input = context.create_local(float_64); let mut scope = context.into_scope(); @@ -85,8 +82,8 @@ fn inline_macro_ref_float() -> String { fn inline_macro_ref_int() -> String { let mut context = CubeContext::root(); - let int_32 = Item::Scalar(Elem::Int(I32::into_kind())); - let int_64 = Item::Scalar(Elem::Int(I64::into_kind())); + let int_32 = Item::Scalar(I32::into_elem()); + let int_64 = Item::Scalar(I64::into_elem()); let input = context.create_local(int_32); let mut scope = context.into_scope(); diff --git a/crates/burn-cube/tests/for_loop.rs b/crates/burn-cube/tests/for_loop.rs index 0e533c4c38..666287a6c0 100644 --- a/crates/burn-cube/tests/for_loop.rs +++ b/crates/burn-cube/tests/for_loop.rs @@ -1,7 +1,7 @@ use burn_cube::{branch::*, cube, Array, CubeContext, Float, UInt, F32}; use burn_jit::{ cube_inline, - gpu::{Elem, Item, Variable}, + gpu::{Item, Variable}, }; type ElemType = F32; @@ -21,8 +21,8 @@ fn test_for_loop_with_unroll() { let mut context = CubeContext::root(); let unroll = true; - let lhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); - let rhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); + let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); + let rhs = context.create_local(Item::Scalar(ElemType::into_elem())); let end = 4u32.into(); for_loop_expand::(&mut context, lhs, rhs, end, unroll); @@ -36,8 +36,8 @@ fn test_for_loop_no_unroll() { let mut context = CubeContext::root(); let unroll = false; - let lhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); - let rhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); + let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); + let rhs = context.create_local(Item::Scalar(ElemType::into_elem())); let end = 4u32.into(); for_loop_expand::(&mut context, lhs, rhs, end, unroll); @@ -48,7 +48,7 @@ fn test_for_loop_no_unroll() { fn inline_macro_ref(unroll: bool) -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Float(ElemType::into_kind())); + let item = Item::Scalar(ElemType::into_elem()); let lhs = context.create_local(item); let rhs = context.create_local(item); diff --git a/crates/burn-cube/tests/generic_kernel.rs b/crates/burn-cube/tests/generic_kernel.rs index f0d6b6228c..d8be4c68e6 100644 --- a/crates/burn-cube/tests/generic_kernel.rs +++ b/crates/burn-cube/tests/generic_kernel.rs @@ -1,19 +1,16 @@ use burn_cube::{cube, CubeContext, Float, Int, Numeric, F32, I32}; -use burn_jit::{ - cube_inline, - gpu::{Elem, Item}, -}; +use burn_jit::{cube_inline, gpu::Item}; #[cube] pub fn generic_kernel(lhs: T) { - let _ = lhs + T::new(5.); + let _ = lhs + T::new(5); } #[test] fn cube_generic_float_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(Elem::Float(F32::into_kind()))); + let lhs = context.create_local(Item::Scalar(F32::into_elem())); generic_kernel_expand::(&mut context, lhs); let scope = context.into_scope(); @@ -25,7 +22,7 @@ fn cube_generic_float_test() { fn cube_generic_int_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(Elem::Int(I32::into_kind()))); + let lhs = context.create_local(Item::Scalar(I32::into_elem())); generic_kernel_expand::(&mut context, lhs); let scope = context.into_scope(); @@ -35,7 +32,7 @@ fn cube_generic_int_test() { fn inline_macro_ref_float() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Float(F32::into_kind())); + let item = Item::Scalar(F32::into_elem()); let lhs = context.create_local(item); let mut scope = context.into_scope(); @@ -47,7 +44,7 @@ fn inline_macro_ref_float() -> String { fn inline_macro_ref_int() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Int(I32::into_kind())); + let item = Item::Scalar(I32::into_elem()); let lhs = context.create_local(item); let mut scope = context.into_scope(); diff --git a/crates/burn-cube/tests/if.rs b/crates/burn-cube/tests/if.rs index 2dbbb468d2..e60eb7a9cc 100644 --- a/crates/burn-cube/tests/if.rs +++ b/crates/burn-cube/tests/if.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, elemtype::*, CubeContext, Float, F32}; +use burn_cube::{branch::*, cube, CubeContext, Float, F32}; use burn_jit::{ cube_inline, gpu::{Elem, Item, Variable}, @@ -8,8 +8,8 @@ type ElemType = F32; #[cube] pub fn if_greater(lhs: F) { - if lhs > F::new(0.0) { - let _ = lhs + F::new(4.0); + if lhs > F::new(0) { + let _ = lhs + F::new(4); } } @@ -17,7 +17,7 @@ pub fn if_greater(lhs: F) { fn cube_if_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); + let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); if_greater_expand::(&mut context, lhs); let scope = context.into_scope(); @@ -27,7 +27,7 @@ fn cube_if_test() { fn inline_macro_ref() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Float(ElemType::into_kind())); + let item = Item::Scalar(ElemType::into_elem()); let lhs = context.create_local(item); let mut scope = context.into_scope(); diff --git a/crates/burn-cube/tests/if_else.rs b/crates/burn-cube/tests/if_else.rs index b8c07151a0..2866124599 100644 --- a/crates/burn-cube/tests/if_else.rs +++ b/crates/burn-cube/tests/if_else.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, elemtype::*, CubeContext, Float, F32}; +use burn_cube::{branch::*, cube, CubeContext, Float, F32}; use burn_jit::{ cube_inline, gpu::{Elem, Item, Variable}, @@ -8,10 +8,10 @@ type ElemType = F32; #[cube] pub fn if_then_else(lhs: F) { - if lhs < F::new(0.0) { - let _ = lhs + F::new(4.0); + if lhs < F::new(0) { + let _ = lhs + F::new(4); } else { - let _ = lhs - F::new(5.0); + let _ = lhs - F::new(5); } } @@ -19,7 +19,7 @@ pub fn if_then_else(lhs: F) { fn cube_if_else_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); + let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); if_then_else_expand::(&mut context, lhs); let scope = context.into_scope(); @@ -29,7 +29,7 @@ fn cube_if_else_test() { fn inline_macro_ref() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Float(ElemType::into_kind())); + let item = Item::Scalar(ElemType::into_elem()); let lhs = context.create_local(item); let mut scope = context.into_scope(); diff --git a/crates/burn-cube/tests/literal.rs b/crates/burn-cube/tests/literal.rs index f861a75883..2f338722d5 100644 --- a/crates/burn-cube/tests/literal.rs +++ b/crates/burn-cube/tests/literal.rs @@ -1,21 +1,18 @@ -use burn_cube::{cube, elemtype::*, CubeContext, Float, F32}; -use burn_jit::{ - cube_inline, - gpu::{Elem, Item}, -}; +use burn_cube::{cube, CubeContext, Float, F32}; +use burn_jit::{cube_inline, gpu::Item}; type ElemType = F32; #[cube] pub fn literal(lhs: F) { - let _ = lhs + F::new(5.); + let _ = lhs + F::new(5); } #[test] fn cube_literal_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(Elem::Float(ElemType::into_kind()))); + let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); // let lhs = context.create_local(ElemType::into_item()); literal_expand::(&mut context, lhs); @@ -26,7 +23,7 @@ fn cube_literal_test() { fn inline_macro_ref() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Float(ElemType::into_kind())); + let item = Item::Scalar(ElemType::into_elem()); let lhs = context.create_local(item); let mut scope = context.into_scope(); diff --git a/crates/burn-cube/tests/loop.rs b/crates/burn-cube/tests/loop.rs index 13c31dd695..15988df487 100644 --- a/crates/burn-cube/tests/loop.rs +++ b/crates/burn-cube/tests/loop.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, elemtype::*, CubeContext, Int, I32}; +use burn_cube::{branch::*, cube, CubeContext, Int, I32}; use burn_jit::cube_inline; use burn_jit::gpu::Branch; use burn_jit::gpu::{Elem, Item, Variable}; @@ -7,18 +7,18 @@ type ElemType = I32; #[cube] pub fn while_not(lhs: I) { - while lhs != I::new(0.) { - let _ = lhs - I::new(1.); + while lhs != I::new(0) { + let _ = lhs - I::new(1); } } #[cube] pub fn manual_loop_break(lhs: I) { loop { - if lhs != I::new(0.) { + if lhs != I::new(0) { break; } - let _ = lhs - I::new(1.); + let _ = lhs - I::new(1); } } @@ -26,7 +26,7 @@ pub fn manual_loop_break(lhs: I) { fn cube_while_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(Elem::Int(ElemType::into_kind()))); + let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); while_not_expand::(&mut context, lhs); let scope = context.into_scope(); @@ -38,7 +38,7 @@ fn cube_while_test() { fn cube_loop_break_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(Elem::Int(ElemType::into_kind()))); + let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); manual_loop_break_expand::(&mut context, lhs); let scope = context.into_scope(); @@ -48,7 +48,7 @@ fn cube_loop_break_test() { fn inline_macro_ref() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Int(ElemType::into_kind())); + let item = Item::Scalar(ElemType::into_elem()); let lhs = context.create_local(item); let mut scope = context.into_scope(); diff --git a/crates/burn-cube/tests/reuse.rs b/crates/burn-cube/tests/reuse.rs index 4f1894dd52..5915d6c6fb 100644 --- a/crates/burn-cube/tests/reuse.rs +++ b/crates/burn-cube/tests/reuse.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, elemtype::*, CubeContext, Int, I32}; +use burn_cube::{branch::*, cube, CubeContext, Int, I32}; use burn_jit::{ cube_inline, gpu::{Branch, Elem, Item, Variable}, @@ -12,15 +12,15 @@ type ElemType = I32; #[cube] pub fn reuse(mut x: I) { - while x < I::new(10.) { - x = x + I::new(1.); + while x < I::new(10) { + x = x + I::new(1); } } #[cube] pub fn reuse_incr(mut x: I) { - while x < I::new(10.) { - x += I::new(1.); + while x < I::new(10) { + x += I::new(1); } } @@ -28,7 +28,7 @@ pub fn reuse_incr(mut x: I) { fn cube_reuse_assign_test() { let mut context = CubeContext::root(); - let x = context.create_local(Item::Scalar(Elem::Int(ElemType::into_kind()))); + let x = context.create_local(Item::Scalar(ElemType::into_elem())); reuse_expand::(&mut context, x); let scope = context.into_scope(); @@ -40,7 +40,7 @@ fn cube_reuse_assign_test() { fn cube_reuse_incr_test() { let mut context = CubeContext::root(); - let x = context.create_local(Item::Scalar(Elem::Int(ElemType::into_kind()))); + let x = context.create_local(Item::Scalar(ElemType::into_elem())); reuse_incr_expand::(&mut context, x); let scope = context.into_scope(); @@ -50,7 +50,7 @@ fn cube_reuse_incr_test() { fn inline_macro_ref_assign() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Int(ElemType::into_kind())); + let item = Item::Scalar(ElemType::into_elem()); let x = context.create_local(item); let mut scope = context.into_scope(); @@ -76,7 +76,7 @@ fn inline_macro_ref_assign() -> String { fn inline_macro_ref_incr() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(Elem::Int(ElemType::into_kind())); + let item = Item::Scalar(ElemType::into_elem()); let x = context.create_local(item); let mut scope = context.into_scope(); From 329219874bdd4d090fcef5914c5aaaddd42a5f39 Mon Sep 17 00:00:00 2001 From: louisfd Date: Fri, 10 May 2024 16:35:27 -0400 Subject: [PATCH 35/54] rename cast --- crates/burn-cube/src/element/base.rs | 2 +- crates/burn-cube/src/element/float.rs | 7 +++--- crates/burn-cube/src/element/int.rs | 7 +++--- crates/burn-cube/src/element/numeric.rs | 4 ++++ crates/burn-cube/src/elemtype.rs | 27 ++++++++---------------- crates/burn-cube/tests/cast_elem.rs | 18 ++++++++-------- crates/burn-cube/tests/cast_kind.rs | 6 +++--- crates/burn-cube/tests/for_loop.rs | 2 +- crates/burn-cube/tests/generic_kernel.rs | 2 +- crates/burn-cube/tests/if.rs | 2 +- crates/burn-cube/tests/if_else.rs | 2 +- crates/burn-cube/tests/literal.rs | 2 +- crates/burn-cube/tests/loop.rs | 2 +- crates/burn-cube/tests/reuse.rs | 2 +- 14 files changed, 39 insertions(+), 46 deletions(-) diff --git a/crates/burn-cube/src/element/base.rs b/crates/burn-cube/src/element/base.rs index 0b8ecb8c64..055d75909c 100644 --- a/crates/burn-cube/src/element/base.rs +++ b/crates/burn-cube/src/element/base.rs @@ -4,7 +4,7 @@ use burn_jit::gpu::{Item, Variable}; /// Types used in a cube function must implement this trait /// /// Variables whose values will be known at runtime must -/// have ExpandElement as associated type (using RuntimeType) +/// have ExpandElement as associated type /// Variables whose values will be known at compile time /// must have the primitive type as associated type /// diff --git a/crates/burn-cube/src/element/float.rs b/crates/burn-cube/src/element/float.rs index 8d9da01f2f..e9f0870519 100644 --- a/crates/burn-cube/src/element/float.rs +++ b/crates/burn-cube/src/element/float.rs @@ -12,7 +12,6 @@ pub trait Float: + std::ops::Div + Numeric { - fn into_elem() -> Elem; fn from(val: f64) -> Self; fn from_expand(context: &mut CubeContext, val: f64) -> ExpandElement; } @@ -30,9 +29,6 @@ macro_rules! impl_float { } impl Float for $type { - fn into_elem() -> Elem { - Elem::Float(FloatKind::$type) - } fn from(val: f64) -> Self { Self { val, @@ -55,6 +51,9 @@ macro_rules! impl_float { fn new_expand(context: &mut CubeContext, val: i64) -> ExpandElement { ::from_expand(context, val as f64) } + fn into_elem() -> Elem { + Elem::Float(FloatKind::$type) + } } }; } diff --git a/crates/burn-cube/src/element/int.rs b/crates/burn-cube/src/element/int.rs index 37fee7b7b1..39a63c3905 100644 --- a/crates/burn-cube/src/element/int.rs +++ b/crates/burn-cube/src/element/int.rs @@ -12,7 +12,6 @@ pub trait Int: + std::ops::AddAssign + Numeric { - fn into_elem() -> Elem; fn from(val: i64) -> Self; fn from_expand(context: &mut CubeContext, val: i64) -> ExpandElement; } @@ -30,9 +29,6 @@ macro_rules! impl_int { } impl Int for $type { - fn into_elem() -> Elem { - Elem::Int(IntKind::$type) - } fn from(val: i64) -> Self { Self { val, @@ -55,6 +51,9 @@ macro_rules! impl_int { fn new_expand(context: &mut CubeContext, val: i64) -> ExpandElement { ::from_expand(context, val) } + fn into_elem() -> Elem { + Elem::Int(IntKind::$type) + } } }; } diff --git a/crates/burn-cube/src/element/numeric.rs b/crates/burn-cube/src/element/numeric.rs index 954d3ba16c..e8a6b394ab 100644 --- a/crates/burn-cube/src/element/numeric.rs +++ b/crates/burn-cube/src/element/numeric.rs @@ -1,5 +1,7 @@ // use crate::{BF16, F16, F32, F64, I32, I64}; +use burn_jit::gpu::Elem; + use crate::{CubeContext, CubeType, ExpandElement}; pub trait Numeric: @@ -8,4 +10,6 @@ pub trait Numeric: // If we use numeric then constants are necessarily ints fn new(val: i64) -> Self; fn new_expand(context: &mut CubeContext, val: i64) -> ExpandElement; + + fn into_elem() -> Elem; } diff --git a/crates/burn-cube/src/elemtype.rs b/crates/burn-cube/src/elemtype.rs index 79b1e18082..137f377d27 100644 --- a/crates/burn-cube/src/elemtype.rs +++ b/crates/burn-cube/src/elemtype.rs @@ -1,6 +1,6 @@ use burn_jit::gpu::{Elem, Item}; -use crate::{assign, Bool, CubeContext, CubeType, ExpandElement, Float, Int, UInt}; +use crate::{assign, Bool, CubeContext, CubeType, ExpandElement, Numeric, UInt}; pub fn uint_new(val: u32) -> UInt { UInt { @@ -22,28 +22,19 @@ pub fn bool_new_expand(_context: &mut CubeContext, val: bool) -> (_input: R) -> I { - I::new(0) -} -pub fn to_int_expand( - context: &mut CubeContext, - val: ExpandElement, -) -> ::ExpandType { - let new_var = context.create_local(Item::Scalar(I::into_elem())); - assign::expand(context, val.into(), new_var.clone()); - new_var -} - -pub fn to_float(_input: R) -> F { +// Why i'm stuck with this kind of cast +// R is useless, but removing it I can't compile because of the input +// Any would need boxing +// Might as well use R to figure the val to cast +pub fn cast(_input: R) -> T { // TODO: make val accessible through trait - F::new(0) + T::new(0) } - -pub fn to_float_expand( +pub fn cast_expand( context: &mut CubeContext, val: ExpandElement, ) -> ExpandElement { - let new_var = context.create_local(Item::Scalar(F::into_elem())); + let new_var = context.create_local(Item::Scalar(T::into_elem())); assign::expand(context, val.into(), new_var.clone()); new_var } diff --git a/crates/burn-cube/tests/cast_elem.rs b/crates/burn-cube/tests/cast_elem.rs index 09b6f3ab45..f0ec4e2aa7 100644 --- a/crates/burn-cube/tests/cast_elem.rs +++ b/crates/burn-cube/tests/cast_elem.rs @@ -1,4 +1,4 @@ -use burn_cube::{cube, elemtype::*, Bool, CubeContext, Float, Int, Numeric, UInt, F32, I32}; +use burn_cube::{cube, elemtype::*, Bool, CubeContext, Numeric, UInt, F32, I32}; use burn_jit::{ cube_inline, gpu::{Elem, Item, Variable}, @@ -44,13 +44,13 @@ macro_rules! cast_test { #[cube] pub fn float_to_float(x: F32) { let y = x + F32::new(2); - let _ = to_float::(y) + F32::new(34); + let _ = cast::(y) + F32::new(34); } #[cube] pub fn float_to_int(x: F32) { let y = x + F32::new(2); - let _ = to_int::(y) + I32::new(34); + let _ = cast::(y) + I32::new(34); } #[cube] @@ -96,13 +96,13 @@ cast_test!( #[cube] pub fn int_to_float(x: I32) { let y = x + I32::new(2); - let _ = to_float::(y) + F32::new(34); + let _ = cast::(y) + F32::new(34); } #[cube] pub fn int_to_int(x: I32) { let y = x + I32::new(2); - let _ = to_int::(y) + I32::new(34); + let _ = cast::(y) + I32::new(34); } #[cube] @@ -148,13 +148,13 @@ cast_test!( #[cube] pub fn uint_to_float(x: UInt) { let y = x + uint_new(2u32); - let _ = to_float::(y) + F32::new(34); + let _ = cast::(y) + F32::new(34); } #[cube] pub fn uint_to_int(x: UInt) { let y = x + uint_new(2u32); - let _ = to_int::(y) + I32::new(34); + let _ = cast::(y) + I32::new(34); } #[cube] @@ -200,13 +200,13 @@ cast_test!( #[cube] pub fn bool_to_float(x: Bool) { let y = x & bool_new(false); - let _ = to_float::(y) + F32::new(34); + let _ = cast::(y) + F32::new(34); } #[cube] pub fn bool_to_int(x: Bool) { let y = x & bool_new(false); - let _ = to_int::(y) + I32::new(34); + let _ = cast::(y) + I32::new(34); } #[cube] diff --git a/crates/burn-cube/tests/cast_kind.rs b/crates/burn-cube/tests/cast_kind.rs index 8eec42e74d..f4bb448a50 100644 --- a/crates/burn-cube/tests/cast_kind.rs +++ b/crates/burn-cube/tests/cast_kind.rs @@ -4,21 +4,21 @@ use burn_jit::{cube_inline, gpu::Item}; #[cube] pub fn cast_float_kind(input: F1) { let x = input + F1::from(5.9); - let y = to_float::(x); + let y = cast::(x); let _ = y + F2::from(2.3); } #[cube] pub fn cast_int_kind(input: I1) { let x = input + I1::from(5); - let y = to_int::(x); + let y = cast::(x); let _ = y + I2::from(2); } #[cube] pub fn cast_numeric_to_kind(input: T) { let x = input + T::new(5); - let y = to_int::(x); + let y = cast::(x); let _ = y + I2::new(2); } diff --git a/crates/burn-cube/tests/for_loop.rs b/crates/burn-cube/tests/for_loop.rs index 666287a6c0..5e0e069ded 100644 --- a/crates/burn-cube/tests/for_loop.rs +++ b/crates/burn-cube/tests/for_loop.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, Array, CubeContext, Float, UInt, F32}; +use burn_cube::{branch::*, cube, Array, CubeContext, Float, Numeric, UInt, F32}; use burn_jit::{ cube_inline, gpu::{Item, Variable}, diff --git a/crates/burn-cube/tests/generic_kernel.rs b/crates/burn-cube/tests/generic_kernel.rs index d8be4c68e6..b757fa6e71 100644 --- a/crates/burn-cube/tests/generic_kernel.rs +++ b/crates/burn-cube/tests/generic_kernel.rs @@ -1,4 +1,4 @@ -use burn_cube::{cube, CubeContext, Float, Int, Numeric, F32, I32}; +use burn_cube::{cube, CubeContext, Numeric, F32, I32}; use burn_jit::{cube_inline, gpu::Item}; #[cube] diff --git a/crates/burn-cube/tests/if.rs b/crates/burn-cube/tests/if.rs index e60eb7a9cc..364a2ab3e9 100644 --- a/crates/burn-cube/tests/if.rs +++ b/crates/burn-cube/tests/if.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, CubeContext, Float, F32}; +use burn_cube::{branch::*, cube, CubeContext, Float, Numeric, F32}; use burn_jit::{ cube_inline, gpu::{Elem, Item, Variable}, diff --git a/crates/burn-cube/tests/if_else.rs b/crates/burn-cube/tests/if_else.rs index 2866124599..ae366dea98 100644 --- a/crates/burn-cube/tests/if_else.rs +++ b/crates/burn-cube/tests/if_else.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, CubeContext, Float, F32}; +use burn_cube::{branch::*, cube, CubeContext, Float, Numeric, F32}; use burn_jit::{ cube_inline, gpu::{Elem, Item, Variable}, diff --git a/crates/burn-cube/tests/literal.rs b/crates/burn-cube/tests/literal.rs index 2f338722d5..36415fde9b 100644 --- a/crates/burn-cube/tests/literal.rs +++ b/crates/burn-cube/tests/literal.rs @@ -1,4 +1,4 @@ -use burn_cube::{cube, CubeContext, Float, F32}; +use burn_cube::{cube, CubeContext, Float, Numeric, F32}; use burn_jit::{cube_inline, gpu::Item}; type ElemType = F32; diff --git a/crates/burn-cube/tests/loop.rs b/crates/burn-cube/tests/loop.rs index 15988df487..a0bb9ad757 100644 --- a/crates/burn-cube/tests/loop.rs +++ b/crates/burn-cube/tests/loop.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, CubeContext, Int, I32}; +use burn_cube::{branch::*, cube, CubeContext, Int, Numeric, I32}; use burn_jit::cube_inline; use burn_jit::gpu::Branch; use burn_jit::gpu::{Elem, Item, Variable}; diff --git a/crates/burn-cube/tests/reuse.rs b/crates/burn-cube/tests/reuse.rs index 5915d6c6fb..30c8a4f3dc 100644 --- a/crates/burn-cube/tests/reuse.rs +++ b/crates/burn-cube/tests/reuse.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, CubeContext, Int, I32}; +use burn_cube::{branch::*, cube, CubeContext, Int, Numeric, I32}; use burn_jit::{ cube_inline, gpu::{Branch, Elem, Item, Variable}, From 9046baa146862fa44aec62c9b75b4514af86252a Mon Sep 17 00:00:00 2001 From: louisfd Date: Sat, 11 May 2024 16:13:27 -0400 Subject: [PATCH 36/54] type system becoming great --- crates/burn-cube/src/element/bool.rs | 29 +++- crates/burn-cube/src/element/conversion.rs | 154 +++++++++++++++++++++ crates/burn-cube/src/element/float.rs | 29 ++-- crates/burn-cube/src/element/int.rs | 29 ++-- crates/burn-cube/src/element/mod.rs | 3 + crates/burn-cube/src/element/numeric.rs | 18 +-- crates/burn-cube/src/element/primitive.rs | 8 +- crates/burn-cube/src/element/runtime.rs | 14 ++ crates/burn-cube/src/element/uint.rs | 26 +++- crates/burn-cube/src/elemtype.rs | 70 ---------- crates/burn-cube/src/lib.rs | 1 - crates/burn-cube/src/operation/binary.rs | 14 +- crates/burn-cube/tests/cast_elem.rs | 50 +++---- crates/burn-cube/tests/cast_kind.rs | 22 +-- crates/burn-cube/tests/for_loop.rs | 2 +- crates/burn-cube/tests/function_call.rs | 10 +- crates/burn-cube/tests/generic_kernel.rs | 2 +- crates/burn-cube/tests/if.rs | 8 +- crates/burn-cube/tests/if_else.rs | 2 +- crates/burn-cube/tests/literal.rs | 3 +- crates/burn-cube/tests/loop.rs | 8 +- crates/burn-cube/tests/reuse.rs | 2 +- 22 files changed, 337 insertions(+), 167 deletions(-) create mode 100644 crates/burn-cube/src/element/conversion.rs create mode 100644 crates/burn-cube/src/element/runtime.rs delete mode 100644 crates/burn-cube/src/elemtype.rs diff --git a/crates/burn-cube/src/element/bool.rs b/crates/burn-cube/src/element/bool.rs index c2f99c4c21..3674d8cdf1 100644 --- a/crates/burn-cube/src/element/bool.rs +++ b/crates/burn-cube/src/element/bool.rs @@ -1,6 +1,8 @@ -use crate::{CubeType, ExpandElement}; +use burn_jit::gpu::Elem; -#[derive(new, Clone, Copy)] +use crate::{CubeContext, CubeType, ExpandElement, RuntimeType}; + +#[derive(Clone, Copy)] pub struct Bool { pub val: bool, pub vectorization: u8, @@ -9,3 +11,26 @@ pub struct Bool { impl CubeType for Bool { type ExpandType = ExpandElement; } + +impl Bool { + pub fn new(val: bool) -> Self { + Self { + val, + vectorization: 1, + } + } + pub fn new_expand(_context: &mut CubeContext, val: bool) -> ::ExpandType { + val.into() + } +} + +impl RuntimeType for Bool { + type Primitive = bool; + fn val(&self) -> Self::Primitive { + self.val + } + + fn into_elem() -> Elem { + Elem::Bool + } +} diff --git a/crates/burn-cube/src/element/conversion.rs b/crates/burn-cube/src/element/conversion.rs new file mode 100644 index 0000000000..48403e8bfe --- /dev/null +++ b/crates/burn-cube/src/element/conversion.rs @@ -0,0 +1,154 @@ +use crate::{Bool, Float, Int, RuntimeType, UInt, BF16, F16, F32, F64, I32, I64}; + +macro_rules! impl_to_float { + ($to:ident, $from1:ident) => { + impl From<$from1> for $to { + fn from(value: $from1) -> Self { + Self::from_primitive(value.val() as f64) + } + } + }; +} + +macro_rules! impl_to_float_from_bool { + ($to:ident, $from1:ident) => { + impl From<$from1> for $to { + fn from(value: $from1) -> Self { + Self::from_primitive(match value.val() { + true => 1., + false => 0., + }) + } + } + }; +} + +impl_to_float!(F16, BF16); +impl_to_float!(F16, F32); +impl_to_float!(F16, F64); +impl_to_float!(F16, I32); +impl_to_float!(F16, I64); +impl_to_float!(F16, UInt); +impl_to_float_from_bool!(F16, Bool); + +impl_to_float!(BF16, F16); +impl_to_float!(BF16, F32); +impl_to_float!(BF16, F64); +impl_to_float!(BF16, I32); +impl_to_float!(BF16, I64); +impl_to_float!(BF16, UInt); +impl_to_float_from_bool!(BF16, Bool); + +impl_to_float!(F32, F16); +impl_to_float!(F32, BF16); +impl_to_float!(F32, F64); +impl_to_float!(F32, I32); +impl_to_float!(F32, I64); +impl_to_float!(F32, UInt); +impl_to_float_from_bool!(F32, Bool); + +impl_to_float!(F64, F16); +impl_to_float!(F64, BF16); +impl_to_float!(F64, F32); +impl_to_float!(F64, I32); +impl_to_float!(F64, I64); +impl_to_float!(F64, UInt); +impl_to_float_from_bool!(F64, Bool); + +macro_rules! impl_to_int { + ($to:ident, $from1:ident) => { + impl From<$from1> for $to { + fn from(value: $from1) -> Self { + Self::from_primitive(value.val() as i64) + } + } + }; +} + +macro_rules! impl_to_int_from_bool { + ($to:ident, $from1:ident) => { + impl From<$from1> for $to { + fn from(value: $from1) -> Self { + Self::from_primitive(match value.val() { + true => 1, + false => 0, + }) + } + } + }; +} + +impl_to_int!(I32, F16); +impl_to_int!(I32, BF16); +impl_to_int!(I32, F32); +impl_to_int!(I32, F64); +impl_to_int!(I32, I64); +impl_to_int!(I32, UInt); +impl_to_int_from_bool!(I32, Bool); + +impl_to_int!(I64, F16); +impl_to_int!(I64, BF16); +impl_to_int!(I64, F32); +impl_to_int!(I64, F64); +impl_to_int!(I64, I32); +impl_to_int!(I64, UInt); +impl_to_int_from_bool!(I64, Bool); + +macro_rules! impl_to_uint { + ($to:ident, $from1:ident) => { + impl From<$from1> for $to { + fn from(value: $from1) -> Self { + Self::new(value.val() as i64) + } + } + }; +} + +macro_rules! impl_to_uint_from_bool { + ($to:ident, $from1:ident) => { + impl From<$from1> for $to { + fn from(value: $from1) -> Self { + Self::new(match value.val() { + true => 1, + false => 0, + }) + } + } + }; +} + +impl_to_uint!(UInt, F16); +impl_to_uint!(UInt, BF16); +impl_to_uint!(UInt, F32); +impl_to_uint!(UInt, F64); +impl_to_uint!(UInt, I32); +impl_to_uint!(UInt, I64); +impl_to_uint_from_bool!(UInt, Bool); + +macro_rules! impl_to_bool_from_float { + ($to:ident, $from1:ident) => { + impl From<$from1> for $to { + fn from(value: $from1) -> Self { + Self::new(value.val() > 0.) + } + } + }; +} + +macro_rules! impl_to_bool_from_int { + ($to:ident, $from1:ident) => { + impl From<$from1> for $to { + fn from(value: $from1) -> Self { + Self::new(value.val() > 0) + } + } + }; +} + +impl_to_bool_from_float!(Bool, F16); +impl_to_bool_from_float!(Bool, BF16); +impl_to_bool_from_float!(Bool, F32); +impl_to_bool_from_float!(Bool, F64); +impl_to_bool_from_int!(Bool, I32); +impl_to_bool_from_int!(Bool, I64); +impl_to_bool_from_int!(Bool, UInt); diff --git a/crates/burn-cube/src/element/float.rs b/crates/burn-cube/src/element/float.rs index e9f0870519..3311e5c689 100644 --- a/crates/burn-cube/src/element/float.rs +++ b/crates/burn-cube/src/element/float.rs @@ -1,4 +1,4 @@ -use crate::{CubeContext, CubeType, ExpandElement, Numeric}; +use crate::{CubeContext, CubeType, ExpandElement, Numeric, RuntimeType}; use burn_jit::gpu::{Elem, FloatKind, Variable}; use std::rc::Rc; @@ -10,10 +10,11 @@ pub trait Float: + std::ops::Sub + std::ops::Mul + std::ops::Div + + std::ops::AddAssign + Numeric { - fn from(val: f64) -> Self; - fn from_expand(context: &mut CubeContext, val: f64) -> ExpandElement; + fn from_primitive(val: f64) -> Self; + fn from_primitive_expand(context: &mut CubeContext, val: f64) -> ExpandElement; } macro_rules! impl_float { @@ -28,31 +29,41 @@ macro_rules! impl_float { type ExpandType = ExpandElement; } + impl RuntimeType for $type { + type Primitive = f64; + fn val(&self) -> Self::Primitive { + self.val + } + fn into_elem() -> Elem { + Elem::Float(FloatKind::$type) + } + } + impl Float for $type { - fn from(val: f64) -> Self { + fn from_primitive(val: f64) -> Self { Self { val, vectorization: 1, } } - fn from_expand(_context: &mut CubeContext, val: f64) -> ExpandElement { + fn from_primitive_expand(_context: &mut CubeContext, val: f64) -> ExpandElement { let new_var = Variable::ConstantScalar(val, Self::into_elem()); ExpandElement::new(Rc::new(new_var)) } } impl Numeric for $type { + // Method new takes an i64, because it is used when treating the float as numeric, + // which must be an int in the cube kernel because new numerics need to be supported by Int as well fn new(val: i64) -> Self { Self { val: val as f64, vectorization: 1, } } + fn new_expand(context: &mut CubeContext, val: i64) -> ExpandElement { - ::from_expand(context, val as f64) - } - fn into_elem() -> Elem { - Elem::Float(FloatKind::$type) + ::from_primitive_expand(context, val as f64) } } }; diff --git a/crates/burn-cube/src/element/int.rs b/crates/burn-cube/src/element/int.rs index 39a63c3905..4f4bb87ec8 100644 --- a/crates/burn-cube/src/element/int.rs +++ b/crates/burn-cube/src/element/int.rs @@ -1,4 +1,4 @@ -use crate::{CubeContext, CubeType, ExpandElement, Numeric}; +use crate::{CubeContext, CubeType, ExpandElement, Numeric, RuntimeType}; use burn_jit::gpu::{Elem, IntKind, Variable}; use std::rc::Rc; @@ -6,14 +6,16 @@ pub trait Int: Clone + Copy + std::cmp::PartialOrd + + std::ops::Add + std::ops::Sub + std::ops::Mul + std::ops::Div + std::ops::AddAssign + + std::ops::Rem + Numeric { - fn from(val: i64) -> Self; - fn from_expand(context: &mut CubeContext, val: i64) -> ExpandElement; + fn from_primitive(val: i64) -> Self; + fn from_primitive_expand(context: &mut CubeContext, val: i64) -> ExpandElement; } macro_rules! impl_int { @@ -28,14 +30,25 @@ macro_rules! impl_int { type ExpandType = ExpandElement; } + impl RuntimeType for $type { + type Primitive = i64; + fn val(&self) -> Self::Primitive { + self.val + } + fn into_elem() -> Elem { + Elem::Int(IntKind::$type) + } + } + impl Int for $type { - fn from(val: i64) -> Self { + fn from_primitive(val: i64) -> Self { Self { val, vectorization: 1, } } - fn from_expand(_context: &mut CubeContext, val: i64) -> ExpandElement { + + fn from_primitive_expand(_context: &mut CubeContext, val: i64) -> ExpandElement { let new_var = Variable::ConstantScalar(val as f64, Self::into_elem()); ExpandElement::new(Rc::new(new_var)) } @@ -48,11 +61,9 @@ macro_rules! impl_int { vectorization: 1, } } + fn new_expand(context: &mut CubeContext, val: i64) -> ExpandElement { - ::from_expand(context, val) - } - fn into_elem() -> Elem { - Elem::Int(IntKind::$type) + ::from_primitive_expand(context, val) } } }; diff --git a/crates/burn-cube/src/element/mod.rs b/crates/burn-cube/src/element/mod.rs index 9ff6b110b1..2269ac8e25 100644 --- a/crates/burn-cube/src/element/mod.rs +++ b/crates/burn-cube/src/element/mod.rs @@ -1,10 +1,12 @@ mod array; mod base; mod bool; +mod conversion; mod float; mod int; mod numeric; mod primitive; +mod runtime; mod uint; pub use array::*; @@ -13,4 +15,5 @@ pub use bool::*; pub use float::*; pub use int::*; pub use numeric::*; +pub use runtime::*; pub use uint::*; diff --git a/crates/burn-cube/src/element/numeric.rs b/crates/burn-cube/src/element/numeric.rs index e8a6b394ab..cf825a0fff 100644 --- a/crates/burn-cube/src/element/numeric.rs +++ b/crates/burn-cube/src/element/numeric.rs @@ -1,15 +1,17 @@ -// use crate::{BF16, F16, F32, F64, I32, I64}; - -use burn_jit::gpu::Elem; - -use crate::{CubeContext, CubeType, ExpandElement}; +use crate::{CubeContext, ExpandElement, RuntimeType}; pub trait Numeric: - Clone + Copy + CubeType + std::ops::Add + Clone + + Copy + + RuntimeType + + std::ops::Add + + std::ops::AddAssign + + std::ops::Sub + + std::ops::Mul + + std::ops::Div + + std::cmp::PartialOrd { // If we use numeric then constants are necessarily ints fn new(val: i64) -> Self; fn new_expand(context: &mut CubeContext, val: i64) -> ExpandElement; - - fn into_elem() -> Elem; } diff --git a/crates/burn-cube/src/element/primitive.rs b/crates/burn-cube/src/element/primitive.rs index 9cb34b6cbb..856858e68b 100644 --- a/crates/burn-cube/src/element/primitive.rs +++ b/crates/burn-cube/src/element/primitive.rs @@ -2,7 +2,7 @@ use std::rc::Rc; use burn_jit::gpu::Variable; -use crate::{ExpandElement, CubeType}; +use crate::{CubeType, ExpandElement}; impl CubeType for bool { type ExpandType = bool; @@ -49,3 +49,9 @@ impl From for ExpandElement { ExpandElement::new(Rc::new(Variable::from(value))) } } + +impl From for ExpandElement { + fn from(value: i64) -> Self { + ExpandElement::new(Rc::new(Variable::from(value))) + } +} diff --git a/crates/burn-cube/src/element/runtime.rs b/crates/burn-cube/src/element/runtime.rs new file mode 100644 index 0000000000..8d4bf7e22f --- /dev/null +++ b/crates/burn-cube/src/element/runtime.rs @@ -0,0 +1,14 @@ +use burn_jit::gpu::{Elem, Item}; + +use crate::{assign, CubeContext, CubeType, ExpandElement}; + +pub trait RuntimeType: CubeType { + type Primitive; + fn val(&self) -> Self::Primitive; + fn into_elem() -> Elem; + fn from_expand(context: &mut CubeContext, val: ExpandElement) -> ExpandElement { + let new_var = context.create_local(Item::Scalar(::into_elem())); + assign::expand(context, val.into(), new_var.clone()); + new_var + } +} diff --git a/crates/burn-cube/src/element/uint.rs b/crates/burn-cube/src/element/uint.rs index 1e255c0828..230f9f2003 100644 --- a/crates/burn-cube/src/element/uint.rs +++ b/crates/burn-cube/src/element/uint.rs @@ -1,32 +1,48 @@ -use crate::{ExpandElement, CubeType}; +use burn_jit::gpu::Elem; + +use crate::{CubeContext, CubeType, ExpandElement, RuntimeType}; #[derive(Clone, Copy)] pub struct UInt { - pub val: u32, + pub val: ::Primitive, pub vectorization: u8, } impl UInt { - pub fn new(val: u32) -> Self { + // Use with integer literal + pub fn new(val: i64) -> Self { Self { val, vectorization: 1, } } + pub fn new_expand(_context: &mut CubeContext, val: i64) -> ::ExpandType { + (val as u32).into() + } } impl CubeType for UInt { type ExpandType = ExpandElement; } +impl RuntimeType for UInt { + type Primitive = i64; + fn val(&self) -> Self::Primitive { + self.val + } + fn into_elem() -> Elem { + Elem::UInt + } +} + impl From for UInt { fn from(value: u32) -> Self { - UInt::new(value) + UInt::new(value as ::Primitive) } } impl From for UInt { fn from(value: usize) -> Self { - UInt::new(value as u32) + UInt::new(value as ::Primitive) } } diff --git a/crates/burn-cube/src/elemtype.rs b/crates/burn-cube/src/elemtype.rs deleted file mode 100644 index 137f377d27..0000000000 --- a/crates/burn-cube/src/elemtype.rs +++ /dev/null @@ -1,70 +0,0 @@ -use burn_jit::gpu::{Elem, Item}; - -use crate::{assign, Bool, CubeContext, CubeType, ExpandElement, Numeric, UInt}; - -pub fn uint_new(val: u32) -> UInt { - UInt { - val, - vectorization: 1, - } -} -pub fn uint_new_expand(_context: &mut CubeContext, val: u32) -> ::ExpandType { - val.into() -} - -pub fn bool_new(val: bool) -> Bool { - Bool { - val, - vectorization: 1, - } -} -pub fn bool_new_expand(_context: &mut CubeContext, val: bool) -> ::ExpandType { - val.into() -} - -// Why i'm stuck with this kind of cast -// R is useless, but removing it I can't compile because of the input -// Any would need boxing -// Might as well use R to figure the val to cast -pub fn cast(_input: R) -> T { - // TODO: make val accessible through trait - T::new(0) -} -pub fn cast_expand( - context: &mut CubeContext, - val: ExpandElement, -) -> ExpandElement { - let new_var = context.create_local(Item::Scalar(T::into_elem())); - assign::expand(context, val.into(), new_var.clone()); - new_var -} - -pub fn to_uint(_input: R) -> UInt { - UInt { - val: 0, - vectorization: 1, - } -} -pub fn to_uint_expand( - context: &mut CubeContext, - val: ExpandElement, -) -> ::ExpandType { - let new_var = context.create_local(Item::Scalar(Elem::UInt)); - assign::expand(context, val.into(), new_var.clone()); - new_var -} - -pub fn to_bool(_input: R) -> Bool { - Bool { - val: true, - vectorization: 1, - } -} -pub fn to_bool_expand( - context: &mut CubeContext, - val: ExpandElement, -) -> ::ExpandType { - let new_var = context.create_local(Item::Scalar(Elem::Bool)); - assign::expand(context, val.into(), new_var.clone()); - new_var -} diff --git a/crates/burn-cube/src/lib.rs b/crates/burn-cube/src/lib.rs index 1a2043d6c3..4009395922 100644 --- a/crates/burn-cube/src/lib.rs +++ b/crates/burn-cube/src/lib.rs @@ -5,7 +5,6 @@ extern crate derive_new; // For use with * pub mod branch; -pub mod elemtype; mod context; mod element; diff --git a/crates/burn-cube/src/operation/binary.rs b/crates/burn-cube/src/operation/binary.rs index 46955ae06d..08c085d8c7 100644 --- a/crates/burn-cube/src/operation/binary.rs +++ b/crates/burn-cube/src/operation/binary.rs @@ -20,7 +20,7 @@ pub mod add { type Output = Self; fn add(self, rhs: Self) -> Self::Output { - <$type as $trait>::from(self.val + rhs.val) + <$type as $trait>::from_primitive(self.val + rhs.val) } } }; @@ -62,7 +62,7 @@ pub mod sub { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { - <$type as $trait>::from(self.val - rhs.val) + <$type as $trait>::from_primitive(self.val - rhs.val) } } }; @@ -104,7 +104,7 @@ pub mod mul { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { - <$type as $trait>::from(self.val * rhs.val) + <$type as $trait>::from_primitive(self.val * rhs.val) } } }; @@ -146,7 +146,7 @@ pub mod div { type Output = Self; fn div(self, rhs: Self) -> Self::Output { - <$type as $trait>::from(self.val / rhs.val) + <$type as $trait>::from_primitive(self.val / rhs.val) } } }; @@ -187,7 +187,7 @@ pub mod rem { type Output = Self; fn rem(self, rhs: Self) -> Self::Output { - <$type as $trait>::from(self.val % rhs.val) + <$type as $trait>::from_primitive(self.val % rhs.val) } } }; @@ -225,7 +225,7 @@ pub mod and { type Output = Bool; fn bitand(self, rhs: Self) -> Self::Output { - Bool::new(self.val && rhs.val, 1) + Bool::new(self.val && rhs.val) } } } @@ -247,7 +247,7 @@ pub mod or { type Output = Bool; fn bitor(self, rhs: Self) -> Self::Output { - Bool::new(self.val || rhs.val, 1) + Bool::new(self.val || rhs.val) } } } diff --git a/crates/burn-cube/tests/cast_elem.rs b/crates/burn-cube/tests/cast_elem.rs index f0ec4e2aa7..9dd80f96c5 100644 --- a/crates/burn-cube/tests/cast_elem.rs +++ b/crates/burn-cube/tests/cast_elem.rs @@ -1,4 +1,4 @@ -use burn_cube::{cube, elemtype::*, Bool, CubeContext, Numeric, UInt, F32, I32}; +use burn_cube::{cube, Bool, CubeContext, Numeric, RuntimeType, UInt, F32, I32}; use burn_jit::{ cube_inline, gpu::{Elem, Item, Variable}, @@ -44,25 +44,25 @@ macro_rules! cast_test { #[cube] pub fn float_to_float(x: F32) { let y = x + F32::new(2); - let _ = cast::(y) + F32::new(34); + let _ = F32::from(y) + F32::new(34); } #[cube] pub fn float_to_int(x: F32) { let y = x + F32::new(2); - let _ = cast::(y) + I32::new(34); + let _ = I32::from(y) + I32::new(34); } #[cube] pub fn float_to_uint(x: F32) { let y = x + F32::new(2); - let _ = to_uint(y) + uint_new(34u32); + let _ = UInt::from(y) + UInt::new(34); } #[cube] pub fn float_to_bool(x: F32) { let y = x + F32::new(2); - let _ = to_bool(y) | bool_new(true); + let _ = Bool::from(y) | Bool::new(true); } cast_test!( @@ -96,25 +96,25 @@ cast_test!( #[cube] pub fn int_to_float(x: I32) { let y = x + I32::new(2); - let _ = cast::(y) + F32::new(34); + let _ = F32::from(y) + F32::new(34); } #[cube] pub fn int_to_int(x: I32) { let y = x + I32::new(2); - let _ = cast::(y) + I32::new(34); + let _ = I32::from(y) + I32::new(34); } #[cube] pub fn int_to_uint(x: I32) { let y = x + I32::new(2); - let _ = to_uint(y) + uint_new(34u32); + let _ = UInt::from(y) + UInt::new(34); } #[cube] pub fn int_to_bool(x: I32) { let y = x + I32::new(2); - let _ = to_bool(y) | bool_new(true); + let _ = Bool::from(y) | Bool::new(true); } cast_test!( @@ -147,26 +147,26 @@ cast_test!( // // From uint #[cube] pub fn uint_to_float(x: UInt) { - let y = x + uint_new(2u32); - let _ = cast::(y) + F32::new(34); + let y = x + UInt::new(2); + let _ = F32::from(y) + F32::new(34); } #[cube] pub fn uint_to_int(x: UInt) { - let y = x + uint_new(2u32); - let _ = cast::(y) + I32::new(34); + let y = x + UInt::new(2); + let _ = I32::from(y) + I32::new(34); } #[cube] pub fn uint_to_uint(x: UInt) { - let y = x + uint_new(2u32); - let _ = to_uint(y) + uint_new(34u32); + let y = x + UInt::new(2); + let _ = UInt::from(y) + UInt::new(34); } #[cube] pub fn uint_to_bool(x: UInt) { - let y = x + uint_new(2u32); - let _ = to_bool(y) | bool_new(true); + let y = x + UInt::new(2); + let _ = Bool::from(y) | Bool::new(true); } cast_test!( @@ -199,26 +199,26 @@ cast_test!( // From bool #[cube] pub fn bool_to_float(x: Bool) { - let y = x & bool_new(false); - let _ = cast::(y) + F32::new(34); + let y = x & Bool::new(false); + let _ = F32::from(y) + F32::new(34); } #[cube] pub fn bool_to_int(x: Bool) { - let y = x & bool_new(false); - let _ = cast::(y) + I32::new(34); + let y = x & Bool::new(false); + let _ = I32::from(y) + I32::new(34); } #[cube] pub fn bool_to_uint(x: Bool) { - let y = x & bool_new(false); - let _ = to_uint(y) + uint_new(34u32); + let y = x & Bool::new(false); + let _ = UInt::from(y) + UInt::new(34); } #[cube] pub fn bool_to_bool(x: Bool) { - let y = x & bool_new(false); - let _ = to_bool(y) | bool_new(true); + let y = x & Bool::new(false); + let _ = Bool::from(y) | Bool::new(true); } cast_test!( diff --git a/crates/burn-cube/tests/cast_kind.rs b/crates/burn-cube/tests/cast_kind.rs index f4bb448a50..d161e6bb45 100644 --- a/crates/burn-cube/tests/cast_kind.rs +++ b/crates/burn-cube/tests/cast_kind.rs @@ -1,24 +1,24 @@ -use burn_cube::{cube, elemtype::*, CubeContext, Float, Int, Numeric, F32, F64, I32, I64}; +use burn_cube::{cube, CubeContext, Float, Int, Numeric, RuntimeType, F32, F64, I32, I64}; use burn_jit::{cube_inline, gpu::Item}; #[cube] -pub fn cast_float_kind(input: F1) { - let x = input + F1::from(5.9); - let y = cast::(x); - let _ = y + F2::from(2.3); +pub fn cast_float_kind>(input: F1) { + let x = input + F1::from_primitive(5.9); + let y = F2::from(x); + let _ = y + F2::from_primitive(2.3); } #[cube] -pub fn cast_int_kind(input: I1) { - let x = input + I1::from(5); - let y = cast::(x); - let _ = y + I2::from(2); +pub fn cast_int_kind>(input: I1) { + let x = input + I1::from_primitive(5); + let y = I2::from(x); + let _ = y + I2::from_primitive(2); } #[cube] -pub fn cast_numeric_to_kind(input: T) { +pub fn cast_numeric_to_kind>(input: T) { let x = input + T::new(5); - let y = cast::(x); + let y = I2::from(x); let _ = y + I2::new(2); } diff --git a/crates/burn-cube/tests/for_loop.rs b/crates/burn-cube/tests/for_loop.rs index 5e0e069ded..544beb75a4 100644 --- a/crates/burn-cube/tests/for_loop.rs +++ b/crates/burn-cube/tests/for_loop.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, Array, CubeContext, Float, Numeric, UInt, F32}; +use burn_cube::{branch::*, cube, Array, CubeContext, Float, RuntimeType, UInt, F32}; use burn_jit::{ cube_inline, gpu::{Item, Variable}, diff --git a/crates/burn-cube/tests/function_call.rs b/crates/burn-cube/tests/function_call.rs index a442cc5d8e..3d3578f911 100644 --- a/crates/burn-cube/tests/function_call.rs +++ b/crates/burn-cube/tests/function_call.rs @@ -1,4 +1,4 @@ -use burn_cube::{cube, elemtype::*, CubeContext, UInt}; +use burn_cube::{cube, CubeContext, UInt}; use burn_jit::gpu::{Elem, Item}; #[cube] @@ -8,12 +8,12 @@ pub fn caller_no_arg(x: UInt) { #[cube] pub fn callee_no_arg() -> UInt { - uint_new(8u32) + UInt::new(8) } #[cube] pub fn no_call_no_arg(x: UInt) { - let _ = x + uint_new(8u32); + let _ = x + UInt::new(8); } #[cube] @@ -23,12 +23,12 @@ pub fn caller_with_arg(x: UInt) { #[cube] pub fn callee_with_arg(x: UInt) -> UInt { - x * uint_new(8u32) + x * UInt::new(8) } #[cube] pub fn no_call_with_arg(x: UInt) { - let _ = x + x * uint_new(8u32); + let _ = x + x * UInt::new(8); } #[test] diff --git a/crates/burn-cube/tests/generic_kernel.rs b/crates/burn-cube/tests/generic_kernel.rs index b757fa6e71..522d2f4ddf 100644 --- a/crates/burn-cube/tests/generic_kernel.rs +++ b/crates/burn-cube/tests/generic_kernel.rs @@ -1,4 +1,4 @@ -use burn_cube::{cube, CubeContext, Numeric, F32, I32}; +use burn_cube::{cube, CubeContext, Numeric, RuntimeType, F32, I32}; use burn_jit::{cube_inline, gpu::Item}; #[cube] diff --git a/crates/burn-cube/tests/if.rs b/crates/burn-cube/tests/if.rs index 364a2ab3e9..9e3237206b 100644 --- a/crates/burn-cube/tests/if.rs +++ b/crates/burn-cube/tests/if.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, CubeContext, Float, Numeric, F32}; +use burn_cube::{branch::*, cube, CubeContext, Numeric, RuntimeType, F32}; use burn_jit::{ cube_inline, gpu::{Elem, Item, Variable}, @@ -7,9 +7,9 @@ use burn_jit::{ type ElemType = F32; #[cube] -pub fn if_greater(lhs: F) { - if lhs > F::new(0) { - let _ = lhs + F::new(4); +pub fn if_greater(lhs: T) { + if lhs > T::new(0) { + let _ = lhs + T::new(4); } } diff --git a/crates/burn-cube/tests/if_else.rs b/crates/burn-cube/tests/if_else.rs index ae366dea98..4871d3c849 100644 --- a/crates/burn-cube/tests/if_else.rs +++ b/crates/burn-cube/tests/if_else.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, CubeContext, Float, Numeric, F32}; +use burn_cube::{branch::*, cube, CubeContext, Float, RuntimeType, F32}; use burn_jit::{ cube_inline, gpu::{Elem, Item, Variable}, diff --git a/crates/burn-cube/tests/literal.rs b/crates/burn-cube/tests/literal.rs index 36415fde9b..b6f4751a8e 100644 --- a/crates/burn-cube/tests/literal.rs +++ b/crates/burn-cube/tests/literal.rs @@ -1,4 +1,4 @@ -use burn_cube::{cube, CubeContext, Float, Numeric, F32}; +use burn_cube::{cube, CubeContext, Float, RuntimeType, F32}; use burn_jit::{cube_inline, gpu::Item}; type ElemType = F32; @@ -13,7 +13,6 @@ fn cube_literal_test() { let mut context = CubeContext::root(); let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); - // let lhs = context.create_local(ElemType::into_item()); literal_expand::(&mut context, lhs); let scope = context.into_scope(); diff --git a/crates/burn-cube/tests/loop.rs b/crates/burn-cube/tests/loop.rs index a0bb9ad757..8a6a5c28b0 100644 --- a/crates/burn-cube/tests/loop.rs +++ b/crates/burn-cube/tests/loop.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, CubeContext, Int, Numeric, I32}; +use burn_cube::{branch::*, cube, CubeContext, Int, RuntimeType, I32}; use burn_jit::cube_inline; use burn_jit::gpu::Branch; use burn_jit::gpu::{Elem, Item, Variable}; @@ -8,7 +8,7 @@ type ElemType = I32; #[cube] pub fn while_not(lhs: I) { while lhs != I::new(0) { - let _ = lhs - I::new(1); + let _ = lhs % I::new(1); } } @@ -18,7 +18,7 @@ pub fn manual_loop_break(lhs: I) { if lhs != I::new(0) { break; } - let _ = lhs - I::new(1); + let _ = lhs % I::new(1); } } @@ -64,7 +64,7 @@ fn inline_macro_ref() -> String { scope.register(Branch::Break); })); - cube_inline!(scope, rhs = lhs - 1i32); + cube_inline!(scope, rhs = lhs % 1i32); }) ); diff --git a/crates/burn-cube/tests/reuse.rs b/crates/burn-cube/tests/reuse.rs index 30c8a4f3dc..a2707cc9a0 100644 --- a/crates/burn-cube/tests/reuse.rs +++ b/crates/burn-cube/tests/reuse.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, CubeContext, Int, Numeric, I32}; +use burn_cube::{branch::*, cube, CubeContext, Int, RuntimeType, I32}; use burn_jit::{ cube_inline, gpu::{Branch, Elem, Item, Variable}, From 9c347da427da90c4b6fd5c5d7e4c3f9e79a7aadb Mon Sep 17 00:00:00 2001 From: louisfd Date: Sat, 11 May 2024 17:32:17 -0400 Subject: [PATCH 37/54] fmt --- crates/burn-cube/src/element/array.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-cube/src/element/array.rs b/crates/burn-cube/src/element/array.rs index 234f6e0c45..e1c360b2ab 100644 --- a/crates/burn-cube/src/element/array.rs +++ b/crates/burn-cube/src/element/array.rs @@ -1,4 +1,4 @@ -use crate::{ExpandElement, CubeType}; +use crate::{CubeType, ExpandElement}; #[derive(new, Clone)] pub struct Array { From 2cf76039668bc04cda02bf27e987602902b50cda Mon Sep 17 00:00:00 2001 From: louisfd Date: Sat, 11 May 2024 19:45:06 -0400 Subject: [PATCH 38/54] clippy --- crates/burn-cube-macros/src/analysis.rs | 8 +++----- crates/burn-cube-macros/src/lib.rs | 9 ++++----- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs index 856d11dd24..5b28f95356 100644 --- a/crates/burn-cube-macros/src/analysis.rs +++ b/crates/burn-cube-macros/src/analysis.rs @@ -39,7 +39,7 @@ impl CodeAnalysis { Some(mut var) => { let should_clone = var.should_clone(loop_level); self.variable_analyses.insert(key, var); - return should_clone; + should_clone } None => panic!("Ident {ident} not part of analysis"), } @@ -81,10 +81,8 @@ impl CodeAnalysisBuilder { } for id in self.var_uses.iter() { - let prev_analysis = variable_analyses.remove(&id).expect(&format!( - "Analyis: Variable {:?} should be declared before it's used", - id - )); + let prev_analysis = variable_analyses.remove(id).unwrap_or_else(|| panic!("Analyis: Variable {:?} should be declared before it's used", + id)); let new_analysis = VariableAnalysis { num_used: prev_analysis.num_used + 1, loop_level_declared: prev_analysis.loop_level_declared, diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index 19265f68f8..f6f6b6bfbb 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -67,7 +67,9 @@ fn codegen_cube(func: &syn::ItemFn, code_analysis: &mut CodeAnalysis) -> TokenSt body.extend(tokens); } - let code = quote::quote! { + + + quote::quote! { // mod #mod_name { // #prelude @@ -80,9 +82,7 @@ fn codegen_cube(func: &syn::ItemFn, code_analysis: &mut CodeAnalysis) -> TokenSt } // } } - .into(); - - code + .into() } // fn get_name(sig: &syn::Signature) -> proc_macro2::TokenStream { @@ -130,5 +130,4 @@ fn expand_sig(sig: &syn::Signature) -> proc_macro2::TokenStream { quote::quote! { pub fn #ident #generics (context: &mut burn_cube::CubeContext, #inputs) -> #output } - .into() } From 3b0052f6f1fe9fd4915108306533c061bdce66d7 Mon Sep 17 00:00:00 2001 From: louisfd Date: Sat, 11 May 2024 19:48:16 -0400 Subject: [PATCH 39/54] found the culprit --- crates/burn-jit/src/kernel/prng/base.rs | 2 +- crates/burn-wgpu/src/compiler/wgsl/compiler.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/burn-jit/src/kernel/prng/base.rs b/crates/burn-jit/src/kernel/prng/base.rs index 0a32f74129..c9a14d423f 100644 --- a/crates/burn-jit/src/kernel/prng/base.rs +++ b/crates/burn-jit/src/kernel/prng/base.rs @@ -285,7 +285,7 @@ pub(crate) fn lcg_step(scope: &mut Scope, z: Variable) { } pub(crate) fn cast_uint_to_float(scope: &mut Scope, int_random: Variable, float_random: Variable) { - let tmp: Variable = 2.328_306_4e-10.into(); + let tmp: Variable = 2.328_306_4e-10f32.into(); gpu!(scope, float_random = cast(int_random)); gpu!(scope, float_random *= tmp); } diff --git a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs index 7656744e9e..9a267a85b4 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs @@ -103,7 +103,7 @@ impl WgslCompiler { gpu::FloatKind::F16 => panic!("f16 is not yet supported"), gpu::FloatKind::BF16 => panic!("f64 is not a valid WgpuElement"), gpu::FloatKind::F32 => wgsl::Elem::F32, - gpu::FloatKind::F64 => wgsl::Elem::F32, + gpu::FloatKind::F64 => panic!("f64 is not a valid WgpuElement"), }, gpu::Elem::Int(i) => match i { gpu::IntKind::I32 => wgsl::Elem::I32, From 6a4214849d278cecd366ff7f42154a47bdae4b66 Mon Sep 17 00:00:00 2001 From: louisfd Date: Sat, 11 May 2024 19:50:05 -0400 Subject: [PATCH 40/54] typo --- crates/burn-cube-macros/src/analysis.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs index 5b28f95356..e2a8d0c968 100644 --- a/crates/burn-cube-macros/src/analysis.rs +++ b/crates/burn-cube-macros/src/analysis.rs @@ -81,7 +81,7 @@ impl CodeAnalysisBuilder { } for id in self.var_uses.iter() { - let prev_analysis = variable_analyses.remove(id).unwrap_or_else(|| panic!("Analyis: Variable {:?} should be declared before it's used", + let prev_analysis = variable_analyses.remove(id).unwrap_or_else(|| panic!("Analysis: Variable {:?} should be declared before it's used", id)); let new_analysis = VariableAnalysis { num_used: prev_analysis.num_used + 1, From 9a4a94ceb656467ac9a159f78044443e4b087e7b Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 13 May 2024 09:00:29 -0400 Subject: [PATCH 41/54] add doc --- crates/burn-cube-macros/Cargo.toml | 5 +- crates/burn-cube-macros/src/analysis.rs | 54 +++++++----- crates/burn-cube-macros/src/codegen.rs | 109 ++++++++++++++++-------- crates/burn-cube-macros/src/lib.rs | 66 +++++++------- crates/burn-cube/Cargo.toml | 9 +- crates/burn-cube/tests/literal.rs | 17 ++++ 6 files changed, 163 insertions(+), 97 deletions(-) diff --git a/crates/burn-cube-macros/Cargo.toml b/crates/burn-cube-macros/Cargo.toml index 06565dec1d..11757df219 100644 --- a/crates/burn-cube-macros/Cargo.toml +++ b/crates/burn-cube-macros/Cargo.toml @@ -1,5 +1,8 @@ [package] -authors = ["nathanielsimard "] +authors = [ + "nathanielsimard ", + "louisfd , } #[derive(Debug, Default)] +/// Reads the Cube code and accumulates information, to generate a CodeAnalysis artefact pub(crate) struct CodeAnalysisBuilder { declarations: Vec<(VariableKey, usize)>, var_uses: Vec, @@ -55,7 +59,7 @@ impl CodeAnalysisBuilder { fn analyze(mut self, func: &syn::ItemFn) -> CodeAnalysis { // Build the vector of (Id, depth), using recursion self.signature_declarations(&func.sig); - self.stmts_occurrences(&func.block.stmts, 0); + self.find_occurrences_in_stmts(&func.block.stmts, 0); CodeAnalysis { variable_analyses: self.to_map(), @@ -81,8 +85,12 @@ impl CodeAnalysisBuilder { } for id in self.var_uses.iter() { - let prev_analysis = variable_analyses.remove(id).unwrap_or_else(|| panic!("Analysis: Variable {:?} should be declared before it's used", - id)); + let prev_analysis = variable_analyses.remove(id).unwrap_or_else(|| { + panic!( + "Analysis: Variable {:?} should be declared before it's used", + id + ) + }); let new_analysis = VariableAnalysis { num_used: prev_analysis.num_used + 1, loop_level_declared: prev_analysis.loop_level_declared, @@ -111,7 +119,7 @@ impl CodeAnalysisBuilder { } } - fn stmts_occurrences(&mut self, stmts: &Vec, depth: usize) { + fn find_occurrences_in_stmts(&mut self, stmts: &Vec, depth: usize) { for stmt in stmts { match stmt { // Declaration @@ -129,16 +137,16 @@ impl CodeAnalysisBuilder { self.declarations.push((id.into(), depth)); } if let Some(local_init) = &local.init { - self.expr_occurrences(&local_init.expr, depth) + self.find_occurrences_in_expr(&local_init.expr, depth) } } - syn::Stmt::Expr(expr, _) => self.expr_occurrences(expr, depth), + syn::Stmt::Expr(expr, _) => self.find_occurrences_in_expr(expr, depth), _ => todo!("Analysis: unsupported stmt {stmt:?}"), } } } - fn expr_occurrences(&mut self, expr: &syn::Expr, depth: usize) { + fn find_occurrences_in_expr(&mut self, expr: &syn::Expr, depth: usize) { match expr { syn::Expr::ForLoop(expr) => { let depth = depth + 1; @@ -149,39 +157,39 @@ impl CodeAnalysisBuilder { self.declarations.push((id.into(), depth)); } - self.stmts_occurrences(&expr.body.stmts, depth); + self.find_occurrences_in_stmts(&expr.body.stmts, depth); } syn::Expr::While(expr) => { let depth = depth + 1; - self.expr_occurrences(&expr.cond, depth); - self.stmts_occurrences(&expr.body.stmts, depth); + self.find_occurrences_in_expr(&expr.cond, depth); + self.find_occurrences_in_stmts(&expr.body.stmts, depth); } syn::Expr::Loop(expr) => { let depth = depth + 1; - self.stmts_occurrences(&expr.body.stmts, depth); + self.find_occurrences_in_stmts(&expr.body.stmts, depth); } syn::Expr::If(expr) => { let depth = depth + 1; - self.expr_occurrences(&expr.cond, depth); - self.stmts_occurrences(&expr.then_branch.stmts, depth); + self.find_occurrences_in_expr(&expr.cond, depth); + self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth); if let Some((_, expr)) = &expr.else_branch { if let syn::Expr::Block(expr_block) = &**expr { - self.stmts_occurrences(&expr_block.block.stmts, depth); + self.find_occurrences_in_stmts(&expr_block.block.stmts, depth); } else { todo!("Analysis: Only block else expr is supported") } } } syn::Expr::Assign(expr) => { - self.expr_occurrences(&expr.left, depth); - self.expr_occurrences(&expr.right, depth); + self.find_occurrences_in_expr(&expr.left, depth); + self.find_occurrences_in_expr(&expr.right, depth); } syn::Expr::Index(expr) => { - self.expr_occurrences(&expr.expr, depth); - self.expr_occurrences(&expr.index, depth); + self.find_occurrences_in_expr(&expr.expr, depth); + self.find_occurrences_in_expr(&expr.index, depth); } syn::Expr::Path(expr) => { let ident = expr @@ -193,8 +201,8 @@ impl CodeAnalysisBuilder { self.var_uses.push(ident.into()); } syn::Expr::Binary(expr) => { - self.expr_occurrences(&expr.left, depth); - self.expr_occurrences(&expr.right, depth); + self.find_occurrences_in_expr(&expr.left, depth); + self.find_occurrences_in_expr(&expr.right, depth); } syn::Expr::Lit(_) => {} syn::Expr::Call(expr) => { @@ -219,13 +227,13 @@ impl CodeAnalysisBuilder { _ => todo!("Analysis: unsupported func expr {:?}", expr.func), } for arg in expr.args.iter() { - self.expr_occurrences(arg, depth); + self.find_occurrences_in_expr(arg, depth); } } syn::Expr::MethodCall(expr) => { - self.expr_occurrences(&expr.receiver, depth); + self.find_occurrences_in_expr(&expr.receiver, depth); for arg in expr.args.iter() { - self.expr_occurrences(arg, depth); + self.find_occurrences_in_expr(arg, depth); } } syn::Expr::Break(_) => {} diff --git a/crates/burn-cube-macros/src/codegen.rs b/crates/burn-cube-macros/src/codegen.rs index 74c38871d0..3d1c099e29 100644 --- a/crates/burn-cube-macros/src/codegen.rs +++ b/crates/burn-cube-macros/src/codegen.rs @@ -4,6 +4,35 @@ use syn::{Lit, PathArguments}; use crate::analysis::CodeAnalysis; +/// Codegen for a code block (a list of statements) +fn codegen_block( + block: &syn::Block, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let mut statements = quote::quote!(); + + for statement in block.stmts.iter() { + statements.extend(codegen_statement(statement, loop_level, variable_analyses)); + } + + quote::quote! { + { + #statements + } + } +} + +/// Codegen for an expression containing a block +fn codegen_expr_block( + block: &syn::ExprBlock, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + codegen_block(&block.block, loop_level, variable_analyses) +} + +/// Codegen for a statement (generally one line) pub fn codegen_statement( statement: &syn::Stmt, loop_level: usize, @@ -24,6 +53,11 @@ pub fn codegen_statement( } } +/// Codegen for a local declaration (let ...) +/// Supports: +/// let x = ... +/// let x: T = ... +/// let _ = ... fn codegen_local( local: &syn::Local, loop_level: usize, @@ -57,6 +91,8 @@ fn codegen_local( } } +/// Codegen for expressions +/// There are many variants of expression, treated differently fn codegen_expr( expr: &syn::Expr, loop_level: usize, @@ -78,11 +114,12 @@ fn codegen_expr( syn::Expr::Break(_) => codegen_break(), syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level, variable_analyses), syn::Expr::MethodCall(call) => codegen_expr_method_call(call), - syn::Expr::Index(index) => codegen_expr_index(index, loop_level, variable_analyses), + syn::Expr::Index(index) => codegen_index(index, loop_level, variable_analyses), _ => panic!("Codegen: Unsupported {:?}", expr), } } +/// Codegen for literals fn codegen_lit(lit: &syn::ExprLit) -> TokenStream { match lit.lit { // We treat floats differently to avoid getting 4..into() for instance @@ -97,7 +134,8 @@ fn codegen_lit(lit: &syn::ExprLit) -> TokenStream { } } -fn codegen_expr_index( +/// Codegen for indexed access +fn codegen_index( index: &syn::ExprIndex, loop_level: usize, variable_analyses: &mut CodeAnalysis, @@ -107,17 +145,21 @@ fn codegen_expr_index( quote::quote! { { - let _array = #array; - let _index = #index; - burn_cube::index::expand(context, _array, _index) + let _array = #array; + let _index = #index; + burn_cube::index::expand(context, _array, _index) } } } +/// Codegen for method call fn codegen_expr_method_call(call: &syn::ExprMethodCall) -> TokenStream { quote::quote!( #call ) } +/// Codegen for for loops +/// Supports range: +/// for i in range(start, end, unroll) {...} fn codegen_for_loop( for_loop: &syn::ExprForLoop, loop_level: usize, @@ -157,6 +199,7 @@ fn codegen_for_loop( } } +/// Codegen for condition of an if or a while fn codegen_cond( cond: &syn::Expr, loop_level: usize, @@ -169,12 +212,17 @@ fn codegen_cond( } } +/// Codegen for break statement fn codegen_break() -> TokenStream { quote::quote! { break_expand(context); } } +/// Codegen for if and if/else statements +/// Supports: +/// if cond {...} +/// if cond {...} else {...} fn codegen_if( expr_if: &syn::ExprIf, loop_level: usize, @@ -203,6 +251,7 @@ fn codegen_if( } } +/// Codegen for loop fn codegen_loop( loop_expr: &syn::ExprLoop, loop_level: usize, @@ -215,6 +264,7 @@ fn codegen_loop( } } +/// Codegen for while loop fn codegen_while_loop( while_loop: &syn::ExprWhile, loop_level: usize, @@ -228,6 +278,10 @@ fn codegen_while_loop( } } +/// Codegen for assignation +/// Supports: +/// - scalar +/// - indexed array fn codegen_assign( assign: &syn::ExprAssign, loop_level: usize, @@ -264,32 +318,7 @@ fn codegen_assign( } } -fn codegen_block( - block: &syn::Block, - loop_level: usize, - variable_analyses: &mut CodeAnalysis, -) -> TokenStream { - let mut statements = quote::quote!(); - - for statement in block.stmts.iter() { - statements.extend(codegen_statement(statement, loop_level, variable_analyses)); - } - - quote::quote! { - { - #statements - } - } -} - -fn codegen_expr_block( - block: &syn::ExprBlock, - loop_level: usize, - variable_analyses: &mut CodeAnalysis, -) -> TokenStream { - codegen_block(&block.block, loop_level, variable_analyses) -} - +/// Codegen for a closure fn codegen_closure( closure: &syn::ExprClosure, loop_level: usize, @@ -313,15 +342,17 @@ fn codegen_closure( } } +/// Codegen for a function call +/// Supports: +/// func() +/// func::() +/// T::func() fn codegen_call( call: &syn::ExprCall, loop_level: usize, variable_analyses: &mut CodeAnalysis, ) -> TokenStream { - // Possibilities: - // a() - // a::() - // T::a + // We start with parsing the function path let (mut idents, generics) = match call.func.as_ref() { syn::Expr::Path(expr_path) => { let mut idents = Vec::new(); @@ -340,6 +371,7 @@ fn codegen_call( _ => todo!("Codegen: func call {:?} not supported", call.func), }; + // Function name with support for longer path let func_name = idents .pop() .expect("Codegen: Func should have at least one ident"); @@ -348,17 +380,18 @@ fn codegen_call( for ident in idents.iter() { previous_tokens.extend(quote_spanned! {ident.span() => #ident :: }); } - let func_name_expand = syn::Ident::new( format!("{func_name}_expand").as_str(), proc_macro2::Span::call_site(), ); + // Generics let generics = match generics { Some(generics) => quote::quote! { #generics }, None => quote::quote! {}, }; + // Arguments let mut args = quote::quote! { context, }; @@ -367,11 +400,14 @@ fn codegen_call( args.extend(quote::quote! { #arg, }); } + // Codegen quote::quote! { #previous_tokens #func_name_expand #generics (#args) } } +/// Codegen for a variable used in rhs of a statement +/// This function adds cloning when necessary fn codegen_path_rhs( path: &syn::ExprPath, loop_level: usize, @@ -395,6 +431,7 @@ fn codegen_path_rhs( } } +/// Codegen for binary operations (+, -, *, etc.) fn codegen_binary( binary: &syn::ExprBinary, loop_level: usize, diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index f6f6b6bfbb..31128fc4d7 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -5,21 +5,41 @@ use analysis::CodeAnalysis; use codegen::codegen_statement; use proc_macro::TokenStream; use quote::ToTokens; -use syn::{parse_macro_input, punctuated::Punctuated, Meta}; +use syn::{parse_macro_input, punctuated::Punctuated, token::Comma, Meta}; + +enum CubeMode { + /// Generates the expanded version of the function + Default, + /// Panics and prints the generated code, useful when debugging + /// Use by writing #[cube(panic)] + Panic, +} /// Derive macro for the module. #[proc_macro_attribute] pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream { let args = parse_macro_input!(attr with Punctuated::::parse_terminated); + let mode = parse_mode(args); + let func: syn::ItemFn = syn::parse(tokens).unwrap(); + let mut variable_analyses = CodeAnalysis::create(&func); + + let code = codegen_cube(&func, &mut variable_analyses); + match mode { + CubeMode::Default => code, + CubeMode::Panic => panic!("{code}"), + } +} + +fn parse_mode(args: Punctuated) -> CubeMode { + let mut mode = CubeMode::Default; - let mut panic_mode = false; if let Some(arg) = args.first() { match arg { Meta::Path(path) => { if let Some(ident) = path.get_ident().map(|id| id.to_string()) { match ident.as_str() { "panic" => { - panic_mode = true; + mode = CubeMode::Panic; } _ => panic!("Attribute {ident} is not supported"), } @@ -32,14 +52,7 @@ pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream { } } - let func: syn::ItemFn = syn::parse(tokens).unwrap(); - let mut variable_analyses = CodeAnalysis::create(&func); - - let code = codegen_cube(&func, &mut variable_analyses); - match panic_mode { - true => panic!("{code}"), - false => code, - } + mode } #[derive(Hash, PartialEq, Eq, Debug, Clone)] @@ -57,8 +70,6 @@ impl From<&syn::Ident> for VariableKey { /// Generate the expanded version of a function marked with the cube macro fn codegen_cube(func: &syn::ItemFn, code_analysis: &mut CodeAnalysis) -> TokenStream { - // let prelude = get_prelude(&code_analysis.needed_functions); - // let mod_name = get_name(&func.sig); let signature = expand_sig(&func.sig); let mut body = quote::quote! {}; @@ -67,33 +78,18 @@ fn codegen_cube(func: &syn::ItemFn, code_analysis: &mut CodeAnalysis) -> TokenSt body.extend(tokens); } - - quote::quote! { - // mod #mod_name { - // #prelude - - #[allow(dead_code)] - #func + #[allow(dead_code)] + #func - #[allow(unused_mut)] - #signature { - #body - } - // } + #[allow(unused_mut)] + #signature { + #body + } } .into() } -// fn get_name(sig: &syn::Signature) -> proc_macro2::TokenStream { -// let ident = &sig.ident; - -// quote::quote! { -// #ident -// } -// .into() -// } - fn expand_sig(sig: &syn::Signature) -> proc_macro2::TokenStream { let mut inputs = quote::quote!(); @@ -107,7 +103,7 @@ fn expand_sig(sig: &syn::Signature) -> proc_macro2::TokenStream { #ident: <#ty as burn_cube::CubeType>::ExpandType, }); } - _ => todo!(), + _ => todo!("Only Typed inputs are supported"), } } diff --git a/crates/burn-cube/Cargo.toml b/crates/burn-cube/Cargo.toml index d7637ac633..9302568ce9 100644 --- a/crates/burn-cube/Cargo.toml +++ b/crates/burn-cube/Cargo.toml @@ -1,5 +1,8 @@ [package] -authors = ["nathanielsimard "] +authors = [ + "nathanielsimard ", + "louisfd (lhs: F) { let _ = lhs + F::new(5); } +#[cube] +pub fn literal_float_no_decimals(lhs: F) { + let _ = lhs + F::from_primitive(5.); +} + #[test] fn cube_literal_test() { let mut context = CubeContext::root(); @@ -20,6 +25,18 @@ fn cube_literal_test() { assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); } +#[test] +fn cube_literal_float_no_decimal_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); + + literal_float_no_decimals_expand::(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); +} + fn inline_macro_ref() -> String { let mut context = CubeContext::root(); let item = Item::Scalar(ElemType::into_elem()); From 80312ac15a7e132a19841bdb70e01819dad8d089 Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 13 May 2024 09:12:07 -0400 Subject: [PATCH 42/54] refactor codegen into files --- crates/burn-cube-macros/src/codegen.rs | 525 ------------------ crates/burn-cube-macros/src/codegen/base.rs | 88 +++ crates/burn-cube-macros/src/codegen/branch.rs | 126 +++++ .../burn-cube-macros/src/codegen/function.rs | 98 ++++ crates/burn-cube-macros/src/codegen/mod.rs | 7 + .../burn-cube-macros/src/codegen/operation.rs | 98 ++++ .../burn-cube-macros/src/codegen/variable.rs | 141 +++++ 7 files changed, 558 insertions(+), 525 deletions(-) delete mode 100644 crates/burn-cube-macros/src/codegen.rs create mode 100644 crates/burn-cube-macros/src/codegen/base.rs create mode 100644 crates/burn-cube-macros/src/codegen/branch.rs create mode 100644 crates/burn-cube-macros/src/codegen/function.rs create mode 100644 crates/burn-cube-macros/src/codegen/mod.rs create mode 100644 crates/burn-cube-macros/src/codegen/operation.rs create mode 100644 crates/burn-cube-macros/src/codegen/variable.rs diff --git a/crates/burn-cube-macros/src/codegen.rs b/crates/burn-cube-macros/src/codegen.rs deleted file mode 100644 index 3d1c099e29..0000000000 --- a/crates/burn-cube-macros/src/codegen.rs +++ /dev/null @@ -1,525 +0,0 @@ -use proc_macro2::TokenStream; -use quote::{quote_spanned, ToTokens}; -use syn::{Lit, PathArguments}; - -use crate::analysis::CodeAnalysis; - -/// Codegen for a code block (a list of statements) -fn codegen_block( - block: &syn::Block, - loop_level: usize, - variable_analyses: &mut CodeAnalysis, -) -> TokenStream { - let mut statements = quote::quote!(); - - for statement in block.stmts.iter() { - statements.extend(codegen_statement(statement, loop_level, variable_analyses)); - } - - quote::quote! { - { - #statements - } - } -} - -/// Codegen for an expression containing a block -fn codegen_expr_block( - block: &syn::ExprBlock, - loop_level: usize, - variable_analyses: &mut CodeAnalysis, -) -> TokenStream { - codegen_block(&block.block, loop_level, variable_analyses) -} - -/// Codegen for a statement (generally one line) -pub fn codegen_statement( - statement: &syn::Stmt, - loop_level: usize, - variable_analyses: &mut CodeAnalysis, -) -> TokenStream { - match statement { - syn::Stmt::Local(local) => codegen_local(local, loop_level, variable_analyses), - syn::Stmt::Expr(expr, semi) => { - let expr = codegen_expr(expr, loop_level, variable_analyses); - match semi { - Some(_semi) => quote::quote!( - #expr; - ), - None => expr, - } - } - _ => todo!("Codegen: statement {statement:?} not supported"), - } -} - -/// Codegen for a local declaration (let ...) -/// Supports: -/// let x = ... -/// let x: T = ... -/// let _ = ... -fn codegen_local( - local: &syn::Local, - loop_level: usize, - variable_analyses: &mut CodeAnalysis, -) -> TokenStream { - let let_tok = local.let_token; - - let ident = match &local.pat { - syn::Pat::Ident(ident) => ident.to_token_stream(), - syn::Pat::Type(pat_type) => match &*pat_type.pat { - syn::Pat::Ident(pat_ident) => pat_ident.to_token_stream(), - _ => todo!("Codegen: Unsupported typed path {:?}", pat_type.pat), - }, - syn::Pat::Wild(wild) => wild.underscore_token.to_token_stream(), - _ => todo!("Codegen: Declaration {:?} is unsupported.", local.pat), - }; - - match local.init.as_ref() { - Some(init) => { - let init = codegen_expr(&init.expr, loop_level, variable_analyses); - - quote::quote! { - #let_tok #ident = #init; - } - } - None => { - quote::quote! { - #let_tok #ident; - } - } - } -} - -/// Codegen for expressions -/// There are many variants of expression, treated differently -fn codegen_expr( - expr: &syn::Expr, - loop_level: usize, - variable_analyses: &mut CodeAnalysis, -) -> TokenStream { - match expr { - syn::Expr::Binary(op) => codegen_binary(op, loop_level, variable_analyses), - syn::Expr::Path(path) => codegen_path_rhs(path, loop_level, variable_analyses), - syn::Expr::Call(call) => codegen_call(call, loop_level, variable_analyses), - syn::Expr::Lit(lit) => codegen_lit(lit), - syn::Expr::Closure(closure) => codegen_closure(closure, loop_level, variable_analyses), - syn::Expr::Block(block) => codegen_expr_block(block, loop_level, variable_analyses), - syn::Expr::Assign(assign) => codegen_assign(assign, loop_level, variable_analyses), - syn::Expr::ForLoop(for_loop) => codegen_for_loop(for_loop, loop_level, variable_analyses), - syn::Expr::While(while_loop) => { - codegen_while_loop(while_loop, loop_level, variable_analyses) - } - syn::Expr::Loop(loop_expr) => codegen_loop(loop_expr, loop_level, variable_analyses), - syn::Expr::Break(_) => codegen_break(), - syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level, variable_analyses), - syn::Expr::MethodCall(call) => codegen_expr_method_call(call), - syn::Expr::Index(index) => codegen_index(index, loop_level, variable_analyses), - _ => panic!("Codegen: Unsupported {:?}", expr), - } -} - -/// Codegen for literals -fn codegen_lit(lit: &syn::ExprLit) -> TokenStream { - match lit.lit { - // We treat floats differently to avoid getting 4..into() for instance - Lit::Float(_) => { - let lit_str = lit.lit.to_token_stream().to_string(); - let float_lit = lit_str.parse::().unwrap(); - quote::quote! { #float_lit.into() } - } - _ => { - quote::quote! { #lit.into() } - } - } -} - -/// Codegen for indexed access -fn codegen_index( - index: &syn::ExprIndex, - loop_level: usize, - variable_analyses: &mut CodeAnalysis, -) -> TokenStream { - let array = codegen_expr(&index.expr, loop_level, variable_analyses); - let index = codegen_expr(&index.index, loop_level, variable_analyses); - - quote::quote! { - { - let _array = #array; - let _index = #index; - burn_cube::index::expand(context, _array, _index) - } - } -} - -/// Codegen for method call -fn codegen_expr_method_call(call: &syn::ExprMethodCall) -> TokenStream { - quote::quote!( #call ) -} - -/// Codegen for for loops -/// Supports range: -/// for i in range(start, end, unroll) {...} -fn codegen_for_loop( - for_loop: &syn::ExprForLoop, - loop_level: usize, - variable_analyses: &mut CodeAnalysis, -) -> TokenStream { - let i = &for_loop.pat; - let block = codegen_block(&for_loop.body, loop_level + 1, variable_analyses); - - match for_loop.expr.as_ref() { - syn::Expr::Call(call) => { - let func_name = match call.func.as_ref() { - syn::Expr::Path(path) => path - .path - .get_ident() - .expect("Codegen: func in for loop should have ident"), - _ => todo!("Codegen: Only path call supported"), - }; - - if &func_name.to_string() == "range" { - let mut args = quote::quote! { - context, - }; - - for argument in call.args.iter() { - let arg = codegen_expr(argument, loop_level, variable_analyses); - args.extend(quote::quote! { #arg, }); - } - - quote::quote! { - range_expand(#args |context, #i| #block); - } - } else { - todo!("Codegen: Only range is supported") - } - } - _ => todo!("Codegen: Only call is supported {for_loop:?}"), - } -} - -/// Codegen for condition of an if or a while -fn codegen_cond( - cond: &syn::Expr, - loop_level: usize, - variable_analyses: &mut CodeAnalysis, -) -> TokenStream { - match cond { - syn::Expr::Binary(expr) => codegen_binary(expr, loop_level, variable_analyses), - syn::Expr::Lit(expr) => codegen_lit(expr), - _ => todo!("{cond:?} cond not supported"), - } -} - -/// Codegen for break statement -fn codegen_break() -> TokenStream { - quote::quote! { - break_expand(context); - } -} - -/// Codegen for if and if/else statements -/// Supports: -/// if cond {...} -/// if cond {...} else {...} -fn codegen_if( - expr_if: &syn::ExprIf, - loop_level: usize, - variable_analyses: &mut CodeAnalysis, -) -> TokenStream { - let cond = codegen_cond(&expr_if.cond, loop_level, variable_analyses); - - let then_block = codegen_block(&expr_if.then_branch, loop_level + 1, variable_analyses); - - if let Some((_, expr)) = &expr_if.else_branch { - if let syn::Expr::Block(expr_block) = &**expr { - let else_block = codegen_block(&expr_block.block, loop_level + 1, variable_analyses); - - quote::quote! { - let _cond = #cond; - if_else_expand(context, _cond, |context| #then_block, |context| #else_block); - } - } else { - todo!("Analysis: Only block else expr is supported") - } - } else { - quote::quote! { - let _cond = #cond; - if_expand(context, _cond, |context| #then_block); - } - } -} - -/// Codegen for loop -fn codegen_loop( - loop_expr: &syn::ExprLoop, - loop_level: usize, - variable_analyses: &mut CodeAnalysis, -) -> TokenStream { - let block = codegen_block(&loop_expr.body, loop_level + 1, variable_analyses); - - quote::quote! { - loop_expand(context, |context| #block); - } -} - -/// Codegen for while loop -fn codegen_while_loop( - while_loop: &syn::ExprWhile, - loop_level: usize, - variable_analyses: &mut CodeAnalysis, -) -> TokenStream { - let cond = codegen_cond(&while_loop.cond, loop_level + 1, variable_analyses); - let block = codegen_block(&while_loop.body, loop_level + 1, variable_analyses); - - quote::quote! { - while_loop_expand(context, |context| #cond, |context| #block); - } -} - -/// Codegen for assignation -/// Supports: -/// - scalar -/// - indexed array -fn codegen_assign( - assign: &syn::ExprAssign, - loop_level: usize, - variable_analyses: &mut CodeAnalysis, -) -> TokenStream { - match assign.left.as_ref() { - syn::Expr::Index(index) => { - let array = codegen_expr(&index.expr, loop_level, variable_analyses); - let index = codegen_expr(&index.index, loop_level, variable_analyses); - let value = codegen_expr(&assign.right, loop_level, variable_analyses); - - quote::quote! { - { - let _array = #array; - let _index = #index; - let _value = #value; - burn_cube::index_assign::expand(context, _array, _index, _value) - } - } - } - syn::Expr::Path(_) => { - let lhs = codegen_expr(&assign.left, loop_level, variable_analyses); - let rhs = codegen_expr(&assign.right, loop_level, variable_analyses); - - quote::quote! { - { - let _assign_lhs = #lhs; - let _assign_rhs = #rhs; - burn_cube::assign::expand(context, _assign_rhs, _assign_lhs) - } - } - } - _ => todo!("Assign of expr {:?} unsupported", assign.left), - } -} - -/// Codegen for a closure -fn codegen_closure( - closure: &syn::ExprClosure, - loop_level: usize, - variable_analyses: &mut CodeAnalysis, -) -> TokenStream { - let mut inputs = quote::quote! {}; - for input in closure.inputs.iter() { - let ident = match input { - syn::Pat::Ident(ident) => &ident.ident, - _ => panic!("Codegen: Unsupported {:?}", input), - }; - inputs.extend(quote::quote! { - #ident, - }); - } - - let body = codegen_expr(closure.body.as_ref(), loop_level, variable_analyses); - - quote::quote! { - |context, #inputs| #body - } -} - -/// Codegen for a function call -/// Supports: -/// func() -/// func::() -/// T::func() -fn codegen_call( - call: &syn::ExprCall, - loop_level: usize, - variable_analyses: &mut CodeAnalysis, -) -> TokenStream { - // We start with parsing the function path - let (mut idents, generics) = match call.func.as_ref() { - syn::Expr::Path(expr_path) => { - let mut idents = Vec::new(); - let mut generics = None; - for (index, segment) in expr_path.path.segments.iter().enumerate() { - idents.push(&segment.ident); - - if index == expr_path.path.segments.len() - 1 { - if let PathArguments::AngleBracketed(arguments) = &segment.arguments { - generics = Some(arguments) - } - } - } - (idents, generics) - } - _ => todo!("Codegen: func call {:?} not supported", call.func), - }; - - // Function name with support for longer path - let func_name = idents - .pop() - .expect("Codegen: Func should have at least one ident"); - - let mut previous_tokens = TokenStream::new(); - for ident in idents.iter() { - previous_tokens.extend(quote_spanned! {ident.span() => #ident :: }); - } - let func_name_expand = syn::Ident::new( - format!("{func_name}_expand").as_str(), - proc_macro2::Span::call_site(), - ); - - // Generics - let generics = match generics { - Some(generics) => quote::quote! { #generics }, - None => quote::quote! {}, - }; - - // Arguments - let mut args = quote::quote! { - context, - }; - for argument in call.args.iter() { - let arg = codegen_expr(argument, loop_level, variable_analyses); - args.extend(quote::quote! { #arg, }); - } - - // Codegen - quote::quote! { - #previous_tokens #func_name_expand #generics (#args) - } -} - -/// Codegen for a variable used in rhs of a statement -/// This function adds cloning when necessary -fn codegen_path_rhs( - path: &syn::ExprPath, - loop_level: usize, - variable_analyses: &mut CodeAnalysis, -) -> TokenStream { - let ident = path - .path - .get_ident() - .expect("Codegen: Only ident path are supported."); - - let will_be_used_again = variable_analyses.should_clone(ident, loop_level); - - if will_be_used_again { - quote::quote! { - #ident.clone() - } - } else { - quote::quote! { - #ident - } - } -} - -/// Codegen for binary operations (+, -, *, etc.) -fn codegen_binary( - binary: &syn::ExprBinary, - loop_level: usize, - variable_analyses: &mut CodeAnalysis, -) -> TokenStream { - let lhs = codegen_expr(&binary.left, loop_level, variable_analyses); - let rhs = codegen_expr(&binary.right, loop_level, variable_analyses); - - match binary.op { - syn::BinOp::Add(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::add::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Sub(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::sub::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Mul(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::mul::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Div(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::div::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Rem(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::rem::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Ne(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::ne::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Gt(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::gt::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Lt(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::lt::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::AddAssign(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::add_assign_op::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::BitAnd(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::and::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::And(_) => unimplemented!("Logical and (&&) not overridable in Rust due to its short circuiting nature. Use bitwise instead (&). "), - syn::BinOp::BitOr(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::or::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Or(_) => unimplemented!("Logical or (||) not overridable in Rust due to its short circuiting nature. Use bitwise instead (|). "), - _ => todo!("Codegen: unsupported op {:?}", binary.op), - } -} diff --git a/crates/burn-cube-macros/src/codegen/base.rs b/crates/burn-cube-macros/src/codegen/base.rs new file mode 100644 index 0000000000..0cea957469 --- /dev/null +++ b/crates/burn-cube-macros/src/codegen/base.rs @@ -0,0 +1,88 @@ +use proc_macro2::TokenStream; + +use crate::analysis::CodeAnalysis; + +use super::{ + branch::{codegen_break, codegen_for_loop, codegen_if, codegen_loop, codegen_while_loop}, + function::{codegen_call, codegen_closure, codegen_expr_method_call}, + operation::codegen_binary, + variable::{codegen_assign, codegen_index, codegen_lit, codegen_local, codegen_path_rhs}, +}; + +/// Codegen for a statement (generally one line) +/// Entry point of codegen +pub fn codegen_statement( + statement: &syn::Stmt, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + match statement { + syn::Stmt::Local(local) => codegen_local(local, loop_level, variable_analyses), + syn::Stmt::Expr(expr, semi) => { + let expr = codegen_expr(expr, loop_level, variable_analyses); + match semi { + Some(_semi) => quote::quote!( + #expr; + ), + None => expr, + } + } + _ => todo!("Codegen: statement {statement:?} not supported"), + } +} + +/// Codegen for a code block (a list of statements) +pub(crate) fn codegen_block( + block: &syn::Block, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let mut statements = quote::quote!(); + + for statement in block.stmts.iter() { + statements.extend(codegen_statement(statement, loop_level, variable_analyses)); + } + + quote::quote! { + { + #statements + } + } +} + +/// Codegen for an expression containing a block +pub(crate) fn codegen_expr_block( + block: &syn::ExprBlock, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + codegen_block(&block.block, loop_level, variable_analyses) +} + +/// Codegen for expressions +/// There are many variants of expression, treated differently +pub(crate) fn codegen_expr( + expr: &syn::Expr, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + match expr { + syn::Expr::Binary(op) => codegen_binary(op, loop_level, variable_analyses), + syn::Expr::Path(path) => codegen_path_rhs(path, loop_level, variable_analyses), + syn::Expr::Call(call) => codegen_call(call, loop_level, variable_analyses), + syn::Expr::Lit(lit) => codegen_lit(lit), + syn::Expr::Closure(closure) => codegen_closure(closure, loop_level, variable_analyses), + syn::Expr::Block(block) => codegen_expr_block(block, loop_level, variable_analyses), + syn::Expr::Assign(assign) => codegen_assign(assign, loop_level, variable_analyses), + syn::Expr::ForLoop(for_loop) => codegen_for_loop(for_loop, loop_level, variable_analyses), + syn::Expr::While(while_loop) => { + codegen_while_loop(while_loop, loop_level, variable_analyses) + } + syn::Expr::Loop(loop_expr) => codegen_loop(loop_expr, loop_level, variable_analyses), + syn::Expr::Break(_) => codegen_break(), + syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level, variable_analyses), + syn::Expr::MethodCall(call) => codegen_expr_method_call(call), + syn::Expr::Index(index) => codegen_index(index, loop_level, variable_analyses), + _ => panic!("Codegen: Unsupported {:?}", expr), + } +} diff --git a/crates/burn-cube-macros/src/codegen/branch.rs b/crates/burn-cube-macros/src/codegen/branch.rs new file mode 100644 index 0000000000..96318a9fa9 --- /dev/null +++ b/crates/burn-cube-macros/src/codegen/branch.rs @@ -0,0 +1,126 @@ +use proc_macro2::TokenStream; + +use crate::{analysis::CodeAnalysis, codegen::base::codegen_expr}; + +use super::{base::codegen_block, operation::codegen_binary, variable::codegen_lit}; + +/// Codegen of for loops +/// Supports range: +/// for i in range(start, end, unroll) {...} +pub(crate) fn codegen_for_loop( + for_loop: &syn::ExprForLoop, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let i = &for_loop.pat; + let block = codegen_block(&for_loop.body, loop_level + 1, variable_analyses); + + match for_loop.expr.as_ref() { + syn::Expr::Call(call) => { + let func_name = match call.func.as_ref() { + syn::Expr::Path(path) => path + .path + .get_ident() + .expect("Codegen: func in for loop should have ident"), + _ => todo!("Codegen: Only path call supported"), + }; + + if &func_name.to_string() == "range" { + let mut args = quote::quote! { + context, + }; + + for argument in call.args.iter() { + let arg = codegen_expr(argument, loop_level, variable_analyses); + args.extend(quote::quote! { #arg, }); + } + + quote::quote! { + range_expand(#args |context, #i| #block); + } + } else { + todo!("Codegen: Only range is supported") + } + } + _ => todo!("Codegen: Only call is supported {for_loop:?}"), + } +} + +/// Codegen for condition of an if or a while +pub(crate) fn codegen_cond( + cond: &syn::Expr, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + match cond { + syn::Expr::Binary(expr) => codegen_binary(expr, loop_level, variable_analyses), + syn::Expr::Lit(expr) => codegen_lit(expr), + _ => todo!("{cond:?} cond not supported"), + } +} + +/// Codegen for break statement +pub(crate) fn codegen_break() -> TokenStream { + quote::quote! { + break_expand(context); + } +} + +/// Codegen for if and if/else statements +/// Supports: +/// if cond {...} +/// if cond {...} else {...} +pub(crate) fn codegen_if( + expr_if: &syn::ExprIf, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let cond = codegen_cond(&expr_if.cond, loop_level, variable_analyses); + + let then_block = codegen_block(&expr_if.then_branch, loop_level + 1, variable_analyses); + + if let Some((_, expr)) = &expr_if.else_branch { + if let syn::Expr::Block(expr_block) = &**expr { + let else_block = codegen_block(&expr_block.block, loop_level + 1, variable_analyses); + + quote::quote! { + let _cond = #cond; + if_else_expand(context, _cond, |context| #then_block, |context| #else_block); + } + } else { + todo!("Analysis: Only block else expr is supported") + } + } else { + quote::quote! { + let _cond = #cond; + if_expand(context, _cond, |context| #then_block); + } + } +} + +/// Codegen of loop +pub(crate) fn codegen_loop( + loop_expr: &syn::ExprLoop, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let block = codegen_block(&loop_expr.body, loop_level + 1, variable_analyses); + + quote::quote! { + loop_expand(context, |context| #block); + } +} + +/// Codegen for while loop +pub(crate) fn codegen_while_loop( + while_loop: &syn::ExprWhile, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let cond = codegen_cond(&while_loop.cond, loop_level + 1, variable_analyses); + let block = codegen_block(&while_loop.body, loop_level + 1, variable_analyses); + + quote::quote! { + while_loop_expand(context, |context| #cond, |context| #block); + } +} diff --git a/crates/burn-cube-macros/src/codegen/function.rs b/crates/burn-cube-macros/src/codegen/function.rs new file mode 100644 index 0000000000..c680aac563 --- /dev/null +++ b/crates/burn-cube-macros/src/codegen/function.rs @@ -0,0 +1,98 @@ +use proc_macro2::TokenStream; +use quote::quote_spanned; +use syn::PathArguments; + +use crate::{analysis::CodeAnalysis, codegen::base::codegen_expr}; + +/// Codegen for method call +pub(crate) fn codegen_expr_method_call(call: &syn::ExprMethodCall) -> TokenStream { + quote::quote!( #call ) +} + +/// Codegen for a closure +pub(crate) fn codegen_closure( + closure: &syn::ExprClosure, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let mut inputs = quote::quote! {}; + for input in closure.inputs.iter() { + let ident = match input { + syn::Pat::Ident(ident) => &ident.ident, + _ => panic!("Codegen: Unsupported {:?}", input), + }; + inputs.extend(quote::quote! { + #ident, + }); + } + + let body = codegen_expr(closure.body.as_ref(), loop_level, variable_analyses); + + quote::quote! { + |context, #inputs| #body + } +} + +/// Codegen for a function call +/// Supports: +/// func() +/// func::() +/// T::func() +pub(crate) fn codegen_call( + call: &syn::ExprCall, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + // We start with parsing the function path + let (mut idents, generics) = match call.func.as_ref() { + syn::Expr::Path(expr_path) => { + let mut idents = Vec::new(); + let mut generics = None; + for (index, segment) in expr_path.path.segments.iter().enumerate() { + idents.push(&segment.ident); + + if index == expr_path.path.segments.len() - 1 { + if let PathArguments::AngleBracketed(arguments) = &segment.arguments { + generics = Some(arguments) + } + } + } + (idents, generics) + } + _ => todo!("Codegen: func call {:?} not supported", call.func), + }; + + // Function name with support for longer path + let func_name = idents + .pop() + .expect("Codegen: Func should have at least one ident"); + + let mut previous_tokens = TokenStream::new(); + for ident in idents.iter() { + previous_tokens.extend(quote_spanned! {ident.span() => #ident :: }); + } + let func_name_expand = syn::Ident::new( + format!("{func_name}_expand").as_str(), + proc_macro2::Span::call_site(), + ); + + // Generics + let generics = match generics { + Some(generics) => quote::quote! { #generics }, + None => quote::quote! {}, + }; + + // Arguments + let mut args = quote::quote! { + context, + }; + for argument in call.args.iter() { + let arg = codegen_expr(argument, loop_level, variable_analyses); + args.extend(quote::quote! { #arg, }); + } + + // Codegen + quote::quote! { + #previous_tokens #func_name_expand #generics (#args) + } +} diff --git a/crates/burn-cube-macros/src/codegen/mod.rs b/crates/burn-cube-macros/src/codegen/mod.rs new file mode 100644 index 0000000000..c5e5f757f3 --- /dev/null +++ b/crates/burn-cube-macros/src/codegen/mod.rs @@ -0,0 +1,7 @@ +mod base; +mod branch; +mod function; +mod operation; +mod variable; + +pub(crate) use base::codegen_statement; diff --git a/crates/burn-cube-macros/src/codegen/operation.rs b/crates/burn-cube-macros/src/codegen/operation.rs new file mode 100644 index 0000000000..76924c080b --- /dev/null +++ b/crates/burn-cube-macros/src/codegen/operation.rs @@ -0,0 +1,98 @@ +use proc_macro2::TokenStream; + +use crate::analysis::CodeAnalysis; + +use super::base::codegen_expr; + +/// Codegen for binary operations (+, -, *, etc.) +pub(crate) fn codegen_binary( + binary: &syn::ExprBinary, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let lhs = codegen_expr(&binary.left, loop_level, variable_analyses); + let rhs = codegen_expr(&binary.right, loop_level, variable_analyses); + + match binary.op { + syn::BinOp::Add(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::add::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::Sub(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::sub::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::Mul(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::mul::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::Div(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::div::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::Rem(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::rem::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::Ne(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::ne::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::Gt(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::gt::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::Lt(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::lt::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::AddAssign(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::add_assign_op::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::BitAnd(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::and::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::And(_) => unimplemented!("Logical and (&&) not overridable in Rust due to its short circuiting nature. Use bitwise instead (&). "), + syn::BinOp::BitOr(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::or::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::Or(_) => unimplemented!("Logical or (||) not overridable in Rust due to its short circuiting nature. Use bitwise instead (|). "), + _ => todo!("Codegen: unsupported op {:?}", binary.op), + } +} diff --git a/crates/burn-cube-macros/src/codegen/variable.rs b/crates/burn-cube-macros/src/codegen/variable.rs new file mode 100644 index 0000000000..769bc6dff0 --- /dev/null +++ b/crates/burn-cube-macros/src/codegen/variable.rs @@ -0,0 +1,141 @@ +use proc_macro2::TokenStream; +use quote::ToTokens; +use syn::Lit; + +use crate::{analysis::CodeAnalysis, codegen::base::codegen_expr}; + +/// Codegen for literals +pub(crate) fn codegen_lit(lit: &syn::ExprLit) -> TokenStream { + match lit.lit { + // We treat floats differently to avoid getting 4..into() for instance + Lit::Float(_) => { + let lit_str = lit.lit.to_token_stream().to_string(); + let float_lit = lit_str.parse::().unwrap(); + quote::quote! { #float_lit.into() } + } + _ => { + quote::quote! { #lit.into() } + } + } +} + +/// Codegen for a local declaration (let ...) +/// Supports: +/// let x = ... +/// let x: T = ... +/// let _ = ... +pub(crate) fn codegen_local( + local: &syn::Local, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let let_tok = local.let_token; + + let ident = match &local.pat { + syn::Pat::Ident(ident) => ident.to_token_stream(), + syn::Pat::Type(pat_type) => match &*pat_type.pat { + syn::Pat::Ident(pat_ident) => pat_ident.to_token_stream(), + _ => todo!("Codegen: Unsupported typed path {:?}", pat_type.pat), + }, + syn::Pat::Wild(wild) => wild.underscore_token.to_token_stream(), + _ => todo!("Codegen: Declaration {:?} is unsupported.", local.pat), + }; + + match local.init.as_ref() { + Some(init) => { + let init = codegen_expr(&init.expr, loop_level, variable_analyses); + + quote::quote! { + #let_tok #ident = #init; + } + } + None => { + quote::quote! { + #let_tok #ident; + } + } + } +} + +/// Codegen for indexed access +pub(crate) fn codegen_index( + index: &syn::ExprIndex, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let array = codegen_expr(&index.expr, loop_level, variable_analyses); + let index = codegen_expr(&index.index, loop_level, variable_analyses); + + quote::quote! { + { + let _array = #array; + let _index = #index; + burn_cube::index::expand(context, _array, _index) + } + } +} + +/// Codegen for assignation +/// Supports: +/// - scalar +/// - indexed array +pub(crate) fn codegen_assign( + assign: &syn::ExprAssign, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + match assign.left.as_ref() { + syn::Expr::Index(index) => { + let array = codegen_expr(&index.expr, loop_level, variable_analyses); + let index = codegen_expr(&index.index, loop_level, variable_analyses); + let value = codegen_expr(&assign.right, loop_level, variable_analyses); + + quote::quote! { + { + let _array = #array; + let _index = #index; + let _value = #value; + burn_cube::index_assign::expand(context, _array, _index, _value) + } + } + } + syn::Expr::Path(_) => { + let lhs = codegen_expr(&assign.left, loop_level, variable_analyses); + let rhs = codegen_expr(&assign.right, loop_level, variable_analyses); + + quote::quote! { + { + let _assign_lhs = #lhs; + let _assign_rhs = #rhs; + burn_cube::assign::expand(context, _assign_rhs, _assign_lhs) + } + } + } + _ => todo!("Assign of expr {:?} unsupported", assign.left), + } +} + +/// Codegen for a variable used in rhs of a statement +/// This function adds cloning when necessary +pub(crate) fn codegen_path_rhs( + path: &syn::ExprPath, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let ident = path + .path + .get_ident() + .expect("Codegen: Only ident path are supported."); + + let will_be_used_again = variable_analyses.should_clone(ident, loop_level); + + if will_be_used_again { + quote::quote! { + #ident.clone() + } + } else { + quote::quote! { + #ident + } + } +} From a2b0369303f5b616812f71ec2ef1d8050e8397a3 Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 13 May 2024 09:35:22 -0400 Subject: [PATCH 43/54] minor refactor --- crates/burn-cube/src/branch.rs | 2 +- crates/burn-cube/src/element/base.rs | 2 +- crates/burn-cube/src/element/runtime.rs | 2 +- crates/burn-cube/src/operation/binary.rs | 8 ++++---- crates/burn-jit/src/codegen/dialect/gpu/macros.rs | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/crates/burn-cube/src/branch.rs b/crates/burn-cube/src/branch.rs index e4a5ea97d0..e2b735b0fa 100644 --- a/crates/burn-cube/src/branch.rs +++ b/crates/burn-cube/src/branch.rs @@ -116,7 +116,7 @@ where let mut inside_loop = context.child(); let cond: ExpandElement = cond_fn(&mut inside_loop); - if_expand(&mut inside_loop, cond, |context| break_expand(context)); + if_expand(&mut inside_loop, cond, break_expand); block(&mut inside_loop); context.register(Branch::Loop(gpu::Loop { diff --git a/crates/burn-cube/src/element/base.rs b/crates/burn-cube/src/element/base.rs index 055d75909c..c62e1ea79b 100644 --- a/crates/burn-cube/src/element/base.rs +++ b/crates/burn-cube/src/element/base.rs @@ -41,6 +41,6 @@ impl core::ops::Deref for ExpandElement { impl From for Variable { fn from(value: ExpandElement) -> Self { - (*value.inner).clone() + *value.inner } } diff --git a/crates/burn-cube/src/element/runtime.rs b/crates/burn-cube/src/element/runtime.rs index 8d4bf7e22f..d19e9debe0 100644 --- a/crates/burn-cube/src/element/runtime.rs +++ b/crates/burn-cube/src/element/runtime.rs @@ -8,7 +8,7 @@ pub trait RuntimeType: CubeType { fn into_elem() -> Elem; fn from_expand(context: &mut CubeContext, val: ExpandElement) -> ExpandElement { let new_var = context.create_local(Item::Scalar(::into_elem())); - assign::expand(context, val.into(), new_var.clone()); + assign::expand(context, val, new_var.clone()); new_var } } diff --git a/crates/burn-cube/src/operation/binary.rs b/crates/burn-cube/src/operation/binary.rs index 08c085d8c7..2e8b4f15b0 100644 --- a/crates/burn-cube/src/operation/binary.rs +++ b/crates/burn-cube/src/operation/binary.rs @@ -114,7 +114,7 @@ pub mod mul { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { - <$type>::new(self.val - rhs.val) + <$type>::new(self.val * rhs.val) } } }; @@ -193,11 +193,11 @@ pub mod rem { }; ($type:ty) => { - impl core::ops::Div for $type { + impl core::ops::Rem for $type { type Output = Self; - fn div(self, rhs: Self) -> Self::Output { - <$type>::new(self.val / rhs.val) + fn rem(self, rhs: Self) -> Self::Output { + <$type>::new(self.val % rhs.val) } } }; diff --git a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs index b2c1c1a457..bacd7510fe 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs @@ -422,7 +422,7 @@ impl From for Variable { impl From for Variable { fn from(value: f64) -> Self { - Self::ConstantScalar(value as f64, super::Elem::Float(super::FloatKind::F64)) + Self::ConstantScalar(value, super::Elem::Float(super::FloatKind::F64)) } } From d342d7ecf2742a6e0d5bee2715aff0b65e2e3aa7 Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 13 May 2024 09:43:15 -0400 Subject: [PATCH 44/54] prevent clippy from breaking tests --- crates/burn-cube/tests/cast_elem.rs | 4 ++++ crates/burn-cube/tests/reuse.rs | 1 + 2 files changed, 5 insertions(+) diff --git a/crates/burn-cube/tests/cast_elem.rs b/crates/burn-cube/tests/cast_elem.rs index b71133075b..65c274d970 100644 --- a/crates/burn-cube/tests/cast_elem.rs +++ b/crates/burn-cube/tests/cast_elem.rs @@ -42,6 +42,7 @@ macro_rules! cast_test { // From float #[cube] +#[allow(clippy::useless_conversion)] pub fn float_to_float(x: F32) { let y = x + F32::new(2); let _ = F32::from(y) + F32::new(34); @@ -100,6 +101,7 @@ pub fn int_to_float(x: I32) { } #[cube] +#[allow(clippy::useless_conversion)] pub fn int_to_int(x: I32) { let y = x + I32::new(2); let _ = I32::from(y) + I32::new(34); @@ -158,6 +160,7 @@ pub fn uint_to_int(x: UInt) { } #[cube] +#[allow(clippy::useless_conversion)] pub fn uint_to_uint(x: UInt) { let y = x + UInt::new(2); let _ = UInt::from(y) + UInt::new(34); @@ -216,6 +219,7 @@ pub fn bool_to_uint(x: Bool) { } #[cube] +#[allow(clippy::useless_conversion)] pub fn bool_to_bool(x: Bool) { let y = x & Bool::new(false); let _ = Bool::from(y) | Bool::new(true); diff --git a/crates/burn-cube/tests/reuse.rs b/crates/burn-cube/tests/reuse.rs index 7a37292297..caa40b97b8 100644 --- a/crates/burn-cube/tests/reuse.rs +++ b/crates/burn-cube/tests/reuse.rs @@ -11,6 +11,7 @@ use burn_jit::{ type ElemType = I32; #[cube] +#[allow(clippy::assign_op_pattern)] pub fn reuse(mut x: I) { while x < I::new(10) { x = x + I::new(1); From fa648e4829a9d3888f06a196206a79029a76d706 Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 13 May 2024 12:54:57 -0400 Subject: [PATCH 45/54] add doc --- crates/burn-cube/src/element/bool.rs | 9 ++- crates/burn-cube/src/element/conversion.rs | 2 +- crates/burn-cube/src/element/float.rs | 19 +++--- crates/burn-cube/src/element/int.rs | 11 +-- crates/burn-cube/src/element/mod.rs | 4 +- crates/burn-cube/src/element/numeric.rs | 14 +++- crates/burn-cube/src/element/primitive.rs | 76 +++++++++------------ crates/burn-cube/src/element/runtime.rs | 14 ---- crates/burn-cube/src/element/static_type.rs | 22 ++++++ crates/burn-cube/src/element/uint.rs | 16 +++-- crates/burn-cube/tests/cast_elem.rs | 2 +- crates/burn-cube/tests/cast_kind.rs | 2 +- crates/burn-cube/tests/for_loop.rs | 2 +- crates/burn-cube/tests/generic_kernel.rs | 2 +- crates/burn-cube/tests/if.rs | 2 +- crates/burn-cube/tests/if_else.rs | 2 +- crates/burn-cube/tests/literal.rs | 2 +- crates/burn-cube/tests/loop.rs | 2 +- crates/burn-cube/tests/reuse.rs | 2 +- 19 files changed, 107 insertions(+), 98 deletions(-) delete mode 100644 crates/burn-cube/src/element/runtime.rs create mode 100644 crates/burn-cube/src/element/static_type.rs diff --git a/crates/burn-cube/src/element/bool.rs b/crates/burn-cube/src/element/bool.rs index 3674d8cdf1..1a77ced3b3 100644 --- a/crates/burn-cube/src/element/bool.rs +++ b/crates/burn-cube/src/element/bool.rs @@ -1,8 +1,9 @@ use burn_jit::gpu::Elem; -use crate::{CubeContext, CubeType, ExpandElement, RuntimeType}; +use crate::{CubeContext, CubeType, ExpandElement, PrimitiveVariable}; #[derive(Clone, Copy)] +/// Boolean type for kernels pub struct Bool { pub val: bool, pub vectorization: u8, @@ -13,19 +14,23 @@ impl CubeType for Bool { } impl Bool { + /// Create a Bool from primitive bool pub fn new(val: bool) -> Self { Self { val, vectorization: 1, } } + + /// Expand version of new pub fn new_expand(_context: &mut CubeContext, val: bool) -> ::ExpandType { val.into() } } -impl RuntimeType for Bool { +impl PrimitiveVariable for Bool { type Primitive = bool; + fn val(&self) -> Self::Primitive { self.val } diff --git a/crates/burn-cube/src/element/conversion.rs b/crates/burn-cube/src/element/conversion.rs index 48403e8bfe..52ec88d6f7 100644 --- a/crates/burn-cube/src/element/conversion.rs +++ b/crates/burn-cube/src/element/conversion.rs @@ -1,4 +1,4 @@ -use crate::{Bool, Float, Int, RuntimeType, UInt, BF16, F16, F32, F64, I32, I64}; +use crate::{Bool, Float, Int, PrimitiveVariable, UInt, BF16, F16, F32, F64, I32, I64}; macro_rules! impl_to_float { ($to:ident, $from1:ident) => { diff --git a/crates/burn-cube/src/element/float.rs b/crates/burn-cube/src/element/float.rs index 3311e5c689..91917d16ec 100644 --- a/crates/burn-cube/src/element/float.rs +++ b/crates/burn-cube/src/element/float.rs @@ -1,19 +1,16 @@ -use crate::{CubeContext, CubeType, ExpandElement, Numeric, RuntimeType}; +use crate::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable}; use burn_jit::gpu::{Elem, FloatKind, Variable}; use std::rc::Rc; +/// Floating point numbers. Used as input in float kernels pub trait Float: Clone + Copy - + std::cmp::PartialOrd - + std::ops::Add - + std::ops::Sub - + std::ops::Mul - + std::ops::Div - + std::ops::AddAssign + Numeric { + /// Create a Float from a float literal fn from_primitive(val: f64) -> Self; + /// Expand version of from_primitive fn from_primitive_expand(context: &mut CubeContext, val: f64) -> ExpandElement; } @@ -29,11 +26,17 @@ macro_rules! impl_float { type ExpandType = ExpandElement; } - impl RuntimeType for $type { + impl PrimitiveVariable for $type { + /// Note: all float types have f64 primitive on CPU to + /// ease casting. On GPU the type will be given by into_elem. type Primitive = f64; + + /// Return the value of the float (on CPU) fn val(&self) -> Self::Primitive { self.val } + + /// Return the element type to use on GPU fn into_elem() -> Elem { Elem::Float(FloatKind::$type) } diff --git a/crates/burn-cube/src/element/int.rs b/crates/burn-cube/src/element/int.rs index 4f4bb87ec8..34f66999b9 100644 --- a/crates/burn-cube/src/element/int.rs +++ b/crates/burn-cube/src/element/int.rs @@ -1,16 +1,11 @@ -use crate::{CubeContext, CubeType, ExpandElement, Numeric, RuntimeType}; +use crate::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable}; use burn_jit::gpu::{Elem, IntKind, Variable}; use std::rc::Rc; +/// Signed integer. Used as input in int kernels pub trait Int: Clone + Copy - + std::cmp::PartialOrd - + std::ops::Add - + std::ops::Sub - + std::ops::Mul - + std::ops::Div - + std::ops::AddAssign + std::ops::Rem + Numeric { @@ -30,7 +25,7 @@ macro_rules! impl_int { type ExpandType = ExpandElement; } - impl RuntimeType for $type { + impl PrimitiveVariable for $type { type Primitive = i64; fn val(&self) -> Self::Primitive { self.val diff --git a/crates/burn-cube/src/element/mod.rs b/crates/burn-cube/src/element/mod.rs index 2269ac8e25..84fa4f585c 100644 --- a/crates/burn-cube/src/element/mod.rs +++ b/crates/burn-cube/src/element/mod.rs @@ -5,8 +5,8 @@ mod conversion; mod float; mod int; mod numeric; +mod static_type; mod primitive; -mod runtime; mod uint; pub use array::*; @@ -15,5 +15,5 @@ pub use bool::*; pub use float::*; pub use int::*; pub use numeric::*; -pub use runtime::*; +pub use primitive::*; pub use uint::*; diff --git a/crates/burn-cube/src/element/numeric.rs b/crates/burn-cube/src/element/numeric.rs index cf825a0fff..30097cb43a 100644 --- a/crates/burn-cube/src/element/numeric.rs +++ b/crates/burn-cube/src/element/numeric.rs @@ -1,9 +1,11 @@ -use crate::{CubeContext, ExpandElement, RuntimeType}; +use crate::{CubeContext, ExpandElement, PrimitiveVariable}; +/// Type that encompasses both integers and floats +/// Used in kernels that should work for both. pub trait Numeric: Clone + Copy - + RuntimeType + + PrimitiveVariable + std::ops::Add + std::ops::AddAssign + std::ops::Sub @@ -11,7 +13,13 @@ pub trait Numeric: + std::ops::Div + std::cmp::PartialOrd { - // If we use numeric then constants are necessarily ints + /// Create a new constant numeric. + /// + /// Note: since this must work for both integer and float + /// only the less expressive of both can be created (int) + /// If a number with decimals is needed, use Float::from_primitive. fn new(val: i64) -> Self; + + /// Expand version of new fn new_expand(context: &mut CubeContext, val: i64) -> ExpandElement; } diff --git a/crates/burn-cube/src/element/primitive.rs b/crates/burn-cube/src/element/primitive.rs index 856858e68b..006bc25134 100644 --- a/crates/burn-cube/src/element/primitive.rs +++ b/crates/burn-cube/src/element/primitive.rs @@ -1,57 +1,43 @@ use std::rc::Rc; -use burn_jit::gpu::Variable; +use burn_jit::gpu::{Elem, Item, Variable}; -use crate::{CubeType, ExpandElement}; +use crate::{assign, CubeContext, CubeType, ExpandElement}; -impl CubeType for bool { - type ExpandType = bool; -} +/// Form of CubeType that encapsulates all primitive types: +/// Numeric, UInt, Bool +pub trait PrimitiveVariable: CubeType { + /// Type of the value kept CPU-side. + /// Does not necessarily match the GPU type. + type Primitive; -impl CubeType for u32 { - type ExpandType = u32; -} + /// Return the value of the float on CPU + fn val(&self) -> Self::Primitive; -impl CubeType for f32 { - type ExpandType = f32; -} - -impl CubeType for i32 { - type ExpandType = i32; -} + /// Return the element type to use on GPU + fn into_elem() -> Elem; -impl From for ExpandElement { - fn from(value: u32) -> Self { - ExpandElement::new(Rc::new(Variable::from(value))) + /// Expand version of from, of the trait From + fn from_expand(context: &mut CubeContext, val: ExpandElement) -> ExpandElement { + let new_var = context.create_local(Item::Scalar(::into_elem())); + assign::expand(context, val, new_var.clone()); + new_var } } -impl From for ExpandElement { - fn from(value: usize) -> Self { - ExpandElement::new(Rc::new(Variable::from(value))) - } -} - -impl From for ExpandElement { - fn from(value: bool) -> Self { - ExpandElement::new(Rc::new(Variable::from(value))) - } +macro_rules! impl_into_expand_element { + ($type:ty) => { + impl From<$type> for ExpandElement { + fn from(value: $type) -> Self { + ExpandElement::new(Rc::new(Variable::from(value))) + } + } + }; } -impl From for ExpandElement { - fn from(value: f32) -> Self { - ExpandElement::new(Rc::new(Variable::from(value))) - } -} - -impl From for ExpandElement { - fn from(value: i32) -> Self { - ExpandElement::new(Rc::new(Variable::from(value))) - } -} - -impl From for ExpandElement { - fn from(value: i64) -> Self { - ExpandElement::new(Rc::new(Variable::from(value))) - } -} +impl_into_expand_element!(u32); +impl_into_expand_element!(usize); +impl_into_expand_element!(bool); +impl_into_expand_element!(f32); +impl_into_expand_element!(i32); +impl_into_expand_element!(i64); diff --git a/crates/burn-cube/src/element/runtime.rs b/crates/burn-cube/src/element/runtime.rs deleted file mode 100644 index d19e9debe0..0000000000 --- a/crates/burn-cube/src/element/runtime.rs +++ /dev/null @@ -1,14 +0,0 @@ -use burn_jit::gpu::{Elem, Item}; - -use crate::{assign, CubeContext, CubeType, ExpandElement}; - -pub trait RuntimeType: CubeType { - type Primitive; - fn val(&self) -> Self::Primitive; - fn into_elem() -> Elem; - fn from_expand(context: &mut CubeContext, val: ExpandElement) -> ExpandElement { - let new_var = context.create_local(Item::Scalar(::into_elem())); - assign::expand(context, val, new_var.clone()); - new_var - } -} diff --git a/crates/burn-cube/src/element/static_type.rs b/crates/burn-cube/src/element/static_type.rs new file mode 100644 index 0000000000..ceb54cf953 --- /dev/null +++ b/crates/burn-cube/src/element/static_type.rs @@ -0,0 +1,22 @@ +use crate::CubeType; + +/// Types that exist within a cube function +/// but should not be turned into a JIT variable +/// For instance: a bool that determines at compile time +/// if we should unroll a for loop or not + +impl CubeType for bool { + type ExpandType = bool; +} + +impl CubeType for u32 { + type ExpandType = u32; +} + +impl CubeType for f32 { + type ExpandType = f32; +} + +impl CubeType for i32 { + type ExpandType = i32; +} diff --git a/crates/burn-cube/src/element/uint.rs b/crates/burn-cube/src/element/uint.rs index 230f9f2003..0d3520133d 100644 --- a/crates/burn-cube/src/element/uint.rs +++ b/crates/burn-cube/src/element/uint.rs @@ -1,21 +1,25 @@ use burn_jit::gpu::Elem; -use crate::{CubeContext, CubeType, ExpandElement, RuntimeType}; +use crate::{CubeContext, CubeType, ExpandElement, PrimitiveVariable}; #[derive(Clone, Copy)] +/// An unsigned int. +/// Preferred for indexing operations pub struct UInt { - pub val: ::Primitive, + pub val: ::Primitive, pub vectorization: u8, } impl UInt { - // Use with integer literal + /// Create a UInt. Use with integer literal pub fn new(val: i64) -> Self { Self { val, vectorization: 1, } } + + /// Expand version of new pub fn new_expand(_context: &mut CubeContext, val: i64) -> ::ExpandType { (val as u32).into() } @@ -25,7 +29,7 @@ impl CubeType for UInt { type ExpandType = ExpandElement; } -impl RuntimeType for UInt { +impl PrimitiveVariable for UInt { type Primitive = i64; fn val(&self) -> Self::Primitive { self.val @@ -37,12 +41,12 @@ impl RuntimeType for UInt { impl From for UInt { fn from(value: u32) -> Self { - UInt::new(value as ::Primitive) + UInt::new(value as ::Primitive) } } impl From for UInt { fn from(value: usize) -> Self { - UInt::new(value as ::Primitive) + UInt::new(value as ::Primitive) } } diff --git a/crates/burn-cube/tests/cast_elem.rs b/crates/burn-cube/tests/cast_elem.rs index 65c274d970..c19b5db763 100644 --- a/crates/burn-cube/tests/cast_elem.rs +++ b/crates/burn-cube/tests/cast_elem.rs @@ -1,4 +1,4 @@ -use burn_cube::{cube, Bool, CubeContext, Numeric, RuntimeType, UInt, F32, I32}; +use burn_cube::{cube, Bool, CubeContext, Numeric, PrimitiveVariable, UInt, F32, I32}; use burn_jit::{ gpu, gpu::{Elem, Item, Variable}, diff --git a/crates/burn-cube/tests/cast_kind.rs b/crates/burn-cube/tests/cast_kind.rs index ae525a973a..8932cc03d8 100644 --- a/crates/burn-cube/tests/cast_kind.rs +++ b/crates/burn-cube/tests/cast_kind.rs @@ -1,4 +1,4 @@ -use burn_cube::{cube, CubeContext, Float, Int, Numeric, RuntimeType, F32, F64, I32, I64}; +use burn_cube::{cube, CubeContext, Float, Int, Numeric, PrimitiveVariable, F32, F64, I32, I64}; use burn_jit::{gpu, gpu::Item}; #[cube] diff --git a/crates/burn-cube/tests/for_loop.rs b/crates/burn-cube/tests/for_loop.rs index 1a3fee31ac..9ee7b0abcd 100644 --- a/crates/burn-cube/tests/for_loop.rs +++ b/crates/burn-cube/tests/for_loop.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, Array, CubeContext, Float, RuntimeType, UInt, F32}; +use burn_cube::{branch::*, cube, Array, CubeContext, Float, PrimitiveVariable, UInt, F32}; use burn_jit::{ gpu, gpu::{Item, Variable}, diff --git a/crates/burn-cube/tests/generic_kernel.rs b/crates/burn-cube/tests/generic_kernel.rs index 287b11e8f0..1887b397bd 100644 --- a/crates/burn-cube/tests/generic_kernel.rs +++ b/crates/burn-cube/tests/generic_kernel.rs @@ -1,4 +1,4 @@ -use burn_cube::{cube, CubeContext, Numeric, RuntimeType, F32, I32}; +use burn_cube::{cube, CubeContext, Numeric, PrimitiveVariable, F32, I32}; use burn_jit::{gpu, gpu::Item}; #[cube] diff --git a/crates/burn-cube/tests/if.rs b/crates/burn-cube/tests/if.rs index 43a6b1eb58..f7a6b2ba2f 100644 --- a/crates/burn-cube/tests/if.rs +++ b/crates/burn-cube/tests/if.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, CubeContext, Numeric, RuntimeType, F32}; +use burn_cube::{branch::*, cube, CubeContext, Numeric, PrimitiveVariable, F32}; use burn_jit::{ gpu, gpu::{Elem, Item, Variable}, diff --git a/crates/burn-cube/tests/if_else.rs b/crates/burn-cube/tests/if_else.rs index 514f7555fa..dc9d643deb 100644 --- a/crates/burn-cube/tests/if_else.rs +++ b/crates/burn-cube/tests/if_else.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, CubeContext, Float, RuntimeType, F32}; +use burn_cube::{branch::*, cube, CubeContext, Float, PrimitiveVariable, F32}; use burn_jit::{ gpu, gpu::{Elem, Item, Variable}, diff --git a/crates/burn-cube/tests/literal.rs b/crates/burn-cube/tests/literal.rs index 9cb10bc622..1718a3e780 100644 --- a/crates/burn-cube/tests/literal.rs +++ b/crates/burn-cube/tests/literal.rs @@ -1,4 +1,4 @@ -use burn_cube::{cube, CubeContext, Float, RuntimeType, F32}; +use burn_cube::{cube, CubeContext, Float, PrimitiveVariable, F32}; use burn_jit::{gpu, gpu::Item}; type ElemType = F32; diff --git a/crates/burn-cube/tests/loop.rs b/crates/burn-cube/tests/loop.rs index 4799d76741..f4e78294bc 100644 --- a/crates/burn-cube/tests/loop.rs +++ b/crates/burn-cube/tests/loop.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, CubeContext, Int, RuntimeType, I32}; +use burn_cube::{branch::*, cube, CubeContext, Int, PrimitiveVariable, I32}; use burn_jit::gpu; use burn_jit::gpu::Branch; use burn_jit::gpu::{Elem, Item, Variable}; diff --git a/crates/burn-cube/tests/reuse.rs b/crates/burn-cube/tests/reuse.rs index caa40b97b8..e06a19aa06 100644 --- a/crates/burn-cube/tests/reuse.rs +++ b/crates/burn-cube/tests/reuse.rs @@ -1,4 +1,4 @@ -use burn_cube::{branch::*, cube, CubeContext, Int, RuntimeType, I32}; +use burn_cube::{branch::*, cube, CubeContext, Int, PrimitiveVariable, I32}; use burn_jit::{ gpu, gpu::{Branch, Elem, Item, Variable}, From a2b63a39ba57458f878045d3beed77f7076de4c8 Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 13 May 2024 13:03:08 -0400 Subject: [PATCH 46/54] more clean --- crates/burn-cube-macros/src/codegen/base.rs | 2 +- crates/burn-cube/src/element/conversion.rs | 2 ++ crates/burn-cube/src/element/float.rs | 6 +----- crates/burn-cube/src/element/int.rs | 7 +------ crates/burn-cube/src/element/mod.rs | 2 +- crates/burn-cube/tests/reuse.rs | 7 +++---- 6 files changed, 9 insertions(+), 17 deletions(-) diff --git a/crates/burn-cube-macros/src/codegen/base.rs b/crates/burn-cube-macros/src/codegen/base.rs index 0cea957469..5fbf91d58b 100644 --- a/crates/burn-cube-macros/src/codegen/base.rs +++ b/crates/burn-cube-macros/src/codegen/base.rs @@ -10,7 +10,7 @@ use super::{ }; /// Codegen for a statement (generally one line) -/// Entry point of codegen +/// Entry point of code generation pub fn codegen_statement( statement: &syn::Stmt, loop_level: usize, diff --git a/crates/burn-cube/src/element/conversion.rs b/crates/burn-cube/src/element/conversion.rs index 52ec88d6f7..dfa94fcea6 100644 --- a/crates/burn-cube/src/element/conversion.rs +++ b/crates/burn-cube/src/element/conversion.rs @@ -1,5 +1,7 @@ use crate::{Bool, Float, Int, PrimitiveVariable, UInt, BF16, F16, F32, F64, I32, I64}; +// Enable elegant casting from any to any primitive variable + macro_rules! impl_to_float { ($to:ident, $from1:ident) => { impl From<$from1> for $to { diff --git a/crates/burn-cube/src/element/float.rs b/crates/burn-cube/src/element/float.rs index 91917d16ec..c54a0f0dda 100644 --- a/crates/burn-cube/src/element/float.rs +++ b/crates/burn-cube/src/element/float.rs @@ -3,11 +3,7 @@ use burn_jit::gpu::{Elem, FloatKind, Variable}; use std::rc::Rc; /// Floating point numbers. Used as input in float kernels -pub trait Float: - Clone - + Copy - + Numeric -{ +pub trait Float: Numeric { /// Create a Float from a float literal fn from_primitive(val: f64) -> Self; /// Expand version of from_primitive diff --git a/crates/burn-cube/src/element/int.rs b/crates/burn-cube/src/element/int.rs index 34f66999b9..02f6f6afa2 100644 --- a/crates/burn-cube/src/element/int.rs +++ b/crates/burn-cube/src/element/int.rs @@ -3,12 +3,7 @@ use burn_jit::gpu::{Elem, IntKind, Variable}; use std::rc::Rc; /// Signed integer. Used as input in int kernels -pub trait Int: - Clone - + Copy - + std::ops::Rem - + Numeric -{ +pub trait Int: Numeric + std::ops::Rem { fn from_primitive(val: i64) -> Self; fn from_primitive_expand(context: &mut CubeContext, val: i64) -> ExpandElement; } diff --git a/crates/burn-cube/src/element/mod.rs b/crates/burn-cube/src/element/mod.rs index 84fa4f585c..df4f96b39b 100644 --- a/crates/burn-cube/src/element/mod.rs +++ b/crates/burn-cube/src/element/mod.rs @@ -5,8 +5,8 @@ mod conversion; mod float; mod int; mod numeric; -mod static_type; mod primitive; +mod static_type; mod uint; pub use array::*; diff --git a/crates/burn-cube/tests/reuse.rs b/crates/burn-cube/tests/reuse.rs index e06a19aa06..b86a7eea75 100644 --- a/crates/burn-cube/tests/reuse.rs +++ b/crates/burn-cube/tests/reuse.rs @@ -4,15 +4,14 @@ use burn_jit::{ gpu::{Branch, Elem, Item, Variable}, }; -// a += b is more efficient than a = a + b -// because the latter does not assume that a is the same in lhs and rhs -// It could be detected and optimized - type ElemType = I32; #[cube] #[allow(clippy::assign_op_pattern)] pub fn reuse(mut x: I) { + // a += b is more efficient than a = a + b + // Because the latter does not assume that a is the same in lhs and rhs + // Normally clippy should detect it while x < I::new(10) { x = x + I::new(1); } From f3341e96859458110b0be44464db0608f8c396fa Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 13 May 2024 13:17:12 -0400 Subject: [PATCH 47/54] fix expand outputs --- crates/burn-cube/src/element/base.rs | 9 +-------- crates/burn-cube/src/element/bool.rs | 2 +- crates/burn-cube/src/element/float.rs | 11 ++++++++--- crates/burn-cube/src/element/int.rs | 10 +++++++--- crates/burn-cube/src/element/numeric.rs | 4 ++-- crates/burn-cube/src/element/primitive.rs | 5 ++++- crates/burn-cube/src/element/uint.rs | 2 +- 7 files changed, 24 insertions(+), 19 deletions(-) diff --git a/crates/burn-cube/src/element/base.rs b/crates/burn-cube/src/element/base.rs index c62e1ea79b..2e7a925473 100644 --- a/crates/burn-cube/src/element/base.rs +++ b/crates/burn-cube/src/element/base.rs @@ -1,5 +1,5 @@ use alloc::rc::Rc; -use burn_jit::gpu::{Item, Variable}; +use burn_jit::gpu::Variable; /// Types used in a cube function must implement this trait /// @@ -24,13 +24,6 @@ pub struct ExpandElement { pub(crate) inner: Rc, } -impl ExpandElement { - /// Returns the Item of the variable - pub fn item(&self) -> Item { - self.inner.item() - } -} - impl core::ops::Deref for ExpandElement { type Target = Variable; diff --git a/crates/burn-cube/src/element/bool.rs b/crates/burn-cube/src/element/bool.rs index 1a77ced3b3..f9ab362ee1 100644 --- a/crates/burn-cube/src/element/bool.rs +++ b/crates/burn-cube/src/element/bool.rs @@ -23,7 +23,7 @@ impl Bool { } /// Expand version of new - pub fn new_expand(_context: &mut CubeContext, val: bool) -> ::ExpandType { + pub fn new_expand(_context: &mut CubeContext, val: bool) -> ::ExpandType { val.into() } } diff --git a/crates/burn-cube/src/element/float.rs b/crates/burn-cube/src/element/float.rs index c54a0f0dda..d3a072b5b9 100644 --- a/crates/burn-cube/src/element/float.rs +++ b/crates/burn-cube/src/element/float.rs @@ -7,7 +7,8 @@ pub trait Float: Numeric { /// Create a Float from a float literal fn from_primitive(val: f64) -> Self; /// Expand version of from_primitive - fn from_primitive_expand(context: &mut CubeContext, val: f64) -> ExpandElement; + fn from_primitive_expand(context: &mut CubeContext, val: f64) + -> ::ExpandType; } macro_rules! impl_float { @@ -45,7 +46,11 @@ macro_rules! impl_float { vectorization: 1, } } - fn from_primitive_expand(_context: &mut CubeContext, val: f64) -> ExpandElement { + + fn from_primitive_expand( + _context: &mut CubeContext, + val: f64, + ) -> ::ExpandType { let new_var = Variable::ConstantScalar(val, Self::into_elem()); ExpandElement::new(Rc::new(new_var)) } @@ -61,7 +66,7 @@ macro_rules! impl_float { } } - fn new_expand(context: &mut CubeContext, val: i64) -> ExpandElement { + fn new_expand(context: &mut CubeContext, val: i64) -> ::ExpandType { ::from_primitive_expand(context, val as f64) } } diff --git a/crates/burn-cube/src/element/int.rs b/crates/burn-cube/src/element/int.rs index 02f6f6afa2..4c6c88302f 100644 --- a/crates/burn-cube/src/element/int.rs +++ b/crates/burn-cube/src/element/int.rs @@ -5,7 +5,8 @@ use std::rc::Rc; /// Signed integer. Used as input in int kernels pub trait Int: Numeric + std::ops::Rem { fn from_primitive(val: i64) -> Self; - fn from_primitive_expand(context: &mut CubeContext, val: i64) -> ExpandElement; + fn from_primitive_expand(context: &mut CubeContext, val: i64) + -> ::ExpandType; } macro_rules! impl_int { @@ -38,7 +39,10 @@ macro_rules! impl_int { } } - fn from_primitive_expand(_context: &mut CubeContext, val: i64) -> ExpandElement { + fn from_primitive_expand( + _context: &mut CubeContext, + val: i64, + ) -> ::ExpandType { let new_var = Variable::ConstantScalar(val as f64, Self::into_elem()); ExpandElement::new(Rc::new(new_var)) } @@ -52,7 +56,7 @@ macro_rules! impl_int { } } - fn new_expand(context: &mut CubeContext, val: i64) -> ExpandElement { + fn new_expand(context: &mut CubeContext, val: i64) -> ::ExpandType { ::from_primitive_expand(context, val) } } diff --git a/crates/burn-cube/src/element/numeric.rs b/crates/burn-cube/src/element/numeric.rs index 30097cb43a..9b98814ed4 100644 --- a/crates/burn-cube/src/element/numeric.rs +++ b/crates/burn-cube/src/element/numeric.rs @@ -1,4 +1,4 @@ -use crate::{CubeContext, ExpandElement, PrimitiveVariable}; +use crate::{CubeContext, CubeType, PrimitiveVariable}; /// Type that encompasses both integers and floats /// Used in kernels that should work for both. @@ -21,5 +21,5 @@ pub trait Numeric: fn new(val: i64) -> Self; /// Expand version of new - fn new_expand(context: &mut CubeContext, val: i64) -> ExpandElement; + fn new_expand(context: &mut CubeContext, val: i64) -> ::ExpandType; } diff --git a/crates/burn-cube/src/element/primitive.rs b/crates/burn-cube/src/element/primitive.rs index 006bc25134..0a42a2bf34 100644 --- a/crates/burn-cube/src/element/primitive.rs +++ b/crates/burn-cube/src/element/primitive.rs @@ -18,7 +18,10 @@ pub trait PrimitiveVariable: CubeType { fn into_elem() -> Elem; /// Expand version of from, of the trait From - fn from_expand(context: &mut CubeContext, val: ExpandElement) -> ExpandElement { + fn from_expand( + context: &mut CubeContext, + val: ExpandElement, + ) -> ::ExpandType { let new_var = context.create_local(Item::Scalar(::into_elem())); assign::expand(context, val, new_var.clone()); new_var diff --git a/crates/burn-cube/src/element/uint.rs b/crates/burn-cube/src/element/uint.rs index 0d3520133d..9cd983191c 100644 --- a/crates/burn-cube/src/element/uint.rs +++ b/crates/burn-cube/src/element/uint.rs @@ -20,7 +20,7 @@ impl UInt { } /// Expand version of new - pub fn new_expand(_context: &mut CubeContext, val: i64) -> ::ExpandType { + pub fn new_expand(_context: &mut CubeContext, val: i64) -> ::ExpandType { (val as u32).into() } } From cd183e94e4dbac1b9dcf1931ad2d476b72b392e9 Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 13 May 2024 18:54:54 -0400 Subject: [PATCH 48/54] traits, modules and parenthesis --- crates/burn-cube-macros/src/analysis.rs | 2 + crates/burn-cube-macros/src/codegen/base.rs | 1 + crates/burn-cube-macros/src/lib.rs | 9 +- crates/burn-cube/src/element/float.rs | 7 +- crates/burn-cube/src/element/int.rs | 4 +- crates/burn-cube/src/element/numeric.rs | 4 +- crates/burn-cube/tests/cast_elem.rs | 32 +-- crates/burn-cube/tests/cast_kind.rs | 4 +- crates/burn-cube/tests/function_call.rs | 37 +++- crates/burn-cube/tests/generic_kernel.rs | 2 +- crates/burn-cube/tests/if.rs | 4 +- crates/burn-cube/tests/if_else.rs | 6 +- crates/burn-cube/tests/literal.rs | 2 +- crates/burn-cube/tests/loop.rs | 8 +- crates/burn-cube/tests/module_import.rs | 47 +++++ crates/burn-cube/tests/parenthesis.rs | 45 +++++ crates/burn-cube/tests/reuse.rs | 8 +- crates/burn-cube/tests/trait.rs | 203 ++++++++++++++++++++ 18 files changed, 381 insertions(+), 44 deletions(-) create mode 100644 crates/burn-cube/tests/module_import.rs create mode 100644 crates/burn-cube/tests/parenthesis.rs create mode 100644 crates/burn-cube/tests/trait.rs diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs index 6c86e2e8d4..50f819a84a 100644 --- a/crates/burn-cube-macros/src/analysis.rs +++ b/crates/burn-cube-macros/src/analysis.rs @@ -6,6 +6,7 @@ use crate::VariableKey; #[derive(Debug)] /// Information about a single variable's use in Cube code +/// Information about a single variable's use in Cube code /// Useful to figure out when the generated variable will need cloning pub(crate) struct VariableAnalysis { num_used: usize, @@ -237,6 +238,7 @@ impl CodeAnalysisBuilder { } } syn::Expr::Break(_) => {} + syn::Expr::Paren(expr) => self.find_occurrences_in_expr(&expr.expr, depth), _ => todo!("Analysis: unsupported expr {expr:?}"), } } diff --git a/crates/burn-cube-macros/src/codegen/base.rs b/crates/burn-cube-macros/src/codegen/base.rs index 5fbf91d58b..857d3bb518 100644 --- a/crates/burn-cube-macros/src/codegen/base.rs +++ b/crates/burn-cube-macros/src/codegen/base.rs @@ -83,6 +83,7 @@ pub(crate) fn codegen_expr( syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level, variable_analyses), syn::Expr::MethodCall(call) => codegen_expr_method_call(call), syn::Expr::Index(index) => codegen_index(index, loop_level, variable_analyses), + syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_analyses), _ => panic!("Codegen: Unsupported {:?}", expr), } } diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs index 31128fc4d7..ffb7f6f4b4 100644 --- a/crates/burn-cube-macros/src/lib.rs +++ b/crates/burn-cube-macros/src/lib.rs @@ -12,7 +12,7 @@ enum CubeMode { Default, /// Panics and prints the generated code, useful when debugging /// Use by writing #[cube(panic)] - Panic, + Debug, } /// Derive macro for the module. @@ -20,13 +20,14 @@ enum CubeMode { pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream { let args = parse_macro_input!(attr with Punctuated::::parse_terminated); let mode = parse_mode(args); + let func: syn::ItemFn = syn::parse(tokens).unwrap(); let mut variable_analyses = CodeAnalysis::create(&func); let code = codegen_cube(&func, &mut variable_analyses); match mode { CubeMode::Default => code, - CubeMode::Panic => panic!("{code}"), + CubeMode::Debug => panic!("{code}"), } } @@ -38,8 +39,8 @@ fn parse_mode(args: Punctuated) -> CubeMode { Meta::Path(path) => { if let Some(ident) = path.get_ident().map(|id| id.to_string()) { match ident.as_str() { - "panic" => { - mode = CubeMode::Panic; + "debug" => { + mode = CubeMode::Debug; } _ => panic!("Attribute {ident} is not supported"), } diff --git a/crates/burn-cube/src/element/float.rs b/crates/burn-cube/src/element/float.rs index d3a072b5b9..5fb67f6dc5 100644 --- a/crates/burn-cube/src/element/float.rs +++ b/crates/burn-cube/src/element/float.rs @@ -59,14 +59,17 @@ macro_rules! impl_float { impl Numeric for $type { // Method new takes an i64, because it is used when treating the float as numeric, // which must be an int in the cube kernel because new numerics need to be supported by Int as well - fn new(val: i64) -> Self { + fn constant(val: i64) -> Self { Self { val: val as f64, vectorization: 1, } } - fn new_expand(context: &mut CubeContext, val: i64) -> ::ExpandType { + fn constant_expand( + context: &mut CubeContext, + val: i64, + ) -> ::ExpandType { ::from_primitive_expand(context, val as f64) } } diff --git a/crates/burn-cube/src/element/int.rs b/crates/burn-cube/src/element/int.rs index 4c6c88302f..f6ae942d41 100644 --- a/crates/burn-cube/src/element/int.rs +++ b/crates/burn-cube/src/element/int.rs @@ -49,14 +49,14 @@ macro_rules! impl_int { } impl Numeric for $type { - fn new(val: i64) -> Self { + fn constant(val: i64) -> Self { Self { val, vectorization: 1, } } - fn new_expand(context: &mut CubeContext, val: i64) -> ::ExpandType { + fn constant_expand(context: &mut CubeContext, val: i64) -> ::ExpandType { ::from_primitive_expand(context, val) } } diff --git a/crates/burn-cube/src/element/numeric.rs b/crates/burn-cube/src/element/numeric.rs index 9b98814ed4..210fc602f4 100644 --- a/crates/burn-cube/src/element/numeric.rs +++ b/crates/burn-cube/src/element/numeric.rs @@ -18,8 +18,8 @@ pub trait Numeric: /// Note: since this must work for both integer and float /// only the less expressive of both can be created (int) /// If a number with decimals is needed, use Float::from_primitive. - fn new(val: i64) -> Self; + fn constant(val: i64) -> Self; /// Expand version of new - fn new_expand(context: &mut CubeContext, val: i64) -> ::ExpandType; + fn constant_expand(context: &mut CubeContext, val: i64) -> ::ExpandType; } diff --git a/crates/burn-cube/tests/cast_elem.rs b/crates/burn-cube/tests/cast_elem.rs index c19b5db763..3b0e70b049 100644 --- a/crates/burn-cube/tests/cast_elem.rs +++ b/crates/burn-cube/tests/cast_elem.rs @@ -44,25 +44,25 @@ macro_rules! cast_test { #[cube] #[allow(clippy::useless_conversion)] pub fn float_to_float(x: F32) { - let y = x + F32::new(2); - let _ = F32::from(y) + F32::new(34); + let y = x + F32::constant(2); + let _ = F32::from(y) + F32::constant(34); } #[cube] pub fn float_to_int(x: F32) { - let y = x + F32::new(2); - let _ = I32::from(y) + I32::new(34); + let y = x + F32::constant(2); + let _ = I32::from(y) + I32::constant(34); } #[cube] pub fn float_to_uint(x: F32) { - let y = x + F32::new(2); + let y = x + F32::constant(2); let _ = UInt::from(y) + UInt::new(34); } #[cube] pub fn float_to_bool(x: F32) { - let y = x + F32::new(2); + let y = x + F32::constant(2); let _ = Bool::from(y) | Bool::new(true); } @@ -96,26 +96,26 @@ cast_test!( // // From int #[cube] pub fn int_to_float(x: I32) { - let y = x + I32::new(2); - let _ = F32::from(y) + F32::new(34); + let y = x + I32::constant(2); + let _ = F32::from(y) + F32::constant(34); } #[cube] #[allow(clippy::useless_conversion)] pub fn int_to_int(x: I32) { - let y = x + I32::new(2); - let _ = I32::from(y) + I32::new(34); + let y = x + I32::constant(2); + let _ = I32::from(y) + I32::constant(34); } #[cube] pub fn int_to_uint(x: I32) { - let y = x + I32::new(2); + let y = x + I32::constant(2); let _ = UInt::from(y) + UInt::new(34); } #[cube] pub fn int_to_bool(x: I32) { - let y = x + I32::new(2); + let y = x + I32::constant(2); let _ = Bool::from(y) | Bool::new(true); } @@ -150,13 +150,13 @@ cast_test!( #[cube] pub fn uint_to_float(x: UInt) { let y = x + UInt::new(2); - let _ = F32::from(y) + F32::new(34); + let _ = F32::from(y) + F32::constant(34); } #[cube] pub fn uint_to_int(x: UInt) { let y = x + UInt::new(2); - let _ = I32::from(y) + I32::new(34); + let _ = I32::from(y) + I32::constant(34); } #[cube] @@ -203,13 +203,13 @@ cast_test!( #[cube] pub fn bool_to_float(x: Bool) { let y = x & Bool::new(false); - let _ = F32::from(y) + F32::new(34); + let _ = F32::from(y) + F32::constant(34); } #[cube] pub fn bool_to_int(x: Bool) { let y = x & Bool::new(false); - let _ = I32::from(y) + I32::new(34); + let _ = I32::from(y) + I32::constant(34); } #[cube] diff --git a/crates/burn-cube/tests/cast_kind.rs b/crates/burn-cube/tests/cast_kind.rs index 8932cc03d8..d1cf9786bb 100644 --- a/crates/burn-cube/tests/cast_kind.rs +++ b/crates/burn-cube/tests/cast_kind.rs @@ -17,9 +17,9 @@ pub fn cast_int_kind>(input: I1) { #[cube] pub fn cast_numeric_to_kind>(input: T) { - let x = input + T::new(5); + let x = input + T::constant(5); let y = I2::from(x); - let _ = y + I2::new(2); + let _ = y + I2::constant(2); } #[test] diff --git a/crates/burn-cube/tests/function_call.rs b/crates/burn-cube/tests/function_call.rs index 3d3578f911..2426e40d26 100644 --- a/crates/burn-cube/tests/function_call.rs +++ b/crates/burn-cube/tests/function_call.rs @@ -1,4 +1,4 @@ -use burn_cube::{cube, CubeContext, UInt}; +use burn_cube::{cube, CubeContext, Numeric, PrimitiveVariable, UInt, I64}; use burn_jit::gpu::{Elem, Item}; #[cube] @@ -31,6 +31,21 @@ pub fn no_call_with_arg(x: UInt) { let _ = x + x * UInt::new(8); } +#[cube] +pub fn caller_with_generics(x: T) { + let _ = x + callee_with_generics::(x); +} + +#[cube] +pub fn callee_with_generics(x: T) -> T { + x * T::constant(8) +} + +#[cube] +pub fn no_call_with_generics(x: T) { + let _ = x + x * T::constant(8); +} + #[test] fn cube_call_equivalent_to_no_call_no_arg_test() { let mut caller_context = CubeContext::root(); @@ -52,6 +67,7 @@ fn cube_call_equivalent_to_no_call_no_arg_test() { #[test] fn cube_call_equivalent_to_no_call_with_arg_test() { let mut caller_context = CubeContext::root(); + let x = caller_context.create_local(Item::Scalar(Elem::UInt)); caller_with_arg_expand(&mut caller_context, x); let caller_scope = caller_context.into_scope(); @@ -66,3 +82,22 @@ fn cube_call_equivalent_to_no_call_with_arg_test() { format!("{:?}", no_call_scope.operations) ); } + +#[test] +fn cube_call_equivalent_to_no_call_with_generics_test() { + let mut caller_context = CubeContext::root(); + type ElemType = I64; + let x = caller_context.create_local(Item::Scalar(ElemType::into_elem())); + caller_with_generics_expand::(&mut caller_context, x); + let caller_scope = caller_context.into_scope(); + + let mut no_call_context = CubeContext::root(); + let x = no_call_context.create_local(Item::Scalar(ElemType::into_elem())); + no_call_with_generics_expand::(&mut no_call_context, x); + let no_call_scope = no_call_context.into_scope(); + + assert_eq!( + format!("{:?}", caller_scope.operations), + format!("{:?}", no_call_scope.operations) + ); +} diff --git a/crates/burn-cube/tests/generic_kernel.rs b/crates/burn-cube/tests/generic_kernel.rs index 1887b397bd..f037f3088f 100644 --- a/crates/burn-cube/tests/generic_kernel.rs +++ b/crates/burn-cube/tests/generic_kernel.rs @@ -3,7 +3,7 @@ use burn_jit::{gpu, gpu::Item}; #[cube] pub fn generic_kernel(lhs: T) { - let _ = lhs + T::new(5); + let _ = lhs + T::constant(5); } #[test] diff --git a/crates/burn-cube/tests/if.rs b/crates/burn-cube/tests/if.rs index f7a6b2ba2f..d631926ce7 100644 --- a/crates/burn-cube/tests/if.rs +++ b/crates/burn-cube/tests/if.rs @@ -8,8 +8,8 @@ type ElemType = F32; #[cube] pub fn if_greater(lhs: T) { - if lhs > T::new(0) { - let _ = lhs + T::new(4); + if lhs > T::constant(0) { + let _ = lhs + T::constant(4); } } diff --git a/crates/burn-cube/tests/if_else.rs b/crates/burn-cube/tests/if_else.rs index dc9d643deb..7a462c5e13 100644 --- a/crates/burn-cube/tests/if_else.rs +++ b/crates/burn-cube/tests/if_else.rs @@ -8,10 +8,10 @@ type ElemType = F32; #[cube] pub fn if_then_else(lhs: F) { - if lhs < F::new(0) { - let _ = lhs + F::new(4); + if lhs < F::constant(0) { + let _ = lhs + F::constant(4); } else { - let _ = lhs - F::new(5); + let _ = lhs - F::constant(5); } } diff --git a/crates/burn-cube/tests/literal.rs b/crates/burn-cube/tests/literal.rs index 1718a3e780..e7acfb73fe 100644 --- a/crates/burn-cube/tests/literal.rs +++ b/crates/burn-cube/tests/literal.rs @@ -5,7 +5,7 @@ type ElemType = F32; #[cube] pub fn literal(lhs: F) { - let _ = lhs + F::new(5); + let _ = lhs + F::constant(5); } #[cube] diff --git a/crates/burn-cube/tests/loop.rs b/crates/burn-cube/tests/loop.rs index f4e78294bc..c28fd6ae18 100644 --- a/crates/burn-cube/tests/loop.rs +++ b/crates/burn-cube/tests/loop.rs @@ -7,18 +7,18 @@ type ElemType = I32; #[cube] pub fn while_not(lhs: I) { - while lhs != I::new(0) { - let _ = lhs % I::new(1); + while lhs != I::constant(0) { + let _ = lhs % I::constant(1); } } #[cube] pub fn manual_loop_break(lhs: I) { loop { - if lhs != I::new(0) { + if lhs != I::constant(0) { break; } - let _ = lhs % I::new(1); + let _ = lhs % I::constant(1); } } diff --git a/crates/burn-cube/tests/module_import.rs b/crates/burn-cube/tests/module_import.rs new file mode 100644 index 0000000000..5eabe8b3ae --- /dev/null +++ b/crates/burn-cube/tests/module_import.rs @@ -0,0 +1,47 @@ +use burn_cube::{CubeContext, PrimitiveVariable, F32}; +use burn_jit::gpu::Item; + +type ElemType = F32; + +mod elsewhere { + use burn_cube::{cube, Float}; + + #[cube] + pub fn my_func(x: F) -> F { + x * F::constant(2) + } +} + +mod here { + use burn_cube::{cube, Float}; + + use crate::elsewhere; + + #[cube] + pub fn caller(x: F) { + let _ = x + elsewhere::my_func::(x); + } + + #[cube] + pub fn no_call_ref(x: F) { + let _ = x + x * F::constant(2); + } +} + +#[test] +fn cube_call_equivalent_to_no_call_no_arg_test() { + let mut caller_context = CubeContext::root(); + let x = caller_context.create_local(Item::Scalar(ElemType::into_elem())); + here::caller_expand::(&mut caller_context, x); + let caller_scope = caller_context.into_scope(); + + let mut no_call_context = CubeContext::root(); + let x = no_call_context.create_local(Item::Scalar(ElemType::into_elem())); + here::no_call_ref_expand::(&mut no_call_context, x); + let no_call_scope = no_call_context.into_scope(); + + assert_eq!( + format!("{:?}", caller_scope.operations), + format!("{:?}", no_call_scope.operations) + ); +} diff --git a/crates/burn-cube/tests/parenthesis.rs b/crates/burn-cube/tests/parenthesis.rs new file mode 100644 index 0000000000..7a6b2049db --- /dev/null +++ b/crates/burn-cube/tests/parenthesis.rs @@ -0,0 +1,45 @@ +use burn_cube::{cube, CubeContext, Numeric, PrimitiveVariable, F32}; +use burn_jit::{ + gpu, + gpu::{Item, Variable}, +}; + +type ElemType = F32; + +#[cube] +pub fn parenthesis(x: T, y: T, z: T) -> T { + x * (y + z) +} + +#[test] +fn cube_parenthesis_priority_test() { + let mut context = CubeContext::root(); + + let x = context.create_local(Item::Scalar(ElemType::into_elem())); + let y = context.create_local(Item::Scalar(ElemType::into_elem())); + let z = context.create_local(Item::Scalar(ElemType::into_elem())); + + parenthesis_expand::(&mut context, x, y, z); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); +} + +fn inline_macro_ref() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(ElemType::into_elem()); + let x = context.create_local(item); + let y = context.create_local(item); + let z = context.create_local(item); + + let mut scope = context.into_scope(); + let x: Variable = x.into(); + let y: Variable = y.into(); + let z: Variable = z.into(); + let tmp = scope.create_local(item); + + gpu!(scope, tmp = y + z); + gpu!(scope, y = x * tmp); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/reuse.rs b/crates/burn-cube/tests/reuse.rs index b86a7eea75..936e5d7ba4 100644 --- a/crates/burn-cube/tests/reuse.rs +++ b/crates/burn-cube/tests/reuse.rs @@ -12,15 +12,15 @@ pub fn reuse(mut x: I) { // a += b is more efficient than a = a + b // Because the latter does not assume that a is the same in lhs and rhs // Normally clippy should detect it - while x < I::new(10) { - x = x + I::new(1); + while x < I::constant(10) { + x = x + I::constant(1); } } #[cube] pub fn reuse_incr(mut x: I) { - while x < I::new(10) { - x += I::new(1); + while x < I::constant(10) { + x += I::constant(1); } } diff --git a/crates/burn-cube/tests/trait.rs b/crates/burn-cube/tests/trait.rs new file mode 100644 index 0000000000..911cf181db --- /dev/null +++ b/crates/burn-cube/tests/trait.rs @@ -0,0 +1,203 @@ +use burn_cube::{cube, CubeContext, CubeType, Float, Numeric, PrimitiveVariable, F32}; +use burn_jit::{ + gpu, + gpu::{Item, Variable}, +}; + +type ElemType = F32; + +/// Traits used in Cube kernels must expose an _expand variant +/// for all their methods. However, one does not need to provide its +/// implementation, see examples below. +trait Strategy { + fn operation(input_1: T, input_2: T) -> T; + fn operation_expand( + context: &mut CubeContext, + input_1: ::ExpandType, + input_2: ::ExpandType, + ) -> ::ExpandType; +} + +struct AddStrategy; + +#[cube] +/// The actual implementation of AddStrategy's operation +/// Automatically generated an _expand variant +fn add_strategy_operation(input_1: T, input_2: T) -> T { + input_1 + input_2 +} + +impl Strategy for AddStrategy { + /// Here we link the trait's method to the cube function + fn operation(input_1: T, input_2: T) -> T { + add_strategy_operation(input_1, input_2) + } + + /// Here we link the trait's expanded method to the cube expanded function + fn operation_expand( + context: &mut CubeContext, + input_1: ::ExpandType, + input_2: ::ExpandType, + ) -> ::ExpandType { + add_strategy_operation_expand::(context, input_1, input_2) + } +} + +struct SubStrategy; + +#[cube] +fn sub_strategy_operation(input_1: T, input_2: T) -> T { + input_1 - input_2 +} + +impl Strategy for SubStrategy { + fn operation(input_1: T, input_2: T) -> T { + sub_strategy_operation(input_1, input_2) + } + + fn operation_expand( + context: &mut CubeContext, + input_1: ::ExpandType, + input_2: ::ExpandType, + ) -> ::ExpandType { + sub_strategy_operation_expand::(context, input_1, input_2) + } +} + +#[cube] +fn with_strategy_trait, T: Numeric>(x: T, y: T) -> T { + let z = S::operation(x, y); + z +} + +#[cube] +fn two_strategy_traits, S2: Strategy, F: Float>(x: F, y: F) -> F { + let z = S1::operation(x, y); + S2::operation(z, y) +} + +trait MethodTypedStrategy { + fn operation(input_1: T, input_2: T) -> T; + fn operation_expand( + _context: &mut CubeContext, + input_1: ::ExpandType, + input_2: ::ExpandType, + ) -> ::ExpandType; +} + +impl MethodTypedStrategy for AddStrategy { + fn operation(input_1: T, input_2: T) -> T { + add_strategy_operation(input_1, input_2) + } + + fn operation_expand( + context: &mut CubeContext, + input_1: ::ExpandType, + input_2: ::ExpandType, + ) -> ::ExpandType { + add_strategy_operation_expand::(context, input_1, input_2) + } +} + +#[cube] +fn with_trait_generic_method(x: T, y: T) -> T { + let z = S::operation::(x, y); + z +} + +#[test] +fn cube_strategy_trait_add_test() { + let mut context = CubeContext::root(); + + let x = context.create_local(Item::Scalar(ElemType::into_elem())); + let y = context.create_local(Item::Scalar(ElemType::into_elem())); + + with_strategy_trait_expand::(&mut context, x, y); + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + inline_macro_ref_one(true) + ); +} + +#[test] +fn cube_strategy_trait_sub_test() { + let mut context = CubeContext::root(); + + let x = context.create_local(Item::Scalar(ElemType::into_elem())); + let y = context.create_local(Item::Scalar(ElemType::into_elem())); + + with_strategy_trait_expand::(&mut context, x, y); + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + inline_macro_ref_one(false) + ); +} + +#[test] +fn cube_two_strategy_traits_test() { + let mut context = CubeContext::root(); + + let x = context.create_local(Item::Scalar(ElemType::into_elem())); + let y = context.create_local(Item::Scalar(ElemType::into_elem())); + + two_strategy_traits_expand::(&mut context, x, y); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_two()); +} + +#[test] +fn cube_trait_generic_method_test() { + let mut context = CubeContext::root(); + + let x = context.create_local(Item::Scalar(ElemType::into_elem())); + let y = context.create_local(Item::Scalar(ElemType::into_elem())); + + with_trait_generic_method_expand::(&mut context, x, y); + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + inline_macro_ref_one(true) + ); +} + +fn inline_macro_ref_one(is_add_strategy: bool) -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(ElemType::into_elem()); + let x = context.create_local(item); + let y = context.create_local(item); + + let mut scope = context.into_scope(); + let x: Variable = x.into(); + let y: Variable = y.into(); + let tmp = scope.create_local(item); + + match is_add_strategy { + true => gpu!(scope, tmp = x + y), + false => gpu!(scope, tmp = x - y), + } + + format!("{:?}", scope.operations) +} + +fn inline_macro_ref_two() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(ElemType::into_elem()); + let x = context.create_local(item); + let y = context.create_local(item); + + let mut scope = context.into_scope(); + let x: Variable = x.into(); + let y: Variable = y.into(); + let tmp = scope.create_local(item); + + gpu!(scope, tmp = x - y); + gpu!(scope, x = tmp + y); + + format!("{:?}", scope.operations) +} From 652c68a21904b51f7288084306330c0793628ef3 Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 13 May 2024 18:55:25 -0400 Subject: [PATCH 49/54] fmt --- crates/burn-cube/src/element/int.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/crates/burn-cube/src/element/int.rs b/crates/burn-cube/src/element/int.rs index f6ae942d41..fde86edb48 100644 --- a/crates/burn-cube/src/element/int.rs +++ b/crates/burn-cube/src/element/int.rs @@ -56,7 +56,10 @@ macro_rules! impl_int { } } - fn constant_expand(context: &mut CubeContext, val: i64) -> ::ExpandType { + fn constant_expand( + context: &mut CubeContext, + val: i64, + ) -> ::ExpandType { ::from_primitive_expand(context, val) } } From 040bbe02f9281f339530f499933fcedec9675e58 Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 13 May 2024 18:55:55 -0400 Subject: [PATCH 50/54] clippy --- crates/burn-cube/tests/trait.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/burn-cube/tests/trait.rs b/crates/burn-cube/tests/trait.rs index 911cf181db..125de8e31e 100644 --- a/crates/burn-cube/tests/trait.rs +++ b/crates/burn-cube/tests/trait.rs @@ -66,8 +66,8 @@ impl Strategy for SubStrategy { #[cube] fn with_strategy_trait, T: Numeric>(x: T, y: T) -> T { - let z = S::operation(x, y); - z + + S::operation(x, y) } #[cube] @@ -101,8 +101,8 @@ impl MethodTypedStrategy for AddStrategy { #[cube] fn with_trait_generic_method(x: T, y: T) -> T { - let z = S::operation::(x, y); - z + + S::operation::(x, y) } #[test] From 73ecc1941498703fccdbcedd21de0e3e3c0e1c30 Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 13 May 2024 19:00:02 -0400 Subject: [PATCH 51/54] fmt again? --- crates/burn-cube/tests/trait.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/crates/burn-cube/tests/trait.rs b/crates/burn-cube/tests/trait.rs index 125de8e31e..bd8ad9dfc4 100644 --- a/crates/burn-cube/tests/trait.rs +++ b/crates/burn-cube/tests/trait.rs @@ -66,7 +66,6 @@ impl Strategy for SubStrategy { #[cube] fn with_strategy_trait, T: Numeric>(x: T, y: T) -> T { - S::operation(x, y) } @@ -101,7 +100,6 @@ impl MethodTypedStrategy for AddStrategy { #[cube] fn with_trait_generic_method(x: T, y: T) -> T { - S::operation::(x, y) } From 2e348f1e4e4ed9bd6f6cf785f25e2eac8cceba5f Mon Sep 17 00:00:00 2001 From: louisfd Date: Wed, 15 May 2024 09:26:36 -0400 Subject: [PATCH 52/54] ops not compiling --- crates/burn-cube/tests/operation.rs | 180 ++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 crates/burn-cube/tests/operation.rs diff --git a/crates/burn-cube/tests/operation.rs b/crates/burn-cube/tests/operation.rs new file mode 100644 index 0000000000..1c18c74281 --- /dev/null +++ b/crates/burn-cube/tests/operation.rs @@ -0,0 +1,180 @@ +use burn_cube::{branch::*, cube, Bool, CubeContext, Float, Numeric, PrimitiveVariable, UInt, F32}; +use burn_jit::{ + gpu, + gpu::{Elem, Item, Variable}, +}; + +#[cube] +fn add_op(a: T, b: T) -> T { + a + b +} + +#[cube] +fn sub_op(a: T, b: T) -> T { + a - b +} + +#[cube] +fn mul_op(a: T, b: T) -> T { + a * b +} + +#[cube] +fn div_op(a: T, b: T) -> T { + a / b +} + +#[cube] +fn abs_op(a: T) -> T { + abs(a) +} + +#[cube] +fn exp_op(a: F) -> F { + exp(a) +} + +#[cube] +fn log_op(a: F) -> F { + log(a) +} + +#[cube] +fn log1p_op(a: F) -> F { + log1p(a) +} + +#[cube] +fn cos_op(a: F) -> F { + cos(a) +} + +#[cube] +fn sin_op(a: F) -> F { + sin(a) +} + +#[cube] +fn tanh_op(a: F) -> F { + tanh(a) +} + +#[cube] +fn powf_op(a: F, b: F) -> F { + powf(a, b) +} + +#[cube] +fn sqrt_op(a: F) -> F { + sqrt(a) +} + +#[cube] +fn floor_op(a: F) -> F { + floor(a) +} + +#[cube] +fn ceil_op(a: F) -> F { + ceil(a) +} + +#[cube] +fn erf_op(a: F) -> F { + erf(a) +} + +#[cube] +fn recip_op(a: F) -> F { + recip(a) +} + +#[cube] +fn equal_op(a: T, b: T) -> Bool { + a == b +} + +#[cube] +fn not_equal_op(a: T, b: T) -> Bool { + a != b +} + +#[cube] +fn lower_op(a: T, b: T) -> Bool { + a < b +} + +#[cube] +fn greater_op(a: T, b: T) -> Bool { + a > b +} + +#[cube] +fn lower_equal_op(a: T, b: T) -> Bool { + a <= b +} + +#[cube] +fn greater_equal_op(a: T, b: T) -> Bool { + a >= b +} + +#[cube] +fn clamp_op(a: T, l: T, u: T) -> T { + clamp(a, l, u) +} + +#[cube] +fn modulo_op(a: UInt, b: UInt) -> UInt { + a % b +} + +#[cube] +fn remainder_op(a: T, b: T) -> T { + rem(a, b) +} + +#[cube] +fn max_op(a: T, b: T) -> T { + max(a, b) +} + +#[cube] +fn min_op(a: T, b: T) -> T { + min(a, b) +} + +#[cube] +fn and_op(a: Bool, b: Bool) -> Bool { + a & b +} + +#[cube] +fn or_op(a: Bool, b: Bool) -> Bool { + a | b +} + +#[cube] +fn not_op(a: Bool) -> Bool { + !a +} + +#[cube] +fn bit_and_op(a: UInt, b: UInt) -> UInt { + a & b +} + +#[cube] +fn bit_or_op(a: UInt, b: UInt) -> UInt { + a | b +} + +#[cube] +fn shift_left_op(a: UInt, b: UInt) -> UInt { + a << b +} + +#[cube] +fn shift_left_op(a: UInt, b: UInt) -> UInt { + a >> b +} From c7a2454c1e70070d463a905a809040fc4fc03e65 Mon Sep 17 00:00:00 2001 From: louisfd Date: Wed, 15 May 2024 09:39:06 -0400 Subject: [PATCH 53/54] make uint numeric --- crates/burn-cube/src/element/bool.rs | 4 +- crates/burn-cube/src/element/conversion.rs | 41 ++--- crates/burn-cube/src/element/numeric.rs | 2 +- crates/burn-cube/src/element/uint.rs | 40 ++++- crates/burn-cube/src/operation/binary.rs | 15 +- crates/burn-cube/tests/cast_elem.rs | 32 ++-- crates/burn-cube/tests/function_call.rs | 8 +- crates/burn-cube/tests/operation.rs | 180 --------------------- 8 files changed, 73 insertions(+), 249 deletions(-) delete mode 100644 crates/burn-cube/tests/operation.rs diff --git a/crates/burn-cube/src/element/bool.rs b/crates/burn-cube/src/element/bool.rs index f9ab362ee1..8c1d826ad5 100644 --- a/crates/burn-cube/src/element/bool.rs +++ b/crates/burn-cube/src/element/bool.rs @@ -15,7 +15,7 @@ impl CubeType for Bool { impl Bool { /// Create a Bool from primitive bool - pub fn new(val: bool) -> Self { + pub fn constant(val: bool) -> Self { Self { val, vectorization: 1, @@ -23,7 +23,7 @@ impl Bool { } /// Expand version of new - pub fn new_expand(_context: &mut CubeContext, val: bool) -> ::ExpandType { + pub fn constant_expand(_context: &mut CubeContext, val: bool) -> ::ExpandType { val.into() } } diff --git a/crates/burn-cube/src/element/conversion.rs b/crates/burn-cube/src/element/conversion.rs index dfa94fcea6..778427d5f7 100644 --- a/crates/burn-cube/src/element/conversion.rs +++ b/crates/burn-cube/src/element/conversion.rs @@ -96,42 +96,19 @@ impl_to_int!(I64, I32); impl_to_int!(I64, UInt); impl_to_int_from_bool!(I64, Bool); -macro_rules! impl_to_uint { - ($to:ident, $from1:ident) => { - impl From<$from1> for $to { - fn from(value: $from1) -> Self { - Self::new(value.val() as i64) - } - } - }; -} - -macro_rules! impl_to_uint_from_bool { - ($to:ident, $from1:ident) => { - impl From<$from1> for $to { - fn from(value: $from1) -> Self { - Self::new(match value.val() { - true => 1, - false => 0, - }) - } - } - }; -} - -impl_to_uint!(UInt, F16); -impl_to_uint!(UInt, BF16); -impl_to_uint!(UInt, F32); -impl_to_uint!(UInt, F64); -impl_to_uint!(UInt, I32); -impl_to_uint!(UInt, I64); -impl_to_uint_from_bool!(UInt, Bool); +impl_to_int!(UInt, F16); +impl_to_int!(UInt, BF16); +impl_to_int!(UInt, F32); +impl_to_int!(UInt, F64); +impl_to_int!(UInt, I32); +impl_to_int!(UInt, I64); +impl_to_int_from_bool!(UInt, Bool); macro_rules! impl_to_bool_from_float { ($to:ident, $from1:ident) => { impl From<$from1> for $to { fn from(value: $from1) -> Self { - Self::new(value.val() > 0.) + Self::constant(value.val() > 0.) } } }; @@ -141,7 +118,7 @@ macro_rules! impl_to_bool_from_int { ($to:ident, $from1:ident) => { impl From<$from1> for $to { fn from(value: $from1) -> Self { - Self::new(value.val() > 0) + Self::constant(value.val() > 0) } } }; diff --git a/crates/burn-cube/src/element/numeric.rs b/crates/burn-cube/src/element/numeric.rs index 210fc602f4..cbf68187a0 100644 --- a/crates/burn-cube/src/element/numeric.rs +++ b/crates/burn-cube/src/element/numeric.rs @@ -1,6 +1,6 @@ use crate::{CubeContext, CubeType, PrimitiveVariable}; -/// Type that encompasses both integers and floats +/// Type that encompasses both (unsigned or signed) integers and floats /// Used in kernels that should work for both. pub trait Numeric: Clone diff --git a/crates/burn-cube/src/element/uint.rs b/crates/burn-cube/src/element/uint.rs index 9cd983191c..e9cb22c514 100644 --- a/crates/burn-cube/src/element/uint.rs +++ b/crates/burn-cube/src/element/uint.rs @@ -1,6 +1,6 @@ use burn_jit::gpu::Elem; -use crate::{CubeContext, CubeType, ExpandElement, PrimitiveVariable}; +use crate::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable}; #[derive(Clone, Copy)] /// An unsigned int. @@ -11,18 +11,31 @@ pub struct UInt { } impl UInt { - /// Create a UInt. Use with integer literal - pub fn new(val: i64) -> Self { + pub fn from_primitive(val: i64) -> Self { Self { val, vectorization: 1, } } - /// Expand version of new - pub fn new_expand(_context: &mut CubeContext, val: i64) -> ::ExpandType { + pub fn from_primitive_expand( + _context: &mut CubeContext, + val: i64, + ) -> ::ExpandType { (val as u32).into() } + // /// Create a UInt. Use with integer literal + // pub fn new(val: i64) -> Self { + // Self { + // val, + // vectorization: 1, + // } + // } + + // /// Expand version of new + // pub fn new_expand(_context: &mut CubeContext, val: i64) -> ::ExpandType { + // (val as u32).into() + // } } impl CubeType for UInt { @@ -39,14 +52,27 @@ impl PrimitiveVariable for UInt { } } +impl Numeric for UInt { + fn constant(val: i64) -> Self { + Self { + val, + vectorization: 1, + } + } + + fn constant_expand(context: &mut CubeContext, val: i64) -> ::ExpandType { + Self::from_primitive_expand(context, val) + } +} + impl From for UInt { fn from(value: u32) -> Self { - UInt::new(value as ::Primitive) + UInt::from_primitive(value as ::Primitive) } } impl From for UInt { fn from(value: usize) -> Self { - UInt::new(value as ::Primitive) + UInt::from_primitive(value as ::Primitive) } } diff --git a/crates/burn-cube/src/operation/binary.rs b/crates/burn-cube/src/operation/binary.rs index 2e8b4f15b0..5647e41d57 100644 --- a/crates/burn-cube/src/operation/binary.rs +++ b/crates/burn-cube/src/operation/binary.rs @@ -30,7 +30,7 @@ pub mod add { type Output = Self; fn add(self, rhs: Self) -> Self::Output { - <$type>::new(self.val + rhs.val) + <$type>::from_primitive(self.val + rhs.val) } } }; @@ -72,7 +72,7 @@ pub mod sub { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { - <$type>::new(self.val - rhs.val) + <$type>::from_primitive(self.val - rhs.val) } } }; @@ -114,7 +114,7 @@ pub mod mul { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { - <$type>::new(self.val * rhs.val) + <$type>::from_primitive(self.val * rhs.val) } } }; @@ -156,7 +156,7 @@ pub mod div { type Output = Self; fn div(self, rhs: Self) -> Self::Output { - <$type>::new(self.val / rhs.val) + <$type>::from_primitive(self.val / rhs.val) } } }; @@ -168,6 +168,7 @@ pub mod div { impl_div!(F64, Float); impl_div!(I32, Int); impl_div!(I64, Int); + impl_div!(UInt); } pub mod rem { @@ -197,7 +198,7 @@ pub mod rem { type Output = Self; fn rem(self, rhs: Self) -> Self::Output { - <$type>::new(self.val % rhs.val) + <$type>::from_primitive(self.val % rhs.val) } } }; @@ -225,7 +226,7 @@ pub mod and { type Output = Bool; fn bitand(self, rhs: Self) -> Self::Output { - Bool::new(self.val && rhs.val) + Bool::constant(self.val && rhs.val) } } } @@ -247,7 +248,7 @@ pub mod or { type Output = Bool; fn bitor(self, rhs: Self) -> Self::Output { - Bool::new(self.val || rhs.val) + Bool::constant(self.val || rhs.val) } } } diff --git a/crates/burn-cube/tests/cast_elem.rs b/crates/burn-cube/tests/cast_elem.rs index 3b0e70b049..25ed945bce 100644 --- a/crates/burn-cube/tests/cast_elem.rs +++ b/crates/burn-cube/tests/cast_elem.rs @@ -57,13 +57,13 @@ pub fn float_to_int(x: F32) { #[cube] pub fn float_to_uint(x: F32) { let y = x + F32::constant(2); - let _ = UInt::from(y) + UInt::new(34); + let _ = UInt::from(y) + UInt::constant(34); } #[cube] pub fn float_to_bool(x: F32) { let y = x + F32::constant(2); - let _ = Bool::from(y) | Bool::new(true); + let _ = Bool::from(y) | Bool::constant(true); } cast_test!( @@ -110,13 +110,13 @@ pub fn int_to_int(x: I32) { #[cube] pub fn int_to_uint(x: I32) { let y = x + I32::constant(2); - let _ = UInt::from(y) + UInt::new(34); + let _ = UInt::from(y) + UInt::constant(34); } #[cube] pub fn int_to_bool(x: I32) { let y = x + I32::constant(2); - let _ = Bool::from(y) | Bool::new(true); + let _ = Bool::from(y) | Bool::constant(true); } cast_test!( @@ -149,27 +149,27 @@ cast_test!( // // From uint #[cube] pub fn uint_to_float(x: UInt) { - let y = x + UInt::new(2); + let y = x + UInt::constant(2); let _ = F32::from(y) + F32::constant(34); } #[cube] pub fn uint_to_int(x: UInt) { - let y = x + UInt::new(2); + let y = x + UInt::constant(2); let _ = I32::from(y) + I32::constant(34); } #[cube] #[allow(clippy::useless_conversion)] pub fn uint_to_uint(x: UInt) { - let y = x + UInt::new(2); - let _ = UInt::from(y) + UInt::new(34); + let y = x + UInt::constant(2); + let _ = UInt::from(y) + UInt::constant(34); } #[cube] pub fn uint_to_bool(x: UInt) { - let y = x + UInt::new(2); - let _ = Bool::from(y) | Bool::new(true); + let y = x + UInt::constant(2); + let _ = Bool::from(y) | Bool::constant(true); } cast_test!( @@ -202,27 +202,27 @@ cast_test!( // From bool #[cube] pub fn bool_to_float(x: Bool) { - let y = x & Bool::new(false); + let y = x & Bool::constant(false); let _ = F32::from(y) + F32::constant(34); } #[cube] pub fn bool_to_int(x: Bool) { - let y = x & Bool::new(false); + let y = x & Bool::constant(false); let _ = I32::from(y) + I32::constant(34); } #[cube] pub fn bool_to_uint(x: Bool) { - let y = x & Bool::new(false); - let _ = UInt::from(y) + UInt::new(34); + let y = x & Bool::constant(false); + let _ = UInt::from(y) + UInt::constant(34); } #[cube] #[allow(clippy::useless_conversion)] pub fn bool_to_bool(x: Bool) { - let y = x & Bool::new(false); - let _ = Bool::from(y) | Bool::new(true); + let y = x & Bool::constant(false); + let _ = Bool::from(y) | Bool::constant(true); } cast_test!( diff --git a/crates/burn-cube/tests/function_call.rs b/crates/burn-cube/tests/function_call.rs index 2426e40d26..e1add5fa89 100644 --- a/crates/burn-cube/tests/function_call.rs +++ b/crates/burn-cube/tests/function_call.rs @@ -8,12 +8,12 @@ pub fn caller_no_arg(x: UInt) { #[cube] pub fn callee_no_arg() -> UInt { - UInt::new(8) + UInt::constant(8) } #[cube] pub fn no_call_no_arg(x: UInt) { - let _ = x + UInt::new(8); + let _ = x + UInt::constant(8); } #[cube] @@ -23,12 +23,12 @@ pub fn caller_with_arg(x: UInt) { #[cube] pub fn callee_with_arg(x: UInt) -> UInt { - x * UInt::new(8) + x * UInt::constant(8) } #[cube] pub fn no_call_with_arg(x: UInt) { - let _ = x + x * UInt::new(8); + let _ = x + x * UInt::constant(8); } #[cube] diff --git a/crates/burn-cube/tests/operation.rs b/crates/burn-cube/tests/operation.rs deleted file mode 100644 index 1c18c74281..0000000000 --- a/crates/burn-cube/tests/operation.rs +++ /dev/null @@ -1,180 +0,0 @@ -use burn_cube::{branch::*, cube, Bool, CubeContext, Float, Numeric, PrimitiveVariable, UInt, F32}; -use burn_jit::{ - gpu, - gpu::{Elem, Item, Variable}, -}; - -#[cube] -fn add_op(a: T, b: T) -> T { - a + b -} - -#[cube] -fn sub_op(a: T, b: T) -> T { - a - b -} - -#[cube] -fn mul_op(a: T, b: T) -> T { - a * b -} - -#[cube] -fn div_op(a: T, b: T) -> T { - a / b -} - -#[cube] -fn abs_op(a: T) -> T { - abs(a) -} - -#[cube] -fn exp_op(a: F) -> F { - exp(a) -} - -#[cube] -fn log_op(a: F) -> F { - log(a) -} - -#[cube] -fn log1p_op(a: F) -> F { - log1p(a) -} - -#[cube] -fn cos_op(a: F) -> F { - cos(a) -} - -#[cube] -fn sin_op(a: F) -> F { - sin(a) -} - -#[cube] -fn tanh_op(a: F) -> F { - tanh(a) -} - -#[cube] -fn powf_op(a: F, b: F) -> F { - powf(a, b) -} - -#[cube] -fn sqrt_op(a: F) -> F { - sqrt(a) -} - -#[cube] -fn floor_op(a: F) -> F { - floor(a) -} - -#[cube] -fn ceil_op(a: F) -> F { - ceil(a) -} - -#[cube] -fn erf_op(a: F) -> F { - erf(a) -} - -#[cube] -fn recip_op(a: F) -> F { - recip(a) -} - -#[cube] -fn equal_op(a: T, b: T) -> Bool { - a == b -} - -#[cube] -fn not_equal_op(a: T, b: T) -> Bool { - a != b -} - -#[cube] -fn lower_op(a: T, b: T) -> Bool { - a < b -} - -#[cube] -fn greater_op(a: T, b: T) -> Bool { - a > b -} - -#[cube] -fn lower_equal_op(a: T, b: T) -> Bool { - a <= b -} - -#[cube] -fn greater_equal_op(a: T, b: T) -> Bool { - a >= b -} - -#[cube] -fn clamp_op(a: T, l: T, u: T) -> T { - clamp(a, l, u) -} - -#[cube] -fn modulo_op(a: UInt, b: UInt) -> UInt { - a % b -} - -#[cube] -fn remainder_op(a: T, b: T) -> T { - rem(a, b) -} - -#[cube] -fn max_op(a: T, b: T) -> T { - max(a, b) -} - -#[cube] -fn min_op(a: T, b: T) -> T { - min(a, b) -} - -#[cube] -fn and_op(a: Bool, b: Bool) -> Bool { - a & b -} - -#[cube] -fn or_op(a: Bool, b: Bool) -> Bool { - a | b -} - -#[cube] -fn not_op(a: Bool) -> Bool { - !a -} - -#[cube] -fn bit_and_op(a: UInt, b: UInt) -> UInt { - a & b -} - -#[cube] -fn bit_or_op(a: UInt, b: UInt) -> UInt { - a | b -} - -#[cube] -fn shift_left_op(a: UInt, b: UInt) -> UInt { - a << b -} - -#[cube] -fn shift_left_op(a: UInt, b: UInt) -> UInt { - a >> b -} From 2bf7db97087d85f62ab19f682021865c8dc8f78e Mon Sep 17 00:00:00 2001 From: louisfd Date: Wed, 15 May 2024 09:45:23 -0400 Subject: [PATCH 54/54] rename new/constant to lit --- crates/burn-cube/src/element/bool.rs | 21 +++++-- crates/burn-cube/src/element/conversion.rs | 4 +- crates/burn-cube/src/element/float.rs | 12 +--- crates/burn-cube/src/element/int.rs | 12 +--- crates/burn-cube/src/element/numeric.rs | 4 +- crates/burn-cube/src/element/uint.rs | 21 +------ crates/burn-cube/src/operation/binary.rs | 4 +- crates/burn-cube/tests/cast_elem.rs | 64 +++++++++++----------- crates/burn-cube/tests/cast_kind.rs | 4 +- crates/burn-cube/tests/function_call.rs | 12 ++-- crates/burn-cube/tests/generic_kernel.rs | 2 +- crates/burn-cube/tests/if.rs | 4 +- crates/burn-cube/tests/if_else.rs | 6 +- crates/burn-cube/tests/literal.rs | 2 +- crates/burn-cube/tests/loop.rs | 8 +-- crates/burn-cube/tests/module_import.rs | 4 +- crates/burn-cube/tests/reuse.rs | 8 +-- 17 files changed, 89 insertions(+), 103 deletions(-) diff --git a/crates/burn-cube/src/element/bool.rs b/crates/burn-cube/src/element/bool.rs index 8c1d826ad5..baf6f084cc 100644 --- a/crates/burn-cube/src/element/bool.rs +++ b/crates/burn-cube/src/element/bool.rs @@ -14,18 +14,31 @@ impl CubeType for Bool { } impl Bool { - /// Create a Bool from primitive bool - pub fn constant(val: bool) -> Self { + /// Make a boolean literal + pub fn lit(val: bool) -> Self { Self { val, vectorization: 1, } } - /// Expand version of new - pub fn constant_expand(_context: &mut CubeContext, val: bool) -> ::ExpandType { + /// Expand version of lit + pub fn lit_expand(_context: &mut CubeContext, val: bool) -> ::ExpandType { val.into() } + + /// Create a Bool from primitive bool + pub fn from_primitive(val: bool) -> Self { + Self::lit(val) + } + + /// Expand version of from_primitive + pub fn from_primitive_expand( + context: &mut CubeContext, + val: bool, + ) -> ::ExpandType { + Self::lit_expand(context, val) + } } impl PrimitiveVariable for Bool { diff --git a/crates/burn-cube/src/element/conversion.rs b/crates/burn-cube/src/element/conversion.rs index 778427d5f7..438a7bd698 100644 --- a/crates/burn-cube/src/element/conversion.rs +++ b/crates/burn-cube/src/element/conversion.rs @@ -108,7 +108,7 @@ macro_rules! impl_to_bool_from_float { ($to:ident, $from1:ident) => { impl From<$from1> for $to { fn from(value: $from1) -> Self { - Self::constant(value.val() > 0.) + Self::from_primitive(value.val() > 0.) } } }; @@ -118,7 +118,7 @@ macro_rules! impl_to_bool_from_int { ($to:ident, $from1:ident) => { impl From<$from1> for $to { fn from(value: $from1) -> Self { - Self::constant(value.val() > 0) + Self::from_primitive(value.val() > 0) } } }; diff --git a/crates/burn-cube/src/element/float.rs b/crates/burn-cube/src/element/float.rs index 5fb67f6dc5..87a2cbc0cb 100644 --- a/crates/burn-cube/src/element/float.rs +++ b/crates/burn-cube/src/element/float.rs @@ -59,17 +59,11 @@ macro_rules! impl_float { impl Numeric for $type { // Method new takes an i64, because it is used when treating the float as numeric, // which must be an int in the cube kernel because new numerics need to be supported by Int as well - fn constant(val: i64) -> Self { - Self { - val: val as f64, - vectorization: 1, - } + fn lit(val: i64) -> Self { + Self::from_primitive(val as f64) } - fn constant_expand( - context: &mut CubeContext, - val: i64, - ) -> ::ExpandType { + fn lit_expand(context: &mut CubeContext, val: i64) -> ::ExpandType { ::from_primitive_expand(context, val as f64) } } diff --git a/crates/burn-cube/src/element/int.rs b/crates/burn-cube/src/element/int.rs index fde86edb48..355ef66131 100644 --- a/crates/burn-cube/src/element/int.rs +++ b/crates/burn-cube/src/element/int.rs @@ -49,17 +49,11 @@ macro_rules! impl_int { } impl Numeric for $type { - fn constant(val: i64) -> Self { - Self { - val, - vectorization: 1, - } + fn lit(val: i64) -> Self { + Self::from_primitive(val) } - fn constant_expand( - context: &mut CubeContext, - val: i64, - ) -> ::ExpandType { + fn lit_expand(context: &mut CubeContext, val: i64) -> ::ExpandType { ::from_primitive_expand(context, val) } } diff --git a/crates/burn-cube/src/element/numeric.rs b/crates/burn-cube/src/element/numeric.rs index cbf68187a0..aa59df6241 100644 --- a/crates/burn-cube/src/element/numeric.rs +++ b/crates/burn-cube/src/element/numeric.rs @@ -18,8 +18,8 @@ pub trait Numeric: /// Note: since this must work for both integer and float /// only the less expressive of both can be created (int) /// If a number with decimals is needed, use Float::from_primitive. - fn constant(val: i64) -> Self; + fn lit(val: i64) -> Self; /// Expand version of new - fn constant_expand(context: &mut CubeContext, val: i64) -> ::ExpandType; + fn lit_expand(context: &mut CubeContext, val: i64) -> ::ExpandType; } diff --git a/crates/burn-cube/src/element/uint.rs b/crates/burn-cube/src/element/uint.rs index e9cb22c514..11e21f48e4 100644 --- a/crates/burn-cube/src/element/uint.rs +++ b/crates/burn-cube/src/element/uint.rs @@ -24,18 +24,6 @@ impl UInt { ) -> ::ExpandType { (val as u32).into() } - // /// Create a UInt. Use with integer literal - // pub fn new(val: i64) -> Self { - // Self { - // val, - // vectorization: 1, - // } - // } - - // /// Expand version of new - // pub fn new_expand(_context: &mut CubeContext, val: i64) -> ::ExpandType { - // (val as u32).into() - // } } impl CubeType for UInt { @@ -53,14 +41,11 @@ impl PrimitiveVariable for UInt { } impl Numeric for UInt { - fn constant(val: i64) -> Self { - Self { - val, - vectorization: 1, - } + fn lit(val: i64) -> Self { + Self::from_primitive(val) } - fn constant_expand(context: &mut CubeContext, val: i64) -> ::ExpandType { + fn lit_expand(context: &mut CubeContext, val: i64) -> ::ExpandType { Self::from_primitive_expand(context, val) } } diff --git a/crates/burn-cube/src/operation/binary.rs b/crates/burn-cube/src/operation/binary.rs index 5647e41d57..70cd1f6bc3 100644 --- a/crates/burn-cube/src/operation/binary.rs +++ b/crates/burn-cube/src/operation/binary.rs @@ -226,7 +226,7 @@ pub mod and { type Output = Bool; fn bitand(self, rhs: Self) -> Self::Output { - Bool::constant(self.val && rhs.val) + Bool::lit(self.val && rhs.val) } } } @@ -248,7 +248,7 @@ pub mod or { type Output = Bool; fn bitor(self, rhs: Self) -> Self::Output { - Bool::constant(self.val || rhs.val) + Bool::lit(self.val || rhs.val) } } } diff --git a/crates/burn-cube/tests/cast_elem.rs b/crates/burn-cube/tests/cast_elem.rs index 25ed945bce..4b9196951d 100644 --- a/crates/burn-cube/tests/cast_elem.rs +++ b/crates/burn-cube/tests/cast_elem.rs @@ -44,26 +44,26 @@ macro_rules! cast_test { #[cube] #[allow(clippy::useless_conversion)] pub fn float_to_float(x: F32) { - let y = x + F32::constant(2); - let _ = F32::from(y) + F32::constant(34); + let y = x + F32::lit(2); + let _ = F32::from(y) + F32::lit(34); } #[cube] pub fn float_to_int(x: F32) { - let y = x + F32::constant(2); - let _ = I32::from(y) + I32::constant(34); + let y = x + F32::lit(2); + let _ = I32::from(y) + I32::lit(34); } #[cube] pub fn float_to_uint(x: F32) { - let y = x + F32::constant(2); - let _ = UInt::from(y) + UInt::constant(34); + let y = x + F32::lit(2); + let _ = UInt::from(y) + UInt::lit(34); } #[cube] pub fn float_to_bool(x: F32) { - let y = x + F32::constant(2); - let _ = Bool::from(y) | Bool::constant(true); + let y = x + F32::lit(2); + let _ = Bool::from(y) | Bool::lit(true); } cast_test!( @@ -96,27 +96,27 @@ cast_test!( // // From int #[cube] pub fn int_to_float(x: I32) { - let y = x + I32::constant(2); - let _ = F32::from(y) + F32::constant(34); + let y = x + I32::lit(2); + let _ = F32::from(y) + F32::lit(34); } #[cube] #[allow(clippy::useless_conversion)] pub fn int_to_int(x: I32) { - let y = x + I32::constant(2); - let _ = I32::from(y) + I32::constant(34); + let y = x + I32::lit(2); + let _ = I32::from(y) + I32::lit(34); } #[cube] pub fn int_to_uint(x: I32) { - let y = x + I32::constant(2); - let _ = UInt::from(y) + UInt::constant(34); + let y = x + I32::lit(2); + let _ = UInt::from(y) + UInt::lit(34); } #[cube] pub fn int_to_bool(x: I32) { - let y = x + I32::constant(2); - let _ = Bool::from(y) | Bool::constant(true); + let y = x + I32::lit(2); + let _ = Bool::from(y) | Bool::lit(true); } cast_test!( @@ -149,27 +149,27 @@ cast_test!( // // From uint #[cube] pub fn uint_to_float(x: UInt) { - let y = x + UInt::constant(2); - let _ = F32::from(y) + F32::constant(34); + let y = x + UInt::lit(2); + let _ = F32::from(y) + F32::lit(34); } #[cube] pub fn uint_to_int(x: UInt) { - let y = x + UInt::constant(2); - let _ = I32::from(y) + I32::constant(34); + let y = x + UInt::lit(2); + let _ = I32::from(y) + I32::lit(34); } #[cube] #[allow(clippy::useless_conversion)] pub fn uint_to_uint(x: UInt) { - let y = x + UInt::constant(2); - let _ = UInt::from(y) + UInt::constant(34); + let y = x + UInt::lit(2); + let _ = UInt::from(y) + UInt::lit(34); } #[cube] pub fn uint_to_bool(x: UInt) { - let y = x + UInt::constant(2); - let _ = Bool::from(y) | Bool::constant(true); + let y = x + UInt::lit(2); + let _ = Bool::from(y) | Bool::lit(true); } cast_test!( @@ -202,27 +202,27 @@ cast_test!( // From bool #[cube] pub fn bool_to_float(x: Bool) { - let y = x & Bool::constant(false); - let _ = F32::from(y) + F32::constant(34); + let y = x & Bool::lit(false); + let _ = F32::from(y) + F32::lit(34); } #[cube] pub fn bool_to_int(x: Bool) { - let y = x & Bool::constant(false); - let _ = I32::from(y) + I32::constant(34); + let y = x & Bool::lit(false); + let _ = I32::from(y) + I32::lit(34); } #[cube] pub fn bool_to_uint(x: Bool) { - let y = x & Bool::constant(false); - let _ = UInt::from(y) + UInt::constant(34); + let y = x & Bool::lit(false); + let _ = UInt::from(y) + UInt::lit(34); } #[cube] #[allow(clippy::useless_conversion)] pub fn bool_to_bool(x: Bool) { - let y = x & Bool::constant(false); - let _ = Bool::from(y) | Bool::constant(true); + let y = x & Bool::lit(false); + let _ = Bool::from(y) | Bool::lit(true); } cast_test!( diff --git a/crates/burn-cube/tests/cast_kind.rs b/crates/burn-cube/tests/cast_kind.rs index d1cf9786bb..00d90dd1b4 100644 --- a/crates/burn-cube/tests/cast_kind.rs +++ b/crates/burn-cube/tests/cast_kind.rs @@ -17,9 +17,9 @@ pub fn cast_int_kind>(input: I1) { #[cube] pub fn cast_numeric_to_kind>(input: T) { - let x = input + T::constant(5); + let x = input + T::lit(5); let y = I2::from(x); - let _ = y + I2::constant(2); + let _ = y + I2::lit(2); } #[test] diff --git a/crates/burn-cube/tests/function_call.rs b/crates/burn-cube/tests/function_call.rs index e1add5fa89..9540059669 100644 --- a/crates/burn-cube/tests/function_call.rs +++ b/crates/burn-cube/tests/function_call.rs @@ -8,12 +8,12 @@ pub fn caller_no_arg(x: UInt) { #[cube] pub fn callee_no_arg() -> UInt { - UInt::constant(8) + UInt::lit(8) } #[cube] pub fn no_call_no_arg(x: UInt) { - let _ = x + UInt::constant(8); + let _ = x + UInt::lit(8); } #[cube] @@ -23,12 +23,12 @@ pub fn caller_with_arg(x: UInt) { #[cube] pub fn callee_with_arg(x: UInt) -> UInt { - x * UInt::constant(8) + x * UInt::lit(8) } #[cube] pub fn no_call_with_arg(x: UInt) { - let _ = x + x * UInt::constant(8); + let _ = x + x * UInt::lit(8); } #[cube] @@ -38,12 +38,12 @@ pub fn caller_with_generics(x: T) { #[cube] pub fn callee_with_generics(x: T) -> T { - x * T::constant(8) + x * T::lit(8) } #[cube] pub fn no_call_with_generics(x: T) { - let _ = x + x * T::constant(8); + let _ = x + x * T::lit(8); } #[test] diff --git a/crates/burn-cube/tests/generic_kernel.rs b/crates/burn-cube/tests/generic_kernel.rs index f037f3088f..f555ec7938 100644 --- a/crates/burn-cube/tests/generic_kernel.rs +++ b/crates/burn-cube/tests/generic_kernel.rs @@ -3,7 +3,7 @@ use burn_jit::{gpu, gpu::Item}; #[cube] pub fn generic_kernel(lhs: T) { - let _ = lhs + T::constant(5); + let _ = lhs + T::lit(5); } #[test] diff --git a/crates/burn-cube/tests/if.rs b/crates/burn-cube/tests/if.rs index d631926ce7..7f193182a8 100644 --- a/crates/burn-cube/tests/if.rs +++ b/crates/burn-cube/tests/if.rs @@ -8,8 +8,8 @@ type ElemType = F32; #[cube] pub fn if_greater(lhs: T) { - if lhs > T::constant(0) { - let _ = lhs + T::constant(4); + if lhs > T::lit(0) { + let _ = lhs + T::lit(4); } } diff --git a/crates/burn-cube/tests/if_else.rs b/crates/burn-cube/tests/if_else.rs index 7a462c5e13..ba121f9f4a 100644 --- a/crates/burn-cube/tests/if_else.rs +++ b/crates/burn-cube/tests/if_else.rs @@ -8,10 +8,10 @@ type ElemType = F32; #[cube] pub fn if_then_else(lhs: F) { - if lhs < F::constant(0) { - let _ = lhs + F::constant(4); + if lhs < F::lit(0) { + let _ = lhs + F::lit(4); } else { - let _ = lhs - F::constant(5); + let _ = lhs - F::lit(5); } } diff --git a/crates/burn-cube/tests/literal.rs b/crates/burn-cube/tests/literal.rs index e7acfb73fe..39778fef9d 100644 --- a/crates/burn-cube/tests/literal.rs +++ b/crates/burn-cube/tests/literal.rs @@ -5,7 +5,7 @@ type ElemType = F32; #[cube] pub fn literal(lhs: F) { - let _ = lhs + F::constant(5); + let _ = lhs + F::lit(5); } #[cube] diff --git a/crates/burn-cube/tests/loop.rs b/crates/burn-cube/tests/loop.rs index c28fd6ae18..b3530bac2f 100644 --- a/crates/burn-cube/tests/loop.rs +++ b/crates/burn-cube/tests/loop.rs @@ -7,18 +7,18 @@ type ElemType = I32; #[cube] pub fn while_not(lhs: I) { - while lhs != I::constant(0) { - let _ = lhs % I::constant(1); + while lhs != I::lit(0) { + let _ = lhs % I::lit(1); } } #[cube] pub fn manual_loop_break(lhs: I) { loop { - if lhs != I::constant(0) { + if lhs != I::lit(0) { break; } - let _ = lhs % I::constant(1); + let _ = lhs % I::lit(1); } } diff --git a/crates/burn-cube/tests/module_import.rs b/crates/burn-cube/tests/module_import.rs index 5eabe8b3ae..04633df163 100644 --- a/crates/burn-cube/tests/module_import.rs +++ b/crates/burn-cube/tests/module_import.rs @@ -8,7 +8,7 @@ mod elsewhere { #[cube] pub fn my_func(x: F) -> F { - x * F::constant(2) + x * F::lit(2) } } @@ -24,7 +24,7 @@ mod here { #[cube] pub fn no_call_ref(x: F) { - let _ = x + x * F::constant(2); + let _ = x + x * F::lit(2); } } diff --git a/crates/burn-cube/tests/reuse.rs b/crates/burn-cube/tests/reuse.rs index 936e5d7ba4..f22d401ce3 100644 --- a/crates/burn-cube/tests/reuse.rs +++ b/crates/burn-cube/tests/reuse.rs @@ -12,15 +12,15 @@ pub fn reuse(mut x: I) { // a += b is more efficient than a = a + b // Because the latter does not assume that a is the same in lhs and rhs // Normally clippy should detect it - while x < I::constant(10) { - x = x + I::constant(1); + while x < I::lit(10) { + x = x + I::lit(1); } } #[cube] pub fn reuse_incr(mut x: I) { - while x < I::constant(10) { - x += I::constant(1); + while x < I::lit(10) { + x += I::lit(1); } }