# 文字列検索のライブラリ

https://drken1215.hatenablog.com/entry/2019/09/16/014600

# Suffix Array

Suffix Array の LCP (Longest Common Prefix)配列

https://atcoder.jp/contests/abc141/tasks/abc141_e

In [1]:
class RMQ:
    inf = 10 ** 9 + 7

    def __init__(self, x):
        # 1-index
        self.n = 2 ** len(x).bit_length()
        self.dat = [self.inf] * (2 * self.n - 1)
        
        for k, v in enumerate(x):
            self.update(k, v)
        
    def update(self, k, v):
        k += self.n - 1
        self.dat[k] = v
        while k > 0:
            k = (k - 1) // 2
            self.dat[k] = min(self.dat[k * 2 + 1], self.dat[k * 2 + 2])
            
    def query(self, a, b, k=0, l=0, r=None):
        # Retrun min [a, b)
        if r is None:
            r = self.n
        
        if b <= l or r <= a:
            return self.inf
        
        if a <= l and r <= b:
            return self.dat[k]
        
        vl = self.query(a, b, k * 2 + 1, l, (l + r) // 2)
        vr = self.query(a, b, k * 2 + 2, (l + r) // 2, r)
        
        return min(vl, vr)
            
    def __repr__(self):
        return ', '.join([str(v) for v in self.dat])

In [2]:
class SuffixArray:
    """Manber & Myers `(O(n log^2 n))`."""
    def __init__(self, s):
        if not isinstance(s, str):
            raise ValueError("Input must be string")

        # Input array
        self.array = s
        self.len = len(s) + 1

        # Suffix array
        self.sa = list(range(self.len))
        self.rank = list(map(ord, self.array)) + [-1]

        # Longest Common Prefix
        self.lcp = [-1] * self.len

        # RMQ
        self.rmq = None

        # Initialize
        self._construct()
        self._calc_lcp()
        self._init_rmq()

    def __getitem__(self, index):
        return self.array[self.sa[index]:]

    def __len__(self):
        return self.len

    def _construct(self):
        """Sorts 2k chars."""
        n = self.len
        k = 1
        while k <= n:
            self.sa.sort(key=lambda x: self._key(x, k))

            tmp = [0] * n
            for i in range(1, n):
                tmp[self.sa[i]] = (tmp[self.sa[i - 1]]
                                   + int(self.compare(i - 1, i, k)))

            self.rank = tmp[:]
            k *= 2

    def _key(self, i, k):
        """Returns key for sort."""
        v1 = self.rank[i]
        v2 = self.rank[i + k] if i + k < self.len else -1
        return (v1, v2)

    def compare(self, i, j, k):
        key1 = self._key(self.sa[i], k)
        key2 = self._key(self.sa[j], k)

        # Compare rank[i] and rank[j]
        if key1[0] != key2[0]:
            return key1[0] < key2[0]

        # Compare rank[i + k] and rank[j + k]
        return key1[1] < key2[1]

    def _calc_lcp(self):
        n = self.len
        for i in range(n):
            self.rank[self.sa[i]] = i

        self.lcp[0] = 0
        h = 0
        for i in range(n - 1):
            j = self.sa[self.rank[i] - 1]

            if h > 0:
                h -= 1
            while j + h < n - 1 and i + h < n - 1:
                if self.array[j + h] != self.array[i + h]:
                    break
                h += 1

            self.lcp[self.rank[i] - 1] = h
            
    def _init_rmq(self):
        lcp = [self.lcp[self.rank[i]] for i in range(self.len)]
        self.rmq = RMQ(lcp)

    def get_lcp(self, i, j):
        if self.rank[i] >= self.rank[j]:
            i, j = j, i

        return self.rmq.query(self.rank[i], self.rank[j])

In [3]:
sa = SuffixArray("abracadabra")

for i in range(len(sa)):
    print(i, sa.sa[i], sa.lcp[i], sa[i])

0 11 0 
1 10 1 a
2 7 4 abra
3 0 1 abracadabra
4 3 1 acadabra
5 5 0 adabra
6 8 3 bra
7 1 0 bracadabra
8 4 0 cadabra
9 6 0 dabra
10 9 2 ra
11 2 -1 racadabra


In [4]:
for i in range(len(sa)):
    print(sa[sa.rank[i]], sa.lcp[sa.rank[i]])

abracadabra 1
bracadabra 0
racadabra -1
acadabra 1
cadabra 0
adabra 0
dabra 0
abra 4
bra 3
ra 2
a 1
 0


部分文字列の一致数の最大値

In [5]:
res = 0
for i in range(1, len(sa)):
    for j in range(i + 1, len(sa)):
        res = max(res, sa.get_lcp(i, j))
print(res, "(ans=4)")

4 (ans=4)


# Z-algorithm

各`i`について`LCP(0, i)`（`S[0:]`と`S[i:]`との最長一致文字列の長さ）を$O(|S|)$で求める．

解説

* https://qiita.com/Pro_ktmr/items/16904c9570aa0953bf05
* https://snuke.hatenablog.com/entry/2014/12/03/214243

In [9]:
def z_algorithm(s):
    n = len(s)
    res = [0] * n
    i = 1
    j = 0
    while i < n:
        # i番目以降の一致文字数
        while i + j < n and s[j] == s[i + j]:
            j += 1
        res[i] = j

        # 一文字も一致しない場合，次の文字へ
        if j == 0:
            i += 1
            continue

        # 一致したところまでを埋める
        k = 1
        while i + k < n and k + res[k] < j:
            res[i + k] = res[k]
            k += 1
        
        i += k
        j -= k
        
    return res

In [17]:
z_algorithm("abracadabra")

[0, 0, 0, 1, 0, 1, 0, 4, 0, 0, 1]

In [19]:
S = "abracadabra"
N = len(S)

res = 0
for i in range(N):
    lcp = z_algorithm(S[i:])
    for j in range(N - i):
        res = max(res, min(lcp[j], j))
        
print(res, "(ans=4)")

4 (ans=4)


# ローリングハッシュ＋二分探索

In [1]:
class RollingHash:
    base1 = 1007
    base2 = 2009
    mod1 = 10 ** 9 + 7
    mod2 = 10 ** 9 + 9
    
    def __init__(self, s):
        self.s = s
        self.len = len(s) + 1

        self.hash1 = [0] * self.len
        self.hash2 = [0] * self.len
        self.power1 = [1] * self.len
        self.power2 = [1] * self.len
        for i in range(self.len - 1):
            self.hash1[i + 1] = (self.hash1[i] * self.base1 + ord(s[i])) % self.mod1
            self.hash2[i + 1] = (self.hash2[i] * self.base2 + ord(s[i])) % self.mod2
            self.power1[i + 1] = (self.power1[i] * self.base1) % self.mod1
            self.power2[i + 1] = (self.power2[i] * self.base2) % self.mod2
            
    def get_hash(self, l, r):
        """Gets hash of S[left:right]"""
        res1 = self.hash1[r] - self.hash1[l] * self.power1[r - l] % self.mod1
        if res1 < 0:
            res1 += self.mod1
        res2 = self.hash2[r] - self.hash2[l] * self.power2[r - l] % self.mod2
        if res2 < 0:
            res2 += self.mod2
            
        return (res1, res2)
    
    def get_lcp(self, a, b):
        """Gets lcp of S[a:] and S[b:]"""
        low = 0
        high = min(len(self.hash1) - a, len(self.hash1) - b)
        
        while high - low > 1:
            mid = (low + high) >> 1
            if self.get_hash(a, a + mid) != self.get_hash(b, b + mid):
                high = mid
            else:
                low = mid
                
        return low
    
    def bisect(self, d):
        """Checks the existence of (i, j) s.t. S[i:d] == S[j:d] and i+d <= j."""
        ctr = {}
        for i in range(self.len - d):
            p = self.get_hash(i, i + d)
            if p in ctr:
                if i - ctr[p] >= d:
                    return True
            else:
                ctr[p] = i

        return False

In [8]:
S = "abracadabra"
N = len(S)
rh = RollingHash(S)

res = 0
for i in range(N):
    for j in range(i + 1, N):
        res = max(res, min(rh.get_lcp(i, j), j - i))
        
print(res, "(ans=4)")

4 (ans=4)


# ローリングハッシュ＋二分探索の高速化

In [2]:
S = "abracadabra"
N = len(S)
rh = RollingHash(S)

left = 0
right = N // 2 + 1
while right - left > 1:
    mid = (left + right) >> 1
    if rh.bisect(mid):
        left = mid
    else:
        right = mid

print(left, "(ans=4)")

4 (ans=4)


# DPで文字列検索

`dp[i][j]`:= i文字目からとj文字目からとで最長の長さ

In [8]:
S = "abracadabra"
N = len(S)
res = 0
dp = [[0] * (N + 1) for _ in range(N + 1)]
for i in reversed(range(N)):
    for j in reversed(range(i + 1, N)):
        if S[i] == S[j]:
            dp[i][j] = max(dp[i][j], dp[i + 1][j + 1] + 1)
        res = max(res, min(dp[i][j], j - i))
        
print(res, "(ans=4)")

4 (ans=4)


# KMP法

Knuth-Morris-Pratt algorithm

* https://snuke.hatenablog.com/entry/2014/12/01/235807
* https://snuke.hatenablog.com/entry/2017/07/18/101026

#### MP法

各`i`について「文字列`S[0, i-1]`の接頭辞と接尾辞との一致文字数」の配列を$O(|S|)$で求める．

In [12]:
def mp(s):
    n = len(s)
    a = [0] * (n + 1)
    a[0] = -1
    j = -1
    for i in range(n):
        while j >= 0 and s[i] != s[j]:
            j = a[j]
        j += 1
        a[i + 1] = j
        
    return a

In [13]:
S = "aabaabaaa"
a = mp(S)

a

[-1, 0, 1, 0, 1, 2, 3, 4, 5, 2]

#### KMP法

In [1]:
def kmp(s):
    n = len(s)
    a = [0] * (n + 1)
    a[0] = -1

    j = -1
    for i in range(n):
        while j >= 0 and s[i] != s[j]:
            j = a[j]
            
        j += 1
        if i + 1 < n and s[i + 1] == s[j]:
            a[i + 1] = a[j]
        else:
            a[i + 1] = j
            
    return a

In [2]:
S = "aabaabaaa"
a = kmp(S)

a

[-1, -1, 1, -1, -1, 1, -1, -1, 5, 2]

In [6]:
S = "abracadabra"
a = kmp(S)