Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CubeCL first iteration #1756

Merged
merged 60 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
2142730
WIP
nathanielsimard Apr 9, 2024
94f476f
WIP
nathanielsimard Apr 10, 2024
8e4d39e
WIP
nathanielsimard Apr 10, 2024
85a0d83
Merge branch 'main' into feat/cube-pl
nathanielsimard Apr 11, 2024
f15919f
WIP
nathanielsimard Apr 12, 2024
3c64001
Push
nathanielsimard Apr 16, 2024
8c82ea8
Wip
nathanielsimard Apr 16, 2024
8ed5217
little refactor
louisfd Apr 17, 2024
dbbfdb7
wip
louisfd Apr 18, 2024
797d997
the right number of clones
louisfd Apr 19, 2024
2617e6b
add comments
louisfd Apr 19, 2024
16b454a
wip
louisfd Apr 22, 2024
c024e73
merge main
louisfd Apr 26, 2024
d756412
for loop tests
louisfd Apr 29, 2024
faea9d7
wip
louisfd Apr 29, 2024
709522e
refactor identify leaves and deletables
louisfd Apr 29, 2024
7f387f4
Merge branch 'main' of github.com:tracel-ai/burn
louisfd Apr 30, 2024
51b6085
Merge branch 'main' into feat/cube-pl
louisfd Apr 30, 2024
32bfcae
merge main
louisfd Apr 30, 2024
7a25525
wip
louisfd May 1, 2024
99fe8d8
prelude working
louisfd May 2, 2024
ad83d02
fix for loop
louisfd May 2, 2024
9547998
wip
louisfd May 2, 2024
bcd3c86
if
louisfd May 3, 2024
9cfafd9
if using variable inside
louisfd May 3, 2024
9310451
while loop
louisfd May 3, 2024
08f0f6e
loop and break
louisfd May 3, 2024
fb65985
assign add
louisfd May 3, 2024
8fd3541
wip
louisfd May 6, 2024
a14b653
variable reuse
louisfd May 6, 2024
fab9201
cast elem
louisfd May 6, 2024
b9bac61
wip cast
louisfd May 8, 2024
6aa7ccc
cast kind float
louisfd May 8, 2024
16e38ac
refactor elements
louisfd May 9, 2024
4f7c56f
make tests work
louisfd May 9, 2024
6375736
cast kind done
louisfd May 9, 2024
0c50975
rename gpu macro
louisfd May 9, 2024
da92893
refactor
louisfd May 10, 2024
2047e47
T::new
louisfd May 10, 2024
3292198
rename cast
louisfd May 10, 2024
9046baa
type system becoming great
louisfd May 11, 2024
8266bae
merge main + dirty fix for F64
louisfd May 11, 2024
9c347da
fmt
louisfd May 11, 2024
2cf7603
clippy
louisfd May 11, 2024
3b0052f
found the culprit
louisfd May 11, 2024
6a42148
typo
louisfd May 11, 2024
9a4a94c
add doc
louisfd May 13, 2024
80312ac
refactor codegen into files
louisfd May 13, 2024
a2b0369
minor refactor
louisfd May 13, 2024
d342d7e
prevent clippy from breaking tests
louisfd May 13, 2024
fa648e4
add doc
louisfd May 13, 2024
a2b63a3
more clean
louisfd May 13, 2024
f3341e9
fix expand outputs
louisfd May 13, 2024
cd183e9
traits, modules and parenthesis
louisfd May 13, 2024
652c68a
fmt
louisfd May 13, 2024
040bbe0
clippy
louisfd May 13, 2024
73ecc19
fmt again?
louisfd May 13, 2024
2e348f1
ops not compiling
louisfd May 15, 2024
c7a2454
make uint numeric
louisfd May 15, 2024
2bf7db9
rename new/constant to lit
louisfd May 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 20 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions crates/burn-cube-macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
[package]
authors = [
"nathanielsimard <nathaniel.simard.42@gmail.com>",
"louisfd <louisfd94@gmail.com",
]
categories = ["science"]
description = "TODO"
edition.workspace = true
keywords = []
license.workspace = true
name = "burn-cube-macros"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/burn-cube-macros"
version.workspace = true

