Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add f16 support in the wgpu backend #1582

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/burn-jit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ burn-fusion = { path = "../burn-fusion", version = "0.13.0", optional = true }

bytemuck = { workspace = true }
derive-new = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }
log = { workspace = true }
num-traits = { workspace = true }
rand = { workspace = true }
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-jit/src/codegen/dialect/gpu/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub enum Visibility {
#[allow(missing_docs)]
pub enum Elem {
Float,
Half,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to add Half here, Float should cover all float types of all precisions in this context.

Int,
UInt,
Bool,
Expand All @@ -36,6 +37,7 @@ impl Display for Elem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Float => f.write_str("float"),
Self::Half => f.write_str("half"),
Self::Int => f.write_str("int"),
Self::UInt => f.write_str("uint"),
Self::Bool => f.write_str("bool"),
Expand Down
1 change: 1 addition & 0 deletions crates/burn-jit/src/codegen/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ fn create_scalar_handles<R: Runtime, E1: JitElement, E2: JitElement, E3: JitElem
Elem::Int => 1,
Elem::UInt => 2,
Elem::Bool => panic!("Bool scalars are not supported"),
Elem::Half => panic!("Half scalars are not supported"),
};
let scalar_priorities: [usize; 3] = [
element_priority(E1::gpu_elem()),
Expand Down
23 changes: 23 additions & 0 deletions crates/burn-jit/src/element.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::codegen::dialect::gpu;
use burn_tensor::Element;
use half::f16;

/// The base element trait for the jit backend.
pub trait JitElement:
Expand Down Expand Up @@ -90,5 +91,27 @@ impl JitElement for f32 {
}
}

impl JitElement for f16 {
fn type_name() -> &'static str {
"f16"
}
fn as_bytes(slice: &[Self]) -> &[u8] {
bytemuck::cast_slice(slice)
}
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
fn gpu_elem() -> gpu::Elem {
gpu::Elem::Half
}
Comment on lines +104 to +106
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gpu element would be Float here.

fn maximum_value() -> Self {
f16::MAX
}
fn minimum_value() -> Self {
f16::MIN
}
}

impl FloatElement for f32 {}
impl IntElement for i32 {}
impl FloatElement for f16 {}
1 change: 1 addition & 0 deletions crates/burn-jit/src/fusion/tracing/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize};
#[derive(Default, Clone, Serialize, Deserialize)]
pub struct Scalars {
pub(crate) num_float: usize,
pub(crate) num_half: usize,
pub(crate) num_int: usize,
pub(crate) num_uint: usize,
pub(crate) num_bool: usize,
Expand Down
7 changes: 7 additions & 0 deletions crates/burn-jit/src/fusion/tracing/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ impl TraceBuilder {
self.scalars.num_float += 1;
var
}
gpu::Elem::Half => {
let var = self
.scope
.read_scalar(self.scalars.num_half as u16, elem_type);
self.scalars.num_half += 1;
var
}
gpu::Elem::Int => {
let var = self
.scope
Expand Down
1 change: 1 addition & 0 deletions crates/burn-wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ burn-common = { path = "../burn-common", version = "0.13.0" }
burn-fusion = { path = "../burn-fusion", version = "0.13.0", optional = true }

bytemuck = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }
wgpu = { workspace = true, features = ["fragile-send-sync-non-atomic-wasm"] }
pollster = { workspace = true }

Expand Down
4 changes: 4 additions & 0 deletions crates/burn-wgpu/src/compiler/wgsl/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub enum Variable {
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub enum Elem {
F32,
F16,
I32,
U32,
Bool,
Expand Down Expand Up @@ -165,6 +166,7 @@ impl Elem {
pub fn size(&self) -> usize {
match self {
Self::F32 => core::mem::size_of::<f32>(),
Self::F16 => core::mem::size_of::<half::f16>(),
Self::I32 => core::mem::size_of::<i32>(),
Self::U32 => core::mem::size_of::<u32>(),
Self::Bool => core::mem::size_of::<bool>(),
Expand All @@ -176,6 +178,7 @@ impl Display for Elem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::F32 => f.write_str("f32"),
Self::F16 => f.write_str("f16"),
Self::I32 => f.write_str("i32"),
Self::U32 => f.write_str("u32"),
Self::Bool => f.write_str("bool"),
Expand Down Expand Up @@ -218,6 +221,7 @@ impl Display for Variable {
}
Variable::ConstantScalar(number, elem) => match elem {
Elem::F32 => f.write_fmt(format_args!("{number}f")),
Elem::F16 => f.write_fmt(format_args!("{number}f")),
Elem::I32 => f.write_fmt(format_args!("{number}i")),
Elem::U32 => f.write_fmt(format_args!("{number}u")),
Elem::Bool => f.write_fmt(format_args!("bool({number})")),
Expand Down
7 changes: 7 additions & 0 deletions crates/burn-wgpu/src/compiler/wgsl/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
self.num_inputs = value.inputs.len();
self.num_outputs = value.outputs.len();

let features = match F::gpu_elem() {
gpu::Elem::Half => vec![wgsl::Feature::ShaderF16],
_ => vec![],
};
Comment on lines +81 to +84
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would check using F::wgpu_elem() == Elem::F16 instead.


let instructions = self.compile_scope(&mut value.body);
let extensions = register_extensions(&instructions);
let body = wgsl::Body {
Expand Down Expand Up @@ -114,6 +119,7 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
workgroup_id: self.workgroup_id,
body,
extensions,
features,
}
}

Expand All @@ -129,6 +135,7 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
fn compile_elem(value: gpu::Elem) -> wgsl::Elem {
match value {
gpu::Elem::Float => F::wgpu_elem(),
gpu::Elem::Half => F::wgpu_elem(),
Comment on lines 137 to +138
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line pretty much explains why we don't need Half in the gpu::Elem enum :)

gpu::Elem::Int => I::wgpu_elem(),
gpu::Elem::UInt => wgsl::Elem::U32,
gpu::Elem::Bool => wgsl::Elem::Bool,
Expand Down
18 changes: 18 additions & 0 deletions crates/burn-wgpu/src/compiler/wgsl/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ impl LocalArray {
}
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Feature {
ShaderF16,
}

#[derive(Debug, Clone)]
pub struct ComputeShader {
pub inputs: Vec<Binding>,
Expand All @@ -75,10 +80,15 @@ pub struct ComputeShader {
pub workgroup_id: bool,
pub body: Body,
pub extensions: Vec<Extension>,
pub features: Vec<Feature>,
}

impl Display for ComputeShader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for feature in self.features.iter() {
f.write_fmt(format_args!("{feature};\n"))?;
}

Self::format_bindings(f, "input", &self.inputs, 0)?;
Self::format_bindings(f, "output", &self.outputs, self.inputs.len())?;

Expand Down Expand Up @@ -218,3 +228,11 @@ impl Display for Visibility {
}
}
}

impl Display for Feature {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Feature::ShaderF16 => f.write_str("enable f16"),
}
}
}
8 changes: 8 additions & 0 deletions crates/burn-wgpu/src/element.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use burn_jit::JitElement;

use crate::compiler::wgsl;
use half::f16;

/// The base element trait for the wgpu backend.
pub trait WgpuElement: JitElement {
Expand Down Expand Up @@ -31,5 +32,12 @@ impl WgpuElement for f32 {
}
}

impl WgpuElement for f16 {
fn wgpu_elem() -> wgsl::Elem {
wgsl::Elem::F16
}
}

impl FloatElement for f32 {}
impl IntElement for i32 {}
impl FloatElement for f16 {}
Loading