# fast motif scanning with aho corasick automaton

This has been my (Maarten's) attempt at implementing a aho corasick automaton for **motif scanning**. This automaton is basically an implementation of what is described in the paper [Finding Significant Matches of Position Weight Matrices in Linear Time](https://ieeexplore.ieee.org/document/4803829) and of which [MOODS](https://github.com/jhkorhonen/MOODS) is already a python implementation. This automaton should scan (according to the paper) +/- **100 times faster** than the naive-ish approach of gimmemotifs. Would be awesome to have this incorporated, however there are some problems I haven't been able to solve... (yet?)

Where normal motif scanning takes O(mn) time, where m is the motif length and n the sequence length, aho corasick scanning only takes O(n) time. This means that the amount of time it takes to scan a sequence is only dependent on the sequence you are scanning! How does this work?! By not caring about poor matches. Do we actually care about poor matches of a motif against a sequence we are scanning? Not really right? So if we were to set a threshold of a minimum motif score, we can compute **all** sequences of length m that would pass that threshold. We can then make a aho-corasick automaton of that (I like [this](https://www.youtube.com/watch?v=O7_w001f58c) explanation from youtube), and scan in linear time!

One of the most obvious problems with this is that precomputing all sequences that would pass a threshold takes an enormous amount of memory! This is the biggest problem I wasn't able to solve.. Not so sure what would help here, perhaps splitting the sequences across multiple automata and loading them one for one, however this probably gives a big hit to performance. Another problem is that because it takes so much memory, we have to scan for smaller motifs inside motifs (e.g. we have a motif of length 20, that is too long to compute all possible kmers above a certain threshold (memory wise), so we only compute all kmers of length 10 that can still pass that threshold to save memory). However this gives so many false positives we are scanning the whole sequence anyways.

Setting a low fpr (threshold) somewhat alleviates these problems, but small motives can not be scanned anymore because they are never possible to be found-ish, and that makes it impractical for real-life situations. Another potential option I can think about is completely reimplementing a aho-corasick automaton from scratch, which only reserves 2 bits per nucleotide, instead of the approx 32/64 bits that it now uses (current implementation uses https://github.com/WojciechMula/pyahocorasick which is awesome). But is that really worth it? Probably not! Will it actually be fast if we implement this automaton ourselves? Again, probably not! Maybe there is actually a very obvious optimization I didn't see/think of yet! 

Anyways, the current implementation is pretty darn fast for short-ish motifs +/- 14 (35 times faster!), but the exploding memory usage is making it unusable for larger motifs... :( Making a hybrid method depending on the motif length seems like a pain to implement, and not sure if that even improves much considering our big Os; O(mn) and O(n).

In [2]:
import time
import random
import sys
from multiprocessing import Pool

import ahocorasick
import numpy as np
import gimmemotifs as gimme
from gimmemotifs.c_metrics import pwmscan

In [3]:
# load our gimme motifs
motifs = gimme.motif.read_motifs(as_dict=True)

In [4]:
# get our false positive thresholds from gimme (takes a while the first time)
s = gimme.scanner.Scanner()
s.set_genome("hg38")
s.set_motifs([motif for motif in motifs.values()])
s.set_threshold(fpr=0.01)
seqs = ["".join([random.choice(["A", "C", "G", "T"]) for _ in range(200)]) for _ in range(20_000)]
thresholds = s.get_gc_thresholds(seqs)
thresholds

{'GM.5.0.Sox.0001': None,
 'GM.5.0.Homeodomain.0001': 8.957283373141681,
 'GM.5.0.Mixed.0001': 7.216143228712358,
 'GM.5.0.Nuclear_receptor.0001': 6.533246258006252,
 'GM.5.0.Mixed.0002': 8.458549842945988,
 'GM.5.0.Nuclear_receptor.0002': 6.998579776560598,
 'GM.5.0.bHLH.0001': None,
 'GM.5.0.Myb_SANT.0001': 6.606636021267626,
 'GM.5.0.C2H2_ZF.0001': 8.608872504295903,
 'GM.5.0.GATA.0001': 8.92030307722338,
 'GM.5.0.C2H2_ZF.0002': 9.363668906652974,
 'GM.5.0.bZIP.0001': 7.585039441313281,
 'GM.5.0.Homeodomain.0002': None,
 'GM.5.0.Ets.0001': None,
 'GM.5.0.bZIP.0002': 7.486839360932254,
 'GM.5.0.IRF.0001': 9.557378169169375,
 'GM.5.0.bHLH.0002': 8.43394830235204,
 'GM.5.0.Nuclear_receptor.0003': 12.464078241477036,
 'GM.5.0.bZIP.0003': 8.419496266209979,
 'GM.5.0.bHLH.0003': 9.68711545390085,
 'GM.5.0.HSF.0001': 12.100057567291481,
 'GM.5.0.Homeodomain.0003': 10.04677662189859,
 'GM.5.0.C2H2_ZF.0003': 8.908807970364814,
 'GM.5.0.ARID_BRIGHT.0001': None,
 'GM.5.0.Homeodomain_POU.0001':

In [5]:
# make random sequences to benchmark against
nr_seq = 2_000
len_seq = 2_00

sequences = ["".join([random.choice("ACGT") for i in range(len_seq)]) for _ in range(nr_seq)]

## Get all potential kmers for a motif and threshold

Remember that an aho corasick automaton needs all the kmers that pass a certain threshold for a motif precomputed? Well.. That's a pain to compute. I implemented a recursive function that does that, it works decently well but for large kmers it takes really long (but there are also really a lot of matches possible). 

In [6]:
def all_kmers(motif, threshold, max_length=None, reverse_complement=False):
    """
    get all kmers for a motif that pass a certain threshold
    """
    # pre-define our private helper function that yields everything we need
    def depth_first_search(logodds, threshold, reverse_complement=False):
        """

        """
        def _depth_first_search(sequence, score):
            """

            """
            seq_len = len(sequence)

            if (score + cumlogodds[seq_len]) + 1.e-5 >= threshold:
                for nuc_pos, nuc_logodds in enumerate(logodds[seq_len]):
                    newseq = sequence + idx2nuc[nuc_pos]
                    newscore = score + logodds[seq_len, nuc_pos]

                    if len(newseq) == logodds.shape[0]:
                        if newscore + 1.e-5 >= threshold:
                            yield newseq
                    else:
                        yield from _depth_first_search(newseq, newscore)

        idx2nuc = ["A", "C", "G", "T"]
        if reverse_complement:
            reverse_dict = {"A": "T",
                            "C": "G", 
                            "G": "C",
                            "T": "A"}
            for kmer in _depth_first_search("", 0):
                yield kmer
                yield "".join(reverse_dict[nuc] for nuc in reversed(kmer))
        else:
            yield from _depth_first_search("", 0)

    
    logodds = np.array(motif.logodds)

    if max_length is None:
        max_length = logodds.shape[0]
    
    if max_length < logodds.shape[0]:
        maxes = np.max(logodds, axis=1)
        sums = [np.sum(maxes[i:i+max_length]) for i in range(logodds.shape[0] - max_length + 1)]
        extend = (-max_length - np.argmax(sums) + 1, -max_length + logodds.shape[0] + np.argmax(sums) + 1)
        logodds = logodds[np.argmax(sums):np.argmax(sums) + max_length]
        cumlogodds = np.flip(np.cumsum(np.flip(np.max(logodds, axis=1))))
        threshold_modifier = np.sum(maxes[:np.argmax(sums)]) + np.sum(maxes[max_length + np.argmax(sums):])
        threshold -= threshold_modifier
    else:
        extend = (-logodds.shape[0] + 1, + 1)
        cumlogodds = np.flip(np.cumsum(np.flip(np.max(logodds, axis=1))))
        if max_length > logodds.shape[0]:
            max_length = logodds.shape[0]


    kmers = list(depth_first_search(logodds, threshold, reverse_complement))
    return kmers

# kmers = all_kmers(motifs["GM.5.0.MADS_box.0008"], thresholds["GM.5.0.MADS_box.0008"])
# print(len(kmers), 4**13)
# # print(kmers)

## make the automaton

To make the automaton much faster, we combine all the motifs into a single automaton. This means we only have to scan each sequence once, but still get the results for all motifs at the same time!

In [7]:
# make a set of motifs we want to test against
# we filter here for a maximum motif length, but this is not strictly necessary

MAX_MOTIF_LENGTH = 14
testset = [k for k, v in thresholds.items() if v is not None and len(motifs[k]) < MAX_MOTIF_LENGTH]

print(len(testset))

897


In [8]:
MAX_KMER_LENGTH = 14

motif2index = {motif: i for i, motif in enumerate(testset)}
index2motif = {i: motif for motif, i in motif2index.items()}

pool = Pool(18)
jobs = []

now = time.time()

for i, motif_name in enumerate(testset):
    job = pool.apply_async(all_kmers, (motifs[motif_name], thresholds[motif_name], MAX_KMER_LENGTH, True))
    jobs.append(job)

A = ahocorasick.Automaton()
for i, job in enumerate(jobs):
#     print(f"Getting the kmers of {index2motif[i]}")
    matching_kmers = job.get()
    for kmer in matching_kmers:
        if kmer in A:
            A.add_word(kmer, (*A.get(kmer), i))           
        else:
            A.add_word(kmer, (i,))

print(time.time() - now)
A.make_automaton()
print(time.time() - now)
pool.close()

1.2250134944915771
1.4587273597717285


In [9]:
# get the stats of the automaton
A.get_stats()

{'nodes_count': 1190969,
 'words_count': 479240,
 'longest_word': 13,
 'links_count': 1190968,
 'sizeof_node': 32,
 'total_size': 47638752}

In [10]:
# the size in gigabytes of the automaton
A.get_stats()["total_size"] / 1024 ** 3

0.044367045164108276

## benchmark

We scan one with gimmemotifs, and once with the automaton. To give a fair comparison, when the automaton gets a hit we scan that hit/subsequence with the gimmemotifs pwmscan (what we are actually comparing against) so we get identical results.

In [11]:
# scan with gimmemotifs
now = time.time()
result_pwmscan = np.zeros((nr_seq, len(testset)))
for i, sequence in enumerate(sequences):
    for j, motif_name in enumerate(testset):
        threshold = thresholds[motif_name]
        result = pwmscan(sequence, motifs[motif_name].logodds, threshold, 1, True)
        if len(result):
            result_pwmscan[i, j] = result[0][0]
        else:
            result_pwmscan[i, j] = 0

gimme_time = time.time() - now
print(f"it took gimme {gimme_time}s to scan.")

it took gimme 9.24500322341919s to scan.


In [12]:
# scan with the automaton
now = time.time()

result_ahocorasick = np.zeros((nr_seq, len(testset)))
motif2index = {motif: i for i, motif in enumerate(testset)}
index2motif = {i: motif for motif, i in motif2index.items()}

total = 0
for i, sequence in enumerate(sequences):
#     print(f"sequence {i}")
    for pos, result_motifs in A.iter(sequence):
        for motif_index in result_motifs:
            motif_name = index2motif[motif_index]
    
            # dumb/naive way of getting sequence range
#             lo_idx = lo_idx if (lo_idx := pos - len(motifs[motif_name])) > 0 else 0
#             hi_idx = hi_idx if (hi_idx := pos + len(motifs[motif_name])) > len(sequence) else len(sequence)

            lo_idx = lo_idx if (lo_idx := pos - len(motifs[motif_name])) > 0 else 0
            hi_idx = pos + 1

            subseq = sequence[lo_idx:hi_idx]
            result = pwmscan(subseq, motifs[motif_name].logodds, thresholds[motif_name], 1, True)
            if len(result) and result[0][0] > result_ahocorasick[i, motif_index]:
                result_ahocorasick[i, motif_index] = result[0][0]

automaton_time = time.time() - now
print(f"it took the automaton {automaton_time}s to scan.")

it took the automaton 0.21793246269226074s to scan.


In [13]:
# confirm they give the same result (should return empty array!!)
np.where(np.isclose(result_pwmscan, result_ahocorasick) != True)

(array([], dtype=int64), array([], dtype=int64))

In [14]:
# find where the results are different for debugging
for x, y in zip(*np.where(np.isclose(result_pwmscan, result_ahocorasick) != True)):
    print(x, y, result_ahocorasick[x, y], result_pwmscan[x, y], thresholds[index2motif[y]], index2motif[y], len(motifs[index2motif[y]]))
    break

In [15]:
f"Scanning with the automaton was {gimme_time/automaton_time} faster compared to scanning with gimmemotifs."

'Scanning with the automaton was 42.4214139977573 faster compared to scanning with gimmemotifs.'