[lib]
proc-macro = true

[features]
default = []
std = []

[dependencies]
proc-macro2 = { workspace = true }
quote = { workspace = true }
syn = { workspace = true }
derive-new = { workspace = true }
243 changes: 243 additions & 0 deletions crates/burn-cube-macros/src/analysis.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
use std::collections::HashMap;

use syn::{PathArguments, Stmt};

use crate::VariableKey;

#[derive(Debug)]
/// Information about a single variable's use in Cube code
/// Useful to figure out when the generated variable will need cloning
pub(crate) struct VariableAnalysis {
num_used: usize,
loop_level_declared: usize,
}

impl VariableAnalysis {
pub fn should_clone(&mut self, loop_level: usize) -> bool {
if self.num_used > 1 {
self.num_used -= 1;
true
} else {
self.loop_level_declared < loop_level
}
}
}

#[derive(Debug)]
/// Information about all variables in the Cube code, transmitted to codegen
pub(crate) struct CodeAnalysis {
pub variable_analyses: HashMap<VariableKey, VariableAnalysis>,
}

#[derive(Debug, Default)]
/// Reads the Cube code and accumulates information, to generate a CodeAnalysis artefact
pub(crate) struct CodeAnalysisBuilder {
declarations: Vec<(VariableKey, usize)>,
var_uses: Vec<VariableKey>,
}

impl CodeAnalysis {
pub fn should_clone(&mut self, ident: &syn::Ident, loop_level: usize) -> bool {
let key: VariableKey = ident.into();
match self.variable_analyses.remove(&key) {
Some(mut var) => {
let should_clone = var.should_clone(loop_level);
self.variable_analyses.insert(key, var);
should_clone
}
None => panic!("Ident {ident} not part of analysis"),
}
}

pub fn create(func: &syn::ItemFn) -> CodeAnalysis {
let code_analysis_builder = CodeAnalysisBuilder::default();
code_analysis_builder.analyze(func)
}
}

impl CodeAnalysisBuilder {
fn analyze(mut self, func: &syn::ItemFn) -> CodeAnalysis {
// Build the vector of (Id, depth), using recursion
self.signature_declarations(&func.sig);
self.find_occurrences_in_stmts(&func.block.stmts, 0);

CodeAnalysis {
variable_analyses: self.to_map(),
}
}

fn to_map(&self) -> HashMap<VariableKey, VariableAnalysis> {
// Run through the vec and build hashmap, without recursion
let mut variable_analyses = HashMap::<VariableKey, VariableAnalysis>::new();
for declaration in self.declarations.iter() {
let id = declaration.0.clone();
let new_analysis = match variable_analyses.remove(&id) {
Some(_) => {
panic!("Analysis: Multiple variables with the same identifier is not supported")
}
None => VariableAnalysis {
num_used: 0,
loop_level_declared: declaration.1,
},
};

variable_analyses.insert(id, new_analysis);
}

for id in self.var_uses.iter() {
let prev_analysis = variable_analyses.remove(id).unwrap_or_else(|| {
panic!(
"Analysis: Variable {:?} should be declared before it's used",
id
)
});
let new_analysis = VariableAnalysis {
num_used: prev_analysis.num_used + 1,
loop_level_declared: prev_analysis.loop_level_declared,
};
variable_analyses.insert(id.clone(), new_analysis);
}

variable_analyses
}

fn signature_declarations(&mut self, sig: &syn::Signature) {
for input in &sig.inputs {
match input {
syn::FnArg::Typed(pat) => {
let ident = &*pat.pat;
match ident {
syn::Pat::Ident(pat_ident) => {
let id = &pat_ident.ident;
self.declarations.push((id.into(), 0));
}
_ => todo!("Analysis: unsupported ident {ident:?}"),
}
}
_ => todo!("Analysis: unsupported input {input:?}"),
}
}
}

fn find_occurrences_in_stmts(&mut self, stmts: &Vec<Stmt>, depth: usize) {
for stmt in stmts {
match stmt {
// Declaration
syn::Stmt::Local(local) => {
let id = match &local.pat {
syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident),
syn::Pat::Type(pat_type) => Some(match &*pat_type.pat {
syn::Pat::Ident(pat_ident) => &pat_ident.ident,
_ => todo!("Analysis: unsupported typed path {:?}", pat_type.pat),
}),
syn::Pat::Wild(_) => None,
_ => todo!("Analysis: unsupported path {:?}", local.pat),
};
if let Some(id) = id {
self.declarations.push((id.into(), depth));
}
if let Some(local_init) = &local.init {
self.find_occurrences_in_expr(&local_init.expr, depth)
}
}
syn::Stmt::Expr(expr, _) => self.find_occurrences_in_expr(expr, depth),
_ => todo!("Analysis: unsupported stmt {stmt:?}"),
}
}
}

