Skip to content

Commit

Permalink
burn-wgpu: f16 support
Browse files Browse the repository at this point in the history
The burn-wgpu backend currently does not support computations on 16 bit
floats. This. for example, limits the ability to run LLMs on top of
Burn, on widely available hardware. So, add 16 bit float support in
burn-wgpu.

Signed-off-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>
  • Loading branch information
p1-0tr committed Apr 7, 2024
1 parent ca3dcb9 commit ee7ec2c
Show file tree
Hide file tree
Showing 12 changed files with 76 additions and 3 deletions.
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/burn-jit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ burn-fusion = { path = "../burn-fusion", version = "0.13.0", optional = true }

bytemuck = { workspace = true }
derive-new = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }
log = { workspace = true }
num-traits = { workspace = true }
rand = { workspace = true }
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-jit/src/codegen/dialect/gpu/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub enum Visibility {
#[allow(missing_docs)]
pub enum Elem {
Float,
Half,
Int,
UInt,
Bool,
Expand All @@ -36,6 +37,7 @@ impl Display for Elem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Float => f.write_str("float"),
Self::Half => f.write_str("half"),
Self::Int => f.write_str("int"),
Self::UInt => f.write_str("uint"),
Self::Bool => f.write_str("bool"),
Expand Down
1 change: 1 addition & 0 deletions crates/burn-jit/src/codegen/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ fn create_scalar_handles<R: Runtime, E1: JitElement, E2: JitElement, E3: JitElem
Elem::Int => 1,
Elem::UInt => 2,
Elem::Bool => panic!("Bool scalars are not supported"),
Elem::Half => panic!("Half scalars are not supported"),
};
let scalar_priorities: [usize; 3] = [
element_priority(E1::gpu_elem()),
Expand Down
23 changes: 23 additions & 0 deletions crates/burn-jit/src/element.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::codegen::dialect::gpu;
use burn_tensor::Element;
use half::f16;

/// The base element trait for the jit backend.
pub trait JitElement:
Expand Down Expand Up @@ -90,5 +91,27 @@ impl JitElement for f32 {
}
}

impl JitElement for f16 {
fn type_name() -> &'static str {
"f16"
}
fn as_bytes(slice: &[Self]) -> &[u8] {
bytemuck::cast_slice(slice)
}
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
fn gpu_elem() -> gpu::Elem {
gpu::Elem::Half
}
fn maximum_value() -> Self {
f16::MAX
}
fn minimum_value() -> Self {
f16::MIN
}
}

