Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .verify-helper/timestamps.remote.json
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
"tests/library_checker_aizu_tests/graphs/strongly_connected_components_aizu.test.cpp": "2026-01-17 13:05:42 -0700",
"tests/library_checker_aizu_tests/graphs/strongly_connected_components_lib_checker.test.cpp": "2026-01-17 13:05:42 -0700",
"tests/library_checker_aizu_tests/graphs/two_edge_components.test.cpp": "2026-02-27 11:16:07 -0700",
"tests/library_checker_aizu_tests/handmade_tests/count_paths.test.cpp": "2026-03-16 13:39:47 -0600",
"tests/library_checker_aizu_tests/handmade_tests/count_paths.test.cpp": "2026-03-16 13:52:41 -0600",
"tests/library_checker_aizu_tests/handmade_tests/dsu.test.cpp": "2026-01-22 10:08:22 -0700",
"tests/library_checker_aizu_tests/handmade_tests/edge_cd_small_trees.test.cpp": "2026-03-16 13:39:47 -0600",
"tests/library_checker_aizu_tests/handmade_tests/fib_matrix_expo.test.cpp": "2026-01-28 21:48:16 -0700",
Expand Down
29 changes: 12 additions & 17 deletions library/trees/centroid_decomp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,28 @@
//! @endcode
//! @time O(n log n)
//! @space O(n)
template<class F, class G> struct centroid {
G adj;
F f;
vi siz;
centroid(const G& adj, F f):
adj(adj), f(f), siz(sz(adj), -1) {
dfs(0, -1);
}
void calc_sz(int u, int p) {
void centroid(auto& adj, const auto& f) {
vi siz(sz(adj));
auto calc_sz = [&](auto&& self, int u, int p) -> void {
siz[u] = 1;
for (int v : adj[u])
if (v != p) calc_sz(v, u), siz[u] += siz[v];
}
void dfs(int u, int p) {
calc_sz(u, -1);
if (v != p) self(self, v, u), siz[u] += siz[v];
};
auto dfs = [&](auto&& self, int u, int p) -> void {
calc_sz(calc_sz, u, -1);
for (int w = -1, sz_root = siz[u];;) {
auto big_ch = ranges::find_if(adj[u], [&](int v) {
return v != w && 2 * siz[v] > sz_root;
});
if (big_ch == end(adj[u])) break;
w = u, u = *big_ch;
}
f(adj, u, p);
f(u, p);
for (int v : adj[u]) {
iter_swap(ranges::find(adj[v], u), rbegin(adj[v]));
adj[v].pop_back();
dfs(v, u);
self(self, v, u);
}
}
};
};
dfs(dfs, 0, -1);
}
28 changes: 0 additions & 28 deletions tests/library_checker_aizu_tests/cd_asserts.hpp
Original file line number Diff line number Diff line change
@@ -1,29 +1 @@
#pragma once
#include "../../library/trees/centroid_decomp.hpp"
void cd_asserts(const vector<vector<int>>& adj) {
vector<int> decomp_size(sz(adj), -1);
vector<int> naive_par_decomp(sz(adj), -1);
centroid(adj,
[&](const vector<vector<int>>& cd_adj, int cent,
int par_cent) -> void {
assert(naive_par_decomp[cent] == par_cent);
assert(decomp_size[cent] == -1);
auto dfs = [&](auto&& self, int u, int p) -> int {
naive_par_decomp[u] = cent;
int sub_size = 1;
for (int v : cd_adj[u])
if (v != p) sub_size += self(self, v, u);
return sub_size;
};
decomp_size[cent] = dfs(dfs, cent, -1);
if (par_cent != -1)
assert(1 <= decomp_size[cent] &&
2 * decomp_size[cent] <= decomp_size[par_cent]);
for (int u : cd_adj[cent]) {
int sz_subtree = dfs(dfs, u, cent);
assert(1 <= sz_subtree &&
2 * sz_subtree <= decomp_size[cent]);
}
});
rep(i, 0, sz(adj)) assert(decomp_size[i] >= 1);
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,34 @@
#include "../template.hpp"
#include "../../../library/contest/random.hpp"
#include "../../../library/trees/lca_rmq.hpp"
#include "../cd_asserts.hpp"
#include "../../../kactl/content/numerical/FastFourierTransform.h"
#include "../../../library/trees/edge_cd.hpp"
#include "../../../library/trees/centroid_decomp.hpp"
void cd_asserts(vector<vector<int>> adj) {
vector<int> decomp_size(sz(adj), -1);
vector<int> naive_par_decomp(sz(adj), -1);
centroid(adj, [&](int cent, int par_cent) -> void {
assert(naive_par_decomp[cent] == par_cent);
assert(decomp_size[cent] == -1);
auto dfs = [&](auto&& self, int u, int p) -> int {
naive_par_decomp[u] = cent;
int sub_size = 1;
for (int v : adj[u])
if (v != p) sub_size += self(self, v, u);
return sub_size;
};
decomp_size[cent] = dfs(dfs, cent, -1);
if (par_cent != -1)
assert(1 <= decomp_size[cent] &&
2 * decomp_size[cent] <= decomp_size[par_cent]);
for (int u : adj[cent]) {
int sz_subtree = dfs(dfs, u, cent);
assert(1 <= sz_subtree &&
2 * sz_subtree <= decomp_size[cent]);
}
});
rep(i, 0, sz(adj)) assert(decomp_size[i] >= 1);
}
//! @param adj unrooted, connected forest
//! @param k number of edges
//! @returns array `num_paths` where `num_paths[i]` =
Expand All @@ -14,37 +39,35 @@
//! @time O(n log n)
//! @space this function allocates/returns various vectors
//! which are all O(n)
vector<ll> count_paths_per_node(const vector<vi>& adj,
int k) {
vector<ll> count_paths_per_node(vector<vi> adj, int k) {
vector<ll> num_paths(sz(adj));
centroid(adj,
[&](const vector<vi>& cd_adj, int cent, int) {
vector pre_d{1}, cur_d{0};
auto dfs = [&](auto&& self, int u, int p,
int d) -> ll {
if (d > k) return 0LL;
if (sz(cur_d) <= d) cur_d.push_back(0);
cur_d[d]++;
ll cnt = 0;
if (k - d < sz(pre_d)) cnt += pre_d[k - d];
for (int c : cd_adj[u])
if (c != p) cnt += self(self, c, u, d + 1);
num_paths[u] += cnt;
return cnt;
};
auto dfs_child = [&](int child) -> ll {
ll cnt = dfs(dfs, child, cent, 1);
pre_d.resize(sz(cur_d));
for (int i = 1; i < sz(cur_d) && cur_d[i]; i++)
pre_d[i] += cur_d[i], cur_d[i] = 0;
return cnt;
};
for (int child : cd_adj[cent])
num_paths[cent] += dfs_child(child);
pre_d = cur_d = {0};
for (int child : cd_adj[cent] | views::reverse)
dfs_child(child);
});
centroid(adj, [&](int cent, int) {
vector pre_d{1}, cur_d{0};
auto dfs = [&](auto&& self, int u, int p,
int d) -> ll {
if (d > k) return 0LL;
if (sz(cur_d) <= d) cur_d.push_back(0);
cur_d[d]++;
ll cnt = 0;
if (k - d < sz(pre_d)) cnt += pre_d[k - d];
for (int c : adj[u])
if (c != p) cnt += self(self, c, u, d + 1);
num_paths[u] += cnt;
return cnt;
};
auto dfs_child = [&](int child) -> ll {
ll cnt = dfs(dfs, child, cent, 1);
pre_d.resize(sz(cur_d));
for (int i = 1; i < sz(cur_d) && cur_d[i]; i++)
pre_d[i] += cur_d[i], cur_d[i] = 0;
return cnt;
};
for (int child : adj[cent])
num_paths[cent] += dfs_child(child);
pre_d = cur_d = {0};
for (int child : adj[cent] | views::reverse)
dfs_child(child);
});
return num_paths;
}
vector<vector<ll>> naive(const vector<vi>& adj) {
Expand Down