From 9996482bbcca5e67e4d90d188aecb02db5e36cd6 Mon Sep 17 00:00:00 2001 From: lucarlig Date: Wed, 22 Oct 2025 09:35:25 +0100 Subject: [PATCH 1/4] feat: implement vadd for + in gpu functions --- examples/core/assign.desc | 7 + .../core/{vec_add_1.desc.off => vadd.desc} | 5 +- examples/core/vdiv.desc | 7 + examples/core/vec_add.desc | 48 ++++ examples/core/vmul.desc | 7 + src/codegen/mlir/to_mlir/types.rs | 207 +++++++++++++++++- tests/mlir/snapshots/mlir__core__assign.snap | 12 + tests/mlir/snapshots/mlir__core__load_ub.snap | 4 +- .../snapshots/mlir__core__load_ub_twice.snap | 8 +- tests/mlir/snapshots/mlir__core__vadd.snap | 15 ++ tests/mlir/snapshots/mlir__core__vdiv.snap | 15 ++ tests/mlir/snapshots/mlir__core__vec_add.snap | 17 ++ tests/mlir/snapshots/mlir__core__vmul.snap | 15 ++ 13 files changed, 351 insertions(+), 16 deletions(-) create mode 100644 examples/core/assign.desc rename examples/core/{vec_add_1.desc.off => vadd.desc} (55%) create mode 100644 examples/core/vdiv.desc create mode 100644 examples/core/vec_add.desc create mode 100644 examples/core/vmul.desc create mode 100644 tests/mlir/snapshots/mlir__core__assign.snap create mode 100644 tests/mlir/snapshots/mlir__core__vadd.snap create mode 100644 tests/mlir/snapshots/mlir__core__vdiv.snap create mode 100644 tests/mlir/snapshots/mlir__core__vec_add.snap create mode 100644 tests/mlir/snapshots/mlir__core__vmul.snap diff --git a/examples/core/assign.desc b/examples/core/assign.desc new file mode 100644 index 0000000..af7b0be --- /dev/null +++ b/examples/core/assign.desc @@ -0,0 +1,7 @@ +fn assign( + a: &r shrd gpu.global [i16; 16], + b: &r uniq gpu.global [i16; 16] +) -[grid: gpu.grid, X<16>>]-> () { + b = a; + () +} diff --git a/examples/core/vec_add_1.desc.off b/examples/core/vadd.desc similarity index 55% rename from examples/core/vec_add_1.desc.off rename to examples/core/vadd.desc index 1d9f64e..c89c49d 100644 --- a/examples/core/vec_add_1.desc.off +++ b/examples/core/vadd.desc @@ -1,8 +1,7 @@ fn add( a: &r shrd gpu.global [i16; 16], - b: &r shrd gpu.global [i16; 16], - c: &r uniq gpu.global [i16; 16] + b: &r shrd gpu.global [i16; 16] ) -[grid: gpu.grid, X<16>>]-> () { - c = a + b; + a + b; () } diff --git a/examples/core/vdiv.desc b/examples/core/vdiv.desc new file mode 100644 index 0000000..0240255 --- /dev/null +++ b/examples/core/vdiv.desc @@ -0,0 +1,7 @@ +fn div( + a: &r shrd gpu.global [i16; 16], + b: &r shrd gpu.global [i16; 16] +) -[grid: gpu.grid, X<16>>]-> () { + a / b; + () +} diff --git a/examples/core/vec_add.desc b/examples/core/vec_add.desc new file mode 100644 index 0000000..5e831e6 --- /dev/null +++ b/examples/core/vec_add.desc @@ -0,0 +1,48 @@ +// Vector addition kernel demonstrating Descend's safe GPU programming model +// This function showcases extended borrow checking, memory safety, and execution context tracking + +// Generic function with type parameters: +// - n: nat - Natural number parameter (for array size, though not used in this specific function) +// - r: prv - Provenance parameter tracking memory region/lifetime for all references +fn add( + // Shared reference to first input vector - multiple threads can read simultaneously + // Memory space: gpu.global (GPU global memory) + // Ownership: shrd (shared) - prevents write-after-read data races + // Type: 16-element array of 16-bit signed integers + a: &r shrd gpu.global [i16; 16], + + // Shared reference to second input vector - multiple threads can read simultaneously + // Same memory space and ownership constraints as 'a' + b: &r shrd gpu.global [i16; 16], + + // Unique reference to output vector - only one thread can write at a time + // Ownership: uniq (unique) - prevents write-after-write data races + // The compiler statically ensures no conflicting borrows exist + c: &r uniq gpu.global [i16; 16] + +// Execution context specification - defines how this function runs on GPU hardware +// - grid: gpu.grid, X<16>> - GPU execution grid with 1 block containing 16 threads +// - The type system ensures GPU memory is only accessed in GPU execution contexts +// - Prevents invalid cross-device memory accesses (CPU accessing GPU memory) +) -[grid: gpu.grid, X<16>>]-> () { + + // Vector addition operation - element-wise addition of arrays + // The compiler generates safe parallel code that: + // 1. Loads data from global memory to local memory for each thread + // 2. Performs vectorized addition using HIVM dialect operations + // 3. Stores results back to global memory safely + // The ownership system ensures this operation is race-free + // + // LAZY LOADING: Descend's compiler implements lazy loading strategies: + // - Memory loads are deferred until actually needed by computation + // - The HIVM dialect generates 'hivm.hir.load' operations that load from + // global memory (gm) to local memory (ub) only when data is accessed + // - This minimizes memory bandwidth usage and improves cache efficiency + // - The type system ensures loads happen in the correct execution context + // - Shared references enable read-only access without unnecessary copies + c = a + b; + + // Unit return value - indicates successful completion + // In MLIR, this becomes a 'return' operation + () +} diff --git a/examples/core/vmul.desc b/examples/core/vmul.desc new file mode 100644 index 0000000..f6b4ac1 --- /dev/null +++ b/examples/core/vmul.desc @@ -0,0 +1,7 @@ +fn mul( + a: &r shrd gpu.global [i16; 16], + b: &r shrd gpu.global [i16; 16] +) -[grid: gpu.grid, X<16>>]-> () { + a * b; + () +} diff --git a/src/codegen/mlir/to_mlir/types.rs b/src/codegen/mlir/to_mlir/types.rs index 4279dc5..a0f117a 100644 --- a/src/codegen/mlir/to_mlir/types.rs +++ b/src/codegen/mlir/to_mlir/types.rs @@ -81,6 +81,20 @@ fn ident_to_mlir<'c>(ident: &crate::ast::Ident, context: &'c Context) -> Type<'c } } +/// Helper function to map BinOp to HIVM operation names +fn binop_to_hivm_operation(binop: &crate::ast::BinOp) -> &'static str { + match binop { + crate::ast::BinOp::Add => "hivm.hir.vadd", + crate::ast::BinOp::Sub => unimplemented!("HIVM dialect does not have vsub operation - subtraction not supported in HIVM vector operations"), + crate::ast::BinOp::Mul => "hivm.hir.vmul", + crate::ast::BinOp::Div => "hivm.hir.vdiv", + crate::ast::BinOp::Mod => "hivm.hir.vmod", + // For now, only support arithmetic operations + // Other operations (comparisons, logical, bitwise) can be added later + _ => unimplemented!("HIVM operation for {:?} not yet implemented", binop), + } +} + /// Helper function to apply HIVM address space to a memref type string fn apply_hivm_address_space(base_str: String, mem: &Memory) -> String { if base_str.starts_with("memref<") { @@ -558,12 +572,182 @@ fn collect_used_parameters(fun: &crate::ast::FunDef) -> std::collections::HashSe used_params } +/// Generate body operations for GPU functions +fn generate_body_operations( + fun: &crate::ast::FunDef, + param_to_local: &std::collections::HashMap, + alloc_counter: &mut usize, + context: &Context, +) -> String { + use crate::ast::{Expr, ExprKind, PlaceExprKind, DataTyKind, Memory, TyKind}; + + let mut body_ops = String::new(); + + fn walk_expr( + expr: &Expr, + fun: &crate::ast::FunDef, + param_to_local: &std::collections::HashMap, + alloc_counter: &mut usize, + context: &Context, + body_ops: &mut String, + ) -> Option { + match &expr.expr { + ExprKind::BinOp(binop, lhs, rhs) => { + // Process left and right operands + let lhs_var = walk_expr(lhs, fun, param_to_local, alloc_counter, context, body_ops)?; + let rhs_var = walk_expr(rhs, fun, param_to_local, alloc_counter, context, body_ops)?; + + // Generate output allocation + let output_alloc = if *alloc_counter == 0 { + "%alloc".to_string() + } else { + format!("%alloc_{}", *alloc_counter - 1) + }; + *alloc_counter += 1; + + // Determine the type for allocation (use the type of lhs as reference) + let lhs_type = match &lhs.expr { + ExprKind::PlaceExpr(place_expr) => { + if let PlaceExprKind::Ident(ident) = &place_expr.pl_expr { + // Find the parameter declaration to get its type + if let Some(param_decl) = fun.param_decls.iter().find(|p| p.ident.name == ident.name) { + if let Some(param_ty) = ¶m_decl.ty { + // Convert the parameter type to ub address space + match ¶m_ty.ty { + TyKind::Data(data_ty) => { + match &data_ty.dty { + DataTyKind::At(inner, _) => { + let base_type = inner.to_mlir(context); + let base_str = base_type.to_string(); + apply_hivm_address_space(base_str, &Memory::GpuLocal) + } + DataTyKind::Ref(ref_dty) => { + let base_type = ref_dty.dty.to_mlir(context); + let base_str = base_type.to_string(); + apply_hivm_address_space(base_str, &Memory::GpuLocal) + } + _ => get_mlir_type_string_with_address_space(param_ty, context), + } + } + _ => get_mlir_type_string_with_address_space(param_ty, context), + } + } else { + return None; + } + } else { + return None; + } + } else { + return None; + } + } + _ => return None, + }; + + // Generate allocation + body_ops.push_str(&format!(" {} = memref.alloc() : {}\n", output_alloc, lhs_type)); + + // Generate HIVM operation + let hivm_op = binop_to_hivm_operation(binop); + body_ops.push_str(&format!( + " {} ins({}, {} : {}, {}) outs({} : {})\n", + hivm_op, + lhs_var, + rhs_var, + lhs_type, + lhs_type, + output_alloc, + lhs_type + )); + + Some(output_alloc) + } + ExprKind::PlaceExpr(place_expr) => { + if let PlaceExprKind::Ident(ident) = &place_expr.pl_expr { + // Return the local variable name for this parameter + param_to_local.get(&ident.name.to_string()).cloned() + } else { + None + } + } + ExprKind::Assign(place_expr, value_expr) => { + // Handle assignment: b = a + // First, evaluate the right-hand side (value expression) + let value_var = walk_expr(value_expr, fun, param_to_local, alloc_counter, context, body_ops)?; + + // Find the target parameter for the assignment + if let PlaceExprKind::Ident(ident) = &place_expr.pl_expr { + // Find the parameter declaration to get its type and index + if let Some((param_idx, param_decl)) = fun.param_decls.iter().enumerate().find(|(_, p)| p.ident.name == ident.name) { + if let Some(param_ty) = ¶m_decl.ty { + // Generate the target parameter type (should be gm address space) + let target_type = get_mlir_type_string_with_address_space(param_ty, context); + + // Generate the source type (should be ub address space for local allocations) + let source_type = match ¶m_ty.ty { + TyKind::Data(data_ty) => { + match &data_ty.dty { + DataTyKind::At(inner, _) => { + let base_type = inner.to_mlir(context); + let base_str = base_type.to_string(); + apply_hivm_address_space(base_str, &Memory::GpuLocal) + } + DataTyKind::Ref(ref_dty) => { + let base_type = ref_dty.dty.to_mlir(context); + let base_str = base_type.to_string(); + apply_hivm_address_space(base_str, &Memory::GpuLocal) + } + _ => get_mlir_type_string_with_address_space(param_ty, context), + } + } + _ => get_mlir_type_string_with_address_space(param_ty, context), + }; + + // Generate store operation: hivm.hir.store ins(value) outs(%argN) + body_ops.push_str(&format!( + " hivm.hir.store ins({} : {}) outs(%arg{} : {})\n", + value_var, + source_type, + param_idx, + target_type + )); + } + } + } + + // Assignment doesn't produce a value + None + } + ExprKind::Seq(exprs) => { + // Process sequence expressions, return the result of the last expression + let mut last_result = None; + for expr in exprs { + last_result = walk_expr(expr, fun, param_to_local, alloc_counter, context, body_ops); + } + last_result + } + ExprKind::Lit(_) => { + // Literals don't produce SSA values in this context + None + } + _ => { + // Other expression types not yet supported + None + } + } + } + + walk_expr(&fun.body.body, fun, param_to_local, alloc_counter, context, &mut body_ops); + body_ops +} + /// Generate load operations for GPU parameters +/// Returns (operations_string, param_to_local_map, final_alloc_counter) fn generate_load_operations( fun: &crate::ast::FunDef, param_usage: &std::collections::HashMap, context: &Context -) -> String { +) -> (String, std::collections::HashMap, usize) { use crate::ast::{DataTyKind, Memory, TyKind}; use std::collections::HashMap; @@ -617,19 +801,24 @@ fn generate_load_operations( }; // Generate alloc and load operations - load_ops.push_str(&format!(" %alloc{} = memref.alloc() : {}\n", alloc_counter, ub_type)); - load_ops.push_str(&format!(" hivm.hir.load ins(%arg{} : {}) outs(%alloc{} : {})\n", - i, gm_type, alloc_counter, ub_type)); + let alloc_name = if alloc_counter == 0 { + "%alloc".to_string() + } else { + format!("%alloc_{}", alloc_counter - 1) + }; + load_ops.push_str(&format!(" {} = memref.alloc() : {}\n", alloc_name, ub_type)); + load_ops.push_str(&format!(" hivm.hir.load ins(%arg{} : {}) outs({} : {})\n", + i, gm_type, alloc_name, ub_type)); // Map parameter to its local version - param_to_local.insert(param_name, format!("%alloc{}", alloc_counter)); + param_to_local.insert(param_name, alloc_name); alloc_counter += 1; } } } } - load_ops + (load_ops, param_to_local, alloc_counter) } /// Generate function with body including load operations for GPU parameters @@ -676,9 +865,13 @@ fn generate_function_with_body(fun: &crate::ast::FunDef, context: &Context) -> S let param_usage = collect_parameter_usage(fun); // Generate load operations for GPU parameters (only for read usage) - let load_ops = generate_load_operations(fun, ¶m_usage, context); + let (load_ops, param_to_local, mut alloc_counter) = generate_load_operations(fun, ¶m_usage, context); signature.push_str(&load_ops); + // Generate body operations (binary operations, etc.) + let body_ops = generate_body_operations(fun, ¶m_to_local, &mut alloc_counter, context); + signature.push_str(&body_ops); + signature.push_str(" return\n }\n"); signature } diff --git a/tests/mlir/snapshots/mlir__core__assign.snap b/tests/mlir/snapshots/mlir__core__assign.snap new file mode 100644 index 0000000..37d32e8 --- /dev/null +++ b/tests/mlir/snapshots/mlir__core__assign.snap @@ -0,0 +1,12 @@ +--- +source: tests/mlir/core.rs +expression: output +--- +module { + func.func @assign(%arg0: memref<16xi16, #hivm.address_space>, %arg1: memref<16xi16, #hivm.address_space>) attributes {hacc.entry, hacc.function_kind = #hacc.function_kind} { + %alloc = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.load ins(%arg0 : memref<16xi16, #hivm.address_space>) outs(%alloc : memref<16xi16, #hivm.address_space>) + hivm.hir.store ins(%alloc : memref<16xi16, #hivm.address_space>) outs(%arg1 : memref<16xi16, #hivm.address_space>) + return + } +} diff --git a/tests/mlir/snapshots/mlir__core__load_ub.snap b/tests/mlir/snapshots/mlir__core__load_ub.snap index 00e9a34..f602ca9 100644 --- a/tests/mlir/snapshots/mlir__core__load_ub.snap +++ b/tests/mlir/snapshots/mlir__core__load_ub.snap @@ -5,8 +5,8 @@ expression: output --- module { func.func @memory_load(%arg0: memref<16xi16, #hivm.address_space>) attributes {hacc.entry, hacc.function_kind = #hacc.function_kind} { - %alloc0 = memref.alloc() : memref<16xi16, #hivm.address_space> - hivm.hir.load ins(%arg0 : memref<16xi16, #hivm.address_space>) outs(%alloc0 : memref<16xi16, #hivm.address_space>) + %alloc = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.load ins(%arg0 : memref<16xi16, #hivm.address_space>) outs(%alloc : memref<16xi16, #hivm.address_space>) return } } diff --git a/tests/mlir/snapshots/mlir__core__load_ub_twice.snap b/tests/mlir/snapshots/mlir__core__load_ub_twice.snap index 228deb3..aced21f 100644 --- a/tests/mlir/snapshots/mlir__core__load_ub_twice.snap +++ b/tests/mlir/snapshots/mlir__core__load_ub_twice.snap @@ -5,10 +5,10 @@ expression: output --- module { func.func @memory_load(%arg0: memref<16xi16, #hivm.address_space>, %arg1: memref<16xi16, #hivm.address_space>) attributes {hacc.entry, hacc.function_kind = #hacc.function_kind} { - %alloc0 = memref.alloc() : memref<16xi16, #hivm.address_space> - hivm.hir.load ins(%arg0 : memref<16xi16, #hivm.address_space>) outs(%alloc0 : memref<16xi16, #hivm.address_space>) - %alloc1 = memref.alloc() : memref<16xi16, #hivm.address_space> - hivm.hir.load ins(%arg1 : memref<16xi16, #hivm.address_space>) outs(%alloc1 : memref<16xi16, #hivm.address_space>) + %alloc = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.load ins(%arg0 : memref<16xi16, #hivm.address_space>) outs(%alloc : memref<16xi16, #hivm.address_space>) + %alloc_0 = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.load ins(%arg1 : memref<16xi16, #hivm.address_space>) outs(%alloc_0 : memref<16xi16, #hivm.address_space>) return } } diff --git a/tests/mlir/snapshots/mlir__core__vadd.snap b/tests/mlir/snapshots/mlir__core__vadd.snap new file mode 100644 index 0000000..fbeaeb9 --- /dev/null +++ b/tests/mlir/snapshots/mlir__core__vadd.snap @@ -0,0 +1,15 @@ +--- +source: tests/mlir/core.rs +expression: output +--- +module { + func.func @add(%arg0: memref<16xi16, #hivm.address_space>, %arg1: memref<16xi16, #hivm.address_space>) attributes {hacc.entry, hacc.function_kind = #hacc.function_kind} { + %alloc = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.load ins(%arg0 : memref<16xi16, #hivm.address_space>) outs(%alloc : memref<16xi16, #hivm.address_space>) + %alloc_0 = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.load ins(%arg1 : memref<16xi16, #hivm.address_space>) outs(%alloc_0 : memref<16xi16, #hivm.address_space>) + %alloc_1 = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.vadd ins(%alloc, %alloc_0 : memref<16xi16, #hivm.address_space>, memref<16xi16, #hivm.address_space>) outs(%alloc_1 : memref<16xi16, #hivm.address_space>) + return + } +} diff --git a/tests/mlir/snapshots/mlir__core__vdiv.snap b/tests/mlir/snapshots/mlir__core__vdiv.snap new file mode 100644 index 0000000..bc8b7cb --- /dev/null +++ b/tests/mlir/snapshots/mlir__core__vdiv.snap @@ -0,0 +1,15 @@ +--- +source: tests/mlir/core.rs +expression: output +--- +module { + func.func @div(%arg0: memref<16xi16, #hivm.address_space>, %arg1: memref<16xi16, #hivm.address_space>) attributes {hacc.entry, hacc.function_kind = #hacc.function_kind} { + %alloc = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.load ins(%arg0 : memref<16xi16, #hivm.address_space>) outs(%alloc : memref<16xi16, #hivm.address_space>) + %alloc_0 = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.load ins(%arg1 : memref<16xi16, #hivm.address_space>) outs(%alloc_0 : memref<16xi16, #hivm.address_space>) + %alloc_1 = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.vdiv ins(%alloc, %alloc_0 : memref<16xi16, #hivm.address_space>, memref<16xi16, #hivm.address_space>) outs(%alloc_1 : memref<16xi16, #hivm.address_space>) + return + } +} diff --git a/tests/mlir/snapshots/mlir__core__vec_add.snap b/tests/mlir/snapshots/mlir__core__vec_add.snap new file mode 100644 index 0000000..312ef3d --- /dev/null +++ b/tests/mlir/snapshots/mlir__core__vec_add.snap @@ -0,0 +1,17 @@ +--- +source: tests/mlir/core.rs +assertion_line: 4 +expression: output +--- +module { + func.func @add(%arg0: memref<16xi16, #hivm.address_space>, %arg1: memref<16xi16, #hivm.address_space>, %arg2: memref<16xi16, #hivm.address_space>) attributes {hacc.entry, hacc.function_kind = #hacc.function_kind} { + %alloc = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.load ins(%arg0 : memref<16xi16, #hivm.address_space>) outs(%alloc : memref<16xi16, #hivm.address_space>) + %alloc_0 = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.load ins(%arg1 : memref<16xi16, #hivm.address_space>) outs(%alloc_0 : memref<16xi16, #hivm.address_space>) + %alloc_1 = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.vadd ins(%alloc, %alloc_0 : memref<16xi16, #hivm.address_space>, memref<16xi16, #hivm.address_space>) outs(%alloc_1 : memref<16xi16, #hivm.address_space>) + hivm.hir.store ins(%alloc_1 : memref<16xi16, #hivm.address_space>) outs(%arg2 : memref<16xi16, #hivm.address_space>) + return + } +} diff --git a/tests/mlir/snapshots/mlir__core__vmul.snap b/tests/mlir/snapshots/mlir__core__vmul.snap new file mode 100644 index 0000000..ffc3664 --- /dev/null +++ b/tests/mlir/snapshots/mlir__core__vmul.snap @@ -0,0 +1,15 @@ +--- +source: tests/mlir/core.rs +expression: output +--- +module { + func.func @mul(%arg0: memref<16xi16, #hivm.address_space>, %arg1: memref<16xi16, #hivm.address_space>) attributes {hacc.entry, hacc.function_kind = #hacc.function_kind} { + %alloc = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.load ins(%arg0 : memref<16xi16, #hivm.address_space>) outs(%alloc : memref<16xi16, #hivm.address_space>) + %alloc_0 = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.load ins(%arg1 : memref<16xi16, #hivm.address_space>) outs(%alloc_0 : memref<16xi16, #hivm.address_space>) + %alloc_1 = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.vmul ins(%alloc, %alloc_0 : memref<16xi16, #hivm.address_space>, memref<16xi16, #hivm.address_space>) outs(%alloc_1 : memref<16xi16, #hivm.address_space>) + return + } +} From 74b581c5d39472f99a0713309b8e3eae2677b3b3 Mon Sep 17 00:00:00 2001 From: lucarlig Date: Wed, 22 Oct 2025 12:19:20 +0100 Subject: [PATCH 2/4] feat: add type checks for MLIR codegen types --- .../error-examples/assign_to_shared_ref.desc | 7 ++ .../error-examples/vec_add_memory_issue.desc | 50 ++++++++ src/codegen/mlir/builder/control_flow.rs | 2 +- src/codegen/mlir/builder/place.rs | 15 +++ src/codegen/mlir/mod.rs | 16 --- src/codegen/mlir/to_mlir/types.rs | 112 +++--------------- src/lib.rs | 16 +-- src/main.rs | 18 +-- tests/mlir.rs | 3 + tests/mlir/core.rs | 2 - tests/mlir/error_examples.rs | 13 ++ 11 files changed, 106 insertions(+), 148 deletions(-) create mode 100644 examples/error-examples/assign_to_shared_ref.desc create mode 100644 examples/error-examples/vec_add_memory_issue.desc create mode 100644 tests/mlir/error_examples.rs diff --git a/examples/error-examples/assign_to_shared_ref.desc b/examples/error-examples/assign_to_shared_ref.desc new file mode 100644 index 0000000..9cc2aa1 --- /dev/null +++ b/examples/error-examples/assign_to_shared_ref.desc @@ -0,0 +1,7 @@ +fn assign_to_shared_ref( + a: &r shrd gpu.global [i16; 16], + b: &r shrd gpu.global [i16; 16] +) -[grid: gpu.grid, X<16>>]-> () { + b = a; // This should fail - cannot assign to shared reference + () +} diff --git a/examples/error-examples/vec_add_memory_issue.desc b/examples/error-examples/vec_add_memory_issue.desc new file mode 100644 index 0000000..caa2683 --- /dev/null +++ b/examples/error-examples/vec_add_memory_issue.desc @@ -0,0 +1,50 @@ +// Vector addition kernel demonstrating Descend's safe GPU programming model +// This function showcases extended borrow checking, memory safety, and execution context tracking + +// Generic function with type parameters: +// - n: nat - Natural number parameter (for array size, though not used in this specific function) +// - r: prv - Provenance parameter tracking memory region/lifetime for all references +fn add( + // Shared reference to first input vector - multiple threads can read simultaneously + // Memory space: gpu.global (GPU global memory) + // Ownership: shrd (shared) - prevents write-after-read data races + // Type: 16-element array of 16-bit signed integers + a: &r shrd gpu.global [i16; 16], + + // Shared reference to second input vector - multiple threads can read simultaneously + // Same memory space and ownership constraints as 'a' + b: &r shrd gpu.global [i16; 16], + + // ERROR: This parameter should be declared as 'unq' (unique) instead of 'shrd' (shared) + // because it's used for assignment (c = a + b). Shared references are read-only and + // prevent data races by allowing multiple concurrent readers. Unique references are + // required for write operations to ensure exclusive access and prevent race conditions. + // The Descend compiler will detect this ownership violation and fail compilation. + c: &r shrd gpu.global [i16; 16] + +// Execution context specification - defines how this function runs on GPU hardware +// - grid: gpu.grid, X<16>> - GPU execution grid with 1 block containing 16 threads +// - The type system ensures GPU memory is only accessed in GPU execution contexts +// - Prevents invalid cross-device memory accesses (CPU accessing GPU memory) +) -[grid: gpu.grid, X<16>>]-> () { + + // Vector addition operation - element-wise addition of arrays + // The compiler generates safe parallel code that: + // 1. Loads data from global memory to local memory for each thread + // 2. Performs vectorized addition using HIVM dialect operations + // 3. Stores results back to global memory safely + // The ownership system ensures this operation is race-free + // + // LAZY LOADING: Descend's compiler implements lazy loading strategies: + // - Memory loads are deferred until actually needed by computation + // - The HIVM dialect generates 'hivm.hir.load' operations that load from + // global memory (gm) to local memory (ub) only when data is accessed + // - This minimizes memory bandwidth usage and improves cache efficiency + // - The type system ensures loads happen in the correct execution context + // - Shared references enable read-only access without unnecessary copies + c = a + b; + + // Unit return value - indicates successful completion + // In MLIR, this becomes a 'return' operation + () +} diff --git a/src/codegen/mlir/builder/control_flow.rs b/src/codegen/mlir/builder/control_flow.rs index 0f55f9d..3643e96 100644 --- a/src/codegen/mlir/builder/control_flow.rs +++ b/src/codegen/mlir/builder/control_flow.rs @@ -100,7 +100,7 @@ where let (then_region, true_value) = build_branch_region(case_true, ctx, build_expr)?; // Build the else region - let (else_region, false_value) = build_branch_region(case_false, ctx, build_expr)?; + let (else_region, _false_value) = build_branch_region(case_false, ctx, build_expr)?; // Determine result types based on whether branches produce values let result_types: Vec = if let Some(val) = true_value { diff --git a/src/codegen/mlir/builder/place.rs b/src/codegen/mlir/builder/place.rs index 77c0992..50954da 100644 --- a/src/codegen/mlir/builder/place.rs +++ b/src/codegen/mlir/builder/place.rs @@ -4,6 +4,7 @@ use super::super::error::MlirError; use super::context::MlirContext; use super::expr::build_expr; use crate::ast as desc; +use crate::ast::{DataTyKind, Ownership}; /// Build a place expression (variable lookup) pub fn build_place_expr<'ctx, 'a, 'b>( @@ -57,6 +58,20 @@ where { use desc::PlaceExprKind; + // Check if we're assigning to a reference - if so, it must be unique + if let Some(ty) = &place_expr.ty { + if let desc::TyKind::Data(data_ty) = &ty.ty { + if let DataTyKind::Ref(ref_dty) = &data_ty.dty { + if ref_dty.own != Ownership::Uniq { + return Err(MlirError::General(format!( + "Assignment to non-unique reference is not allowed. Expected unique reference, found {:?}", + ref_dty.own + ))); + } + } + } + } + // Evaluate the right-hand side value let value = build_expr(value_expr, ctx)? .ok_or_else(|| MlirError::General("Missing value for assignment".to_string()))?; diff --git a/src/codegen/mlir/mod.rs b/src/codegen/mlir/mod.rs index 6d3d1e1..470a7cd 100644 --- a/src/codegen/mlir/mod.rs +++ b/src/codegen/mlir/mod.rs @@ -72,22 +72,6 @@ fn build_module_internal(comp_unit: &CompilUnit) -> Result { Ok(builder.module().as_operation().to_string()) } -pub fn gen(comp_unit: &CompilUnit, _idx_checks: bool) -> String { - // Check if we need HIVM address spaces - if needs_hivm_address_space(comp_unit) { - to_mlir::types::generate_mlir_string_with_hivm(comp_unit) - } else { - // Use internal helper, but handle errors by falling back to string generation - match build_module_internal(comp_unit) { - Ok(mlir_string) => mlir_string, - Err(_) => { - // Fallback to string generation if internal building fails - to_mlir::types::generate_mlir_string_with_hivm(comp_unit) - } - } - } -} - pub fn gen_checked(comp_unit: &CompilUnit, _idx_checks: bool) -> Result { // Check if we need HIVM address spaces if needs_hivm_address_space(comp_unit) { diff --git a/src/codegen/mlir/to_mlir/types.rs b/src/codegen/mlir/to_mlir/types.rs index a0f117a..2dadf49 100644 --- a/src/codegen/mlir/to_mlir/types.rs +++ b/src/codegen/mlir/to_mlir/types.rs @@ -1,4 +1,4 @@ -use crate::ast::{AtomicTy, BaseExec, DataTy, DataTyKind, FunDef, Memory, Nat, NatCtx, ScalarTy, Ty, TyKind}; +use crate::ast::{AtomicTy, BaseExec, DataTy, DataTyKind, FunDef, Memory, Nat, NatCtx, Ownership, ScalarTy, Ty, TyKind}; use melior::{ dialect::func, ir::{ @@ -23,35 +23,6 @@ fn nat_to_dimension(nat: &Nat) -> String { } } -/// Helper function to create HACC attributes for GPU functions -/// This function creates proper MLIR attribute objects for HACC attributes. -/// Note: This requires the HACC dialect to be registered in the MLIR context. -fn create_hacc_attributes<'c>(context: &'c Context) -> Vec<(Identifier<'c>, Attribute<'c>)> { - // Create HACC entry attribute (hacc.entry) - this is a unit attribute - let entry_attr = Attribute::parse(context, "unit") - .expect("Failed to create HACC entry attribute"); - - // Create HACC function type attribute (hacc.function_kind = DEVICE) - // Note: This will fail if the HACC dialect is not registered in the context - let func_type_attr = Attribute::parse(context, "#hacc.function_kind") - .expect("Failed to create HACC function type attribute - ensure HACC dialect is registered"); - - vec![ - (Identifier::new(context, "hacc.entry"), entry_attr), - (Identifier::new(context, "hacc.function_kind"), func_type_attr), - ] -} - -/// Helper function to generate HACC attributes string for GPU functions -/// This function generates the MLIR string representation of HACC attributes -/// for use in string-based MLIR generation. The proper attribute objects -/// are available through create_hacc_attributes() when the HACC dialect is registered. -fn generate_hacc_attributes_string(context: &Context) -> String { - // Generate the string representation directly since we know the exact format - // The create_hacc_attributes() function is available for when the HACC dialect is registered - " attributes {hacc.entry, hacc.function_kind = #hacc.function_kind}".to_string() -} - /// Helper function to convert ScalarTy to MLIR Type fn scalar_ty_to_mlir<'c>(scalar_ty: &ScalarTy, context: &'c Context) -> Type<'c> { match scalar_ty { @@ -505,73 +476,6 @@ fn collect_parameter_usage(fun: &crate::ast::FunDef) -> std::collections::HashMa param_usage } -/// Collect which parameters are referenced in the function body (legacy function for compatibility) -fn collect_used_parameters(fun: &crate::ast::FunDef) -> std::collections::HashSet { - use crate::ast::{Expr, ExprKind, PlaceExprKind}; - use std::collections::HashSet; - - let mut used_params = HashSet::new(); - - fn walk_expr(expr: &Expr, used_params: &mut HashSet) { - match &expr.expr { - ExprKind::PlaceExpr(place_expr) => { - if let PlaceExprKind::Ident(ident) = &place_expr.pl_expr { - used_params.insert(ident.name.to_string()); - } - } - ExprKind::BinOp(_, lhs, rhs) => { - walk_expr(lhs, used_params); - walk_expr(rhs, used_params); - } - ExprKind::Let(_, _, value_expr) => { - walk_expr(value_expr, used_params); - } - ExprKind::Seq(exprs) => { - for expr in exprs { - walk_expr(expr, used_params); - } - } - ExprKind::Assign(place_expr, value_expr) => { - if let PlaceExprKind::Ident(ident) = &place_expr.pl_expr { - used_params.insert(ident.name.to_string()); - } - walk_expr(value_expr, used_params); - } - ExprKind::App(_, _, args) => { - for arg in args { - walk_expr(arg, used_params); - } - } - ExprKind::IfElse(cond, case_true, case_false) => { - walk_expr(cond, used_params); - walk_expr(case_true, used_params); - walk_expr(case_false, used_params); - } - ExprKind::If(cond, case_true) => { - walk_expr(cond, used_params); - walk_expr(case_true, used_params); - } - ExprKind::ForNat(_, _, body) => { - walk_expr(body, used_params); - } - ExprKind::Ref(_, _, place_expr) => { - if let PlaceExprKind::Ident(ident) = &place_expr.pl_expr { - used_params.insert(ident.name.to_string()); - } - } - ExprKind::Unsafe(expr) => { - walk_expr(expr, used_params); - } - _ => { - // Other expression types don't contain variable references - } - } - } - - walk_expr(&fun.body.body, &mut used_params); - used_params -} - /// Generate body operations for GPU functions fn generate_body_operations( fun: &crate::ast::FunDef, @@ -680,6 +584,18 @@ fn generate_body_operations( // Find the parameter declaration to get its type and index if let Some((param_idx, param_decl)) = fun.param_decls.iter().enumerate().find(|(_, p)| p.ident.name == ident.name) { if let Some(param_ty) = ¶m_decl.ty { + // Check if we're assigning to a reference - if so, it must be unique + if let TyKind::Data(data_ty) = ¶m_ty.ty { + if let DataTyKind::Ref(ref_dty) = &data_ty.dty { + if ref_dty.own != Ownership::Uniq { + panic!( + "Assignment to non-unique reference is not allowed. Expected unique reference, found {:?}", + ref_dty.own + ); + } + } + } + // Generate the target parameter type (should be gm address space) let target_type = get_mlir_type_string_with_address_space(param_ty, context); @@ -856,7 +772,7 @@ fn generate_function_with_body(fun: &crate::ast::FunDef, context: &Context) -> S // TODO: When HACC dialect is registered in the MLIR context, replace this with: // let hacc_attributes = create_hacc_attributes(context); // and use the attributes with MLIR operation builders instead of string generation - signature.push_str(&generate_hacc_attributes_string(context)); + signature.push_str(" attributes {hacc.entry, hacc.function_kind = #hacc.function_kind}"); } signature.push_str(" {\n"); diff --git a/src/lib.rs b/src/lib.rs index e809ddf..998286a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,18 +29,4 @@ pub fn compile(file_path: &str, backend: Backend) -> Result<(String, String), Er let ast_string = format!("{:#?}", compil_unit.items); Ok((code_string, ast_string)) -} - -pub fn compile_unchecked(file_path: &str, backend: Backend) -> Result { - let source = parser::SourceCode::from_file(file_path)?; - let mut compil_unit = parser::parse(&source)?; - - ty_check::ty_check(&mut compil_unit)?; - - let code_string = match backend { - Backend::Cuda => codegen::cuda::gen(&compil_unit, false), - Backend::Mlir => codegen::mlir::gen(&compil_unit, false), - }; - - Ok(code_string) -} +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index b7c0b82..344b687 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,10 +21,6 @@ struct Args { /// Print Ast #[arg(short, long)] print_ast: bool, - - /// Skip MLIR verification checks - #[arg(long)] - no_checks: bool, } /// Backend selection passed via CLI @@ -52,23 +48,13 @@ fn main() { let output_dir = &args.output_dir; // Compile using Descend - let (code_string, ast_string) = if args.no_checks { - let code_string = match descend::compile_unchecked(&input_path.to_string_lossy(), backend) { + let (code_string, ast_string) = match descend::compile(&input_path.to_string_lossy(), backend) { Ok(output) => output, Err(e) => { eprintln!("Compilation failed: {:?}", e); std::process::exit(1); } - }; - (code_string, String::new()) - } else { - match descend::compile(&input_path.to_string_lossy(), backend) { - Ok(output) => output, - Err(e) => { - eprintln!("Compilation failed: {:?}", e); - std::process::exit(1); - } - } + }; // Generate output file path with appropriate extension based on backend diff --git a/tests/mlir.rs b/tests/mlir.rs index ed3c60b..03c4cff 100644 --- a/tests/mlir.rs +++ b/tests/mlir.rs @@ -1,4 +1,7 @@ #[path = "mlir/core.rs"] mod core; +#[path = "mlir/error_examples.rs"] +mod error_examples; + const BACKEND: descend::Backend = descend::Backend::Mlir; diff --git a/tests/mlir/core.rs b/tests/mlir/core.rs index f7fa120..0aaec8b 100644 --- a/tests/mlir/core.rs +++ b/tests/mlir/core.rs @@ -1,4 +1,2 @@ -type Res = Result<(), descend::error::ErrorReported>; - // Automatically generate tests for all .desc files in examples/core/ descend_derive::generate_desc_tests!(); diff --git a/tests/mlir/error_examples.rs b/tests/mlir/error_examples.rs new file mode 100644 index 0000000..2df4710 --- /dev/null +++ b/tests/mlir/error_examples.rs @@ -0,0 +1,13 @@ +use super::BACKEND; + +#[test] +#[should_panic] +fn vec_add_memory_issue_error() { + descend::compile("examples/error-examples/vec_add_memory_issue.desc", BACKEND).unwrap(); +} + +#[test] +#[should_panic] +fn assign_to_shared_ref_error() { + descend::compile("examples/error-examples/assign_to_shared_ref.desc", BACKEND).unwrap(); +} From d4f37fbc462fb734973693e6d787ebc0f9c4310b Mon Sep 17 00:00:00 2001 From: lucarlig Date: Wed, 22 Oct 2025 14:07:45 +0100 Subject: [PATCH 3/4] refactor: improve tomlir trait --- descend_derive/build.rs | 4 +- src/codegen/mlir/mod.rs | 46 +-- src/codegen/mlir/to_mlir/types.rs | 474 ++++++++++++++++-------------- src/lib.rs | 2 +- src/main.rs | 11 +- 5 files changed, 275 insertions(+), 262 deletions(-) diff --git a/descend_derive/build.rs b/descend_derive/build.rs index f5239fc..a301056 100644 --- a/descend_derive/build.rs +++ b/descend_derive/build.rs @@ -4,7 +4,7 @@ use std::path::Path; fn main() { // Tell Cargo to re-run this build script if any .desc files change let examples_dir = "examples/core"; - + if Path::new(examples_dir).exists() { // Walk through the directory and tell Cargo to re-run if any .desc files change if let Ok(entries) = fs::read_dir(examples_dir) { @@ -17,7 +17,7 @@ fn main() { } } } - + // Also watch the entire directory for new files println!("cargo:rerun-if-changed={}", examples_dir); } diff --git a/src/codegen/mlir/mod.rs b/src/codegen/mlir/mod.rs index 470a7cd..c62aa53 100644 --- a/src/codegen/mlir/mod.rs +++ b/src/codegen/mlir/mod.rs @@ -36,12 +36,13 @@ use builder::MlirBuilder; use error::MlirError; use melior::{ dialect::DialectRegistry, - ir::{Location, Module, operation::OperationLike}, + ir::{operation::OperationLike, Location, Module}, utility::register_all_dialects, Context, }; use crate::ast::CompilUnit; +use crate::ast::{DataTyKind, Memory, TyKind}; /// Internal helper function to build MLIR module fn build_module_internal(comp_unit: &CompilUnit) -> Result { @@ -92,13 +93,10 @@ pub fn gen_checked(comp_unit: &CompilUnit, _idx_checks: bool) -> Result bool { for item in &comp_unit.items { if let crate::ast::Item::FunDef(fun) = item { - // Only check the main function or functions that are not HIVM placeholders - if fun.ident.name == "main".into() || !is_hivm_placeholder_function(fun) { - for param in &fun.param_decls { - if let Some(ty) = ¶m.ty { - if has_gpu_memory(ty) { - return true; - } + for param in &fun.param_decls { + if let Some(ty) = ¶m.ty { + if has_gpu_memory(ty) { + return true; } } } @@ -107,31 +105,19 @@ fn needs_hivm_address_space(comp_unit: &CompilUnit) -> bool { false } -/// Check if a function is a HIVM placeholder function -fn is_hivm_placeholder_function(fun: &crate::ast::FunDef) -> bool { - fun.ident.name.starts_with("hivm_") -} - /// Check if a type has GPU memory qualifiers fn has_gpu_memory(ty: &crate::ast::Ty) -> bool { + fn mem_is_gpu(mem: &Memory) -> bool { + matches!( + mem, + Memory::GpuGlobal | Memory::GpuShared | Memory::GpuLocal + ) + } + match &ty.ty { - crate::ast::TyKind::Data(data_ty) => match &data_ty.dty { - crate::ast::DataTyKind::At(_, mem) => { - matches!( - mem, - crate::ast::Memory::GpuGlobal - | crate::ast::Memory::GpuShared - | crate::ast::Memory::GpuLocal - ) - } - crate::ast::DataTyKind::Ref(ref_dty) => { - matches!( - ref_dty.mem, - crate::ast::Memory::GpuGlobal - | crate::ast::Memory::GpuShared - | crate::ast::Memory::GpuLocal - ) - } + TyKind::Data(data_ty) => match &data_ty.dty { + DataTyKind::At(_, mem) => mem_is_gpu(mem), + DataTyKind::Ref(ref_dty) => mem_is_gpu(&ref_dty.mem), _ => false, }, _ => false, diff --git a/src/codegen/mlir/to_mlir/types.rs b/src/codegen/mlir/to_mlir/types.rs index 2dadf49..f693e34 100644 --- a/src/codegen/mlir/to_mlir/types.rs +++ b/src/codegen/mlir/to_mlir/types.rs @@ -1,10 +1,13 @@ -use crate::ast::{AtomicTy, BaseExec, DataTy, DataTyKind, FunDef, Memory, Nat, NatCtx, Ownership, ScalarTy, Ty, TyKind}; +use crate::ast::{ + AtomicTy, BaseExec, DataTy, DataTyKind, FunDef, Memory, Nat, NatCtx, Ownership, ScalarTy, Ty, + TyKind, +}; use melior::{ dialect::func, ir::{ - attribute::{StringAttribute, TypeAttribute, Attribute}, - r#type::{FunctionType, IntegerType, TupleType}, - Location, Operation, Region, Type, Identifier, + attribute::{Attribute, StringAttribute, TypeAttribute}, + r#type::{FunctionType, IntegerType, MemRefType, TupleType}, + Identifier, Location, Operation, Region, Type, }, Context, }; @@ -13,42 +16,67 @@ pub trait ToMlir { fn to_mlir<'c>(&self, context: &'c Context) -> Self::Output<'c>; } -/// Helper function to convert Nat to a dimension string for MLIR types -fn nat_to_dimension(nat: &Nat) -> String { - // Try to evaluate the Nat with an empty context - let nat_ctx = NatCtx::new(); - match nat.eval(&nat_ctx) { - Ok(size) => size.to_string(), - Err(_) => "?".to_string(), // Use dynamic dimension for non-literal Nat +impl Nat { + fn to_dimension_i64(self: &Self) -> i64 { + // Try to evaluate the Nat with an empty context + let nat_ctx = NatCtx::new(); + match self.eval(&nat_ctx) { + Ok(size) => size as i64, + Err(_) => panic!( + "Array dimensions must be compile-time known. Dynamic arrays are not supported." + ), + } } } -/// Helper function to convert ScalarTy to MLIR Type -fn scalar_ty_to_mlir<'c>(scalar_ty: &ScalarTy, context: &'c Context) -> Type<'c> { - match scalar_ty { - ScalarTy::Unit => Type::parse(context, "none").expect("Failed to parse none type"), - ScalarTy::U8 => IntegerType::new(context, 8).into(), - ScalarTy::U32 => IntegerType::new(context, 32).into(), - ScalarTy::U64 => IntegerType::new(context, 64).into(), - ScalarTy::I32 => IntegerType::new(context, 32).into(), - ScalarTy::I64 => IntegerType::new(context, 64).into(), - ScalarTy::F32 => Type::parse(context, "f32").expect("Failed to parse f32 type"), - ScalarTy::F64 => Type::parse(context, "f64").expect("Failed to parse f64 type"), - ScalarTy::Bool => IntegerType::new(context, 1).into(), - ScalarTy::Gpu => IntegerType::new(context, 32).into(), // this will be ignored in the MLIR backend +impl ToMlir for ScalarTy { + type Output<'c> = Type<'c>; + + fn to_mlir<'c>(&self, context: &'c Context) -> Type<'c> { + match self { + ScalarTy::Unit => Type::none(context), + ScalarTy::U8 => Type::from(IntegerType::new(context, 8)), + ScalarTy::U32 => IntegerType::new(context, 32).into(), + ScalarTy::U64 => IntegerType::new(context, 64).into(), + ScalarTy::I32 => IntegerType::new(context, 32).into(), + ScalarTy::I64 => IntegerType::new(context, 64).into(), + ScalarTy::F32 => Type::float32(context), + ScalarTy::F64 => Type::float64(context), + ScalarTy::Bool => IntegerType::new(context, 1).into(), + ScalarTy::Gpu => IntegerType::new(context, 32).into(), // this will be ignored in the MLIR backend + } } } -/// Helper function to convert DataTyKind::Ident to MLIR Type -fn ident_to_mlir<'c>(ident: &crate::ast::Ident, context: &'c Context) -> Type<'c> { - match ident.name.as_ref() { - "i16" => IntegerType::new(context, 16).into(), - "i8" => IntegerType::new(context, 8).into(), - "u16" => IntegerType::new(context, 16).into(), - _ => unimplemented!( - "Type identifier '{}' not yet supported in MLIR conversion", - ident.name - ), +impl ScalarTy { + /// Convert scalar reference to MLIR memref type with address space + pub fn to_mlir_ref<'c>(&self, mem: &Memory, context: &'c Context) -> Type<'c> { + // Scalar reference -> rank-0 memref + let elem_type = self.to_mlir(context); + let memref_str = format!("memref<{}>", elem_type); + let base_type = + Type::parse(context, &memref_str).expect("Failed to parse rank-0 memref type"); + + // Add HIVM address space if needed + let base_str = base_type.to_string(); + let final_str = apply_hivm_address_space(base_str, mem); + parse_type_with_hivm_fallback(context, final_str, base_type) + } +} + +impl ToMlir for crate::ast::Ident { + type Output<'c> = Type<'c>; + + fn to_mlir<'c>(&self, context: &'c Context) -> Type<'c> { + match self.name.as_ref() { + "i16" => IntegerType::new(context, 16).into(), + "i8" => IntegerType::new(context, 8).into(), + "u16" => IntegerType::new(context, 16).into(), + _ => unimplemented!( + "Type identifier '{}' not yet supported in MLIR conversion", + self.name + ), + } } } @@ -112,28 +140,6 @@ fn parse_type_with_hivm_fallback<'c>( } } -/// Helper function to convert DataTy to element type for arrays/memrefs -fn data_ty_to_element_type<'c>(data_ty: &DataTy, context: &'c Context) -> Type<'c> { - match &data_ty.dty { - DataTyKind::Scalar(scalar_ty) => scalar_ty_to_mlir(scalar_ty, context), - DataTyKind::Ident(ident) => ident_to_mlir(ident, context), - _ => data_ty.to_mlir(context), // Fallback to full conversion for complex types - } -} - -/// Helper function to convert scalar reference to MLIR type -fn ref_scalar_to_mlir<'c>(scalar_ty: &ScalarTy, mem: &Memory, context: &'c Context) -> Type<'c> { - // Scalar reference -> rank-0 memref - let elem_type = scalar_ty_to_mlir(scalar_ty, context); - let memref_str = format!("memref<{}>", elem_type); - let base_type = Type::parse(context, &memref_str).expect("Failed to parse rank-0 memref type"); - - // Add HIVM address space if needed - let base_str = base_type.to_string(); - let final_str = apply_hivm_address_space(base_str, mem); - parse_type_with_hivm_fallback(context, final_str, base_type) -} - /// Helper function to convert array reference to MLIR type fn ref_array_to_mlir<'c>( elem_ty: &DataTy, @@ -142,10 +148,9 @@ fn ref_array_to_mlir<'c>( context: &'c Context, ) -> Type<'c> { // Array reference -> memref with dimensions - let elem_type = data_ty_to_element_type(elem_ty, context); - let dim = nat_to_dimension(size); - let memref_str = format!("memref<{}x{}>", dim, elem_type); - let base_type = Type::parse(context, &memref_str).expect("Failed to parse array memref type"); + let elem_type = elem_ty.to_mlir(context); + let dim = size.to_dimension_i64(); + let base_type: Type<'c> = MemRefType::new(elem_type, &[dim], None, None).into(); // Add HIVM address space if needed let base_str = base_type.to_string(); @@ -158,15 +163,14 @@ fn ref_at_to_mlir<'c>(inner: &DataTy, mem: &Memory, context: &'c Context) -> Typ // Build base memref type from the inner data type, then append address space if needed let base_type = match &inner.dty { DataTyKind::Scalar(scalar_ty) => { - let elem_type = scalar_ty_to_mlir(scalar_ty, context); + let elem_type = scalar_ty.to_mlir(context); let memref_str = format!("memref<{}>", elem_type); Type::parse(context, &memref_str).expect("Failed to parse scalar memref type") } DataTyKind::Array(elem_ty, size) | DataTyKind::ArrayShape(elem_ty, size) => { let elem_type = elem_ty.to_mlir(context); - let dim = nat_to_dimension(size); - let memref_str = format!("memref<{}x{}>", dim, elem_type); - Type::parse(context, &memref_str).expect("Failed to parse array memref type") + let dim = size.to_dimension_i64(); + MemRefType::new(elem_type, &[dim], None, None).into() } DataTyKind::Tuple(_) => { unimplemented!("Tuple references with At not yet supported in MLIR conversion") @@ -222,7 +226,7 @@ impl ToMlir for DataTy { fn to_mlir<'c>(&self, context: &'c Context) -> Type<'c> { match &self.dty { - DataTyKind::Scalar(scalar_ty) => scalar_ty_to_mlir(scalar_ty, context), + DataTyKind::Scalar(scalar_ty) => scalar_ty.to_mlir(context), DataTyKind::Atomic(atomic_ty) => match atomic_ty { AtomicTy::AtomicU32 => IntegerType::new(context, 32).into(), AtomicTy::AtomicI32 => IntegerType::new(context, 32).into(), @@ -232,20 +236,16 @@ impl ToMlir for DataTy { elem_tys.iter().map(|ty| ty.to_mlir(context)).collect(); TupleType::new(context, &elem_types).into() } - DataTyKind::Ident(ident) => ident_to_mlir(ident, context), + DataTyKind::Ident(ident) => ident.to_mlir(context), DataTyKind::Array(elem_ty, size) => { let elem_type = elem_ty.to_mlir(context); - let dim = nat_to_dimension(size); - let memref_str = format!("memref<{}x{}>", dim, elem_type); - Type::parse(context, &memref_str).expect("Failed to parse memref type") + let dim = size.to_dimension_i64(); + MemRefType::new(elem_type, &[dim], None, None).into() } DataTyKind::ArrayShape(elem_ty, size) => { - // ArrayShape is similar to Array but may have different semantics - // For now, treat it the same as Array using memref let elem_type = elem_ty.to_mlir(context); - let dim = nat_to_dimension(size); - let memref_str = format!("memref<{}x{}>", dim, elem_type); - Type::parse(context, &memref_str).expect("Failed to parse memref type") + let dim = size.to_dimension_i64(); + MemRefType::new(elem_type, &[dim], None, None).into() } DataTyKind::Struct(_) => { unimplemented!("Struct types not yet supported in MLIR conversion") @@ -262,9 +262,7 @@ impl ToMlir for DataTy { DataTyKind::Ref(ref_dty) => { // Convert the inner DataTy to MLIR based on its kind match &ref_dty.dty.dty { - DataTyKind::Scalar(scalar_ty) => { - ref_scalar_to_mlir(scalar_ty, &ref_dty.mem, context) - } + DataTyKind::Scalar(scalar_ty) => scalar_ty.to_mlir_ref(&ref_dty.mem, context), DataTyKind::Array(elem_ty, size) => { ref_array_to_mlir(elem_ty, size, &ref_dty.mem, context) } @@ -351,15 +349,16 @@ impl ToMlir for FunDef { let function_name = &self.ident.name; // Create attributes based on execution type - let attributes: Vec<(Identifier, Attribute)> = if matches!(self.exec.exec.base, BaseExec::GpuGrid(_, _)) { - // For GPU functions, we need to create HACC attributes - // Since we can't easily create HACC dialect attributes here, we'll use empty attributes - // The actual GPU attributes will be handled in the string-based generation path - vec![] - } else { - // For CPU functions, no special attributes needed - vec![] - }; + let attributes: Vec<(Identifier, Attribute)> = + if matches!(self.exec.exec.base, BaseExec::GpuGrid(_, _)) { + // For GPU functions, we need to create HACC attributes + // Since we can't easily create HACC dialect attributes here, we'll use empty attributes + // The actual GPU attributes will be handled in the string-based generation path + vec![] + } else { + // For CPU functions, no special attributes needed + vec![] + }; func::func( context, @@ -381,17 +380,20 @@ struct ParameterUsage { impl ParameterUsage { fn new() -> Self { - Self { read: false, write: false } + Self { + read: false, + write: false, + } } - + fn mark_read(&mut self) { self.read = true; } - + fn mark_write(&mut self) { self.write = true; } - + fn needs_ub_allocation(&self) -> bool { // Only allocate ub memory if the parameter is read from // (parameters that are only written to can write directly to global memory) @@ -400,18 +402,23 @@ impl ParameterUsage { } /// Collect which parameters are referenced in the function body and how they are used -fn collect_parameter_usage(fun: &crate::ast::FunDef) -> std::collections::HashMap { +fn collect_parameter_usage( + fun: &crate::ast::FunDef, +) -> std::collections::HashMap { use crate::ast::{Expr, ExprKind, PlaceExprKind}; use std::collections::HashMap; - + let mut param_usage = HashMap::new(); - + fn walk_expr(expr: &Expr, param_usage: &mut HashMap) { match &expr.expr { ExprKind::PlaceExpr(place_expr) => { // This is a read operation if let PlaceExprKind::Ident(ident) = &place_expr.pl_expr { - param_usage.entry(ident.name.to_string()).or_insert_with(ParameterUsage::new).mark_read(); + param_usage + .entry(ident.name.to_string()) + .or_insert_with(ParameterUsage::new) + .mark_read(); } } ExprKind::BinOp(_, lhs, rhs) => { @@ -429,14 +436,20 @@ fn collect_parameter_usage(fun: &crate::ast::FunDef) -> std::collections::HashMa ExprKind::Assign(place_expr, value_expr) => { // This is a write operation if let PlaceExprKind::Ident(ident) = &place_expr.pl_expr { - param_usage.entry(ident.name.to_string()).or_insert_with(ParameterUsage::new).mark_write(); + param_usage + .entry(ident.name.to_string()) + .or_insert_with(ParameterUsage::new) + .mark_write(); } walk_expr(value_expr, param_usage); } ExprKind::IdxAssign(place_expr, _, value_expr) => { // This is a write operation if let PlaceExprKind::Ident(ident) = &place_expr.pl_expr { - param_usage.entry(ident.name.to_string()).or_insert_with(ParameterUsage::new).mark_write(); + param_usage + .entry(ident.name.to_string()) + .or_insert_with(ParameterUsage::new) + .mark_write(); } walk_expr(value_expr, param_usage); } @@ -460,7 +473,10 @@ fn collect_parameter_usage(fun: &crate::ast::FunDef) -> std::collections::HashMa ExprKind::Ref(_, _, place_expr) => { // Taking a reference is a read operation if let PlaceExprKind::Ident(ident) = &place_expr.pl_expr { - param_usage.entry(ident.name.to_string()).or_insert_with(ParameterUsage::new).mark_read(); + param_usage + .entry(ident.name.to_string()) + .or_insert_with(ParameterUsage::new) + .mark_read(); } } ExprKind::Unsafe(expr) => { @@ -471,7 +487,7 @@ fn collect_parameter_usage(fun: &crate::ast::FunDef) -> std::collections::HashMa } } } - + walk_expr(&fun.body.body, &mut param_usage); param_usage } @@ -483,10 +499,10 @@ fn generate_body_operations( alloc_counter: &mut usize, context: &Context, ) -> String { - use crate::ast::{Expr, ExprKind, PlaceExprKind, DataTyKind, Memory, TyKind}; - + use crate::ast::{DataTyKind, Expr, ExprKind, Memory, PlaceExprKind, TyKind}; + let mut body_ops = String::new(); - + fn walk_expr( expr: &Expr, fun: &crate::ast::FunDef, @@ -498,9 +514,11 @@ fn generate_body_operations( match &expr.expr { ExprKind::BinOp(binop, lhs, rhs) => { // Process left and right operands - let lhs_var = walk_expr(lhs, fun, param_to_local, alloc_counter, context, body_ops)?; - let rhs_var = walk_expr(rhs, fun, param_to_local, alloc_counter, context, body_ops)?; - + let lhs_var = + walk_expr(lhs, fun, param_to_local, alloc_counter, context, body_ops)?; + let rhs_var = + walk_expr(rhs, fun, param_to_local, alloc_counter, context, body_ops)?; + // Generate output allocation let output_alloc = if *alloc_counter == 0 { "%alloc".to_string() @@ -508,32 +526,42 @@ fn generate_body_operations( format!("%alloc_{}", *alloc_counter - 1) }; *alloc_counter += 1; - + // Determine the type for allocation (use the type of lhs as reference) let lhs_type = match &lhs.expr { ExprKind::PlaceExpr(place_expr) => { if let PlaceExprKind::Ident(ident) = &place_expr.pl_expr { // Find the parameter declaration to get its type - if let Some(param_decl) = fun.param_decls.iter().find(|p| p.ident.name == ident.name) { + if let Some(param_decl) = + fun.param_decls.iter().find(|p| p.ident.name == ident.name) + { if let Some(param_ty) = ¶m_decl.ty { // Convert the parameter type to ub address space match ¶m_ty.ty { - TyKind::Data(data_ty) => { - match &data_ty.dty { - DataTyKind::At(inner, _) => { - let base_type = inner.to_mlir(context); - let base_str = base_type.to_string(); - apply_hivm_address_space(base_str, &Memory::GpuLocal) - } - DataTyKind::Ref(ref_dty) => { - let base_type = ref_dty.dty.to_mlir(context); - let base_str = base_type.to_string(); - apply_hivm_address_space(base_str, &Memory::GpuLocal) - } - _ => get_mlir_type_string_with_address_space(param_ty, context), + TyKind::Data(data_ty) => match &data_ty.dty { + DataTyKind::At(inner, _) => { + let base_type = inner.to_mlir(context); + let base_str = base_type.to_string(); + apply_hivm_address_space( + base_str, + &Memory::GpuLocal, + ) + } + DataTyKind::Ref(ref_dty) => { + let base_type = ref_dty.dty.to_mlir(context); + let base_str = base_type.to_string(); + apply_hivm_address_space( + base_str, + &Memory::GpuLocal, + ) } - } - _ => get_mlir_type_string_with_address_space(param_ty, context), + _ => get_mlir_type_string_with_address_space( + param_ty, context, + ), + }, + _ => get_mlir_type_string_with_address_space( + param_ty, context, + ), } } else { return None; @@ -547,23 +575,20 @@ fn generate_body_operations( } _ => return None, }; - + // Generate allocation - body_ops.push_str(&format!(" {} = memref.alloc() : {}\n", output_alloc, lhs_type)); - + body_ops.push_str(&format!( + " {} = memref.alloc() : {}\n", + output_alloc, lhs_type + )); + // Generate HIVM operation let hivm_op = binop_to_hivm_operation(binop); body_ops.push_str(&format!( " {} ins({}, {} : {}, {}) outs({} : {})\n", - hivm_op, - lhs_var, - rhs_var, - lhs_type, - lhs_type, - output_alloc, - lhs_type + hivm_op, lhs_var, rhs_var, lhs_type, lhs_type, output_alloc, lhs_type )); - + Some(output_alloc) } ExprKind::PlaceExpr(place_expr) => { @@ -577,12 +602,24 @@ fn generate_body_operations( ExprKind::Assign(place_expr, value_expr) => { // Handle assignment: b = a // First, evaluate the right-hand side (value expression) - let value_var = walk_expr(value_expr, fun, param_to_local, alloc_counter, context, body_ops)?; - + let value_var = walk_expr( + value_expr, + fun, + param_to_local, + alloc_counter, + context, + body_ops, + )?; + // Find the target parameter for the assignment if let PlaceExprKind::Ident(ident) = &place_expr.pl_expr { // Find the parameter declaration to get its type and index - if let Some((param_idx, param_decl)) = fun.param_decls.iter().enumerate().find(|(_, p)| p.ident.name == ident.name) { + if let Some((param_idx, param_decl)) = fun + .param_decls + .iter() + .enumerate() + .find(|(_, p)| p.ident.name == ident.name) + { if let Some(param_ty) = ¶m_decl.ty { // Check if we're assigning to a reference - if so, it must be unique if let TyKind::Data(data_ty) = ¶m_ty.ty { @@ -595,42 +632,38 @@ fn generate_body_operations( } } } - + // Generate the target parameter type (should be gm address space) - let target_type = get_mlir_type_string_with_address_space(param_ty, context); - + let target_type = + get_mlir_type_string_with_address_space(param_ty, context); + // Generate the source type (should be ub address space for local allocations) let source_type = match ¶m_ty.ty { - TyKind::Data(data_ty) => { - match &data_ty.dty { - DataTyKind::At(inner, _) => { - let base_type = inner.to_mlir(context); - let base_str = base_type.to_string(); - apply_hivm_address_space(base_str, &Memory::GpuLocal) - } - DataTyKind::Ref(ref_dty) => { - let base_type = ref_dty.dty.to_mlir(context); - let base_str = base_type.to_string(); - apply_hivm_address_space(base_str, &Memory::GpuLocal) - } - _ => get_mlir_type_string_with_address_space(param_ty, context), + TyKind::Data(data_ty) => match &data_ty.dty { + DataTyKind::At(inner, _) => { + let base_type = inner.to_mlir(context); + let base_str = base_type.to_string(); + apply_hivm_address_space(base_str, &Memory::GpuLocal) } - } + DataTyKind::Ref(ref_dty) => { + let base_type = ref_dty.dty.to_mlir(context); + let base_str = base_type.to_string(); + apply_hivm_address_space(base_str, &Memory::GpuLocal) + } + _ => get_mlir_type_string_with_address_space(param_ty, context), + }, _ => get_mlir_type_string_with_address_space(param_ty, context), }; - + // Generate store operation: hivm.hir.store ins(value) outs(%argN) body_ops.push_str(&format!( " hivm.hir.store ins({} : {}) outs(%arg{} : {})\n", - value_var, - source_type, - param_idx, - target_type + value_var, source_type, param_idx, target_type )); } } } - + // Assignment doesn't produce a value None } @@ -638,7 +671,8 @@ fn generate_body_operations( // Process sequence expressions, return the result of the last expression let mut last_result = None; for expr in exprs { - last_result = walk_expr(expr, fun, param_to_local, alloc_counter, context, body_ops); + last_result = + walk_expr(expr, fun, param_to_local, alloc_counter, context, body_ops); } last_result } @@ -652,39 +686,46 @@ fn generate_body_operations( } } } - - walk_expr(&fun.body.body, fun, param_to_local, alloc_counter, context, &mut body_ops); + + walk_expr( + &fun.body.body, + fun, + param_to_local, + alloc_counter, + context, + &mut body_ops, + ); body_ops } /// Generate load operations for GPU parameters /// Returns (operations_string, param_to_local_map, final_alloc_counter) fn generate_load_operations( - fun: &crate::ast::FunDef, + fun: &crate::ast::FunDef, param_usage: &std::collections::HashMap, - context: &Context + context: &Context, ) -> (String, std::collections::HashMap, usize) { use crate::ast::{DataTyKind, Memory, TyKind}; use std::collections::HashMap; - + let mut load_ops = String::new(); let mut param_to_local = HashMap::new(); let mut alloc_counter = 0; - + for (i, param) in fun.param_decls.iter().enumerate() { let param_name = param.ident.name.to_string(); - + // Check if parameter is used and needs ub allocation let usage = match param_usage.get(¶m_name) { Some(usage) => usage, None => continue, // Parameter not used at all }; - + // Only allocate ub memory if the parameter is read from if !usage.needs_ub_allocation() { continue; } - + if let Some(ty) = ¶m.ty { if let TyKind::Data(data_ty) = &ty.ty { let needs_gpu_load = match &data_ty.dty { @@ -696,11 +737,11 @@ fn generate_load_operations( } _ => false, }; - + if needs_gpu_load { // Generate the original type with gm address space let gm_type = get_mlir_type_string_with_address_space(ty, context); - + // Generate the local type with ub address space let ub_type = match &data_ty.dty { DataTyKind::At(inner, _) => { @@ -715,17 +756,22 @@ fn generate_load_operations( } _ => gm_type.clone(), }; - + // Generate alloc and load operations let alloc_name = if alloc_counter == 0 { "%alloc".to_string() } else { format!("%alloc_{}", alloc_counter - 1) }; - load_ops.push_str(&format!(" {} = memref.alloc() : {}\n", alloc_name, ub_type)); - load_ops.push_str(&format!(" hivm.hir.load ins(%arg{} : {}) outs({} : {})\n", - i, gm_type, alloc_name, ub_type)); - + load_ops.push_str(&format!( + " {} = memref.alloc() : {}\n", + alloc_name, ub_type + )); + load_ops.push_str(&format!( + " hivm.hir.load ins(%arg{} : {}) outs({} : {})\n", + i, gm_type, alloc_name, ub_type + )); + // Map parameter to its local version param_to_local.insert(param_name, alloc_name); alloc_counter += 1; @@ -733,7 +779,7 @@ fn generate_load_operations( } } } - + (load_ops, param_to_local, alloc_counter) } @@ -772,22 +818,24 @@ fn generate_function_with_body(fun: &crate::ast::FunDef, context: &Context) -> S // TODO: When HACC dialect is registered in the MLIR context, replace this with: // let hacc_attributes = create_hacc_attributes(context); // and use the attributes with MLIR operation builders instead of string generation - signature.push_str(" attributes {hacc.entry, hacc.function_kind = #hacc.function_kind}"); + signature + .push_str(" attributes {hacc.entry, hacc.function_kind = #hacc.function_kind}"); } signature.push_str(" {\n"); - + // Collect parameter usage information let param_usage = collect_parameter_usage(fun); - + // Generate load operations for GPU parameters (only for read usage) - let (load_ops, param_to_local, mut alloc_counter) = generate_load_operations(fun, ¶m_usage, context); + let (load_ops, param_to_local, mut alloc_counter) = + generate_load_operations(fun, ¶m_usage, context); signature.push_str(&load_ops); - + // Generate body operations (binary operations, etc.) let body_ops = generate_body_operations(fun, ¶m_to_local, &mut alloc_counter, context); signature.push_str(&body_ops); - + signature.push_str(" return\n }\n"); signature } @@ -971,17 +1019,6 @@ mod tests { assert_eq!(mlir_type.to_string(), "memref<10xf32>"); } - #[test] - fn test_array_with_dynamic_size_to_mlir() { - let context = Context::new(); - let data_ty = make_data_ty(DataTyKind::Array( - Box::new(make_data_ty(DataTyKind::Scalar(ScalarTy::I64))), - Nat::Ident(Ident::new("n")), - )); - let mlir_type = data_ty.to_mlir(&context); - assert_eq!(mlir_type.to_string(), "memref"); - } - #[test] fn test_array_shape_with_literal_size_to_mlir() { let context = Context::new(); @@ -1038,23 +1075,6 @@ mod tests { assert_eq!(mlir_type.to_string(), "memref<10xf32>"); } - #[test] - fn test_ref_array_dynamic_to_mlir() { - let context = Context::new(); - let ref_dty = RefDty::new( - Provenance::Ident(Ident::new("r")), - Ownership::Shrd, - Memory::CpuMem, - make_data_ty(DataTyKind::Array( - Box::new(make_data_ty(DataTyKind::Scalar(ScalarTy::I64))), - Nat::Ident(Ident::new("n")), - )), - ); - let data_ty = make_data_ty(DataTyKind::Ref(Box::new(ref_dty))); - let mlir_type = data_ty.to_mlir(&context); - assert_eq!(mlir_type.to_string(), "memref"); - } - /// Helper function to test At type lowering without MLIR parsing (avoids HIVM dialect registration) fn test_at_type_string(data_ty: &DataTy, context: &Context) -> String { match &data_ty.dty { @@ -1097,8 +1117,10 @@ mod tests { /// Helper function to create a minimal GPU function with GpuGrid execution context fn make_gpu_function() -> FunDef { - use crate::ast::{Block, DataTy, DataTyKind, Dim, Dim1d, ExecExpr, ExecExprKind, Ident, ScalarTy, Span}; - + use crate::ast::{ + Block, DataTy, DataTyKind, Dim, Dim1d, ExecExpr, ExecExprKind, Ident, ScalarTy, Span, + }; + FunDef { ident: Ident { name: "gpu_kernel".into(), @@ -1137,8 +1159,10 @@ mod tests { /// Helper function to create a minimal CPU function with CpuThread execution context fn make_cpu_function() -> FunDef { - use crate::ast::{Block, DataTy, DataTyKind, ExecExpr, ExecExprKind, Ident, ScalarTy, Span}; - + use crate::ast::{ + Block, DataTy, DataTyKind, ExecExpr, ExecExprKind, Ident, ScalarTy, Span, + }; + FunDef { ident: Ident { name: "cpu_function".into(), @@ -1177,9 +1201,10 @@ mod tests { let context = Context::new(); let gpu_fun = make_gpu_function(); let signature = generate_function_with_body(&gpu_fun, &context); - + // Check that the signature contains the GPU attributes - assert!(signature.contains("attributes {hacc.entry, hacc.function_kind = #hacc.function_kind}")); + assert!(signature + .contains("attributes {hacc.entry, hacc.function_kind = #hacc.function_kind}")); assert!(signature.contains("func.func @gpu_kernel")); assert!(signature.contains(") attributes")); } @@ -1189,9 +1214,10 @@ mod tests { let context = Context::new(); let cpu_fun = make_cpu_function(); let signature = generate_function_with_body(&cpu_fun, &context); - + // Check that the signature does NOT contain GPU attributes - assert!(!signature.contains("attributes {hacc.entry, hacc.function_kind = #hacc.function_kind}")); + assert!(!signature + .contains("attributes {hacc.entry, hacc.function_kind = #hacc.function_kind}")); assert!(signature.contains("func.func @cpu_function")); assert!(signature.contains(") {")); assert!(!signature.contains(") attributes")); @@ -1202,17 +1228,19 @@ mod tests { let context = Context::new(); let gpu_fun = make_gpu_function(); let signature = generate_function_with_body(&gpu_fun, &context); - + // Check that attributes appear in the correct position (after params, before brace) let lines: Vec<&str> = signature.lines().collect(); - + // The function signature should have 3 lines: func declaration, return, closing brace assert_eq!(lines.len(), 3); - + let func_line = lines[0]; // First line should be the function declaration assert!(func_line.contains("func.func @gpu_kernel")); - assert!(func_line.contains(") attributes {hacc.entry, hacc.function_kind = #hacc.function_kind} {")); - + assert!(func_line.contains( + ") attributes {hacc.entry, hacc.function_kind = #hacc.function_kind} {" + )); + // Verify the structure: function_name() attributes { ... } { let parts: Vec<&str> = func_line.split(") attributes").collect(); assert_eq!(parts.len(), 2); diff --git a/src/lib.rs b/src/lib.rs index 998286a..4a08706 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,4 +29,4 @@ pub fn compile(file_path: &str, backend: Backend) -> Result<(String, String), Er let ast_string = format!("{:#?}", compil_unit.items); Ok((code_string, ast_string)) -} \ No newline at end of file +} diff --git a/src/main.rs b/src/main.rs index 344b687..2dfeaba 100644 --- a/src/main.rs +++ b/src/main.rs @@ -49,12 +49,11 @@ fn main() { // Compile using Descend let (code_string, ast_string) = match descend::compile(&input_path.to_string_lossy(), backend) { - Ok(output) => output, - Err(e) => { - eprintln!("Compilation failed: {:?}", e); - std::process::exit(1); - } - + Ok(output) => output, + Err(e) => { + eprintln!("Compilation failed: {:?}", e); + std::process::exit(1); + } }; // Generate output file path with appropriate extension based on backend From da42644b9e02a5064626831425a7a8e2d7993d3b Mon Sep 17 00:00:00 2001 From: lucarlig Date: Wed, 22 Oct 2025 15:43:12 +0100 Subject: [PATCH 4/4] add more tests --- examples/core/vec_add.desc.off | 8 ------- ...{vec_add_2.desc.off => vec_add_2_ops.desc} | 0 examples/core/vec_add_3.desc.off | 7 ------- .../snapshots/mlir__core__vec_add_2_ops.snap | 21 +++++++++++++++++++ 4 files changed, 21 insertions(+), 15 deletions(-) delete mode 100644 examples/core/vec_add.desc.off rename examples/core/{vec_add_2.desc.off => vec_add_2_ops.desc} (100%) delete mode 100644 examples/core/vec_add_3.desc.off create mode 100644 tests/mlir/snapshots/mlir__core__vec_add_2_ops.snap diff --git a/examples/core/vec_add.desc.off b/examples/core/vec_add.desc.off deleted file mode 100644 index 2b55c7b..0000000 --- a/examples/core/vec_add.desc.off +++ /dev/null @@ -1,8 +0,0 @@ -fn add( - a: &r shrd gpu.global [i16; 16], - b: &r shrd gpu.global [i16; 16], - c: &r uniq gpu.global [i16; 16] -) -[grid: gpu.grid, X<16>>]-> () { - a + b; - () -} diff --git a/examples/core/vec_add_2.desc.off b/examples/core/vec_add_2_ops.desc similarity index 100% rename from examples/core/vec_add_2.desc.off rename to examples/core/vec_add_2_ops.desc diff --git a/examples/core/vec_add_3.desc.off b/examples/core/vec_add_3.desc.off deleted file mode 100644 index c89c49d..0000000 --- a/examples/core/vec_add_3.desc.off +++ /dev/null @@ -1,7 +0,0 @@ -fn add( - a: &r shrd gpu.global [i16; 16], - b: &r shrd gpu.global [i16; 16] -) -[grid: gpu.grid, X<16>>]-> () { - a + b; - () -} diff --git a/tests/mlir/snapshots/mlir__core__vec_add_2_ops.snap b/tests/mlir/snapshots/mlir__core__vec_add_2_ops.snap new file mode 100644 index 0000000..1b8a5bf --- /dev/null +++ b/tests/mlir/snapshots/mlir__core__vec_add_2_ops.snap @@ -0,0 +1,21 @@ +--- +source: tests/mlir/core.rs +expression: output +--- +module { + func.func @add(%arg0: memref<16xi16, #hivm.address_space>, %arg1: memref<16xi16, #hivm.address_space>, %arg2: memref<16xi16, #hivm.address_space>, %arg3: memref<16xi16, #hivm.address_space>, %arg4: memref<16xi16, #hivm.address_space>) attributes {hacc.entry, hacc.function_kind = #hacc.function_kind} { + %alloc = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.load ins(%arg0 : memref<16xi16, #hivm.address_space>) outs(%alloc : memref<16xi16, #hivm.address_space>) + %alloc_0 = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.load ins(%arg1 : memref<16xi16, #hivm.address_space>) outs(%alloc_0 : memref<16xi16, #hivm.address_space>) + %alloc_1 = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.load ins(%arg2 : memref<16xi16, #hivm.address_space>) outs(%alloc_1 : memref<16xi16, #hivm.address_space>) + %alloc_2 = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.vadd ins(%alloc, %alloc_0 : memref<16xi16, #hivm.address_space>, memref<16xi16, #hivm.address_space>) outs(%alloc_2 : memref<16xi16, #hivm.address_space>) + hivm.hir.store ins(%alloc_2 : memref<16xi16, #hivm.address_space>) outs(%arg3 : memref<16xi16, #hivm.address_space>) + %alloc_3 = memref.alloc() : memref<16xi16, #hivm.address_space> + hivm.hir.vadd ins(%alloc_0, %alloc_1 : memref<16xi16, #hivm.address_space>, memref<16xi16, #hivm.address_space>) outs(%alloc_3 : memref<16xi16, #hivm.address_space>) + hivm.hir.store ins(%alloc_3 : memref<16xi16, #hivm.address_space>) outs(%arg4 : memref<16xi16, #hivm.address_space>) + return + } +}