Skip to content

Commit

Permalink
improved logic
Browse files Browse the repository at this point in the history
  • Loading branch information
tnagler committed Jun 15, 2024
1 parent bf65bb0 commit fe57abb
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 43 deletions.
36 changes: 19 additions & 17 deletions inst/include/vinecopulib/bicop/implementation/kernel.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -18,42 +18,44 @@ namespace vinecopulib {
{
using namespace tools_stats;

// for better cache behavior
using MatrixXd = Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;

size_t n = u.rows();
auto seed_state = SeedState(seeds);

auto w = simulate_uniform(n, 2, false, seed_state.next());
Eigen::MatrixXd uu = w.array() * u.leftCols(2).array() +
MatrixXd uu = w.array() * u.leftCols(2).array() +
(1 - w.array()) * u.rightCols(2).array();

Eigen::MatrixXd x = qnorm(uu);
MatrixXd x = qnorm(uu);
auto depth = std::ceil(std::log2(10 / b));
QuadTree quadtree(BoundingBox(-4.5, -4.5, 4.5, 4.5), depth, seeds);
for (size_t i = 0; i < n; i++) {
quadtree.insert(Point{x(i, 0), x(i, 1)});
}

Eigen::MatrixXd lb = safe_qnorm(u.rightCols(2)).array() - b;
Eigen::MatrixXd ub = safe_qnorm(u.leftCols(2)).array() + b;
MatrixXd lb = safe_qnorm(u.rightCols(2)).array() - b;
MatrixXd ub = safe_qnorm(u.leftCols(2)).array() + b;

Eigen::MatrixXd norm_sim(n, 2);
MatrixXd norm_sim(n, 2);
Point new_sample, old_sample;

QuadTree quadtree(BoundingBox(-4.5, -4.5, 4.5, 4.5), depth, seed_state.next());
for (size_t i = 0; i < n; i++) {
quadtree.insert(Point{x(i, 0), x(i, 1), i});
}

for (size_t it = 0; it < niter; it++) {
// x = qnorm(to_pseudo_obs(x));

norm_sim = simulate_normal(n, 2, seed_state.next()).array() * b;

for (size_t i = 0; i < n; i++) {
try {
new_sample = quadtree.sample(
BoundingBox(lb(i, 0), lb(i, 1), ub(i, 0), ub(i, 1))
);

old_sample.x = x(i, 0);
old_sample.y = x(i, 1);
quadtree.remove(old_sample);
new_sample = quadtree.sample(
BoundingBox(lb(i, 0), lb(i, 1), ub(i, 0), ub(i, 1))
);
quadtree.remove(Point{x(i, 0), x(i, 1), i});

new_sample.x += norm_sim(i, 0);
new_sample.y += norm_sim(i, 1);
new_sample.index = i;
x(i, 0) = new_sample.x;
x(i, 1) = new_sample.y;
quadtree.insert(new_sample);
Expand Down
104 changes: 78 additions & 26 deletions inst/include/vinecopulib/misc/quadtree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ namespace vinecopulib {

struct Point {
double x, y;
size_t index;
bool operator==(const Point& other) const {
return x == other.x && y == other.y && index == other.index;
}
};

struct PointID {
Point* ptr;
size_t index;
};

// A set of points with O(1) insert, remove, and uniform sampling;
Expand All @@ -18,25 +27,37 @@ class PointSet

void insert(const Point& p)
{
if (ptr_set_.find(&p) == ptr_set_.end()) { // point not already in the set
vector_.push_back(p);
ptr_set_.insert(&(*vector_.crbegin()));
if (vector_.size() == vector_.capacity()) {
vector_.reserve(vector_.size() * 2);
set_.clear();
for (auto& point : vector_) {
set_.insert(PointID{&point, point.index});
}
}
vector_.push_back(p);
set_.insert(PointID{&vector_.back(), p.index});
}

void remove(const Point& p)
{
auto el = ptr_set_.find(&p);
if (el != ptr_set_.end()) {
// first replace element by last entry of vector_, then pop last entry
if (*el != &(*vector_.crbegin())) {
*(const_cast<Point*>(*el)) = *vector_.rbegin();
}
vector_.pop_back();
auto el = set_.lower_bound(PointID{nullptr, p.index});
if (el->index != p.index) {
throw std::runtime_error("Cannot remove point that's not yet in the tree.");
}

auto last = set_.lower_bound(PointID{nullptr, vector_.back().index});
if (el->index != last->index) {
auto new_id = PointID{el->ptr, last->index};
*el->ptr = vector_.back();
set_.erase(last);
set_.insert(new_id);
}

vector_.pop_back();
set_.erase(el);
}

const Point& sample(std::mt19937* rng_ptr_) const
const Point& sample(std::mt19937* rng_ptr_)
{
auto n = vector_.size();
if (n == 0) {
Expand All @@ -48,7 +69,13 @@ class PointSet
}

private:
std::unordered_set<const Point*> ptr_set_;
struct IndexCompare {
bool operator()(const PointID& a, const PointID& b) const {
return a.index < b.index;
}
};

std::set<PointID, IndexCompare> set_;
std::vector<Point> vector_;
};

Expand Down Expand Up @@ -78,19 +105,22 @@ class QuadTree {
struct Node {
BoundingBox boundary;
PointSet points;
size_t point_count = 0;
Node* parent_ptr = nullptr;

std::array<Node*, 4> children;
Node(const BoundingBox& boundary) : boundary(boundary) {}
Node(const BoundingBox& boundary, Node* parent_ptr = nullptr) :
boundary(boundary), parent_ptr(parent_ptr) {}
};

void construct_children(Node* node, uint16_t depth, int& node_idx) {
if (depth >= depth_) return;

const auto& b = node->boundary;
nodes_.emplace_back(BoundingBox{b.x_min, b.y_min, b.x_mid, b.y_mid});
nodes_.emplace_back(BoundingBox{b.x_mid, b.y_min, b.x_max, b.y_mid});
nodes_.emplace_back(BoundingBox{b.x_min, b.y_mid, b.x_mid, b.y_max});
nodes_.emplace_back(BoundingBox{b.x_mid, b.y_mid, b.x_max, b.y_max});
nodes_.emplace_back(BoundingBox{b.x_min, b.y_min, b.x_mid, b.y_mid}, node);
nodes_.emplace_back(BoundingBox{b.x_mid, b.y_min, b.x_max, b.y_mid}, node);
nodes_.emplace_back(BoundingBox{b.x_min, b.y_mid, b.x_mid, b.y_max}, node);
nodes_.emplace_back(BoundingBox{b.x_mid, b.y_mid, b.x_max, b.y_max}, node);

for (int i = 0; i < 4; ++i) {
node->children[i] = &nodes_[node_idx++];
Expand Down Expand Up @@ -140,18 +170,22 @@ class QuadTree {
void insert(const Point& point)
{
Node* node = &nodes_[0]; // start at root
for (int d = 0; d <= depth_; ++d) {
node->points.insert(point);
node->point_count++;
for (int d = 1; d <= depth_; ++d) {
node = find_child_with_point(node, point);
node->point_count++;
}
node->points.insert(point);
}

void remove(const Point& point) {
Node* node = &nodes_[0]; // start at root
for (int d = 0; d <= depth_; ++d) {
node->points.remove(point);
node->point_count--;
for (int d = 1; d <= depth_; ++d) {
node = find_child_with_point(node, point);
node->point_count--;
}
node->points.remove(point);
}

// Helper function to recursively calculate points within the range
Expand All @@ -167,7 +201,7 @@ class QuadTree {
}

// If the node is fully contained in the range, all its points are
if (range.contains(node->boundary)) {
if (range.contains(node->boundary) || depth == depth_) {
terminal_nodes_.push_back(node);
return;
}
Expand All @@ -191,7 +225,7 @@ class QuadTree {
int total_points_in_range = 0;
for (auto node : terminal_nodes_) {
auto c = node;
total_points_in_range += node->points.size();
total_points_in_range += node->point_count;
}
if (total_points_in_range == 0) {
throw std::runtime_error("No points in the specified range.");
Expand All @@ -202,15 +236,33 @@ class QuadTree {

Node* sampled_node;
for (auto& node : terminal_nodes_) {
if (r < node->points.size()) {
if (r < node->point_count) {
sampled_node = node;
break;
}
r -= node->points.size();
r -= node->point_count;
}

while (sampled_node->points.size() == 0) {
size_t total_child_count = 0;
for (auto& child : sampled_node->children) {
total_child_count += child->point_count;
}
std::uniform_int_distribution<int> dist(0, total_child_count - 1);
r = dist(rng_);

for (auto& child : sampled_node->children) {
if (r < child->point_count) {
sampled_node = child;
break;
}
r -= child->point_count;
}
}

return sampled_node->points.sample(&rng_);
}
};

};

} // namespace vinecopulib

0 comments on commit fe57abb

Please sign in to comment.