impl FloatElement for f32 {}
impl IntElement for i32 {}
impl FloatElement for f16 {}
1 change: 1 addition & 0 deletions crates/burn-jit/src/fusion/tracing/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize};
#[derive(Default, Clone, Serialize, Deserialize)]
pub struct Scalars {
pub(crate) num_float: usize,
pub(crate) num_half: usize,
pub(crate) num_int: usize,
pub(crate) num_uint: usize,
pub(crate) num_bool: usize,
Expand Down
7 changes: 7 additions & 0 deletions crates/burn-jit/src/fusion/tracing/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ impl TraceBuilder {
self.scalars.num_float += 1;
var
}
gpu::Elem::Half => {
let var = self
.scope
.read_scalar(self.scalars.num_half as u16, elem_type);
self.scalars.num_half += 1;
var
}
gpu::Elem::Int => {
let var = self
.scope
Expand Down
1 change: 1 addition & 0 deletions crates/burn-wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ burn-common = { path = "../burn-common", version = "0.13.0" }
burn-fusion = { path = "../burn-fusion", version = "0.13.0", optional = true }

bytemuck = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }
wgpu = { workspace = true, features = ["fragile-send-sync-non-atomic-wasm"] }
pollster = { workspace = true }

Expand Down
4 changes: 4 additions & 0 deletions crates/burn-wgpu/src/compiler/wgsl/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub enum Variable {
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub enum Elem {
F32,
F16,
I32,
U32,
Bool,
Expand Down Expand Up @@ -165,6 +166,7 @@ impl Elem {
pub fn size(&self) -> usize {
match self {
Self::F32 => core::mem::size_of::<f32>(),
Self::F16 => core::mem::size_of::<half::f16>(),
Self::I32 => core::mem::size_of::<i32>(),
Self::U32 => core::mem::size_of::<u32>(),
Self::Bool => core::mem::size_of::<bool>(),
Expand All @@ -176,6 +178,7 @@ impl Display for Elem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::F32 => f.write_str("f32"),
Self::F16 => f.write_str("f16"),
Self::I32 => f.write_str("i32"),
Self::U32 => f.write_str("u32"),
Self::Bool => f.write_str("bool"),
Expand Down Expand Up @@ -218,6 +221,7 @@ impl Display for Variable {
}
Variable::ConstantScalar(number, elem) => match elem {
Elem::F32 => f.write_fmt(format_args!("{number}f")),
Elem::F16 => f.write_fmt(format_args!("{number}f")),
Elem::I32 => f.write_fmt(format_args!("{number}i")),
Elem::U32 => f.write_fmt(format_args!("{number}u")),
Elem::Bool => f.write_fmt(format_args!("bool({number})")),
Expand Down
7 changes: 7 additions & 0 deletions crates/burn-wgpu/src/compiler/wgsl/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
self.num_inputs = value.inputs.len();
self.num_outputs = value.outputs.len();

let features = match F::gpu_elem() {
gpu::Elem::Half => vec![wgsl::Feature::ShaderF16],
_ => vec![],
};

let instructions = self.compile_scope(&mut value.body);
let extensions = register_extensions(&instructions);
let body = wgsl::Body {
Expand Down Expand Up @@ -114,6 +119,7 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
workgroup_id: self.workgroup_id,
body,
extensions,
features,
}
}

Expand All @@ -129,6 +135,7 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
fn compile_elem(value: gpu::Elem) -> wgsl::Elem {
match value {
gpu::Elem::Float => F::wgpu_elem(),
gpu::Elem::Half => F::wgpu_elem(),
gpu::Elem::Int => I::wgpu_elem(),
gpu::Elem::UInt => wgsl::Elem::U32,
gpu::Elem::Bool => wgsl::Elem::Bool,
Expand Down
18 changes: 18 additions & 0 deletions crates/burn-wgpu/src/compiler/wgsl/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ impl LocalArray {
}
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Feature {
ShaderF16,
}

#[derive(Debug, Clone)]
pub struct ComputeShader {
pub inputs: Vec<Binding>,
Expand All @@ -75,10 +80,15 @@ pub struct ComputeShader {
pub workgroup_id: bool,
pub body: Body,
pub extensions: Vec<Extension>,
pub features: Vec<Feature>,
}

impl Display for ComputeShader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for feature in self.features.iter() {
f.write_fmt(format_args!("{feature};\n"))?;
}

Self::format_bindings(f, "input", &self.inputs, 0)?;
Self::format_bindings(f, "output", &self.outputs, self.inputs.len())?;

Expand Down Expand Up @@ -218,3 +228,11 @@ impl Display for Visibility {
}
}
}

impl Display for Feature {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Feature::ShaderF16 => f.write_str("enable f16"),
}
}
}
8 changes: 8 additions & 0 deletions crates/burn-wgpu/src/element.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use burn_jit::JitElement;

use crate::compiler::wgsl;
use half::f16;

/// The base element trait for the wgpu backend.
pub trait WgpuElement: JitElement {
Expand Down Expand Up @@ -31,5 +32,12 @@ impl WgpuElement for f32 {
}
}

impl WgpuElement for f16 {
fn wgpu_elem() -> wgsl::Elem {
wgsl::Elem::F16
}
}

impl FloatElement for f32 {}
impl IntElement for i32 {}
impl FloatElement for f16 {}

0 comments on commit ee7ec2c

Please sign in to comment.