Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 1 addition & 4 deletions vortex-cuda/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ cudarc = { workspace = true, features = ["f16"] }
futures = { workspace = true, features = ["executor"] }
itertools = { workspace = true }
kanal = { workspace = true }
num-traits = { workspace = true }
object_store = { workspace = true, features = ["fs"] }
parking_lot = { workspace = true }
prost = { workspace = true }
Expand Down Expand Up @@ -89,7 +90,3 @@ harness = false
[[bench]]
name = "throughput_cuda"
harness = false

[[bench]]
name = "transpose_patches"
harness = false
81 changes: 0 additions & 81 deletions vortex-cuda/benches/transpose_patches.rs

This file was deleted.

19 changes: 13 additions & 6 deletions vortex-cuda/kernels/src/bit_unpack_16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,28 @@
template <int BW>
__device__ void _bit_unpack_16_device(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, int thread_idx, GPUPatches& patches) {
__shared__ uint16_t shared_out[1024];

// Step 1: Unpack into shared memory
#pragma unroll
for (int i = 0; i < 2; i++) {
_bit_unpack_16_lane<BW>(in, shared_out, reference, thread_idx * 2 + i);
}
__syncwarp();

// Step 2: Apply patches to shared memory in parallel
PatchesCursor<uint16_t> cursor(patches, blockIdx.x, thread_idx, 32);
auto patch = cursor.next();
while (patch.index != 1024) {
shared_out[patch.index] = patch.value;
patch = cursor.next();
}
__syncwarp();

// Step 3: Copy to global memory
#pragma unroll
for (int i = 0; i < 32; i++) {
auto idx = i * 32 + thread_idx;
if (idx == patch.index) {
out[idx] = patch.value;
patch = cursor.next();
} else {
out[idx] = shared_out[idx];
}
out[idx] = shared_out[idx];
}
}

Expand Down
19 changes: 13 additions & 6 deletions vortex-cuda/kernels/src/bit_unpack_32.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,28 @@
template <int BW>
__device__ void _bit_unpack_32_device(const uint32_t *__restrict in, uint32_t *__restrict out, uint32_t reference, int thread_idx, GPUPatches& patches) {
__shared__ uint32_t shared_out[1024];

// Step 1: Unpack into shared memory
#pragma unroll
for (int i = 0; i < 1; i++) {
_bit_unpack_32_lane<BW>(in, shared_out, reference, thread_idx * 1 + i);
}
__syncwarp();

// Step 2: Apply patches to shared memory in parallel
PatchesCursor<uint32_t> cursor(patches, blockIdx.x, thread_idx, 32);
auto patch = cursor.next();
while (patch.index != 1024) {
shared_out[patch.index] = patch.value;
patch = cursor.next();
}
__syncwarp();

// Step 3: Copy to global memory
#pragma unroll
for (int i = 0; i < 32; i++) {
auto idx = i * 32 + thread_idx;
if (idx == patch.index) {
out[idx] = patch.value;
patch = cursor.next();
} else {
out[idx] = shared_out[idx];
}
out[idx] = shared_out[idx];
}
}

Expand Down
19 changes: 13 additions & 6 deletions vortex-cuda/kernels/src/bit_unpack_64.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,28 @@
template <int BW>
__device__ void _bit_unpack_64_device(const uint64_t *__restrict in, uint64_t *__restrict out, uint64_t reference, int thread_idx, GPUPatches& patches) {
__shared__ uint64_t shared_out[1024];

// Step 1: Unpack into shared memory
#pragma unroll
for (int i = 0; i < 1; i++) {
_bit_unpack_64_lane<BW>(in, shared_out, reference, thread_idx * 1 + i);
}
__syncwarp();

// Step 2: Apply patches to shared memory in parallel
PatchesCursor<uint64_t> cursor(patches, blockIdx.x, thread_idx, 16);
auto patch = cursor.next();
while (patch.index != 1024) {
shared_out[patch.index] = patch.value;
patch = cursor.next();
}
__syncwarp();

// Step 3: Copy to global memory
#pragma unroll
for (int i = 0; i < 64; i++) {
auto idx = i * 16 + thread_idx;
if (idx == patch.index) {
out[idx] = patch.value;
patch = cursor.next();
} else {
out[idx] = shared_out[idx];
}
out[idx] = shared_out[idx];
}
}

Expand Down
19 changes: 13 additions & 6 deletions vortex-cuda/kernels/src/bit_unpack_8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,28 @@
template <int BW>
__device__ void _bit_unpack_8_device(const uint8_t *__restrict in, uint8_t *__restrict out, uint8_t reference, int thread_idx, GPUPatches& patches) {
__shared__ uint8_t shared_out[1024];

// Step 1: Unpack into shared memory
#pragma unroll
for (int i = 0; i < 4; i++) {
_bit_unpack_8_lane<BW>(in, shared_out, reference, thread_idx * 4 + i);
}
__syncwarp();

// Step 2: Apply patches to shared memory in parallel
PatchesCursor<uint8_t> cursor(patches, blockIdx.x, thread_idx, 32);
auto patch = cursor.next();
while (patch.index != 1024) {
shared_out[patch.index] = patch.value;
patch = cursor.next();
}
__syncwarp();

// Step 3: Copy to global memory
#pragma unroll
for (int i = 0; i < 32; i++) {
auto idx = i * 32 + thread_idx;
if (idx == patch.index) {
out[idx] = patch.value;
patch = cursor.next();
} else {
out[idx] = shared_out[idx];
}
out[idx] = shared_out[idx];
}
}

