Skip to content

Commit

Permalink
Day 10: Draw indirect!
Browse files Browse the repository at this point in the history
We've converted both traditional rasterization pipeline and mesh shading
pipeline to draw indirect, which will be helpful later for GPU-side mesh
culling etc.
  • Loading branch information
zeux committed Nov 4, 2018
1 parent f98da57 commit fda3d87
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 38 deletions.
77 changes: 49 additions & 28 deletions src/niagara.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#define VSYNC 1

bool meshShadingEnabled = false;
bool meshShadingEnabled = true;

VkInstance createInstance()
{
Expand Down Expand Up @@ -200,6 +200,7 @@ VkDevice createDevice(VkInstance instance, VkPhysicalDevice physicalDevice, uint
extensions.push_back(VK_NV_MESH_SHADER_EXTENSION_NAME);

VkPhysicalDeviceFeatures2 features = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2 };
features.features.multiDrawIndirect = true;

VkPhysicalDevice16BitStorageFeatures features16 = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES };
features16.storageBuffer16BitAccess = true;
Expand Down Expand Up @@ -518,12 +519,27 @@ struct alignas(16) Meshlet
uint8_t triangleCount;
};

struct alignas(16) MeshDraw
struct alignas(16) Globals
{
glm::mat4 projection;
};

struct alignas(16) MeshDraw
{
glm::vec3 position;
float scale;
glm::quat orientation;

union
{
uint32_t commandData[7];

struct
{
VkDrawIndexedIndirectCommand commandIndirect; // 5 uint32_t
VkDrawMeshTasksIndirectCommandNV commandIndirectMS; // 2 uint32_t
};
};
};

struct Vertex
Expand Down Expand Up @@ -900,11 +916,11 @@ int main(int argc, const char** argv)
// TODO: this is critical for performance!
VkPipelineCache pipelineCache = 0;

Program meshProgram = createProgram(device, VK_PIPELINE_BIND_POINT_GRAPHICS, { &meshVS, &meshFS }, sizeof(MeshDraw));
Program meshProgram = createProgram(device, VK_PIPELINE_BIND_POINT_GRAPHICS, { &meshVS, &meshFS }, sizeof(Globals));

Program meshProgramMS = {};
if (meshShadingSupported)
meshProgramMS = createProgram(device, VK_PIPELINE_BIND_POINT_GRAPHICS, { &meshletTS, &meshletMS, &meshFS }, sizeof(MeshDraw));
meshProgramMS = createProgram(device, VK_PIPELINE_BIND_POINT_GRAPHICS, { &meshletTS, &meshletMS, &meshFS }, sizeof(Globals));

VkPipeline meshPipeline = createGraphicsPipeline(device, pipelineCache, renderPass, { &meshVS, &meshFS }, meshProgram.layout);
assert(meshPipeline);
Expand Down Expand Up @@ -971,13 +987,6 @@ int main(int argc, const char** argv)
uploadBuffer(device, commandPool, commandBuffer, queue, mdb, scratch, mesh.meshletdata.data(), mesh.meshletdata.size() * sizeof(uint32_t));
}

Image colorTarget = {};
Image depthTarget = {};
VkFramebuffer targetFB = 0;

double frameCpuAvg = 0;
double frameGpuAvg = 0;

uint32_t drawCount = 3000;
std::vector<MeshDraw> draws(drawCount);

Expand All @@ -994,8 +1003,25 @@ int main(int argc, const char** argv)
float angle = glm::radians((float(rand()) / RAND_MAX) * 90.f);

draws[i].orientation = glm::rotate(glm::quat(1, 0, 0, 0), angle, axis);

memset(draws[i].commandData, 0, sizeof(draws[i].commandData));
draws[i].commandIndirect.indexCount = uint32_t(mesh.indices.size());
draws[i].commandIndirect.instanceCount = 1;
draws[i].commandIndirectMS.taskCount = uint32_t(mesh.meshlets.size() / 32);
}

Buffer db = {};
createBuffer(db, device, memoryProperties, 128 * 1024 * 1024, VK_BUFFER_USAGE_INDIRECT_BUFFER_BIT | VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT);

uploadBuffer(device, commandPool, commandBuffer, queue, db, scratch, draws.data(), draws.size() * sizeof(MeshDraw));

Image colorTarget = {};
Image depthTarget = {};
VkFramebuffer targetFB = 0;

double frameCpuAvg = 0;
double frameGpuAvg = 0;

