/
dummy_2.test.cpp
118 lines (95 loc) · 3.25 KB
/
dummy_2.test.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#define PROBLEM "https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=ITP1_1_A"
#include <iostream>
#include <set>
#include <random>
#include <atcoder/modint>
#include <atcoder/convolution>
using mint = atcoder::modint998244353;
#include "library/polynomial/formal_power_series.hpp"
#include "library/polynomial/lagrange_interpolation.hpp"
template <int N>
void test() {
std::mt19937 rng{ std::random_device{}() };
std::uniform_int_distribution<int> dist(0, mint::mod() - 1);
std::vector<mint> f(N);
for (int i = 0; i < N; ++i) f[i] = dist(rng);
auto eval = [&f](mint x) -> mint {
mint y = 0;
for (int i = N - 1; i >= 0; --i) y = y * x + f[i];
return y;
};
std::vector<mint> xs(N), ys(N);
[&] {
std::set<int> st;
for (int i = 0; i < N; ++i) {
do xs[i] = dist(rng); while (st.count(xs[i].val()));
st.insert(xs[i].val());
ys[i] = eval(xs[i]);
}
}();
auto check = [&](mint t) {
mint expected = eval(t);
mint actual_fast = suisen::lagrange_interpolation<suisen::FormalPowerSeries<mint>>(xs, ys, t);
mint actual_naive = suisen::lagrange_interpolation_naive(xs, ys, t);
assert(expected == actual_naive);
assert(expected == actual_fast);
};
for (int i = 0; i < N; ++i) {
check(xs[i]);
}
for (int i = 0; i < N; ++i) {
check(dist(rng));
}
}
template <int N>
void test_arithmetic_progression() {
std::mt19937 rng{ std::random_device{}() };
std::uniform_int_distribution<int> dist(0, mint::mod() - 1);
std::vector<mint> f(N);
for (int i = 0; i < N; ++i) f[i] = dist(rng);
auto eval = [&f](mint x) -> mint {
mint y = 0;
for (int i = N - 1; i >= 0; --i) y = y * x + f[i];
return y;
};
auto do_test = [&](mint a, mint b) {
std::vector<mint> xs(N), ys(N);
for (int i = 0; i < N; ++i) {
xs[i] = a * i + b;
ys[i] = eval(xs[i]);
}
auto check = [&](mint t) {
mint expected = eval(t);
mint actual_arith = suisen::lagrange_interpolation_arithmetic_progression(a, b, ys, t);
mint actual_fast = suisen::lagrange_interpolation<suisen::FormalPowerSeries<mint>>(xs, ys, t);
mint actual_naive = suisen::lagrange_interpolation_naive(xs, ys, t);
assert(expected == actual_arith);
assert(expected == actual_naive);
assert(expected == actual_fast);
};
for (int i = 0; i < N; ++i) {
check(xs[i]);
}
for (int i = 0; i < N; ++i) {
check(dist(rng));
}
};
mint a = dist(rng);
while (a == 0) a = dist(rng);
do_test(a, dist(rng));
}
void test_arithmetic_progression_zero() {
std::mt19937 rng{ std::random_device{}() };
std::uniform_int_distribution<int> dist(0, mint::mod() - 1);
mint a = 0, b = dist(rng), y = dist(rng), t = dist(rng);
mint expected = y;
mint actual = suisen::lagrange_interpolation_arithmetic_progression(a, b, { y }, t);
assert(expected == actual);
}
int main() {
test<100>();
test_arithmetic_progression<100>();
test_arithmetic_progression_zero();
std::cout << "Hello World" << std::endl;
return 0;
}