Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cube: first ported kernel + comptime support + variable reuse + cleanup #1797

Merged
merged 20 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ os_info = { workspace = true }
indicatif = { workspace = true }
percent-encoding = { workspace = true }
rand = { workspace = true }
reqwest = {workspace = true, features = ["blocking", "json"]}
reqwest = { workspace = true, features = ["blocking", "json"] }
serde = { workspace = true }
serde_json = { workspace = true }
strum = { workspace = true }
Expand Down Expand Up @@ -70,6 +70,10 @@ name = "conv-transpose2d"
path = "benches/conv_transpose2d.rs"
harness = false

[[bench]]
name = "conv2d"
harness = false

[[bench]]
name = "matmul"
harness = false
Expand Down
103 changes: 103 additions & 0 deletions backend-comparison/benches/conv2d.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use backend_comparison::persistence::save;
use burn::tensor::{
backend::Backend, module::conv2d, ops::ConvOptions, Distribution, Shape, Tensor,
};
use burn_common::benchmark::{run_benchmark, Benchmark};

pub struct Conv2dBenchmark<B: Backend> {
input_shape: Shape<4>,
weight_shape: Shape<4>,
bias_shape: Shape<1>,
options: ConvOptions<2>,
device: B::Device,
}

impl<B: Backend> Benchmark for Conv2dBenchmark<B> {
type Args = (Tensor<B, 4>, Tensor<B, 4>, Tensor<B, 1>);

fn name(&self) -> String {
"conv2d".into()
}

fn shapes(&self) -> Vec<Vec<usize>> {
vec![
self.input_shape.dims.into(),
self.weight_shape.dims.into(),
self.bias_shape.dims.into(),
]
}

fn execute(&self, (x, w, b): Self::Args) {
conv2d(x, w, Some(b), self.options.clone());
}

fn prepare(&self) -> Self::Args {
(
Tensor::random(
self.input_shape.clone(),
Distribution::Default,
&self.device,
),
Tensor::random(
self.weight_shape.clone(),
Distribution::Default,
&self.device,
),
Tensor::random(self.bias_shape.clone(), Distribution::Default, &self.device),
)
}

fn sync(&self) {
B::sync(&self.device)
}
}

#[allow(dead_code)]
fn bench<B: Backend>(
device: &B::Device,
feature_name: &str,
url: Option<&str>,
token: Option<&str>,
) {
// Shapes
let batch_size = 16;
let channels_in = 16;
let channels_out = 16;
let height_in = 512;
let width_in = 512;
let kernel_size_0 = 3;
let kernel_size_1 = 3;

// Options
let strides = [1, 1];
let padding = [0, 0];
let dilations = [1, 1];
let groups = 1;
let options = ConvOptions::new(strides, padding, dilations, groups);
let benchmark = Conv2dBenchmark::<B> {
input_shape: [batch_size, channels_in, height_in, width_in].into(),
weight_shape: [
channels_in,
channels_out / groups,
kernel_size_0,
kernel_size_1,
]
.into(),
bias_shape: [channels_out].into(),
options,
device: device.clone(),
};

save::<B>(
vec![run_benchmark(benchmark)],
device,
feature_name,
url,
token,
)
.unwrap();
}

fn main() {
backend_comparison::bench_on_backend!();
}
4 changes: 4 additions & 0 deletions backend-comparison/src/burnbenchapp/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ enum BenchmarkValues {
LoadRecord,
#[strum(to_string = "autodiff")]
Autodiff,
#[strum(to_string = "conv-transpose2d")]
ConvTranspose2d,
#[strum(to_string = "conv2d")]
Conv2d,
}

