Skip to content

Commit

Permalink
doc
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd committed Oct 26, 2023
1 parent d157e6b commit c9ac019
Show file tree
Hide file tree
Showing 17 changed files with 45 additions and 47 deletions.
7 changes: 1 addition & 6 deletions burn-compute/src/channel/base.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
use crate::{
server::{ComputeServer, Handle},
tune::AutotuneOperationSet,
};
use alloc::boxed::Box;
use crate::server::{ComputeServer, Handle};
use alloc::vec::Vec;
use burn_common::reader::Reader;

Expand All @@ -23,5 +19,4 @@ pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug {

/// Wait for the completion of every task in the server.
fn sync(&self);

}
3 changes: 1 addition & 2 deletions burn-compute/src/tune/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
mod operation;
mod tune_benchmark;
mod tune_cache;
mod tuner;

pub use operation::*;
pub use tune_benchmark::*;
pub use tune_cache::*;

mod tuner;
pub use tuner::*;
8 changes: 6 additions & 2 deletions burn-compute/src/tune/operation.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
use alloc::string::String;
use alloc::vec::Vec;

/// Type of operation for the kernel
/// Groups operations of the same type for autotune
pub trait AutotuneOperationSet: Send {
/// The key used in the tune cache
fn key(&self) -> AutotuneKey;

/// All candidate operations for autotuning this operation type
/// Operations can run on toy tensors of relevant size
fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation>>;

/// Returns the operation for the given index, matching the order
/// returned by autotunables
/// returned by autotunables. Operation obtained here runs on original tensors
fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation>;
}

/// Contains operation to run and inputs on which to run it
pub trait AutotuneOperation {
/// Runs the operation
fn execute(self: Box<Self>);

/// Clones the operation and inputs
fn clone(&self) -> Box<dyn AutotuneOperation>;
}

Expand Down
2 changes: 1 addition & 1 deletion burn-compute/src/tune/tune_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use burn_common::benchmark::Benchmark;

use crate::channel::ComputeChannel;
use crate::client::ComputeClient;
use crate::server::ComputeServer;
use crate::{channel::ComputeChannel, server::Handle};

use super::AutotuneOperation;
use alloc::string::{String, ToString};
Expand Down
5 changes: 4 additions & 1 deletion burn-compute/src/tune/tune_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ use alloc::boxed::Box;

