Skip to content

Commit

Permalink
Implemented load_buffer_format_* conversions (#295)
Browse files Browse the repository at this point in the history
* Implemented load_buffer_format_* conversions

* clang-format insists on ugly things
  • Loading branch information
vladmikhalin committed Jul 16, 2024
1 parent c6cdfcf commit f9e9679
Show file tree
Hide file tree
Showing 15 changed files with 469 additions and 85 deletions.
229 changes: 195 additions & 34 deletions src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include "shader_recompiler/backend/spirv/emit_spirv_instructions.h"
#include "shader_recompiler/backend/spirv/spirv_emit_context.h"

#include <magic_enum.hpp>

namespace Shader::Backend::SPIRV {
namespace {

Expand Down Expand Up @@ -209,57 +211,216 @@ void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, Id value, u32 elemen
ctx.OpStore(pointer, value);
}

Id EmitLoadBufferF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
const auto info = inst->Flags<IR::BufferInstInfo>();
Id EmitLoadBufferU32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
return EmitLoadBufferF32(ctx, inst, handle, address);
}

template <int N>
static Id EmitLoadBufferF32xN(EmitContext& ctx, u32 handle, Id address) {
const auto& buffer = ctx.buffers[handle];
if (info.index_enable && info.offset_enable) {
UNREACHABLE();
} else if (info.index_enable) {
Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u));
if constexpr (N == 1) {
const Id ptr{
ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, address)};
return ctx.OpLoad(buffer.data_types->Get(1), ptr);
} else {
boost::container::static_vector<Id, N> ids;
for (u32 i = 0; i < N; i++) {
index = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(i));
const Id ptr{
ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index)};
ids.push_back(ctx.OpLoad(buffer.data_types->Get(1), ptr));
}
return ctx.OpCompositeConstruct(buffer.data_types->Get(N), ids);
}
UNREACHABLE();
}

Id EmitLoadBufferU32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
return EmitLoadBufferF32(ctx, inst, handle, address);
Id EmitLoadBufferF32(EmitContext& ctx, IR::Inst*, u32 handle, Id address) {
return EmitLoadBufferF32xN<1>(ctx, handle, address);
}

