diff --git a/Cargo.toml b/Cargo.toml index 49eb0e0b7d..49cf3918dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -98,6 +98,9 @@ half = { version = "2.5", features = [ "num-traits", "serde", ], default-features = false } +num-complex = { version = "0.4", default-features = false, features = [ + "bytemuck", +] } num-traits = { version = "0.2.19", default-features = false, features = [ "libm", ] } # libm is for no_std diff --git a/crates/cubecl-common/src/arena.rs b/crates/cubecl-common/src/arena.rs index 97a08c9eaf..7ccf29caf9 100644 --- a/crates/cubecl-common/src/arena.rs +++ b/crates/cubecl-common/src/arena.rs @@ -175,31 +175,23 @@ impl Clone for ReservedMemory { impl Drop for ReservedMemory { fn drop(&mut self) { - // Ref-count lifecycle: - // reserve() → stores 1 (arena holds the slot) - // init() → consumes UninitReservedMemory, inherits count 1 - // clone() → fetch_add (count grows with each clone: 2, 3, …) - // drop() → fetch_sub (count shrinks) + // `ref_count` equals the number of live `ReservedMemory` clones. + // reserve() → stores 1 (the one clone that `init` will produce) + // init() → consumes UninitReservedMemory, count unchanged at 1 + // clone() → fetch_add (count grows: 2, 3, …) + // drop() → fetch_sub (count shrinks; previous == 1 means we + // were the last clone, so run the destructor) // - // When `fetch_sub` returns 2 it means the count just moved from 2→1, - // and we are the last `ReservedMemory` clone — the remaining "1" is - // the arena's own ref-count baseline. At this point no other clone - // can access the data, so we can safely run the destructor. - // - // If the arena has already been dropped, the same logic applies: the - // last clone still sees previous == 2 because the arena never - // decrements the logical ref_count — only `ReservedMemory::drop` does. - let drop_fn = || { + // The arena never touches `ref_count`; slot freeness is tracked via + // `Arc::strong_count` on the backing buffer instead. So the same + // logic is correct whether or not the arena is still alive. + let previous = self.ref_count.fetch_sub(1, Ordering::Release); + + if previous == 1 { // SAFETY: We are the last user of this slot. The data pointer is valid, // initialized, and no other `ReservedMemory` clone exists. let bytes_mut = unsafe { self.data.get().as_mut().unwrap() }; (self.drop_fn)(bytes_mut); - }; - - let previous = self.ref_count.fetch_sub(1, Ordering::Release); - - if previous == 2 { - drop_fn(); } } } @@ -406,6 +398,198 @@ mod tests { } } +#[cfg(test)] +mod drop_lifecycle_tests { + //! These tests verify the drop-timing contract of `ReservedMemory`: + //! `drop_fn` must run exactly once, and **only when the last clone is + //! released**. + //! + //! Every assertion is expressed through `Arc::strong_count` on a + //! payload-owned anchor. Each case is + //! runnable under `cargo miri test` to catch UB. + + use super::*; + use alloc::boxed::Box; + use alloc::vec::Vec; + use std::sync::Arc; + + struct Payload { + _anchor: Arc<()>, + } + + #[test] + fn last_clone_runs_destructor_with_one_clone() { + let anchor = Arc::new(()); + let mut arena = Arena::<4, 256>::new(); + let reserved = arena.reserve().unwrap().init(Payload { + _anchor: anchor.clone(), + }); + assert_eq!(Arc::strong_count(&anchor), 2); + drop(reserved); + assert_eq!( + Arc::strong_count(&anchor), + 1, + "single ReservedMemory must run drop_fn on drop" + ); + } + + #[test] + fn destructor_deferred_until_last_of_two_clones() { + let anchor = Arc::new(()); + let mut arena = Arena::<4, 256>::new(); + let a = arena.reserve().unwrap().init(Payload { + _anchor: anchor.clone(), + }); + let b = a.clone(); + + drop(a); + assert_eq!( + Arc::strong_count(&anchor), + 2, + "destructor fired prematurely — `b` still owns the payload" + ); + drop(b); + assert_eq!(Arc::strong_count(&anchor), 1); + } + + #[test] + fn destructor_deferred_until_last_of_many_clones() { + let anchor = Arc::new(()); + let mut arena = Arena::<4, 256>::new(); + let first = arena.reserve().unwrap().init(Payload { + _anchor: anchor.clone(), + }); + + const N: usize = 16; + let clones: Vec<_> = (0..N).map(|_| first.clone()).collect(); + drop(first); + assert_eq!(Arc::strong_count(&anchor), 2); + + for (i, c) in clones.into_iter().enumerate() { + drop(c); + let expected = if i + 1 == N { 1 } else { 2 }; + assert_eq!( + Arc::strong_count(&anchor), + expected, + "premature destructor after dropping clone {i}" + ); + } + } + + /// After the destructor runs once, re-cloning a surviving clone and + /// dropping it again must not run the destructor a second time. + #[test] + fn destructor_runs_exactly_once_across_refill_cycle() { + let anchor1 = Arc::new(()); + let anchor2 = Arc::new(()); + let mut arena = Arena::<1, 256>::new(); + + let first = arena.reserve().unwrap().init(Payload { + _anchor: anchor1.clone(), + }); + drop(first); + assert_eq!(Arc::strong_count(&anchor1), 1); + + // Slot is free again — reuse it with a brand-new payload. + let second = arena.reserve().unwrap().init(Payload { + _anchor: anchor2.clone(), + }); + assert_eq!( + Arc::strong_count(&anchor1), + 1, + "refilling the slot must not touch the prior payload's anchor" + ); + assert_eq!(Arc::strong_count(&anchor2), 2); + drop(second); + assert_eq!(Arc::strong_count(&anchor2), 1); + } + + /// A payload owning a heap allocation (`Box`) gives Miri something + /// concrete to complain about if the destructor is missed or runs + /// twice: a double drop is a double-free. + #[test] + fn heap_owning_payload_drops_exactly_once() { + struct HeapOwner(#[allow(dead_code)] Box<[u64; 8]>); + + let mut arena = Arena::<2, 256>::new(); + let a = arena + .reserve() + .unwrap() + .init(HeapOwner(Box::new([1, 2, 3, 4, 5, 6, 7, 8]))); + let b = a.clone(); + let c = a.clone(); + drop(a); + drop(b); + drop(c); + // Miri would flag a double-free on the Box here if the destructor + // fired more than once, or a leak under `-Zmiri-ignore-leaks=no` + // if it never fired. + } +} + +#[cfg(test)] +mod concurrent_drop_timing_tests { + //! Concurrent counterparts to `drop_lifecycle_tests`. The existing + //! `concurrent_tests` module checks that `drop_fn` runs exactly once + //! under contention but not *when* it runs — a premature destructor + //! satisfies "exactly once" while still corrupting surviving clones. + //! These tests bracket drop timing with live observers. + + use super::*; + use std::sync::{Arc, Barrier}; + use std::thread; + + /// Drop N clones concurrently and verify that the payload's anchor is + /// released exactly once, after all clones finish. Running under Miri + /// with `-Zmiri-disable-isolation -Zmiri-preemption-rate=...` flags + /// is not required — this test observes the post-condition on the + /// main thread after the join. + #[test] + fn concurrent_drops_release_anchor_exactly_once() { + let anchor = Arc::new(()); + let mut arena = Arena::<4, 256>::new(); + + struct Payload { + #[allow(dead_code)] + anchor: Arc<()>, + } + + let reserved = arena.reserve().unwrap().init(Payload { + anchor: anchor.clone(), + }); + + const N: usize = 8; + let barrier = Arc::new(Barrier::new(N)); + let mut handles = Vec::with_capacity(N); + for _ in 0..N - 1 { + let clone = reserved.clone(); + let b = barrier.clone(); + handles.push(thread::spawn(move || { + b.wait(); + drop(clone); + })); + } + { + let b = barrier.clone(); + let original = reserved; + handles.push(thread::spawn(move || { + b.wait(); + drop(original); + })); + } + + for h in handles { + h.join().unwrap(); + } + + assert_eq!( + Arc::strong_count(&anchor), + 1, + "after all clones drop, payload anchor must be released exactly once" + ); + } +} + #[cfg(test)] mod concurrent_tests { use super::*; diff --git a/crates/cubecl-common/src/bytes/base.rs b/crates/cubecl-common/src/bytes/base.rs index 2d006311e2..fb4c36c0ee 100644 --- a/crates/cubecl-common/src/bytes/base.rs +++ b/crates/cubecl-common/src/bytes/base.rs @@ -393,10 +393,12 @@ impl Bytes { }) } - /// Ensure the contained buffer is aligned to `align` by possibly moving it to a new buffer. + /// Ensure the allocation's reported alignment is at least `align`, reallocating + /// into a fresh controller if not. We check the controller's reported alignment + /// (not the raw pointer) because downstream callers such as `try_into_vec::` + /// depend on `alloc_align()` matching the element alignment. fn try_enforce_runtime_align(&mut self, align: usize) -> Result<(), LayoutError> { - if self.as_mut_ptr().align_offset(align) == 0 { - // data is already aligned correctly + if self.controller.alloc_align() >= align { return Ok(()); } *self = Self::try_from_data(align, self)?; @@ -667,6 +669,20 @@ mod tests { assert_eq!(right.len(), 0); } + /// `from_bytes_vec` enforces `MAX_ALIGN`, so converting the result to a Vec of + /// any type whose alignment is `<= MAX_ALIGN` must succeed. We iterate so the + /// test hits a range of underlying allocator addresses. + #[test_log::test] + fn test_from_bytes_vec_try_into_vec_aligned_type() { + for _ in 0..64 { + let bytes = Bytes::from_bytes_vec(vec![0u8; 16]); + let vec: Vec = bytes + .try_into_vec::() + .expect("MAX_ALIGN-aligned bytes must convert to Vec"); + assert_eq!(vec.len(), 1); + } + } + #[test_log::test] fn test_many_extends_with_growth() { let mut bytes = Bytes::from_elems::(vec![]); diff --git a/crates/cubecl-common/src/device/handle/channel.rs b/crates/cubecl-common/src/device/handle/channel.rs index b6f8f77a26..2e1b9d183e 100644 --- a/crates/cubecl-common/src/device/handle/channel.rs +++ b/crates/cubecl-common/src/device/handle/channel.rs @@ -255,6 +255,9 @@ struct ChannelService { } static RUNNERS: spin::Mutex>> = spin::Mutex::new(None); +/// Device/service map. The lock is held across the entire `init` sequence so `S::init` runs +/// once per `(DeviceId, TypeId)` pair. This serializes channel creation across all +/// backends. static CHANNELS: spin::Mutex>> = spin::Mutex::new(None); @@ -265,31 +268,33 @@ impl ChannelDeviceState { ) -> Result { let type_id = TypeId::of::(); let key = (device_id, type_id); + + // Hold the `CHANNELS` lock across the entire init sequence so that the + // "check missing, insert new" transition is atomic. Without this, two + // concurrent callers for the same key would both observe a missing entry, + // both run `S::init`, and race to insert. let mut guard_channel = CHANNELS.lock(); let channels = guard_channel.get_or_insert_with(HashMap::new); - // Most of the time the channel state is already initialized. - if let Some(value) = channels.get(&key) { - return Ok(value.clone()); - }; - - core::mem::drop(guard_channel); + if let Some(existing) = channels.get(&key) { + if service.is_some() { + // `insert(device, service)` cannot replace an existing state. + return Err(ServiceCreationError::new( + "Service already initialized.".into(), + )); + } + return Ok(existing.clone()); + } - // When initializing a service, we first need to make sure the device runner is - // initialized. - // - // # Notes - // // A single device runner can serve multiple [`DeviceService`]. - let mut guard = RUNNERS.lock(); - let runners = guard.get_or_insert_with(HashMap::new); - - let device_client = runners - .entry(device_id) - .or_insert_with(|| DeviceRunner::start(device_id)) - .clone(); - - core::mem::drop(guard); + let device_client = { + let mut guard = RUNNERS.lock(); + let runners = guard.get_or_insert_with(HashMap::new); + runners + .entry(device_id) + .or_insert_with(|| DeviceRunner::start(device_id)) + .clone() + }; let (callback, recv) = oneshot::channel(); @@ -342,8 +347,6 @@ impl ChannelDeviceState { service, }; - let mut guard_channel = CHANNELS.lock(); - let channels = guard_channel.get_or_insert_with(HashMap::new); channels.insert(key, channel.clone()); Ok(channel) @@ -398,18 +401,24 @@ mod task { use super::*; use core::sync::atomic::{AtomicPtr, Ordering}; use std::{ - mem::size_of, + mem::{align_of, size_of}, panic::{AssertUnwindSafe, catch_unwind}, }; /// The maximum size of a closure that can be stored without heap allocation. pub const GLOBAL_TASK_MAX_SIZE: usize = 4096; + /// The maximum size of a closure that can be stored using inlined memory. const INLINE_TASK_MAX_SIZE: usize = 48; - // We use u128 to force alignment. - pub type LargeTaskData = [u128; GLOBAL_TASK_MAX_SIZE / 16]; - pub type SmallTaskData = [u128; INLINE_TASK_MAX_SIZE / 16]; + /// One arena slot. `#[repr(C, align(64))]` makes every slot 64-byte + /// aligned on its own, so the slot alignment does not depend on the layout of any + /// enclosing type. `GLOBAL_TASK_MAX_SIZE` is a multiple of 64, so there is no + /// per-slot padding. + #[repr(C, align(64))] + pub struct ArenaSlot { + pub data: [u8; GLOBAL_TASK_MAX_SIZE], + } /// The return type of wrapped closures. pub type TaskResult = (); @@ -420,28 +429,46 @@ mod task { /// It fits in 64 bytes, ensuring multiple threads can initialize tasks at the same time /// without causing false sharing. pub struct Task { - // 48 bytes - data: SmallTaskData, + // 48 bytes; 64-aligned because it is the first field of a 64-aligned struct. + data: [u8; INLINE_TASK_MAX_SIZE], // 8 bytes (usize/u64 ptr) data_large_ptr: AtomicPtr, // 8 bytes (usize/u64 ptr) fn_ptr: fn(&mut Task) -> TaskResult, } + const _: () = { + // ArenaSlot is 4096 bytes and 64-aligned on its own. + assert!(core::mem::size_of::() == GLOBAL_TASK_MAX_SIZE); + // `Task::data` lives at offset 0 of a 64-aligned 64-byte struct, which is + // what lets the router assume the inline slot has `SLOT_ALIGN`-byte alignment. + assert!(core::mem::size_of::() == 64); + assert!(core::mem::align_of::() == core::mem::align_of::()); + assert!(core::mem::offset_of!(Task, data) == 0); + }; + impl Task { pub fn new(large_data_ptr: *mut u8) -> Self { - debug_assert!(size_of::() == 64usize); Self { - data: [0; INLINE_TASK_MAX_SIZE / 16], + data: [0u8; INLINE_TASK_MAX_SIZE], data_large_ptr: AtomicPtr::new(large_data_ptr), fn_ptr: |_| {}, } } - /// Initializes a task based on the given closure. + /// Store `func` in the inline slot, the arena slot, or on the heap depending on + /// its size and alignment. Both checks are required: writing into a slot whose + /// alignment is smaller than `align_of::()` would produce a misaligned + /// `ptr::write` (UB). The boxed fallback uses `Box::new`, whose allocation + /// satisfies any alignment. pub fn init TaskResult + Send + 'static>(&mut self, func: F) { - if size_of::() <= size_of::() { - // SAFETY: size checked above, read back exactly once by fn_ptr. + let fits_inline = size_of::() <= INLINE_TASK_MAX_SIZE + && align_of::() <= align_of::(); + let fits_arena = size_of::() <= GLOBAL_TASK_MAX_SIZE + && align_of::() <= align_of::(); + + if fits_inline { + // SAFETY: size + align checked above, read back exactly once by fn_ptr. unsafe { std::ptr::write(self.data.as_mut_ptr() as *mut F, func) }; self.fn_ptr = |task| { // SAFETY: Paired with the ptr::write to data above. @@ -450,8 +477,8 @@ mod task { log::warn!("Task failed: {err:?}"); } }; - } else if size_of::() <= size_of::() { - // SAFETY: size checked above, read back exactly once by fn_ptr. + } else if fits_arena { + // SAFETY: size + align checked above, read back exactly once by fn_ptr. unsafe { std::ptr::write(self.data_large_ptr.load(Ordering::Relaxed) as *mut F, func) }; @@ -465,8 +492,9 @@ mod task { } }; } else { - // Heap-allocate to make it pointer-sized, then recurse so we use this - // as a small task. + // Size or alignment exceeds both slots. Heap-allocate to get a + // properly-aligned, pointer-sized handle, then recurse as an inline + // task (the Box is a pointer so it trivially fits inline). let boxed: Box TaskResult + Send> = Box::new(func); self.init(boxed); } @@ -562,7 +590,7 @@ mod custom_channel { DeviceId, handle::{ CallError, - channel::task::{GLOBAL_TASK_MAX_SIZE, Task, TaskResult}, + channel::task::{ArenaSlot, GLOBAL_TASK_MAX_SIZE, Task, TaskResult}, }, }; use core::{ @@ -692,17 +720,19 @@ mod custom_channel { /// Owns a task buffer and its associated large-closure arena. struct TaskBuffer { tasks: Vec, - // u128 ensures 16-byte alignment, matching LargeTaskData = [u128; ...]. - _arena: Vec, + _arena: Vec, } impl TaskBuffer { fn new() -> Self { - let mut arena = std::vec![0u128; CHANNEL_MAX_TASK * GLOBAL_TASK_MAX_SIZE / 16]; + let mut arena: Vec = + Vec::from_iter((0..CHANNEL_MAX_TASK).map(|_| ArenaSlot { + data: [0u8; GLOBAL_TASK_MAX_SIZE], + })); + let arena_ptr = arena.as_mut_ptr() as *mut u8; let tasks = Vec::from_iter((0..CHANNEL_MAX_TASK).map(|index| { - // SAFETY: Each task gets a non-overlapping region of the arena. - // arena is CHANNEL_MAX_TASK * GLOBAL_TASK_MAX_SIZE bytes total. + // SAFETY: Each task owns a non-overlapping `ArenaSlot` region. Task::new(unsafe { arena_ptr.add(index * GLOBAL_TASK_MAX_SIZE) }) })); Self { @@ -884,8 +914,9 @@ mod tests { }); } - // Give the server a moment to process the batch - std::thread::sleep(Duration::from_millis(50)); + // Wait for tasks to complete. Miri is very slow with this test, so sleeping fails here. + let _ = handle.submit_blocking(|_| {}); + assert_eq!(completed_count.load(Ordering::SeqCst), 32); } @@ -956,7 +987,7 @@ mod tests { #[test] fn test_large_closure_uses_arena() { - // Closure captures > 48 bytes (SmallTaskData), forcing the arena path. + // Closure captures > 48 bytes (InlineSlot), forcing the arena path. let device_id = DeviceId { type_id: 0, index_id: 7, @@ -1024,4 +1055,108 @@ mod tests { assert_eq!(drop_count.load(Ordering::SeqCst), 1); } + + /// Concurrent callers racing on the same `(DeviceId, TypeId)` must share a single + /// `S::init` invocation. + #[test] + fn test_init_runs_exactly_once_under_contention() { + use alloc::vec::Vec; + use std::sync::Barrier; + use std::sync::atomic::AtomicUsize; + use std::thread; + + static INIT_CALLS: AtomicUsize = AtomicUsize::new(0); + + struct CountingService; + impl DeviceService for CountingService { + fn init(_: DeviceId) -> Self { + INIT_CALLS.fetch_add(1, Ordering::SeqCst); + CountingService + } + fn utilities(&self) -> ServerUtilitiesHandle { + Arc::new(()) + } + } + + INIT_CALLS.store(0, Ordering::SeqCst); + + const THREADS: usize = 4; + // Unique device_id so the global `CHANNELS` entry is independent of other tests. + let device_id = DeviceId { + type_id: 0, + index_id: 77, + }; + + let barrier = Arc::new(Barrier::new(THREADS)); + let mut handles = Vec::new(); + for _ in 0..THREADS { + let b = barrier.clone(); + handles.push(thread::spawn(move || { + b.wait(); + ChannelDeviceHandle::::new(device_id) + })); + } + for h in handles { + let _ = h.join().unwrap(); + } + + assert_eq!( + INIT_CALLS.load(Ordering::SeqCst), + 1, + "CountingService::init must run exactly once across {THREADS} racing callers" + ); + } + + /// A closure that spills to the arena (size > 48) and carries the maximum arena + /// alignment (64) must be stored and executed soundly. + #[test] + fn test_task_init_arena_aligned_closure() { + use super::task::{ArenaSlot, GLOBAL_TASK_MAX_SIZE, Task}; + + #[repr(align(64))] + #[derive(Clone, Copy)] + struct A64 { + data: [u8; 128], + } + + // Mirror `TaskBuffer::new`: a 64-aligned 4KB region per slot. + let mut arena = alloc::boxed::Box::new(ArenaSlot { + data: [0u8; GLOBAL_TASK_MAX_SIZE], + }); + let arena_ptr = arena.data.as_mut_ptr(); + let mut task = Task::new(arena_ptr); + + let data = A64 { data: [0xCD; 128] }; + task.init(move || { + let d = core::hint::black_box(data); + let _: usize = d.data.iter().map(|&b| b as usize).sum(); + }); + task.run(); + } + + /// A closure whose alignment exceeds the arena slot alignment must take the + /// boxed fallback. + #[test] + fn test_task_init_extremely_over_aligned_closure_uses_box() { + use super::task::{ArenaSlot, GLOBAL_TASK_MAX_SIZE, Task}; + + #[repr(align(256))] + #[derive(Clone, Copy)] + struct A256 { + data: [u8; 256], + } + + let mut arena = alloc::boxed::Box::new(ArenaSlot { + data: [0u8; GLOBAL_TASK_MAX_SIZE], + }); + let arena_ptr = arena.data.as_mut_ptr(); + let mut task = Task::new(arena_ptr); + + let data = A256 { data: [0xAA; 256] }; + task.init(move || { + let d = core::hint::black_box(data); + let _: usize = d.data.iter().map(|&b| b as usize).sum(); + }); + task.run(); + } } diff --git a/crates/cubecl-common/src/lib.rs b/crates/cubecl-common/src/lib.rs index a676aa43e4..dde9ec7ddc 100644 --- a/crates/cubecl-common/src/lib.rs +++ b/crates/cubecl-common/src/lib.rs @@ -30,9 +30,6 @@ pub mod device_handle { pub use super::device::handle::DeviceHandle; } -/// Map utilities and implementations. -pub mod map; - /// Utilities module to manipulate bytes. #[cfg(feature = "serde")] pub mod bytes; diff --git a/crates/cubecl-common/src/map.rs b/crates/cubecl-common/src/map.rs deleted file mode 100644 index 6f340fc4d7..0000000000 --- a/crates/cubecl-common/src/map.rs +++ /dev/null @@ -1,103 +0,0 @@ -use crate::stub::{Arc, Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; -use hashbrown::HashMap; - -/// A thread-safe map that allows concurrent access to values using read-write locks. -pub struct SharedStateMap { - state: Mutex>>, -} - -type State = HashMap>>; - -/// A value in the [`SharedStateMap`] that provides read and write access. -pub struct SharedState { - val: Arc>, -} - -impl SharedState { - /// Acquires a read lock on the value, returning a read guard. - pub fn read(&self) -> RwLockReadGuard<'_, V> { - self.val.read().unwrap() - } - - /// Acquires a write lock on the value, returning a write guard. - pub fn write(&self) -> RwLockWriteGuard<'_, V> { - self.val.write().unwrap() - } -} - -impl Default for SharedStateMap -where - K: core::hash::Hash + core::cmp::PartialEq + core::cmp::Eq, -{ - fn default() -> Self { - Self::new() - } -} - -impl SharedStateMap -where - K: core::hash::Hash + core::cmp::PartialEq + core::cmp::Eq, -{ - /// Creates a new, empty `SharedStateMap`. - pub const fn new() -> Self { - Self { - state: Mutex::new(None), - } - } - - /// Retrieves a value associated with the given key, if it exists. - pub fn get(&self, k: &K) -> Option> { - let mut state = self.state.lock().unwrap(); - let map = get_or_init::(&mut state); - - match map.get(k) { - Some(val) => Some(SharedState { val: val.clone() }), - None => None, - } - } - - /// Retrieves a value associated with the given key, or inserts a new value using the provided - /// initializer function if the key does not exist. - pub fn get_or_init V>(&self, k: &K, mut init: Fn) -> SharedState - where - K: Clone, - { - let mut state = self.state.lock().unwrap(); - let map = get_or_init::(&mut state); - - match map.get(k) { - Some(val) => SharedState { val: val.clone() }, - None => { - let val = init(k); - let val = Arc::new(RwLock::new(val)); - map.insert(k.clone(), val.clone()); - SharedState { val: val.clone() } - } - } - } - - /// Inserts a key-value pair into the map. - pub fn insert(&self, k: K, v: V) { - let mut state = self.state.lock().unwrap(); - let map = get_or_init::(&mut state); - - map.insert(k, Arc::new(RwLock::new(v))); - } - - /// Clears the map, removing all key-value pairs. - pub fn clear(&self) { - let mut state = self.state.lock().unwrap(); - let map = get_or_init::(&mut state); - map.clear(); - } -} - -fn get_or_init(state: &mut Option>) -> &mut State { - match state { - Some(state) => state, - None => { - *state = Some(State::::default()); - state.as_mut().unwrap() - } - } -} diff --git a/crates/cubecl-core/Cargo.toml b/crates/cubecl-core/Cargo.toml index 9df4a31f72..9dd1972c7a 100644 --- a/crates/cubecl-core/Cargo.toml +++ b/crates/cubecl-core/Cargo.toml @@ -56,6 +56,7 @@ enumset = { workspace = true } float-ord = { workspace = true } half = { workspace = true, features = ["bytemuck"] } hashbrown = { workspace = true } +num-complex = { workspace = true } num-traits = { workspace = true } paste = { workspace = true } serde = { workspace = true } diff --git a/crates/cubecl-core/src/frontend/container/vector/ops.rs b/crates/cubecl-core/src/frontend/container/vector/ops.rs index 5adc9813bd..85c5a63632 100644 --- a/crates/cubecl-core/src/frontend/container/vector/ops.rs +++ b/crates/cubecl-core/src/frontend/container/vector/ops.rs @@ -232,7 +232,9 @@ where } } -impl Abs for Vector {} +impl Abs for Vector { + type AbsElem = P::AbsElem; +} impl Log for Vector {} impl Log1p for Vector {} impl Erf for Vector {} @@ -268,6 +270,7 @@ impl IsNan for Vector {} impl IsInf for Vector {} impl Normalize for Vector {} impl Magnitude for Vector {} +impl VectorSum for Vector {} impl Degrees for Vector {} impl Radians for Vector {} diff --git a/crates/cubecl-core/src/frontend/element/base.rs b/crates/cubecl-core/src/frontend/element/base.rs index ba34805ac2..af0d05c37b 100644 --- a/crates/cubecl-core/src/frontend/element/base.rs +++ b/crates/cubecl-core/src/frontend/element/base.rs @@ -325,6 +325,8 @@ from_const!(e4m3); from_const!(e5m2); from_const!(ue8m0); from_const!(bool); +from_const!(num_complex::Complex); +from_const!(num_complex::Complex); macro_rules! tuple_cube_type { ($($P:ident),*) => { diff --git a/crates/cubecl-core/src/frontend/element/complex.rs b/crates/cubecl-core/src/frontend/element/complex.rs new file mode 100644 index 0000000000..ff6acf2023 --- /dev/null +++ b/crates/cubecl-core/src/frontend/element/complex.rs @@ -0,0 +1,130 @@ +use core::ops::{Add, Div, Mul, Neg, Sub}; + +use crate::{ + ir::{ComplexKind, ElemType, ManagedVariable, Scope, StorageType, Type}, + prelude::{CubePrimitive, CubeType, IntoRuntime, NativeAssign, NativeExpand, Scalar}, + unexpanded, +}; +use cubecl_ir::{Arithmetic, ConstantValue, Operator}; + +use crate::frontend::{ + Abs, + operation::{unary_expand, unary_expand_fixed_output}, +}; + +pub trait Complex: + CubePrimitive + + Abs + + Add + + Sub + + Mul + + Div + + Neg + + Copy + + Clone + + PartialEq + + core::fmt::Debug + + Send + + Sync + + 'static +{ + type FloatElem: Scalar; + + fn conj(self) -> Self { + unexpanded!() + } + + fn real_val(self) -> Self::FloatElem { + unexpanded!() + } + + fn imag_val(self) -> Self::FloatElem { + unexpanded!() + } +} + +pub trait ComplexExpand { + fn __expand_conj_method(self, scope: &mut Scope) -> Self; + fn __expand_real_val_method( + self, + scope: &mut Scope, + ) -> NativeExpand<::FloatElem>; + fn __expand_imag_val_method( + self, + scope: &mut Scope, + ) -> NativeExpand<::FloatElem>; + + type FloatElem: Scalar; +} + +impl ComplexExpand for NativeExpand { + type FloatElem = T::FloatElem; + + fn __expand_conj_method(self, scope: &mut Scope) -> Self { + unary_expand(scope, self.into(), Arithmetic::Conj).into() + } + + fn __expand_real_val_method(self, scope: &mut Scope) -> NativeExpand { + let expand_element: ManagedVariable = self.into(); + let item = ::as_type(scope); + unary_expand_fixed_output(scope, expand_element, item, Operator::Real).into() + } + + fn __expand_imag_val_method(self, scope: &mut Scope) -> NativeExpand { + let expand_element: ManagedVariable = self.into(); + let item = ::as_type(scope); + unary_expand_fixed_output(scope, expand_element, item, Operator::Imag).into() + } +} + +macro_rules! impl_complex { + ($primitive:ty, $kind:ident, $float:ty) => { + impl CubeType for $primitive { + type ExpandType = NativeExpand<$primitive>; + } + + impl CubePrimitive for $primitive { + type Scalar = Self; + type Size = crate::prelude::Const<1>; + type WithScalar = S; + + fn as_type_native() -> Option { + Some(StorageType::Scalar(ElemType::Complex(ComplexKind::$kind)).into()) + } + + fn from_const_value(value: ConstantValue) -> Self { + let ConstantValue::Complex(re, im) = value else { + unreachable!("expected Complex constant") + }; + <$primitive>::new(re as $float, im as $float) + } + } + + impl IntoRuntime for $primitive { + fn __expand_runtime_method(self, _scope: &mut Scope) -> NativeExpand { + self.into() + } + } + + impl NativeAssign for $primitive {} + + impl crate::prelude::IntoMut for $primitive { + fn into_mut(self, _scope: &mut Scope) -> Self { + self + } + } + + impl Scalar for $primitive {} + + impl Abs for $primitive { + type AbsElem = $float; + } + + impl Complex for $primitive { + type FloatElem = $float; + } + }; +} + +impl_complex!(num_complex::Complex, C32, f32); +impl_complex!(num_complex::Complex, C64, f64); diff --git a/crates/cubecl-core/src/frontend/element/mod.rs b/crates/cubecl-core/src/frontend/element/mod.rs index 800162f000..e0f568b9b1 100644 --- a/crates/cubecl-core/src/frontend/element/mod.rs +++ b/crates/cubecl-core/src/frontend/element/mod.rs @@ -2,6 +2,7 @@ mod atomic; mod base; mod bool; mod cast; +mod complex; mod cube_elem; mod float; mod int; @@ -13,6 +14,7 @@ pub use atomic::*; pub use base::*; pub use bool::*; pub use cast::*; +pub use complex::*; pub use cube_elem::*; pub use float::*; pub use int::*; diff --git a/crates/cubecl-core/src/frontend/element/numeric.rs b/crates/cubecl-core/src/frontend/element/numeric.rs index 17a1a7ac13..405362cdc2 100644 --- a/crates/cubecl-core/src/frontend/element/numeric.rs +++ b/crates/cubecl-core/src/frontend/element/numeric.rs @@ -5,7 +5,7 @@ use num_traits::{NumCast, One, Zero}; use crate::compute::KernelLauncher; use crate::{IntoRuntime, ScalarArgType, compute::KernelBuilder}; use crate::{ - frontend::{Abs, Remainder}, + frontend::{Abs, Remainder, VectorSum}, unexpanded, }; use crate::{ @@ -23,7 +23,8 @@ use super::{LaunchArg, NativeAssign, NativeExpand}; /// Used in kernels that should work for both. pub trait Numeric: Copy - + Abs + + Abs + + VectorSum + Remainder + Scalar + NativeAssign diff --git a/crates/cubecl-core/src/frontend/element/typemap.rs b/crates/cubecl-core/src/frontend/element/typemap.rs index 257b0a51bf..797e707977 100644 --- a/crates/cubecl-core/src/frontend/element/typemap.rs +++ b/crates/cubecl-core/src/frontend/element/typemap.rs @@ -115,6 +115,7 @@ impl Neg for DynamicScalar { ConstantValue::Float(val) => (-val).into(), ConstantValue::UInt(val) => (-(val as i64)).into(), ConstantValue::Bool(val) => (!val).into(), + ConstantValue::Complex(_, _) => panic!("Complex values aren't supported"), }) } } @@ -301,11 +302,14 @@ impl ScalarArgSettings for DynamicScalar { impl Normalize for DynamicScalar {} impl Dot for DynamicScalar {} impl Magnitude for DynamicScalar {} +impl VectorSum for DynamicScalar {} impl Recip for DynamicScalar {} impl Erf for DynamicScalar {} impl Exp for DynamicScalar {} impl Remainder for DynamicScalar {} -impl Abs for DynamicScalar {} +impl Abs for DynamicScalar { + type AbsElem = Self; +} impl Log for DynamicScalar {} impl Log1p for DynamicScalar {} impl Cos for DynamicScalar {} @@ -419,6 +423,7 @@ impl Not for DynamicScalar { ConstantValue::UInt(val) => (!val).into(), ConstantValue::Bool(val) => (!val).into(), ConstantValue::Float(val) => f64::from_bits(!val.to_bits()).into(), + ConstantValue::Complex(_, _) => panic!("Complex values aren't supported"), }) } } @@ -470,6 +475,7 @@ impl Shl for DynamicScalar { ConstantValue::Float(val) => f64::from_bits(val.to_bits() << rhs.val.as_u64()).into(), ConstantValue::UInt(val) => (val << rhs.val.as_u64()).into(), ConstantValue::Bool(_) => panic!("Invalid value"), + ConstantValue::Complex(_, _) => panic!("Complex values aren't supported"), }) } } @@ -482,6 +488,7 @@ impl Shr for DynamicScalar { ConstantValue::Float(val) => f64::from_bits(val.to_bits() >> rhs.val.as_u64()).into(), ConstantValue::UInt(val) => (val >> rhs.val.as_u64()).into(), ConstantValue::Bool(_) => panic!("Invalid value"), + ConstantValue::Complex(_, _) => panic!("Complex values aren't supported"), }) } } @@ -493,6 +500,7 @@ impl ShrAssign for DynamicScalar { ConstantValue::Float(val) => f64::from_bits(val.to_bits() >> rhs).into(), ConstantValue::UInt(val) => (val >> rhs).into(), ConstantValue::Bool(_) => panic!("Invalid value"), + ConstantValue::Complex(_, _) => panic!("Complex values aren't supported"), }); } } @@ -504,6 +512,7 @@ impl ShlAssign for DynamicScalar { ConstantValue::Float(val) => f64::from_bits(val.to_bits() << rhs).into(), ConstantValue::UInt(val) => (val << rhs).into(), ConstantValue::Bool(_) => panic!("Invalid value"), + ConstantValue::Complex(_, _) => panic!("Complex values aren't supported"), }); } } @@ -824,3 +833,7 @@ impl Zero for DynamicScalar { self.val.is_zero() } } + +impl Complex for DynamicScalar { + type FloatElem = Self; +} diff --git a/crates/cubecl-core/src/frontend/operation/binary.rs b/crates/cubecl-core/src/frontend/operation/binary.rs index 7a64e70bbe..d31364067c 100644 --- a/crates/cubecl-core/src/frontend/operation/binary.rs +++ b/crates/cubecl-core/src/frontend/operation/binary.rs @@ -549,7 +549,9 @@ impl_binary_func!( flex32, tf32, f32, - f64 + f64, + num_complex::Complex, + num_complex::Complex ); impl_binary_func!( diff --git a/crates/cubecl-core/src/frontend/operation/branch.rs b/crates/cubecl-core/src/frontend/operation/branch.rs index e36f0b702a..d04dda6a68 100644 --- a/crates/cubecl-core/src/frontend/operation/branch.rs +++ b/crates/cubecl-core/src/frontend/operation/branch.rs @@ -30,6 +30,8 @@ pub fn select_many( } pub mod select { + use cubecl_ir::VariableKind; + use crate::ir::Instruction; use super::*; @@ -41,6 +43,15 @@ pub mod select { or_else: NativeExpand, ) -> NativeExpand { let cond = condition.expand.consume(); + + if let VariableKind::Constant(value) = cond.kind { + if value.as_bool() { + return then; + } else { + return or_else; + } + } + let then = then.expand.consume(); let or_else = or_else.expand.consume(); diff --git a/crates/cubecl-core/src/frontend/operation/unary.rs b/crates/cubecl-core/src/frontend/operation/unary.rs index 292bec9634..abbb5c4c4c 100644 --- a/crates/cubecl-core/src/frontend/operation/unary.rs +++ b/crates/cubecl-core/src/frontend/operation/unary.rs @@ -6,7 +6,7 @@ use half::{bf16, f16}; use crate::{ flex32, ir::{Arithmetic, ManagedVariable, Scope}, - prelude::{CubePrimitive, CubePrimitiveExpand, CubeType, NativeExpand, Reinterpret}, + prelude::{CubePrimitive, CubePrimitiveExpand, CubeType, NativeExpand, Reinterpret, Scalar}, tf32, unexpanded, }; @@ -66,6 +66,57 @@ impl Exp for f32 { } } +pub trait Abs: + CubePrimitive + + CubeType< + ExpandType: AbsExpand< + AbsElem = Self::AbsElem, + AbsOut = NativeExpand>, + >, + > + Sized +{ + type AbsElem: Scalar; + + #[allow(unused_variables)] + fn abs(self) -> Self::WithScalar { + unexpanded!() + } + + fn __expand_abs( + scope: &mut Scope, + x: NativeExpand, + ) -> NativeExpand> { + x.__expand_abs_method(scope) + } +} + +pub trait AbsExpand: CubePrimitiveExpand { + type AbsElem: Scalar; + type AbsOut; + + fn __expand_abs_method(self, scope: &mut Scope) -> Self::AbsOut; +} + +impl AbsExpand for NativeExpand { + type AbsElem = T::AbsElem; + type AbsOut = NativeExpand>; + + fn __expand_abs_method(self, scope: &mut Scope) -> Self::AbsOut { + let expand_element: ManagedVariable = self.into(); + let item = ::as_type(scope) + .with_vector_size(expand_element.ty.vector_size()); + unary_expand_fixed_output(scope, expand_element, item, Arithmetic::Abs).into() + } +} + +macro_rules! impl_abs_same_type { + ($($type:ty),*) => { + $(impl Abs for $type { + type AbsElem = $type; + })* + }; +} + macro_rules! impl_unary_func_scalar_out { ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => { paste::paste! { @@ -158,43 +209,36 @@ impl_not!( Not, not, bool, u8, u16, u32, u64, i8, i16, i32, i64, isize, usize ); +impl_abs_same_type!( + e2m1, e4m3, e5m2, ue8m0, f16, bf16, flex32, tf32, f32, f64, i8, i16, i32, i64, u8, u16, u32, + u64, usize, isize +); impl_unary_func!( - Abs, - abs, - Arithmetic::Abs, - e2m1, - e4m3, - e5m2, - ue8m0, + Exp, + exp, + Arithmetic::Exp, f16, bf16, flex32, tf32, - f32, + // f32, f64, - i8, - i16, - i32, - i64, - u8, - u16, - u32, - u64, - usize, - isize + num_complex::Complex, + num_complex::Complex ); impl_unary_func!( - Exp, - exp, - Arithmetic::Exp, + Log, + ln, + Arithmetic::Log, f16, bf16, flex32, tf32, - // f32, - f64 + f32, + f64, + num_complex::Complex, + num_complex::Complex ); -impl_unary_func!(Log, ln, Arithmetic::Log, f16, bf16, flex32, tf32, f32, f64); impl_unary_func!( Log1p, log1p, @@ -206,8 +250,32 @@ impl_unary_func!( f32, f64 ); -impl_unary_func!(Cos, cos, Arithmetic::Cos, f16, bf16, flex32, tf32, f32, f64); -impl_unary_func!(Sin, sin, Arithmetic::Sin, f16, bf16, flex32, tf32, f32, f64); +impl_unary_func!( + Cos, + cos, + Arithmetic::Cos, + f16, + bf16, + flex32, + tf32, + f32, + f64, + num_complex::Complex, + num_complex::Complex +); +impl_unary_func!( + Sin, + sin, + Arithmetic::Sin, + f16, + bf16, + flex32, + tf32, + f32, + f64, + num_complex::Complex, + num_complex::Complex +); impl_unary_func!(Tan, tan, Arithmetic::Tan, f16, bf16, flex32, tf32, f32, f64); impl_unary_func!( Tanh, @@ -218,7 +286,9 @@ impl_unary_func!( flex32, tf32, f32, - f64 + f64, + num_complex::Complex, + num_complex::Complex ); impl_unary_func!( Sinh, @@ -339,7 +409,9 @@ impl_unary_func!( flex32, tf32, f32, - f64 + f64, + num_complex::Complex, + num_complex::Complex ); impl_unary_func!( InverseSqrt, @@ -419,6 +491,31 @@ impl_unary_func_scalar_out!( f32, f64 ); +impl_unary_func_scalar_out!( + VectorSum, + vector_sum, + Arithmetic::VectorSum, + e2m1, + e4m3, + e5m2, + ue8m0, + f16, + bf16, + flex32, + tf32, + f32, + f64, + i8, + i16, + i32, + i64, + u8, + u16, + u32, + u64, + usize, + isize +); impl_unary_func!( Normalize, normalize, diff --git a/crates/cubecl-core/src/frontend/scalar.rs b/crates/cubecl-core/src/frontend/scalar.rs index 32c055388a..a3d0dcdd81 100644 --- a/crates/cubecl-core/src/frontend/scalar.rs +++ b/crates/cubecl-core/src/frontend/scalar.rs @@ -80,6 +80,7 @@ impl InputScalar { UIntKind::U64 => write::(val, &mut out.data), }, ElemType::Bool => panic!("Bool isn't a scalar"), + ElemType::Complex(_) => unimplemented!("Complex not supported for scalar input"), }, other => unimplemented!("{other} not supported for scalars"), }; diff --git a/crates/cubecl-core/src/pod.rs b/crates/cubecl-core/src/pod.rs index 2d62b2228c..db013229b4 100644 --- a/crates/cubecl-core/src/pod.rs +++ b/crates/cubecl-core/src/pod.rs @@ -2,7 +2,7 @@ use cubecl_common::{e2m1, e2m1x2, e4m3, e5m2, flex32, tf32, ue8m0}; use cubecl_ir::StorageType; use crate::{ - ir::{ElemType, FloatKind, IntKind, UIntKind}, + ir::{ComplexKind, ElemType, FloatKind, IntKind, UIntKind}, prelude::{Numeric, Scalar}, }; @@ -438,3 +438,45 @@ impl CubeElement for e2m1x2 { e2m1x2::from_bits(min << 4 | min) } } + +impl CubeElement for num_complex::Complex { + fn type_name() -> &'static str { + "cf32" + } + fn as_bytes(slice: &[Self]) -> &[u8] { + bytemuck::cast_slice(slice) + } + fn from_bytes(bytes: &[u8]) -> &[Self] { + bytemuck::cast_slice(bytes) + } + fn cube_type() -> StorageType { + ElemType::Complex(ComplexKind::C32).into() + } + fn maximum_value() -> Self { + num_complex::Complex::new(f32::MAX, 0.0) + } + fn minimum_value() -> Self { + num_complex::Complex::new(f32::MIN, 0.0) + } +} + +impl CubeElement for num_complex::Complex { + fn type_name() -> &'static str { + "cf64" + } + fn as_bytes(slice: &[Self]) -> &[u8] { + bytemuck::cast_slice(slice) + } + fn from_bytes(bytes: &[u8]) -> &[Self] { + bytemuck::cast_slice(bytes) + } + fn cube_type() -> StorageType { + ElemType::Complex(ComplexKind::C64).into() + } + fn maximum_value() -> Self { + num_complex::Complex::new(f64::MAX, 0.0) + } + fn minimum_value() -> Self { + num_complex::Complex::new(f64::MIN, 0.0) + } +} diff --git a/crates/cubecl-core/src/runtime_tests/complex.rs b/crates/cubecl-core/src/runtime_tests/complex.rs new file mode 100644 index 0000000000..0dbe429c00 --- /dev/null +++ b/crates/cubecl-core/src/runtime_tests/complex.rs @@ -0,0 +1,831 @@ +use crate::{self as cubecl}; +use alloc::vec; +use core::fmt::{Debug, Display}; +use cubecl::prelude::*; + +fn assert_exact_eq( + client: &ComputeClient, + output: cubecl_runtime::server::Handle, + expected: &[E], +) { + let actual = client.read_one_unchecked(output); + let actual = E::from_bytes(&actual); + + assert_eq!(actual, expected); +} + +fn assert_real_approx_eq( + client: &ComputeClient, + output: cubecl_runtime::server::Handle, + expected: &[F], + epsilon: F, +) { + let actual = client.read_one_unchecked(output); + let actual = F::from_bytes(&actual); + + for (index, (actual, expected)) in actual.iter().zip(expected.iter()).enumerate() { + assert!( + (*actual - *expected).abs() <= epsilon + || (actual.is_nan() && expected.is_nan()) + || (actual.is_infinite() + && expected.is_infinite() + && actual.is_sign_positive() == expected.is_sign_positive()), + "Values differ more than epsilon: actual={}, expected={}, difference={}, epsilon={}, index={}", + actual, + expected, + (*actual - *expected).abs(), + epsilon, + index + ); + } +} + +fn assert_complex_approx_eq( + client: &ComputeClient, + output: cubecl_runtime::server::Handle, + expected: &[num_complex::Complex], + epsilon: F, +) where + num_complex::Complex: CubeElement, +{ + let actual = client.read_one_unchecked(output); + let actual = as CubeElement>::from_bytes(&actual); + + for (index, (actual, expected)) in actual.iter().zip(expected.iter()).enumerate() { + let real_matches = (actual.re - expected.re).abs() <= epsilon + || (actual.re.is_nan() && expected.re.is_nan()) + || (actual.re.is_infinite() + && expected.re.is_infinite() + && actual.re.is_sign_positive() == expected.re.is_sign_positive()); + let imag_matches = (actual.im - expected.im).abs() <= epsilon + || (actual.im.is_nan() && expected.im.is_nan()) + || (actual.im.is_infinite() + && expected.im.is_infinite() + && actual.im.is_sign_positive() == expected.im.is_sign_positive()); + + assert!( + real_matches && imag_matches, + "Complex values differ more than epsilon: actual={:?}, expected={:?}, epsilon={}, index={}", + actual, + expected, + epsilon, + index + ); + } +} + +fn complex_abs_value(value: num_complex::Complex) -> T { + value.re.hypot(value.im) +} + +fn complex_exp_value( + value: num_complex::Complex, +) -> num_complex::Complex { + let magnitude = value.re.exp(); + num_complex::Complex::new(magnitude * value.im.cos(), magnitude * value.im.sin()) +} + +fn complex_ln_value(value: num_complex::Complex) -> num_complex::Complex { + num_complex::Complex::new(complex_abs_value(value).ln(), value.im.atan2(value.re)) +} + +fn complex_powc_value( + value: num_complex::Complex, + exp: num_complex::Complex, +) -> num_complex::Complex { + if exp.re.is_zero() && exp.im.is_zero() { + return num_complex::Complex::new(T::one(), T::zero()); + } + + complex_exp_value(exp * complex_ln_value(value)) +} + +#[cube(launch_unchecked)] +pub fn kernel_complex_add(output: &mut Array, lhs: &Array, rhs: &Array) { + if ABSOLUTE_POS < output.len() { + output[ABSOLUTE_POS] = lhs[ABSOLUTE_POS] + rhs[ABSOLUTE_POS]; + } +} + +#[cube(launch_unchecked)] +pub fn kernel_complex_mul(output: &mut Array, lhs: &Array, rhs: &Array) { + if ABSOLUTE_POS < output.len() { + output[ABSOLUTE_POS] = lhs[ABSOLUTE_POS] * rhs[ABSOLUTE_POS]; + } +} + +#[cube(launch_unchecked)] +pub fn kernel_complex_conj(output: &mut Array, input: &Array) { + if ABSOLUTE_POS < output.len() { + output[ABSOLUTE_POS] = input[ABSOLUTE_POS].conj(); + } +} + +#[cube(launch_unchecked)] +pub fn kernel_complex_constant( + output: &mut Array, + scale: C, +) { + if ABSOLUTE_POS < output.len() { + output[ABSOLUTE_POS] = output[ABSOLUTE_POS] * scale; + } +} + +macro_rules! test_complex_binary_eq_op { + ( + $test_name:ident, + $kernel:ident, + $ty:ty, + lhs: [$($lhs:expr),+ $(,)?], + rhs: [$($rhs:expr),+ $(,)?], + expect: |$lhs_var:ident, $rhs_var:ident| $expected:expr + ) => { + pub fn $test_name(client: ComputeClient) { + type C = $ty; + let lhs = vec![$($lhs),+]; + let rhs = vec![$($rhs),+]; + let expected = lhs + .iter() + .copied() + .zip(rhs.iter().copied()) + .map(|($lhs_var, $rhs_var)| $expected) + .collect::>(); + + let handle_output = client.empty(lhs.len() * core::mem::size_of::()); + let handle_lhs = client.create_from_slice(C::as_bytes(&lhs)); + let handle_rhs = client.create_from_slice(C::as_bytes(&rhs)); + + unsafe { + $kernel::launch_unchecked::( + &client, + CubeCount::new_single(), + CubeDim::new_1d(lhs.len() as u32), + ArrayArg::from_raw_parts(handle_output.clone(), lhs.len()), + ArrayArg::from_raw_parts(handle_lhs, lhs.len()), + ArrayArg::from_raw_parts(handle_rhs, rhs.len()), + ) + }; + + assert_exact_eq(&client, handle_output, &expected); + } + }; +} + +macro_rules! test_complex_unary_eq_op { + ( + $test_name:ident, + $kernel:ident, + $ty:ty, + input: [$($value:expr),+ $(,)?], + expect: |$value_var:ident| $expected:expr + ) => { + pub fn $test_name(client: ComputeClient) { + type C = $ty; + let input = vec![$($value),+]; + let expected = input + .iter() + .copied() + .map(|$value_var| $expected) + .collect::>(); + + let handle_output = client.empty(input.len() * core::mem::size_of::()); + let handle_input = client.create_from_slice(C::as_bytes(&input)); + + unsafe { + $kernel::launch_unchecked::( + &client, + CubeCount::new_single(), + CubeDim::new_1d(input.len() as u32), + ArrayArg::from_raw_parts(handle_output.clone(), input.len()), + ArrayArg::from_raw_parts(handle_input, input.len()), + ) + }; + + assert_exact_eq(&client, handle_output, &expected); + } + }; +} + +macro_rules! test_complex_scalar_eq_op { + ( + $test_name:ident, + $kernel:ident, + $ty:ty, + input: [$($value:expr),+ $(,)?], + scalar: $scalar:expr, + expect: |$value_var:ident, $scale_var:ident| $expected:expr + ) => { + pub fn $test_name(client: ComputeClient) { + type C = $ty; + let input = vec![$($value),+]; + let scale = $scalar; + let expected = input + .iter() + .copied() + .map(|$value_var| { + let $scale_var = scale; + $expected + }) + .collect::>(); + + let handle_output = client.create_from_slice(C::as_bytes(&input)); + + unsafe { + $kernel::launch_unchecked::( + &client, + CubeCount::new_single(), + CubeDim::new_1d(input.len() as u32), + ArrayArg::from_raw_parts(handle_output.clone(), input.len()), + scale, + ) + }; + + assert_exact_eq(&client, handle_output, &expected); + } + }; +} + +test_complex_binary_eq_op!( + test_complex_add_cf32, + kernel_complex_add, + num_complex::Complex, + lhs: [ + num_complex::Complex::new(1.0f32, 2.0f32), + num_complex::Complex::new(3.0f32, 4.0f32), + ], + rhs: [ + num_complex::Complex::new(5.0f32, 6.0f32), + num_complex::Complex::new(7.0f32, 8.0f32), + ], + expect: |lhs, rhs| lhs + rhs +); +test_complex_binary_eq_op!( + test_complex_add_cf64, + kernel_complex_add, + num_complex::Complex, + lhs: [ + num_complex::Complex::new(1.0f64, 2.0f64), + num_complex::Complex::new(3.0f64, 4.0f64), + ], + rhs: [ + num_complex::Complex::new(5.0f64, 6.0f64), + num_complex::Complex::new(7.0f64, 8.0f64), + ], + expect: |lhs, rhs| lhs + rhs +); +test_complex_binary_eq_op!( + test_complex_mul_cf32, + kernel_complex_mul, + num_complex::Complex, + lhs: [ + num_complex::Complex::new(1.0f32, 2.0f32), + num_complex::Complex::new(3.0f32, 4.0f32), + ], + rhs: [ + num_complex::Complex::new(3.0f32, 4.0f32), + num_complex::Complex::new(5.0f32, 6.0f32), + ], + expect: |lhs, rhs| lhs * rhs +); +test_complex_binary_eq_op!( + test_complex_mul_cf64, + kernel_complex_mul, + num_complex::Complex, + lhs: [ + num_complex::Complex::new(1.0f64, 2.0f64), + num_complex::Complex::new(3.0f64, 4.0f64), + ], + rhs: [ + num_complex::Complex::new(3.0f64, 4.0f64), + num_complex::Complex::new(5.0f64, 6.0f64), + ], + expect: |lhs, rhs| lhs * rhs +); +test_complex_unary_eq_op!( + test_complex_conj_cf32, + kernel_complex_conj, + num_complex::Complex, + input: [ + num_complex::Complex::new(1.0f32, 2.0f32), + num_complex::Complex::new(3.0f32, -4.0f32), + ], + expect: |value| num_complex::Complex::new(value.re, -value.im) +); +test_complex_unary_eq_op!( + test_complex_conj_cf64, + kernel_complex_conj, + num_complex::Complex, + input: [ + num_complex::Complex::new(1.0f64, 2.0f64), + num_complex::Complex::new(3.0f64, -4.0f64), + ], + expect: |value| num_complex::Complex::new(value.re, -value.im) +); +test_complex_scalar_eq_op!( + test_complex_constant_cf32, + kernel_complex_constant, + num_complex::Complex, + input: [ + num_complex::Complex::new(1.0f32, 2.0f32), + num_complex::Complex::new(3.0f32, 4.0f32), + ], + scalar: num_complex::Complex::new(2.0f32, -1.0f32), + expect: |value, scale| value * scale +); +test_complex_scalar_eq_op!( + test_complex_constant_cf64, + kernel_complex_constant, + num_complex::Complex, + input: [ + num_complex::Complex::new(1.0f64, 2.0f64), + num_complex::Complex::new(3.0f64, 4.0f64), + ], + scalar: num_complex::Complex::new(2.0f64, -1.0f64), + expect: |value, scale| value * scale +); + +#[cube(launch_unchecked)] +pub fn kernel_complex_abs_cf32(output: &mut Array, input: &Array>) { + if ABSOLUTE_POS < output.len() { + output[ABSOLUTE_POS] = input[ABSOLUTE_POS].abs(); + } +} + +#[cube(launch_unchecked)] +pub fn kernel_complex_abs_cf64(output: &mut Array, input: &Array>) { + if ABSOLUTE_POS < output.len() { + output[ABSOLUTE_POS] = input[ABSOLUTE_POS].abs(); + } +} + +#[cube(launch_unchecked)] +pub fn kernel_complex_exp(output: &mut Array, input: &Array) { + if ABSOLUTE_POS < output.len() { + output[ABSOLUTE_POS] = input[ABSOLUTE_POS].exp(); + } +} + +#[cube(launch_unchecked)] +pub fn kernel_complex_log(output: &mut Array, input: &Array) { + if ABSOLUTE_POS < output.len() { + output[ABSOLUTE_POS] = input[ABSOLUTE_POS].ln(); + } +} + +#[cube(launch_unchecked)] +pub fn kernel_complex_sin(output: &mut Array, input: &Array) { + if ABSOLUTE_POS < output.len() { + output[ABSOLUTE_POS] = input[ABSOLUTE_POS].sin(); + } +} + +#[cube(launch_unchecked)] +pub fn kernel_complex_cos(output: &mut Array, input: &Array) { + if ABSOLUTE_POS < output.len() { + output[ABSOLUTE_POS] = input[ABSOLUTE_POS].cos(); + } +} + +#[cube(launch_unchecked)] +pub fn kernel_complex_sqrt(output: &mut Array, input: &Array) { + if ABSOLUTE_POS < output.len() { + output[ABSOLUTE_POS] = input[ABSOLUTE_POS].sqrt(); + } +} + +#[cube(launch_unchecked)] +pub fn kernel_complex_tanh(output: &mut Array, input: &Array) { + if ABSOLUTE_POS < output.len() { + output[ABSOLUTE_POS] = input[ABSOLUTE_POS].tanh(); + } +} + +#[cube(launch_unchecked)] +pub fn kernel_complex_powf( + output: &mut Array, + lhs: &Array, + rhs: &Array, +) { + if ABSOLUTE_POS < output.len() { + output[ABSOLUTE_POS] = ::powf(lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS]); + } +} + +pub fn test_complex_abs_cf32(client: ComputeClient) { + type C = num_complex::Complex; + let input = vec![C::new(3.0f32, 4.0f32), C::new(5.0f32, -12.0f32)]; + let expected = vec![complex_abs_value(input[0]), complex_abs_value(input[1])]; + + let handle_output = client.empty(2 * core::mem::size_of::()); + let handle_input = client.create_from_slice(C::as_bytes(&input)); + + unsafe { + kernel_complex_abs_cf32::launch_unchecked::( + &client, + CubeCount::new_single(), + CubeDim::new_1d(2), + ArrayArg::from_raw_parts(handle_output.clone(), 2), + ArrayArg::from_raw_parts(handle_input, 2), + ) + }; + + assert_real_approx_eq::(&client, handle_output, &expected, 1.0e-5f32); +} + +pub fn test_complex_abs_cf64(client: ComputeClient) { + type C = num_complex::Complex; + let input = vec![C::new(3.0f64, 4.0f64), C::new(5.0f64, -12.0f64)]; + let expected = vec![complex_abs_value(input[0]), complex_abs_value(input[1])]; + + let handle_output = client.empty(2 * core::mem::size_of::()); + let handle_input = client.create_from_slice(C::as_bytes(&input)); + + unsafe { + kernel_complex_abs_cf64::launch_unchecked::( + &client, + CubeCount::new_single(), + CubeDim::new_1d(2), + ArrayArg::from_raw_parts(handle_output.clone(), 2), + ArrayArg::from_raw_parts(handle_input, 2), + ) + }; + + assert_real_approx_eq::(&client, handle_output, &expected, 1.0e-12f64); +} + +macro_rules! test_complex_unary_op { + ($test_name:ident, $kernel:ident, $method:ident, $ty:ty, $epsilon:expr, [$($value:expr),+ $(,)?]) => { + pub fn $test_name(client: ComputeClient) { + type C = $ty; + let input = vec![$($value),+]; + let expected = input.iter().copied().map(|value| value.$method()).collect::>(); + + let handle_output = client.empty(input.len() * core::mem::size_of::()); + let handle_input = client.create_from_slice(C::as_bytes(&input)); + + unsafe { + $kernel::launch_unchecked::( + &client, + CubeCount::new_single(), + CubeDim::new_1d(input.len() as u32), + ArrayArg::from_raw_parts(handle_output.clone(), input.len()), + ArrayArg::from_raw_parts(handle_input, input.len()), + ) + }; + + assert_complex_approx_eq::(&client, handle_output, &expected, $epsilon); + } + }; +} + +macro_rules! test_complex_powf_op { + ($test_name:ident, $ty:ty, $epsilon:expr, lhs: [$($lhs:expr),+ $(,)?], rhs: [$($rhs:expr),+ $(,)?]) => { + pub fn $test_name(client: ComputeClient) { + type C = $ty; + let lhs = vec![$($lhs),+]; + let rhs = vec![$($rhs),+]; + let expected = lhs + .iter() + .copied() + .zip(rhs.iter().copied()) + .map(|(lhs, rhs)| complex_powc_value(lhs, rhs)) + .collect::>(); + + let handle_output = client.empty(lhs.len() * core::mem::size_of::()); + let handle_lhs = client.create_from_slice(C::as_bytes(&lhs)); + let handle_rhs = client.create_from_slice(C::as_bytes(&rhs)); + + unsafe { + kernel_complex_powf::launch_unchecked::( + &client, + CubeCount::new_single(), + CubeDim::new_1d(lhs.len() as u32), + ArrayArg::from_raw_parts(handle_output.clone(), lhs.len()), + ArrayArg::from_raw_parts(handle_lhs, lhs.len()), + ArrayArg::from_raw_parts(handle_rhs, rhs.len()), + ) + }; + + assert_complex_approx_eq::(&client, handle_output, &expected, $epsilon); + } + }; +} + +test_complex_unary_op!( + test_complex_exp_cf32, + kernel_complex_exp, + exp, + num_complex::Complex, + 1.0e-4f32, + [ + num_complex::Complex::new(0.5f32, -0.75f32), + num_complex::Complex::new(-1.25f32, 0.25f32), + ] +); +test_complex_unary_op!( + test_complex_exp_cf64, + kernel_complex_exp, + exp, + num_complex::Complex, + 1.0e-12f64, + [ + num_complex::Complex::new(0.5f64, -0.75f64), + num_complex::Complex::new(-1.25f64, 0.25f64), + ] +); +test_complex_unary_op!( + test_complex_log_cf32, + kernel_complex_log, + ln, + num_complex::Complex, + 1.0e-4f32, + [ + num_complex::Complex::new(0.5f32, -0.75f32), + num_complex::Complex::new(-1.25f32, 0.25f32), + ] +); +test_complex_unary_op!( + test_complex_log_cf64, + kernel_complex_log, + ln, + num_complex::Complex, + 1.0e-12f64, + [ + num_complex::Complex::new(0.5f64, -0.75f64), + num_complex::Complex::new(-1.25f64, 0.25f64), + ] +); +test_complex_unary_op!( + test_complex_sin_cf32, + kernel_complex_sin, + sin, + num_complex::Complex, + 1.0e-4f32, + [ + num_complex::Complex::new(0.5f32, -0.75f32), + num_complex::Complex::new(-1.25f32, 0.25f32), + ] +); +test_complex_unary_op!( + test_complex_sin_cf64, + kernel_complex_sin, + sin, + num_complex::Complex, + 1.0e-12f64, + [ + num_complex::Complex::new(0.5f64, -0.75f64), + num_complex::Complex::new(-1.25f64, 0.25f64), + ] +); +test_complex_unary_op!( + test_complex_cos_cf32, + kernel_complex_cos, + cos, + num_complex::Complex, + 1.0e-4f32, + [ + num_complex::Complex::new(0.5f32, -0.75f32), + num_complex::Complex::new(-1.25f32, 0.25f32), + ] +); +test_complex_unary_op!( + test_complex_cos_cf64, + kernel_complex_cos, + cos, + num_complex::Complex, + 1.0e-12f64, + [ + num_complex::Complex::new(0.5f64, -0.75f64), + num_complex::Complex::new(-1.25f64, 0.25f64), + ] +); +test_complex_unary_op!( + test_complex_sqrt_cf32, + kernel_complex_sqrt, + sqrt, + num_complex::Complex, + 1.0e-4f32, + [ + num_complex::Complex::new(0.5f32, -0.75f32), + num_complex::Complex::new(-1.25f32, 0.25f32), + ] +); +test_complex_unary_op!( + test_complex_sqrt_cf64, + kernel_complex_sqrt, + sqrt, + num_complex::Complex, + 1.0e-12f64, + [ + num_complex::Complex::new(0.5f64, -0.75f64), + num_complex::Complex::new(-1.25f64, 0.25f64), + ] +); +test_complex_unary_op!( + test_complex_tanh_cf32, + kernel_complex_tanh, + tanh, + num_complex::Complex, + 1.0e-4f32, + [ + num_complex::Complex::new(0.5f32, -0.75f32), + num_complex::Complex::new(-1.25f32, 0.25f32), + ] +); +test_complex_unary_op!( + test_complex_tanh_cf64, + kernel_complex_tanh, + tanh, + num_complex::Complex, + 1.0e-12f64, + [ + num_complex::Complex::new(0.5f64, -0.75f64), + num_complex::Complex::new(-1.25f64, 0.25f64), + ] +); +test_complex_powf_op!( + test_complex_powf_cf32, + num_complex::Complex, + 1.0e-4f32, + lhs: [ + num_complex::Complex::new(0.5f32, -0.75f32), + num_complex::Complex::new(-1.25f32, 0.25f32), + ], + rhs: [ + num_complex::Complex::new(0.25f32, 0.5f32), + num_complex::Complex::new(-0.75f32, 0.125f32), + ] +); +test_complex_powf_op!( + test_complex_powf_cf64, + num_complex::Complex, + 1.0e-12f64, + lhs: [ + num_complex::Complex::new(0.5f64, -0.75f64), + num_complex::Complex::new(-1.25f64, 0.25f64), + ], + rhs: [ + num_complex::Complex::new(0.25f64, 0.5f64), + num_complex::Complex::new(-0.75f64, 0.125f64), + ] +); + +#[allow(missing_docs)] +#[macro_export] +macro_rules! testgen_complex { + () => { + use super::*; + + mod complex { + use super::*; + + #[$crate::runtime_tests::test_log::test] + fn test_complex_add_cf32() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_add_cf32::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_add_cf64() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_add_cf64::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_mul_cf32() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_mul_cf32::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_mul_cf64() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_mul_cf64::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_conj_cf32() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_conj_cf32::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_conj_cf64() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_conj_cf64::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_constant_cf32() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_constant_cf32::( + client, + ); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_constant_cf64() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_constant_cf64::( + client, + ); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_abs_cf32() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_abs_cf32::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_abs_cf64() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_abs_cf64::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_exp_cf32() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_exp_cf32::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_exp_cf64() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_exp_cf64::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_log_cf32() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_log_cf32::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_log_cf64() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_log_cf64::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_sin_cf32() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_sin_cf32::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_sin_cf64() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_sin_cf64::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_cos_cf32() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_cos_cf32::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_cos_cf64() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_cos_cf64::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_sqrt_cf32() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_sqrt_cf32::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_sqrt_cf64() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_sqrt_cf64::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_tanh_cf32() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_tanh_cf32::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_tanh_cf64() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_tanh_cf64::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_powf_cf32() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_powf_cf32::(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_powf_cf64() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_powf_cf64::(client); + } + } + }; +} diff --git a/crates/cubecl-core/src/runtime_tests/mod.rs b/crates/cubecl-core/src/runtime_tests/mod.rs index 6044ba2e1f..5552fd7da5 100644 --- a/crates/cubecl-core/src/runtime_tests/mod.rs +++ b/crates/cubecl-core/src/runtime_tests/mod.rs @@ -10,6 +10,7 @@ pub mod branch; pub mod cluster; pub mod cmma; pub mod comparison; +pub mod complex; pub mod const_match; pub mod constants; pub mod debug; diff --git a/crates/cubecl-core/src/runtime_tests/unary.rs b/crates/cubecl-core/src/runtime_tests/unary.rs index 32c15cde4d..9d195c9c4e 100644 --- a/crates/cubecl-core/src/runtime_tests/unary.rs +++ b/crates/cubecl-core/src/runtime_tests/unary.rs @@ -598,6 +598,44 @@ test_unary_impl!( ] ); +test_unary_impl!( + test_vector_sum, + F, + Vector::vector_sum, + [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: -1., 23.1, -1.4, 5.1], + expected: as_type![F: -1., 23.1, -1.4, 5.1] + }, + { + input_vectorization: 2, + out_vectorization: 1, + input: as_type![F: 1., 3., 2., 5.], + expected: as_type![F: 4., 7.] + }, + { + input_vectorization: 4, + out_vectorization: 1, + input: as_type![F: 1., 2., 3., 4.], + expected: as_type![F: 10.] + }, + { + input_vectorization: 4, + out_vectorization: 1, + input: as_type![F: 0., 0., 0., 0.], + expected: as_type![F: 0.] + }, + { + input_vectorization: 4, + out_vectorization: 1, + input: as_type![F: -1., 1., -2., 2.], + expected: as_type![F: 0.] + } + ] +); + test_unary_impl!(test_abs, F, Vector::abs, [ { input_vectorization: 1, @@ -889,6 +927,7 @@ macro_rules! testgen_unary { add_test!(test_radians); add_test!(test_normalize); add_test!(test_magnitude); + add_test!(test_vector_sum); add_test!(test_sqrt); add_test!(test_inverse_sqrt); add_test!(test_abs); @@ -917,6 +956,87 @@ test_unary_impl_int!(test_abs_int, I, Abs::abs, [ } ]); +pub fn test_vector_sum_int(client: ComputeClient) { + #[cube(launch_unchecked)] + fn test_function( + input: &Array>, + output: &mut Array>, + ) { + if ABSOLUTE_POS < input.len() { + output[ABSOLUTE_POS] = Vector::cast_from(input[ABSOLUTE_POS].vector_sum()); + } + } + + // vec1: identity + { + let input = as_type![I: 3, -5, 7, -2]; + let output_handle = client.empty(input.len() * core::mem::size_of::()); + let input_handle = client.create_from_slice(I::as_bytes(input)); + + unsafe { + test_function::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new_1d(input.len() as u32), + 1usize, + 1usize, + ArrayArg::from_raw_parts(input_handle, input.len()), + ArrayArg::from_raw_parts(output_handle.clone(), input.len()), + ) + }; + + let actual = client.read_one_unchecked(output_handle); + let actual = I::from_bytes(&actual); + assert_eq!(actual, as_type![I: 3, -5, 7, -2]); + } + + // vec2: sum pairs + { + let input = as_type![I: 1, 3, 2, 5]; + let output_handle = client.empty(2 * core::mem::size_of::()); + let input_handle = client.create_from_slice(I::as_bytes(input)); + + unsafe { + test_function::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new_1d(2), + 2usize, + 1usize, + ArrayArg::from_raw_parts(input_handle, input.len()), + ArrayArg::from_raw_parts(output_handle.clone(), 2), + ) + }; + + let actual = client.read_one_unchecked(output_handle); + let actual = I::from_bytes(&actual); + assert_eq!(actual, as_type![I: 4, 7]); + } + + // vec4: sum all 4 + { + let input = as_type![I: 1, 2, 3, 4]; + let output_handle = client.empty(core::mem::size_of::()); + let input_handle = client.create_from_slice(I::as_bytes(input)); + + unsafe { + test_function::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new_1d(1), + 4usize, + 1usize, + ArrayArg::from_raw_parts(input_handle, input.len()), + ArrayArg::from_raw_parts(output_handle.clone(), 1), + ) + }; + + let actual = client.read_one_unchecked(output_handle); + let actual = I::from_bytes(&actual); + assert_eq!(actual, as_type![I: 10]); + } +} + #[allow(missing_docs)] #[macro_export] macro_rules! testgen_unary_int { @@ -937,6 +1057,7 @@ macro_rules! testgen_unary_int { } add_test!(test_abs_int); + add_test!(test_vector_sum_int); add_test!(test_count_ones); add_test!(test_reverse_bits); add_test!(test_leading_zeros); diff --git a/crates/cubecl-cpp/src/cuda/dialect.rs b/crates/cubecl-cpp/src/cuda/dialect.rs index 980bd43f15..0fcb074ee2 100644 --- a/crates/cubecl-cpp/src/cuda/dialect.rs +++ b/crates/cubecl-cpp/src/cuda/dialect.rs @@ -51,6 +51,127 @@ impl> DialectIncludes for CudaDialect { if flags.elem_f16 { f.write_str("#include \n")?; } + if flags.elem_complex { + // Use cuComplex.h instead of thrust/complex.h for NVRTC compatibility. + // thrust/complex.h requires C++ standard headers (, ) + // that are unavailable in NVRTC. cuComplex.h is a C API header that + // works with both nvcc and NVRTC. + // + // Since cuComplex has no operator overloading, we define inline wrappers + // so that the shared codegen can emit `a + b` etc. for complex types. + f.write_str("#include \n")?; + f.write_str(concat!( + "__device__ __host__ inline cuFloatComplex operator+(cuFloatComplex a, cuFloatComplex b) { return cuCaddf(a, b); }\n", + "__device__ __host__ inline cuFloatComplex operator-(cuFloatComplex a, cuFloatComplex b) { return cuCsubf(a, b); }\n", + "__device__ __host__ inline cuFloatComplex operator*(cuFloatComplex a, cuFloatComplex b) { return cuCmulf(a, b); }\n", + "__device__ __host__ inline cuFloatComplex operator/(cuFloatComplex a, cuFloatComplex b) { return cuCdivf(a, b); }\n", + "__device__ __host__ inline cuFloatComplex operator-(cuFloatComplex a) { return make_cuFloatComplex(-cuCrealf(a), -cuCimagf(a)); }\n", + "__device__ __host__ inline bool operator==(cuFloatComplex a, cuFloatComplex b) { return cuCrealf(a)==cuCrealf(b) && cuCimagf(a)==cuCimagf(b); }\n", + "__device__ __host__ inline bool operator!=(cuFloatComplex a, cuFloatComplex b) { return !(a==b); }\n", + "__device__ __host__ inline cuDoubleComplex operator+(cuDoubleComplex a, cuDoubleComplex b) { return cuCadd(a, b); }\n", + "__device__ __host__ inline cuDoubleComplex operator-(cuDoubleComplex a, cuDoubleComplex b) { return cuCsub(a, b); }\n", + "__device__ __host__ inline cuDoubleComplex operator*(cuDoubleComplex a, cuDoubleComplex b) { return cuCmul(a, b); }\n", + "__device__ __host__ inline cuDoubleComplex operator/(cuDoubleComplex a, cuDoubleComplex b) { return cuCdiv(a, b); }\n", + "__device__ __host__ inline cuDoubleComplex operator-(cuDoubleComplex a) { return make_cuDoubleComplex(-cuCreal(a), -cuCimag(a)); }\n", + "__device__ __host__ inline bool operator==(cuDoubleComplex a, cuDoubleComplex b) { return cuCreal(a)==cuCreal(b) && cuCimag(a)==cuCimag(b); }\n", + "__device__ __host__ inline bool operator!=(cuDoubleComplex a, cuDoubleComplex b) { return !(a==b); }\n", + ))?; + f.write_str( + r#"__device__ __host__ inline float cubecl_abs(cuFloatComplex a) { + return hypotf(cuCrealf(a), cuCimagf(a)); +} +__device__ __host__ inline double cubecl_abs(cuDoubleComplex a) { + return hypot(cuCreal(a), cuCimag(a)); +} +__device__ __host__ inline cuFloatComplex cubecl_exp(cuFloatComplex a) { + const float x = cuCrealf(a); + const float y = cuCimagf(a); + const float ex = expf(x); + return make_cuFloatComplex(ex * cosf(y), ex * sinf(y)); +} +__device__ __host__ inline cuDoubleComplex cubecl_exp(cuDoubleComplex a) { + const double x = cuCreal(a); + const double y = cuCimag(a); + const double ex = exp(x); + return make_cuDoubleComplex(ex * cos(y), ex * sin(y)); +} +__device__ __host__ inline cuFloatComplex cubecl_log(cuFloatComplex a) { + const float x = cuCrealf(a); + const float y = cuCimagf(a); + return make_cuFloatComplex(logf(hypotf(x, y)), atan2f(y, x)); +} +__device__ __host__ inline cuDoubleComplex cubecl_log(cuDoubleComplex a) { + const double x = cuCreal(a); + const double y = cuCimag(a); + return make_cuDoubleComplex(log(hypot(x, y)), atan2(y, x)); +} +__device__ __host__ inline cuFloatComplex cubecl_sin(cuFloatComplex a) { + const float x = cuCrealf(a); + const float y = cuCimagf(a); + return make_cuFloatComplex(sinf(x) * coshf(y), cosf(x) * sinhf(y)); +} +__device__ __host__ inline cuDoubleComplex cubecl_sin(cuDoubleComplex a) { + const double x = cuCreal(a); + const double y = cuCimag(a); + return make_cuDoubleComplex(sin(x) * cosh(y), cos(x) * sinh(y)); +} +__device__ __host__ inline cuFloatComplex cubecl_cos(cuFloatComplex a) { + const float x = cuCrealf(a); + const float y = cuCimagf(a); + return make_cuFloatComplex(cosf(x) * coshf(y), -sinf(x) * sinhf(y)); +} +__device__ __host__ inline cuDoubleComplex cubecl_cos(cuDoubleComplex a) { + const double x = cuCreal(a); + const double y = cuCimag(a); + return make_cuDoubleComplex(cos(x) * cosh(y), -sin(x) * sinh(y)); +} +__device__ __host__ inline cuFloatComplex cubecl_sqrt(cuFloatComplex a) { + const float x = cuCrealf(a); + const float y = cuCimagf(a); + const float r = hypotf(x, y); + if (x >= 0.0f) { + const float re = sqrtf(0.5f * (r + x)); + const float im = re == 0.0f ? 0.0f : y / (2.0f * re); + return make_cuFloatComplex(re, im); + } + const float im = copysignf(sqrtf(0.5f * (r - x)), y); + const float re = im == 0.0f ? 0.0f : y / (2.0f * im); + return make_cuFloatComplex(re, im); +} +__device__ __host__ inline cuDoubleComplex cubecl_sqrt(cuDoubleComplex a) { + const double x = cuCreal(a); + const double y = cuCimag(a); + const double r = hypot(x, y); + if (x >= 0.0) { + const double re = sqrt(0.5 * (r + x)); + const double im = re == 0.0 ? 0.0 : y / (2.0 * re); + return make_cuDoubleComplex(re, im); + } + const double im = copysign(sqrt(0.5 * (r - x)), y); + const double re = im == 0.0 ? 0.0 : y / (2.0 * im); + return make_cuDoubleComplex(re, im); +} +__device__ __host__ inline cuFloatComplex cubecl_tanh(cuFloatComplex a) { + const float x2 = 2.0f * cuCrealf(a); + const float y2 = 2.0f * cuCimagf(a); + const float denom = coshf(x2) + cosf(y2); + return make_cuFloatComplex(sinhf(x2) / denom, sinf(y2) / denom); +} +__device__ __host__ inline cuDoubleComplex cubecl_tanh(cuDoubleComplex a) { + const double x2 = 2.0 * cuCreal(a); + const double y2 = 2.0 * cuCimag(a); + const double denom = cosh(x2) + cos(y2); + return make_cuDoubleComplex(sinh(x2) / denom, sin(y2) / denom); +} +__device__ __host__ inline cuFloatComplex cubecl_powf(cuFloatComplex a, cuFloatComplex b) { + return cubecl_exp(b * cubecl_log(a)); +} +__device__ __host__ inline cuDoubleComplex cubecl_powf(cuDoubleComplex a, cuDoubleComplex b) { + return cubecl_exp(b * cubecl_log(a)); +} +"#, + )?; + } // tf32 conversion function is in mma header if flags.inst_wmma || flags.elem_tf32 { @@ -271,6 +392,8 @@ impl> DialectTypes for CudaDialect { shared::Elem::U16 => f.write_str("ushort"), shared::Elem::U32 => f.write_str("uint"), shared::Elem::U64 => f.write_str("ulong"), + shared::Elem::CF32 => f.write_str("cuFloatComplex"), + shared::Elem::CF64 => f.write_str("cuDoubleComplex"), _ => Self::compile_elem(f, elem, false), } } else { @@ -296,6 +419,8 @@ impl> DialectTypes for CudaDialect { shared::Elem::U16 => f.write_str("uint16"), shared::Elem::U32 => f.write_str("uint32"), shared::Elem::U64 => f.write_str("uint64"), + shared::Elem::CF32 => f.write_str("cuFloatComplex"), + shared::Elem::CF64 => f.write_str("cuDoubleComplex"), shared::Elem::Bool => f.write_str("bool"), shared::Elem::Barrier(BarrierLevel::Unit) => { f.write_str("cuda::barrier") diff --git a/crates/cubecl-cpp/src/hip/dialect.rs b/crates/cubecl-cpp/src/hip/dialect.rs index 5cc9627cd5..68d67be10c 100644 --- a/crates/cubecl-cpp/src/hip/dialect.rs +++ b/crates/cubecl-cpp/src/hip/dialect.rs @@ -290,6 +290,9 @@ impl> DialectTypes for HipDialect { shared::Elem::Bool => f.write_str("bool"), shared::Elem::Barrier(_) => panic!("Barrier object not supported in HIP"), shared::Elem::Atomic(inner) => inner.fmt(f), + shared::Elem::CF32 | shared::Elem::CF64 => { + f.write_str("#error Complex not supported in HIP\n") + } shared::Elem::_Dialect(_) => Ok(()), } } diff --git a/crates/cubecl-cpp/src/metal/dialect.rs b/crates/cubecl-cpp/src/metal/dialect.rs index 9bfe4c83f3..587b857783 100644 --- a/crates/cubecl-cpp/src/metal/dialect.rs +++ b/crates/cubecl-cpp/src/metal/dialect.rs @@ -285,6 +285,9 @@ struct alignas({alignment}) {item} {{" shared::Elem::Bool => f.write_str("bool"), shared::Elem::Barrier(_) => unimplemented!("metal doesn't support barrier object"), shared::Elem::Atomic(inner) => inner.fmt(f), + shared::Elem::CF32 | shared::Elem::CF64 => { + f.write_str("#error Complex not supported in Metal\n") + } shared::Elem::_Dialect(_) => Ok(()), } } diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index 0bf58befad..f542c061a4 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -78,6 +78,7 @@ pub struct Flags { pub elem_bf16: bool, pub elem_f16: bool, pub elem_tf32: bool, + pub elem_complex: bool, pub indexes: CubeIndexFlags, pub op_barrier: bool, pub op_pipeline: bool, @@ -124,6 +125,7 @@ impl Default for Flags { elem_bf16: Default::default(), elem_f16: Default::default(), elem_tf32: Default::default(), + elem_complex: Default::default(), indexes: Default::default(), op_barrier: Default::default(), op_pipeline: Default::default(), @@ -252,6 +254,7 @@ impl CppCompiler { elem_bf16: self.flags.elem_bf16, elem_f16: self.flags.elem_f16, elem_tf32: self.flags.elem_tf32, + elem_complex: self.flags.elem_complex, inst_tma: self.flags.inst_tma, inst_tma_im2col: self.flags.inst_tma_im2col, inst_async_copy: self.flags.inst_async_copy, @@ -1366,6 +1369,7 @@ impl CppCompiler { gpu::ElemType::Int(_) => gpu::ConstantValue::Int(1), gpu::ElemType::UInt(_) => gpu::ConstantValue::UInt(1), gpu::ElemType::Bool => gpu::ConstantValue::Bool(true), + gpu::ElemType::Complex(_) => unimplemented!("Recip not supported for complex"), }; let div = Instruction::Div(BinaryInstruction { lhs: Variable::Constant(lhs, self.compile_type(op.input.ty)), @@ -1432,6 +1436,12 @@ impl CppCompiler { gpu::Arithmetic::Dot(op) => { instructions.push(Instruction::Dot(self.compile_binary(op, out))) } + gpu::Arithmetic::Conj(op) => { + instructions.push(Instruction::Conj(self.compile_unary(op, out))) + } + gpu::Arithmetic::VectorSum(op) => { + instructions.push(Instruction::VectorSum(self.compile_unary(op, out))) + } }; } @@ -1632,6 +1642,12 @@ impl CppCompiler { gpu::Operator::Reinterpret(op) => { instructions.push(Instruction::Bitcast(self.compile_unary(op, out))) } + gpu::Operator::Real(op) => { + instructions.push(Instruction::Real(self.compile_unary(op, out))) + } + gpu::Operator::Imag(op) => { + instructions.push(Instruction::Imag(self.compile_unary(op, out))) + } }; } @@ -2065,6 +2081,13 @@ impl CppCompiler { gpu::UIntKind::U64 => Elem::U64, }, gpu::ElemType::Bool => Elem::Bool, + gpu::ElemType::Complex(kind) => { + self.flags.elem_complex = true; + match kind { + gpu::ComplexKind::C32 => Elem::CF32, + gpu::ComplexKind::C64 => Elem::CF64, + } + } } } } @@ -2110,6 +2133,8 @@ pub fn register_supported_types(props: &mut DeviceProperties) { gpu::ElemType::Float(gpu::FloatKind::Flex32), // Causes CUDA_ERROR_INVALID_VALUE for matmul, disabling until that can be investigated //gpu::Elem::Float(gpu::FloatKind::F64), + gpu::ElemType::Complex(gpu::ComplexKind::C32), + gpu::ElemType::Complex(gpu::ComplexKind::C64), gpu::ElemType::Bool, ]; diff --git a/crates/cubecl-cpp/src/shared/binary.rs b/crates/cubecl-cpp/src/shared/binary.rs index 87fbdd5090..492de58c97 100644 --- a/crates/cubecl-cpp/src/shared/binary.rs +++ b/crates/cubecl-cpp/src/shared/binary.rs @@ -227,6 +227,7 @@ impl Binary for Powf { let lhs = lhs.to_string(); let rhs = rhs.to_string(); match elem { + Elem::CF32 | Elem::CF64 => write!(f, "cubecl_powf({lhs}, {rhs})"), Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => { let lhs = format!("float({lhs})"); let rhs = format!("float({rhs})"); diff --git a/crates/cubecl-cpp/src/shared/element.rs b/crates/cubecl-cpp/src/shared/element.rs index a2dd8c69a0..8073d31fe7 100644 --- a/crates/cubecl-cpp/src/shared/element.rs +++ b/crates/cubecl-cpp/src/shared/element.rs @@ -31,6 +31,8 @@ pub enum Elem { U16, U32, U64, + CF32, + CF64, Bool, Barrier(BarrierLevel), Atomic(AtomicKind), @@ -115,6 +117,8 @@ impl Elem { Elem::U16 => core::mem::size_of::(), Elem::U32 => core::mem::size_of::(), Elem::U64 => core::mem::size_of::(), + Elem::CF32 => core::mem::size_of::() * 2, + Elem::CF64 => core::mem::size_of::() * 2, Elem::Bool => core::mem::size_of::(), Elem::Barrier(_) => core::mem::size_of::(), Elem::Atomic(AtomicKind::I32) => core::mem::size_of::(), @@ -181,6 +185,8 @@ impl Elem { Elem::U16 => "u16", Elem::U32 => "u32", Elem::U64 => "u64", + Elem::CF32 => "cf32", + Elem::CF64 => "cf64", Elem::Bool => "bool", Elem::Barrier(BarrierLevel::Cube) => "cuda::barrier", Elem::Barrier(BarrierLevel::Unit) => "cuda::barrier", diff --git a/crates/cubecl-cpp/src/shared/instruction.rs b/crates/cubecl-cpp/src/shared/instruction.rs index 9ac23eab30..7d5841f3b1 100644 --- a/crates/cubecl-cpp/src/shared/instruction.rs +++ b/crates/cubecl-cpp/src/shared/instruction.rs @@ -247,11 +247,15 @@ pub enum Instruction { out: Variable, }, Neg(UnaryInstruction), + Conj(UnaryInstruction), + Real(UnaryInstruction), + Imag(UnaryInstruction), Magnitude(UnaryInstruction), FastMagnitude(UnaryInstruction), Normalize(UnaryInstruction), FastNormalize(UnaryInstruction), Dot(BinaryInstruction), + VectorSum(UnaryInstruction), Copy { input: Variable, in_index: Variable, @@ -457,14 +461,21 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ let cond_elem = cond.item().elem; let out = out.fmt_left(); - let should_broadcast = - vf_cond > 1 || item_out != item_or_else || item_out != item_then; + // It seems to always be faster to broadcast the select, because the compiler is + // able to output branchless instructions when the ternary is done on native types + // rather than cubecl defined types. - if should_broadcast { - let vf = usize::max(vf_cond, vf_out); - let vf = usize::max(vf, vf_then); - let vf = usize::max(vf, vf_or_else); + let vf = usize::max(vf_cond, vf_out); + let vf = usize::max(vf, vf_then); + let vf = usize::max(vf, vf_or_else); + let should_broadcast = vf > 1; + + // Keep the condition here for future testing. + // + // let should_broadcast = + // vf_cond > 1 || item_out != item_or_else || item_out != item_then; + if should_broadcast { writeln!(f, "{out} = {item_out} {{")?; for i in 0..vf { let theni = then.index(i); @@ -652,6 +663,25 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ let out = out.fmt_left(); writeln!(f, "{out} = -{input};") } + Instruction::Conj(UnaryInstruction { input, out }) => { + let elem = out.elem(); + let out_left = out.fmt_left(); + // cuComplex structs have fields .x (real) and .y (imag). + let make_fn = match format!("{elem}").as_str() { + "cuFloatComplex" => "make_cuFloatComplex", + "cuDoubleComplex" => "make_cuDoubleComplex", + _ => "make_cuDoubleComplex", // fallback + }; + writeln!(f, "{out_left} = {make_fn}({input}.x, -{input}.y);") + } + Instruction::Real(UnaryInstruction { input, out }) => { + let out = out.fmt_left(); + writeln!(f, "{out} = {input}.x;") + } + Instruction::Imag(UnaryInstruction { input, out }) => { + let out = out.fmt_left(); + writeln!(f, "{out} = {input}.y;") + } Instruction::Normalize(inst) => { Normalize::::format(f, &inst.input, &inst.out) } @@ -663,6 +693,7 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ Magnitude::::format(f, &inst.input, &inst.out) } Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out), + Instruction::VectorSum(inst) => VectorSumFmt::::format(f, &inst.input, &inst.out), Instruction::VecInit { inputs, out } => { let item = out.item(); let inputs = inputs @@ -1070,6 +1101,27 @@ impl Dot { } } +struct VectorSumFmt { + _dialect: PhantomData, +} + +impl VectorSumFmt { + fn format( + f: &mut core::fmt::Formatter<'_>, + input: &Variable, + out: &Variable, + ) -> core::fmt::Result { + let num = input.item().vectorization; + + let elems = (0..num) + .map(|i| format!("{}", input.index(i))) + .collect::>(); + + let out = out.fmt_left(); + writeln!(f, "{out} = {};", elems.join(" + ")) + } +} + struct EnsureBoolArg<'a, V: Display, D: Dialect> { var: &'a V, elem: &'a Elem, diff --git a/crates/cubecl-cpp/src/shared/unary.rs b/crates/cubecl-cpp/src/shared/unary.rs index ca0ddc5e13..709ad3e452 100644 --- a/crates/cubecl-cpp/src/shared/unary.rs +++ b/crates/cubecl-cpp/src/shared/unary.rs @@ -101,6 +101,15 @@ pub trait FunctionFmt { input: Input, elem: Elem, ) -> std::fmt::Result { + if matches!(elem, Elem::CF32 | Elem::CF64) { + return match Self::base_function_name() { + "exp" | "log" | "sin" | "cos" | "sqrt" => { + write!(f, "cubecl_{}({input})", Self::base_function_name()) + } + _ => write!(f, "{}({input})", Self::function_name(elem)), + }; + } + if Self::half_support() { write!(f, "{}({input})", Self::function_name(elem)) } else { @@ -177,7 +186,35 @@ function!(FastRecip, "__frcp_rn", false); function!(FastTanh, "__tanhf", false); function!(Erf, "erf", false); -function!(Abs, "abs", false); + +pub struct Abs; + +impl FunctionFmt for Abs { + fn base_function_name() -> &'static str { + "abs" + } + + fn half_support() -> bool { + false + } +} + +impl Unary for Abs { + fn format_scalar>( + f: &mut std::fmt::Formatter<'_>, + input: Input, + _out_elem: Elem, + ) -> std::fmt::Result { + match input.elem() { + Elem::CF32 | Elem::CF64 => write!(f, "cubecl_abs({input})"), + elem => Self::format_unary(f, input, elem), + } + } + + fn can_optimize() -> bool { + false + } +} pub struct Log1p; @@ -203,7 +240,10 @@ impl Unary for Tanh { input: Input, _out_elem: Elem, ) -> std::fmt::Result { - D::compile_instruction_tanh_scalar(f, input) + match input.elem() { + Elem::CF32 | Elem::CF64 => write!(f, "cubecl_tanh({input})"), + _ => D::compile_instruction_tanh_scalar(f, input), + } } fn can_optimize() -> bool { @@ -411,9 +451,13 @@ impl Unary for IsNan { input: Input, _elem: Elem, ) -> std::fmt::Result { - // Format unary function name based on *input* elem dtype let elem = input.elem(); - write!(f, "{}({input})", elem_function_name("isnan", elem)) + match elem { + Elem::CF32 | Elem::CF64 => { + write!(f, "(isnan({input}.x) || isnan({input}.y))") + } + _ => write!(f, "{}({input})", elem_function_name("isnan", elem)), + } } fn can_optimize() -> bool { @@ -429,9 +473,13 @@ impl Unary for IsInf { input: Input, _elem: Elem, ) -> std::fmt::Result { - // Format unary function name based on *input* elem dtype let elem = input.elem(); - write!(f, "{}({input})", elem_function_name("isinf", elem)) + match elem { + Elem::CF32 | Elem::CF64 => { + write!(f, "(isinf({input}.x) || isinf({input}.y))") + } + _ => write!(f, "{}({input})", elem_function_name("isinf", elem)), + } } fn can_optimize() -> bool { diff --git a/crates/cubecl-cpp/src/shared/variable.rs b/crates/cubecl-cpp/src/shared/variable.rs index 2978778f35..caef4fcedb 100644 --- a/crates/cubecl-cpp/src/shared/variable.rs +++ b/crates/cubecl-cpp/src/shared/variable.rs @@ -209,6 +209,14 @@ impl Component for Variable { } pub(crate) fn format_const(number: &ConstantValue, item: &Item) -> String { + if let ConstantValue::Complex(re, im) = number { + return match item.elem() { + Elem::CF32 => format!("make_cuFloatComplex({re:?}f, {im:?}f)"), + Elem::CF64 => format!("make_cuDoubleComplex({re:?}, {im:?})"), + _ => format!("{number}"), + }; + } + // minifloats are represented as raw bits, so use special handling let number = match item.elem() { Elem::FP4(FP4Kind::E2M1) => e2m1::from_f64(number.as_f64()).to_bits(), @@ -245,7 +253,10 @@ impl Display for Variable { Variable::GlobalScalar { id, elem } => write!(f, "info.scalars_{elem}[{id}]"), Variable::Constant(number, item) if item.vectorization <= 1 => { let value = format_const(number, item); - write!(f, "{item}({value})") + match item.elem() { + Elem::CF32 | Elem::CF64 => write!(f, "{value}"), + _ => write!(f, "{item}({value})"), + } } Variable::Constant(number, item) => { let number = format_const(number, item); @@ -642,7 +653,10 @@ impl Display for IndexedVariable { if let Variable::Constant(value, item) = var { let value = format_const(value, item); - return write!(f, "{}({value})", item.elem()); + return match item.elem() { + Elem::CF32 | Elem::CF64 => write!(f, "{value}"), + _ => write!(f, "{}({value})", item.elem()), + }; } let ref_ = matches!(var, Variable::LocalConst { .. }) diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs index 9fd9a32833..9256021372 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs @@ -559,6 +559,24 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, output); } + Arithmetic::Conj(_) => unimplemented!("Conj not supported on CPU"), + Arithmetic::VectorSum(vector_sum) => { + let value = self.get_variable(vector_sum.input); + if vector_sum.input.ty.is_vectorized() { + let kind = Attribute::parse(self.context, "#vector.kind").unwrap(); + let result = vector_sum.input.storage_type().to_type(self.context); + let reduced = self.append_operation_with_result(vector::reduction( + self.context, + result, + value, + kind, + self.location, + )); + self.insert_variable(out, reduced); + } else { + self.insert_variable(out, value); + } + } } } diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/operator.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/operator.rs index 62822c81d8..cb8177e4db 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/operator.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/operator.rs @@ -146,6 +146,9 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, value); } + Operator::Real(_) | Operator::Imag(_) => { + unimplemented!("Real/Imag not supported on CPU") + } } } diff --git a/crates/cubecl-cpu/src/compiler/visitor/variables.rs b/crates/cubecl-cpu/src/compiler/visitor/variables.rs index 2ab6037206..3369a3c07b 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/variables.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/variables.rs @@ -256,6 +256,9 @@ impl<'a> Visitor<'a> { IntegerAttribute::new(integer_type, bool as i64).into(); (integer_type, integer_attribute) } + ConstantValue::Complex(_, _) => { + unimplemented!("Complex constants are not supported on the CPU backend") + } }; let value = self.append_operation_with_result(arith::constant( self.context, diff --git a/crates/cubecl-cuda/src/compute/communication.rs b/crates/cubecl-cuda/src/compute/communication.rs index 034d0e6a05..ac48a4ee99 100644 --- a/crates/cubecl-cuda/src/compute/communication.rs +++ b/crates/cubecl-cuda/src/compute/communication.rs @@ -118,5 +118,6 @@ pub(crate) fn get_nccl_dtype_count( ), }, ElemType::Bool => panic!("NCCL doesn't support Bool format."), + ElemType::Complex(_) => panic!("NCCL doesn't support Complex format."), } } diff --git a/crates/cubecl-cuda/src/compute/mod.rs b/crates/cubecl-cuda/src/compute/mod.rs index c595738e3a..978b846183 100644 --- a/crates/cubecl-cuda/src/compute/mod.rs +++ b/crates/cubecl-cuda/src/compute/mod.rs @@ -6,7 +6,7 @@ pub(crate) mod storage; pub(crate) mod stream; pub(crate) mod sync; -mod server; +pub(crate) mod server; pub use server::*; diff --git a/crates/cubecl-cuda/src/compute/server.rs b/crates/cubecl-cuda/src/compute/server.rs index c6f3e690c3..d419cde663 100644 --- a/crates/cubecl-cuda/src/compute/server.rs +++ b/crates/cubecl-cuda/src/compute/server.rs @@ -402,6 +402,24 @@ impl ServerCommunication for CudaServer { } impl CudaServer { + /// Returns the raw `CUstream` handle for the given stream ID. + /// + /// This allows external FFI libraries (cuBLAS, cuSOLVER, cuTENSOR) to + /// execute on the same CUDA stream as `CubeCL` kernels, eliminating the + /// need for inter-stream event synchronization. + /// + /// # Safety + /// + /// The returned `CUstream` is owned by `CubeCL`'s runtime. The caller must + /// not destroy it or use it after the server is dropped. + pub fn raw_stream( + &mut self, + stream_id: StreamId, + ) -> Result { + let mut resolved = self.streams.resolve(stream_id, [].into_iter(), false)?; + Ok(resolved.current().sys) + } + /// Create a new cuda server. pub(crate) fn new( ctx: CudaContext, diff --git a/crates/cubecl-cuda/src/lib.rs b/crates/cubecl-cuda/src/lib.rs index 409db3af87..b8be8f5313 100644 --- a/crates/cubecl-cuda/src/lib.rs +++ b/crates/cubecl-cuda/src/lib.rs @@ -9,6 +9,17 @@ mod runtime; pub use device::*; pub use runtime::*; +/// Re-exports for FFI interop with external CUDA libraries. +/// +/// These types allow extracting raw CUDA pointers and streams from `CubeCL`'s +/// managed resources, enabling zero-copy interop with cuBLAS, cuSOLVER, +/// cuTENSOR, and other CUDA FFI libraries. +pub mod ffi_interop { + pub use crate::compute::server::CudaServer; + pub use crate::compute::storage::gpu::GpuResource; + pub use crate::compute::stream::Stream; +} + #[cfg(feature = "ptx-wmma")] pub(crate) type WmmaCompiler = cubecl_cpp::cuda::mma::PtxWmmaCompiler; @@ -77,4 +88,5 @@ mod tests { cubecl_std::testgen!(); cubecl_std::testgen_tensor_identity!([f16, bf16, f32, u32]); cubecl_std::testgen_quantized_view!(f16); + cubecl_core::testgen_complex!(); } diff --git a/crates/cubecl-ir/Cargo.toml b/crates/cubecl-ir/Cargo.toml index 7bfc1a8ca5..26691a2984 100644 --- a/crates/cubecl-ir/Cargo.toml +++ b/crates/cubecl-ir/Cargo.toml @@ -40,6 +40,7 @@ float-ord = { workspace = true } fnv = { workspace = true } foldhash = { workspace = true } half = { workspace = true } +num-complex = { workspace = true } num-traits = { workspace = true } serde = { workspace = true, optional = true, features = ["rc"] } variadics_please = { workspace = true } diff --git a/crates/cubecl-ir/src/arithmetic.rs b/crates/cubecl-ir/src/arithmetic.rs index e502502a3f..1e56ea3113 100644 --- a/crates/cubecl-ir/src/arithmetic.rs +++ b/crates/cubecl-ir/src/arithmetic.rs @@ -64,6 +64,8 @@ pub enum Arithmetic { Dot(BinaryOperator), #[operation(commutative)] MulHi(BinaryOperator), + Conj(UnaryOperator), + VectorSum(UnaryOperator), } impl Display for Arithmetic { @@ -119,6 +121,8 @@ impl Display for Arithmetic { Arithmetic::Normalize(op) => write!(f, "{}.normalize()", op.input), Arithmetic::Dot(op) => write!(f, "{}.dot({})", op.lhs, op.rhs), Arithmetic::MulHi(op) => write!(f, "mul_hi({}, {})", op.lhs, op.rhs), + Arithmetic::Conj(op) => write!(f, "{}.conj()", op.input), + Arithmetic::VectorSum(op) => write!(f, "{}.vector_sum()", op.input), } } } diff --git a/crates/cubecl-ir/src/operator.rs b/crates/cubecl-ir/src/operator.rs index 74fa36ed66..20dad8d74a 100644 --- a/crates/cubecl-ir/src/operator.rs +++ b/crates/cubecl-ir/src/operator.rs @@ -31,6 +31,10 @@ pub enum Operator { Cast(UnaryOperator), #[operation(pure)] Reinterpret(UnaryOperator), + #[operation(pure)] + Real(UnaryOperator), + #[operation(pure)] + Imag(UnaryOperator), /// A select statement/ternary #[operation(pure)] Select(Select), @@ -71,6 +75,8 @@ impl Display for Operator { } Operator::Cast(op) => write!(f, "cast({})", op.input), Operator::Reinterpret(op) => write!(f, "reinterpret({})", op.input), + Operator::Real(op) => write!(f, "{}.real()", op.input), + Operator::Imag(op) => write!(f, "{}.imag()", op.input), } } } diff --git a/crates/cubecl-ir/src/type.rs b/crates/cubecl-ir/src/type.rs index 32c8b4914c..668300fa7b 100644 --- a/crates/cubecl-ir/src/type.rs +++ b/crates/cubecl-ir/src/type.rs @@ -55,6 +55,14 @@ pub enum UIntKind { U64, } +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[allow(missing_docs)] +pub enum ComplexKind { + C32, + C64, +} + /// Conceptual element type, not necessarily the physical type used in the code #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord, From)] @@ -63,6 +71,7 @@ pub enum ElemType { Float(FloatKind), Int(IntKind), UInt(UIntKind), + Complex(ComplexKind), Bool, } @@ -177,6 +186,10 @@ impl ElemType { UIntKind::U32 => core::mem::size_of::(), UIntKind::U64 => core::mem::size_of::(), }, + ElemType::Complex(kind) => match kind { + ComplexKind::C32 => core::mem::size_of::() * 2, + ComplexKind::C64 => core::mem::size_of::() * 2, + }, ElemType::Bool => core::mem::size_of::(), } } @@ -198,7 +211,9 @@ impl ElemType { | FloatKind::TF32 => self.size() * 8, FloatKind::E2M1 => 4, }, - ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool => self.size() * 8, + ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool | ElemType::Complex(_) => { + self.size() * 8 + } } } @@ -229,6 +244,17 @@ impl ElemType { matches!(self, ElemType::Bool) } + pub fn is_complex(&self) -> bool { + matches!(self, ElemType::Complex(_)) + } + + pub fn as_complex(&self) -> Option { + match self { + ElemType::Complex(kind) => Some(*kind), + _ => None, + } + } + pub fn as_float(&self) -> Option { match self { ElemType::Float(kind) => Some(*kind), @@ -265,6 +291,7 @@ impl ElemType { UIntKind::U64 => u64::MAX, } .into(), + ElemType::Complex(_) => panic!("Complex numbers have no maximum"), ElemType::Bool => true.into(), }; @@ -300,6 +327,7 @@ impl ElemType { UIntKind::U64 => u64::MIN, } .into(), + ElemType::Complex(_) => panic!("Complex numbers have no minimum"), ElemType::Bool => false.into(), }; @@ -320,7 +348,11 @@ impl ElemType { FloatKind::Flex32 | FloatKind::F32 | FloatKind::TF32 => f32::EPSILON.into(), FloatKind::F64 => f64::EPSILON, }, - ElemType::Int(_) | ElemType::UInt(_) => 1.0, // step of 1 + ElemType::Int(_) | ElemType::UInt(_) => 1.0, + ElemType::Complex(kind) => match kind { + ComplexKind::C32 => f32::EPSILON.into(), + ComplexKind::C64 => f64::EPSILON, + }, ElemType::Bool => 1.0, } } @@ -420,7 +452,7 @@ macro_rules! storage_from_elem { }; } -storage_from_elem!(FloatKind, IntKind, UIntKind, ElemType); +storage_from_elem!(FloatKind, IntKind, UIntKind, ComplexKind, ElemType); impl From for StorageType { fn from(val: OpaqueType) -> Self { @@ -609,6 +641,10 @@ impl Display for ElemType { UIntKind::U32 => f.write_str("u32"), UIntKind::U64 => f.write_str("u64"), }, + Self::Complex(kind) => match kind { + ComplexKind::C32 => f.write_str("c32"), + ComplexKind::C64 => f.write_str("c64"), + }, Self::Bool => f.write_str("bool"), } } @@ -791,3 +827,21 @@ impl_into_variable!( usize => UIntKind::U32, isize => IntKind::I32, ); + +impl From> for Variable { + fn from(value: num_complex::Complex) -> Self { + Variable::new( + VariableKind::Constant(ConstantValue::Complex(value.re as f64, value.im as f64)), + StorageType::Scalar(ElemType::Complex(ComplexKind::C32)).into(), + ) + } +} + +impl From> for Variable { + fn from(value: num_complex::Complex) -> Self { + Variable::new( + VariableKind::Constant(ConstantValue::Complex(value.re, value.im)), + StorageType::Scalar(ElemType::Complex(ComplexKind::C64)).into(), + ) + } +} diff --git a/crates/cubecl-ir/src/variable.rs b/crates/cubecl-ir/src/variable.rs index b7ca22a946..646da93efa 100644 --- a/crates/cubecl-ir/src/variable.rs +++ b/crates/cubecl-ir/src/variable.rs @@ -2,7 +2,7 @@ use core::{fmt::Display, hash::Hash}; use crate::{BarrierLevel, FloatKind, IntKind, StorageType, TypeHash}; -use super::{ElemType, Matrix, Type, UIntKind}; +use super::{ComplexKind, ElemType, Matrix, Type, UIntKind}; use cubecl_common::{e2m1, e4m3, e5m2, ue8m0}; use derive_more::From; use float_ord::FloatOrd; @@ -233,6 +233,7 @@ pub enum ConstantValue { Float(f64), UInt(u64), Bool(bool), + Complex(f64, f64), } impl Ord for ConstantValue { @@ -243,6 +244,12 @@ impl Ord for ConstantValue { (ConstantValue::Float(this), ConstantValue::Float(other)) => { FloatOrd(*this).cmp(&FloatOrd(*other)) } + ( + ConstantValue::Complex(this_re, this_im), + ConstantValue::Complex(other_re, other_im), + ) => FloatOrd(*this_re) + .cmp(&FloatOrd(*other_re)) + .then_with(|| FloatOrd(*this_im).cmp(&FloatOrd(*other_im))), _ => self.partial_cmp(other).unwrap(), } } @@ -265,6 +272,10 @@ impl Hash for ConstantValue { ConstantValue::Bool(f0) => { f0.hash(ra_expand_state); } + ConstantValue::Complex(f0, f1) => { + FloatOrd(*f0).hash(ra_expand_state); + FloatOrd(*f1).hash(ra_expand_state); + } } } } @@ -279,6 +290,7 @@ impl ConstantValue { ConstantValue::Int(val) => Some(*val as usize), ConstantValue::Float(_) => None, ConstantValue::Bool(_) => None, + ConstantValue::Complex(_, _) => None, } } @@ -289,6 +301,9 @@ impl ConstantValue { ConstantValue::Int(val) => *val as usize, ConstantValue::Float(val) => *val as usize, ConstantValue::Bool(val) => *val as usize, + ConstantValue::Complex(_, _) => { + panic!("Complex constants can't be converted to usize") + } } } @@ -315,6 +330,7 @@ impl ConstantValue { ConstantValue::Int(val) => Some(*val as u64), ConstantValue::Float(_) => None, ConstantValue::Bool(_) => None, + ConstantValue::Complex(_, _) => None, } } @@ -325,6 +341,7 @@ impl ConstantValue { ConstantValue::Int(val) => *val as u64, ConstantValue::Float(val) => *val as u64, ConstantValue::Bool(val) => *val as u64, + ConstantValue::Complex(_, _) => panic!("Complex constants can't be converted to u64"), } } @@ -337,6 +354,7 @@ impl ConstantValue { ConstantValue::Int(val) => Some(*val), ConstantValue::Float(_) => None, ConstantValue::Bool(_) => None, + ConstantValue::Complex(_, _) => None, } } @@ -347,6 +365,9 @@ impl ConstantValue { ConstantValue::Int(val) => *val as i128, ConstantValue::Float(val) => *val as i128, ConstantValue::Bool(val) => *val as i128, + ConstantValue::Complex(_, _) => { + panic!("Complex constants can't be converted to i128") + } } } @@ -357,6 +378,7 @@ impl ConstantValue { ConstantValue::Int(val) => *val, ConstantValue::Float(val) => *val as i64, ConstantValue::Bool(val) => *val as i64, + ConstantValue::Complex(_, _) => panic!("Complex constants can't be converted to i64"), } } @@ -367,6 +389,7 @@ impl ConstantValue { ConstantValue::Int(val) => *val as i32, ConstantValue::Float(val) => *val as i32, ConstantValue::Bool(val) => *val as i32, + ConstantValue::Complex(_, _) => panic!("Complex constants can't be converted to i32"), } } @@ -376,6 +399,7 @@ impl ConstantValue { pub fn try_as_f64(&self) -> Option { match self { ConstantValue::Float(val) => Some(*val), + ConstantValue::Complex(re, _) => Some(*re), _ => None, } } @@ -387,6 +411,7 @@ impl ConstantValue { ConstantValue::Int(val) => *val as f64, ConstantValue::Float(val) => *val, ConstantValue::Bool(val) => *val as u8 as f64, + ConstantValue::Complex(re, _) => *re, } } @@ -407,6 +432,9 @@ impl ConstantValue { ConstantValue::Int(val) => *val != 0, ConstantValue::Float(val) => *val != 0., ConstantValue::Bool(val) => *val, + ConstantValue::Complex(_, _) => { + panic!("Complex constants can't be converted to bool") + } } } @@ -416,6 +444,7 @@ impl ConstantValue { ConstantValue::Float(val) => *val == 0.0, ConstantValue::UInt(val) => *val == 0, ConstantValue::Bool(val) => !*val, + ConstantValue::Complex(re, im) => *re == 0.0 && *im == 0.0, } } @@ -425,6 +454,7 @@ impl ConstantValue { ConstantValue::Float(val) => *val == 1.0, ConstantValue::UInt(val) => *val == 1, ConstantValue::Bool(val) => *val, + ConstantValue::Complex(re, im) => *re == 1.0 && *im == 0.0, } } @@ -447,21 +477,48 @@ impl ConstantValue { FloatKind::F64 => self.as_f64(), } .into(), - ElemType::Int(kind) => match kind { - IntKind::I8 => self.as_i64() as i8 as i64, - IntKind::I16 => self.as_i64() as i16 as i64, - IntKind::I32 => self.as_i64() as i32 as i64, - IntKind::I64 => self.as_i64(), + ElemType::Int(kind) => { + let value = match self { + ConstantValue::Complex(re, _) => *re as i64, + _ => self.as_i64(), + }; + + match kind { + IntKind::I8 => value as i8 as i64, + IntKind::I16 => value as i16 as i64, + IntKind::I32 => value as i32 as i64, + IntKind::I64 => value, + } } .into(), - ElemType::UInt(kind) => match kind { - UIntKind::U8 => self.as_u64() as u8 as u64, - UIntKind::U16 => self.as_u64() as u16 as u64, - UIntKind::U32 => self.as_u64() as u32 as u64, - UIntKind::U64 => self.as_u64(), + ElemType::UInt(kind) => { + let value = match self { + ConstantValue::Complex(re, _) => *re as u64, + _ => self.as_u64(), + }; + + match kind { + UIntKind::U8 => value as u8 as u64, + UIntKind::U16 => value as u16 as u64, + UIntKind::U32 => value as u32 as u64, + UIntKind::U64 => value, + } } .into(), ElemType::Bool => self.as_bool().into(), + ElemType::Complex(kind) => match (self, kind) { + (ConstantValue::Complex(re, im), ComplexKind::C32) => { + ConstantValue::Complex(*re as f32 as f64, *im as f32 as f64) + } + (ConstantValue::Complex(re, im), ComplexKind::C64) => { + ConstantValue::Complex(*re, *im) + } + (_, ComplexKind::C32) => { + let re = self.as_f64() as f32 as f64; + ConstantValue::Complex(re, 0.0) + } + (_, ComplexKind::C64) => ConstantValue::Complex(self.as_f64(), 0.0), + }, }, StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2) => { e2m1::from_f64(self.as_f64()).to_f64().into() @@ -480,6 +537,7 @@ impl Display for ConstantValue { ConstantValue::Float(val) => write!(f, "{val:?}"), ConstantValue::UInt(val) => write!(f, "{val}"), ConstantValue::Bool(val) => write!(f, "{val}"), + ConstantValue::Complex(re, im) => write!(f, "({re:?}, {im:?})"), } } } diff --git a/crates/cubecl-macros/src/parse/kernel.rs b/crates/cubecl-macros/src/parse/kernel.rs index 3c901ed06e..42596eeade 100644 --- a/crates/cubecl-macros/src/parse/kernel.rs +++ b/crates/cubecl-macros/src/parse/kernel.rs @@ -246,7 +246,7 @@ impl GenericAnalysis { let marker_ty = format_ident!("_{ident}"); match name.as_str() { - "Float" | "Int" | "Numeric" | "CubePrimitive" => { + "Float" | "Int" | "Numeric" | "CubePrimitive" | "Complex" => { if explicit_defines { map.insert( ident.clone(), diff --git a/crates/cubecl-opt/src/instructions.rs b/crates/cubecl-opt/src/instructions.rs index b8c1c13ad3..b9f7f6b28b 100644 --- a/crates/cubecl-opt/src/instructions.rs +++ b/crates/cubecl-opt/src/instructions.rs @@ -134,8 +134,10 @@ impl Optimizer { | Arithmetic::Erf(unary_operator) | Arithmetic::Recip(unary_operator) | Arithmetic::Neg(unary_operator) + | Arithmetic::Conj(unary_operator) | Arithmetic::Magnitude(unary_operator) - | Arithmetic::Normalize(unary_operator) => self.visit_unop(unary_operator, visit_read), + | Arithmetic::Normalize(unary_operator) + | Arithmetic::VectorSum(unary_operator) => self.visit_unop(unary_operator, visit_read), Arithmetic::Clamp(clamp_operator) => { visit_read(self, &mut clamp_operator.input); @@ -203,7 +205,9 @@ impl Optimizer { } Operator::Not(unary_operator) | Operator::Cast(unary_operator) - | Operator::Reinterpret(unary_operator) => self.visit_unop(unary_operator, visit_read), + | Operator::Reinterpret(unary_operator) + | Operator::Real(unary_operator) + | Operator::Imag(unary_operator) => self.visit_unop(unary_operator, visit_read), Operator::Index(index_operator) | Operator::UncheckedIndex(index_operator) => { visit_read(self, &mut index_operator.list); visit_read(self, &mut index_operator.index); diff --git a/crates/cubecl-opt/src/passes/constant_prop.rs b/crates/cubecl-opt/src/passes/constant_prop.rs index 8cbcc6a5e0..5b90621a7c 100644 --- a/crates/cubecl-opt/src/passes/constant_prop.rs +++ b/crates/cubecl-opt/src/passes/constant_prop.rs @@ -489,7 +489,9 @@ fn try_const_eval_arithmetic(op: &mut Arithmetic) -> Option { | Arithmetic::Hypot(_) | Arithmetic::Rhypot(_) | Arithmetic::Magnitude(_) - | Arithmetic::Normalize(_) => None, + | Arithmetic::Normalize(_) + | Arithmetic::Conj(_) + | Arithmetic::VectorSum(_) => None, } } @@ -505,6 +507,7 @@ fn try_const_eval_cmp(op: &mut Comparison) -> Option { use ConstantValue::*; op.input.as_const().map(|input| match input { Float(val) => Bool(val.is_nan()), + Complex(re, im) => Bool(re.is_nan() || im.is_nan()), // Integers, bools, uints can't be NaN, so always false Int(_) | UInt(_) | Bool(_) => Bool(false), }) @@ -513,6 +516,7 @@ fn try_const_eval_cmp(op: &mut Comparison) -> Option { use ConstantValue::*; op.input.as_const().map(|input| match input { Float(val) => Bool(val.is_infinite()), + Complex(re, im) => Bool(re.is_infinite() || im.is_infinite()), // Integers, bools, uints can't be infinite, so always false Int(_) | UInt(_) | Bool(_) => Bool(false), }) @@ -578,6 +582,8 @@ fn try_const_eval_operator(op: &mut Operator, out_ty: Option) -> Option None, } } diff --git a/crates/cubecl-runtime/src/client.rs b/crates/cubecl-runtime/src/client.rs index 3e2d5438fb..004f0550b6 100644 --- a/crates/cubecl-runtime/src/client.rs +++ b/crates/cubecl-runtime/src/client.rs @@ -818,6 +818,20 @@ impl ComputeClient { .unwrap() } + /// Execute a closure with mutable access to the underlying compute server. + /// + /// This is the general-purpose escape hatch for backend-specific operations + /// that are not covered by the `ComputeClient` API (e.g., extracting a raw + /// CUDA stream for FFI interop). + /// + /// Returns `None` if the device handle call fails. + pub fn with_server( + &self, + f: impl FnOnce(&mut R::Server) -> R2 + Send + 'static, + ) -> Option { + self.device.submit_blocking(f).ok() + } + /// Get all devices of a specific type available to this runtime pub fn enumerate_devices(&self, type_id: u16) -> Vec { R::enumerate_devices(type_id, self.info()) diff --git a/crates/cubecl-runtime/src/tune/base.rs b/crates/cubecl-runtime/src/tune/base.rs index 5c26886ca0..5f0bc6d750 100644 --- a/crates/cubecl-runtime/src/tune/base.rs +++ b/crates/cubecl-runtime/src/tune/base.rs @@ -1,16 +1,10 @@ use super::{AutotuneKey, IntoTuneFn, TuneFn}; -use alloc::format; -use alloc::string::String; -use alloc::sync::Arc; -use alloc::vec; -use alloc::vec::Vec; +use alloc::{format, string::String, sync::Arc, vec, vec::Vec}; use core::sync::atomic::{AtomicU32, Ordering}; use hashbrown::HashMap; -/// A tunable wraps a [function](TuneFn) that can be included in multiple [groups](TuneGroup). -/// -/// When a tunable is part of multiple groups, it will be autotuned when one of those groups is -/// prioritized. +/// A single candidate for autotune: a [`TuneFn`] plus the [groups](TuneGroup) it +/// belongs to. A tunable is autotuned whenever any of its groups is prioritized. pub struct Tunable { pub(crate) function: Arc>, groups: Vec<(TuneGroup, PriorityFunc)>, @@ -28,26 +22,25 @@ impl Tunable { } } - /// Tag the current tunable as part of the given [group](TuneGroup). - /// `group` is a tuning group with a corresponding priority function. - /// `priority` is the intra-group priority, applied after the group priority to further sort entries + /// Add this tunable to a [`TuneGroup`] with the given intra-group priority. /// - /// Groups are tuned in order of priority, and then each entry in the group is tuned based on the - /// intra-group priority. Negative priorities ensure the entry is never tuned for this key. - pub fn group i8 + 'static>(mut self, group: &TuneGroup, priority: F) -> Self { + /// Groups are autotuned in order of their priority; within each group, tunables are + /// tried in order of `priority(key)`. A negative priority skips the tunable for this + /// key. + pub fn group i8 + Send + Sync + 'static>( + mut self, + group: &TuneGroup, + priority: F, + ) -> Self { self.groups.push((group.clone(), Arc::new(priority))); self } } -/// A tune group encapsulates a priority that can be calculated based on an -/// [autotune key](AutotuneKey). -/// -/// During autotuning, the higher prioritized groups will be autotuned first, and if a tunable -/// returns a valid result, no more groups will be autotuned afterward. +/// A priority bucket for tunables, computed from the [autotune key](AutotuneKey). /// -/// Note that tunables themselves have a priority dictating the order in which they are autotuned in -/// each group. +/// Higher-priority groups are autotuned first; once any tunable in a group returns a +/// valid result, no later groups are tried. pub struct TuneGroup { id: u32, name: Arc, @@ -72,7 +65,7 @@ impl Clone for TuneGroup { impl TuneGroup { /// Create a new group based on a priority function. - pub fn new i8 + 'static, S: Into>(name: S, f: F) -> Self { + pub fn new i8 + Send + Sync + 'static, S: Into>(name: S, f: F) -> Self { let id = GROUP_COUNTER.fetch_add(1, Ordering::Relaxed); Self { @@ -177,6 +170,7 @@ impl TunePlan { let (group_indices, cleanup) = self.group_plan_next(priority); // Some entries are skipped for this round of prioritizing. let skipped = cleanup.skipped || priority < 0; + let mut all_skip = true; self.cleanup(cleanup); @@ -187,6 +181,7 @@ impl TunePlan { } for (index, _name) in group_indices { if !self.returned.contains(&index) { + all_skip = false; indices.push(index); } } @@ -194,7 +189,8 @@ impl TunePlan { // The indices list is empty, but it doesn't mean we should stop // autotuning, since some entries were skipped. - if indices.is_empty() && skipped { + + if indices.is_empty() && (skipped || all_skip) { self.next(context_logs) } else { for i in indices.iter() { @@ -278,7 +274,7 @@ impl TunePlan { } } -type PriorityFunc = Arc i8>; +type PriorityFunc = Arc i8 + Send + Sync>; static GROUP_COUNTER: AtomicU32 = AtomicU32::new(0); @@ -404,6 +400,167 @@ mod tests { assert!(plan.next(None).is_empty()); } + #[test_log::test] + fn test_plan_falls_through_when_all_group_tunables_fail() { + // Every tunable lives in exactly one group; the caller treats every batch as a failure + // by continuing to call next(). The plan must still surface every tunable, in priority + // order, before going empty. + let group0 = TuneGroup::::new("group0", |_| 2); + let group1 = TuneGroup::::new("group1", |_| 1); + + let tunable0 = + Tunable::::new("fake", fake_kernel).group(&group0, |_| 1); + let tunable1 = + Tunable::::new("fake", fake_kernel).group(&group0, |_| 2); + let tunable2 = + Tunable::::new("fake", fake_kernel).group(&group1, |_| 1); + let tunable3 = + Tunable::::new("fake", fake_kernel).group(&group1, |_| 2); + + let key = FakeAutotuneKey; + let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]); + + let mut all_returned: Vec = Vec::new(); + loop { + let batch = plan.next(None); + if batch.is_empty() { + break; + } + all_returned.extend(batch); + } + + // Highest group (prio 2) drains first from highest intra-priority down, then next group. + assert_eq!(all_returned, vec![1, 0, 3, 2]); + } + + #[test_log::test] + fn test_plan_single_group_exhausts_all_intra_priorities() { + // A single group with multiple intra-priorities should yield each batch separately, + // allowing the caller to continue on failures until the group is exhausted. + let group0 = TuneGroup::::new("group0", |_| 0); + + let tunable0 = + Tunable::::new("fake", fake_kernel).group(&group0, |_| 1); + let tunable1 = + Tunable::::new("fake", fake_kernel).group(&group0, |_| 2); + let tunable2 = + Tunable::::new("fake", fake_kernel).group(&group0, |_| 3); + + let key = FakeAutotuneKey; + let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2]); + + assert_eq!(plan.next(None), vec![2]); + assert_eq!(plan.next(None), vec![1]); + assert_eq!(plan.next(None), vec![0]); + assert!(plan.next(None).is_empty()); + } + + #[test_log::test] + fn test_plan_all_negative_group_advances_to_next_group() { + // A group whose every tunable has a negative intra-priority should be skipped entirely + // without stopping autotuning — the next group must still be reached. + let group0 = TuneGroup::::new("group0", |_| 2); + let group1 = TuneGroup::::new("group1", |_| 1); + + let tunable0 = + Tunable::::new("fake", fake_kernel).group(&group0, |_| -1); + let tunable1 = + Tunable::::new("fake", fake_kernel).group(&group0, |_| -2); + let tunable2 = + Tunable::::new("fake", fake_kernel).group(&group1, |_| 1); + + let key = FakeAutotuneKey; + let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2]); + + assert_eq!(plan.next(None), vec![2]); + assert!(plan.next(None).is_empty()); + } + + #[test_log::test] + fn test_plan_no_group_tunables_only_emitted_once_even_on_failures() { + // The ungrouped tunables are emitted together with the first group batch. If the caller + // keeps calling next() (treating the first batch as failing), they must not be + // re-emitted, and the plan must still advance to later groups. + let group0 = TuneGroup::::new("group0", |_| 2); + let group1 = TuneGroup::::new("group1", |_| 1); + + let tunable0 = Tunable::::new("fake", fake_kernel); + let tunable1 = + Tunable::::new("fake", fake_kernel).group(&group0, |_| 1); + let tunable2 = + Tunable::::new("fake", fake_kernel).group(&group1, |_| 1); + + let key = FakeAutotuneKey; + let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2]); + + assert_eq!(plan.next(None), vec![0, 1]); + assert_eq!(plan.next(None), vec![2]); + assert!(plan.next(None).is_empty()); + } + + #[test_log::test] + fn test_plan_multi_group_tunable_not_duplicated_across_failed_groups() { + // tunable1 belongs to both group0 and group1. It must be returned exactly once (via its + // higher-priority group), even if the caller continues iterating after failures. + let group0 = TuneGroup::::new("group0", |_| 1); + let group1 = TuneGroup::::new("group1", |_| 2); + + let tunable0 = Tunable::::new("fake", fake_kernel) + .group(&group0, |_| 1) + .group(&group1, |_| 1); + let tunable1 = + Tunable::::new("fake", fake_kernel).group(&group0, |_| 2); + + let key = FakeAutotuneKey; + let mut plan = TunePlan::new(&key, &[tunable0, tunable1]); + + let mut all_returned: Vec = Vec::new(); + loop { + let batch = plan.next(None); + if batch.is_empty() { + break; + } + all_returned.extend(batch); + } + + // tunable0 comes from group1 (higher priority). tunable1 is the sole member of group0 + // after cross-group dedup. No duplicates. + assert_eq!(all_returned, vec![0, 1]); + } + + #[test_log::test] + fn test_plan_recurses_when_batch_is_fully_already_returned() { + // Regression test: a tunable that lives in multiple groups was already emitted via its + // higher-priority group, so when its lower-priority group's batch fires the only index + // is one already present in `returned`. The plan must NOT return an empty batch here + // (that signals "no more work" to the caller and aborts with NoValidKernelFound); it + // must recurse to the next intra-priority and surface the remaining tunable. + // + // Cross-group dedup in group_plan_next compares (index, Arc group_name), so a + // tunable appearing in both group_hi and group_lo isn't auto-removed from group_lo + // when popped from group_hi — the `returned` + `all_skip` path is the only guard. + let group_hi = TuneGroup::::new("hi", |_| 2); + let group_lo = TuneGroup::::new("lo", |_| 1); + + // tunable0 is in both groups. tunable1 is only in group_lo at a lower intra-priority. + let tunable0 = Tunable::::new("fake", fake_kernel) + .group(&group_hi, |_| 1) + .group(&group_lo, |_| 2); + let tunable1 = + Tunable::::new("fake", fake_kernel).group(&group_lo, |_| 1); + + let key = FakeAutotuneKey; + let mut plan = TunePlan::new(&key, &[tunable0, tunable1]); + + // First call: group_hi yields tunable0. + assert_eq!(plan.next(None), vec![0]); + // Second call: group_lo's higher intra-priority batch is just tunable0 (already + // returned). Without the fix this returns [] and the autotuner aborts. With the fix + // the plan recurses and yields tunable1. + assert_eq!(plan.next(None), vec![1]); + assert!(plan.next(None).is_empty()); + } + fn fake_kernel() -> Result<(), String> { Ok(()) } diff --git a/crates/cubecl-runtime/src/tune/input_generator.rs b/crates/cubecl-runtime/src/tune/input_generator.rs index 7c09aa0dd4..89a966d467 100644 --- a/crates/cubecl-runtime/src/tune/input_generator.rs +++ b/crates/cubecl-runtime/src/tune/input_generator.rs @@ -3,7 +3,7 @@ use core::marker::PhantomData; use variadics_please::all_tuples; /// A function that generates the input for autotuning passes -pub trait InputGenerator: 'static { +pub trait InputGenerator: Send + Sync + 'static { /// Generate a set of inputs for a given key and reference inputs fn generate(&self, key: &K, inputs: &Inputs) -> Inputs; } @@ -18,12 +18,13 @@ pub trait IntoInputGenerator { } /// An input generator implemented by an `Fn` -pub struct FunctionInputGenerator { +pub struct FunctionInputGenerator { func: F, _marker: PhantomData, } -impl InputGenerator for FunctionInputGenerator +impl InputGenerator + for FunctionInputGenerator where F: FunctionInputGen, { @@ -42,7 +43,8 @@ pub trait FunctionInputGen: 'static { fn execute(&self, key: &K, inputs: &Inputs) -> Inputs; } -impl IntoInputGenerator for F +impl IntoInputGenerator + for F where F: FunctionInputGen, { diff --git a/crates/cubecl-runtime/src/tune/key_generator.rs b/crates/cubecl-runtime/src/tune/key_generator.rs index 40b1ccba7f..673faa5af2 100644 --- a/crates/cubecl-runtime/src/tune/key_generator.rs +++ b/crates/cubecl-runtime/src/tune/key_generator.rs @@ -3,7 +3,7 @@ use core::marker::PhantomData; use variadics_please::all_tuples; /// A generator that creates a key for a given set of inputs -pub trait KeyGenerator: 'static { +pub trait KeyGenerator: Send + Sync + 'static { /// Generate a key from a set of inputs fn generate(&self, inputs: &Inputs) -> K; } @@ -18,12 +18,13 @@ pub trait IntoKeyGenerator { } /// A key generator implemented by an `Fn` -pub struct FunctionKeyGenerator { +pub struct FunctionKeyGenerator { func: F, _marker: PhantomData, } -impl KeyGenerator for FunctionKeyGenerator +impl KeyGenerator + for FunctionKeyGenerator where F: FunctionKeygen, { @@ -42,7 +43,8 @@ pub trait FunctionKeygen: 'static { fn execute(&self, inputs: &Inputs) -> K; } -impl IntoKeyGenerator for F +impl IntoKeyGenerator + for F where F: FunctionKeygen, { diff --git a/crates/cubecl-runtime/src/tune/local.rs b/crates/cubecl-runtime/src/tune/local.rs index 77933660b0..5d1f979a25 100644 --- a/crates/cubecl-runtime/src/tune/local.rs +++ b/crates/cubecl-runtime/src/tune/local.rs @@ -7,19 +7,17 @@ use core::{ fmt::Display, hash::Hash, }; -use cubecl_common::map::SharedStateMap; use hashbrown::HashMap; +use spin::Mutex; /// A local tuner allows to create a tuner for a specific key that can be different from the server /// key. pub struct LocalTuner { - state: SharedStateMap>, + state: Mutex>>>>, name: &'static str, sets: spin::RwLock>>>, } -unsafe impl Sync for LocalTuner {} - /// Create a local tuner with the provided name. #[macro_export] macro_rules! local_tuner { @@ -41,13 +39,16 @@ where /// Create a new local tuner. pub const fn new(name: &'static str) -> Self { Self { - state: SharedStateMap::new(), + state: Mutex::new(None), name, sets: spin::RwLock::new(None), } } - /// Init the [tunable set](TunableSet) + /// Get or initialize the [`TunableSet`] for this tuner. + /// + /// Returns a cached `Arc` keyed by the `TypeId` of `init_set`. The + /// initializer runs at most once per process. pub fn init(&self, init_set: F) -> Arc> where F: Fn() -> TunableSet + 'static + Send + Sync, @@ -90,7 +91,9 @@ where /// Clear the autotune state. pub fn clear(&self) { - self.state.clear() + if let Some(s) = self.state.lock().as_mut() { + s.clear() + } } #[cfg(feature = "autotune-checks")] @@ -110,24 +113,8 @@ where super::check_autotune_outputs(checks_outputs); } - /// Try every operation in order and return the first successful result. - /// - /// Used as a fallback when autotuning results aren't available yet - /// (e.g. on wasm where tuning is async). - fn try_all_operations(operations: &TunableSet, inputs: In) -> Out - where - In: Clone + Send + 'static, - Out: AutotuneOutput, - { - for i in 0..operations.len() { - if let Ok(output) = operations.fastest(i).execute(inputs.clone()) { - return output; - } - } - panic!("All autotune operations failed, no viable operation found."); - } - - /// Execute the best operation in the provided [tunable set](TunableSet) + /// Execute the fastest operation in a [`TunableSet`], triggering a tuning pass on + /// the first call for a given key. pub fn execute( &self, id: &ID, @@ -141,113 +128,62 @@ where { let key = operations.generate_key(&inputs); - // If this is cached and ready, use the operation. - let tuner_state = self.state.get_or_init(id, move |id| { - let name = self.name.replace("::", "-"); - Tuner::new(&name, &id.to_string()) - }); - let tuner = tuner_state.read(); + let tuner = { + let mut state_lock = self.state.lock(); + let state_map = state_lock.get_or_insert_with(|| HashMap::new()); + state_map + .entry(id.clone()) + .or_insert_with(move || { + let name = self.name.replace("::", "-"); + Arc::new(Tuner::new(&name, &id.to_string())) + }) + .clone() + }; + + // First, check for a cache hit under a read lock. + if let TuneCacheResult::Hit { fastest_index } = tuner.fastest(&key) { + #[cfg(feature = "autotune-checks")] + self.checks(&operations, &inputs); + return operations + .fastest(fastest_index) + .execute(inputs) + .expect("Should run when selected by autotune."); + } - let mut tuner = match tuner.fastest(&key) { + let fastest = tuner.check_tune( + &key, + &inputs, + &operations, + || operations.compute_checksum(), + client, + ); + + // Resolve the cache state into a `done_rx` we can wait on. Hit → run immediately; + // Pending → attach to the in-flight tune; Miss → kick one off. + match fastest { TuneCacheResult::Hit { fastest_index } => { - core::mem::drop(tuner); - core::mem::drop(tuner_state); - #[cfg(feature = "autotune-checks")] self.checks(&operations, &inputs); - - let op = operations.fastest(fastest_index); - let result = op + return operations + .fastest(fastest_index) .execute(inputs) .expect("Should run when selected by autotune."); - return result; } - TuneCacheResult::Pending => { - core::mem::drop(tuner); - core::mem::drop(tuner_state); - - #[cfg(feature = "autotune-checks")] - self.checks(&operations, &inputs); - - return Self::try_all_operations(&operations, inputs); + TuneCacheResult::Unchecked | TuneCacheResult::Miss => { + panic!( + "Somehow we STILL didn't check a tuning checksum or start tuning, something has gone wrong." + ) } - #[cfg(std_io)] - TuneCacheResult::Unchecked => { - core::mem::drop(tuner); - let mut tuner = tuner_state.write(); - - // If the cache checksum hasn't been checked, do so now, and retry. - let checksum = operations.compute_checksum(); - tuner.validate_checksum(&key, &checksum); - - // Check if with validation we can use its result - if let TuneCacheResult::Hit { fastest_index } = tuner.fastest(&key) { - core::mem::drop(tuner); - core::mem::drop(tuner_state); - - let op = operations.fastest(fastest_index); - let result = op - .execute(inputs) - .expect("Should run when selected by autotune."); - return result; + TuneCacheResult::Pending => { + // If we're still waiting for the result, eg. on wasm, just fallback to trying all operations. + let operations: &TunableSet = &operations; + for i in 0..operations.len() { + if let Ok(output) = operations.fastest(i).execute(inputs.clone()) { + return output; + } } - - tuner - } - - #[cfg(not(std_io))] - TuneCacheResult::Unchecked => { - core::mem::drop(tuner); - tuner_state.write() - } - TuneCacheResult::Miss => { - core::mem::drop(tuner); - tuner_state.write() + panic!("All autotune operations failed, no viable operation found."); } - }; - - let job = if !tuner.autotuning.contains(&key) { - tuner.autotuning.insert(key.clone()); - Some(tuner.prepare_autotune(key.clone(), &inputs, &operations, client)) - } else { - None - }; - - // Drop the write lock before running the (potentially blocking) job - // and before re-acquiring the lock below. - core::mem::drop(tuner); - core::mem::drop(tuner_state); - - if let Some(job) = job { - job(); } - - let index_to_run = { - let tuner_state = self.state.get(id).unwrap(); - let mut tuner = tuner_state.write(); - - tuner.handle_results(); - - match tuner.fastest(&key) { - TuneCacheResult::Hit { fastest_index } => { - // There's a known good value - just run it. - fastest_index - } - TuneCacheResult::Pending | TuneCacheResult::Miss => { - // We're still waiting for the results of the autotune task. - // This should only happen on wasm since we can't block waiting - // on the results there. Try all options. - return Self::try_all_operations(&operations, inputs); - } - TuneCacheResult::Unchecked => { - panic!("Should have checked the cache.") - } - } - }; - - operations - .fastest(index_to_run) - .execute(inputs) - .expect("Should run when selected by autotune.") } } diff --git a/crates/cubecl-runtime/src/tune/operation.rs b/crates/cubecl-runtime/src/tune/operation.rs index f76cd92aca..27cf581e00 100644 --- a/crates/cubecl-runtime/src/tune/operation.rs +++ b/crates/cubecl-runtime/src/tune/operation.rs @@ -5,9 +5,6 @@ use alloc::vec::Vec; use core::fmt::{Debug, Display}; use core::hash::Hash; -#[cfg(std_io)] -use alloc::format; - use super::{ AutotuneError, input_generator::{InputGenerator, IntoInputGenerator}, @@ -16,7 +13,6 @@ use super::{ use super::{Tunable, TunePlan}; /// Default checksum for an operation set -#[cfg(std_io)] pub fn compute_checksum( autotunables: impl Iterator>>, ) -> String { @@ -24,22 +20,19 @@ pub fn compute_checksum( autotunables.for_each(|op| { checksum += op.name(); }); - format!("{:x}", md5::compute(checksum)) + alloc::format!("{:x}", md5::compute(checksum)) } /// Groups operations of the same type for autotune pub struct TunableSet { tunables: Vec>, - key_gen: Arc>, - input_gen: Arc>, + key_gen: Arc + Send + Sync>, + input_gen: Arc + Send + Sync>, #[allow(clippy::type_complexity)] checksum_override: Option String + Send + Sync>>, } -unsafe impl Send for TunableSet {} -unsafe impl Sync for TunableSet {} - -impl +impl TunableSet { /// The number of tunables in the set. @@ -105,7 +98,6 @@ impl } /// Compute a checksum that can invalidate outdated cached auto-tune results. - #[cfg(std_io)] pub fn compute_checksum(&self) -> String { if let Some(checksum_override) = &self.checksum_override { checksum_override(self) diff --git a/crates/cubecl-runtime/src/tune/tune_benchmark.rs b/crates/cubecl-runtime/src/tune/tune_benchmark.rs index 29d59a0de9..04fcb3b206 100644 --- a/crates/cubecl-runtime/src/tune/tune_benchmark.rs +++ b/crates/cubecl-runtime/src/tune/tune_benchmark.rs @@ -5,7 +5,8 @@ use alloc::sync::Arc; use alloc::vec::Vec; use cubecl_common::profile::ProfileDuration; -/// A benchmark that runs on server handles +/// A single candidate's benchmark: ties a [`TuneFn`] to its inputs and a client, ready +/// to run warmup + profiling samples. #[derive(new)] pub struct TuneBenchmark { operation: Arc>, diff --git a/crates/cubecl-runtime/src/tune/tune_cache.rs b/crates/cubecl-runtime/src/tune/tune_cache.rs index 7af00c87ec..619ba5e61f 100644 --- a/crates/cubecl-runtime/src/tune/tune_cache.rs +++ b/crates/cubecl-runtime/src/tune/tune_cache.rs @@ -12,7 +12,6 @@ use super::{AutotuneError, AutotuneKey, AutotuneOutcome}; use alloc::string::String; use hashbrown::HashMap; -/// In-memory cache entry #[derive(Debug)] pub(crate) enum CacheEntry { Done { @@ -97,7 +96,9 @@ pub enum TuneCacheResult { }, /// The operation might be cached, but we don't know yet whether the checksum is valid. Unchecked, - /// We don't know yet what is fastest, but are waiting for a result to come in. + /// A tuning job is in flight for this key — the worker hasn't published a result yet. + /// The receiver wakes (with `Err(RecvError)`) when the worker commits the result. Native + /// callers `block_on` it and re-query; wasm callers drop it and fall back. Pending, /// No operation is found yet. Miss, @@ -134,42 +135,43 @@ impl TuneCache { } pub fn fastest(&self, key: &K) -> TuneCacheResult { - let result = self.in_memory_cache.get(key); - - let Some(val) = result else { + let Some(val) = self.in_memory_cache.get(key) else { return TuneCacheResult::Miss; }; - match val { - CacheEntry::Done { - checksum, - fastest_index, - } => { - if cfg!(std_io) { - match checksum { - ChecksumState::ToBeVerified(..) => TuneCacheResult::Unchecked, // Don't know yet. - ChecksumState::NoMatch => TuneCacheResult::Miss, // Can't use this. - ChecksumState::Match => TuneCacheResult::Hit { - fastest_index: *fastest_index, - }, - } - } else { - // Clippy; - let _ = checksum; - TuneCacheResult::Hit { - fastest_index: *fastest_index, - } - } + let CacheEntry::Done { + checksum, + fastest_index, + } = val + else { + // Pending: clone the receiver so the caller can subscribe to the in-flight tune. + let CacheEntry::Pending = val else { + unreachable!() + }; + return TuneCacheResult::Pending; + }; + + if cfg!(std_io) { + match checksum { + ChecksumState::ToBeVerified(..) => TuneCacheResult::Unchecked, // Don't know yet. + ChecksumState::NoMatch => TuneCacheResult::Miss, // Can't use this. + ChecksumState::Match => TuneCacheResult::Hit { + fastest_index: *fastest_index, + }, + } + } else { + // Clippy; + let _ = checksum; + TuneCacheResult::Hit { + fastest_index: *fastest_index, } - CacheEntry::Pending => TuneCacheResult::Pending, } } #[cfg(std_io)] - pub fn validate_checksum(&mut self, key: &K, checksum: &str) { - let result = self.in_memory_cache.get_mut(key); - let Some(val) = result else { - return; + pub fn validate_checksum(&mut self, key: &K, checksum: &str) -> TuneCacheResult { + let Some(val) = self.in_memory_cache.get_mut(key) else { + return TuneCacheResult::Miss; }; if let CacheEntry::Done { @@ -184,9 +186,13 @@ impl TuneCache { *checksum_state = ChecksumState::NoMatch; } } + + self.fastest(key) } - #[allow(unused)] + /// Mark a key as being tuned. Used by [`Tuner::tune`] under the cache mutex so that + /// concurrent callers see [`TuneCacheResult::Pending`] and wait on the same job instead of + /// starting a second one. Returns `(Sender, Receiver)`: pub(crate) fn mark_pending(&mut self, key: K) { self.in_memory_cache.insert(key, CacheEntry::Pending); } diff --git a/crates/cubecl-runtime/src/tune/tuner.rs b/crates/cubecl-runtime/src/tune/tuner.rs index b67f21d52f..9d8458cd4c 100644 --- a/crates/cubecl-runtime/src/tune/tuner.rs +++ b/crates/cubecl-runtime/src/tune/tuner.rs @@ -1,10 +1,7 @@ -use alloc::boxed::Box; use alloc::format; use alloc::sync::Arc; use alloc::vec::Vec; -use async_channel::{Receiver, Sender}; use cubecl_common::profile::ProfileDuration; -use hashbrown::HashSet; use core::time::Duration; @@ -16,15 +13,17 @@ use crate::server::LaunchError; use crate::tune::{AutotuneResult, TuneBenchmark, TuneCache}; use crate::{client::ComputeClient, runtime::Runtime}; -use super::{AutotuneKey, AutotuneOutput, TunableSet, TuneCacheResult, TuneFn, TunePlan}; +use super::{AutotuneKey, AutotuneOutput, TunableSet, TuneCacheResult}; #[derive(Debug)] -/// Executes autotune benchmarking and caching +/// Runs autotune benchmarks for a single device and caches the results. +/// +/// On wasm, [`tune`](Self::tune) spawns its work on the browser event loop; elsewhere +/// it blocks inline. Either way the benchmarking itself is synchronous; only the +/// per-sample profile resolution is awaited. pub struct Tuner { - tune_cache: TuneCache, - logger: Logger, - channel: (Sender>, Receiver>), - pub(crate) autotuning: HashSet, + cache: Arc>>, + logger: Arc>, } /// The measured outcome for a given autotune invocation. @@ -46,19 +45,6 @@ impl core::fmt::Display for AutotuneOutcome { } } -enum AutotuneMessage { - Done { - key: K, - fastest_index: usize, - results: Vec, - #[cfg(std_io)] - checksum: String, - context_logs: Option, - }, - #[allow(dead_code)] - Pending(K), -} - /// Error from running autotune. #[derive(Debug, Clone)] #[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))] @@ -100,378 +86,312 @@ impl From for AutotuneError { } } +/// A queued benchmark: one sample-set of profile futures plus its metadata. +struct PendingBench { + index: usize, + name: String, + profiles: Vec, +} + +/// A queued tuning job: all data needed to resolve samples and commit the result. +/// Holds no references so it's trivially `Send + 'static` for the wasm spawn path. +struct TuneRequest { + key: K, + results: Vec, + #[cfg(std_io)] + checksum: String, + context_logs: Option, + pending: Vec, +} + #[allow(clippy::new_without_default)] impl Tuner { - /// Returns a tuner with cache initialized from persistent cache + /// Create a tuner. Its cache is seeded from the persistent on-disk cache when + /// `std_io` is enabled. pub fn new(name: &str, device_id: &str) -> Self { - let channel = async_channel::unbounded(); - Self { - tune_cache: TuneCache::new(name, device_id), - logger: Logger::new(), - channel, - autotuning: HashSet::new(), + cache: Arc::new(spin::Mutex::new(TuneCache::new(name, device_id))), + logger: Arc::new(spin::Mutex::new(Logger::new())), } } /// Fetch the fastest autotune operation index for an autotune key. pub fn fastest(&self, key: &K) -> TuneCacheResult { - self.tune_cache.fastest(key) - } - - /// Fetch the fastest autotune operation index for an autotune key and validate the checksum. - #[cfg(std_io)] - pub fn validate_checksum(&mut self, key: &K, checksum: &str) { - if let AutotuneLogLevel::Full = self.logger.log_level_autotune() { - self.logger - .log_autotune(&format!("validate checksum key={key}, checksum={checksum}")); - } - self.tune_cache.validate_checksum(key, checksum) + self.cache.lock().fastest(key) } - /// Handle an autotune result message, see [`execute_autotune`] - fn handle_result(&mut self, msg: AutotuneMessage) { - match msg { - AutotuneMessage::Pending(key) => { - self.tune_cache.mark_pending(key); - } - AutotuneMessage::Done { - key, - fastest_index, - results, - #[cfg(std_io)] - checksum, - context_logs, - } => { - match self.logger.log_level_autotune() { - AutotuneLogLevel::Minimal => { - let top_times = results - .iter() - .map(|r| { - let time = r - .outcome - .as_ref() - .map(|r| r.computation.median) - .unwrap_or(Duration::MAX); - - let index = r.outcome.as_ref().map(|r| r.index).unwrap_or_default(); - (index, time) - }) - .take(3) - .collect::>(); - - let result = results - .first() - .expect("At least one kernel needed.") - .outcome - .as_ref() - .expect("At least one kernel has to succeed."); - - let context = match &context_logs { - Some(context) => context, - None => "", - }; - self.logger.log_autotune(&format!( - "Fastest result {}-{key}. \n Top 3 times: {top_times:?}, context: {context}", - result.name, - )); - } - AutotuneLogLevel::Full => { - let result = results - .first() - .expect("At least one kernel needed.") - .outcome - .as_ref() - .expect("At least one kernel has to succeed."); - - let context = match &context_logs { - Some(context) => context, - None => "", - }; - self.logger.log_autotune(&format!( - "Fastest result {}-{key}. Context: {context}", - result.name, - )); - - for result in results.iter() { - match &result.outcome { - Ok(val) => { - self.logger.log_autotune(&format!("{val}")); - } - Err(err) => self.logger.log_autotune(&format!("{err:?}")), - } - } - } - AutotuneLogLevel::Disabled => {} - }; - - self.tune_cache.cache_insert(key.clone(), fastest_index); - - #[cfg(std_io)] - { - self.tune_cache - .persistent_cache_insert(key, checksum, fastest_index, results); - } - } - } - } - - /// Check if any autotuning results have come in asynchronously. - pub fn handle_results(&mut self) { - // Handle any results that have come in. Note that execute_autotune pushes results to the channel immediately if possible. - // Since this function takes an &mut we know we have exclusive access, and no other threads are currently still adding results. - while let Ok(msg) = self.channel.1.try_recv() { - self.handle_result(msg); - } - } - - /// Execute benchmarks to find out what the fastest operation is. - pub fn prepare_autotune( + /// Kick off a tuning job for `key`, or return immediately if the cache already has + /// a result or another thread is tuning it. + pub fn check_tune( &self, - key: K, + key: &K, inputs: &In, tunables: &TunableSet, + #[allow(unused_variables)] checksum: impl FnOnce() -> String + Send + Sync, client: &ComputeClient, - ) -> Box { - log::info!("Tuning {key}"); + ) -> TuneCacheResult { + { + let mut cache = self.cache.lock(); - // Note that this message will be processed straight away by handle_results. - let sender = self.channel.0.clone(); + #[allow(unused_mut)] + let mut cur = cache.fastest(key); - let autotunables = tunables.autotunables(); - let mut results: Vec = Vec::with_capacity(autotunables.len()); + // Try to validate current if need be. + #[cfg(std_io)] + if matches!(cur, TuneCacheResult::Unchecked) { + // Checksum validation may retroactively turn an Unchecked entry into a Hit. + let mut log = self.logger.lock(); + let checksum = checksum(); + if let AutotuneLogLevel::Full = log.log_level_autotune() { + log.log_autotune(&format!("validate checksum key={key}, checksum={checksum}")); + } + cur = cache.validate_checksum(key, &checksum) + } - for a in autotunables.iter() { - results.push(AutotuneResult::error(AutotuneError::Skip { - name: a.name().to_string(), - })); + match cur { + // Already pending or done, return current state. + TuneCacheResult::Hit { .. } | TuneCacheResult::Pending => return cur, + // Otherwise we start tuning, mark this as pending. + TuneCacheResult::Miss | TuneCacheResult::Unchecked => { + cache.mark_pending(key.clone()) + } + } + // Scope the guard: the rest of this function re-locks `self.cache` (fast + // path insert, `process_request`), and `spin::Mutex` is non-reentrant. } - if autotunables.len() == 1 { - let message = AutotuneMessage::Done { - key, - fastest_index: 0, - results, - #[cfg(std_io)] - checksum: tunables.compute_checksum(), - context_logs: None, - }; - - return Box::new(move || { - sender - .try_send(message) - .expect("Loss message channel somehow") - }); - } + log::info!("Tuning {key}"); - let client = client.clone(); - let key_cloned = key.clone(); - let plan = tunables.plan(&key); - let inputs_generator = tunables.inputs_generator(&key.clone(), inputs); + let autotunables = tunables.autotunables(); + let mut results: Vec = autotunables + .iter() + .map(|a| { + AutotuneResult::error(AutotuneError::Skip { + name: a.name().to_string(), + }) + }) + .collect(); #[cfg(std_io)] let checksum = tunables.compute_checksum(); - let context_logs = match self.logger.log_level_autotune() { - AutotuneLogLevel::Disabled => false, - AutotuneLogLevel::Minimal => false, - AutotuneLogLevel::Full => true, - }; - - let fut_result = async move { - let test_inputs = inputs_generator(); - - Self::generate_tune_message( - key_cloned, - &client, - plan, - autotunables, - test_inputs, - results, - #[cfg(std_io)] - checksum, - context_logs, - ) - .await - }; - Box::new(move || { - let message = { - cfg_if::cfg_if! { - if #[cfg(target_family = "wasm")] { - let sender = sender.clone(); - - let send_fut = async move { - // If the channel has been closed, ignore. Maybe the main app is exiting - // before the tune results come in. - let _ = sender.send(fut_result.await).await; - }; - // On wasm, spawn the tuning as a detached task. - wasm_bindgen_futures::spawn_local(send_fut); - // Mark the current tuning as pending. - AutotuneMessage::Pending(key) - } else { - cubecl_common::future::block_on(fut_result) - } - } - }; - - // Note that this message will be processed straight away by handle_results. - sender - .try_send(message) - .expect("Loss message channel somehow"); - }) - } - - #[allow(clippy::too_many_arguments)] - async fn generate_tune_message( - key: K, - client: &ComputeClient, - mut plan: TunePlan, - autotunables: Vec + 'static>>, - test_inputs: In, - mut results: Vec, - #[cfg(std_io)] checksum: String, - context_logs: bool, - ) -> AutotuneMessage { - let context_logs = match Self::execute_tune_plan( - client, - &mut plan, - autotunables, - &test_inputs, - &mut results, - context_logs, - ) - .await - { - Ok(context_logs) => context_logs, - Err(err) => { - panic!("Can't execute the autotune plan for key: {key:?}\n - Error: {err:?}"); - } - }; - - // Finds the fastest operation. - results.sort_by(|a, b| { - let a = a - .outcome - .as_ref() - .map(|r| r.computation.score()) - .unwrap_or(u64::MAX); - let b = b - .outcome - .as_ref() - .map(|r| r.computation.score()) - .unwrap_or(u64::MAX); - - a.cmp(&b) - }); - - // Log & send results. - let result = results - .first() - .expect("At least one kernel needed.") - .outcome - .as_ref() - .expect("At least one kernel has to succeed."); - - AutotuneMessage::Done { - key, - fastest_index: result.index, - results, - #[cfg(std_io)] - checksum, - context_logs, - } - } - - async fn execute_tune_plan( - client: &ComputeClient, - plan: &mut TunePlan, - autotunables: Vec + 'static>>, - test_inputs: &In, - results: &mut [AutotuneResult], - context_logs: bool, - ) -> Result, AutotuneError> { - #[derive(Debug)] - #[allow(unused_variables, dead_code)] // Only use for debug - struct Context<'a> { - plan: &'a TunePlan, - results: &'a [AutotuneResult], + // Fast path: single tunable, no benchmarking needed. + if autotunables.len() == 1 { + self.cache.lock().cache_insert(key.clone(), 0); + return TuneCacheResult::Hit { fastest_index: 0 }; } - let mut context_logs = match context_logs { - true => Some("".to_string()), - false => None, + let test_inputs = tunables.inputs_generator(key, inputs)(); + let mut plan = tunables.plan(key); + let mut context_logs = match self.logger.lock().log_level_autotune() { + AutotuneLogLevel::Full => Some(String::new()), + _ => None, }; + // Walk the plan batch by batch, launching each benchmark synchronously. A + // successful launch queues a `PendingBench` for the async resolver below; + // launch errors go straight into `results`. Retry the next batch if a whole + // batch failed to queue anything. + let mut pending = Vec::::new(); loop { - let mut num_success = 0; let tunable_indices = plan.next(context_logs.as_mut()); if tunable_indices.is_empty() { - return Err(AutotuneError::NoValidKernelFound { - context: format!("{:?}", &Context { plan, results }), - }); + panic!( + "Can't execute the autotune plan for key: {key:?}\n - plan: {plan:?}\n - results: {results:?}" + ); } for index in tunable_indices { let op = &autotunables[index]; let name = op.name().to_string(); - let tuner = TuneBenchmark::new(op.clone(), test_inputs.clone(), client.clone()); - let profiles = tuner.profile().map(|bench| (name, index, bench)); - - match profiles { - Ok(result) => { - // Wait for the results to come in, and determine the outcome. - let (name, index, profiles) = result; - let result = Self::process_autotune(name, index, profiles).await; - match result { - Ok(val) => { - results[index] = AutotuneResult::success(val); - num_success += 1; - } - Err(err) => { - results[index] = AutotuneResult::error(err); - } - } - } + let bench = TuneBenchmark::new(op.clone(), test_inputs.clone(), client.clone()); + match bench.profile() { + Ok(profiles) => pending.push(PendingBench { + index, + name, + profiles, + }), Err(err) => { results[index] = AutotuneResult::error(err); } } } - if num_success > 0 { + if !pending.is_empty() { break; } } - Ok(context_logs) - } + let request = TuneRequest { + key: key.clone(), + results, + #[cfg(std_io)] + checksum, + context_logs, + pending, + }; - async fn process_autotune( - name: String, - index: usize, - profiles: Vec, - ) -> Result { - let mut durations = Vec::new(); - if !profiles.is_empty() { - let timing_method = profiles.first().unwrap().timing_method(); - for profile in profiles { - durations.push(profile.resolve().await.duration()); - } - let bench_durations = BenchmarkDurations::from_durations(timing_method, durations); + // Resolve samples and commit the result. On wasm this runs on the browser + // event loop; elsewhere it blocks inline. + #[cfg(target_family = "wasm")] + { + let cache = self.cache.clone(); + let logger = self.logger.clone(); + wasm_bindgen_futures::spawn_local(async move { + process_request(request, &cache, &logger).await; + }); - Ok(AutotuneOutcome::new( - name, - index, - BenchmarkComputations::new(&bench_durations), - )) - } else { - Err(AutotuneError::Unknown { + // Pending results. + return TuneCacheResult::Pending; + } + + #[cfg(not(target_family = "wasm"))] + cubecl_common::future::block_on(process_request(request, &self.cache, &self.logger)) + } +} + +/// Await every profile sample, pick the fastest tunable, commit to the cache. +async fn process_request( + request: TuneRequest, + cache: &spin::Mutex>, + logger: &spin::Mutex, +) -> TuneCacheResult { + let TuneRequest { + key, + mut results, + #[cfg(std_io)] + checksum, + context_logs, + pending, + } = request; + + for bench in pending { + let PendingBench { + index, + name, + profiles, + } = bench; + + if profiles.is_empty() { + results[index] = AutotuneResult::error(AutotuneError::Unknown { name, err: "No profiling available".to_string(), - }) + }); + continue; + } + + let timing_method = profiles.first().unwrap().timing_method(); + let mut durations = Vec::with_capacity(profiles.len()); + for profile in profiles { + durations.push(profile.resolve().await.duration()); + } + + results[index] = AutotuneResult::success(AutotuneOutcome::new( + name, + index, + BenchmarkComputations::new(&BenchmarkDurations::from_durations( + timing_method, + durations, + )), + )); + } + + results.sort_by(|a, b| { + let a = a + .outcome + .as_ref() + .map(|r| r.computation.score()) + .unwrap_or(u64::MAX); + let b = b + .outcome + .as_ref() + .map(|r| r.computation.score()) + .unwrap_or(u64::MAX); + a.cmp(&b) + }); + + let fastest_index = results + .first() + .expect("At least one kernel needed.") + .outcome + .as_ref() + .expect("At least one kernel has to succeed.") + .index; + + { + log_result(&mut logger.lock(), &key, &results, context_logs.as_deref()); + cache.lock().cache_insert(key.clone(), fastest_index); + #[cfg(std_io)] + cache + .lock() + .persistent_cache_insert(key, checksum, fastest_index, results); + } + + TuneCacheResult::Hit { fastest_index } +} + +/// Emit the autotune result through the logger at the currently configured level. +fn log_result( + logger: &mut Logger, + key: &K, + results: &[AutotuneResult], + context_logs: Option<&str>, +) { + match logger.log_level_autotune() { + AutotuneLogLevel::Minimal => { + let top_times = results + .iter() + .map(|r| { + let time = r + .outcome + .as_ref() + .map(|r| r.computation.median) + .unwrap_or(Duration::MAX); + + let index = r.outcome.as_ref().map(|r| r.index).unwrap_or_default(); + (index, time) + }) + .take(3) + .collect::>(); + + let result = results + .first() + .expect("At least one kernel needed.") + .outcome + .as_ref() + .expect("At least one kernel has to succeed."); + + let context = context_logs.unwrap_or(""); + logger.log_autotune(&format!( + "Fastest result {}-{key}. \n Top 3 times: {top_times:?}, context: {context}", + result.name, + )); + } + AutotuneLogLevel::Full => { + let result = results + .first() + .expect("At least one kernel needed.") + .outcome + .as_ref() + .expect("At least one kernel has to succeed."); + + let context = context_logs.unwrap_or(""); + logger.log_autotune(&format!( + "Fastest result {}-{key}. Context: {context}", + result.name, + )); + + for result in results.iter() { + match &result.outcome { + Ok(val) => { + logger.log_autotune(&format!("{val}")); + } + Err(err) => logger.log_autotune(&format!("{err:?}")), + } + } } + AutotuneLogLevel::Disabled => {} } } diff --git a/crates/cubecl-spirv/src/arithmetic.rs b/crates/cubecl-spirv/src/arithmetic.rs index 5049c877e7..91454cf29f 100644 --- a/crates/cubecl-spirv/src/arithmetic.rs +++ b/crates/cubecl-spirv/src/arithmetic.rs @@ -285,6 +285,49 @@ impl SpirvCompiler { } }); } + Arithmetic::VectorSum(op) => { + let input_ir = op.input; + let input = self.compile_variable(input_ir); + let out = self.compile_variable(out); + let in_item = input.item(); + let out_ty = out.item(); + let vec_size = in_item.vectorization(); + let scalar_ty = out_ty.id(self); + let input_id = self.read(&input); + let out_id = self.write_id(&out); + self.mark_uniformity(out_id, uniform); + + if vec_size <= 1 { + // Scalar: identity + self.copy_object(scalar_ty, Some(out_id), input_id).unwrap(); + } else if matches!(out_ty.elem(), Elem::Float(..) | Elem::Relaxed) { + // Float vector: use OpDot with ones vector for optimal single instruction + self.declare_math_mode(modes, out_id); + let ones_val = in_item + .elem() + .constant(self, ConstVal::Bit32(1.0f32.to_bits())); + let vec_ty = in_item.id(self); + let ones = self.constant_composite(vec_ty, (0..vec_size).map(|_| ones_val)); + self.dot(scalar_ty, Some(out_id), input_id, ones).unwrap(); + if matches!(out_ty.elem(), Elem::Relaxed) { + self.decorate(out_id, Decoration::RelaxedPrecision, []); + } + } else { + // Integer vector: extract and add + let elem_ty = out_ty.id(self); + let mut acc = self + .composite_extract(elem_ty, None, input_id, vec![0]) + .unwrap(); + for i in 1..vec_size { + let elem = self + .composite_extract(elem_ty, None, input_id, vec![i]) + .unwrap(); + acc = self.i_add(elem_ty, None, acc, elem).unwrap(); + } + self.copy_object(scalar_ty, Some(out_id), acc).unwrap(); + } + self.write(&out, out_id); + } Arithmetic::Magnitude(op) => { self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| { b.declare_math_mode(modes, out); @@ -668,6 +711,7 @@ impl SpirvCompiler { _ => unreachable!(), }, ), + Arithmetic::Conj(_) => unimplemented!("Conj not supported on SPIRV"), } } } diff --git a/crates/cubecl-spirv/src/instruction.rs b/crates/cubecl-spirv/src/instruction.rs index 76a7f381e3..48863457cc 100644 --- a/crates/cubecl-spirv/src/instruction.rs +++ b/crates/cubecl-spirv/src/instruction.rs @@ -300,6 +300,9 @@ impl SpirvCompiler { .unwrap(); } Operator::Select(op) => self.compile_select(op.cond, op.then, op.or_else, out, uniform), + Operator::Real(_) | Operator::Imag(_) => { + unimplemented!("Real/Imag not supported on SPIRV") + } } } diff --git a/crates/cubecl-spirv/src/item.rs b/crates/cubecl-spirv/src/item.rs index 0b9902f86f..9c1fc08270 100644 --- a/crates/cubecl-spirv/src/item.rs +++ b/crates/cubecl-spirv/src/item.rs @@ -396,6 +396,7 @@ impl SpirvCompiler { Elem::Int(8, false) } core::ElemType::Bool => Elem::Bool, + core::ElemType::Complex(_) => unimplemented!("Complex not supported on SPIRV"), } } diff --git a/crates/cubecl-spirv/src/variable.rs b/crates/cubecl-spirv/src/variable.rs index 9771d6bf8a..27fcc25f5e 100644 --- a/crates/cubecl-spirv/src/variable.rs +++ b/crates/cubecl-spirv/src/variable.rs @@ -169,6 +169,7 @@ impl From<(ConstantValue, Item)> for ConstVal { ConstantValue::Float(val) => ConstVal::from_float(val, width, elem.float_encoding()), ConstantValue::UInt(val) => ConstVal::from_uint(val, width), ConstantValue::Bool(val) => ConstVal::from_bool(val), + ConstantValue::Complex(_, _) => unimplemented!("Complex not supported on SPIRV"), } } } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 375332d068..444166bd58 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -254,6 +254,7 @@ impl WgslCompiler { kind => panic!("{kind:?} is not a valid WgpuElement"), }, cube::ElemType::Bool => wgsl::Elem::Bool, + cube::ElemType::Complex(_) => unimplemented!("Complex not supported in WGSL"), } } @@ -979,6 +980,11 @@ impl WgslCompiler { rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), }), + cube::Arithmetic::Conj(_) => unimplemented!("Conj not supported in WGSL"), + cube::Arithmetic::VectorSum(op) => instructions.push(wgsl::Instruction::VectorSum { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), } } @@ -1166,6 +1172,9 @@ impl WgslCompiler { or_else: self.compile_variable(op.or_else), out: self.compile_variable(out), }), + cube::Operator::Real(_) | cube::Operator::Imag(_) => { + unimplemented!("Real/Imag not supported in WGSL") + } } } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index 037786cd38..9fb4192e29 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -421,6 +421,10 @@ pub enum Instruction { rhs: Variable, out: Variable, }, + VectorSum { + input: Variable, + out: Variable, + }, IsNan { input: Variable, out: Variable, @@ -1111,6 +1115,18 @@ for (var {i}: {i_ty} = {start}; {i} {cmp} {end}; {increment}) {{ writeln!(f, "{out} = dot({lhs}, {rhs});") } } + Instruction::VectorSum { input, out } => { + let vec_size = input.item().vectorization_factor(); + let out = out.fmt_left(); + if vec_size <= 1 { + writeln!(f, "{out} = {input};") + } else { + let elems = (0..vec_size) + .map(|i| format!("{}", input.index(i))) + .collect::>(); + writeln!(f, "{out} = {};", elems.join(" + ")) + } + } Instruction::VecInit { inputs, out } => { let item = out.item(); let inputs = inputs.iter().map(|var| var.to_string()).collect::>(); diff --git a/docs/plans/2026-04-14-complex-design.md b/docs/plans/2026-04-14-complex-design.md new file mode 100644 index 0000000000..a5027a6a20 --- /dev/null +++ b/docs/plans/2026-04-14-complex-design.md @@ -0,0 +1,177 @@ +# Interleaved Complex Number Support for CubeCL — Design + +## Overview + +Add interleaved `Complex` / `Complex` as first-class types in CubeCL IR, with CUDA and WGPU backend support. + +## Trait Hierarchy + +Independent `Complex` trait, parallel to `Float` and `Int`: + +``` +CubePrimitive + ├── Numeric (Add/Sub/Mul/Div/Neg/Abs/Remainder + num_traits) + │ ├── Float (Exp/Log/Sin/Cos/Sqrt/Tan/Powf/Ceil/Floor/...) + │ └── Int (Bitwise/Saturating/...) + └── Complex (Add/Sub/Mul/Div/Neg/Abs + Exp/Log/Sin/Cos/Sqrt/Powf + Conj/Real/Imag) +``` + +Complex does NOT implement Numeric, Float, or Int. Invalid ops (Remainder, Ceil/Floor/Round/Trunc, Bitwise, ordering comparisons) are excluded by type system. + +## Design Sections + +### 1. IR Type System (`cubecl-ir`) + +**File:** `crates/cubecl-ir/src/type.rs` + +Add to `ElemType`: + +```rust +#[derive(Clone, Debug, PartialEq, Eq, Hash, Copy)] +pub enum ComplexKind { + C32, // Complex + C64, // Complex +} + +pub enum ElemType { + Float(FloatKind), + Int(IntKind), + UInt(UIntKind), + Bool, + Complex(ComplexKind), // NEW +} +``` + +`size()` / `size_bits()`: return 2x base float size (C32 = 8 bytes/64 bits, C64 = 16 bytes/128 bits). + +Classification methods: +- `is_complex()`, `as_complex()` + +No changes to `StorageType` or `Type` — `Scalar(ElemType::Complex(C32))` works as-is. + +**~30 lines changed.** + +### 2. CUDA Backend (`cubecl-cpp` / `cubecl-cuda`) + +**Key enabler:** `thrust::complex` provides operator overloading for `+`, `-`, `*`, `/`, `==`, `!=`, plus `thrust::abs`, `thrust::exp`, `thrust::log`, `thrust::sqrt`, `thrust::conj`, etc. + +**Element type mapping:** + +| File | Change | +|------|--------| +| `shared/element.rs` | Add `CF32`, `CF64` variants to `Elem` | +| `cuda/dialect.rs` `compile_elem()` | `CF32` → `"thrust::complex"`, `CF64` → `"thrust::complex"` | +| `cuda/dialect.rs` `compile_includes()` | Conditional `#include ` via `flags.elem_complex` | +| `shared/base.rs` | `ElemType::Complex(C32/C64)` → `Elem::CF32/CF64` mapping, set `flags.elem_complex` | + +**Binary ops:** No changes needed. Existing `operator!(Add, "+")` etc. work via thrust overloading. + +**Unary ops:** Most work as-is via `function!(Exp, "exp")` pattern with thrust namespace functions. Half-support disabled for complex types. + +**New custom ops:** + +| Op | Generated code | Location | +|----|----------------|----------| +| Conj | `thrust::conj(input)` | `unary.rs`, ~1 line | +| Real (extract) | `(input).real()` | New | +| Imag (extract) | `(input).imag()` | New | +| IsNan | `thrust::isnan(re) \|\| thrust::isnan(im)` | Custom | +| IsInf | `thrust::isinf(re) \|\| thrust::isinf(im)` | Custom | + +**Rejected ops (panic in backend dispatch):** +- All bitwise ops +- Ordering comparisons (`<`, `<=`, `>`, `>=`) +- `Ceil`, `Floor`, `Round`, `Trunc`, `%`, `MulHi`, Saturating ops + +**~65 lines changed.** + +### 3. Frontend (`cubecl-core` / `cubecl-macros`) + +**New file:** `crates/cubecl-core/src/frontend/element/complex.rs` + +```rust +macro_rules! impl_complex { + ($primitive:ty, $kind:ident) => { + impl CubeType for $primitive { type ExpandType = NativeExpand; } + impl Scalar for $primitive {} + impl CubePrimitive for $primitive { + type Scalar = Self; + type Size = Const<1>; + type WithScalar = S; + fn as_type_native() -> Option { + Some(StorageType::Scalar(ElemType::Complex(ComplexKind::$kind)).into()) + } + } + impl Complex for $primitive {} + }; +} + +impl_complex!(num_complex::Complex, C32); +impl_complex!(num_complex::Complex, C64); +``` + +**Complex trait:** + +```rust +pub trait Complex: CubePrimitive { + fn conj(self) -> Self { unexpanded!() } + fn real(self) -> Self::Scalar { unexpanded!() } + fn imag(self) -> Self::Scalar { unexpanded!() } +} +``` + +**Arithmetic expand:** New `impl_complex_binop!` macro (same body as `impl_core_binop!` but without `Numeric` bound). Calls `binary_expand()` — identical IR emission. + +**Transcendental expand:** `impl_complex_unary_func!` macro (same body as `impl_unary_func!`). Calls `unary_expand()` — identical IR emission. + +**`CubeElement` (`pod.rs`):** `unsafe impl bytemuck::Pod` for `num_complex::Complex` (they are `#[repr(C)]` with two f32/f64). + +**Dependency:** Add `num-complex` to `cubecl-core/Cargo.toml`. + +**~163 lines changed.** + +### 4. Validation + +**Primary mechanism: Rust type system.** + +- `Complex` does not implement `Float` → Ceil/Floor/Round/Trunc/Sin/Cos caller sites reject +- `Complex` does not implement `Int` → Bitwise/Saturating caller sites reject +- `Complex` does not implement `Numeric` → Remainder (%) caller sites reject +- `Complex` does not implement `PartialOrd` → `<`, `>`, `<=`, `>=` reject + +**Secondary mechanism: Backend dispatch panic.** + +For IR instructions constructed directly (low-level paths), `base.rs` dispatch panics on invalid ops for complex types. + +**~20 lines changed.** + +## Scope Summary + +| Scope | Lines | Difficulty | +|-------|-------|------------| +| IR type system | ~30 | Low | +| CUDA backend | ~65 | Low | +| Frontend (Complex trait + macros) | ~163 | Low–Medium | +| Validation | ~20 | Low | +| **Total Phase 1** | **~278** | **Low–Medium** | + +Phase 2 (WGPU) and Phase 3 (CPU) are deferred. + +## Operation Compatibility Matrix + +| Category | Works (same IR, thrust overloading) | Needs custom impl | Not applicable | +|----------|--------------------------------------|-------------------|----------------| +| Binary arithmetic | `+`, `-`, `*`, `/` | `pow` | `%`, `mulhi`, saturating | +| Unary math | `abs`, `exp`, `log`, `sqrt`, `sin`, `cos` | `isnan`, `isinf` | `ceil`, `floor`, `round`, `trunc` | +| Comparison | `==`, `!=` | — | `<`, `<=`, `>`, `>=` | +| Bitwise | — | — | All | +| Structural | All (element-type agnostic) | — | — | +| Reduction | `sum`, `prod` | — | `max`, `min` | +| Complex-specific | — | `conj`, `real`, `imag` | — | + +## Decisions Log + +- **Interleaved (not split)**: Required for cuSOLVER/cuBLAS interop, cache locality, SIMD +- **Independent Complex trait (not Numeric extension)**: Clean separation, invalid ops excluded by type system, minimal code duplication (macros reuse `binary_expand`/`unary_expand`) +- **HIP deferred**: No CI, no local test hardware +- **WGPU deferred to Phase 2**: WGSL has no native complex type; CUDA is the priority diff --git a/docs/plans/2026-04-14-complex-impl-plan.md b/docs/plans/2026-04-14-complex-impl-plan.md new file mode 100644 index 0000000000..577048e809 --- /dev/null +++ b/docs/plans/2026-04-14-complex-impl-plan.md @@ -0,0 +1,933 @@ +# Complex Number Support — Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Add interleaved `Complex` / `Complex` as first-class types in CubeCL IR with CUDA backend support. + +**Architecture:** New `ElemType::Complex(ComplexKind)` variant in IR, independent `Complex` trait in frontend (parallel to Float/Int), `thrust::complex` mapping in CUDA backend. Tests use the `testgen_*` macro pattern executed via `cargo test -p cubecl-cuda`. + +**Tech Stack:** CubeCL IR, cubecl-cpp (CUDA codegen), cubecl-core (frontend), cubecl-cuda (runtime tests), `thrust::complex`, `num_complex`. + +**Design doc:** `docs/plans/2026-04-14-complex-design.md` + +--- + +### Task 1: Add `ComplexKind` and `ElemType::Complex` to IR + +**Files:** +- Modify: `crates/cubecl-ir/src/type.rs:56-67` (add `ComplexKind` enum, add variant to `ElemType`) +- Modify: `crates/cubecl-ir/src/type.rs:152-203` (add `size()` / `size_bits()` arms) + +**Step 1: Add `ComplexKind` enum before `ElemType`** + +In `crates/cubecl-ir/src/type.rs`, after the `UIntKind` enum (line 56), add: + +```rust +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[allow(missing_docs)] +pub enum ComplexKind { + C32, + C64, +} +``` + +Then add variant to `ElemType` (after `Bool` at line 66): + +```rust +pub enum ElemType { + Float(FloatKind), + Int(IntKind), + UInt(UIntKind), + Bool, + Complex(ComplexKind), +} +``` + +Note: `ElemType` already derives `From` via `derive_more::From`. Since we added a new variant, `From` will be auto-derived. Verify this compiles. + +**Step 2: Add `size()` arm** + +In the `size()` method (line 180, after `ElemType::Bool`), add: + +```rust + ElemType::Complex(kind) => match kind { + ComplexKind::C32 => core::mem::size_of::() * 2, + ComplexKind::C64 => core::mem::size_of::() * 2, + }, +``` + +**Step 3: Add `size_bits()` arm** + +In `size_bits()` (line 201, after the existing match arms), the current code has: +```rust + ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool => self.size() * 8, +``` +Add a new arm: +```rust + ElemType::Complex(_) => self.size() * 8, +``` + +Or alternatively, change the existing arm to include Complex: +```rust + ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool | ElemType::Complex(_) => self.size() * 8, +``` + +**Step 4: Add classification methods** + +After `is_bool()` (line 228-230), add: + +```rust + pub fn is_complex(&self) -> bool { + matches!(self, ElemType::Complex(_)) + } + + pub fn as_complex(&self) -> Option { + match self { + ElemType::Complex(kind) => Some(*kind), + _ => None, + } + } +``` + +**Step 5: Check compilation** + +Run: `cargo check -p cubecl-ir` +Expected: Compile success. May have warnings about unused `ComplexKind` — that's fine. + +**Step 6: Commit** + +```bash +git add crates/cubecl-ir/src/type.rs +git commit -m "feat: add ComplexKind and ElemType::Complex to IR" +``` + +--- + +### Task 2: Register Complex types in CUDA backend (cpp layer) + +**Files:** +- Modify: `crates/cubecl-cpp/src/shared/element.rs:11-38` (add `CF32`, `CF64` to `Elem`) +- Modify: `crates/cubecl-cpp/src/shared/element.rs:159-188` (add `ident()` arms) +- Modify: `crates/cubecl-cpp/src/shared/base.rs:73-95` (add `elem_complex` flag to `Flags`) +- Modify: `crates/cubecl-cpp/src/shared/base.rs:2012-2066` (add IR→Elem mapping) + +**Step 1: Add `CF32`, `CF64` to `Elem` enum** + +In `crates/cubecl-cpp/src/shared/element.rs`, add before `Bool` (line 34): + +```rust + CF32, + CF64, +``` + +**Step 2: Add `ident()` arms** + +In the `ident()` method (around line 159-188), add matching arms: + +```rust + Elem::CF32 => "cf32", + Elem::CF64 => "cf64", +``` + +**Step 3: Add `elem_complex` flag to `Flags`** + +In `crates/cubecl-cpp/src/shared/base.rs`, add to `Flags` struct (after line 80): + +```rust + pub elem_complex: bool, +``` + +Initialize it in the `Flags` default/constructor (search for where other `elem_*` flags are initialized, likely `false`). + +**Step 4: Add IR→Elem mapping in `compile_elem()`** + +In `crates/cubecl-cpp/src/shared/base.rs`, in the `compile_elem` method that maps `gpu::ElemType` to `Elem` (around line 2064, after `gpu::ElemType::Bool => Elem::Bool`), add: + +```rust + gpu::ElemType::Complex(kind) => { + self.flags.elem_complex = true; + match kind { + gpu::ComplexKind::C32 => Elem::CF32, + gpu::ComplexKind::C64 => Elem::CF64, + } + } +``` + +**Step 5: Check compilation** + +Run: `cargo check -p cubecl-cpp` +Expected: May fail until CUDA dialect's `compile_elem` is updated (Task 3). If it fails, continue to Task 3. + +**Step 6: Commit** + +```bash +git add crates/cubecl-cpp/src/shared/element.rs crates/cubecl-cpp/src/shared/base.rs +git commit -m "feat: add CF32/CF64 to cpp backend element types" +``` + +--- + +### Task 3: Add CUDA dialect codegen for Complex types + +**Files:** +- Modify: `crates/cubecl-cpp/src/cuda/dialect.rs` (add `compile_elem` arms + include) + +**Step 1: Add type name mapping in `compile_elem()`** + +In `crates/cubecl-cpp/src/cuda/dialect.rs`, find the `compile_elem` method (around line 255-309). In the non-`words` branch, add before `Bool`: + +```rust + Elem::CF32 => f.write_str("thrust::complex"), + Elem::CF64 => f.write_str("thrust::complex"), +``` + +**Step 2: Add `#include `** + +In the same file, find `compile_includes()` (around line 37-80). Add after the existing includes: + +```rust + if flags.elem_complex { + f.write_str("#include \n")?; + } +``` + +**Step 3: Add word-size type mapping (if needed)** + +In the `words` branch of `compile_elem()`, complex types don't have native vector types. For now, skip or map to the same type. If there's no match for `CF32`/`CF64` in the words branch, it may fall through — verify this doesn't cause issues. + +**Step 4: Check compilation** + +Run: `cargo check -p cubecl-cpp` +Expected: Compile success. + +**Step 5: Commit** + +```bash +git add crates/cubecl-cpp/src/cuda/dialect.rs +git commit -m "feat: add CUDA dialect codegen for complex types" +``` + +--- + +### Task 4: Add Complex element support in CUDA runtime + +**Files:** +- Modify: `crates/cubecl-cuda/src/runtime.rs` (register complex types as supported) + +**Step 1: Check how types are registered** + +Look at `crates/cubecl-cpp/src/shared/base.rs` function `register_supported_types()` (around line 2091). This is likely called during runtime initialization. Add: + +```rust + gpu::ElemType::Complex(gpu::ComplexKind::C32), + gpu::ElemType::Complex(gpu::ComplexKind::C64), +``` + +to the `supported_types` array. + +**Step 2: Check compilation** + +Run: `cargo check -p cubecl-cuda` +Expected: Compile success. + +**Step 3: Commit** + +```bash +git add crates/cubecl-cuda/ crates/cubecl-cpp/src/shared/base.rs +git commit -m "feat: register complex types in CUDA runtime" +``` + +--- + +### Task 5: Add Complex frontend trait and CubePrimitive impl + +**Files:** +- Create: `crates/cubecl-core/src/frontend/element/complex.rs` +- Modify: `crates/cubecl-core/src/frontend/element/mod.rs` (add `mod complex; pub use complex::*;`) +- Modify: `crates/cubecl-core/Cargo.toml` (add `num-complex` dependency) + +**Step 1: Add `num-complex` dependency** + +In `crates/cubecl-core/Cargo.toml`, add to `[dependencies]`: + +```toml +num-complex = { workspace = true } +``` + +Check workspace `Cargo.toml` to see if `num-complex` is already in workspace dependencies. If not, add it there: + +```toml +num-complex = "0.4" +``` + +**Step 2: Create `complex.rs`** + +Create `crates/cubecl-core/src/frontend/element/complex.rs`: + +```rust +use core::ops::{Add, Div, Mul, Sub}; + +use crate::{ + ir::{Arithmetic, ComplexKind, ElemType, Scope, StorageType, Type}, + prelude::{ + unexpanded, Assign, AssignExpand, CubePrimitive, CubePrimitiveExpand, CubeType, + IntoRuntime, NativeAssign, NativeExpand, Scalar, + }, + unsafe_ignore_fmt, +}; +use cubecl_ir::ConstantValue; + +pub trait Complex: + CubePrimitive + + Add + + Sub + + Mul + + Div + + Neg + + Copy + + Clone + + PartialEq + + core::fmt::Debug + + Send + + Sync + + 'static +{ + fn conj(self) -> Self { + unexpanded!() + } + + fn real_val(self) -> Self::Scalar { + unexpanded!() + } + + fn imag_val(self) -> Self::Scalar { + unexpanded!() + } +} + +macro_rules! impl_complex { + ($primitive:ty, $kind:ident) => { + impl CubeType for $primitive { + type ExpandType = NativeExpand<$primitive>; + } + + impl Scalar for $primitive {} + + impl CubePrimitive for $primitive { + type Scalar = Self; + type Size = crate::prelude::Const<1>; + type WithScalar = S; + + fn as_type_native() -> Option { + Some(StorageType::Scalar(ElemType::Complex(ComplexKind::$kind)).into()) + } + + fn from_const_value(_value: ConstantValue) -> Self { + unimplemented!("Complex constants not yet supported") + } + } + + impl IntoRuntime for $primitive { + fn __expand_runtime_method(self, _scope: &mut Scope) -> NativeExpand { + self.into() + } + } + + impl NativeAssign for $primitive {} + + impl crate::prelude::IntoMut for $primitive { + fn into_mut(self, _scope: &mut Scope) -> Self { + self + } + } + + impl Complex for $primitive {} + }; +} + +impl_complex!(num_complex::Complex, C32); +impl_complex!(num_complex::Complex, C64); +``` + +**Step 3: Register module** + +In `crates/cubecl-core/src/frontend/element/mod.rs`, add: + +```rust +mod complex; +pub use complex::*; +``` + +**Step 4: Check compilation** + +Run: `cargo check -p cubecl-core` +Expected: May fail because `core::ops::Neg` is referenced but not imported. Also may fail because the `#[cube]` macro's `normalize_kernel_ty` doesn't know about Complex bounds. Fix compilation errors iteratively. + +**Step 5: Commit** + +```bash +git add crates/cubecl-core/src/frontend/element/complex.rs crates/cubecl-core/src/frontend/element/mod.rs crates/cubecl-core/Cargo.toml +git commit -m "feat: add Complex trait and CubePrimitive impl for Complex32/64" +``` + +--- + +### Task 6: Add Complex arithmetic expand macros and operations + +**Files:** +- Modify: `crates/cubecl-core/src/frontend/element/complex.rs` (add expand macros) +- Modify: `crates/cubecl-core/src/frontend/operation/unary.rs` (add Complex types to applicable `impl_unary_func!` invocations) + +**Step 1: Add expand macros in `complex.rs`** + +The key insight: `impl_core_binop!` requires `T: Add + CubePrimitive` but NOT `Numeric`. Let's check if `Add`, `Sub`, `Mul`, `Div` are already blanket-implemented for `num_complex::Complex`. Yes, `num_complex::Complex` implements `Add`, `Sub`, `Mul`, `Div` when `T: Clone + Num`. + +So we just need the expand side. Add to `complex.rs`: + +```rust +use core::ops::Neg; + +macro_rules! impl_complex_binop { + ($trait: ident, $method: ident, $op: expr) => { + paste::paste! { + pub trait []: $trait + CubePrimitive + CubeType]> + Sized { + fn [<__expand_ $method>]( + scope: &mut Scope, + lhs: NativeExpand, + rhs: NativeExpand, + ) -> NativeExpand { + lhs.[<__expand_ $method _method>](scope, rhs) + } + } + + pub trait [<$trait Expand>] { + fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) -> Self; + } + + impl + CubePrimitive> [] for T {} + impl + CubePrimitive> [<$trait Expand>] for NativeExpand { + fn [<__expand_ $method _method>](self, scope: &mut Scope, rhs: Self) -> Self { + crate::frontend::operation::base::binary_expand(scope, self.into(), rhs.into(), $op).into() + } + } + } + }; +} + +impl_complex_binop!(Add, add, Arithmetic::Add); +impl_complex_binop!(Sub, sub, Arithmetic::Sub); +impl_complex_binop!(Mul, mul, Arithmetic::Mul); +impl_complex_binop!(Div, div, Arithmetic::Div); +``` + +For Neg: +```rust +pub trait CubeNeg: Neg + CubePrimitive + Sized { + fn __expand_neg(scope: &mut Scope, x: NativeExpand) -> NativeExpand; +} + +impl + CubePrimitive> CubeNeg for T {} + +impl + CubePrimitive> crate::frontend::operation::unary::neg::NegExpand + for NativeExpand +{ + // This won't work directly since neg::expand is a free function +} +``` + +Actually, let's look at how `neg` works. The `neg::expand` function at `unary.rs:27-33` takes `E: CubePrimitive` — no `Numeric` bound. So `neg` should work out of the box for Complex types! + +Similarly, the `impl_unary_func!` macro takes a list of types. We can add Complex types to the relevant invocations. Let's verify which ones Complex needs: + +Complex needs: `Abs`, `Exp`, `Log`, `Sin`, `Cos`, `Sqrt`, `Powf` +Complex does NOT need: `Ceil`, `Floor`, `Round`, `Trunc`, `Erf`, `Tan`, `Tanh`, etc. + +**Step 2: Add Complex types to `impl_unary_func!` invocations** + +In `crates/cubecl-core/src/frontend/operation/unary.rs`, add `num_complex::Complex, num_complex::Complex` to these invocations: + +- `impl_unary_func!(Abs, abs, ...)` (line 161) +- `impl_unary_func!(Exp, exp, ...)` (line 186) +- `impl_unary_func!(Log, ln, ...)` (line 197) +- `impl_unary_func!(Cos, cos, ...)` (line 209) +- `impl_unary_func!(Sin, sin, ...)` (line 210) +- `impl_unary_func!(Sqrt, sqrt, ...)` (line 334) + +For `Powf`, check how it's defined — it may require `Float` trait. If so, we'll need a separate Complex-specific pow expand. + +**Step 3: Check compilation** + +Run: `cargo check -p cubecl-core` +Expected: Compile success after adding the import for `num_complex` at the top of `unary.rs`. + +**Step 4: Commit** + +```bash +git add crates/cubecl-core/src/frontend/element/complex.rs crates/cubecl-core/src/frontend/operation/unary.rs +git commit -m "feat: add Complex arithmetic and transcendental expand ops" +``` + +--- + +### Task 7: Add Complex-specific operations (Conj, Real, Imag) — Frontend + +**Files:** +- Modify: `crates/cubecl-core/src/frontend/element/complex.rs` (add expand methods) +- Modify: `crates/cubecl-ir/src/arithmetic.rs` or `operator.rs` (add IR ops if needed) + +**Step 1: Check if IR has Conj operator** + +Search `crates/cubecl-ir/src/arithmetic.rs` for `Conj`. If it doesn't exist, add it to the `Arithmetic` enum: + +```rust +Conj, +``` + +Also check if `Operator` has `Real`/`Imag` or if we should use a different mechanism (e.g., lowering via metadata or struct access). + +**Step 2: Add expand methods for Complex-specific ops** + +In `complex.rs`, add expand traits: + +```rust +pub trait ComplexExpand: CubePrimitive { + fn __expand_conj_method(self, scope: &mut Scope) -> Self; + fn __expand_real_val_method(self, scope: &mut Scope) -> NativeExpand; + fn __expand_imag_val_method(self, scope: &mut Scope) -> NativeExpand; +} + +impl ComplexExpand for NativeExpand { + fn __expand_conj_method(self, scope: &mut Scope) -> Self { + crate::frontend::operation::base::unary_expand(scope, self.into(), Arithmetic::Conj).into() + } + + fn __expand_real_val_method(self, scope: &mut Scope) -> NativeExpand { + let expand_element: crate::ir::ManagedVariable = self.into(); + let item = ::as_type(scope); + crate::frontend::operation::base::unary_expand_fixed_output( + scope, expand_element, item, Operator::Real, + ).into() + } + + fn __expand_imag_val_method(self, scope: &mut Scope) -> NativeExpand { + let expand_element: crate::ir::ManagedVariable = self.into(); + let item = ::as_type(scope); + crate::frontend::operation::base::unary_expand_fixed_output( + scope, expand_element, item, Operator::Imag, + ).into() + } +} +``` + +Note: `Operator::Real` and `Operator::Imag` may not exist yet. If they don't, we need to add them to `crates/cubecl-ir/src/operator.rs`. + +**Step 3: Add IR ops if needed** + +In `crates/cubecl-ir/src/operator.rs`, add to the `Operator` enum: + +```rust +Real, +Imag, +``` + +And add to `crates/cubecl-ir/src/arithmetic.rs`: + +```rust +Conj, +``` + +**Step 4: Check compilation** + +Run: `cargo check -p cubecl-core` +Expected: Compile success. + +**Step 5: Commit** + +```bash +git add crates/cubecl-core/src/frontend/element/complex.rs crates/cubecl-ir/src/operator.rs crates/cubecl-ir/src/arithmetic.rs +git commit -m "feat: add Conj/Real/Imag IR ops and frontend expand" +``` + +--- + +### Task 8: Add CUDA codegen for Complex-specific ops (Conj/Real/Imag) + +**Files:** +- Modify: `crates/cubecl-cpp/src/shared/instruction.rs` (add dispatch for new ops) +- Modify: `crates/cubecl-cpp/src/shared/unary.rs` (add Conj formatter) +- Possibly: `crates/cubecl-cpp/src/cuda/dialect.rs` (if custom formatting needed) + +**Step 1: Add Conj unary formatter** + +In `crates/cubecl-cpp/src/shared/unary.rs`, add: + +```rust +pub struct Conj; +impl Unary for Conj { + fn format_scalar(f: &mut core::fmt::Formatter, input: Variable, elem: Elem) -> std::fmt::Result { + write!(f, "thrust::conj({input})") + } +} +``` + +**Step 2: Add Real/Imag extraction formatting** + +These extract a scalar from a complex value: + +```rust +pub struct RealExtract; +impl Unary for RealExtract { + fn format_scalar(f: &mut core::fmt::Formatter, input: Variable, _elem: Elem) -> std::fmt::Result { + let out_elem = /* the output element type */; + write!(f, "{out_elem}({input}.real())") + } +} +``` + +This may need special handling since the output type is different from the input type (Complex → float). + +**Step 3: Add instruction dispatch** + +In `crates/cubecl-cpp/src/shared/instruction.rs`, find where `Arithmetic::Abs` etc. are dispatched and add: + +```rust +gpu::Arithmetic::Conj => Conj::format(f, &it.input, &it.out), +``` + +And for `Operator::Real` / `Operator::Imag`: + +```rust +gpu::Operator::Real => RealExtract::format(f, &it.input, &it.out), +gpu::Operator::Imag => ImagExtract::format(f, &it.input, &it.out), +``` + +**Step 4: Add IsNan/IsInf for complex types in CUDA** + +In `crates/cubecl-cpp/src/shared/instruction.rs` or `comparison.rs`, where `IsNan` is handled, add complex-specific path: + +```rust +// In the IsNan handler, check if input is complex: +if input_elem.is_complex() { + write!(f, "({out} = (thrust::isnan({input}.real()) || thrust::isnan({input}.imag())))") +} +``` + +**Step 5: Check compilation** + +Run: `cargo check -p cubecl-cpp` +Expected: Compile success. + +**Step 6: Commit** + +```bash +git add crates/cubecl-cpp/ +git commit -m "feat: add CUDA codegen for Conj/Real/Imag complex ops" +``` + +--- + +### Task 9: Add CubeElement impl for Complex types + +**Files:** +- Modify: `crates/cubecl-core/src/pod.rs` (add `CubeElement` impl for `Complex` / `Complex`) + +**Step 1: Add `bytemuck::Pod` / `Zeroable` safety** + +`num_complex::Complex` is `#[repr(C)]` and `T: Copy`, so it should be safe to implement `Pod`. Check if `bytemuck` already has an impl for `num_complex::Complex`. If not: + +```rust +unsafe impl bytemuck::Pod for num_complex::Complex {} +unsafe impl bytemuck::Zeroable for num_complex::Complex {} +unsafe impl bytemuck::Pod for num_complex::Complex {} +unsafe impl bytemuck::Zeroable for num_complex::Complex {} +``` + +**Step 2: Add `CubeElement` impl** + +In `crates/cubecl-core/src/pod.rs`, add: + +```rust +impl CubeElement for num_complex::Complex { + fn type_name() -> &'static str { + "cf32" + } + fn as_bytes(slice: &[Self]) -> &[u8] { + bytemuck::cast_slice(slice) + } + fn from_bytes(bytes: &[u8]) -> &[Self] { + bytemuck::cast_slice(bytes) + } + fn cube_type() -> StorageType { + ElemType::Complex(ComplexKind::C32).into() + } + fn maximum_value() -> Self { + num_complex::Complex::new(f32::MAX, 0.0) + } + fn minimum_value() -> Self { + num_complex::Complex::new(f32::MIN, 0.0) + } +} + +impl CubeElement for num_complex::Complex { + fn type_name() -> &'static str { + "cf64" + } + fn as_bytes(slice: &[Self]) -> &[u8] { + bytemuck::cast_slice(slice) + } + fn from_bytes(bytes: &[u8]) -> &[Self] { + bytemuck::cast_slice(bytes) + } + fn cube_type() -> StorageType { + ElemType::Complex(ComplexKind::C64).into() + } + fn maximum_value() -> Self { + num_complex::Complex::new(f64::MAX, 0.0) + } + fn minimum_value() -> Self { + num_complex::Complex::new(f64::MIN, 0.0) + } +} +``` + +**Step 3: Check compilation** + +Run: `cargo check -p cubecl-core` +Expected: Compile success. + +**Step 4: Commit** + +```bash +git add crates/cubecl-core/src/pod.rs +git commit -m "feat: add CubeElement impl for Complex32/64" +``` + +--- + +### Task 10: Add runtime test for Complex addition + +**Files:** +- Create: `crates/cubecl-core/src/runtime_tests/complex.rs` +- Modify: `crates/cubecl-core/src/runtime_tests/mod.rs` (add module + `testgen_complex!`) +- Modify: `crates/cubecl-cuda/src/lib.rs` (invoke `testgen_complex!`) + +**Step 1: Create runtime test kernel** + +Create `crates/cubecl-core/src/runtime_tests/complex.rs`: + +```rust +use crate::frontend::element::complex::Complex; +use crate::prelude::*; +use cubecl_runtime::runtime::Runtime; + +#[cube(launch)] +pub fn kernel_complex_add(a: &mut Array, b: Array) { + a[UNIT_POS] = a[UNIT_POS] + b[UNIT_POS]; +} + +pub fn test_complex_add(client: ComputeClient) { + let a = vec![C::new(1.0, 2.0), C::new(3.0, 4.0)]; // adjust for num_complex API + let b = vec![C::new(5.0, 6.0), C::new(7.0, 8.0)]; + + let handle_a = client.create_from_slice(C::as_bytes(&a)); + let handle_b = client.create_from_slice(C::as_bytes(&b)); + + kernel_complex_add::launch::( + &client, + CubeCount::Static(2, 1, 1), + CubeDim::new_1d(1), + unsafe { ArrayArg::from_raw_parts(handle_a.clone(), 2) }, + unsafe { ArrayArg::from_raw_parts(handle_b, 2) }, + ); + + let actual = client.read_one_unchecked(handle_a); + let actual = C::from_bytes(&actual); + assert_eq!(actual[0], C::new(6.0, 8.0)); // (1+5, 2+6) + assert_eq!(actual[1], C::new(10.0, 12.0)); // (3+7, 4+8) +} +``` + +Note: The `C::new(re, im)` constructor and `assert_eq!` require `num_complex::Complex` to implement `PartialEq` (it does). Adjust the constructor call to `num_complex::Complex::new(re, im)`. + +**Step 2: Register in `mod.rs`** + +Add to `crates/cubecl-core/src/runtime_tests/mod.rs`: +```rust +pub mod complex; +``` + +And export macro: +```rust +#[macro_export] +macro_rules! testgen_complex { + () => { + use super::*; + use num_complex; + + mod complex { + use super::*; + + #[$crate::runtime_tests::test_log::test] + fn test_complex_add_cf32() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_add::>(client); + } + + #[$crate::runtime_tests::test_log::test] + fn test_complex_add_cf64() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::complex::test_complex_add::>(client); + } + } + }; +} +``` + +**Step 3: Invoke in CUDA tests** + +In `crates/cubecl-cuda/src/lib.rs`, inside the `#[cfg(test)] mod tests` block (after line 76), add: + +```rust + cubecl_core::testgen_complex!(); +``` + +**Step 4: Run the test** + +Run: `cargo test -p cubecl-cuda test_complex_add` +Expected: PASS. This validates the full pipeline: frontend → IR → CUDA codegen → runtime. + +**Step 5: Commit** + +```bash +git add crates/cubecl-core/src/runtime_tests/complex.rs crates/cubecl-core/src/runtime_tests/mod.rs crates/cubecl-cuda/src/lib.rs +git commit -m "test: add complex addition runtime test" +``` + +--- + +### Task 11: Add runtime tests for complex multiply, conjugate, and abs + +**Files:** +- Modify: `crates/cubecl-core/src/runtime_tests/complex.rs` (add more test kernels + functions) + +**Step 1: Add multiply test** + +Complex multiplication: `(a+bi)(c+di) = (ac-bd) + (ad+bc)i` + +```rust +#[cube(launch)] +pub fn kernel_complex_mul(a: &mut Array, b: Array) { + a[UNIT_POS] = a[UNIT_POS] * b[UNIT_POS]; +} + +pub fn test_complex_mul(client: ComputeClient) { + // (1+2i) * (3+4i) = (3-8) + (4+6)i = -5 + 10i + let a = vec![num_complex::Complex::new(1.0f32, 2.0)]; + let b = vec![num_complex::Complex::new(3.0f32, 4.0)]; + // ... +} +``` + +Adjust for f64 as needed. + +**Step 2: Add conjugate test** + +```rust +#[cube(launch)] +pub fn kernel_complex_conj(a: &mut Array) { + a[UNIT_POS] = a[UNIT_POS].conj(); +} +``` + +Expected: `conj(1+2i) = 1-2i` + +**Step 3: Add abs test** + +```rust +#[cube(launch)] +pub fn kernel_complex_abs(a: Array, out: &mut Array) { + out[UNIT_POS] = a[UNIT_POS].abs(); +} +``` + +Wait — `abs` returns `Self` for Complex (it's defined via `impl_unary_func!` which keeps same type). But `abs(complex)` should return the magnitude as a float. This needs careful design. + +Actually, in the `Arithmetic::Abs` IR op, the output type matches the input type. For complex, `abs(c)` could return the magnitude as a complex with zero imaginary part, OR we need a separate `magnitude` operation that returns float. + +Let's defer the abs design question and test just mul and conj for now. + +**Step 4: Run tests** + +Run: `cargo test -p cubecl-cuda test_complex` +Expected: All complex tests pass. + +**Step 5: Commit** + +```bash +git add crates/cubecl-core/src/runtime_tests/complex.rs +git commit -m "test: add complex multiply and conjugate runtime tests" +``` + +--- + +### Task 12: Handle #[cube] macro integration for Complex + +**Files:** +- Modify: `crates/cubecl-macros/src/parse/kernel.rs` (add `Complex` to recognized trait bounds) + +**Step 1: Add Complex to recognized bounds** + +In `crates/cubecl-macros/src/parse/kernel.rs`, around line 231-297, where `Float`, `Int`, `Numeric`, `CubePrimitive` are mapped to `DynamicScalar`, add: + +```rust +"Complex" => { + map.insert(ident, GenericArg { + expand_ty: parse_quote!(DynamicScalar), + kind: DefineKind::Type, + }); +} +``` + +**Step 2: Check compilation** + +Run: `cargo check -p cubecl-macros` +Expected: Compile success. + +**Step 3: Commit** + +```bash +git add crates/cubecl-macros/src/parse/kernel.rs +git commit -m "feat: add Complex to #[cube] macro recognized bounds" +``` + +--- + +### Task 13: Full integration test run + +**Step 1: Run all cubecl-ir tests** + +Run: `cargo test -p cubecl-ir` +Expected: No tests exist, should complete immediately. + +**Step 2: Run all cubecl-core tests** + +Run: `cargo test -p cubecl-core` +Expected: Only trybuild compile-fail tests run. All pass. + +**Step 3: Run all cubecl-cuda tests** + +Run: `cargo test -p cubecl-cuda` +Expected: All existing tests pass + new complex tests pass. + +**Step 4: Run full xtask test suite** + +Run: `cargo xtask test` +Expected: All tests pass. + +**Step 5: Final commit (if any fixes needed)** + +```bash +git add -A +git commit -m "fix: integration fixes for complex type support" +``` diff --git a/docs/plans/2026-04-17-complex-math-resume.md b/docs/plans/2026-04-17-complex-math-resume.md new file mode 100644 index 0000000000..6ecf6055f5 --- /dev/null +++ b/docs/plans/2026-04-17-complex-math-resume.md @@ -0,0 +1,122 @@ +# Complex CUDA Math Surface Resume Notes + +## Snapshot + +- Date: 2026-04-17 +- Worktree: `/home/shinaoka/tensor4all/cubecl` +- Branch: `feat/complex-numbers` +- HEAD: `9163c50b` (`Add complex CUDA math helpers and tests`) +- Worktree status at note creation: clean + +## What Is Already Done + +The complex CUDA math follow-up has been implemented and committed. + +Main changes: + +- `abs(complex)` now returns the underlying real scalar type +- complex `tanh` and `powf` are exposed in the frontend +- CUDA complex math helpers were added in `cubecl-cpp` +- runtime coverage was added for: + - `abs` + - `exp` + - `log` + - `sin` + - `cos` + - `sqrt` + - `tanh` + - `powf` + +Related docs: + +- `docs/plans/2026-04-17-complex-math-surface-design.md` +- `docs/plans/2026-04-17-complex-math-surface-impl.md` + +## Current Verification State + +Verified on this worktree after commit: + +- `cargo fmt --all --check` +- `cargo check -p cubecl-core -p cubecl-cpp -p cubecl-cuda` +- `cargo test -p cubecl-cuda --no-run` +- `cargo test -p cubecl-cuda test_complex_abs_cf32 -- --nocapture` +- `cargo test -p cubecl-cuda test_complex_exp_cf32 -- --nocapture` +- `cargo test -p cubecl-cuda test_complex_tanh_cf32 -- --nocapture` +- `cargo test -p cubecl-cuda test_complex_powf_cf32 -- --nocapture` + +At the end of this session, those focused CUDA runtime tests passed locally. + +## Important Note About The Earlier PTX Error + +Earlier in the session, CUDA runtime tests failed with: + +- `CUDA_ERROR_UNSUPPORTED_PTX_VERSION` + +That failure is not reproducing in the latest reruns listed above. + +Do not assume the driver/toolchain mismatch is still an active blocker unless the error reproduces again. If it does come back, investigate the NVRTC and driver combination before changing code. + +Environment clues captured during debugging: + +- `/usr/local/cuda -> /usr/local/cuda-12.6` +- `/usr/local/cuda/lib64/libnvrtc.so -> .../libnvrtc.so.12.6.85` +- `/proc/driver/nvidia/version` showed `535.288.01` + +## Resume Procedure + +When resuming, start from this exact sequence: + +```bash +cd /home/shinaoka/tensor4all/cubecl +git status --short +git rev-parse --short HEAD +git branch --show-current +``` + +Expected: + +- clean worktree +- HEAD still at `9163c50b` or a descendant +- branch still `feat/complex-numbers` + +Then rerun the verification set: + +```bash +cargo fmt --all --check +cargo check -p cubecl-core -p cubecl-cpp -p cubecl-cuda +cargo test -p cubecl-cuda --no-run +cargo test -p cubecl-cuda test_complex_abs_cf32 -- --nocapture +cargo test -p cubecl-cuda test_complex_exp_cf32 -- --nocapture +cargo test -p cubecl-cuda test_complex_tanh_cf32 -- --nocapture +cargo test -p cubecl-cuda test_complex_powf_cf32 -- --nocapture +``` + +## If Everything Still Passes + +The implementation work is effectively in a handoff state. + +Likely next actions: + +1. Run a broader complex test sweep if desired. +2. Push the branch and open a PR. +3. Clean up or split docs only if needed. + +## If `CUDA_ERROR_UNSUPPORTED_PTX_VERSION` Comes Back + +Treat it as an environment issue first, not a code regression. + +Recommended checks: + +```bash +cat /proc/driver/nvidia/version +readlink -f /usr/local/cuda +readlink -f /usr/local/cuda/lib64/libnvrtc.so +``` + +Then reproduce with one focused test: + +```bash +cargo test -p cubecl-cuda test_complex_abs_cf32 -- --nocapture +``` + +If the PTX error is back, inspect the runtime loader path before editing code again. A driver upgrade to `580-server-open` was considered during debugging, but it should only be treated as necessary if the error is reproducible. diff --git a/docs/plans/2026-04-17-complex-math-surface-design.md b/docs/plans/2026-04-17-complex-math-surface-design.md new file mode 100644 index 0000000000..29935ea2a7 --- /dev/null +++ b/docs/plans/2026-04-17-complex-math-surface-design.md @@ -0,0 +1,202 @@ +# Complex CUDA Math Surface Follow-up — Design + +## Overview + +Close the remaining gap in CubeCL's CUDA complex-number support by making the complex math contract explicit and testable for downstream runtimes. + +This follow-up keeps interleaved `Complex` / `Complex` as the primary model, keeps invalid operations centrally rejected, and focuses on the CUDA backend only. + +## Scope + +This design covers: + +- `abs` +- `exp` +- `log` +- `sin` +- `cos` +- `sqrt` +- `tanh` +- `powf` +- runtime tests for `Complex32` and `Complex64` + +This design does not cover: + +- WGPU complex support +- CPU complex support +- GEMM or linalg integration +- HIP support + +## API Contract + +### `abs(complex)` returns a real scalar + +`abs(complex)` should return the magnitude, not a complex value with zero imaginary part. + +Concrete return types: + +- `num_complex::Complex -> f32` +- `num_complex::Complex -> f64` +- `Vector, N> -> Vector` +- `Vector, N> -> Vector` + +This aligns `abs(complex)` with downstream expectations and makes the result-type policy explicit rather than implicit. + +### `real_val()` / `imag_val()` remain real-valued + +The existing `real_val()` and `imag_val()` methods already return the underlying floating-point scalar. That policy remains unchanged. + +### `tanh(complex)` and `powf(complex, complex)` become first-class supported ops + +These operations should work through the same frontend/IR/CUDA pipeline as the already-supported complex arithmetic and unary math surface. + +## Frontend Design + +### 1. Generalize `Abs` + +Today `Abs` is modeled as a same-type unary trait. That is correct for real scalars and integers, but incorrect for complex values. + +The `Abs` trait should be reshaped to return an associated output type based on `CubePrimitive::WithScalar<_>`: + +- real scalars still return themselves +- integers still return themselves +- complex returns its corresponding floating scalar +- vectors inherit the correct result through `WithScalar` + +This keeps `Abs` uniform at the trait level without introducing a complex-only escape hatch. + +### 2. Use fixed-output expansion for complex `abs` + +No new IR opcode is needed. The existing `Arithmetic::Abs` operation is still emitted. + +The frontend decides the output type: + +- real / int `abs` continues using same-type unary expansion +- complex `abs` uses `unary_expand_fixed_output(...)` with the matching float element type + +This mirrors the already-established approach for `Real` and `Imag`. + +### 3. Extend complex math trait coverage + +The frontend currently allows several unary complex ops but does not fully expose the downstream-required surface. + +Required additions: + +- add complex types to `Tanh` +- add complex types to `Powf` + +`Complex` itself remains an independent trait rather than implementing `Float` or `Numeric`. + +## IR Design + +No new IR instruction is required for this issue. + +The relevant existing operations are already sufficient: + +- `Arithmetic::Abs` +- `Arithmetic::Exp` +- `Arithmetic::Log` +- `Arithmetic::Sin` +- `Arithmetic::Cos` +- `Arithmetic::Sqrt` +- `Arithmetic::Tanh` +- `Arithmetic::Powf` +- `Operator::Real` +- `Operator::Imag` + +The main change is the frontend result-type policy for `Abs`, not the IR vocabulary. + +## CUDA Backend Design + +### 1. Stay on `cuComplex.h` + +The branch already switched from `thrust::complex` to `cuComplex.h` for NVRTC compatibility. This design keeps that direction. + +### 2. Add explicit complex math helpers + +The CUDA backend should not rely on generic `exp(z)` / `pow(z, w)` calls accidentally working for `cuFloatComplex` / `cuDoubleComplex`. + +Instead, `cuda/dialect.rs` should define explicit inline helpers for: + +- `cubecl_abs` +- `cubecl_exp` +- `cubecl_log` +- `cubecl_sin` +- `cubecl_cos` +- `cubecl_sqrt` +- `cubecl_tanh` +- `cubecl_powf` + +These helpers should be overloaded for `cuFloatComplex` and `cuDoubleComplex`. + +Representative formulas: + +- `abs(z) = hypot(re, im)` +- `exp(x + iy) = exp(x) * (cos(y) + i sin(y))` +- `log(z) = log(|z|) + i atan2(im, re)` +- `sin(x + iy) = sin(x) cosh(y) + i cos(x) sinh(y)` +- `cos(x + iy) = cos(x) cosh(y) - i sin(x) sinh(y)` +- `sqrt(z)` via polar half-angle formula or equivalent numerically stable branch +- `tanh(z) = sinh(2x)/(cosh(2x)+cos(2y)) + i sin(2y)/(cosh(2x)+cos(2y))` +- `pow(z, w) = exp(w * log(z))` + +### 3. Route shared unary/binary formatting through explicit complex branches + +`shared/unary.rs` and `shared/binary.rs` should emit the explicit helper calls whenever the element type is `CF32` or `CF64`. + +This makes the generated CUDA source deterministic and backend-owned instead of depending on incidental overload resolution. + +### 4. Keep explicit unsupported behavior centralized + +The following remain intentionally unsupported for complex values: + +- ordering comparisons +- ordering-based `min` / `max` +- bitwise ops +- integer-only saturating ops +- `MulHi` +- `Remainder` + +## Testing Strategy + +### Runtime tests + +Extend `crates/cubecl-core/src/runtime_tests/complex.rs` with focused CUDA runtime tests for: + +- `abs` with real-valued outputs +- `exp` +- `log` +- `sin` +- `cos` +- `sqrt` +- `tanh` +- `powf` + +Each op should be tested for: + +- `Complex` +- `Complex` +- nontrivial values with nonzero real and imaginary parts +- branch-sensitive values for `log`, `sqrt`, and `powf` + +Reference values should come from `num_complex`. + +### Comparison policy + +Complex-valued outputs should be compared with per-component tolerances. + +Real-valued outputs from `abs` should be compared as floats. + +## Validation Notes + +On this machine, `cargo test -p cubecl-cuda test_complex` currently fails at launch time with `CUDA_ERROR_UNSUPPORTED_PTX_VERSION`, so local verification for this issue should be split into: + +- compile-level verification locally +- runtime validation on CI or a CUDA environment with a compatible driver/toolchain + +## Decisions + +- Keep `cuComplex.h` as the CUDA representation for NVRTC compatibility. +- Make `abs(complex)` return a real scalar, not a zero-imaginary complex. +- Do not add new IR opcodes for this issue. +- Implement the remaining complex math surface explicitly in CUDA helper code instead of relying on ambient overloads. diff --git a/docs/plans/2026-04-17-complex-math-surface-impl.md b/docs/plans/2026-04-17-complex-math-surface-impl.md new file mode 100644 index 0000000000..f71ece1394 --- /dev/null +++ b/docs/plans/2026-04-17-complex-math-surface-impl.md @@ -0,0 +1,267 @@ +# Complex CUDA Math Surface Follow-up Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Finish the CUDA complex math surface by making `abs(complex)` return a real value and adding explicit complex support for `tanh` and `powf`, with runtime tests for the supported operator set. + +**Architecture:** Reshape the frontend `Abs` trait to support fixed output types, use that to lower complex `abs` as `Arithmetic::Abs` with a real-valued output, and route CUDA complex math through explicit `cuComplex` helper functions. Extend runtime tests to lock the contract against `num_complex` reference values. + +**Tech Stack:** `cubecl-core`, `cubecl-cpp`, `cubecl-cuda`, `num-complex`, CUDA `cuComplex.h` + +**Design doc:** `docs/plans/2026-04-17-complex-math-surface-design.md` + +--- + +### Task 1: Add failing tests for the missing complex math contract + +**Files:** +- Modify: `crates/cubecl-core/src/runtime_tests/complex.rs` +- Test: `cargo test -p cubecl-core --no-run` + +**Step 1: Write failing test coverage for `abs(complex)`** + +Add kernels and host-side tests that assert: + +- `Complex.abs()` writes `f32` +- `Complex.abs()` writes `f64` +- results match `num_complex::Complex::norm()` + +Use output arrays with element type `C::FloatElem`. + +**Step 2: Write failing test coverage for `tanh` and `powf`** + +Add kernels and host-side tests for: + +- `tanh` on `Complex` and `Complex` +- `powf` on `Complex` and `Complex` + +Use `num_complex` as the reference implementation and compare with tolerances. + +**Step 3: Add coverage for the existing contract surface** + +Add or consolidate tests for: + +- `exp` +- `log` +- `sin` +- `cos` +- `sqrt` + +This ensures the whole supported surface is explicitly covered, not only the two newly missing ops. + +**Step 4: Run a compile-only check** + +Run: `cargo test -p cubecl-core --no-run` +Expected: The new tests compile, even though runtime execution is deferred to CUDA-capable CI/tooling. + +**Step 5: Commit** + +```bash +git add crates/cubecl-core/src/runtime_tests/complex.rs +git commit -m "test: add complex math surface coverage" +``` + +--- + +### Task 2: Generalize frontend `Abs` to support real-valued complex outputs + +**Files:** +- Modify: `crates/cubecl-core/src/frontend/operation/unary.rs` +- Modify: `crates/cubecl-core/src/frontend/element/complex.rs` +- Modify: `crates/cubecl-core/src/frontend/container/vector/ops.rs` +- Check: `crates/cubecl-core/src/frontend/element/numeric.rs` + +**Step 1: Replace same-type `Abs` with a fixed-output trait shape** + +Refactor `Abs` in `frontend/operation/unary.rs` so its return type is based on `Self::WithScalar<...>` instead of always `Self`. + +Keep the real/int behavior unchanged by making their output type equal to themselves. + +**Step 2: Teach complex `Abs` to lower with fixed output** + +In `frontend/element/complex.rs`, implement the expand side of `Abs` so: + +- the emitted operation is still `Arithmetic::Abs` +- the output item is `T::FloatElem` + +This should mirror how `real_val()` and `imag_val()` already use `unary_expand_fixed_output(...)`. + +**Step 3: Ensure vectors inherit the right output type** + +Update any vector trait bounds or impls so `Vector, N>::abs()` yields `Vector` through `WithScalar`. + +**Step 4: Compile-check the frontend** + +Run: `cargo check -p cubecl-core` +Expected: `abs` compiles for existing numeric users and now type-checks correctly for complex users. + +**Step 5: Commit** + +```bash +git add crates/cubecl-core/src/frontend/operation/unary.rs crates/cubecl-core/src/frontend/element/complex.rs crates/cubecl-core/src/frontend/container/vector/ops.rs +git commit -m "feat: make complex abs return real values" +``` + +--- + +### Task 3: Expose `tanh` and `powf` for complex types in the frontend + +**Files:** +- Modify: `crates/cubecl-core/src/frontend/operation/unary.rs` +- Modify: `crates/cubecl-core/src/frontend/operation/binary.rs` + +**Step 1: Add complex support to `Tanh`** + +Extend the `impl_unary_func!` invocation for `Tanh` to include: + +- `num_complex::Complex` +- `num_complex::Complex` + +**Step 2: Add complex support to `Powf`** + +Extend the `impl_binary_func!` invocation for `Powf` to include: + +- `num_complex::Complex` +- `num_complex::Complex` + +**Step 3: Re-run a focused compile check** + +Run: `cargo check -p cubecl-core` +Expected: complex kernels can now be written using `.tanh()` and `.powf(...)`. + +**Step 4: Commit** + +```bash +git add crates/cubecl-core/src/frontend/operation/unary.rs crates/cubecl-core/src/frontend/operation/binary.rs +git commit -m "feat: expose complex tanh and powf in frontend" +``` + +--- + +### Task 4: Add explicit CUDA complex math helpers + +**Files:** +- Modify: `crates/cubecl-cpp/src/cuda/dialect.rs` + +**Step 1: Add helper functions for complex unary math** + +In `cuda/dialect.rs`, under the existing `cuComplex` overload block, add inline helpers for: + +- `cubecl_abs(cuFloatComplex)` -> `float` +- `cubecl_abs(cuDoubleComplex)` -> `double` +- `cubecl_exp` +- `cubecl_log` +- `cubecl_sin` +- `cubecl_cos` +- `cubecl_sqrt` +- `cubecl_tanh` + +**Step 2: Add helper functions for complex binary math** + +Add inline overloads for: + +- `cubecl_powf(cuFloatComplex, cuFloatComplex)` +- `cubecl_powf(cuDoubleComplex, cuDoubleComplex)` + +Implement `powf(z, w)` as `exp(w * log(z))` using the same helper family. + +**Step 3: Keep helpers usable from both host and device builds** + +Mark them `__device__ __host__ inline` like the existing operator wrappers. + +**Step 4: Compile the CUDA backend** + +Run: `cargo check -p cubecl-cpp` +Expected: helper definitions compile cleanly with the current `cuComplex.h` path. + +**Step 5: Commit** + +```bash +git add crates/cubecl-cpp/src/cuda/dialect.rs +git commit -m "feat: add cuComplex math helpers for complex ops" +``` + +--- + +### Task 5: Route shared codegen through the explicit complex helpers + +**Files:** +- Modify: `crates/cubecl-cpp/src/shared/unary.rs` +- Modify: `crates/cubecl-cpp/src/shared/binary.rs` + +**Step 1: Add complex-aware unary formatting** + +Special-case `CF32` and `CF64` in the unary formatter implementations for: + +- `Abs` +- `Exp` +- `Log` +- `Sin` +- `Cos` +- `Sqrt` +- `Tanh` + +Emit the helper calls from Task 4 instead of generic `exp(...)`, `sqrt(...)`, etc. + +**Step 2: Add complex-aware binary formatting for `Powf`** + +Special-case `CF32` and `CF64` in `shared/binary.rs` so complex `powf` emits `cubecl_powf(lhs, rhs)`. + +**Step 3: Preserve existing scalar behavior** + +Do not change codegen for real or integer element types. + +**Step 4: Compile-check the codegen path** + +Run: `cargo check -p cubecl-cpp -p cubecl-cuda` +Expected: shared formatting compiles and CUDA codegen still builds. + +**Step 5: Commit** + +```bash +git add crates/cubecl-cpp/src/shared/unary.rs crates/cubecl-cpp/src/shared/binary.rs +git commit -m "feat: route complex math codegen through explicit helpers" +``` + +--- + +### Task 6: Verify the whole surface and document remaining runtime limits + +**Files:** +- Modify: `crates/cubecl-core/src/runtime_tests/complex.rs` if minor tolerance fixes are needed +- Optional docs note in: `docs/plans/2026-04-17-complex-math-surface-design.md` + +**Step 1: Run compile-oriented verification** + +Run: + +```bash +cargo check -p cubecl-core -p cubecl-cpp -p cubecl-cuda +cargo test -p cubecl-core --no-run +cargo test -p cubecl-cuda --no-run +``` + +Expected: all compile successfully. + +**Step 2: Run the CUDA runtime tests if the environment allows** + +Run: `cargo test -p cubecl-cuda test_complex -- --nocapture` + +Expected on a compatible CUDA environment: + +- the complex runtime tests pass + +Expected on this current machine: + +- runtime launch may still fail with `CUDA_ERROR_UNSUPPORTED_PTX_VERSION` + +**Step 3: Record the verification result** + +If runtime execution is still blocked by the environment, keep that limitation explicit in the final summary instead of claiming runtime success. + +**Step 4: Commit** + +```bash +git add crates/cubecl-core/src/runtime_tests/complex.rs docs/plans/2026-04-17-complex-math-surface-design.md +git commit -m "test: verify complex CUDA math surface" +```