pub fn execute() {
Expand Down
10 changes: 10 additions & 0 deletions crates/burn-cube-macros/src/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ impl CodeAnalysisBuilder {
fn find_occurrences_in_expr(&mut self, expr: &syn::Expr, depth: usize) {
match expr {
syn::Expr::ForLoop(expr) => {
self.find_occurrences_in_expr(&expr.expr, depth);

let depth = depth + 1;

// Declaration of iterator
Expand Down Expand Up @@ -248,6 +250,14 @@ impl CodeAnalysisBuilder {
}
}
syn::Expr::Reference(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
syn::Expr::Closure(expr) => {
assert!(
expr.inputs.is_empty(),
"Analysis: closure with args not supported"
);

self.find_occurrences_in_expr(&expr.body, depth + 1)
}
_ => todo!("Analysis: unsupported expr {expr:?}"),
}
}
Expand Down
36 changes: 26 additions & 10 deletions crates/burn-cube-macros/src/codegen/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@ use proc_macro2::TokenStream;

use crate::{analysis::CodeAnalysis, codegen::base::codegen_expr};

use super::{base::codegen_block, operation::codegen_binary, variable::codegen_lit};
use super::{
base::codegen_block,
function::parse_function_call,
operation::codegen_binary,
variable::{codegen_lit, codegen_path_rhs},
};

/// Codegen of for loops
/// Supports range:
Expand All @@ -13,7 +18,6 @@ pub(crate) fn codegen_for_loop(
variable_analyses: &mut CodeAnalysis,
) -> TokenStream {
let i = &for_loop.pat;
let block = codegen_block(&for_loop.body, loop_level + 1, variable_analyses);

match for_loop.expr.as_ref() {
syn::Expr::Call(call) => {
Expand All @@ -35,6 +39,8 @@ pub(crate) fn codegen_for_loop(
args.extend(quote::quote! { #arg, });
}

let block = codegen_block(&for_loop.body, loop_level + 1, variable_analyses);

quote::quote! {
burn_cube::branch::range_expand(#args |context, #i| #block);
}
Expand All @@ -51,10 +57,12 @@ pub(crate) fn codegen_cond(
cond: &syn::Expr,
loop_level: usize,
variable_analyses: &mut CodeAnalysis,
) -> TokenStream {
) -> (TokenStream, bool) {
match cond {
syn::Expr::Binary(expr) => codegen_binary(expr, loop_level, variable_analyses),
syn::Expr::Lit(expr) => codegen_lit(expr),
syn::Expr::Binary(expr) => (codegen_binary(expr, loop_level, variable_analyses), false),
syn::Expr::Lit(expr) => (codegen_lit(expr), false),
syn::Expr::Path(expr) => (codegen_path_rhs(expr, loop_level, variable_analyses), false),
syn::Expr::Call(expr) => parse_function_call(expr, loop_level, variable_analyses),
_ => todo!("{cond:?} cond not supported"),
}
}
Expand All @@ -70,12 +78,18 @@ pub(crate) fn codegen_break() -> TokenStream {
/// Supports:
/// if cond {...}
/// if cond {...} else {...}
/// if Comptime::get(...) {...} [else {...}]
pub(crate) fn codegen_if(
expr_if: &syn::ExprIf,
loop_level: usize,
variable_analyses: &mut CodeAnalysis,
) -> TokenStream {
let cond = codegen_cond(&expr_if.cond, loop_level, variable_analyses);
let (cond, comptime) = codegen_cond(&expr_if.cond, loop_level, variable_analyses);
let comptime_bool = if comptime {
quote::quote! { Some(#cond) }
} else {
quote::quote! { None }
};

let then_block = codegen_block(&expr_if.then_branch, loop_level + 1, variable_analyses);

Expand All @@ -85,15 +99,15 @@ pub(crate) fn codegen_if(

quote::quote! {
let _cond = #cond;
burn_cube::branch::if_else_expand(context, _cond, |context| #then_block, |context| #else_block);
burn_cube::branch::if_else_expand(context, #comptime_bool, _cond.into(), |context| #then_block, |context| #else_block);
}
} else {
todo!("Analysis: Only block else expr is supported")
todo!("Codegen: Only block else expr is supported")
}
} else {
quote::quote! {
let _cond = #cond;
burn_cube::branch::if_expand(context, _cond, |context| #then_block);
burn_cube::branch::if_expand(context, #comptime_bool, _cond.into(), |context| #then_block);
}
}
}
Expand All @@ -117,7 +131,9 @@ pub(crate) fn codegen_while_loop(
loop_level: usize,
variable_analyses: &mut CodeAnalysis,
) -> TokenStream {
let cond = codegen_cond(&while_loop.cond, loop_level + 1, variable_analyses);
let (cond, comptime) = codegen_cond(&while_loop.cond, loop_level + 1, variable_analyses);
assert!(!comptime, "Codegen: Comptime not supported for while");

let block = codegen_block(&while_loop.body, loop_level + 1, variable_analyses);

quote::quote! {
Expand Down
75 changes: 63 additions & 12 deletions crates/burn-cube-macros/src/codegen/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,25 @@ pub(crate) fn codegen_closure(
}

/// Codegen for a function call
pub(crate) fn codegen_call(
call: &syn::ExprCall,
loop_level: usize,
variable_analyses: &mut CodeAnalysis,
) -> TokenStream {
parse_function_call(call, loop_level, variable_analyses).0
}

/// Maps
/// [A[::<...>]?::]^* func[::<...>] (args)
/// to
/// [A[::<...>]?::]^* func_expand[::<...>] (context, args)
pub(crate) fn codegen_call(
///
/// Also returns a bool that is true if it's comptime
pub(crate) fn parse_function_call(
call: &syn::ExprCall,
loop_level: usize,
variable_analyses: &mut CodeAnalysis,
) -> TokenStream {
) -> (TokenStream, bool) {
// We start with parsing the function path
let path: Vec<(&Ident, Option<&AngleBracketedGenericArguments>)> = match call.func.as_ref() {
syn::Expr::Path(expr_path) => {
Expand All @@ -63,8 +73,19 @@ pub(crate) fn codegen_call(

// Path
let mut path_tokens = TokenStream::new();
let mut is_comptime = false;
let mut comptime_func: Option<String> = None;

for (i, (ident, generics)) in path.iter().enumerate() {
if *ident == "Comptime" {
is_comptime = true;
continue;
}
if i == path.len() - 1 {
if is_comptime {
comptime_func = Some(ident.to_string());
break;
}
let func_name_expand = syn::Ident::new(
format!("{ident}_expand").as_str(),
proc_macro2::Span::call_site(),
Expand All @@ -82,16 +103,46 @@ pub(crate) fn codegen_call(
}

// Arguments
let mut args = quote::quote! {
context,
};
for argument in call.args.iter() {
let arg = codegen_expr(argument, loop_level, variable_analyses);
args.extend(quote::quote! { #arg, });
}
if let Some(func_name) = comptime_func {
let tokens = match func_name.as_str() {
"get" | "new" => {
let code = call.args.first().unwrap();
quote::quote! {#code}
}
"unwrap_or_else" => {
let mut args = quote::quote! {};
args.extend(quote::quote! { context, });
for argument in call.args.iter() {
let arg = codegen_expr(argument, loop_level, variable_analyses);
args.extend(quote::quote! { #arg, });
}

// Codegen
quote::quote! {
#path_tokens (#args)
// Codegen
quote::quote! {
Comptime::unwrap_or_else_expand(#args)
}
}
"is_some" => {
let code = call.args.first().unwrap();
quote::quote! { #code.is_some() }
}
_ => panic!("Codegen: Comptime function {:?} does not exist", func_name),
};

(tokens, true)
} else {
let mut args = quote::quote! {};
args.extend(quote::quote! { context, });
for argument in call.args.iter() {
let arg = codegen_expr(argument, loop_level, variable_analyses);
args.extend(quote::quote! { #arg, });
}

// Codegen
let tokens = quote::quote! {
#path_tokens (#args)
};

(tokens, false)
}
}
Loading
Loading