diff --git a/Cargo.lock b/Cargo.lock index a04573b5eb..c6fa6919dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -522,6 +522,26 @@ dependencies = [ "thiserror", ] +[[package]] +name = "burn-cube" +version = "0.15.0" +dependencies = [ + "burn-cube-macros", + "burn-jit", + "derive-new", + "log", +] + +[[package]] +name = "burn-cube-macros" +version = "0.15.0" +dependencies = [ + "derive-new", + "proc-macro2", + "quote", + "syn 2.0.61", +] + [[package]] name = "burn-dataset" version = "0.15.0" @@ -4317,14 +4337,6 @@ dependencies = [ "thiserror", ] -[[package]] -name = "refactor" -version = "0.15.0" -dependencies = [ - "burn", - "serde", -] - [[package]] name = "regex" version = "1.10.4" diff --git a/crates/burn-cube-macros/Cargo.toml b/crates/burn-cube-macros/Cargo.toml new file mode 100644 index 0000000000..11757df219 --- /dev/null +++ b/crates/burn-cube-macros/Cargo.toml @@ -0,0 +1,27 @@ +[package] +authors = [ + "nathanielsimard ", + "louisfd bool { + if self.num_used > 1 { + self.num_used -= 1; + true + } else { + self.loop_level_declared < loop_level + } + } +} + +#[derive(Debug)] +/// Information about all variables in the Cube code, transmitted to codegen +pub(crate) struct CodeAnalysis { + pub variable_analyses: HashMap, +} + +#[derive(Debug, Default)] +/// Reads the Cube code and accumulates information, to generate a CodeAnalysis artefact +pub(crate) struct CodeAnalysisBuilder { + declarations: Vec<(VariableKey, usize)>, + var_uses: Vec, +} + +impl CodeAnalysis { + pub fn should_clone(&mut self, ident: &syn::Ident, loop_level: usize) -> bool { + let key: VariableKey = ident.into(); + match self.variable_analyses.remove(&key) { + Some(mut var) => { + let should_clone = var.should_clone(loop_level); + self.variable_analyses.insert(key, var); + should_clone + } + None => panic!("Ident {ident} not part of analysis"), + } + } + + pub fn create(func: &syn::ItemFn) -> CodeAnalysis { + let code_analysis_builder = CodeAnalysisBuilder::default(); + code_analysis_builder.analyze(func) + } +} + +impl CodeAnalysisBuilder { + fn analyze(mut self, func: &syn::ItemFn) -> CodeAnalysis { + // Build the vector of (Id, depth), using recursion + self.signature_declarations(&func.sig); + self.find_occurrences_in_stmts(&func.block.stmts, 0); + + CodeAnalysis { + variable_analyses: self.to_map(), + } + } + + fn to_map(&self) -> HashMap { + // Run through the vec and build hashmap, without recursion + let mut variable_analyses = HashMap::::new(); + for declaration in self.declarations.iter() { + let id = declaration.0.clone(); + let new_analysis = match variable_analyses.remove(&id) { + Some(_) => { + panic!("Analysis: Multiple variables with the same identifier is not supported") + } + None => VariableAnalysis { + num_used: 0, + loop_level_declared: declaration.1, + }, + }; + + variable_analyses.insert(id, new_analysis); + } + + for id in self.var_uses.iter() { + let prev_analysis = variable_analyses.remove(id).unwrap_or_else(|| { + panic!( + "Analysis: Variable {:?} should be declared before it's used", + id + ) + }); + let new_analysis = VariableAnalysis { + num_used: prev_analysis.num_used + 1, + loop_level_declared: prev_analysis.loop_level_declared, + }; + variable_analyses.insert(id.clone(), new_analysis); + } + + variable_analyses + } + + fn signature_declarations(&mut self, sig: &syn::Signature) { + for input in &sig.inputs { + match input { + syn::FnArg::Typed(pat) => { + let ident = &*pat.pat; + match ident { + syn::Pat::Ident(pat_ident) => { + let id = &pat_ident.ident; + self.declarations.push((id.into(), 0)); + } + _ => todo!("Analysis: unsupported ident {ident:?}"), + } + } + _ => todo!("Analysis: unsupported input {input:?}"), + } + } + } + + fn find_occurrences_in_stmts(&mut self, stmts: &Vec, depth: usize) { + for stmt in stmts { + match stmt { + // Declaration + syn::Stmt::Local(local) => { + let id = match &local.pat { + syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident), + syn::Pat::Type(pat_type) => Some(match &*pat_type.pat { + syn::Pat::Ident(pat_ident) => &pat_ident.ident, + _ => todo!("Analysis: unsupported typed path {:?}", pat_type.pat), + }), + syn::Pat::Wild(_) => None, + _ => todo!("Analysis: unsupported path {:?}", local.pat), + }; + if let Some(id) = id { + self.declarations.push((id.into(), depth)); + } + if let Some(local_init) = &local.init { + self.find_occurrences_in_expr(&local_init.expr, depth) + } + } + syn::Stmt::Expr(expr, _) => self.find_occurrences_in_expr(expr, depth), + _ => todo!("Analysis: unsupported stmt {stmt:?}"), + } + } + } + + fn find_occurrences_in_expr(&mut self, expr: &syn::Expr, depth: usize) { + match expr { + syn::Expr::ForLoop(expr) => { + let depth = depth + 1; + + // Declaration of iterator + if let syn::Pat::Ident(pat_ident) = &*expr.pat { + let id = &pat_ident.ident; + self.declarations.push((id.into(), depth)); + } + + self.find_occurrences_in_stmts(&expr.body.stmts, depth); + } + syn::Expr::While(expr) => { + let depth = depth + 1; + + self.find_occurrences_in_expr(&expr.cond, depth); + self.find_occurrences_in_stmts(&expr.body.stmts, depth); + } + syn::Expr::Loop(expr) => { + let depth = depth + 1; + + self.find_occurrences_in_stmts(&expr.body.stmts, depth); + } + syn::Expr::If(expr) => { + let depth = depth + 1; + + self.find_occurrences_in_expr(&expr.cond, depth); + self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth); + if let Some((_, expr)) = &expr.else_branch { + if let syn::Expr::Block(expr_block) = &**expr { + self.find_occurrences_in_stmts(&expr_block.block.stmts, depth); + } else { + todo!("Analysis: Only block else expr is supported") + } + } + } + syn::Expr::Assign(expr) => { + self.find_occurrences_in_expr(&expr.left, depth); + self.find_occurrences_in_expr(&expr.right, depth); + } + syn::Expr::Index(expr) => { + self.find_occurrences_in_expr(&expr.expr, depth); + self.find_occurrences_in_expr(&expr.index, depth); + } + syn::Expr::Path(expr) => { + let ident = expr + .path + .get_ident() + .expect("Analysis: only ident path are supported."); + + // Use + self.var_uses.push(ident.into()); + } + syn::Expr::Binary(expr) => { + self.find_occurrences_in_expr(&expr.left, depth); + self.find_occurrences_in_expr(&expr.right, depth); + } + syn::Expr::Lit(_) => {} + syn::Expr::Call(expr) => { + match &*expr.func { + syn::Expr::Path(expr_path) => { + if let Some(first_segment) = expr_path.path.segments.first() { + // Check if the path segment has generic arguments + if let PathArguments::AngleBracketed(arguments) = + &first_segment.arguments + { + // Extract the generic arguments + for arg in &arguments.args { + match arg { + syn::GenericArgument::Type(_) + | syn::GenericArgument::Constraint(_) => {} + _ => todo!("Analysis: Generic {:?} not supported", arg), + } + } + } + } + } + _ => todo!("Analysis: unsupported func expr {:?}", expr.func), + } + for arg in expr.args.iter() { + self.find_occurrences_in_expr(arg, depth); + } + } + syn::Expr::MethodCall(expr) => { + self.find_occurrences_in_expr(&expr.receiver, depth); + for arg in expr.args.iter() { + self.find_occurrences_in_expr(arg, depth); + } + } + syn::Expr::Break(_) => {} + syn::Expr::Paren(expr) => self.find_occurrences_in_expr(&expr.expr, depth), + _ => todo!("Analysis: unsupported expr {expr:?}"), + } + } +} diff --git a/crates/burn-cube-macros/src/codegen/base.rs b/crates/burn-cube-macros/src/codegen/base.rs new file mode 100644 index 0000000000..857d3bb518 --- /dev/null +++ b/crates/burn-cube-macros/src/codegen/base.rs @@ -0,0 +1,89 @@ +use proc_macro2::TokenStream; + +use crate::analysis::CodeAnalysis; + +use super::{ + branch::{codegen_break, codegen_for_loop, codegen_if, codegen_loop, codegen_while_loop}, + function::{codegen_call, codegen_closure, codegen_expr_method_call}, + operation::codegen_binary, + variable::{codegen_assign, codegen_index, codegen_lit, codegen_local, codegen_path_rhs}, +}; + +/// Codegen for a statement (generally one line) +/// Entry point of code generation +pub fn codegen_statement( + statement: &syn::Stmt, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + match statement { + syn::Stmt::Local(local) => codegen_local(local, loop_level, variable_analyses), + syn::Stmt::Expr(expr, semi) => { + let expr = codegen_expr(expr, loop_level, variable_analyses); + match semi { + Some(_semi) => quote::quote!( + #expr; + ), + None => expr, + } + } + _ => todo!("Codegen: statement {statement:?} not supported"), + } +} + +/// Codegen for a code block (a list of statements) +pub(crate) fn codegen_block( + block: &syn::Block, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let mut statements = quote::quote!(); + + for statement in block.stmts.iter() { + statements.extend(codegen_statement(statement, loop_level, variable_analyses)); + } + + quote::quote! { + { + #statements + } + } +} + +/// Codegen for an expression containing a block +pub(crate) fn codegen_expr_block( + block: &syn::ExprBlock, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + codegen_block(&block.block, loop_level, variable_analyses) +} + +/// Codegen for expressions +/// There are many variants of expression, treated differently +pub(crate) fn codegen_expr( + expr: &syn::Expr, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + match expr { + syn::Expr::Binary(op) => codegen_binary(op, loop_level, variable_analyses), + syn::Expr::Path(path) => codegen_path_rhs(path, loop_level, variable_analyses), + syn::Expr::Call(call) => codegen_call(call, loop_level, variable_analyses), + syn::Expr::Lit(lit) => codegen_lit(lit), + syn::Expr::Closure(closure) => codegen_closure(closure, loop_level, variable_analyses), + syn::Expr::Block(block) => codegen_expr_block(block, loop_level, variable_analyses), + syn::Expr::Assign(assign) => codegen_assign(assign, loop_level, variable_analyses), + syn::Expr::ForLoop(for_loop) => codegen_for_loop(for_loop, loop_level, variable_analyses), + syn::Expr::While(while_loop) => { + codegen_while_loop(while_loop, loop_level, variable_analyses) + } + syn::Expr::Loop(loop_expr) => codegen_loop(loop_expr, loop_level, variable_analyses), + syn::Expr::Break(_) => codegen_break(), + syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level, variable_analyses), + syn::Expr::MethodCall(call) => codegen_expr_method_call(call), + syn::Expr::Index(index) => codegen_index(index, loop_level, variable_analyses), + syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_analyses), + _ => panic!("Codegen: Unsupported {:?}", expr), + } +} diff --git a/crates/burn-cube-macros/src/codegen/branch.rs b/crates/burn-cube-macros/src/codegen/branch.rs new file mode 100644 index 0000000000..96318a9fa9 --- /dev/null +++ b/crates/burn-cube-macros/src/codegen/branch.rs @@ -0,0 +1,126 @@ +use proc_macro2::TokenStream; + +use crate::{analysis::CodeAnalysis, codegen::base::codegen_expr}; + +use super::{base::codegen_block, operation::codegen_binary, variable::codegen_lit}; + +/// Codegen of for loops +/// Supports range: +/// for i in range(start, end, unroll) {...} +pub(crate) fn codegen_for_loop( + for_loop: &syn::ExprForLoop, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let i = &for_loop.pat; + let block = codegen_block(&for_loop.body, loop_level + 1, variable_analyses); + + match for_loop.expr.as_ref() { + syn::Expr::Call(call) => { + let func_name = match call.func.as_ref() { + syn::Expr::Path(path) => path + .path + .get_ident() + .expect("Codegen: func in for loop should have ident"), + _ => todo!("Codegen: Only path call supported"), + }; + + if &func_name.to_string() == "range" { + let mut args = quote::quote! { + context, + }; + + for argument in call.args.iter() { + let arg = codegen_expr(argument, loop_level, variable_analyses); + args.extend(quote::quote! { #arg, }); + } + + quote::quote! { + range_expand(#args |context, #i| #block); + } + } else { + todo!("Codegen: Only range is supported") + } + } + _ => todo!("Codegen: Only call is supported {for_loop:?}"), + } +} + +/// Codegen for condition of an if or a while +pub(crate) fn codegen_cond( + cond: &syn::Expr, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + match cond { + syn::Expr::Binary(expr) => codegen_binary(expr, loop_level, variable_analyses), + syn::Expr::Lit(expr) => codegen_lit(expr), + _ => todo!("{cond:?} cond not supported"), + } +} + +/// Codegen for break statement +pub(crate) fn codegen_break() -> TokenStream { + quote::quote! { + break_expand(context); + } +} + +/// Codegen for if and if/else statements +/// Supports: +/// if cond {...} +/// if cond {...} else {...} +pub(crate) fn codegen_if( + expr_if: &syn::ExprIf, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let cond = codegen_cond(&expr_if.cond, loop_level, variable_analyses); + + let then_block = codegen_block(&expr_if.then_branch, loop_level + 1, variable_analyses); + + if let Some((_, expr)) = &expr_if.else_branch { + if let syn::Expr::Block(expr_block) = &**expr { + let else_block = codegen_block(&expr_block.block, loop_level + 1, variable_analyses); + + quote::quote! { + let _cond = #cond; + if_else_expand(context, _cond, |context| #then_block, |context| #else_block); + } + } else { + todo!("Analysis: Only block else expr is supported") + } + } else { + quote::quote! { + let _cond = #cond; + if_expand(context, _cond, |context| #then_block); + } + } +} + +/// Codegen of loop +pub(crate) fn codegen_loop( + loop_expr: &syn::ExprLoop, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let block = codegen_block(&loop_expr.body, loop_level + 1, variable_analyses); + + quote::quote! { + loop_expand(context, |context| #block); + } +} + +/// Codegen for while loop +pub(crate) fn codegen_while_loop( + while_loop: &syn::ExprWhile, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let cond = codegen_cond(&while_loop.cond, loop_level + 1, variable_analyses); + let block = codegen_block(&while_loop.body, loop_level + 1, variable_analyses); + + quote::quote! { + while_loop_expand(context, |context| #cond, |context| #block); + } +} diff --git a/crates/burn-cube-macros/src/codegen/function.rs b/crates/burn-cube-macros/src/codegen/function.rs new file mode 100644 index 0000000000..c680aac563 --- /dev/null +++ b/crates/burn-cube-macros/src/codegen/function.rs @@ -0,0 +1,98 @@ +use proc_macro2::TokenStream; +use quote::quote_spanned; +use syn::PathArguments; + +use crate::{analysis::CodeAnalysis, codegen::base::codegen_expr}; + +/// Codegen for method call +pub(crate) fn codegen_expr_method_call(call: &syn::ExprMethodCall) -> TokenStream { + quote::quote!( #call ) +} + +/// Codegen for a closure +pub(crate) fn codegen_closure( + closure: &syn::ExprClosure, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let mut inputs = quote::quote! {}; + for input in closure.inputs.iter() { + let ident = match input { + syn::Pat::Ident(ident) => &ident.ident, + _ => panic!("Codegen: Unsupported {:?}", input), + }; + inputs.extend(quote::quote! { + #ident, + }); + } + + let body = codegen_expr(closure.body.as_ref(), loop_level, variable_analyses); + + quote::quote! { + |context, #inputs| #body + } +} + +/// Codegen for a function call +/// Supports: +/// func() +/// func::() +/// T::func() +pub(crate) fn codegen_call( + call: &syn::ExprCall, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + // We start with parsing the function path + let (mut idents, generics) = match call.func.as_ref() { + syn::Expr::Path(expr_path) => { + let mut idents = Vec::new(); + let mut generics = None; + for (index, segment) in expr_path.path.segments.iter().enumerate() { + idents.push(&segment.ident); + + if index == expr_path.path.segments.len() - 1 { + if let PathArguments::AngleBracketed(arguments) = &segment.arguments { + generics = Some(arguments) + } + } + } + (idents, generics) + } + _ => todo!("Codegen: func call {:?} not supported", call.func), + }; + + // Function name with support for longer path + let func_name = idents + .pop() + .expect("Codegen: Func should have at least one ident"); + + let mut previous_tokens = TokenStream::new(); + for ident in idents.iter() { + previous_tokens.extend(quote_spanned! {ident.span() => #ident :: }); + } + let func_name_expand = syn::Ident::new( + format!("{func_name}_expand").as_str(), + proc_macro2::Span::call_site(), + ); + + // Generics + let generics = match generics { + Some(generics) => quote::quote! { #generics }, + None => quote::quote! {}, + }; + + // Arguments + let mut args = quote::quote! { + context, + }; + for argument in call.args.iter() { + let arg = codegen_expr(argument, loop_level, variable_analyses); + args.extend(quote::quote! { #arg, }); + } + + // Codegen + quote::quote! { + #previous_tokens #func_name_expand #generics (#args) + } +} diff --git a/crates/burn-cube-macros/src/codegen/mod.rs b/crates/burn-cube-macros/src/codegen/mod.rs new file mode 100644 index 0000000000..c5e5f757f3 --- /dev/null +++ b/crates/burn-cube-macros/src/codegen/mod.rs @@ -0,0 +1,7 @@ +mod base; +mod branch; +mod function; +mod operation; +mod variable; + +pub(crate) use base::codegen_statement; diff --git a/crates/burn-cube-macros/src/codegen/operation.rs b/crates/burn-cube-macros/src/codegen/operation.rs new file mode 100644 index 0000000000..76924c080b --- /dev/null +++ b/crates/burn-cube-macros/src/codegen/operation.rs @@ -0,0 +1,98 @@ +use proc_macro2::TokenStream; + +use crate::analysis::CodeAnalysis; + +use super::base::codegen_expr; + +/// Codegen for binary operations (+, -, *, etc.) +pub(crate) fn codegen_binary( + binary: &syn::ExprBinary, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let lhs = codegen_expr(&binary.left, loop_level, variable_analyses); + let rhs = codegen_expr(&binary.right, loop_level, variable_analyses); + + match binary.op { + syn::BinOp::Add(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::add::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::Sub(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::sub::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::Mul(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::mul::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::Div(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::div::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::Rem(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::rem::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::Ne(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::ne::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::Gt(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::gt::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::Lt(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::lt::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::AddAssign(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::add_assign_op::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::BitAnd(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::and::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::And(_) => unimplemented!("Logical and (&&) not overridable in Rust due to its short circuiting nature. Use bitwise instead (&). "), + syn::BinOp::BitOr(_) => quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::or::expand(context, _lhs, _rhs) + } + }, + syn::BinOp::Or(_) => unimplemented!("Logical or (||) not overridable in Rust due to its short circuiting nature. Use bitwise instead (|). "), + _ => todo!("Codegen: unsupported op {:?}", binary.op), + } +} diff --git a/crates/burn-cube-macros/src/codegen/variable.rs b/crates/burn-cube-macros/src/codegen/variable.rs new file mode 100644 index 0000000000..769bc6dff0 --- /dev/null +++ b/crates/burn-cube-macros/src/codegen/variable.rs @@ -0,0 +1,141 @@ +use proc_macro2::TokenStream; +use quote::ToTokens; +use syn::Lit; + +use crate::{analysis::CodeAnalysis, codegen::base::codegen_expr}; + +/// Codegen for literals +pub(crate) fn codegen_lit(lit: &syn::ExprLit) -> TokenStream { + match lit.lit { + // We treat floats differently to avoid getting 4..into() for instance + Lit::Float(_) => { + let lit_str = lit.lit.to_token_stream().to_string(); + let float_lit = lit_str.parse::().unwrap(); + quote::quote! { #float_lit.into() } + } + _ => { + quote::quote! { #lit.into() } + } + } +} + +/// Codegen for a local declaration (let ...) +/// Supports: +/// let x = ... +/// let x: T = ... +/// let _ = ... +pub(crate) fn codegen_local( + local: &syn::Local, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let let_tok = local.let_token; + + let ident = match &local.pat { + syn::Pat::Ident(ident) => ident.to_token_stream(), + syn::Pat::Type(pat_type) => match &*pat_type.pat { + syn::Pat::Ident(pat_ident) => pat_ident.to_token_stream(), + _ => todo!("Codegen: Unsupported typed path {:?}", pat_type.pat), + }, + syn::Pat::Wild(wild) => wild.underscore_token.to_token_stream(), + _ => todo!("Codegen: Declaration {:?} is unsupported.", local.pat), + }; + + match local.init.as_ref() { + Some(init) => { + let init = codegen_expr(&init.expr, loop_level, variable_analyses); + + quote::quote! { + #let_tok #ident = #init; + } + } + None => { + quote::quote! { + #let_tok #ident; + } + } + } +} + +/// Codegen for indexed access +pub(crate) fn codegen_index( + index: &syn::ExprIndex, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let array = codegen_expr(&index.expr, loop_level, variable_analyses); + let index = codegen_expr(&index.index, loop_level, variable_analyses); + + quote::quote! { + { + let _array = #array; + let _index = #index; + burn_cube::index::expand(context, _array, _index) + } + } +} + +/// Codegen for assignation +/// Supports: +/// - scalar +/// - indexed array +pub(crate) fn codegen_assign( + assign: &syn::ExprAssign, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + match assign.left.as_ref() { + syn::Expr::Index(index) => { + let array = codegen_expr(&index.expr, loop_level, variable_analyses); + let index = codegen_expr(&index.index, loop_level, variable_analyses); + let value = codegen_expr(&assign.right, loop_level, variable_analyses); + + quote::quote! { + { + let _array = #array; + let _index = #index; + let _value = #value; + burn_cube::index_assign::expand(context, _array, _index, _value) + } + } + } + syn::Expr::Path(_) => { + let lhs = codegen_expr(&assign.left, loop_level, variable_analyses); + let rhs = codegen_expr(&assign.right, loop_level, variable_analyses); + + quote::quote! { + { + let _assign_lhs = #lhs; + let _assign_rhs = #rhs; + burn_cube::assign::expand(context, _assign_rhs, _assign_lhs) + } + } + } + _ => todo!("Assign of expr {:?} unsupported", assign.left), + } +} + +/// Codegen for a variable used in rhs of a statement +/// This function adds cloning when necessary +pub(crate) fn codegen_path_rhs( + path: &syn::ExprPath, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let ident = path + .path + .get_ident() + .expect("Codegen: Only ident path are supported."); + + let will_be_used_again = variable_analyses.should_clone(ident, loop_level); + + if will_be_used_again { + quote::quote! { + #ident.clone() + } + } else { + quote::quote! { + #ident + } + } +} diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs new file mode 100644 index 0000000000..ffb7f6f4b4 --- /dev/null +++ b/crates/burn-cube-macros/src/lib.rs @@ -0,0 +1,130 @@ +mod analysis; +mod codegen; + +use analysis::CodeAnalysis; +use codegen::codegen_statement; +use proc_macro::TokenStream; +use quote::ToTokens; +use syn::{parse_macro_input, punctuated::Punctuated, token::Comma, Meta}; + +enum CubeMode { + /// Generates the expanded version of the function + Default, + /// Panics and prints the generated code, useful when debugging + /// Use by writing #[cube(panic)] + Debug, +} + +/// Derive macro for the module. +#[proc_macro_attribute] +pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream { + let args = parse_macro_input!(attr with Punctuated::::parse_terminated); + let mode = parse_mode(args); + + let func: syn::ItemFn = syn::parse(tokens).unwrap(); + let mut variable_analyses = CodeAnalysis::create(&func); + + let code = codegen_cube(&func, &mut variable_analyses); + match mode { + CubeMode::Default => code, + CubeMode::Debug => panic!("{code}"), + } +} + +fn parse_mode(args: Punctuated) -> CubeMode { + let mut mode = CubeMode::Default; + + if let Some(arg) = args.first() { + match arg { + Meta::Path(path) => { + if let Some(ident) = path.get_ident().map(|id| id.to_string()) { + match ident.as_str() { + "debug" => { + mode = CubeMode::Debug; + } + _ => panic!("Attribute {ident} is not supported"), + } + } else { + panic!("Only ident attribute supported"); + } + } + Meta::List(_) => panic!("No List attribute supported"), + Meta::NameValue(_) => panic!("No NameValue attribute supported"), + } + } + + mode +} + +#[derive(Hash, PartialEq, Eq, Debug, Clone)] +struct VariableKey { + name: String, +} + +impl From<&syn::Ident> for VariableKey { + fn from(value: &syn::Ident) -> Self { + VariableKey { + name: value.to_string(), + } + } +} + +/// Generate the expanded version of a function marked with the cube macro +fn codegen_cube(func: &syn::ItemFn, code_analysis: &mut CodeAnalysis) -> TokenStream { + let signature = expand_sig(&func.sig); + let mut body = quote::quote! {}; + + for statement in func.block.stmts.iter() { + let tokens = codegen_statement(statement, 0, code_analysis); + body.extend(tokens); + } + + quote::quote! { + #[allow(dead_code)] + #func + + #[allow(unused_mut)] + #signature { + #body + } + } + .into() +} + +fn expand_sig(sig: &syn::Signature) -> proc_macro2::TokenStream { + let mut inputs = quote::quote!(); + + for input in &sig.inputs { + match input { + syn::FnArg::Typed(pat) => { + let ty = &pat.ty; + let ident = pat.pat.clone(); + + inputs.extend(quote::quote! { + #ident: <#ty as burn_cube::CubeType>::ExpandType, + }); + } + _ => todo!("Only Typed inputs are supported"), + } + } + + let mut output = quote::quote!(); + + match &sig.output { + syn::ReturnType::Default => output.extend(quote::quote! { ()}), + syn::ReturnType::Type(_, ty) => { + output.extend(quote::quote! { + <#ty as burn_cube::CubeType>::ExpandType + }); + } + } + + let ident = &sig.ident; + let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span()); + + let generics = sig.generics.clone().into_token_stream(); + + quote::quote! { + pub fn #ident #generics (context: &mut burn_cube::CubeContext, #inputs) -> #output + } +} diff --git a/crates/burn-cube/Cargo.toml b/crates/burn-cube/Cargo.toml new file mode 100644 index 0000000000..9302568ce9 --- /dev/null +++ b/crates/burn-cube/Cargo.toml @@ -0,0 +1,27 @@ +[package] +authors = [ + "nathanielsimard ", + "louisfd (start: S, end: E, _unroll: bool) -> core::ops::Range +where + S: Into, + E: Into, +{ + let start: UInt = start.into(); + let end: UInt = end.into(); + + core::ops::Range { + start: start.val as usize, + end: end.val as usize, + } +} + +pub fn range_expand( + context: &mut CubeContext, + start: ExpandElement, + end: ExpandElement, + unroll: bool, + mut func: F, +) where + F: FnMut(&mut CubeContext, ExpandElement), +{ + if unroll { + let start = match start.deref() { + Variable::ConstantScalar(val, _) => *val as usize, + _ => panic!("Only constant start can be unrolled."), + }; + let end = match end.deref() { + Variable::ConstantScalar(val, _) => *val as usize, + _ => panic!("Only constant end can be unrolled."), + }; + + for i in start..end { + func(context, i.into()) + } + } else { + let mut child = context.child(); + let index_ty = Item::Scalar(Elem::UInt); + let i = child.scope.borrow_mut().create_local_undeclared(index_ty); + let i = ExpandElement::new(Rc::new(i)); + + func(&mut child, i.clone()); + + context.register(Branch::RangeLoop(gpu::RangeLoop { + i: *i, + start: *start, + end: *end, + scope: child.into_scope(), + })); + } +} + +pub fn if_expand(context: &mut CubeContext, cond: ExpandElement, mut block: IF) +where + IF: FnMut(&mut CubeContext), +{ + let mut child = context.child(); + + block(&mut child); + + context.register(Branch::If(gpu::If { + cond: *cond, + scope: child.into_scope(), + })); +} + +pub fn if_else_expand( + context: &mut CubeContext, + cond: ExpandElement, + mut then_block: IF, + mut else_block: EL, +) where + IF: FnMut(&mut CubeContext), + EL: FnMut(&mut CubeContext), +{ + let mut then_child = context.child(); + then_block(&mut then_child); + + let mut else_child = context.child(); + else_block(&mut else_child); + + context.register(Branch::IfElse(gpu::IfElse { + cond: *cond, + scope_if: then_child.into_scope(), + scope_else: else_child.into_scope(), + })); +} + +pub fn break_expand(context: &mut CubeContext) { + context.register(Branch::Break); +} + +pub fn loop_expand(context: &mut CubeContext, mut block: FB) +where + FB: FnMut(&mut CubeContext), +{ + let mut inside_loop = context.child(); + + block(&mut inside_loop); + context.register(Branch::Loop(gpu::Loop { + scope: inside_loop.into_scope(), + })); +} + +pub fn while_loop_expand(context: &mut CubeContext, mut cond_fn: FC, mut block: FB) +where + FC: FnMut(&mut CubeContext) -> ExpandElement, + FB: FnMut(&mut CubeContext), +{ + let mut inside_loop = context.child(); + + let cond: ExpandElement = cond_fn(&mut inside_loop); + if_expand(&mut inside_loop, cond, break_expand); + + block(&mut inside_loop); + context.register(Branch::Loop(gpu::Loop { + scope: inside_loop.into_scope(), + })); +} diff --git a/crates/burn-cube/src/context.rs b/crates/burn-cube/src/context.rs new file mode 100644 index 0000000000..00afd24387 --- /dev/null +++ b/crates/burn-cube/src/context.rs @@ -0,0 +1,107 @@ +use crate::ExpandElement; +use alloc::rc::Rc; +use burn_jit::gpu::{self, Item, Scope}; +use core::cell::RefCell; +use std::collections::HashMap; + +#[derive(Default, Clone)] +pub struct VariablePool { + map: Rc>>>, +} + +impl VariablePool { + /// Returns an old, not used anymore variable, if there exists one. + pub fn reuse(&self, item: Item) -> Option { + let map = self.map.borrow(); + + // Filter for candidate variables of the same Item + let variables = match map.get(&item) { + Some(val) => val, + None => return None, + }; + + // Among the candidates, take a variable if it's only referenced by the map + // Arbitrarily takes the first it finds + for variable in variables.iter() { + if Rc::strong_count(&variable.inner) == 1 { + // println!("Reuse var {:?}", variable.inner); + return Some(variable.clone()); + } + } + + // If no candidate was found, a new var will be needed + None + } + + /// Insert a new variable in the map, which is classified by Item + pub fn insert(&mut self, var: ExpandElement) { + let mut map = self.map.borrow_mut(); + let item = var.item(); + + if let Some(variables) = map.get_mut(&item) { + variables.push(var.clone()); + } else { + map.insert(var.item(), vec![var.clone()]); + } + } +} + +pub struct CubeContext { + pub root: Rc>, + pub scope: Rc>, + pub pool: VariablePool, +} + +impl CubeContext { + /// Create a new cube context, with a root scope + /// A root scope is at the root of a compute shader + /// Therefore there is one cube context per shader + pub fn root() -> CubeContext { + let root = Rc::new(RefCell::new(Scope::root())); + let scope = root.clone(); + + Self { + pool: Default::default(), + scope, + root, + } + } + + pub fn register>(&mut self, op: O) { + self.scope.borrow_mut().register(op) + } + + pub fn child(&mut self) -> CubeContext { + let scope = self.scope.borrow_mut().child(); + + Self { + scope: Rc::new(RefCell::new(scope)), + root: self.root.clone(), + pool: self.pool.clone(), + } + } + + pub fn into_scope(self) -> Scope { + core::mem::drop(self.root); + + Rc::into_inner(self.scope) + .expect("Only one reference") + .into_inner() + } + + /// When a new variable is required, we check if we can reuse an old one + /// Otherwise we create a new one. + pub fn create_local(&mut self, item: Item) -> ExpandElement { + // Reuse an old variable if possible + if let Some(var) = self.pool.reuse(item) { + return var; + } + + // Create a new variable at the root scope + // Insert it in the variable pool for potential reuse + let new = ExpandElement::new(Rc::new(self.root.borrow_mut().create_local(item))); + self.pool.insert(new.clone()); + + new + } +} diff --git a/crates/burn-cube/src/element/array.rs b/crates/burn-cube/src/element/array.rs new file mode 100644 index 0000000000..e1c360b2ab --- /dev/null +++ b/crates/burn-cube/src/element/array.rs @@ -0,0 +1,10 @@ +use crate::{CubeType, ExpandElement}; + +#[derive(new, Clone)] +pub struct Array { + pub vals: Vec, +} + +impl CubeType for Array { + type ExpandType = ExpandElement; +} diff --git a/crates/burn-cube/src/element/base.rs b/crates/burn-cube/src/element/base.rs new file mode 100644 index 0000000000..2e7a925473 --- /dev/null +++ b/crates/burn-cube/src/element/base.rs @@ -0,0 +1,39 @@ +use alloc::rc::Rc; +use burn_jit::gpu::Variable; + +/// Types used in a cube function must implement this trait +/// +/// Variables whose values will be known at runtime must +/// have ExpandElement as associated type +/// Variables whose values will be known at compile time +/// must have the primitive type as associated type +/// +/// Note: Cube functions should be written using CubeTypes, +/// so that the code generated uses the associated ExpandType. +/// This allows Cube code to not necessitate cloning, which is cumbersome +/// in algorithmic code. The necessary cloning will automatically appear in +/// the generated code. +pub trait CubeType { + type ExpandType: Clone; +} + +#[derive(new, Clone, Debug)] +/// Reference to a JIT variable +/// It's the expand element that is actually kept in the variable pool +pub struct ExpandElement { + pub(crate) inner: Rc, +} + +impl core::ops::Deref for ExpandElement { + type Target = Variable; + + fn deref(&self) -> &Self::Target { + self.inner.as_ref() + } +} + +impl From for Variable { + fn from(value: ExpandElement) -> Self { + *value.inner + } +} diff --git a/crates/burn-cube/src/element/bool.rs b/crates/burn-cube/src/element/bool.rs new file mode 100644 index 0000000000..baf6f084cc --- /dev/null +++ b/crates/burn-cube/src/element/bool.rs @@ -0,0 +1,54 @@ +use burn_jit::gpu::Elem; + +use crate::{CubeContext, CubeType, ExpandElement, PrimitiveVariable}; + +#[derive(Clone, Copy)] +/// Boolean type for kernels +pub struct Bool { + pub val: bool, + pub vectorization: u8, +} + +impl CubeType for Bool { + type ExpandType = ExpandElement; +} + +impl Bool { + /// Make a boolean literal + pub fn lit(val: bool) -> Self { + Self { + val, + vectorization: 1, + } + } + + /// Expand version of lit + pub fn lit_expand(_context: &mut CubeContext, val: bool) -> ::ExpandType { + val.into() + } + + /// Create a Bool from primitive bool + pub fn from_primitive(val: bool) -> Self { + Self::lit(val) + } + + /// Expand version of from_primitive + pub fn from_primitive_expand( + context: &mut CubeContext, + val: bool, + ) -> ::ExpandType { + Self::lit_expand(context, val) + } +} + +impl PrimitiveVariable for Bool { + type Primitive = bool; + + fn val(&self) -> Self::Primitive { + self.val + } + + fn into_elem() -> Elem { + Elem::Bool + } +} diff --git a/crates/burn-cube/src/element/conversion.rs b/crates/burn-cube/src/element/conversion.rs new file mode 100644 index 0000000000..438a7bd698 --- /dev/null +++ b/crates/burn-cube/src/element/conversion.rs @@ -0,0 +1,133 @@ +use crate::{Bool, Float, Int, PrimitiveVariable, UInt, BF16, F16, F32, F64, I32, I64}; + +// Enable elegant casting from any to any primitive variable + +macro_rules! impl_to_float { + ($to:ident, $from1:ident) => { + impl From<$from1> for $to { + fn from(value: $from1) -> Self { + Self::from_primitive(value.val() as f64) + } + } + }; +} + +macro_rules! impl_to_float_from_bool { + ($to:ident, $from1:ident) => { + impl From<$from1> for $to { + fn from(value: $from1) -> Self { + Self::from_primitive(match value.val() { + true => 1., + false => 0., + }) + } + } + }; +} + +impl_to_float!(F16, BF16); +impl_to_float!(F16, F32); +impl_to_float!(F16, F64); +impl_to_float!(F16, I32); +impl_to_float!(F16, I64); +impl_to_float!(F16, UInt); +impl_to_float_from_bool!(F16, Bool); + +impl_to_float!(BF16, F16); +impl_to_float!(BF16, F32); +impl_to_float!(BF16, F64); +impl_to_float!(BF16, I32); +impl_to_float!(BF16, I64); +impl_to_float!(BF16, UInt); +impl_to_float_from_bool!(BF16, Bool); + +impl_to_float!(F32, F16); +impl_to_float!(F32, BF16); +impl_to_float!(F32, F64); +impl_to_float!(F32, I32); +impl_to_float!(F32, I64); +impl_to_float!(F32, UInt); +impl_to_float_from_bool!(F32, Bool); + +impl_to_float!(F64, F16); +impl_to_float!(F64, BF16); +impl_to_float!(F64, F32); +impl_to_float!(F64, I32); +impl_to_float!(F64, I64); +impl_to_float!(F64, UInt); +impl_to_float_from_bool!(F64, Bool); + +macro_rules! impl_to_int { + ($to:ident, $from1:ident) => { + impl From<$from1> for $to { + fn from(value: $from1) -> Self { + Self::from_primitive(value.val() as i64) + } + } + }; +} + +macro_rules! impl_to_int_from_bool { + ($to:ident, $from1:ident) => { + impl From<$from1> for $to { + fn from(value: $from1) -> Self { + Self::from_primitive(match value.val() { + true => 1, + false => 0, + }) + } + } + }; +} + +impl_to_int!(I32, F16); +impl_to_int!(I32, BF16); +impl_to_int!(I32, F32); +impl_to_int!(I32, F64); +impl_to_int!(I32, I64); +impl_to_int!(I32, UInt); +impl_to_int_from_bool!(I32, Bool); + +impl_to_int!(I64, F16); +impl_to_int!(I64, BF16); +impl_to_int!(I64, F32); +impl_to_int!(I64, F64); +impl_to_int!(I64, I32); +impl_to_int!(I64, UInt); +impl_to_int_from_bool!(I64, Bool); + +impl_to_int!(UInt, F16); +impl_to_int!(UInt, BF16); +impl_to_int!(UInt, F32); +impl_to_int!(UInt, F64); +impl_to_int!(UInt, I32); +impl_to_int!(UInt, I64); +impl_to_int_from_bool!(UInt, Bool); + +macro_rules! impl_to_bool_from_float { + ($to:ident, $from1:ident) => { + impl From<$from1> for $to { + fn from(value: $from1) -> Self { + Self::from_primitive(value.val() > 0.) + } + } + }; +} + +macro_rules! impl_to_bool_from_int { + ($to:ident, $from1:ident) => { + impl From<$from1> for $to { + fn from(value: $from1) -> Self { + Self::from_primitive(value.val() > 0) + } + } + }; +} + +impl_to_bool_from_float!(Bool, F16); +impl_to_bool_from_float!(Bool, BF16); +impl_to_bool_from_float!(Bool, F32); +impl_to_bool_from_float!(Bool, F64); +impl_to_bool_from_int!(Bool, I32); +impl_to_bool_from_int!(Bool, I64); +impl_to_bool_from_int!(Bool, UInt); diff --git a/crates/burn-cube/src/element/float.rs b/crates/burn-cube/src/element/float.rs new file mode 100644 index 0000000000..87a2cbc0cb --- /dev/null +++ b/crates/burn-cube/src/element/float.rs @@ -0,0 +1,76 @@ +use crate::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable}; +use burn_jit::gpu::{Elem, FloatKind, Variable}; +use std::rc::Rc; + +/// Floating point numbers. Used as input in float kernels +pub trait Float: Numeric { + /// Create a Float from a float literal + fn from_primitive(val: f64) -> Self; + /// Expand version of from_primitive + fn from_primitive_expand(context: &mut CubeContext, val: f64) + -> ::ExpandType; +} + +macro_rules! impl_float { + ($type:ident) => { + #[derive(Clone, Copy)] + pub struct $type { + pub val: f64, + pub vectorization: usize, + } + + impl CubeType for $type { + type ExpandType = ExpandElement; + } + + impl PrimitiveVariable for $type { + /// Note: all float types have f64 primitive on CPU to + /// ease casting. On GPU the type will be given by into_elem. + type Primitive = f64; + + /// Return the value of the float (on CPU) + fn val(&self) -> Self::Primitive { + self.val + } + + /// Return the element type to use on GPU + fn into_elem() -> Elem { + Elem::Float(FloatKind::$type) + } + } + + impl Float for $type { + fn from_primitive(val: f64) -> Self { + Self { + val, + vectorization: 1, + } + } + + fn from_primitive_expand( + _context: &mut CubeContext, + val: f64, + ) -> ::ExpandType { + let new_var = Variable::ConstantScalar(val, Self::into_elem()); + ExpandElement::new(Rc::new(new_var)) + } + } + + impl Numeric for $type { + // Method new takes an i64, because it is used when treating the float as numeric, + // which must be an int in the cube kernel because new numerics need to be supported by Int as well + fn lit(val: i64) -> Self { + Self::from_primitive(val as f64) + } + + fn lit_expand(context: &mut CubeContext, val: i64) -> ::ExpandType { + ::from_primitive_expand(context, val as f64) + } + } + }; +} + +impl_float!(F16); +impl_float!(BF16); +impl_float!(F32); +impl_float!(F64); diff --git a/crates/burn-cube/src/element/int.rs b/crates/burn-cube/src/element/int.rs new file mode 100644 index 0000000000..355ef66131 --- /dev/null +++ b/crates/burn-cube/src/element/int.rs @@ -0,0 +1,64 @@ +use crate::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable}; +use burn_jit::gpu::{Elem, IntKind, Variable}; +use std::rc::Rc; + +/// Signed integer. Used as input in int kernels +pub trait Int: Numeric + std::ops::Rem { + fn from_primitive(val: i64) -> Self; + fn from_primitive_expand(context: &mut CubeContext, val: i64) + -> ::ExpandType; +} + +macro_rules! impl_int { + ($type:ident) => { + #[derive(Clone, Copy)] + pub struct $type { + pub val: i64, + pub vectorization: usize, + } + + impl CubeType for $type { + type ExpandType = ExpandElement; + } + + impl PrimitiveVariable for $type { + type Primitive = i64; + fn val(&self) -> Self::Primitive { + self.val + } + fn into_elem() -> Elem { + Elem::Int(IntKind::$type) + } + } + + impl Int for $type { + fn from_primitive(val: i64) -> Self { + Self { + val, + vectorization: 1, + } + } + + fn from_primitive_expand( + _context: &mut CubeContext, + val: i64, + ) -> ::ExpandType { + let new_var = Variable::ConstantScalar(val as f64, Self::into_elem()); + ExpandElement::new(Rc::new(new_var)) + } + } + + impl Numeric for $type { + fn lit(val: i64) -> Self { + Self::from_primitive(val) + } + + fn lit_expand(context: &mut CubeContext, val: i64) -> ::ExpandType { + ::from_primitive_expand(context, val) + } + } + }; +} + +impl_int!(I32); +impl_int!(I64); diff --git a/crates/burn-cube/src/element/mod.rs b/crates/burn-cube/src/element/mod.rs new file mode 100644 index 0000000000..df4f96b39b --- /dev/null +++ b/crates/burn-cube/src/element/mod.rs @@ -0,0 +1,19 @@ +mod array; +mod base; +mod bool; +mod conversion; +mod float; +mod int; +mod numeric; +mod primitive; +mod static_type; +mod uint; + +pub use array::*; +pub use base::*; +pub use bool::*; +pub use float::*; +pub use int::*; +pub use numeric::*; +pub use primitive::*; +pub use uint::*; diff --git a/crates/burn-cube/src/element/numeric.rs b/crates/burn-cube/src/element/numeric.rs new file mode 100644 index 0000000000..aa59df6241 --- /dev/null +++ b/crates/burn-cube/src/element/numeric.rs @@ -0,0 +1,25 @@ +use crate::{CubeContext, CubeType, PrimitiveVariable}; + +/// Type that encompasses both (unsigned or signed) integers and floats +/// Used in kernels that should work for both. +pub trait Numeric: + Clone + + Copy + + PrimitiveVariable + + std::ops::Add + + std::ops::AddAssign + + std::ops::Sub + + std::ops::Mul + + std::ops::Div + + std::cmp::PartialOrd +{ + /// Create a new constant numeric. + /// + /// Note: since this must work for both integer and float + /// only the less expressive of both can be created (int) + /// If a number with decimals is needed, use Float::from_primitive. + fn lit(val: i64) -> Self; + + /// Expand version of new + fn lit_expand(context: &mut CubeContext, val: i64) -> ::ExpandType; +} diff --git a/crates/burn-cube/src/element/primitive.rs b/crates/burn-cube/src/element/primitive.rs new file mode 100644 index 0000000000..0a42a2bf34 --- /dev/null +++ b/crates/burn-cube/src/element/primitive.rs @@ -0,0 +1,46 @@ +use std::rc::Rc; + +use burn_jit::gpu::{Elem, Item, Variable}; + +use crate::{assign, CubeContext, CubeType, ExpandElement}; + +/// Form of CubeType that encapsulates all primitive types: +/// Numeric, UInt, Bool +pub trait PrimitiveVariable: CubeType { + /// Type of the value kept CPU-side. + /// Does not necessarily match the GPU type. + type Primitive; + + /// Return the value of the float on CPU + fn val(&self) -> Self::Primitive; + + /// Return the element type to use on GPU + fn into_elem() -> Elem; + + /// Expand version of from, of the trait From + fn from_expand( + context: &mut CubeContext, + val: ExpandElement, + ) -> ::ExpandType { + let new_var = context.create_local(Item::Scalar(::into_elem())); + assign::expand(context, val, new_var.clone()); + new_var + } +} + +macro_rules! impl_into_expand_element { + ($type:ty) => { + impl From<$type> for ExpandElement { + fn from(value: $type) -> Self { + ExpandElement::new(Rc::new(Variable::from(value))) + } + } + }; +} + +impl_into_expand_element!(u32); +impl_into_expand_element!(usize); +impl_into_expand_element!(bool); +impl_into_expand_element!(f32); +impl_into_expand_element!(i32); +impl_into_expand_element!(i64); diff --git a/crates/burn-cube/src/element/static_type.rs b/crates/burn-cube/src/element/static_type.rs new file mode 100644 index 0000000000..ceb54cf953 --- /dev/null +++ b/crates/burn-cube/src/element/static_type.rs @@ -0,0 +1,22 @@ +use crate::CubeType; + +/// Types that exist within a cube function +/// but should not be turned into a JIT variable +/// For instance: a bool that determines at compile time +/// if we should unroll a for loop or not + +impl CubeType for bool { + type ExpandType = bool; +} + +impl CubeType for u32 { + type ExpandType = u32; +} + +impl CubeType for f32 { + type ExpandType = f32; +} + +impl CubeType for i32 { + type ExpandType = i32; +} diff --git a/crates/burn-cube/src/element/uint.rs b/crates/burn-cube/src/element/uint.rs new file mode 100644 index 0000000000..11e21f48e4 --- /dev/null +++ b/crates/burn-cube/src/element/uint.rs @@ -0,0 +1,63 @@ +use burn_jit::gpu::Elem; + +use crate::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable}; + +#[derive(Clone, Copy)] +/// An unsigned int. +/// Preferred for indexing operations +pub struct UInt { + pub val: ::Primitive, + pub vectorization: u8, +} + +impl UInt { + pub fn from_primitive(val: i64) -> Self { + Self { + val, + vectorization: 1, + } + } + + pub fn from_primitive_expand( + _context: &mut CubeContext, + val: i64, + ) -> ::ExpandType { + (val as u32).into() + } +} + +impl CubeType for UInt { + type ExpandType = ExpandElement; +} + +impl PrimitiveVariable for UInt { + type Primitive = i64; + fn val(&self) -> Self::Primitive { + self.val + } + fn into_elem() -> Elem { + Elem::UInt + } +} + +impl Numeric for UInt { + fn lit(val: i64) -> Self { + Self::from_primitive(val) + } + + fn lit_expand(context: &mut CubeContext, val: i64) -> ::ExpandType { + Self::from_primitive_expand(context, val) + } +} + +impl From for UInt { + fn from(value: u32) -> Self { + UInt::from_primitive(value as ::Primitive) + } +} + +impl From for UInt { + fn from(value: usize) -> Self { + UInt::from_primitive(value as ::Primitive) + } +} diff --git a/crates/burn-cube/src/lib.rs b/crates/burn-cube/src/lib.rs new file mode 100644 index 0000000000..4009395922 --- /dev/null +++ b/crates/burn-cube/src/lib.rs @@ -0,0 +1,17 @@ +extern crate alloc; + +#[macro_use] +extern crate derive_new; + +// For use with * +pub mod branch; + +mod context; +mod element; +mod operation; + +pub use context::*; +pub use element::*; +pub use operation::*; + +pub use burn_cube_macros::cube; diff --git a/crates/burn-cube/src/operation/assignation.rs b/crates/burn-cube/src/operation/assignation.rs new file mode 100644 index 0000000000..4c768354c4 --- /dev/null +++ b/crates/burn-cube/src/operation/assignation.rs @@ -0,0 +1,94 @@ +use crate::{Array, CubeContext, ExpandElement, UInt}; +use burn_jit::gpu::{self}; + +pub mod assign { + use super::*; + + pub fn expand(context: &mut CubeContext, input: ExpandElement, output: ExpandElement) { + let input = *input; + let out = *output; + + context.register(gpu::Operator::Assign(gpu::UnaryOperator { input, out })); + } +} + +pub mod index_assign { + use crate::CubeType; + + use super::*; + + pub fn expand( + context: &mut CubeContext, + array: ExpandElement, + index: ExpandElement, + value: ExpandElement, + ) { + context.register(gpu::Operator::IndexAssign(gpu::BinaryOperator { + lhs: *index, + rhs: *value, + out: *array, + })) + } + + impl> core::ops::IndexMut for Array { + fn index_mut(&mut self, index: I) -> &mut Self::Output { + let index = index.into().val; + &mut self.vals[index as usize] + } + } +} + +pub mod index { + use crate::{operation::base::binary_expand, CubeType}; + + use super::*; + + pub fn expand( + context: &mut CubeContext, + array: ExpandElement, + index: ExpandElement, + ) -> ExpandElement { + binary_expand(context, array, index, gpu::Operator::Index) + } + + impl> core::ops::Index for Array { + type Output = E; + + fn index(&self, index: I) -> &Self::Output { + let index = index.into().val; + &self.vals[index as usize] + } + } +} + +pub mod add_assign_op { + use crate::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64}; + + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + assign_op_expand(context, lhs, rhs, gpu::Operator::Add) + } + + macro_rules! impl_add_assign { + ($type:ty) => { + impl core::ops::AddAssign for $type { + fn add_assign(&mut self, rhs: Self) { + self.val += rhs.val + } + } + }; + } + + impl_add_assign!(F16); + impl_add_assign!(BF16); + impl_add_assign!(F32); + impl_add_assign!(F64); + impl_add_assign!(I32); + impl_add_assign!(I64); + impl_add_assign!(UInt); +} diff --git a/crates/burn-cube/src/operation/base.rs b/crates/burn-cube/src/operation/base.rs new file mode 100644 index 0000000000..96a921d73f --- /dev/null +++ b/crates/burn-cube/src/operation/base.rs @@ -0,0 +1,84 @@ +use crate::{CubeContext, ExpandElement}; +use burn_jit::gpu::{self, Elem, Variable}; + +pub(crate) fn binary_expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + func: F, +) -> ExpandElement +where + F: Fn(gpu::BinaryOperator) -> gpu::Operator, +{ + let lhs: Variable = *lhs; + let rhs: Variable = *rhs; + + let item = lhs.item(); + let out = context.create_local(item); + let out_var = *out; + + let op = func(gpu::BinaryOperator { + lhs, + rhs, + out: out_var, + }); + + context.register(op); + + out +} + +pub(crate) fn cmp_expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + func: F, +) -> ExpandElement +where + F: Fn(gpu::BinaryOperator) -> gpu::Operator, +{ + let lhs: Variable = *lhs; + let rhs: Variable = *rhs; + + let out_item = match lhs.item() { + gpu::Item::Vec4(_) => gpu::Item::Vec4(Elem::Bool), + gpu::Item::Vec3(_) => gpu::Item::Vec3(Elem::Bool), + gpu::Item::Vec2(_) => gpu::Item::Vec2(Elem::Bool), + gpu::Item::Scalar(_) => gpu::Item::Scalar(Elem::Bool), + }; + let out = context.create_local(out_item); + let out_var = *out; + + let op = func(gpu::BinaryOperator { + lhs, + rhs, + out: out_var, + }); + + context.register(op); + + out +} + +pub(crate) fn assign_op_expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + func: F, +) -> ExpandElement +where + F: Fn(gpu::BinaryOperator) -> gpu::Operator, +{ + let lhs_var: Variable = *lhs; + let rhs: Variable = *rhs; + + let op = func(gpu::BinaryOperator { + lhs: lhs_var, + rhs, + out: lhs_var, + }); + + context.register(op); + + lhs +} diff --git a/crates/burn-cube/src/operation/binary.rs b/crates/burn-cube/src/operation/binary.rs new file mode 100644 index 0000000000..70cd1f6bc3 --- /dev/null +++ b/crates/burn-cube/src/operation/binary.rs @@ -0,0 +1,254 @@ +use crate::operation::base::binary_expand; +use crate::{CubeContext, ExpandElement, Float, Int, UInt, BF16, F16, F32, F64, I32, I64}; +use burn_jit::gpu::{self}; + +pub mod add { + + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + binary_expand(context, lhs, rhs, gpu::Operator::Add) + } + + macro_rules! impl_add { + ($type:ty, $trait:ty) => { + impl core::ops::Add for $type { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + <$type as $trait>::from_primitive(self.val + rhs.val) + } + } + }; + + ($type:ty) => { + impl core::ops::Add for $type { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + <$type>::from_primitive(self.val + rhs.val) + } + } + }; + } + + impl_add!(F16, Float); + impl_add!(BF16, Float); + impl_add!(F32, Float); + impl_add!(F64, Float); + impl_add!(I32, Int); + impl_add!(I64, Int); + impl_add!(UInt); +} + +pub mod sub { + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + binary_expand(context, lhs, rhs, gpu::Operator::Sub) + } + + macro_rules! impl_sub { + ($type:ty, $trait:ty) => { + impl core::ops::Sub for $type { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + <$type as $trait>::from_primitive(self.val - rhs.val) + } + } + }; + + ($type:ty) => { + impl core::ops::Sub for $type { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + <$type>::from_primitive(self.val - rhs.val) + } + } + }; + } + + impl_sub!(F16, Float); + impl_sub!(BF16, Float); + impl_sub!(F32, Float); + impl_sub!(F64, Float); + impl_sub!(I32, Int); + impl_sub!(I64, Int); + impl_sub!(UInt); +} + +pub mod mul { + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + binary_expand(context, lhs, rhs, gpu::Operator::Mul) + } + + macro_rules! impl_mul { + ($type:ty, $trait:ty) => { + impl core::ops::Mul for $type { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + <$type as $trait>::from_primitive(self.val * rhs.val) + } + } + }; + + ($type:ty) => { + impl core::ops::Mul for $type { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + <$type>::from_primitive(self.val * rhs.val) + } + } + }; + } + + impl_mul!(F16, Float); + impl_mul!(BF16, Float); + impl_mul!(F32, Float); + impl_mul!(F64, Float); + impl_mul!(I32, Int); + impl_mul!(I64, Int); + impl_mul!(UInt); +} + +pub mod div { + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + binary_expand(context, lhs, rhs, gpu::Operator::Div) + } + + macro_rules! impl_div { + ($type:ty, $trait:ty) => { + impl core::ops::Div for $type { + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + <$type as $trait>::from_primitive(self.val / rhs.val) + } + } + }; + + ($type:ty) => { + impl core::ops::Div for $type { + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + <$type>::from_primitive(self.val / rhs.val) + } + } + }; + } + + impl_div!(F16, Float); + impl_div!(BF16, Float); + impl_div!(F32, Float); + impl_div!(F64, Float); + impl_div!(I32, Int); + impl_div!(I64, Int); + impl_div!(UInt); +} + +pub mod rem { + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + binary_expand(context, lhs, rhs, gpu::Operator::Modulo) + } + + macro_rules! impl_rem { + ($type:ty, $trait:ty) => { + impl core::ops::Rem for $type { + type Output = Self; + + fn rem(self, rhs: Self) -> Self::Output { + <$type as $trait>::from_primitive(self.val % rhs.val) + } + } + }; + + ($type:ty) => { + impl core::ops::Rem for $type { + type Output = Self; + + fn rem(self, rhs: Self) -> Self::Output { + <$type>::from_primitive(self.val % rhs.val) + } + } + }; + } + + impl_rem!(I32, Int); + impl_rem!(I64, Int); + impl_rem!(UInt); +} + +pub mod and { + use crate::Bool; + + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + binary_expand(context, lhs, rhs, gpu::Operator::And) + } + + impl core::ops::BitAnd for Bool { + type Output = Bool; + + fn bitand(self, rhs: Self) -> Self::Output { + Bool::lit(self.val && rhs.val) + } + } +} + +pub mod or { + use crate::Bool; + + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + binary_expand(context, lhs, rhs, gpu::Operator::Or) + } + + impl core::ops::BitOr for Bool { + type Output = Bool; + + fn bitor(self, rhs: Self) -> Self::Output { + Bool::lit(self.val || rhs.val) + } + } +} diff --git a/crates/burn-cube/src/operation/cmp.rs b/crates/burn-cube/src/operation/cmp.rs new file mode 100644 index 0000000000..169f21bcca --- /dev/null +++ b/crates/burn-cube/src/operation/cmp.rs @@ -0,0 +1,80 @@ +use crate::operation::base::cmp_expand; +use crate::{CubeContext, ExpandElement, UInt, BF16, F16, F32, F64, I32, I64}; +use burn_jit::gpu::{self}; + +macro_rules! impl_cmp { + ($type:ty) => { + impl core::cmp::PartialEq for $type { + fn eq(&self, other: &Self) -> bool { + self.val == other.val && self.vectorization == other.vectorization + } + } + impl core::cmp::Eq for $type {} + impl core::cmp::PartialOrd for $type { + fn partial_cmp(&self, other: &Self) -> Option { + match self.val.partial_cmp(&other.val) { + Some(core::cmp::Ordering::Equal) => {} + ord => return ord, + } + self.vectorization.partial_cmp(&other.vectorization) + } + } + }; +} + +impl_cmp!(F16); +impl_cmp!(BF16); +impl_cmp!(F32); +impl_cmp!(F64); +impl_cmp!(I32); +impl_cmp!(I64); +impl_cmp!(UInt); + +pub mod ne { + + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + cmp_expand(context, lhs, rhs, gpu::Operator::NotEqual) + } +} + +pub mod gt { + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + cmp_expand(context, lhs, rhs, gpu::Operator::Greater) + } +} + +pub mod lt { + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + cmp_expand(context, lhs, rhs, gpu::Operator::Lower) + } +} + +pub mod add_assign { + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElement, + rhs: ExpandElement, + ) -> ExpandElement { + cmp_expand(context, lhs, rhs, gpu::Operator::Add) + } +} diff --git a/crates/burn-cube/src/operation/mod.rs b/crates/burn-cube/src/operation/mod.rs new file mode 100644 index 0000000000..db2470087f --- /dev/null +++ b/crates/burn-cube/src/operation/mod.rs @@ -0,0 +1,8 @@ +mod assignation; +mod base; +mod binary; +mod cmp; + +pub use assignation::*; +pub use binary::*; +pub use cmp::*; diff --git a/crates/burn-cube/tests/cast_elem.rs b/crates/burn-cube/tests/cast_elem.rs new file mode 100644 index 0000000000..4b9196951d --- /dev/null +++ b/crates/burn-cube/tests/cast_elem.rs @@ -0,0 +1,310 @@ +use burn_cube::{cube, Bool, CubeContext, Numeric, PrimitiveVariable, UInt, F32, I32}; +use burn_jit::{ + gpu, + gpu::{Elem, Item, Variable}, +}; + +macro_rules! cast_test { + ($name:ident, $module:ident, $from:expr, $to:expr) => { + #[test] + fn $name() { + let mut context = CubeContext::root(); + + let x = context.create_local($from); + + $module(&mut context, x); + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + inline_macro_ref_cast($from, $to) + ); + } + }; + + ($name:ident, $module:ident, $ty:expr) => { + #[test] + fn $name() { + let mut context = CubeContext::root(); + + let x = context.create_local($ty); + + $module(&mut context, x); + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + inline_macro_ref_identity($ty) + ); + } + }; +} + +// From float +#[cube] +#[allow(clippy::useless_conversion)] +pub fn float_to_float(x: F32) { + let y = x + F32::lit(2); + let _ = F32::from(y) + F32::lit(34); +} + +#[cube] +pub fn float_to_int(x: F32) { + let y = x + F32::lit(2); + let _ = I32::from(y) + I32::lit(34); +} + +#[cube] +pub fn float_to_uint(x: F32) { + let y = x + F32::lit(2); + let _ = UInt::from(y) + UInt::lit(34); +} + +#[cube] +pub fn float_to_bool(x: F32) { + let y = x + F32::lit(2); + let _ = Bool::from(y) | Bool::lit(true); +} + +cast_test!( + cube_float_to_float_test, + float_to_float_expand, + Item::Scalar(F32::into_elem()) +); + +cast_test!( + cube_float_to_int_test, + float_to_int_expand, + Item::Scalar(F32::into_elem()), + Item::Scalar(I32::into_elem()) +); + +cast_test!( + cube_float_to_uint_test, + float_to_uint_expand, + Item::Scalar(F32::into_elem()), + Item::Scalar(Elem::UInt) +); + +cast_test!( + cube_float_to_bool_test, + float_to_bool_expand, + Item::Scalar(F32::into_elem()), + Item::Scalar(Elem::Bool) +); + +// // From int +#[cube] +pub fn int_to_float(x: I32) { + let y = x + I32::lit(2); + let _ = F32::from(y) + F32::lit(34); +} + +#[cube] +#[allow(clippy::useless_conversion)] +pub fn int_to_int(x: I32) { + let y = x + I32::lit(2); + let _ = I32::from(y) + I32::lit(34); +} + +#[cube] +pub fn int_to_uint(x: I32) { + let y = x + I32::lit(2); + let _ = UInt::from(y) + UInt::lit(34); +} + +#[cube] +pub fn int_to_bool(x: I32) { + let y = x + I32::lit(2); + let _ = Bool::from(y) | Bool::lit(true); +} + +cast_test!( + cube_int_to_float_test, + int_to_float_expand, + Item::Scalar(I32::into_elem()), + Item::Scalar(F32::into_elem()) +); + +cast_test!( + cube_int_to_int_test, + int_to_int_expand, + Item::Scalar(I32::into_elem()) +); + +cast_test!( + cube_int_to_uint_test, + int_to_uint_expand, + Item::Scalar(I32::into_elem()), + Item::Scalar(Elem::UInt) +); + +cast_test!( + cube_int_to_bool_test, + int_to_bool_expand, + Item::Scalar(I32::into_elem()), + Item::Scalar(Elem::Bool) +); + +// // From uint +#[cube] +pub fn uint_to_float(x: UInt) { + let y = x + UInt::lit(2); + let _ = F32::from(y) + F32::lit(34); +} + +#[cube] +pub fn uint_to_int(x: UInt) { + let y = x + UInt::lit(2); + let _ = I32::from(y) + I32::lit(34); +} + +#[cube] +#[allow(clippy::useless_conversion)] +pub fn uint_to_uint(x: UInt) { + let y = x + UInt::lit(2); + let _ = UInt::from(y) + UInt::lit(34); +} + +#[cube] +pub fn uint_to_bool(x: UInt) { + let y = x + UInt::lit(2); + let _ = Bool::from(y) | Bool::lit(true); +} + +cast_test!( + cube_uint_to_float_test, + uint_to_float_expand, + Item::Scalar(Elem::UInt), + Item::Scalar(F32::into_elem()) +); + +cast_test!( + cube_uint_to_int_test, + uint_to_int_expand, + Item::Scalar(Elem::UInt), + Item::Scalar(I32::into_elem()) +); + +cast_test!( + cube_uint_to_uint_test, + uint_to_uint_expand, + Item::Scalar(Elem::UInt) +); + +cast_test!( + cube_uint_to_bool_test, + uint_to_bool_expand, + Item::Scalar(Elem::UInt), + Item::Scalar(Elem::Bool) +); + +// From bool +#[cube] +pub fn bool_to_float(x: Bool) { + let y = x & Bool::lit(false); + let _ = F32::from(y) + F32::lit(34); +} + +#[cube] +pub fn bool_to_int(x: Bool) { + let y = x & Bool::lit(false); + let _ = I32::from(y) + I32::lit(34); +} + +#[cube] +pub fn bool_to_uint(x: Bool) { + let y = x & Bool::lit(false); + let _ = UInt::from(y) + UInt::lit(34); +} + +#[cube] +#[allow(clippy::useless_conversion)] +pub fn bool_to_bool(x: Bool) { + let y = x & Bool::lit(false); + let _ = Bool::from(y) | Bool::lit(true); +} + +cast_test!( + cube_bool_to_float_test, + bool_to_float_expand, + Item::Scalar(Elem::Bool), + Item::Scalar(F32::into_elem()) +); + +cast_test!( + cube_bool_to_int_test, + bool_to_int_expand, + Item::Scalar(Elem::Bool), + Item::Scalar(I32::into_elem()) +); + +cast_test!( + cube_bool_to_uint_test, + bool_to_uint_expand, + Item::Scalar(Elem::Bool), + Item::Scalar(Elem::UInt) +); + +cast_test!( + cube_bool_to_bool_test, + bool_to_bool_expand, + Item::Scalar(Elem::Bool) +); + +fn inline_macro_ref_cast(from_item: Item, to_item: Item) -> String { + let mut context = CubeContext::root(); + let x = context.create_local(from_item); + + let mut scope = context.into_scope(); + let x: Variable = x.into(); + let y = scope.create_local(from_item); + let y_casted = scope.create_local(to_item); + let z = scope.create_local(to_item); + + match from_item.elem() { + Elem::Float(_) => gpu!(scope, y = x + 2f32), + Elem::Int(_) => gpu!(scope, y = x + 2i32), + Elem::UInt => gpu!(scope, y = x + 2u32), + Elem::Bool => gpu!(scope, y = x && false), + } + + gpu!(scope, y_casted = cast(y)); + + match to_item.elem() { + Elem::Float(_) => gpu!(scope, z = y_casted + 34f32), + Elem::Int(_) => gpu!(scope, z = y_casted + 34i32), + Elem::UInt => gpu!(scope, z = y_casted + 34u32), + Elem::Bool => gpu!(scope, z = y_casted || true), + } + + format!("{:?}", scope.operations) +} + +fn inline_macro_ref_identity(item: Item) -> String { + // When staying with the same type variables are automatically reused in cube + let mut context = CubeContext::root(); + let x = context.create_local(item); + + let mut scope = context.into_scope(); + let x: Variable = x.into(); + let y = scope.create_local(item); + + match item.elem() { + Elem::Float(_) => gpu!(scope, y = x + 2f32), + Elem::Int(_) => gpu!(scope, y = x + 2i32), + Elem::UInt => gpu!(scope, y = x + 2u32), + Elem::Bool => gpu!(scope, y = x && false), + } + + gpu!(scope, x = cast(y)); + + match item.elem() { + Elem::Float(_) => gpu!(scope, y = x + 34f32), + Elem::Int(_) => gpu!(scope, y = x + 34i32), + Elem::UInt => gpu!(scope, y = x + 34u32), + Elem::Bool => gpu!(scope, y = x || true), + } + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/cast_kind.rs b/crates/burn-cube/tests/cast_kind.rs new file mode 100644 index 0000000000..00d90dd1b4 --- /dev/null +++ b/crates/burn-cube/tests/cast_kind.rs @@ -0,0 +1,99 @@ +use burn_cube::{cube, CubeContext, Float, Int, Numeric, PrimitiveVariable, F32, F64, I32, I64}; +use burn_jit::{gpu, gpu::Item}; + +#[cube] +pub fn cast_float_kind>(input: F1) { + let x = input + F1::from_primitive(5.9); + let y = F2::from(x); + let _ = y + F2::from_primitive(2.3); +} + +#[cube] +pub fn cast_int_kind>(input: I1) { + let x = input + I1::from_primitive(5); + let y = I2::from(x); + let _ = y + I2::from_primitive(2); +} + +#[cube] +pub fn cast_numeric_to_kind>(input: T) { + let x = input + T::lit(5); + let y = I2::from(x); + let _ = y + I2::lit(2); +} + +#[test] +fn cube_cast_float_kind_test() { + let mut context = CubeContext::root(); + let item = Item::Scalar(F64::into_elem()); + + let input = context.create_local(item); + + // F16 not testable with the gpu macro, but should work the same + cast_float_kind_expand::(&mut context, input); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_float()); +} + +#[test] +fn cube_cast_int_kind_test() { + let mut context = CubeContext::root(); + let item = Item::Scalar(I32::into_elem()); + + let input = context.create_local(item); + + cast_int_kind_expand::(&mut context, input); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); +} + +#[test] +fn cube_cast_numeric_kind_test() { + let mut context = CubeContext::root(); + let item = Item::Scalar(I32::into_elem()); + + let input = context.create_local(item); + + cast_numeric_to_kind_expand::(&mut context, input); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); +} + +fn inline_macro_ref_float() -> String { + let mut context = CubeContext::root(); + let float_64 = Item::Scalar(F64::into_elem()); + let float_32 = Item::Scalar(F32::into_elem()); + let input = context.create_local(float_64); + + let mut scope = context.into_scope(); + let x = scope.create_local(float_64); + let y = scope.create_local(float_32); + let z = scope.create_local(float_32); + + gpu!(scope, x = input + 5.9f32 as f64); + gpu!(scope, y = cast(x)); + gpu!(scope, z = y + 2.3f32); + + format!("{:?}", scope.operations) +} + +fn inline_macro_ref_int() -> String { + let mut context = CubeContext::root(); + let int_32 = Item::Scalar(I32::into_elem()); + let int_64 = Item::Scalar(I64::into_elem()); + let input = context.create_local(int_32); + + let mut scope = context.into_scope(); + let x = scope.create_local(int_32); + let y = scope.create_local(int_64); + let z = scope.create_local(int_64); + + gpu!(scope, x = input + 5i32); + gpu!(scope, y = cast(x)); + gpu!(scope, z = y + 2i64); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/for_loop.rs b/crates/burn-cube/tests/for_loop.rs new file mode 100644 index 0000000000..9ee7b0abcd --- /dev/null +++ b/crates/burn-cube/tests/for_loop.rs @@ -0,0 +1,76 @@ +use burn_cube::{branch::*, cube, Array, CubeContext, Float, PrimitiveVariable, UInt, F32}; +use burn_jit::{ + gpu, + gpu::{Item, Variable}, +}; + +type ElemType = F32; + +#[cube] +pub fn for_loop(mut lhs: Array, rhs: F, end: UInt, unroll: bool) { + let tmp1 = rhs * rhs; + let tmp2 = tmp1 + rhs; + + for i in range(0u32, end, unroll) { + lhs[i] = tmp2 + lhs[i]; + } +} + +#[test] +fn test_for_loop_with_unroll() { + let mut context = CubeContext::root(); + let unroll = true; + + let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); + let rhs = context.create_local(Item::Scalar(ElemType::into_elem())); + let end = 4u32.into(); + + for_loop_expand::(&mut context, lhs, rhs, end, unroll); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll)); +} + +#[test] +fn test_for_loop_no_unroll() { + let mut context = CubeContext::root(); + let unroll = false; + + let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); + let rhs = context.create_local(Item::Scalar(ElemType::into_elem())); + let end = 4u32.into(); + + for_loop_expand::(&mut context, lhs, rhs, end, unroll); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll)); +} + +fn inline_macro_ref(unroll: bool) -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(ElemType::into_elem()); + + let lhs = context.create_local(item); + let rhs = context.create_local(item); + let lhs: Variable = lhs.into(); + let rhs: Variable = rhs.into(); + let end = 4u32; + let mut scope = context.into_scope(); + + // Kernel + let tmp1 = scope.create_local(item); + let tmp2 = scope.create_local(item); + gpu!(scope, tmp1 = rhs * rhs); + gpu!(scope, tmp2 = tmp1 + rhs); + + gpu!( + &mut scope, + range(0u32, end, unroll).for_each(|i, scope| { + gpu!(scope, rhs = lhs[i]); + gpu!(scope, tmp1 = tmp2 + rhs); + gpu!(scope, lhs[i] = tmp1); + }) + ); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/function_call.rs b/crates/burn-cube/tests/function_call.rs new file mode 100644 index 0000000000..9540059669 --- /dev/null +++ b/crates/burn-cube/tests/function_call.rs @@ -0,0 +1,103 @@ +use burn_cube::{cube, CubeContext, Numeric, PrimitiveVariable, UInt, I64}; +use burn_jit::gpu::{Elem, Item}; + +#[cube] +pub fn caller_no_arg(x: UInt) { + let _ = x + callee_no_arg(); +} + +#[cube] +pub fn callee_no_arg() -> UInt { + UInt::lit(8) +} + +#[cube] +pub fn no_call_no_arg(x: UInt) { + let _ = x + UInt::lit(8); +} + +#[cube] +pub fn caller_with_arg(x: UInt) { + let _ = x + callee_with_arg(x); +} + +#[cube] +pub fn callee_with_arg(x: UInt) -> UInt { + x * UInt::lit(8) +} + +#[cube] +pub fn no_call_with_arg(x: UInt) { + let _ = x + x * UInt::lit(8); +} + +#[cube] +pub fn caller_with_generics(x: T) { + let _ = x + callee_with_generics::(x); +} + +#[cube] +pub fn callee_with_generics(x: T) -> T { + x * T::lit(8) +} + +#[cube] +pub fn no_call_with_generics(x: T) { + let _ = x + x * T::lit(8); +} + +#[test] +fn cube_call_equivalent_to_no_call_no_arg_test() { + let mut caller_context = CubeContext::root(); + let x = caller_context.create_local(Item::Scalar(Elem::UInt)); + caller_no_arg_expand(&mut caller_context, x); + let caller_scope = caller_context.into_scope(); + + let mut no_call_context = CubeContext::root(); + let x = no_call_context.create_local(Item::Scalar(Elem::UInt)); + no_call_no_arg_expand(&mut no_call_context, x); + let no_call_scope = no_call_context.into_scope(); + + assert_eq!( + format!("{:?}", caller_scope.operations), + format!("{:?}", no_call_scope.operations) + ); +} + +#[test] +fn cube_call_equivalent_to_no_call_with_arg_test() { + let mut caller_context = CubeContext::root(); + + let x = caller_context.create_local(Item::Scalar(Elem::UInt)); + caller_with_arg_expand(&mut caller_context, x); + let caller_scope = caller_context.into_scope(); + + let mut no_call_context = CubeContext::root(); + let x = no_call_context.create_local(Item::Scalar(Elem::UInt)); + no_call_with_arg_expand(&mut no_call_context, x); + let no_call_scope = no_call_context.into_scope(); + + assert_eq!( + format!("{:?}", caller_scope.operations), + format!("{:?}", no_call_scope.operations) + ); +} + +#[test] +fn cube_call_equivalent_to_no_call_with_generics_test() { + let mut caller_context = CubeContext::root(); + type ElemType = I64; + let x = caller_context.create_local(Item::Scalar(ElemType::into_elem())); + caller_with_generics_expand::(&mut caller_context, x); + let caller_scope = caller_context.into_scope(); + + let mut no_call_context = CubeContext::root(); + let x = no_call_context.create_local(Item::Scalar(ElemType::into_elem())); + no_call_with_generics_expand::(&mut no_call_context, x); + let no_call_scope = no_call_context.into_scope(); + + assert_eq!( + format!("{:?}", caller_scope.operations), + format!("{:?}", no_call_scope.operations) + ); +} diff --git a/crates/burn-cube/tests/generic_kernel.rs b/crates/burn-cube/tests/generic_kernel.rs new file mode 100644 index 0000000000..f555ec7938 --- /dev/null +++ b/crates/burn-cube/tests/generic_kernel.rs @@ -0,0 +1,55 @@ +use burn_cube::{cube, CubeContext, Numeric, PrimitiveVariable, F32, I32}; +use burn_jit::{gpu, gpu::Item}; + +#[cube] +pub fn generic_kernel(lhs: T) { + let _ = lhs + T::lit(5); +} + +#[test] +fn cube_generic_float_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(F32::into_elem())); + + generic_kernel_expand::(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_float()); +} + +#[test] +fn cube_generic_int_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(I32::into_elem())); + + generic_kernel_expand::(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); +} + +fn inline_macro_ref_float() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(F32::into_elem()); + let lhs = context.create_local(item); + + let mut scope = context.into_scope(); + let out = scope.create_local(item); + gpu!(scope, out = lhs + 5.0f32); + + format!("{:?}", scope.operations) +} + +fn inline_macro_ref_int() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(I32::into_elem()); + let lhs = context.create_local(item); + + let mut scope = context.into_scope(); + let out = scope.create_local(item); + gpu!(scope, out = lhs + 5); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/if.rs b/crates/burn-cube/tests/if.rs new file mode 100644 index 0000000000..7f193182a8 --- /dev/null +++ b/crates/burn-cube/tests/if.rs @@ -0,0 +1,44 @@ +use burn_cube::{branch::*, cube, CubeContext, Numeric, PrimitiveVariable, F32}; +use burn_jit::{ + gpu, + gpu::{Elem, Item, Variable}, +}; + +type ElemType = F32; + +#[cube] +pub fn if_greater(lhs: T) { + if lhs > T::lit(0) { + let _ = lhs + T::lit(4); + } +} + +#[test] +fn cube_if_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); + + if_greater_expand::(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); +} + +fn inline_macro_ref() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(ElemType::into_elem()); + let lhs = context.create_local(item); + + let mut scope = context.into_scope(); + let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let lhs: Variable = lhs.into(); + let y = scope.create_local(item); + + gpu!(scope, cond = lhs > 0f32); + gpu!(&mut scope, if(cond).then(|scope| { + gpu!(scope, y = lhs + 4.0f32); + })); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/if_else.rs b/crates/burn-cube/tests/if_else.rs new file mode 100644 index 0000000000..ba121f9f4a --- /dev/null +++ b/crates/burn-cube/tests/if_else.rs @@ -0,0 +1,48 @@ +use burn_cube::{branch::*, cube, CubeContext, Float, PrimitiveVariable, F32}; +use burn_jit::{ + gpu, + gpu::{Elem, Item, Variable}, +}; + +type ElemType = F32; + +#[cube] +pub fn if_then_else(lhs: F) { + if lhs < F::lit(0) { + let _ = lhs + F::lit(4); + } else { + let _ = lhs - F::lit(5); + } +} + +#[test] +fn cube_if_else_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); + + if_then_else_expand::(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); +} + +fn inline_macro_ref() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(ElemType::into_elem()); + let lhs = context.create_local(item); + + let mut scope = context.into_scope(); + let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let lhs: Variable = lhs.into(); + let y = scope.create_local(item); + + gpu!(scope, cond = lhs < 0f32); + gpu!(&mut scope, if(cond).then(|scope| { + gpu!(scope, y = lhs + 4.0f32); + }).else(|scope|{ + gpu!(scope, y = lhs - 5.0f32); + })); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/literal.rs b/crates/burn-cube/tests/literal.rs new file mode 100644 index 0000000000..39778fef9d --- /dev/null +++ b/crates/burn-cube/tests/literal.rs @@ -0,0 +1,50 @@ +use burn_cube::{cube, CubeContext, Float, PrimitiveVariable, F32}; +use burn_jit::{gpu, gpu::Item}; + +type ElemType = F32; + +#[cube] +pub fn literal(lhs: F) { + let _ = lhs + F::lit(5); +} + +#[cube] +pub fn literal_float_no_decimals(lhs: F) { + let _ = lhs + F::from_primitive(5.); +} + +#[test] +fn cube_literal_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); + + literal_expand::(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); +} + +#[test] +fn cube_literal_float_no_decimal_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); + + literal_float_no_decimals_expand::(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); +} + +fn inline_macro_ref() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(ElemType::into_elem()); + let lhs = context.create_local(item); + + let mut scope = context.into_scope(); + let out = scope.create_local(item); + gpu!(scope, out = lhs + 5.0f32); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/loop.rs b/crates/burn-cube/tests/loop.rs new file mode 100644 index 0000000000..b3530bac2f --- /dev/null +++ b/crates/burn-cube/tests/loop.rs @@ -0,0 +1,72 @@ +use burn_cube::{branch::*, cube, CubeContext, Int, PrimitiveVariable, I32}; +use burn_jit::gpu; +use burn_jit::gpu::Branch; +use burn_jit::gpu::{Elem, Item, Variable}; + +type ElemType = I32; + +#[cube] +pub fn while_not(lhs: I) { + while lhs != I::lit(0) { + let _ = lhs % I::lit(1); + } +} + +#[cube] +pub fn manual_loop_break(lhs: I) { + loop { + if lhs != I::lit(0) { + break; + } + let _ = lhs % I::lit(1); + } +} + +#[test] +fn cube_while_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); + + while_not_expand::(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); +} + +#[test] +fn cube_loop_break_test() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); + + manual_loop_break_expand::(&mut context, lhs); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); +} + +fn inline_macro_ref() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(ElemType::into_elem()); + let lhs = context.create_local(item); + + let mut scope = context.into_scope(); + let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let lhs: Variable = lhs.into(); + let rhs = scope.create_local(item); + + gpu!( + &mut scope, + loop(|scope| { + gpu!(scope, cond = lhs != 0); + gpu!(scope, if(cond).then(|scope|{ + scope.register(Branch::Break); + })); + + gpu!(scope, rhs = lhs % 1i32); + }) + ); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/module_import.rs b/crates/burn-cube/tests/module_import.rs new file mode 100644 index 0000000000..04633df163 --- /dev/null +++ b/crates/burn-cube/tests/module_import.rs @@ -0,0 +1,47 @@ +use burn_cube::{CubeContext, PrimitiveVariable, F32}; +use burn_jit::gpu::Item; + +type ElemType = F32; + +mod elsewhere { + use burn_cube::{cube, Float}; + + #[cube] + pub fn my_func(x: F) -> F { + x * F::lit(2) + } +} + +mod here { + use burn_cube::{cube, Float}; + + use crate::elsewhere; + + #[cube] + pub fn caller(x: F) { + let _ = x + elsewhere::my_func::(x); + } + + #[cube] + pub fn no_call_ref(x: F) { + let _ = x + x * F::lit(2); + } +} + +#[test] +fn cube_call_equivalent_to_no_call_no_arg_test() { + let mut caller_context = CubeContext::root(); + let x = caller_context.create_local(Item::Scalar(ElemType::into_elem())); + here::caller_expand::(&mut caller_context, x); + let caller_scope = caller_context.into_scope(); + + let mut no_call_context = CubeContext::root(); + let x = no_call_context.create_local(Item::Scalar(ElemType::into_elem())); + here::no_call_ref_expand::(&mut no_call_context, x); + let no_call_scope = no_call_context.into_scope(); + + assert_eq!( + format!("{:?}", caller_scope.operations), + format!("{:?}", no_call_scope.operations) + ); +} diff --git a/crates/burn-cube/tests/parenthesis.rs b/crates/burn-cube/tests/parenthesis.rs new file mode 100644 index 0000000000..7a6b2049db --- /dev/null +++ b/crates/burn-cube/tests/parenthesis.rs @@ -0,0 +1,45 @@ +use burn_cube::{cube, CubeContext, Numeric, PrimitiveVariable, F32}; +use burn_jit::{ + gpu, + gpu::{Item, Variable}, +}; + +type ElemType = F32; + +#[cube] +pub fn parenthesis(x: T, y: T, z: T) -> T { + x * (y + z) +} + +#[test] +fn cube_parenthesis_priority_test() { + let mut context = CubeContext::root(); + + let x = context.create_local(Item::Scalar(ElemType::into_elem())); + let y = context.create_local(Item::Scalar(ElemType::into_elem())); + let z = context.create_local(Item::Scalar(ElemType::into_elem())); + + parenthesis_expand::(&mut context, x, y, z); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); +} + +fn inline_macro_ref() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(ElemType::into_elem()); + let x = context.create_local(item); + let y = context.create_local(item); + let z = context.create_local(item); + + let mut scope = context.into_scope(); + let x: Variable = x.into(); + let y: Variable = y.into(); + let z: Variable = z.into(); + let tmp = scope.create_local(item); + + gpu!(scope, tmp = y + z); + gpu!(scope, y = x * tmp); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/reuse.rs b/crates/burn-cube/tests/reuse.rs new file mode 100644 index 0000000000..f22d401ce3 --- /dev/null +++ b/crates/burn-cube/tests/reuse.rs @@ -0,0 +1,99 @@ +use burn_cube::{branch::*, cube, CubeContext, Int, PrimitiveVariable, I32}; +use burn_jit::{ + gpu, + gpu::{Branch, Elem, Item, Variable}, +}; + +type ElemType = I32; + +#[cube] +#[allow(clippy::assign_op_pattern)] +pub fn reuse(mut x: I) { + // a += b is more efficient than a = a + b + // Because the latter does not assume that a is the same in lhs and rhs + // Normally clippy should detect it + while x < I::lit(10) { + x = x + I::lit(1); + } +} + +#[cube] +pub fn reuse_incr(mut x: I) { + while x < I::lit(10) { + x += I::lit(1); + } +} + +#[test] +fn cube_reuse_assign_test() { + let mut context = CubeContext::root(); + + let x = context.create_local(Item::Scalar(ElemType::into_elem())); + + reuse_expand::(&mut context, x); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_assign()); +} + +#[test] +fn cube_reuse_incr_test() { + let mut context = CubeContext::root(); + + let x = context.create_local(Item::Scalar(ElemType::into_elem())); + + reuse_incr_expand::(&mut context, x); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_incr()); +} + +fn inline_macro_ref_assign() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(ElemType::into_elem()); + let x = context.create_local(item); + + let mut scope = context.into_scope(); + let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let x: Variable = x.into(); + let tmp = scope.create_local(item); + + gpu!( + &mut scope, + loop(|scope| { + gpu!(scope, cond = x < 10); + gpu!(scope, if(cond).then(|scope|{ + scope.register(Branch::Break); + })); + + gpu!(scope, tmp = x + 1); + gpu!(scope, x = tmp); + }) + ); + + format!("{:?}", scope.operations) +} + +fn inline_macro_ref_incr() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(ElemType::into_elem()); + let x = context.create_local(item); + + let mut scope = context.into_scope(); + let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let x: Variable = x.into(); + + gpu!( + &mut scope, + loop(|scope| { + gpu!(scope, cond = x < 10); + gpu!(scope, if(cond).then(|scope|{ + scope.register(Branch::Break); + })); + + gpu!(scope, x = x + 1); + }) + ); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-cube/tests/trait.rs b/crates/burn-cube/tests/trait.rs new file mode 100644 index 0000000000..bd8ad9dfc4 --- /dev/null +++ b/crates/burn-cube/tests/trait.rs @@ -0,0 +1,201 @@ +use burn_cube::{cube, CubeContext, CubeType, Float, Numeric, PrimitiveVariable, F32}; +use burn_jit::{ + gpu, + gpu::{Item, Variable}, +}; + +type ElemType = F32; + +/// Traits used in Cube kernels must expose an _expand variant +/// for all their methods. However, one does not need to provide its +/// implementation, see examples below. +trait Strategy { + fn operation(input_1: T, input_2: T) -> T; + fn operation_expand( + context: &mut CubeContext, + input_1: ::ExpandType, + input_2: ::ExpandType, + ) -> ::ExpandType; +} + +struct AddStrategy; + +#[cube] +/// The actual implementation of AddStrategy's operation +/// Automatically generated an _expand variant +fn add_strategy_operation(input_1: T, input_2: T) -> T { + input_1 + input_2 +} + +impl Strategy for AddStrategy { + /// Here we link the trait's method to the cube function + fn operation(input_1: T, input_2: T) -> T { + add_strategy_operation(input_1, input_2) + } + + /// Here we link the trait's expanded method to the cube expanded function + fn operation_expand( + context: &mut CubeContext, + input_1: ::ExpandType, + input_2: ::ExpandType, + ) -> ::ExpandType { + add_strategy_operation_expand::(context, input_1, input_2) + } +} + +struct SubStrategy; + +#[cube] +fn sub_strategy_operation(input_1: T, input_2: T) -> T { + input_1 - input_2 +} + +impl Strategy for SubStrategy { + fn operation(input_1: T, input_2: T) -> T { + sub_strategy_operation(input_1, input_2) + } + + fn operation_expand( + context: &mut CubeContext, + input_1: ::ExpandType, + input_2: ::ExpandType, + ) -> ::ExpandType { + sub_strategy_operation_expand::(context, input_1, input_2) + } +} + +#[cube] +fn with_strategy_trait, T: Numeric>(x: T, y: T) -> T { + S::operation(x, y) +} + +#[cube] +fn two_strategy_traits, S2: Strategy, F: Float>(x: F, y: F) -> F { + let z = S1::operation(x, y); + S2::operation(z, y) +} + +trait MethodTypedStrategy { + fn operation(input_1: T, input_2: T) -> T; + fn operation_expand( + _context: &mut CubeContext, + input_1: ::ExpandType, + input_2: ::ExpandType, + ) -> ::ExpandType; +} + +impl MethodTypedStrategy for AddStrategy { + fn operation(input_1: T, input_2: T) -> T { + add_strategy_operation(input_1, input_2) + } + + fn operation_expand( + context: &mut CubeContext, + input_1: ::ExpandType, + input_2: ::ExpandType, + ) -> ::ExpandType { + add_strategy_operation_expand::(context, input_1, input_2) + } +} + +#[cube] +fn with_trait_generic_method(x: T, y: T) -> T { + S::operation::(x, y) +} + +#[test] +fn cube_strategy_trait_add_test() { + let mut context = CubeContext::root(); + + let x = context.create_local(Item::Scalar(ElemType::into_elem())); + let y = context.create_local(Item::Scalar(ElemType::into_elem())); + + with_strategy_trait_expand::(&mut context, x, y); + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + inline_macro_ref_one(true) + ); +} + +#[test] +fn cube_strategy_trait_sub_test() { + let mut context = CubeContext::root(); + + let x = context.create_local(Item::Scalar(ElemType::into_elem())); + let y = context.create_local(Item::Scalar(ElemType::into_elem())); + + with_strategy_trait_expand::(&mut context, x, y); + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + inline_macro_ref_one(false) + ); +} + +#[test] +fn cube_two_strategy_traits_test() { + let mut context = CubeContext::root(); + + let x = context.create_local(Item::Scalar(ElemType::into_elem())); + let y = context.create_local(Item::Scalar(ElemType::into_elem())); + + two_strategy_traits_expand::(&mut context, x, y); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_two()); +} + +#[test] +fn cube_trait_generic_method_test() { + let mut context = CubeContext::root(); + + let x = context.create_local(Item::Scalar(ElemType::into_elem())); + let y = context.create_local(Item::Scalar(ElemType::into_elem())); + + with_trait_generic_method_expand::(&mut context, x, y); + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + inline_macro_ref_one(true) + ); +} + +fn inline_macro_ref_one(is_add_strategy: bool) -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(ElemType::into_elem()); + let x = context.create_local(item); + let y = context.create_local(item); + + let mut scope = context.into_scope(); + let x: Variable = x.into(); + let y: Variable = y.into(); + let tmp = scope.create_local(item); + + match is_add_strategy { + true => gpu!(scope, tmp = x + y), + false => gpu!(scope, tmp = x - y), + } + + format!("{:?}", scope.operations) +} + +fn inline_macro_ref_two() -> String { + let mut context = CubeContext::root(); + let item = Item::Scalar(ElemType::into_elem()); + let x = context.create_local(item); + let y = context.create_local(item); + + let mut scope = context.into_scope(); + let x: Variable = x.into(); + let y: Variable = y.into(); + let tmp = scope.create_local(item); + + gpu!(scope, tmp = x - y); + gpu!(scope, x = tmp + y); + + format!("{:?}", scope.operations) +} diff --git a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs index ecb095f356..bacd7510fe 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs @@ -1,5 +1,7 @@ use super::Variable; +#[macro_export(local_inner_macros)] +/// Macro for generating JIT intermediate representation, in a concise way macro_rules! gpu { // out = lhs + rhs ($scope:expr, $out:ident = $lhs:ident + $rhs:expr) => { @@ -406,12 +408,24 @@ impl From for Variable { } } +impl From for Variable { + fn from(value: i64) -> Self { + Self::ConstantScalar(value as f64, super::Elem::Int(super::IntKind::I64)) + } +} + impl From for Variable { fn from(value: f32) -> Self { Self::ConstantScalar(value as f64, super::Elem::Float(super::FloatKind::F32)) } } +impl From for Variable { + fn from(value: f64) -> Self { + Self::ConstantScalar(value, super::Elem::Float(super::FloatKind::F64)) + } +} + impl From for Variable { fn from(value: u32) -> Self { Self::ConstantScalar(value as f64, super::Elem::UInt) diff --git a/crates/burn-jit/src/codegen/dialect/gpu/scope.rs b/crates/burn-jit/src/codegen/dialect/gpu/scope.rs index 585c59a601..1bfada540a 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/scope.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/scope.rs @@ -44,7 +44,7 @@ impl Scope { /// [compute shader](crate::codegen::dialect::gpu::ComputeShader). /// /// A local scope can be created with the [child](Self::child) method. - pub(crate) fn root() -> Self { + pub fn root() -> Self { Self { depth: 0, operations: Vec::new(), @@ -69,7 +69,7 @@ impl Scope { } /// Create a variable initialized at some value. - pub(crate) fn create_with_value + Copy>( + pub fn create_with_value + Copy>( &mut self, value: E, item: I, @@ -81,7 +81,7 @@ impl Scope { } /// Create a local variable of the given [item type](Item). - pub(crate) fn create_local>(&mut self, item: I) -> Variable { + pub fn create_local>(&mut self, item: I) -> Variable { let item = item.into(); let index = self.new_local_index(); let local = Variable::Local(index, item, self.depth); @@ -92,7 +92,7 @@ impl Scope { /// Create a new local variable, but doesn't perform the declaration. /// /// Useful for _for loops_ and other algorithms that require the control over initialization. - pub(crate) fn create_local_undeclared(&mut self, item: Item) -> Variable { + pub fn create_local_undeclared(&mut self, item: Item) -> Variable { let index = self.new_local_index(); let local = Variable::Local(index, item, self.depth); self.undeclared += 1; @@ -205,12 +205,12 @@ impl Scope { } /// Register an [operation](Operation) into the scope. - pub(crate) fn register>(&mut self, operation: T) { + pub fn register>(&mut self, operation: T) { self.operations.push(operation.into()) } /// Create an empty child scope. - pub(crate) fn child(&mut self) -> Self { + pub fn child(&mut self) -> Self { Self { depth: self.depth + 1, operations: Vec::new(), diff --git a/crates/burn-jit/src/codegen/dialect/gpu/shader.rs b/crates/burn-jit/src/codegen/dialect/gpu/shader.rs index b007533e72..f985d6244e 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/shader.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/shader.rs @@ -80,7 +80,7 @@ impl Display for Elem { } } -#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize, Hash)] #[allow(missing_docs)] pub enum Item { Vec4(Elem), diff --git a/crates/burn-jit/src/codegen/dialect/gpu/variable.rs b/crates/burn-jit/src/codegen/dialect/gpu/variable.rs index 4a6a336d6a..3f4330050a 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/variable.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/variable.rs @@ -63,6 +63,7 @@ impl Variable { Variable::NumWorkgroupsZ => None, } } + /// Fetch the item of the variable. pub fn item(&self) -> Item { match self { @@ -96,7 +97,7 @@ impl Variable { } } -// Useful with the gpu! macro. +// Useful with the cube_inline macro. impl From<&Variable> for Variable { fn from(value: &Variable) -> Self { *value diff --git a/crates/burn-jit/src/codegen/mod.rs b/crates/burn-jit/src/codegen/mod.rs index cc9ca47404..cd3a3bef6c 100644 --- a/crates/burn-jit/src/codegen/mod.rs +++ b/crates/burn-jit/src/codegen/mod.rs @@ -1,6 +1,7 @@ mod compilation; pub(crate) mod compiler; -pub(crate) mod dialect; +/// Contains Intermediate Representation +pub mod dialect; mod kernel; diff --git a/crates/burn-jit/src/kernel/prng/base.rs b/crates/burn-jit/src/kernel/prng/base.rs index 0a32f74129..c9a14d423f 100644 --- a/crates/burn-jit/src/kernel/prng/base.rs +++ b/crates/burn-jit/src/kernel/prng/base.rs @@ -285,7 +285,7 @@ pub(crate) fn lcg_step(scope: &mut Scope, z: Variable) { } pub(crate) fn cast_uint_to_float(scope: &mut Scope, int_random: Variable, float_random: Variable) { - let tmp: Variable = 2.328_306_4e-10.into(); + let tmp: Variable = 2.328_306_4e-10f32.into(); gpu!(scope, float_random = cast(int_random)); gpu!(scope, float_random *= tmp); } diff --git a/crates/burn-jit/src/lib.rs b/crates/burn-jit/src/lib.rs index fa08bb2924..0ea613f315 100644 --- a/crates/burn-jit/src/lib.rs +++ b/crates/burn-jit/src/lib.rs @@ -15,7 +15,8 @@ pub mod kernel; /// Tensor module. pub mod tensor; -pub(crate) mod codegen; +/// Useful in Cube, should be moved over there +pub mod codegen; pub(crate) mod tune; mod element;