Skip to content

Commit

Permalink
Update Modint
Browse files Browse the repository at this point in the history
  • Loading branch information
ruthen71 committed May 3, 2024
1 parent 71436a6 commit cc381ab
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 22 deletions.
9 changes: 9 additions & 0 deletions docs/math/dynamic_modint.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
title: Dynamic Modint
documentation_of: //math/dynamic_modint.hpp
---

コンパイル時に $ \bmod $ が確定しないときに使える

- すでに `using mintd = DynamicModint<0>;` が宣言されており、`mintd::set_mod(m)` などでセット可能
- 複数の $ \bmod $ に対して利用したい場合は `using mintd1 = DynamicModint<-1>` などとして増やす
2 changes: 1 addition & 1 deletion docs/math/modint261.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
title: ModInt $\pmod{2^{61}-1} $
title: Modint $\pmod{2^{61}-1} $
documentation_of: //math/modint261.hpp
---

Expand Down
10 changes: 10 additions & 0 deletions docs/math/static_modint.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
title: Static Modint
documentation_of: //math/static_modint.hpp
---

コンパイル時に $ \bmod $ が確定するときに使える

以下は宣言済み
- `using mint107 = StaticModint<1000000007>;`
- `using mint998 = StaticModint<998244353>;`
111 changes: 111 additions & 0 deletions math/dynamic_modint.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#pragma once

#include <utility>
#include <cassert>

template <int id> struct DynamicModint {
using mint = DynamicModint;
unsigned int _v;

static int m;
static void set_mod(const int _m) {
assert(_m >= 1);
m = _m;
}

static int mod() { return m; }
static unsigned int umod() { return m; }

DynamicModint() : _v(0) {}

template <class T> DynamicModint(T v) {
long long x = (long long)(v % (long long)(umod()));
if (x < 0) x += umod();
_v = (unsigned int)(x);
}

unsigned int val() const { return _v; }

mint &operator++() {
_v++;
if (_v == umod()) _v = 0;
return *this;
}
mint &operator--() {
if (_v == 0) _v = umod();
_v--;
return *this;
}
mint operator++(int) {
mint result = *this;
++*this;
return result;
}
mint operator--(int) {
mint result = *this;
--*this;
return result;
}

mint &operator+=(const mint &rhs) {
_v += rhs._v;
if (_v >= umod()) _v -= umod();
return *this;
}
mint &operator-=(const mint &rhs) {
_v -= rhs._v;
if (_v >= umod()) _v += umod();
return *this;
}
mint &operator*=(const mint &rhs) {
unsigned long long z = _v;
z *= rhs._v;
_v = (unsigned int)(z % umod());
return *this;
}
mint &operator/=(const mint &rhs) { return (*this *= rhs.inv()); }

mint operator+() const { return *this; }
mint operator-() const { return mint() - *this; }

mint pow(long long n) const {
assert(n >= 0);
mint x = *this, r = 1;
while (n) {
if (n & 1) r *= x;
x *= x;
n >>= 1;
}
return r;
}

mint inv() const {
auto eg = inv_gcd(_v, mod());
assert(eg.first == 1);
return eg.second;
}

friend mint operator+(const mint &lhs, const mint &rhs) { return mint(lhs) += rhs; }
friend mint operator-(const mint &lhs, const mint &rhs) { return mint(lhs) -= rhs; }
friend mint operator*(const mint &lhs, const mint &rhs) { return mint(lhs) *= rhs; }
friend mint operator/(const mint &lhs, const mint &rhs) { return mint(lhs) /= rhs; }
friend bool operator==(const mint &lhs, const mint &rhs) { return lhs._v == rhs._v; }
friend bool operator!=(const mint &lhs, const mint &rhs) { return lhs._v != rhs._v; }
friend std::ostream &operator<<(std::ostream &os, const mint &v) { return os << v.val(); }

static constexpr std::pair<int, int> inv_gcd(int a, int b) {
if (a == 0) return {b, 0};
int s = b, t = a, m0 = 0, m1 = 1;
while (t) {
const int u = s / t;
s -= t * u;
m0 -= m1 * u;
std::swap(s, t);
std::swap(m0, m1);
}
if (m0 < 0) m0 += b / s;
return {s, m0};
}
};
template <int id> int DynamicModint<id>::m = 998244353;
using mintd = DynamicModint<0>;
87 changes: 66 additions & 21 deletions math/static_modint.hpp
Original file line number Diff line number Diff line change
@@ -1,65 +1,68 @@
#pragma once

#include <utility>