Id EmitLoadBufferF32x2(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
const auto info = inst->Flags<IR::BufferInstInfo>();
const auto& buffer = ctx.buffers[handle];
boost::container::static_vector<Id, 2> ids;
for (u32 i = 0; i < 2; i++) {
const Id index{ctx.OpIAdd(ctx.U32[1], address, ctx.ConstU32(i))};
const Id ptr{ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index)};
ids.push_back(ctx.OpLoad(buffer.data_types->Get(1), ptr));
Id EmitLoadBufferF32x2(EmitContext& ctx, IR::Inst*, u32 handle, Id address) {
return EmitLoadBufferF32xN<2>(ctx, handle, address);
}

Id EmitLoadBufferF32x3(EmitContext& ctx, IR::Inst*, u32 handle, Id address) {
return EmitLoadBufferF32xN<3>(ctx, handle, address);
}

Id EmitLoadBufferF32x4(EmitContext& ctx, IR::Inst*, u32 handle, Id address) {
return EmitLoadBufferF32xN<4>(ctx, handle, address);
}

static bool IsSignedInteger(AmdGpu::NumberFormat format) {
switch (format) {
case AmdGpu::NumberFormat::Unorm:
case AmdGpu::NumberFormat::Uscaled:
case AmdGpu::NumberFormat::Uint:
return false;
case AmdGpu::NumberFormat::Snorm:
case AmdGpu::NumberFormat::Sscaled:
case AmdGpu::NumberFormat::Sint:
case AmdGpu::NumberFormat::SnormNz:
return true;
case AmdGpu::NumberFormat::Float:
default:
UNREACHABLE();
}
return ctx.OpCompositeConstruct(buffer.data_types->Get(2), ids);
}

Id EmitLoadBufferF32x3(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
const auto info = inst->Flags<IR::BufferInstInfo>();
const auto& buffer = ctx.buffers[handle];
boost::container::static_vector<Id, 3> ids;
for (u32 i = 0; i < 3; i++) {
const Id index{ctx.OpIAdd(ctx.U32[1], address, ctx.ConstU32(i))};
const Id ptr{ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index)};
ids.push_back(ctx.OpLoad(buffer.data_types->Get(1), ptr));
static u32 UXBitsMax(u32 bit_width) {
return (1u << bit_width) - 1u;
}

static u32 SXBitsMax(u32 bit_width) {
return (1u << (bit_width - 1u)) - 1u;
}

static Id ConvertValue(EmitContext& ctx, Id value, AmdGpu::NumberFormat format, u32 bit_width) {
switch (format) {
case AmdGpu::NumberFormat::Unorm:
return ctx.OpFDiv(ctx.F32[1], value, ctx.ConstF32(float(UXBitsMax(bit_width))));
case AmdGpu::NumberFormat::Snorm:
return ctx.OpFDiv(ctx.F32[1], value, ctx.ConstF32(float(SXBitsMax(bit_width))));
case AmdGpu::NumberFormat::SnormNz:
// (x * 2 + 1) / (Format::SMAX * 2)
value = ctx.OpFMul(ctx.F32[1], value, ctx.ConstF32(2.f));
value = ctx.OpFAdd(ctx.F32[1], value, ctx.ConstF32(1.f));
return ctx.OpFDiv(ctx.F32[1], value, ctx.ConstF32(float(SXBitsMax(bit_width) * 2)));
case AmdGpu::NumberFormat::Uscaled:
case AmdGpu::NumberFormat::Sscaled:
case AmdGpu::NumberFormat::Uint:
case AmdGpu::NumberFormat::Sint:
case AmdGpu::NumberFormat::Float:
return value;
default:
UNREACHABLE_MSG("Unsupported number fromat for conversion: {}",
magic_enum::enum_name(format));
}
return ctx.OpCompositeConstruct(buffer.data_types->Get(3), ids);
}

Id EmitLoadBufferF32x4(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
const auto info = inst->Flags<IR::BufferInstInfo>();
static Id ComponentOffset(EmitContext& ctx, Id address, u32 stride, u32 bit_offset) {
Id comp_offset = ctx.ConstU32(bit_offset);
if (stride < 4) {
// comp_offset += (address % 4) * 8;
const Id byte_offset = ctx.OpUMod(ctx.U32[1], address, ctx.ConstU32(4u));
const Id bit_offset = ctx.OpShiftLeftLogical(ctx.U32[1], byte_offset, ctx.ConstU32(3u));
comp_offset = ctx.OpIAdd(ctx.U32[1], comp_offset, bit_offset);
}
return comp_offset;
}

static Id GetBufferFormatValue(EmitContext& ctx, u32 handle, Id address, u32 comp) {
const auto& buffer = ctx.buffers[handle];
boost::container::static_vector<Id, 4> ids;
for (u32 i = 0; i < 4; i++) {
const Id index{ctx.OpIAdd(ctx.U32[1], address, ctx.ConstU32(i))};
const Id ptr{ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index)};
ids.push_back(ctx.OpLoad(buffer.data_types->Get(1), ptr));
const auto format = buffer.buffer.GetDataFmt();
switch (format) {
case AmdGpu::DataFormat::FormatInvalid:
return ctx.f32_zero_value;
case AmdGpu::DataFormat::Format8:
case AmdGpu::DataFormat::Format16:
case AmdGpu::DataFormat::Format32:
case AmdGpu::DataFormat::Format8_8:
case AmdGpu::DataFormat::Format16_16:
case AmdGpu::DataFormat::Format10_11_11:
case AmdGpu::DataFormat::Format11_11_10:
case AmdGpu::DataFormat::Format10_10_10_2:
case AmdGpu::DataFormat::Format2_10_10_10:
case AmdGpu::DataFormat::Format8_8_8_8:
case AmdGpu::DataFormat::Format32_32:
case AmdGpu::DataFormat::Format16_16_16_16:
case AmdGpu::DataFormat::Format32_32_32:
case AmdGpu::DataFormat::Format32_32_32_32: {
const u32 num_components = AmdGpu::NumComponents(format);
if (comp >= num_components) {
return ctx.f32_zero_value;
}

// uint index = address / 4;
Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u));
const u32 stride = buffer.buffer.GetStride();
if (stride > 4) {
const u32 index_offset = u32(AmdGpu::ComponentOffset(format, comp) / 32);
if (index_offset > 0) {
// index += index_offset;
index = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(index_offset));
}
}
const Id ptr = ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index);

const u32 bit_offset = AmdGpu::ComponentOffset(format, comp) % 32;
const u32 bit_width = AmdGpu::ComponentBits(format, comp);
const auto num_format = buffer.buffer.GetNumberFmt();
if (num_format == AmdGpu::NumberFormat::Float) {
if (bit_width == 32) {
return ctx.OpLoad(ctx.F32[1], ptr);
} else if (bit_width == 16) {
const Id comp_offset = ComponentOffset(ctx, address, stride, bit_offset);
Id value = ctx.OpLoad(ctx.U32[1], ptr);
value =
ctx.OpBitFieldSExtract(ctx.S32[1], value, comp_offset, ctx.ConstU32(bit_width));
value = ctx.OpSConvert(ctx.U16, value);
value = ctx.OpBitcast(ctx.F16[1], value);
return ctx.OpFConvert(ctx.F32[1], value);
} else {
UNREACHABLE_MSG("Invalid float bit width {}", bit_width);
}
} else {
Id value = ctx.OpLoad(ctx.U32[1], ptr);
const bool is_signed = IsSignedInteger(num_format);
if (bit_width < 32) {
const Id comp_offset = ComponentOffset(ctx, address, stride, bit_offset);
if (is_signed) {
value = ctx.OpBitFieldSExtract(ctx.S32[1], value, comp_offset,
ctx.ConstU32(bit_width));
value = ctx.OpConvertSToF(ctx.F32[1], value);
} else {
value = ctx.OpBitFieldUExtract(ctx.U32[1], value, comp_offset,
ctx.ConstU32(bit_width));
value = ctx.OpConvertUToF(ctx.F32[1], value);
}
} else {
if (is_signed) {
value = ctx.OpConvertSToF(ctx.F32[1], value);
} else {
value = ctx.OpConvertUToF(ctx.F32[1], value);
}
}
return ConvertValue(ctx, value, num_format, bit_width);
}
break;
}
default:
UNREACHABLE_MSG("Invalid format for conversion: {}", magic_enum::enum_name(format));
}
}