/// Use to find and reuse the best kernel for some input
#[derive(Debug)]
pub struct TuneCache<S> {
pub(crate) struct TuneCache<S> {
cache: HashMap<AutotuneKey, usize>,
_server: PhantomData<S>,
}

/// Result of the cache try
pub enum TuneCacheResult {
/// An operation is found and given
Hit(Box<dyn AutotuneOperation>),
/// No operation is found and the set is given back for ownership
Miss(Box<dyn AutotuneOperationSet>),
}

Expand Down
5 changes: 3 additions & 2 deletions burn-compute/src/tune/tuner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ use crate::client::ComputeClient;
use crate::server::ComputeServer;
use crate::tune::{AutotuneOperation, AutotuneOperationSet, TuneBenchmark, TuneCache};

/// Server wrapper with extra capability of autotuning kernels
#[derive(Debug)]
/// Executes autotune benchmarking and caching
pub struct Tuner<S, C> {
pub tune_cache: TuneCache<S>,
tune_cache: TuneCache<S>,
_server: PhantomData<S>,
_channel: PhantomData<C>,
}

impl<S: ComputeServer, C: ComputeChannel<S>> Tuner<S, C> {
/// Returns a tuner with empty cache
pub fn new() -> Self {
Self {
tune_cache: TuneCache::new(),
Expand Down
2 changes: 1 addition & 1 deletion burn-wgpu/src/kernel/base.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::SourceTemplate;
use crate::{
compute::{Kernel, StaticKernel, WorkGroup},
compute::{StaticKernel, WorkGroup},
element::WgpuElement,
tensor::WgpuTensor,
};
Expand Down
4 changes: 2 additions & 2 deletions burn-wgpu/src/kernel/matmul/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
pub mod utils;

mod mem_coalescing;
mod naive;
mod tiling2d;
mod tune;
mod utils;

pub use mem_coalescing::*;
pub use naive::*;
pub use tiling2d::*;
pub use tune::*;
pub use utils::*;
2 changes: 0 additions & 2 deletions burn-wgpu/src/kernel/matmul/naive.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use std::sync::Arc;

use super::utils::shape_out;
use crate::{
compute::{StaticKernel, WorkGroup},
element::WgpuElement,
kernel::{build_info, into_contiguous, KernelSettings, SourceTemplate, StaticKernelSource},
kernel_wgsl,
ops::numeric::empty_device,
tensor::WgpuTensor,
};

Expand Down
3 changes: 1 addition & 2 deletions burn-wgpu/src/kernel/matmul/tiling2d/padding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use burn_tensor::{Element, Shape};

use crate::{
element::WgpuElement,
kernel::{slice, slice_assign, slice_on_output},
kernel::{slice_assign, slice_on_output},
ops::numeric::zeros_device,
tensor::WgpuTensor,
};
Expand Down Expand Up @@ -273,7 +273,6 @@ mod tests {
let col = 12;
let keep_cols = 10;
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
let shape = tensor.shape();
let expected_shape = [keep_rows, keep_cols].into();

let unpadded = crop(
Expand Down
33 changes: 18 additions & 15 deletions burn-wgpu/src/kernel/matmul/tune/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ use burn_tensor::Element;
use crate::{
element::WgpuElement,
kernel::matmul::{
autotune_tensors, utils::init_matmul_output, MemoryCoalescingMatmulAutotuneOperation,
Vec4TilingMatmulAutotuneOperation,
tune::utils::autotune_tensors, utils::init_matmul_output,
MemoryCoalescingMatmulAutotuneOperation, Vec4TilingMatmulAutotuneOperation,
},
tensor::WgpuTensor,
};

/// Set of matmul implementations available for autotune
/// Autotune key is given by concatenating the closest upper power of 2 of m, k and n
pub struct MatmulAutotuneOperationSet<E: WgpuElement, const D: usize> {
key: AutotuneKey,
lhs: WgpuTensor<E, D>,
Expand All @@ -25,13 +27,25 @@ impl<E: WgpuElement, const D: usize> MatmulAutotuneOperationSet<E, D> {
let k = lhs.shape.dims[D - 1];
let n = rhs.shape.dims[D - 1];
Self {
key: AutotuneKey::new("matmul".to_string(), log_mkn_input_key(m, k, n)),
key: AutotuneKey::new("matmul".to_string(), Self::log_mkn_input_key(m, k, n)),
lhs,
rhs,
out,
_element: PhantomData,
}
}

fn log_mkn_input_key(m: usize, k: usize, n: usize) -> String {
let mut desc = String::new();

for size in [m, k, n] {
let exp = f32::ceil(f32::log2(size as f32)) as u32;
desc.push_str(2_u32.pow(exp).to_string().as_str());
desc.push(',');
}

desc
}
}

impl<E: WgpuElement + Element, const D: usize> AutotuneOperationSet
Expand Down Expand Up @@ -71,6 +85,7 @@ impl<E: WgpuElement + Element, const D: usize> AutotuneOperationSet
}
}

/// Executes autotune on matmul operations
pub fn matmul_autotune<E: WgpuElement + Element, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
Expand All @@ -89,15 +104,3 @@ pub fn matmul_autotune<E: WgpuElement + Element, const D: usize>(

output
}

fn log_mkn_input_key(m: usize, k: usize, n: usize) -> String {
let mut desc = String::new();

for size in [m, k, n] {
let exp = f32::ceil(f32::log2(size as f32)) as u32;
desc.push_str(2_u32.pow(exp).to_string().as_str());
desc.push(',');
}

desc
}
1 change: 1 addition & 0 deletions burn-wgpu/src/kernel/matmul/tune/mem_coalescing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
};

#[derive(new)]
/// Memory coalescing matmul operation
pub struct MemoryCoalescingMatmulAutotuneOperation<E: WgpuElement, const D: usize> {
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
Expand Down
1 change: 0 additions & 1 deletion burn-wgpu/src/kernel/matmul/tune/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,4 @@ mod vec4_tiling;

pub use base::*;
pub use mem_coalescing::*;
pub use utils::*;
pub use vec4_tiling::*;
10 changes: 1 addition & 9 deletions burn-wgpu/src/kernel/matmul/tune/utils.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
use burn_tensor::{Element, Shape};
use burn_tensor::Element;

use crate::{element::WgpuElement, ops::numeric::ones_device, tensor::WgpuTensor};

pub(crate) fn n_bytes<E, const D: usize>(shape: &Shape<D>) -> usize {
shape.num_elements() * core::mem::size_of::<E>()
}

pub(crate) fn autotune_tensors<E: WgpuElement + Element, const D: usize>(
tensor: &WgpuTensor<E, D>,
) -> WgpuTensor<E, 3> {
Expand All @@ -21,7 +17,3 @@ pub(crate) fn autotune_tensors<E: WgpuElement + Element, const D: usize>(
.into(),
)
}

pub(crate) fn fill_bytes<E: WgpuElement, const D: usize>(value: u8, shape: &Shape<D>) -> Vec<u8> {
vec![value; n_bytes::<E, D>(shape)]
}
1 change: 1 addition & 0 deletions burn-wgpu/src/kernel/matmul/tune/vec4_tiling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
};

#[derive(new)]
/// Tiling 2d with vec4 primitive matmul operation
pub struct Vec4TilingMatmulAutotuneOperation<E: WgpuElement, const D: usize> {
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
Expand Down
1 change: 1 addition & 0 deletions burn-wgpu/src/kernel/matmul/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{element::WgpuElement, ops::numeric::empty_device, tensor::WgpuTensor};
use burn_tensor::Shape;

/// Creates an empty output tensor with matmul output shape
pub fn init_matmul_output<E: WgpuElement, const D: usize>(
lhs: &WgpuTensor<E, D>,
rhs: &WgpuTensor<E, D>,
Expand Down
4 changes: 3 additions & 1 deletion burn-wgpu/src/ops/float_ops.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use super::{numeric, BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor};
use crate::kernel::matmul::init_matmul_output;
#[cfg(feature = "autotune")]
use crate::kernel::matmul::matmul_autotune;
use crate::kernel::matmul::utils::init_matmul_output;
#[cfg(not(feature = "autotune"))]
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
use crate::kernel::{
self, unary_default, unary_inplace_default, unary_scalar_default, unary_scalar_inplace_default,
Expand Down

0 comments on commit c9ac019

Please sign in to comment.