Skip to content

Commit

Permalink
Issue #16 #33: Add compute shader support to D3D12
Browse files Browse the repository at this point in the history
Changes:
add: generic compute pipeline
add: compute pass support

Modules:
D3D12RenderSystem
  • Loading branch information
vasumahesh1 committed Oct 23, 2018
1 parent 68e1612 commit b314aab
Show file tree
Hide file tree
Showing 33 changed files with 2,422 additions and 520 deletions.
2 changes: 2 additions & 0 deletions Source/Azura/RenderSystem/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ set(

# Headers
"Inc/Generic/Constants.h"
"Inc/Generic/ComputePool.h"
"Inc/Generic/Drawable.h"
"Inc/Generic/GenericTypes.h"
"Inc/Generic/GLFWWindow.h"
Expand All @@ -18,6 +19,7 @@ set(
"Inc/Generic/Windows/Win32GLFWWindow.h"

# Sources
"Src/Generic/ComputePool.cpp"
"Src/Generic/Drawable.cpp"
"Src/Generic/GenericTypes.cpp"
"Src/Generic/GLFWWindow.cpp"
Expand Down
12 changes: 8 additions & 4 deletions Source/Azura/RenderSystem/D3D12.RenderSystem.cmake
Original file line number Diff line number Diff line change
@@ -1,37 +1,41 @@
set(D3D12_SOURCES
# Include
"Inc/D3D12/D3D12ComputePool.h"
"Inc/D3D12/D3D12Core.h"
"Inc/D3D12/D3D12Drawable.h"
"Inc/D3D12/D3D12DrawablePool.h"
"Inc/D3D12/D3D12Macros.h"
"Inc/D3D12/D3D12Renderer.h"
"Inc/D3D12/D3D12ScopedBuffer.h"
"Inc/D3D12/D3D12ScopedCommandBuffer.h"
"Inc/D3D12/D3D12ScopedSwapChain.h"
"Inc/D3D12/D3D12ScopedComputePass.h"
"Inc/D3D12/D3D12ScopedImage.h"
"Inc/D3D12/D3D12ScopedPipeline.h"
"Inc/D3D12/D3D12ScopedRenderPass.h"
"Inc/D3D12/D3D12ScopedSampler.h"
"Inc/D3D12/D3D12ScopedShader.h"
"Inc/D3D12/D3D12ScopedRenderPass.h"
"Inc/D3D12/D3D12ScopedSwapChain.h"
"Inc/D3D12/D3D12TextureManager.h"
"Inc/D3D12/D3D12TypeMapping.h"

"Inc/D3D12/d3dx12.h"

# Source
"Src/D3D12/D3D12ComputePool.cpp"
"Src/D3D12/D3D12Core.cpp"
"Src/D3D12/D3D12Drawable.cpp"
"Src/D3D12/D3D12DrawablePool.cpp"
"Src/D3D12/D3D12Renderer.cpp"
"Src/D3D12/D3D12RenderSystem.cpp"
"Src/D3D12/D3D12ScopedBuffer.cpp"
"Src/D3D12/D3D12ScopedCommandBuffer.cpp"
"Src/D3D12/D3D12ScopedSwapChain.cpp"
"Src/D3D12/D3D12ScopedComputePass.cpp"
"Src/D3D12/D3D12ScopedImage.cpp"
"Src/D3D12/D3D12ScopedPipeline.cpp"
"Src/D3D12/D3D12ScopedRenderPass.cpp"
"Src/D3D12/D3D12ScopedSampler.cpp"
"Src/D3D12/D3D12ScopedShader.cpp"
"Src/D3D12/D3D12ScopedRenderPass.cpp"
"Src/D3D12/D3D12ScopedSwapChain.cpp"
"Src/D3D12/D3D12TextureManager.cpp"
"Src/D3D12/D3D12TypeMapping.cpp"
)
Expand Down
104 changes: 104 additions & 0 deletions Source/Azura/RenderSystem/Inc/D3D12/D3D12ComputePool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#pragma once
#include "Generic/ComputePool.h"
#include "Log/Log.h"

#include "D3D12/D3D12Core.h"
#include "D3D12/D3D12ScopedPipeline.h"
#include "D3D12/D3D12Drawable.h"
#include "D3D12/D3D12ScopedBuffer.h"
#include "D3D12/D3D12ScopedCommandBuffer.h"
#include "D3D12/D3D12ScopedImage.h"
#include "D3D12/D3D12ScopedSampler.h"
#include "D3D12/D3D12ScopedComputePass.h"


namespace Azura {
namespace D3D12 {

struct D3D12ComputePassRecordEntry
{
ID3D12PipelineState* m_pso;
ID3D12GraphicsCommandList* m_bundle;
U32 m_poolIdx;
};

class D3D12ComputePool : public ComputePool {

public:
D3D12ComputePool(const Microsoft::WRL::ComPtr<ID3D12Device>& device,
const ComputePoolCreateInfo& createInfo,
const DescriptorCount& descriptorCount,
const Containers::Vector<DescriptorSlot>& descriptorSlots,
const Containers::Vector<D3D12ScopedShader>& shaders,
const Containers::Vector<D3D12ScopedComputePass>& renderPasses,
Microsoft::WRL::ComPtr<ID3D12CommandQueue> commandQueue,
Memory::Allocator& mainAllocator,
Memory::Allocator& initAllocator,
Log log);

void AddShader(U32 shaderId) override;
void BindTextureData(SlotID slot, const TextureDesc& desc, const U8* buffer) override;
void BindSampler(SlotID slot, const SamplerDesc& desc) override;
void BindUniformData(SlotID slot, const U8* buffer, U32 size) override;
void Submit() override;
void Record();

void BeginUpdates() override;
void UpdateUniformData(SlotID slot, const U8* buffer, U32 size) override;
void UpdateTextureData(SlotID slot, const U8* buffer) override;
void SubmitUpdates() override;

const Containers::Vector<ID3D12DescriptorHeap*>& GetAllDescriptorHeaps() const;
ID3D12PipelineState* GetPipelineState(U32 renderPassId) const;

void GetRecordEntries(Containers::Vector<std::pair<U32, D3D12ComputePassRecordEntry>>& recordList) const;

ID3D12GraphicsCommandList* GetSecondaryCommandList(U32 renderPassId) const;

private:
void CreateRenderPassReferences(const ComputePoolCreateInfo& createInfo, const Containers::Vector<D3D12ScopedComputePass>& renderPasses);
void CreateDescriptorHeap();

void SetTextureData(ID3D12GraphicsCommandList* oneTimeCommandList);
void SetSamplerData();

void CreateComputePassInputTargetSRV(const Containers::Vector<std::reference_wrapper<D3D12ScopedImage>>& renderPassInputs, U32 offsetTillThis) const;
void CreateComputePassInputTargetUAV(
const Containers::Vector<std::reference_wrapper<D3D12ScopedImage>>& computePassOutputs,
U32 offsetTillThis) const;

Log log_D3D12RenderSystem;

const Microsoft::WRL::ComPtr<ID3D12Device>& m_device;
const Containers::Vector<DescriptorSlot>& m_globalDescriptorSlots;
const Containers::Vector<D3D12ScopedShader>& m_shaders;

Containers::Vector<D3D12ScopedPipeline> m_pipelines;

Containers::Vector<std::reference_wrapper<D3D12ScopedComputePass>> m_computePasses;

Microsoft::WRL::ComPtr<ID3D12CommandQueue> m_graphicsCommandQueue;

D3D12PipelineFactory m_pipelineFactory;
D3D12ScopedBuffer m_updateBuffer;
D3D12ScopedBuffer m_stagingBuffer;
D3D12ScopedBuffer m_mainBuffer;

U32 m_cbvSrvDescriptorElementSize{0};
U32 m_samplerDescriptorElementSize{0};
U32 m_offsetToConstantBuffers{0};
U32 m_offsetToComputePassInputs{0};
U32 m_offsetToComputePassOutputs{0};

Containers::Vector<D3D12ScopedImage> m_images;
Containers::Vector<D3D12ScopedSampler> m_samplers;
Containers::Vector<D3D12ScopedCommandBuffer> m_secondaryCommandBuffers;

Microsoft::WRL::ComPtr<ID3D12DescriptorHeap> m_descriptorComputeHeap;
Microsoft::WRL::ComPtr<ID3D12DescriptorHeap> m_descriptorSamplerHeap;

Containers::Vector<ID3D12DescriptorHeap*> m_allHeaps;
};

} // namespace D3D12
} // namespace Azura
14 changes: 12 additions & 2 deletions Source/Azura/RenderSystem/Inc/D3D12/D3D12Renderer.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
#include "Memory/HeapMemoryBuffer.h"
#include "D3D12/D3D12ScopedCommandBuffer.h"
#include "D3D12/D3D12ScopedSwapChain.h"
#include "D3D12ScopedRenderPass.h"
#include "D3D12/D3D12ScopedRenderPass.h"
#include "D3D12/D3D12ScopedComputePass.h"
#include "D3D12/D3D12ComputePool.h"


namespace Azura {
Expand All @@ -32,7 +34,10 @@ class D3D12Renderer : public Renderer {
Window& window);

String GetRenderingAPI() const override;

DrawablePool& CreateDrawablePool(const DrawablePoolCreateInfo& createInfo) override;
ComputePool& CreateComputePool(const ComputePoolCreateInfo& createInfo) override;

void Submit() override;
void RenderFrame() override;

Expand All @@ -53,7 +58,8 @@ class D3D12Renderer : public Renderer {
Microsoft::WRL::ComPtr<ID3D12Device> m_device;

Microsoft::WRL::ComPtr<ID3D12CommandAllocator> m_commandAllocator;
Microsoft::WRL::ComPtr<ID3D12CommandQueue> m_commandQueue;
Microsoft::WRL::ComPtr<ID3D12CommandQueue> m_mainGraphicsCommandQueue;
Microsoft::WRL::ComPtr<ID3D12CommandQueue> m_mainComputeCommandQueue;

UINT m_rtvDescriptorSize;
UINT m_dsvDescriptorSize;
Expand All @@ -63,10 +69,14 @@ class D3D12Renderer : public Renderer {

D3D12ScopedImage m_depthTexture{};

Containers::Vector<std::pair<U32, RenderPassType>> m_renderSequence;

Containers::Vector<D3D12ScopedImage> m_renderTargetImages;
Containers::Vector<D3D12ScopedRenderPass> m_renderPasses;
Containers::Vector<D3D12ScopedComputePass> m_computePasses;

Containers::Vector<D3D12DrawablePool> m_drawablePools;
Containers::Vector<D3D12ComputePool> m_computePools;
Containers::Vector<D3D12ScopedShader> m_shaders;
};

Expand Down
99 changes: 99 additions & 0 deletions Source/Azura/RenderSystem/Inc/D3D12/D3D12ScopedComputePass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#pragma once
#include "Types.h"

#include "Log/Log.h"
#include "D3D12/D3D12Core.h"
#include "D3D12/D3D12ScopedImage.h"
#include "D3D12/D3D12ScopedShader.h"
#include "D3D12/D3D12ScopedCommandBuffer.h"
#include "D3D12/D3D12ScopedSwapChain.h"

namespace Azura {
namespace Memory {
class Allocator;
}
}

namespace Azura {
namespace D3D12 {

class D3D12ScopedComputePass {
public:
D3D12ScopedComputePass(U32 idx, U32 internalId, Memory::Allocator& mainAllocator, Log logger);

void Create(const Microsoft::WRL::ComPtr<ID3D12Device>& device,
const PipelinePassCreateInfo& createInfo,
const Containers::Vector<RenderTargetCreateInfo>& pipelineBuffers,
const Containers::Vector<D3D12ScopedImage>& pipelineBufferImages,
const Containers::Vector<DescriptorSlot>& descriptorSlots,
const Containers::Vector<DescriptorTableEntry>& descriptorSetTable,
const Containers::Vector<D3D12ScopedShader>& allShaders);

U32 GetId() const;
U32 GetInternalId() const;

ID3D12GraphicsCommandList* GetPrimaryComputeCommandList(U32 idx) const;

const DescriptorCount& GetDescriptorCount() const;

ID3D12RootSignature* GetRootSignature() const;

const Containers::Vector<std::reference_wrapper<const D3D12ScopedShader>>& GetShaders() const;
const Containers::Vector<std::reference_wrapper<D3D12ScopedImage>>& GetInputImages() const;
const Containers::Vector<std::reference_wrapper<D3D12ScopedImage>>& GetOutputImages() const;

const Containers::Vector<DescriptorTableEntry>& GetRootSignatureTable() const;

void RecordResourceBarriersForOutputsStart(ID3D12GraphicsCommandList* commandList) const;
void RecordResourceBarriersForOutputsEnd(ID3D12GraphicsCommandList* commandList) const;
void RecordResourceBarriersForInputsStart(ID3D12GraphicsCommandList* commandList) const;
void RecordResourceBarriersForInputsEnd(ID3D12GraphicsCommandList* commandList) const;

void WaitForGPU(ID3D12CommandQueue* commandQueue);

U32 GetCommandBufferCount() const;

U32 GetInputRootDescriptorTableId() const;
U32 GetOutputRootDescriptorTableId() const;

private:
void CreateBase(
const Microsoft::WRL::ComPtr<ID3D12Device>& device,
const PipelinePassCreateInfo& createInfo,
const Containers::Vector<DescriptorSlot>& descriptorSlots,
const Containers::Vector<DescriptorTableEntry>& descriptorSetTable,
const Containers::Vector<RenderTargetCreateInfo>& pipelineBuffers,
const Containers::Vector<D3D12ScopedImage>& pipelineBufferImages,
const Containers::Vector<D3D12ScopedShader>& allShaders);

const Log log_D3D12RenderSystem;
U32 m_id;
U32 m_internalId;

U32 m_computeOutputTableIdx{ 0 };
U32 m_computeInputTableIdx{ 0 };

Containers::Vector<DescriptorTableEntry> m_rootSignatureTable;

Containers::Vector<std::reference_wrapper<const D3D12ScopedShader>> m_passShaders;

SmallVector<D3D12ScopedCommandBuffer, DEFAULT_FRAMES_IN_FLIGHT> m_commandBuffers;

Containers::Vector<std::reference_wrapper<D3D12ScopedImage>> m_computeOutputs;

Containers::Vector<std::reference_wrapper<D3D12ScopedImage>> m_computeInputs;
Containers::Vector<std::reference_wrapper<D3D12ScopedImage>> m_computeDepthInputs;
Containers::Vector<std::reference_wrapper<D3D12ScopedImage>> m_allComputeInputs;

Microsoft::WRL::ComPtr<ID3D12RootSignature> m_rootSignature;

DescriptorCount m_descriptorCount{};

U32 m_fenceValue{1};
HANDLE m_fenceEvent{};
Microsoft::WRL::ComPtr<ID3D12Fence> m_signalFence;
SmallVector<Microsoft::WRL::ComPtr<ID3D12Fence>, MAX_RENDER_PASS_INPUTS> m_waitFences;
};

} // namespace D3D12
} // namespace Azura
4 changes: 3 additions & 1 deletion Source/Azura/RenderSystem/Inc/D3D12/D3D12ScopedImage.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class D3D12ScopedImage {
ImageViewType imageView,
const Log& log_D3D12RenderSystem);

static D3D12_UNORDERED_ACCESS_VIEW_DESC GetUAV(RawStorageFormat viewFormat, ImageViewType imageView, const Log & log_D3D12RenderSystem);

static D3D12_DEPTH_STENCIL_VIEW_DESC GetDSV(RawStorageFormat viewFormat, ImageViewType imageView, const Log & log_D3D12RenderSystem);

static D3D12_RENDER_TARGET_VIEW_DESC GetRTV(RawStorageFormat viewFormat, ImageViewType imageView, const Log & log_D3D12RenderSystem);
Expand All @@ -38,7 +40,7 @@ class D3D12ScopedImage {
D3D12_RESOURCE_STATES fromState,
D3D12_RESOURCE_STATES toState) const;

void Transition(ID3D12GraphicsCommandList * commandList, D3D12_RESOURCE_STATES toState) const;
void Transition(ID3D12GraphicsCommandList * commandList, D3D12_RESOURCE_STATES toState);

private:
Microsoft::WRL::ComPtr<ID3D12Resource> m_texture;
Expand Down
12 changes: 12 additions & 0 deletions Source/Azura/RenderSystem/Inc/D3D12/D3D12ScopedPipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "Memory/Allocator.h"
#include "D3D12/D3D12ScopedShader.h"
#include "D3D12/D3D12ScopedRenderPass.h"
#include "D3D12/D3D12ScopedComputePass.h"
#include <optional>

namespace Azura {
Expand All @@ -15,16 +16,22 @@ namespace D3D12 {
class D3D12ScopedPipeline {
public:
D3D12ScopedPipeline(const Microsoft::WRL::ComPtr<ID3D12Device>& device, D3D12_GRAPHICS_PIPELINE_STATE_DESC psoDesc, const Log& log);
D3D12ScopedPipeline(const Microsoft::WRL::ComPtr<ID3D12Device>& device, D3D12_COMPUTE_PIPELINE_STATE_DESC psoDesc, const Log& log);
ID3D12PipelineState* GetState() const;

PipelineType GetType() const;

private:
PipelineType m_type;
Microsoft::WRL::ComPtr<ID3D12PipelineState> m_pipeline;
};

class D3D12PipelineFactory {
public:
D3D12PipelineFactory(Memory::Allocator& allocator, Log logger);

D3D12PipelineFactory & SetPipelineType(PipelineType type);

D3D12PipelineFactory& BulkAddAttributeDescription(const VertexSlot& vertexSlot, U32 binding);

D3D12PipelineFactory & SetRasterizerStage(CullMode cullMode, FrontFace faceOrder);
Expand All @@ -33,9 +40,13 @@ namespace D3D12 {

void Submit(const Microsoft::WRL::ComPtr<ID3D12Device>& device, const Containers::Vector<std::reference_wrapper<D3D12ScopedRenderPass>>& renderPasses, Containers::Vector<D3D12ScopedPipeline>& resultPipelines) const;

void Submit(const Microsoft::WRL::ComPtr<ID3D12Device>& device, const Containers::Vector<std::reference_wrapper<D3D12ScopedComputePass>>& computePasses, Containers::Vector<D3D12ScopedPipeline>& resultPipelines) const;

private:
const Log log_D3D12RenderSystem;

PipelineType m_type;

struct BindingInfo {
U32 m_offset{0};
};
Expand All @@ -48,6 +59,7 @@ namespace D3D12 {
Containers::Vector<D3D12_INPUT_ELEMENT_DESC> m_inputElementDescs;
std::optional<D3D12_SHADER_BYTECODE> m_vertexShaderModule;
std::optional<D3D12_SHADER_BYTECODE> m_pixelShaderModule;
std::optional<D3D12_SHADER_BYTECODE> m_computeShaderModule;
};


Expand Down
Loading

0 comments on commit b314aab

Please sign in to comment.