Expand Down
90 changes: 69 additions & 21 deletions vortex-cuda/kernels/src/patches.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,21 @@

#include "patches.h"

/// Load a chunk offset value, dispatching on the runtime type.
__device__ inline uint32_t load_chunk_offset(const GPUPatches &patches, uint32_t idx) {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this exists b/c the chunk_offsets array can be any PType. We only index it once when we setup the cursor, so instead of trying to do an expensive cast of the whole array upfront, I just do runtime dispatch on the PType

switch (patches.chunk_offset_type) {
case CO_U8:
return reinterpret_cast<const uint8_t *>(patches.chunk_offsets)[idx];
case CO_U16:
return reinterpret_cast<const uint16_t *>(patches.chunk_offsets)[idx];
case CO_U32:
return reinterpret_cast<const uint32_t *>(patches.chunk_offsets)[idx];
case CO_U64:
return static_cast<uint32_t>(reinterpret_cast<const uint64_t *>(patches.chunk_offsets)[idx]);
}
return 0;
}

/// A single patch: a within-chunk index and its replacement value.
/// A sentinel patch has index == 1024, which can never match a valid
/// within-chunk position (0–1023).
Expand All @@ -14,54 +29,87 @@ struct Patch {
T value;
};

/// Cursor for iterating over a single lane's patches within a chunk.
/// Cursor for iterating over a thread's portion of patches within a chunk.
///
/// Usage in the generated merge-loop:
/// Patches are divided evenly among threads. Each thread applies its patches
/// to shared memory, then all threads sync and copy to global memory.
///
/// Usage in the generated kernel:
///
/// PatchesCursor<uint32_t> cursor(patches, blockIdx.x, thread_idx, 32);
/// auto patch = cursor.next();
/// for (int i = 0; i < 32; i++) {
/// auto idx = i * 32 + thread_idx;
/// if (idx == patch.index) {
/// out[idx] = patch.value;
/// patch = cursor.next();
/// } else {
/// out[idx] = shared_out[idx];
/// }
/// while (patch.index != 1024) {
/// shared_out[patch.index] = patch.value;
/// patch = cursor.next();
/// }
template <typename T>
class PatchesCursor {
public:
/// Construct a cursor positioned at the patches for the given (chunk, lane).
/// n_lanes is a compile-time constant emitted by the code generator (16 or 32).
__device__ PatchesCursor(const GPUPatches &patches, uint32_t chunk, uint32_t lane, uint32_t n_lanes) {
if (patches.lane_offsets == nullptr) {
/// Construct a cursor for this thread's portion of patches in the chunk.
__device__
PatchesCursor(const GPUPatches &patches, uint32_t chunk, uint32_t thread_idx, uint32_t n_threads) {
if (patches.chunk_offsets == nullptr) {
indices = nullptr;
values = nullptr;
remaining = 0;
return;
}
auto slot = chunk * n_lanes + lane;
auto start = patches.lane_offsets[slot];
remaining = patches.lane_offsets[slot + 1] - start;

// mirrors the logic from vortex-array/src/arrays/primitive/array/patch.rs

// Compute base_offset from the first chunk offset.
uint32_t base_offset = load_chunk_offset(patches, 0);

uint32_t patches_start_idx = load_chunk_offset(patches, chunk) - base_offset;
patches_start_idx -= min(patches_start_idx, patches.offset_within_chunk);

// calculate the ending index.
uint32_t patches_end_idx;
if ((chunk + 1) < patches.n_chunks) {
patches_end_idx = load_chunk_offset(patches, chunk + 1) - base_offset;
// if this is the end of times, we should drop it out here...
patches_end_idx -= min(patches_end_idx, patches.offset_within_chunk);
} else {
patches_end_idx = patches.num_patches;
}

// calculate how many patches are in the chunk
uint32_t num_patches = patches_end_idx - patches_start_idx;

// Divide patches among threads (ceil division)
uint32_t patches_per_thread = (num_patches + n_threads - 1) / n_threads;
uint32_t my_start = min(thread_idx * patches_per_thread, num_patches);
uint32_t my_end = min((thread_idx + 1) * patches_per_thread, num_patches);

uint32_t start = patches_start_idx + my_start;
remaining = my_end - my_start;
indices = patches.indices + start;
values = reinterpret_cast<const T *>(patches.values) + start;

// The iterator returns indices relative to the start of the chunk.
// `chunk_base` is the index of the first element within a chunk, accounting
// for the slice offset.
chunk_base = chunk * 1024 + patches.offset;
chunk_base -= min(chunk_base, patches.offset % 1024);
}

/// Return the current patch and advance, or a sentinel {1024, 0} if exhausted.
/// Return the current patch (with within-chunk index) and advance,
/// or a sentinel {1024, 0} if exhausted.
__device__ Patch<T> next() {
if (remaining == 0) {
return {1024, T {}};
}
Patch<T> patch = {*indices, *values};
uint16_t within_chunk = static_cast<uint16_t>(*indices - chunk_base);
Patch<T> patch = {within_chunk, *values};
indices++;
values++;
remaining--;
return patch;
}

private:
const uint16_t *indices;
const uint32_t *indices;
const T *values;
uint8_t remaining;
};
uint32_t chunk_base;
};
Loading
Loading