Skip to content

Commit

Permalink
Add documentation & minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
osreboot committed Nov 30, 2023
1 parent 7eb61d1 commit b2e5499
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 58 deletions.
3 changes: 1 addition & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ project(GaussianSplatterer LANGUAGES C CXX CUDA)

set(CMAKE_CXX_STANDARD 23)
set(CMAKE_CUDA_STANDARD 14)

add_definitions(-DNDEBUG)
set(CMAKE_CUDA_ARCHITECTURES 75)

add_executable(${PROJECT_NAME} WIN32 ${SOURCES})

Expand Down
8 changes: 8 additions & 0 deletions src/Config.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,17 @@
#define M_PI 3.1415926535897932384626433832795
#endif

#define VERSION "1.0.0"

// Number of allowed training steps per second
#define AUTO_TRAIN_BUDGET 100.0f

// Resolution of truth images used by the training process
#define RENDER_RESOLUTION_X 1024
#define RENDER_RESOLUTION_Y 1024

// Maximum number of splats that a model can have (splats will automatically stop subdividing after reaching this limit)
#define SPLATS_LIMIT 1000000

#define SPLATS_SH_DEGREE 1
#define SPLATS_SH_COEF 4
15 changes: 8 additions & 7 deletions src/ModelSplatsHost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@ ModelSplatsHost::ModelSplatsHost(const std::vector<float>& locationsArg, const s
shDegree = (((int)shsArg.size() / (3 * count)) - 1) / 3;
shCoeffs = ((int)shsArg.size() / (3 * count));

assert(locationsArg.size() == count * 3 &&
shsArg.size() == count * 3 * shCoeffs &&
scalesArg.size() == count * 3 &&
opacitiesArg.size() == count &&
rotationsArg.size() == count * 4);
if (locationsArg.size() != count * 3 || shsArg.size() != count * 3 * shCoeffs || scalesArg.size() != count * 3 ||
opacitiesArg.size() != count || rotationsArg.size() != count * 4) {
throw std::runtime_error("Inconsistent feature dimensions supplied when creating a host model!");
}

locations = new float[capacity * 3];
shs = new float[capacity * 3 * shCoeffs];
Expand All @@ -64,7 +63,7 @@ ModelSplatsHost::~ModelSplatsHost() {
}

void ModelSplatsHost::pushBack(glm::vec3 location, std::vector<float> sh, glm::vec3 scale, float opacity, glm::quat rotation) {
assert(count < capacity);
if (count >= capacity) throw std::runtime_error("Model ran out of capacity!");

memcpy(&locations[count * 3], &location, 3 * sizeof(float));
for (int i = 0; i < shCoeffs * 3; i++) {
Expand All @@ -78,7 +77,9 @@ void ModelSplatsHost::pushBack(glm::vec3 location, std::vector<float> sh, glm::v
}

void ModelSplatsHost::copy(int indexTo, int indexFrom) {
assert(indexTo >= 0 && indexTo < count && indexFrom > 0 && indexFrom < count);
if (indexTo < 0 || indexTo >= count || indexFrom < 0 || indexFrom >= count) {
throw std::runtime_error("Can't copy splat in model, incorrect bounds and/or no capacity!");
}

memcpy(&locations[indexTo * 3], &locations[indexFrom * 3], 3 * sizeof(float));
for (int i = 0; i < shCoeffs * 3; i++) {
Expand Down
89 changes: 64 additions & 25 deletions src/Trainer.cu

Large diffs are not rendered by default.

15 changes: 13 additions & 2 deletions src/Trainer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,37 @@ class RtxHost;
class Trainer {

private:
// Variables to track how many splats/truth frames were present last iteration. If these don't match the current
// values, we need to reallocate all host and device buffers.
int lastCount = -1;
int lastTruthCount = -1;

// Device training data (used between all models/projects)
float* devBackground;
float* devMatView;
float* devMatProjView;
float* devCameraLocation;
float* devRasterized;

// Device averaged/accumulated training data (used across all truth frames for a single training iteration)
float* devVarLocations;
float* devAvgGradLocations;
float* devAvgGradShs;
float* devAvgGradScales;
float* devAvgGradOpacities;
float* devAvgGradRotations;

// Device per-pixel loss for a single truth frame
float* devLossPixels;

// Device per-frame training data
float* devGradLocations;
float* devGradShs;
float* devGradScales;
float* devGradOpacities;
float* devGradRotations;

// Device data used by the backward rasterizer
float* devGradMean2D;
float* devGradConic;
float* devGradColor;
Expand All @@ -42,8 +49,8 @@ private:
public:
ModelSplatsDevice* model;

std::vector<uint32_t*> truthFrameBuffersW;
std::vector<uint32_t*> truthFrameBuffersB;
std::vector<uint32_t*> truthFrameBuffersW; // Truth FBO pointers (with white backgrounds)
std::vector<uint32_t*> truthFrameBuffersB; // Truth FBO pointers (with black backgrounds)
std::vector<Camera> truthCameras;

Trainer();
Expand All @@ -55,10 +62,14 @@ public:

~Trainer();

// Render the current model to a given FBO. Size parameters represent the FBO resolution. Uses the given camera and
// the debug scale modifier for splats.
void render(uint32_t* frameBuffer, int sizeX, int sizeY, float splatScale, const Camera& camera);

// Capture truth frames using the given project settings and ray tracer
void captureTruths(const Project& project, RtxHost& rtx);

// Advance one training iteration (assumes model is loaded and truth data is present)
void train(Project& project, bool densify);

};
53 changes: 37 additions & 16 deletions src/ui/UiFrame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
using namespace std;

UiFrame::UiFrame() :
wxFrame(nullptr, wxID_ANY, "Gaussian Splatterer") {
wxFrame(nullptr, wxID_ANY, format("Gaussian Splatterer - v{}", VERSION)) {
panel = new wxPanel(this);
sizer = new wxBoxSizer(wxVERTICAL);
panel->SetSizerAndFit(sizer);
Expand Down Expand Up @@ -101,14 +101,15 @@ void UiFrame::initProject() {
}

void UiFrame::initFieldGrid() {
ModelSplatsHost modelHost(SPLATS_LIMIT, 1, 4);
ModelSplatsHost modelHost(SPLATS_LIMIT, SPLATS_SH_DEGREE, SPLATS_SH_COEF);

static const float dim = 4.0f;
static const float step = 0.5f;

std::vector<float> shs;
for(int i = 0; i < 3 * modelHost.shCoeffs; i++) shs.push_back(0.0f);

// Create a scene-sized grid of splats
for(float x = -dim; x <= dim; x += step){
for(float y = -dim; y <= dim; y += step){
for(float z = -dim; z <= dim; z += step){
Expand All @@ -118,20 +119,23 @@ void UiFrame::initFieldGrid() {
}
}

// Send new splats to the GPU for training
delete trainer->model;
trainer->model = new ModelSplatsDevice(modelHost);
project->iterations = 0;
}

void UiFrame::initFieldMono() {
ModelSplatsHost modelHost(SPLATS_LIMIT, 1, 4);
ModelSplatsHost modelHost(SPLATS_LIMIT, SPLATS_SH_DEGREE, SPLATS_SH_COEF);

std::vector<float> shs;
for(int i = 0; i < 3 * modelHost.shCoeffs; i++) shs.push_back(0.0f);

// Create a single giant splat
modelHost.pushBack({0.0f, 0.0f, 0.0f}, shs, {0.3f, 0.3f, 0.3f}, 1.0f,
glm::angleAxis(0.0f, glm::vec3(0.0f, 1.0f, 0.0f)));

// Send new splats to the GPU for training
delete trainer->model;
trainer->model = new ModelSplatsDevice(modelHost);
project->iterations = 0;
Expand All @@ -144,6 +148,7 @@ void UiFrame::initFieldModel() {
return;
}

// Parse the model OBJ file and accumulate vertices & triangles
std::vector<owl::vec3f> vertices;
std::vector<owl::vec3i> triangles;

Expand Down Expand Up @@ -191,57 +196,67 @@ void UiFrame::initFieldModel() {
}
}

ModelSplatsHost modelHost(SPLATS_LIMIT, 1, 4);
ModelSplatsHost modelHost(SPLATS_LIMIT, SPLATS_SH_DEGREE, SPLATS_SH_COEF);

// Create one splat per triangle, with a scale/rotation that matches the triangle's orientation
for(owl::vec3i triangle : triangles) {
glm::vec3 v0(vertices[triangle.x].x, vertices[triangle.x].y, vertices[triangle.x].z);
glm::vec3 v1(vertices[triangle.y].x, vertices[triangle.y].y, vertices[triangle.y].z);
glm::vec3 v2(vertices[triangle.z].x, vertices[triangle.z].y, vertices[triangle.z].z);

// Location is the average of the triangle's vertices
glm::vec3 location = (v0 + v1 + v2) / 3.0f;

// Very thin splat, estimate the planar dimensions based on the triangle's edge lengths
glm::vec3 scale(glm::length(v1 - v0), glm::length(v2 - v0), 0.005f);
scale *= 0.2f;

std::vector<float> shs;
for(int i = 0; i < 3 * modelHost.shCoeffs; i++) shs.push_back(0.0f);

glm::vec3 up = glm::vec3(0.0f, 0.0f, 1.0f);
glm::vec3 dir = glm::normalize(glm::cross(v1 - v0, v2 - v0));

glm::vec3 axis = glm::cross(up, dir);
float angle = glm::acos(glm::dot(up, dir));
// Calculate axis/angle parameters so we can rotate the splat to face the source triangle's normal
glm::vec3 splatUp = glm::vec3(0.0f, 0.0f, 1.0f);
glm::vec3 triNormal = glm::normalize(glm::cross(v1 - v0, v2 - v0));
glm::vec3 axis = glm::cross(splatUp, triNormal);
float angle = glm::acos(glm::dot(splatUp, triNormal));

modelHost.pushBack(location, shs, scale, 1.0f, glm::angleAxis(angle, axis));
}

// Send new splats to the GPU for training
delete trainer->model;
trainer->model = new ModelSplatsDevice(modelHost);
project->iterations = 0;
}

void UiFrame::update() {
// Calculate how much time has passed since the last update
timeNow = chrono::high_resolution_clock::now();
float delta = (float)chrono::duration_cast<chrono::nanoseconds>(timeNow - timeLastUpdate).count() / 1000000000.0f;
//delta = min(delta, 0.2f);
timeLastUpdate = timeNow;

project->previewTimer += delta;

if(autoTraining) {
autoTrainingBudget = min(1.0f, autoTrainingBudget + delta * 100.0f);
// Advance the training capacity if we have room, but do not exceed one potential iteration (so as not to
// accumulate over time)
autoTrainingBudget = min(1.0f, autoTrainingBudget + delta * AUTO_TRAIN_BUDGET);

if(autoTrainingBudget >= 1.0f) {
if(autoTrainingBudget >= 1.0f) { // Run a training iteration
autoTrainingBudget = 0.0f;

// Check if this is a special iteration
const bool capture = project->intervalCapture > 0 && project->iterations % project->intervalCapture == 0;
const bool densify = project->intervalDensify > 0 && project->iterations % project->intervalDensify == 0;

if(capture) {
if(capture) { // Randomize camera sphere rotations & collect new truth data
wxCommandEvent eventFake = wxCommandEvent(wxEVT_NULL, 0);
panelTools->panelTruth->onButtonRandomRotate(eventFake);
panelTools->panelTruth->onButtonCapture(eventFake);
}

trainer->train(*project, densify);

panelOutput->refreshText();
panelTools->panelTrain->refreshText();
}
Expand Down Expand Up @@ -272,11 +287,13 @@ void UiFrame::saveSettings(const std::string& path) const {

void UiFrame::saveSplats(const std::string& path) const {
wxProgressDialog dialog("Saving Gaussian Splats", "Writing splats to \"" + path + "\"...", trainer->model->count + 1000, panel, wxPD_AUTO_HIDE);
ModelSplatsHost model(*trainer->model);

// Initializing the model takes time, so count this as (the equivalent of) 1,000 line reads
ModelSplatsHost model(*trainer->model);
int progress = 1000;
dialog.Update(progress);

// Write splats to the custom Gaussian OBJ (.gobj) file format
std::ofstream file(path);
for (int i = 0; i < model.count; i++) {
file << "v " << model.locations[i * 3] << " " << model.locations[i * 3 + 1] << " " << model.locations[i * 3 + 2] << "\n";
Expand Down Expand Up @@ -323,6 +340,7 @@ void UiFrame::loadSplats(const std::string& path) {

int progress = 0;

// Read splats from the custom Gaussian OBJ (.gobj) file format
std::optional<int> shCoeffs = nullopt;

std::vector<float> locations;
Expand Down Expand Up @@ -351,8 +369,10 @@ void UiFrame::loadSplats(const std::string& path) {
shs.push_back(x);
shCoeffsCount++;
}

// All spherical harmonic properties (splat color) need to have the same dimension throughout the model
if (!shCoeffs) shCoeffs = shCoeffsCount;
else assert(shCoeffs == shCoeffsCount);
else if (shCoeffs != shCoeffsCount) throw std::runtime_error("Inconsistent SH degree!");
} else if (prefix == "s") {
for (int f = 0; f < 3; f++) {
float x;
Expand All @@ -374,11 +394,12 @@ void UiFrame::loadSplats(const std::string& path) {
dialog.Update(++progress);
}

// Initializing the model takes time, so count this as (the equivalent of) 1,000 line reads
ModelSplatsHost modelHost(locations, shs, scales, opacities, rotations);

progress += 1000;
dialog.Update(progress);

// Send new splats to the GPU for training
delete trainer->model;
trainer->model = new ModelSplatsDevice(modelHost);
}
Expand Down
12 changes: 6 additions & 6 deletions src/ui/UiFrame.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class UiFrame : public wxFrame {
wxBoxSizer* sizer;
wxBoxSizer* sizerViews;

float autoTrainingBudget = 0.0f;
float autoTrainingBudget = 0.0f; // Number of allowed training steps per second

public:
Project* project = nullptr;
Expand All @@ -47,14 +47,14 @@ class UiFrame : public wxFrame {
~UiFrame() override;

private:
void initProject();
void initFieldGrid();
void initFieldMono();
void initFieldModel();
void initProject(); // Reset to a new project
void initFieldGrid(); // Initialize splats with a scene-sized grid
void initFieldMono(); // Initialize splats with a single giant splat
void initFieldModel(); // Initialize splats with one splat per model triangle

void update();

void refreshProject();
void refreshProject(); // Called when the project data gets changed, update all text fields & spinners

void saveSettings(const std::string& path) const;
void saveSplats(const std::string& path) const;
Expand Down

0 comments on commit b2e5499

Please sign in to comment.