-
Notifications
You must be signed in to change notification settings - Fork 0
/
segment_tree_2d.hpp
105 lines (95 loc) · 3.31 KB
/
segment_tree_2d.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#pragma once
template <class MS> struct SegmentTree2D {
public:
using S = typename MS::S;
SegmentTree2D() : SegmentTree2D(0, 0) {}
SegmentTree2D(int h, int w) : SegmentTree2D(std::vector<std::vector<S>>(h, std::vector<S>(w, MS::e()))) {}
SegmentTree2D(const std::vector<std::vector<S>>& v) : h((int)(v.size())), w((int)(v[0].size())) {
logh = 0;
while ((1U << logh) < (unsigned int)(h)) logh++;
sizeh = 1 << logh;
logw = 0;
while ((1U << logw) < (unsigned int)(w)) logw++;
sizew = 1 << logw;
d = std::vector<std::vector<S>>(sizeh << 1, std::vector<S>(sizew << 1, MS::e()));
for (int i = 0; i < h; i++) {
for (int j = 0; j < w; j++) {
d[i + sizeh][j + sizew] = v[i][j];
}
}
for (int i = sizeh - 1; i >= 1; i--) {
for (int j = sizew; j < (sizew << 1); j++) {
update_bottom(i, j);
}
}
for (int i = 0; i < (sizeh << 1); i++) {
for (int j = sizew - 1; j >= 1; j--) {
update_else(i, j);
}
}
}
void set(int h, int w, const S& x) {
assert(0 <= h and h < h and 0 <= w and w < w);
h += sizeh;
w += sizew;
d[h][w] = x;
for (int i = 1; i <= logh; i++) update_bottom(h >> i, w);
for (int i = 0; i <= logh; i++) {
for (int j = 1; j <= logw; j++) {
update_else(h >> i, w >> j);
}
}
}
void chset(int h, int w, const S& x) {
assert(0 <= h and h < h and 0 <= w and w < w);
h += sizeh;
w += sizew;
d[h][w] = MS::op(d[h][w], x);
for (int i = 1; i <= logh; i++) update_bottom(h >> i, w);
for (int i = 0; i <= logh; i++) {
for (int j = 1; j <= logw; j++) {
update_else(h >> i, w >> j);
}
}
}
S operator()(int h, int w) const {
assert(0 <= h and h < h and 0 <= w and w < w);
return d[h + sizeh][w + sizew];
}
S get(int h, int w) const {
assert(0 <= h and h < h and 0 <= w and w < w);
return d[h + sizeh][w + sizew];
}
S inner_prod(int h, int w1, int w2) {
S sml = MS::e(), smr = MS::e();
while (w1 < w2) {
if (w1 & 1) sml = MS::op(sml, d[h][w1++]);
if (w2 & 1) smr = MS::op(d[h][--w2], smr);
w1 >>= 1;
w2 >>= 1;
}
return MS::op(sml, smr);
}
S prod(int h1, int w1, int h2, int w2) {
assert(0 <= h1 and h1 <= h2 and h2 <= h);
assert(0 <= w1 and w1 <= w2 and w2 <= w);
S sml = MS::e(), smr = MS::e();
h1 += sizeh;
h2 += sizeh;
w1 += sizew;
w2 += sizew;
while (h1 < h2) {
if (h1 & 1) sml = MS::op(sml, inner_prod(h1++, w1, w2));
if (h2 & 1) smr = MS::op(inner_prod(--h2, w1, w2), smr);
h1 >>= 1;
h2 >>= 1;
}
return MS::op(sml, smr);
}
S all_prod() const { return d[1][1]; }
private:
int h, logh, sizeh, w, logw, sizew;
std::vector<std::vector<S>> d;
inline void update_bottom(int i, int j) { d[i][j] = MS::op(d[(i << 1) | 0][j], d[(i << 1) | 1][j]); }
inline void update_else(int i, int j) { d[i][j] = MS::op(d[i][(j << 1) | 0], d[i][(j << 1) | 1]); }
};