In [15]:
class Symbol:
  def __init__(self, index, val):
    self.index = index
    self.val = val

  def __str__(self):
    return self.val

class StateMachine:
  def __init__(self, pattern):
    self.pattern = pattern
    self.pi = StateMachine.build_pi(pattern)

  def build_state_cache(self, fsst_symbols):
    assert self.pattern is not None

    self.fsst_symbols = fsst_symbols
    self.cache_state = [-1] * (len(self.pattern) * len(fsst_symbols))

    for i in range(len(self.pattern)):
      for symbol in fsst_symbols:
        # Set the current state to position `i`.
        self.init_state(i)

        # Simulate.
        self.accept_symbol(symbol)

        # Cache the state we arrived at.
        self.cache_state[i * len(fsst_symbols) + symbol.index] = self.curr_state
        
  def init_state(self, pos=0):
    self.curr_state = pos

  def accept_letter(self, letter):
    while (self.curr_state > 0) and (self.pattern[self.curr_state] != letter):
      self.curr_state = self.pi[self.curr_state - 1]
    
    if self.pattern[self.curr_state] == letter:
      self.curr_state += 1

  def accept_symbol(self, symbol):
    for letter in symbol.val:
      self.accept_letter(letter)

      # Already reached the final state?
      if self.curr_state == len(self.pattern):
        return

  def accept_cached_symbol(self, symbol):
    assert self.fsst_symbols is not None
    assert self.cache_state is not None
    
    self.curr_state = self.cache_state[self.curr_state * len(self.fsst_symbols) + symbol.index]

  @staticmethod
  def build_pi(P):
    # Init.
    pi = [0] * len(P)

    # Nowhere to go.
    pi[0] = 0

    # Init the state.
    k = 0
    for q in range(1, len(P)):
      while (k > 0) and (P[k] != P[q]):
        k = pi[k - 1]
      
      # Match? Then advance.
      if P[k] == P[q]:
        k += 1
      
      # Store the state.
      pi[q] = k

    return pi
  
  def naive_match(self, text):
    # Init.
    self.init_state()

    for i in range(len(text)):
      if isinstance(text[i], Symbol):
        self.accept_symbol(text[i])
      else:
        self.accept_letter(text[i])
      
      if self.curr_state == len(self.pattern):
        return True
      
    return False
  
  def fast_match(self, text):
    assert self.cache_state is not None

    # Init.
    self.init_state()

    for i in range(len(text)):
      if isinstance(text[i], Symbol):
        self.accept_cached_symbol(text[i])
      else:
        self.accept_letter(text[i])
      
      if self.curr_state == len(self.pattern):
        return True
      
    return False

In [23]:
import random
import string

random.seed(0)

letter_set = ['a', 'b', 'c']

def generate_random_string(fsst_symbols):
  return [random.choice(fsst_symbols if random.random() < 0.2 else letter_set) for _ in range(20)]

def generate_random_symbol_string():
  return ''.join(random.choices(letter_set, k=8))

def python_match(column_data, pattern):
  ret = []

  # NOTE: These are the raw strings.
  for index, text in enumerate(column_data):
    if pattern in text:
      ret.append(index)
  return ret

def eval_naive_match(column_data, pattern):
  state_machine = StateMachine(pattern)

  ret = []
  for index, text in enumerate(column_data):
    if state_machine.naive_match(text):
      ret.append(index)
  return ret

def eval_fast_match(column_data, pattern):
  state_machine = StateMachine(pattern)
  state_machine.build_state_cache(fsst_symbols)

  ret = []
  for index, text in enumerate(column_data):
    if state_machine.fast_match(text):
      ret.append(index)
  return ret

FSST_SIZE = 256
fsst_symbols = [Symbol(i, generate_random_symbol_string()) for i in range(FSST_SIZE)]

N = 10_000
strings = [generate_random_string(fsst_symbols) for _ in range(N)]

raw_strings = [''.join(list(map(str, x))) for x in strings]

for k in range(1, 10):
  pattern = ''.join(random.choices(letter_set, k=k))

  import time
  t0 = time.time_ns()
  ret0 = python_match(raw_strings, pattern)
  t1 = time.time_ns()
  ret1 = eval_naive_match(strings, pattern)
  t2 = time.time_ns()
  ret2 = eval_fast_match(strings, pattern)
  t3 = time.time_ns()

  print(f'patttern= {pattern}:')
  # print(f'ret0={ret0}, ret1={ret1}, ret2={ret2}')
  assert ret0 == ret1
  assert ret0 == ret2

  print(f'Python: {(t1 - t0) / 1e6} ms')
  print(f'Naive: {(t2 - t1) / 1e6} ms')
  print(f'Fast: {(t3 - t2) / 1e6} ms')

patttern= c:
Python: 0.669 ms
Naive: 8.007 ms
Fast: 6.562 ms
patttern= cb:
Python: 1.113 ms
Naive: 21.972 ms
Fast: 13.344 ms
patttern= cbc:
Python: 1.711 ms
Naive: 55.511 ms
Fast: 28.247 ms
patttern= baca:
Python: 3.131 ms
Naive: 91.087 ms
Fast: 43.437 ms
patttern= bbaca:
Python: 3.515 ms
Naive: 101.606 ms
Fast: 48.529 ms
patttern= bccacb:
Python: 2.412 ms
Naive: 110.959 ms
Fast: 51.632 ms
patttern= cbcbacc:
Python: 4.375 ms
Naive: 111.695 ms
Fast: 52.691 ms
patttern= bbbbaacc:
Python: 4.555 ms
Naive: 109.001 ms
Fast: 51.984 ms
patttern= bbacbbabc:
Python: 2.336 ms
Naive: 109.475 ms
Fast: 53.076 ms