fn find_occurrences_in_expr(&mut self, expr: &syn::Expr, depth: usize) {
match expr {
syn::Expr::ForLoop(expr) => {
let depth = depth + 1;

// Declaration of iterator
if let syn::Pat::Ident(pat_ident) = &*expr.pat {
let id = &pat_ident.ident;
self.declarations.push((id.into(), depth));
}

self.find_occurrences_in_stmts(&expr.body.stmts, depth);
}
syn::Expr::While(expr) => {
let depth = depth + 1;

self.find_occurrences_in_expr(&expr.cond, depth);
self.find_occurrences_in_stmts(&expr.body.stmts, depth);
}
syn::Expr::Loop(expr) => {
let depth = depth + 1;

self.find_occurrences_in_stmts(&expr.body.stmts, depth);
}
syn::Expr::If(expr) => {
let depth = depth + 1;

self.find_occurrences_in_expr(&expr.cond, depth);
self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth);
if let Some((_, expr)) = &expr.else_branch {
if let syn::Expr::Block(expr_block) = &**expr {
self.find_occurrences_in_stmts(&expr_block.block.stmts, depth);
} else {
todo!("Analysis: Only block else expr is supported")
}
}
}
syn::Expr::Assign(expr) => {
self.find_occurrences_in_expr(&expr.left, depth);
self.find_occurrences_in_expr(&expr.right, depth);
}
syn::Expr::Index(expr) => {
self.find_occurrences_in_expr(&expr.expr, depth);
self.find_occurrences_in_expr(&expr.index, depth);
}
syn::Expr::Path(expr) => {
let ident = expr
.path
.get_ident()
.expect("Analysis: only ident path are supported.");

// Use
self.var_uses.push(ident.into());
}
syn::Expr::Binary(expr) => {
self.find_occurrences_in_expr(&expr.left, depth);
self.find_occurrences_in_expr(&expr.right, depth);
}
syn::Expr::Lit(_) => {}
syn::Expr::Call(expr) => {
match &*expr.func {
syn::Expr::Path(expr_path) => {
if let Some(first_segment) = expr_path.path.segments.first() {
// Check if the path segment has generic arguments
if let PathArguments::AngleBracketed(arguments) =
&first_segment.arguments
{
// Extract the generic arguments
for arg in &arguments.args {
match arg {
syn::GenericArgument::Type(_)
| syn::GenericArgument::Constraint(_) => {}
_ => todo!("Analysis: Generic {:?} not supported", arg),
}
}
}
}
}
_ => todo!("Analysis: unsupported func expr {:?}", expr.func),
}
for arg in expr.args.iter() {
self.find_occurrences_in_expr(arg, depth);
}
}
syn::Expr::MethodCall(expr) => {
self.find_occurrences_in_expr(&expr.receiver, depth);
for arg in expr.args.iter() {
self.find_occurrences_in_expr(arg, depth);
}
}
syn::Expr::Break(_) => {}
_ => todo!("Analysis: unsupported expr {expr:?}"),
}
}
}