// constexpr ... for constexpr bool prime()
template <int m> struct StaticModint {
using mint = StaticModint;
unsigned int _v;

static constexpr int mod() { return m; }
static constexpr unsigned int umod() { return m; }

StaticModint() : _v(0) {}
constexpr StaticModint() : _v(0) {}

template <class T> StaticModint(T v) {
template <class T> constexpr StaticModint(T v) {
long long x = (long long)(v % (long long)(umod()));
if (x < 0) x += umod();
_v = (unsigned int)(x);
}

unsigned int val() const { return _v; }
constexpr unsigned int val() const { return _v; }

mint &operator++() {
constexpr mint &operator++() {
_v++;
if (_v == umod()) _v = 0;
return *this;
}
mint &operator--() {
constexpr mint &operator--() {
if (_v == 0) _v = umod();
_v--;
return *this;
}
mint operator++(int) {
constexpr mint operator++(int) {
mint result = *this;
++*this;
return result;
}
mint operator--(int) {
constexpr mint operator--(int) {
mint result = *this;
--*this;
return result;
}

mint &operator+=(const mint &rhs) {
constexpr mint &operator+=(const mint &rhs) {
_v += rhs._v;
if (_v >= umod()) _v -= umod();
return *this;
}
mint &operator-=(const mint &rhs) {
constexpr mint &operator-=(const mint &rhs) {
_v -= rhs._v;
if (_v >= umod()) _v += umod();
return *this;
}
mint &operator*=(const mint &rhs) {
constexpr mint &operator*=(const mint &rhs) {
unsigned long long z = _v;
z *= rhs._v;
_v = (unsigned int)(z % umod());
return *this;
}
mint &operator/=(const mint &rhs) { return (*this *= rhs.inv()); }
constexpr mint &operator/=(const mint &rhs) { return (*this *= rhs.inv()); }

mint operator+() const { return *this; }
mint operator-() const { return mint() - *this; }
constexpr mint operator+() const { return *this; }
constexpr mint operator-() const { return mint() - *this; }

mint pow(long long n) const {
constexpr mint pow(long long n) const {
assert(n >= 0);
mint x = *this, r = 1;
while (n) {
Expand All @@ -70,15 +73,57 @@ template <int m> struct StaticModint {
return r;
}

mint inv() const { return pow(umod() - 2); }
constexpr mint inv() const {
if (prime) {
assert(_v);
return pow(umod() - 2);
} else {
auto eg = inv_gcd(_v, m);
assert(eg.first == 1);
return eg.second;
}
}

friend mint operator+(const mint &lhs, const mint &rhs) { return mint(lhs) += rhs; }
friend mint operator-(const mint &lhs, const mint &rhs) { return mint(lhs) -= rhs; }
friend mint operator*(const mint &lhs, const mint &rhs) { return mint(lhs) *= rhs; }
friend mint operator/(const mint &lhs, const mint &rhs) { return mint(lhs) /= rhs; }
friend bool operator==(const mint &lhs, const mint &rhs) { return lhs._v == rhs._v; }
friend bool operator!=(const mint &lhs, const mint &rhs) { return lhs._v != rhs._v; }
friend constexpr mint operator+(const mint &lhs, const mint &rhs) { return mint(lhs) += rhs; }
friend constexpr mint operator-(const mint &lhs, const mint &rhs) { return mint(lhs) -= rhs; }
friend constexpr mint operator*(const mint &lhs, const mint &rhs) { return mint(lhs) *= rhs; }
friend constexpr mint operator/(const mint &lhs, const mint &rhs) { return mint(lhs) /= rhs; }
friend constexpr bool operator==(const mint &lhs, const mint &rhs) { return lhs._v == rhs._v; }
friend constexpr bool operator!=(const mint &lhs, const mint &rhs) { return lhs._v != rhs._v; }
friend std::ostream &operator<<(std::ostream &os, const mint &v) { return os << v.val(); }

static constexpr bool prime = []() -> bool {
if (m == 1) return false;
if (m == 2 || m == 7 || m == 61) return true;
if (m % 2 == 0) return false;
unsigned int d = m - 1;
while (d % 2 == 0) d /= 2;
for (unsigned int a : {2, 7, 61}) {
unsigned int t = d;
mint y = mint(a).pow(t);
while (t != m - 1 and y != 1 and y != m - 1) {
y *= y;
t <<= 1;
}
if (y != m - 1 and t % 2 == 0) {
return false;
}
}
return true;
}();
static constexpr std::pair<int, int> inv_gcd(int a, int b) {
if (a == 0) return {b, 0};
int s = b, t = a, m0 = 0, m1 = 1;
while (t) {
const int u = s / t;
s -= t * u;
m0 -= m1 * u;
std::swap(s, t);
std::swap(m0, m1);
}
if (m0 < 0) m0 += b / s;
return {s, m0};
}
};
using mint107 = StaticModint<1000000007>;
using mint998 = StaticModint<998244353>;

0 comments on commit cc381ab

Please sign in to comment.