template <int N>
static Id EmitLoadBufferFormatF32xN(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
if constexpr (N == 1) {
return GetBufferFormatValue(ctx, handle, address, 0);
} else {
boost::container::static_vector<Id, N> ids;
for (u32 i = 0; i < N; i++) {
ids.push_back(GetBufferFormatValue(ctx, handle, address, i));
}
return ctx.OpCompositeConstruct(ctx.F32[N], ids);
}
return ctx.OpCompositeConstruct(buffer.data_types->Get(4), ids);
}

Id EmitLoadBufferFormatF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
return EmitLoadBufferFormatF32xN<1>(ctx, inst, handle, address);
}

Id EmitLoadBufferFormatF32x2(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
return EmitLoadBufferFormatF32xN<2>(ctx, inst, handle, address);
}

Id EmitLoadBufferFormatF32x3(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
return EmitLoadBufferFormatF32xN<3>(ctx, inst, handle, address);
}

Id EmitLoadBufferFormatF32x4(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
return EmitLoadBufferFormatF32xN<4>(ctx, inst, handle, address);
}

void EmitStoreBufferF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value) {
Expand Down
4 changes: 4 additions & 0 deletions src/shader_recompiler/backend/spirv/emit_spirv_instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ Id EmitLoadBufferF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
Id EmitLoadBufferF32x2(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
Id EmitLoadBufferF32x3(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
Id EmitLoadBufferF32x4(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
Id EmitLoadBufferFormatF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
Id EmitLoadBufferFormatF32x2(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
Id EmitLoadBufferFormatF32x3(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
Id EmitLoadBufferFormatF32x4(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
Id EmitLoadBufferU32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address);
void EmitStoreBufferF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
void EmitStoreBufferF32x2(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value);
Expand Down
5 changes: 2 additions & 3 deletions src/shader_recompiler/backend/spirv/spirv_emit_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,7 @@ void EmitContext::DefineBuffers(const Info& info) {
for (u32 i = 0; const auto& buffer : info.buffers) {
const auto* data_types = True(buffer.used_types & IR::Type::F32) ? &F32 : &U32;
const Id data_type = (*data_types)[1];
const u32 stride = buffer.stride == 0 ? 1 : buffer.stride;
const u32 num_elements = stride * buffer.num_records;
const Id record_array_type{TypeArray(data_type, ConstU32(num_elements))};
const Id record_array_type{TypeArray(data_type, ConstU32(buffer.length))};
const Id struct_type{TypeStruct(record_array_type)};
if (std::ranges::find(type_ids, record_array_type.value, &Id::value) == type_ids.end()) {
Decorate(record_array_type, spv::Decoration::ArrayStride, 4);
Expand Down Expand Up @@ -333,6 +331,7 @@ void EmitContext::DefineBuffers(const Info& info) {
.id = id,
.data_types = data_types,
.pointer_type = pointer_type,
.buffer = buffer.GetVsharp(info),
});
interfaces.push_back(id);
i++;
Expand Down
1 change: 1 addition & 0 deletions src/shader_recompiler/backend/spirv/spirv_emit_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class EmitContext final : public Sirit::Module {
Id id;
const VectorIds* data_types;
Id pointer_type;
AmdGpu::Buffer buffer;
};

u32& binding;
Expand Down
33 changes: 22 additions & 11 deletions src/shader_recompiler/frontend/translate/translate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ void Translator::EmitFetch(const GcnInst& inst) {
info.buffers.push_back({
.sgpr_base = attrib.sgpr_base,
.dword_offset = attrib.dword_offset,
.stride = buffer.GetStride(),
.num_records = buffer.num_records,
.length = buffer.num_records,
.used_types = IR::Type::F32,
.is_storage = true, // we may not fit into UBO with large meshes
.is_instance_data = true,
Expand Down Expand Up @@ -571,28 +570,40 @@ void Translate(IR::Block* block, u32 block_base, std::span<const GcnInst> inst_l
translator.V_CNDMASK_B32(inst);
break;
case Opcode::TBUFFER_LOAD_FORMAT_X:
translator.BUFFER_LOAD_FORMAT(1, true, inst);
translator.BUFFER_LOAD_FORMAT(1, true, true, inst);
break;
case Opcode::TBUFFER_LOAD_FORMAT_XY:
translator.BUFFER_LOAD_FORMAT(2, true, inst);
translator.BUFFER_LOAD_FORMAT(2, true, true, inst);
break;
case Opcode::TBUFFER_LOAD_FORMAT_XYZ:
translator.BUFFER_LOAD_FORMAT(3, true, inst);
translator.BUFFER_LOAD_FORMAT(3, true, true, inst);
break;
case Opcode::TBUFFER_LOAD_FORMAT_XYZW:
translator.BUFFER_LOAD_FORMAT(4, true, inst);
translator.BUFFER_LOAD_FORMAT(4, true, true, inst);
break;
case Opcode::BUFFER_LOAD_FORMAT_X:
case Opcode::BUFFER_LOAD_DWORD:
translator.BUFFER_LOAD_FORMAT(1, false, inst);
translator.BUFFER_LOAD_FORMAT(1, false, true, inst);
break;
case Opcode::BUFFER_LOAD_FORMAT_XY:
translator.BUFFER_LOAD_FORMAT(2, false, true, inst);
break;
case Opcode::BUFFER_LOAD_FORMAT_XYZ:
case Opcode::BUFFER_LOAD_DWORDX3:
translator.BUFFER_LOAD_FORMAT(3, false, inst);
translator.BUFFER_LOAD_FORMAT(3, false, true, inst);
break;
case Opcode::BUFFER_LOAD_FORMAT_XYZW:
translator.BUFFER_LOAD_FORMAT(4, false, true, inst);
break;
case Opcode::BUFFER_LOAD_DWORD:
translator.BUFFER_LOAD_FORMAT(1, false, false, inst);
break;
case Opcode::BUFFER_LOAD_DWORDX2:
translator.BUFFER_LOAD_FORMAT(2, false, false, inst);
break;
case Opcode::BUFFER_LOAD_DWORDX3:
translator.BUFFER_LOAD_FORMAT(3, false, false, inst);
break;
case Opcode::BUFFER_LOAD_DWORDX4:
translator.BUFFER_LOAD_FORMAT(4, false, inst);
translator.BUFFER_LOAD_FORMAT(4, false, false, inst);
break;
case Opcode::BUFFER_STORE_FORMAT_X:
case Opcode::BUFFER_STORE_DWORD:
Expand Down
2 changes: 1 addition & 1 deletion src/shader_recompiler/frontend/translate/translate.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class Translator {
void V_CMP_CLASS_F32(const GcnInst& inst);

// Vector Memory
void BUFFER_LOAD_FORMAT(u32 num_dwords, bool is_typed, const GcnInst& inst);
void BUFFER_LOAD_FORMAT(u32 num_dwords, bool is_typed, bool is_format, const GcnInst& inst);
void BUFFER_STORE_FORMAT(u32 num_dwords, bool is_typed, const GcnInst& inst);

// Vector interpolation
Expand Down
6 changes: 4 additions & 2 deletions src/shader_recompiler/frontend/translate/vector_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ void Translator::IMAGE_STORE(const GcnInst& inst) {
ir.ImageWrite(handle, body, value, {});
}

void Translator::BUFFER_LOAD_FORMAT(u32 num_dwords, bool is_typed, const GcnInst& inst) {
void Translator::BUFFER_LOAD_FORMAT(u32 num_dwords, bool is_typed, bool is_format,
const GcnInst& inst) {
const auto& mtbuf = inst.control.mtbuf;
const IR::VectorReg vaddr{inst.src[0].code};
const IR::ScalarReg sharp{inst.src[2].code * 4};
Expand Down Expand Up @@ -254,7 +255,8 @@ void Translator::BUFFER_LOAD_FORMAT(u32 num_dwords, bool is_typed, const GcnInst
const IR::Value handle =
ir.CompositeConstruct(ir.GetScalarReg(sharp), ir.GetScalarReg(sharp + 1),
ir.GetScalarReg(sharp + 2), ir.GetScalarReg(sharp + 3));
const IR::Value value = ir.LoadBuffer(num_dwords, handle, address, info);
const IR::Value value = is_format ? ir.LoadBufferFormat(num_dwords, handle, address, info)
: ir.LoadBuffer(num_dwords, handle, address, info);
const IR::VectorReg dst_reg{inst.src[1].code};
if (num_dwords == 1) {
ir.SetVectorReg(dst_reg, IR::F32{value});
Expand Down
Loading

0 comments on commit f9e9679

Please sign in to comment.