while (!glfwWindowShouldClose(window))
{
double frameCpuBegin = glfwGetTime() * 1000;
Expand Down Expand Up @@ -1057,38 +1083,30 @@ int main(int argc, const char** argv)
vkCmdSetViewport(commandBuffer, 0, 1, &viewport);
vkCmdSetScissor(commandBuffer, 0, 1, &scissor);

glm::mat4x4 projection = perspectiveProjection(glm::radians(70.f), float(swapchain.width) / float(swapchain.height), 0.01f);

for (uint32_t i = 0; i < drawCount; ++i)
draws[i].projection = projection;
Globals globals = {};
globals.projection = perspectiveProjection(glm::radians(70.f), float(swapchain.width) / float(swapchain.height), 0.01f);

if (meshShadingSupported && meshShadingEnabled)
{
vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, meshPipelineMS);

DescriptorInfo descriptors[] = { vb.buffer, mb.buffer, mdb.buffer };
DescriptorInfo descriptors[] = { db.buffer, mb.buffer, mdb.buffer, vb.buffer };
vkCmdPushDescriptorSetWithTemplateKHR(commandBuffer, meshProgramMS.updateTemplate, meshProgramMS.layout, 0, descriptors);

for (auto& draw : draws)
{
vkCmdPushConstants(commandBuffer, meshProgramMS.layout, meshProgramMS.pushConstantStages, 0, sizeof(draw), &draw);
vkCmdDrawMeshTasksNV(commandBuffer, uint32_t(mesh.meshlets.size()) / 32, 0);
}
vkCmdPushConstants(commandBuffer, meshProgramMS.layout, meshProgramMS.pushConstantStages, 0, sizeof(globals), &globals);
vkCmdDrawMeshTasksIndirectNV(commandBuffer, db.buffer, offsetof(MeshDraw, commandIndirectMS), uint32_t(draws.size()), sizeof(MeshDraw));
}
else
{
vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, meshPipeline);

DescriptorInfo descriptors[] = { vb.buffer };
DescriptorInfo descriptors[] = { db.buffer, vb.buffer };
vkCmdPushDescriptorSetWithTemplateKHR(commandBuffer, meshProgram.updateTemplate, meshProgram.layout, 0, descriptors);

vkCmdBindIndexBuffer(commandBuffer, ib.buffer, 0, VK_INDEX_TYPE_UINT32);

for (auto& draw : draws)
{
vkCmdPushConstants(commandBuffer, meshProgram.layout, meshProgram.pushConstantStages, 0, sizeof(draw), &draw);
vkCmdDrawIndexed(commandBuffer, uint32_t(mesh.indices.size()), 1, 0, 0, 0);
}
vkCmdPushConstants(commandBuffer, meshProgram.layout, meshProgram.pushConstantStages, 0, sizeof(globals), &globals);
vkCmdDrawIndexedIndirect(commandBuffer, db.buffer, offsetof(MeshDraw, commandIndirect), uint32_t(draws.size()), sizeof(MeshDraw));
}

vkCmdEndRenderPass(commandBuffer);
Expand Down Expand Up @@ -1153,9 +1171,10 @@ int main(int argc, const char** argv)
frameGpuAvg = frameGpuAvg * 0.95 + (frameGpuEnd - frameGpuBegin) * 0.05;

double trianglesPerSec = double(drawCount) * double(mesh.indices.size() / 3) / double(frameGpuAvg * 1e-3);
double kittensPerSec = double(drawCount) / double(frameGpuAvg * 1e-3);

char title[256];
sprintf(title, "cpu: %.2f ms; gpu: %.2f ms; triangles %d; meshlets %d; mesh shading %s; %.1fB tri/sec", frameCpuAvg, frameGpuAvg, int(mesh.indices.size() / 3), int(mesh.meshlets.size()), meshShadingSupported && meshShadingEnabled ? "ON" : "OFF", trianglesPerSec * 1e-9);
sprintf(title, "cpu: %.2f ms; gpu: %.2f ms; triangles %d; meshlets %d; mesh shading %s; %.1fB tri/sec, %.1fM kittens/sec", frameCpuAvg, frameGpuAvg, int(mesh.indices.size() / 3), int(mesh.meshlets.size()), meshShadingSupported && meshShadingEnabled ? "ON" : "OFF", trianglesPerSec * 1e-9, kittensPerSec * 1e-6);
glfwSetWindowTitle(window, title);
}

Expand All @@ -1168,6 +1187,8 @@ int main(int argc, const char** argv)
if (targetFB)
vkDestroyFramebuffer(device, targetFB, 0);

destroyBuffer(db, device);

if (meshShadingSupported)
{
destroyBuffer(mb, device);
Expand Down
8 changes: 7 additions & 1 deletion src/shaders/mesh.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@ struct Meshlet
uint8_t triangleCount;
};

struct MeshDraw
struct Globals
{
mat4 projection;
};

struct MeshDraw
{
vec3 position;
float scale;
vec4 orientation;

uint commandData[7];
};

