In [51]:
# https://www.codingninjas.com/studio/problems/stringmatch-rabincarp_1115738?utm_source=striver&utm_medium=website&utm_campaign=a_zcoursetuf&leftPanelTabValue=PROBLEM

In [52]:
# Rabin Karp

# there are 26 lowercase english characters
from typing import List


BASE = 26


def hash_char(ch):
    return ord(ch) - ord('a') + 1


def get_power(n: int):
    power = [1] * n

    for i in range(1, n):
        power[i] = BASE * power[i - 1]
    return power


def get_hash(string):
    n = len(string)
    power = get_power(n)
    sum_hash = 0

    for i in range(n):
        sum_hash = (sum_hash + (hash_char(string[i]) * power[n - 1 - i]))

    return sum_hash


def stringMatch(text: str, pattern: str) -> List[int]:
    n, m = len(pattern), len(text)

    power = get_power(n)
    ans = []

    pattern_hash = get_hash(pattern)
    rolling_hash = get_hash(text[:n])

    l, r = 0, n - 1

    while r < m:
        if pattern_hash == rolling_hash:
            ans.append(l + 1)

        left = text[l]

        # Remove left character
        rolling_hash = rolling_hash - hash_char(left) * power[n - 1]

        l += 1
        r += 1

        if r < m:
            right = text[r]

            # Right Shift
            rolling_hash = rolling_hash * BASE \
                            + hash_char(right) * power[0]

    return ans

In [53]:
text = "cxyzghxyzvjkxyz"
pattern = "xyz"
stringMatch(text, pattern)

[2, 7, 13]

In [54]:
text = "ababacabab"
pattern = "aba"
stringMatch(text, pattern)

[1, 3, 7]

In [55]:
# KMP for the same problem

def stringMatch(text: str, pattern: str) -> int:
    lps = [0] * len(pattern)
    prevLPS, i = 0, 1
    while i < len(pattern):
        if pattern[i] == pattern[prevLPS]:
            lps[i] = prevLPS + 1
            prevLPS += 1
            i += 1
        else:
            if prevLPS == 0:
                lps[i] = 0
                i += 1
            else:
                prevLPS = lps[prevLPS - 1]
    # print(f'{lps=}')
    i = 0
    j = 0
    ans = []
    while i < len(text):
        if text[i] == pattern[j]:
            i += 1
            j += 1
            if j == len(pattern):
                ans.append(i - len(pattern) + 1)
                j = 0
                i = i - len(pattern) + 1 
                # this is needed because if we form a match, 
                # the next match could start from next index
        else:
            if j == 0:
                i += 1
            else:
                j = lps[j-1]
    return ans

In [56]:
text = "cxyzghxyzvjkxyz"
pattern = "xyz"
stringMatch(text, pattern)

[2, 7, 13]

In [57]:
text = "ababacabab"
pattern = "aba"
stringMatch(text, pattern)

[1, 3, 7]