vec3 rotateQuat(vec3 v, vec4 q)
Expand Down
15 changes: 12 additions & 3 deletions src/shaders/mesh.vert.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,21 @@

#extension GL_GOOGLE_include_directive: require

#extension GL_ARB_shader_draw_parameters: require

#include "mesh.h"

layout(push_constant) uniform block
{
MeshDraw meshDraw;
Globals globals;
};

layout(binding = 0) readonly buffer Vertices
layout(binding = 0) readonly buffer Draws
{
MeshDraw draws[];
};

layout(binding = 1) readonly buffer Vertices
{
Vertex vertices[];
};
Expand All @@ -21,11 +28,13 @@ layout(location = 0) out vec4 color;

void main()
{
MeshDraw meshDraw = draws[gl_DrawIDARB];

vec3 position = vec3(vertices[gl_VertexIndex].vx, vertices[gl_VertexIndex].vy, vertices[gl_VertexIndex].vz);
vec3 normal = vec3(int(vertices[gl_VertexIndex].nx), int(vertices[gl_VertexIndex].ny), int(vertices[gl_VertexIndex].nz)) / 127.0 - 1.0;
vec2 texcoord = vec2(vertices[gl_VertexIndex].tu, vertices[gl_VertexIndex].tv);

gl_Position = meshDraw.projection * vec4(rotateQuat(position, meshDraw.orientation) * meshDraw.scale + meshDraw.position, 1);
gl_Position = globals.projection * vec4(rotateQuat(position, meshDraw.orientation) * meshDraw.scale + meshDraw.position, 1);

color = vec4(normal * 0.5 + vec3(0.5), 1.0);
}
17 changes: 13 additions & 4 deletions src/shaders/meshlet.mesh.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#extension GL_GOOGLE_include_directive: require

#extension GL_ARB_shader_draw_parameters: require

#include "mesh.h"

#define DEBUG 0
Expand All @@ -15,12 +17,12 @@ layout(triangles, max_vertices = 64, max_primitives = 124) out;

layout(push_constant) uniform block
{
MeshDraw meshDraw;
Globals globals;
};

layout(binding = 0) readonly buffer Vertices
layout(binding = 0) readonly buffer Draws
{
Vertex vertices[];
MeshDraw draws[];
};

layout(binding = 1) readonly buffer Meshlets
Expand All @@ -33,6 +35,11 @@ layout(binding = 2) readonly buffer MeshletData
uint meshletData[];
};

layout(binding = 3) readonly buffer Vertices
{
Vertex vertices[];
};

in taskNV block
{
uint meshletIndices[32];
Expand All @@ -56,6 +63,8 @@ void main()
uint ti = gl_LocalInvocationID.x;
uint mi = meshletIndices[gl_WorkGroupID.x];

MeshDraw meshDraw = draws[gl_DrawIDARB];

uint vertexCount = uint(meshlets[mi].vertexCount);
uint triangleCount = uint(meshlets[mi].triangleCount);
uint indexCount = triangleCount * 3;
Expand All @@ -78,7 +87,7 @@ void main()
vec3 normal = vec3(int(vertices[vi].nx), int(vertices[vi].ny), int(vertices[vi].nz)) / 127.0 - 1.0;
vec2 texcoord = vec2(vertices[vi].tu, vertices[vi].tv);

gl_MeshVerticesNV[i].gl_Position = meshDraw.projection * vec4(rotateQuat(position, meshDraw.orientation) * meshDraw.scale + meshDraw.position, 1);
gl_MeshVerticesNV[i].gl_Position = globals.projection * vec4(rotateQuat(position, meshDraw.orientation) * meshDraw.scale + meshDraw.position, 1);
color[i] = vec4(normal * 0.5 + vec3(0.5), 1.0);

#if DEBUG
Expand Down
8 changes: 6 additions & 2 deletions src/shaders/meshlet.task.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@

#extension GL_KHR_shader_subgroup_ballot: require

#extension GL_ARB_shader_draw_parameters: require

#include "mesh.h"

#define CULL 1

layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;

layout(push_constant) uniform block
layout(binding = 0) readonly buffer Draws
{
MeshDraw meshDraw;
MeshDraw draws[];
};

layout(binding = 1) readonly buffer Meshlets
Expand All @@ -42,6 +44,8 @@ void main()
uint mgi = gl_WorkGroupID.x;
uint mi = mgi * 32 + ti;

MeshDraw meshDraw = draws[gl_DrawIDARB];

#if CULL
vec3 center = rotateQuat(meshlets[mi].center, meshDraw.orientation) * meshDraw.scale + meshDraw.position;
float radius = meshlets[mi].radius * meshDraw.scale;
Expand Down

0 comments on commit fda3d87

Please sign in to comment.