# Basic Prefix Tree implementation

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from src.sequence.gen_spectra import gen_spectrum, gen_min_ordering
from src.utils import ppm_to_da, insort_by_func

from collections import namedtuple
from more_itertools import flatten

import numpy as np

In [2]:
from pyteomics import fasta
def read_fasta(fasta_file):
    #split the name on the OS value if it exists
    get_name = lambda name: name[:name.index('OS=')-1] if 'OS=' in name else name

    prots = []

    # go through each entry in the fasta and put it in memory
    for i, entry in enumerate(fasta.read(fasta_file)):

        # take the description without the 'sp' value
        desc = entry.description.split('|')[1:] if '|' in entry.description else entry.description

        # if the id is in the description, take it
        if len(desc) > 1:
            id_ = desc[0]
            name = get_name(desc[1])

        # make the id just the number
        else:
            id_ = i
            name = get_name(desc[0])

        # get the sequence
        seq = entry.sequence

        # make the entry and add it to prots
        prots.append({'name': name, 'id': id_, 'sequence': seq})
        
    return prots

In [7]:
from heapq import merge
from math import ceil
from collections import defaultdict

def build_and_sort_kmers(fasta_file):
        
    # read in the fasta file
    prots = read_fasta(fasta_file)
    
    # for printing purposes
    plen = len(prots)
    
    # keep track of unique kmers
    kmers = defaultdict(lambda x: 0)
    
    
    for i, prot in enumerate(prots):
        print(f'On protein {i}/{plen}\r', end='')

        # do the first 1 to 30 then keep going, then continue until we get to the last 30 then iter down
        seq = prot['sequence']

        # first 30 
        for i in range(1, 30):
            kmers[seq[:i]]= None

        # all the 30 mers in between
        for i in range(len(seq) - 29):
            kmers[seq[i:i+30]] = None

        # get the last 30
        for i in range(len(seq)-30, len(seq)):
            kmers[seq[i:]] = None
            
            
    
    # break into lists of 1000000 and sort them individualy
    kmer_seqs = list(kmers.keys())
    batch_size = 1000000
    num_batches = ceil(len(kmer_seqs)/batch_size)

    # broken into lists
    lols = [kmer_seqs[i*batch_size:(i+1)*batch_size] for i in range(num_batches)]

    # sort and replace the list at each index
    for i in range(len(lols)):
        print(f'sorting list {i+1} of {len(lols)}\r', end='')
        lols[i] = sorted(lols[i], key=gen_min_ordering)
        
    # merge them in a sorted fashion
    print('\nMerging sorted lists...')
    sorted_kmers = merge(*lols, key=gen_min_ordering)
    print('Done')
    return sorted_kmers, len(kmer_seqs)
            

In [3]:
BranchResult = namedtuple(
    'BranchResult', 
    ['keys', 'value'], 
    defaults=[[], []]
)
class tree_node:
    '''
    Node that contains mass information. 
    NOTE: 
    '''
    
    def __init__(self, key, value):
        self.keys = [key]
        self.value = value
        self.children = []
        
    # creat a and return a new node after adding it to my children
    def add_child(self, key, value):
        new_child = tree_node(key, value)
        self.children.append(new_child)
        self.add_key(key) 
        return new_child
        
    # see if any of my children have value == value
    def has_child(self, value):
        return any([child.value == value for child in self.children])
    
    # if any of my children have value == value, return it. Otherwise None
    def get_child(self, value):
        if not self.has_child(value):
            return None
        
        for c in self.children:
            if c.value == value:
                return c
    
    # add a new key to myself
    def add_key(self, key):
        key not in self.keys and self.keys.append(key)
        
    def show(self, tabs=0):
        ptabs = ''.join(['\t' for _ in range(tabs)])
        print(f'{ptabs}|---> keys: {self.keys}, value: {self.value}')
        [c.show(tabs+1) for c in self.children]
        
class tree:
    
    def __init__(self, da_tolerance=.01):
        self.root = tree_node(None, [])
        self.da_tolerance = da_tolerance
        
    def add_sequence(self, key, sequence):
        
        current_node = self.root
        
        for value in sequence:
            
            # if i get a child node back, make it current node and 
            # add a key 
            if current_node.get_child(value) is not None:
                current_node = current_node.get_child(value)
                current_node.add_key(key)
                
            # add a new child to current node
            # add child adds the key to the current node and creates the new child
            # that is returned
            else:
                current_node = current_node.add_child(key, value)
            
    def search(self, sequence):
        current_node = self.root
        
        for value in sequence:
            if not current_node.has_child(value):
                return []
            
            current_node = current_node.get_child(value)
            
        return current_node.keys
    
    def gap_search_rec(self, sequence, node, current_gap, gap_limit) -> list:
        
        # check to see if the total-current < 0. If so return empty value
        if gap_limit - current_gap < 0:
            return [BranchResult([], '')]
        
        # if the sequence is none
        if len(sequence) <= 0:
            return [BranchResult([], '')]
        
        # value found bool
        value_found = node.value == sequence[0]
        
                
        # if my value is the one in question, don't add a gap
        gap_addition = 0 if value_found else 1
        
        # start at the next position if I've found the correct one, otherwise don't
        seq_start = 1 if value_found else 0
        
        # if we are at the end of the sequence and we've found the correct 
        # match, return 
        if len(sequence) == 1 and value_found:
            return [BranchResult(node.keys, node.value)]
                
        # all other cases look through all my children
        return_values = list(flatten([self.gap_search_rec(sequence[seq_start:], child, current_gap + gap_addition, gap_limit)\
                        for child in node.children]))
        
        # filter out all return values that have [] as the keys
        return_values = [x for x in return_values if len(x.keys)]
        
        # if the current node is the root, don't append its stuff
        if node == self.root:
            return return_values
        
        # add node values to reconstruct the sequuence
        return [BranchResult(result.keys, node.value + result.value)\
                for result in return_values]
    
    def gap_search(self, sequence, gap):
        # incerment gap because we will be checking on the root that has no data
        gap += 1
        result = self.gap_search_rec(sequence, self.root, 0, gap) 
        
        # if result is none, return empty list
        if not len(result):
            return []
        
        longest = max(map(lambda x: len(x[1]), result))
        return [x for x in result if len(x[1]) == longest]
    
    def show(self):
        print('root')
        [c.show() for c in self.root.children]
        
                
        

## Make sure the tree actually works

In [3]:
t = tree()
t.add_sequence('A', 'ABCDEF')
t.add_sequence('B', 'ABCXYZ')
t.show()


root
|---> keys: ['A', 'B'], value: A
	|---> keys: ['A', 'B'], value: B
		|---> keys: ['A', 'B'], value: C
			|---> keys: ['A'], value: D
				|---> keys: ['A'], value: E
					|---> keys: ['A'], value: F
			|---> keys: ['B'], value: X
				|---> keys: ['B'], value: Y
					|---> keys: ['B'], value: Z


In [4]:
print(t.search('ABC'))
print(t.search('ABCDEF'))
print(t.search('ABCXYZ'))
print(t.search('XYZ'))

['A', 'B']
['A']
['B']
[]


In [5]:
print(t.gap_search('ABC', 3))
print(t.gap_search('AC', 3))

[BranchResult(keys=['A', 'B'], value='ABC')]
[BranchResult(keys=['A', 'B'], value='ABC')]


In [6]:
print(t.gap_search('ACD', 1))
print(t.gap_search('ADF', 4))
print(t.gap_search('XYZ', 3))
print(t.gap_search('XYZ', 2))

[BranchResult(keys=['A'], value='ABCD')]
[BranchResult(keys=['A'], value='ABCDEF')]
[BranchResult(keys=['B'], value='ABCXYZ')]
[]


# Modified tree to make it a mass tree

In [7]:
from __future__ import annotations
from src.utils import ppm_to_da

import sys

BranchResult = namedtuple(
    'BranchResult', 
    ['kmers', 'masses'], 
    defaults=[[], []]
)

MassNode = namedtuple(
    'MassNode', 
    ['kmers', 'mass', 'children'], 
    defaults=[[], 0, []]
)

class MassTree:
    '''
    Tree for mass searches

    Properties:
        root:   (MassNode) the root of the tree. Contains no data
    '''
    
    def __init__(self, ppm_tol=20):
        self.root = MassNode(None, 0, [])
        self.ppm_tol = ppm_tol
        
    def __get_child_with_mass_equal(self, node: MassNode, mass: float) -> MassNode:
        '''
        Return a child if it has as mass equal to the input mass

        Inputs:
            node:   (MassNode) the node who's children to check
            mass:   (float) the mass of the child to get
        Outputs:
            (MassNode) child with the mass. If child not found, None is returned
        '''
        for c in node.children:
            if c.mass == mass:
                return c

        return None
    
    def __gap_search_rec(self, boundaries: list, node: MassNode, current_gap: int, gap_limit: int) -> list:
        '''
        Recursive search of the tree for a sequence while allowing for missed
        values in the sequence up to gap_limit

        Input:
            boundaries:     (list) tuples of upper and lower bound masses to search for
            node:           (MassNode) the current node to look at
            current_gap:    (int) the number of gaps we've used up to this point
            gap_limit:      (int) the maximum number of gaps allowed 
        Outputs:
            (list) BranchResult reconstructed sequence without gaps found with all associated keys
        '''
        # check to see if the total-current < 0. If so return empty mass
        if gap_limit - current_gap < 0:
            return [BranchResult([], [])]
        
        # if the sequence is none
        if len(boundaries) <= 0:
            return [BranchResult([], [])]
        
        # see if we can find ANY of the future masses with a tolerance that
        # contains my mass
        mass_found = any([b[0] <= node.mass <= b[1] for b in boundaries])
                
        # if my mass is the one in question, don't add a gap
        gap_addition = 0 if mass_found else 1
        
        # start at the position of the boundary found + 1 to look through
        seq_start = [i for i, b in enumerate(boundaries) if b[0] <= node.mass <= b[1]][0] + 1\
             if mass_found else 0
        
        # if we are at the end of the sequence and we've found the correct 
        # match, return 
        if mass_found and seq_start >= len(boundaries):
            return [BranchResult(node.kmers, [node.mass])]
                
        # all other cases look through all my children
        return_masses = list(flatten(
            [self.__gap_search_rec(
                boundaries[seq_start:], 
                child, 
                current_gap + gap_addition, 
                gap_limit
            ) for child in node.children]))
        
        # filter out all return masss that have [] as the kmers
        return_masses = [x for x in return_masses if len(x.kmers)]
        
        # if return masses is empty but at this depth we found something, return my value
        if not len(return_masses) and mass_found:
            return [BranchResult(node.kmers, [node.mass])]
        
        # if the current node is the root, don't append its stuff
        if node == self.root:
            return return_masses
        
        # add node mass to reconstrunct the sequence
        return [BranchResult(result.kmers, [node.mass] + result.masses)\
                for result in return_masses]
    
    def __show_rec(self, node: MassNode, spaces=0, stop_level=None):
        '''
        Recursive function for printing tree
        '''
        if stop_level is not None and spaces / 2 >= stop_level:
            return
        
        p_spaces = ''.join([' ' for _ in range(spaces)])
        print(f'{p_spaces}|---> kmers: {node.kmers}, mass: {node.mass}')
        [self.__show_rec(c, spaces + 2, stop_level) for c in node.children]
        
        
    ############################ Public methods ############################
    
    def add_sequence(self, kmer: str, sequence: list) -> None:
        '''
        Add a sequence with a kmer sequence to the tree

        Inputs:
            kmer:       (str) the kmer associated with the sequence of masses
            sequence:   (list) masses to add to the tree
        Outputs:
            None
        '''
        current_node = self.root
        
        # go through each mass in the sequence 
        for i, mass in enumerate(sequence):
            
            # if i get a child node back, make it current node and 
            # add the kmer up to the current value
            child_node = self.__get_child_with_mass_equal(current_node, mass)
            if child_node is not None:
                current_node = child_node
                
                # only add a new kmer if its not already in the kmers
                if kmer[:i+1] not in current_node.kmers:
                    current_node.kmers.append(kmer[:i+1])
                
            # add a new child to current node
            # add child adds the kmer to the current node and creates the new child
            # that is returned
            else:
                new_child = MassNode([kmer[:i+1]], mass, [])
                current_node = current_node.children.append(new_child)
                current_node = new_child
            
    def search(self, sequence: list) -> list:
        '''
        Search for the exact sequence provided in the tree

        Inputs:
            sequence:   (list) the sequence to look for in the tree
        Outputs:
            (list) kmers that are associated with the exact sequence provide. Empty list
                    is returned if the sequence is not found
        '''
        current_node = self.root
        
        # go through each sequence and search for the exact mass
        for mass in sequence:

            # if we cannot find the next mass, return empty list
            if not self.__get_child_with_mass_equal(current_node, mass):
                return []
            
            current_node = current_node.__get_child_with_mass_equal(current_node, mass)
            
        return current_node.kmers
    
    def gap_search(self, sequence: list, gap: int) -> list:
        '''
        Search the tree for the sequence allowing gap number of missed values in the sequence

        Inputs:
            sequence:   (list) the sequence to look for
            gap:        (int) the number of gaps to allow when searching for the sequence
        Outputs:
            (list) BranchResult reconstructed sequence without gaps found with all associated keys
        '''
        # incerment gap because we will be checking on the root that has no data
        gap += 1

        # create a list of ranges from our sequence to pass into recursive search
        get_boundaries = lambda mz: (mz - ppm_to_da(mz, self.ppm_tol), mz + ppm_to_da(mz, self.ppm_tol))
        boundaries = [get_boundaries(mz) for mz in sequence]

        # recursive search for this
        result = self.__gap_search_rec(boundaries, self.root, 0, gap) 
        
        # if result is none, return empty list
        if not len(result):
            return []
        
        longest = max(map(lambda x: len(x[1]), result))
        return [x for x in result if len(x[1]) == longest]
    
    def show(self, stop_level=None):
        '''
        Print the tree to console
        '''
        print('root')
        [self.__show_rec(c, 0, stop_level) for c in self.root.children]
        
                
        

## Make sure that this thing actually works

In [8]:
# import the stuff for sequence generation
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from src.sequence.gen_spectra import gen_spectrum

In [9]:
kmer1 = 'MALWAR'
kmer2 = 'QQSNPP'
bs1 = gen_spectrum(kmer1, ion='b', charge=1)['spectrum']
bd1 = gen_spectrum(kmer1, ion='b', charge=2)['spectrum']
bs2 = gen_spectrum(kmer2, ion='b', charge=1)['spectrum']
bd2 = gen_spectrum(kmer2, ion='b', charge=2)['spectrum']

print(f'bs1: {bs1}\nbd1: {bd1}\nbs2: {bs2}\nbd2: {bd2}\n')

bs1: [132.04776143499998, 203.084875435, 316.16893943499997, 502.248252435, 573.285366435, 729.386477435]
bd1: [66.52751893499999, 102.04607593499999, 158.58810793499998, 251.627764435, 287.146321435, 365.196876935]
bs2: [129.065854435, 257.124432435, 344.156460435, 458.199387435, 555.252151435, 652.304915435]
bd2: [65.036565435, 129.065854435, 172.581868435, 229.603331935, 278.129713935, 326.656095935]



In [10]:
# make a bs and bd tree and add each of them respectively
bs_tree = MassTree()
bd_tree = MassTree()

In [11]:
bs_tree.add_sequence(kmer1, bs1)
bs_tree.add_sequence(kmer2, bs2)
bs_tree.show()

root
|---> kmers: ['M'], mass: 132.04776143499998
  |---> kmers: ['MA'], mass: 203.084875435
    |---> kmers: ['MAL'], mass: 316.16893943499997
      |---> kmers: ['MALW'], mass: 502.248252435
        |---> kmers: ['MALWA'], mass: 573.285366435
          |---> kmers: ['MALWAR'], mass: 729.386477435
|---> kmers: ['Q'], mass: 129.065854435
  |---> kmers: ['QQ'], mass: 257.124432435
    |---> kmers: ['QQS'], mass: 344.156460435
      |---> kmers: ['QQSN'], mass: 458.199387435
        |---> kmers: ['QQSNP'], mass: 555.252151435
          |---> kmers: ['QQSNPP'], mass: 652.304915435


In [12]:
bd_tree.add_sequence(kmer1, bd1)
bd_tree.add_sequence(kmer2, bd2)
bd_tree.show()

root
|---> kmers: ['M'], mass: 66.52751893499999
  |---> kmers: ['MA'], mass: 102.04607593499999
    |---> kmers: ['MAL'], mass: 158.58810793499998
      |---> kmers: ['MALW'], mass: 251.627764435
        |---> kmers: ['MALWA'], mass: 287.146321435
          |---> kmers: ['MALWAR'], mass: 365.196876935
|---> kmers: ['Q'], mass: 65.036565435
  |---> kmers: ['QQ'], mass: 129.065854435
    |---> kmers: ['QQS'], mass: 172.581868435
      |---> kmers: ['QQSN'], mass: 229.603331935
        |---> kmers: ['QQSNP'], mass: 278.129713935
          |---> kmers: ['QQSNPP'], mass: 326.656095935


In [13]:
# search the tree for the sequence
bs_tree.gap_search([129.065854435, 200, 257.124432435, 275, 300, 325, 344.156460435, 555.252151435, 652.304915435, 800], 0)

[BranchResult(kmers=['QQS'], masses=[129.065854435, 257.124432435, 344.156460435])]

# Build the tree with the entire fasta (filtered) database

In [15]:
%%time
ions = ['bs', 'bd', 'ys', 'yd']
trees = namedtuple('trees', ions)

ts = trees(MassTree(), MassTree(), MassTree(), MassTree())
prots = read_fasta('/Users/zacharymcgrath/Desktop/nod2 data/filteredNOD2.fasta')
plen = len(prots)
for i, prot in enumerate(prots):
    
    print(f'On protein {i}/{plen}\r', end='')
    
    seq = prot['sequence']
    # get each 30-mer both left and right
    kmers = []
    for i in range(len(seq) - 1):
        kmer_len = 30 if i + 30 < len(seq) else len(seq) - i
        kmers.append(seq[i:i+kmer_len])
    
    for i in range(len(seq)-1, -1, -1):
        kmer_len = 30 if i - 30 > 0 else i
        kmers.append(seq[i-kmer_len:i])
        
    kmers = list(set(kmers))
    
    # for each kmer, build the spectra and add it to each of the trees
    for kmer in kmers:
        bs = gen_spectrum(kmer, charge=1, ion='b')['spectrum']
        bd = gen_spectrum(kmer, charge=2, ion='b')['spectrum']
        ys = gen_spectrum(kmer, charge=1, ion='y')['spectrum']
        yd = gen_spectrum(kmer, charge=2, ion='y')['spectrum']
        
        ts.bs.add_sequence(kmer, bs)
        ts.bd.add_sequence(kmer, bd)
        ts.ys.add_sequence(kmer, ys)
        ts.yd.add_sequence(kmer, yd)

On protein 172/279

KeyboardInterrupt: 

In [13]:
ts.bs.show(2)

root
|---> kmers: ['N'], mass: 115.050203435
  |---> kmers: ['NN'], mass: 229.093130435
  |---> kmers: ['NP'], mass: 212.10296743499998
  |---> kmers: ['NM'], mass: 246.090688435
  |---> kmers: ['NL', 'NI'], mass: 228.134267435
  |---> kmers: ['NA'], mass: 186.087317435
  |---> kmers: ['NT'], mass: 216.097882435
  |---> kmers: ['NR'], mass: 271.151314435
  |---> kmers: ['NY'], mass: 278.113523435
  |---> kmers: ['NF'], mass: 262.118617435
  |---> kmers: ['NE'], mass: 244.09279643500003
  |---> kmers: ['NG'], mass: 172.071667435
  |---> kmers: ['NW'], mass: 301.129516435
  |---> kmers: ['NV'], mass: 214.118617435
  |---> kmers: ['NQ'], mass: 243.10878143500003
  |---> kmers: ['NS'], mass: 202.08223143499998
  |---> kmers: ['NH'], mass: 252.10911543499998
  |---> kmers: ['ND'], mass: 230.077146435
  |---> kmers: ['NK'], mass: 243.14516643500002
  |---> kmers: ['NC'], mass: 218.05938843500002
|---> kmers: ['M'], mass: 132.04776143499998
  |---> kmers: ['MA'], mass: 203.084875435
  |---> k

# DAWG: next step
The tree works well and having my own search algorithm is important to the problem at hand. The problem is that it is very large. For example, branch `A` and branch `N` SHOULD be separate because they have different masses. However, `AN` and `NA` have the same mass. Since they have the same mass, we should combine those two nodes. This makes a DAWG or a DAFSA (or something like that). So the next step is to implement that in order to use less memory

### Test this DAWG representation for the protein database (all of 30-mers)

In [8]:
'''
Original code from: http://stevehanov.ca/blog/?id=115
By Steve Hanov, 2011. Released to the public domain.

Modified by:
Zachary McGrath
'''

from collections import namedtuple

# This class represents a node in the directed acyclic word graph (DAWG). It
# has a list of edges to other nodes. It has functions for testing whether it
# is equivalent to another node. Nodes are equivalent if they have identical
# edges, and each identical edge leads to identical states. The __hash__ and
# __eq__ functions allow it to be used as a key in a python dictionary.
DawgNode = namedtuple(
    'DawgNode',
    ['final', 'edges', 'kmers'], 
    defaults=[False, {}, []]
)

PreviousSequence = namedtuple(
    'PreviousSequence',
    ['sequence', 'nodes'], 
    defaults=[[], []]
)

def DawgNodeHash(node):
    arr = []
    if node.final: 
        arr.append("1")
    else:
        arr.append("0")

    for (label, node) in node.edges.items():
        arr.append( label )
        arr.append( str( node.kmers ) )

    return "_".join([str(a) for a in arr]).__hash__()


class Dawg:
    def __init__(self):
        self.previousWord = PreviousSequence()
        self.root = DawgNode()
        
        # Here is a list of nodes that have not been checked for duplication.
        self.uncheckedNodes = []

        # Here is a list of unique nodes that have been checked for
        # duplication.
        self.minimizedNodes = {}

    def insert(self, word, kmer):
        if word < self.previousWord.sequence:
            raise Exception("Error: Words must be inserted in alphabetical order.")

        # find common prefix between word and previous word
        commonPrefix = 0
        for i in range( min( len( word ), len( self.previousWord.sequence ) ) ):
            if word[i] != self.previousWord.sequence[i]:
                break
                
            # update the kmers at this node to include this new one
            # if its not already in the list
            new_kmer = kmer[:i+1]
            new_kmer not in self.previousWord.nodes[i].kmers \
                and self.previousWord.nodes[i].kmers.append(new_kmer)
            
            # increment common prefix to keep track of how much is in common
            commonPrefix += 1

        # Check the uncheckedNodes for redundant nodes, proceeding from last
        # one down to the common prefix size. Then truncate the list at that
        # point.
        self._minimize( commonPrefix )

        # add the suffix, starting from the correct node mid-way through the
        # graph
        if len(self.uncheckedNodes) == 0:
            node = self.root
        else:
            node = self.uncheckedNodes[-1][2]
            
        # update previous sequence to have only the nodes up until common prefix
        self.previousWord = self.previousWord._replace(
            nodes=self.previousWord.nodes[:commonPrefix]
        )

        for letter in word[commonPrefix:]:
            # create a new node
            nextNode = DawgNode(False, {}, [kmer[:commonPrefix+1]])
            
            # update the edges at this new value to the new node
            node.edges[letter] = nextNode
            
            # append the new node to previous nodes
            self.previousWord.nodes.append(nextNode)
        
            # we have not checked these new nodes for duplicates
            self.uncheckedNodes.append( (node, letter, nextNode) )
            node = nextNode

        # set the last node to be true for final AND update the last sequence to be this sequence
        node = node._replace(final=True)
        self.previousWord = self.previousWord._replace(sequence=word)

    def finish( self ):
        # minimize all uncheckedNodes
        self._minimize( 0 );
        self.minimizedNodes = None
        self.uncheckedNodes = None

    def _minimize( self, downTo ):
        # proceed from the leaf up to a certain point
        for i in range( len(self.uncheckedNodes) - 1, downTo - 1, -1 ):
            (parent, letter, child) = self.uncheckedNodes[i];
      
            # hash the node once
            dnh = DawgNodeHash(child)
    
            if dnh in self.minimizedNodes:
                # replace the child with the previously encountered one
                parent.edges[letter] = self.minimizedNodes[dnh]
            else:
                # add the state to the minimized nodes.
                self.minimizedNodes[dnh] = child;
            self.uncheckedNodes.pop()

    def lookup( self, word ):
        node = self.root
        for letter in word:
            if letter not in node.edges: return False
            node = node.edges[letter]

        return node.kmers
                
        
    def __fuzzy_lookup_rec(self, node, word, current_gap, gap_limit):
        # check to see if the total-current < 0. If so return empty mass
        if gap_limit - current_gap < 0:
            return []
        
        # if the sequence is none
        if len(word) <= 0:
            return []
        
        # see if we can find ANY of the future masses with a tolerance that
        # contains my mass
        mass_found = any([w in node.edges for w in word])
                
        # if my mass is the one in question, don't add a gap
        gap_addition = 0 if mass_found else 1
        
        # start at the position of the boundary found + 1 to look through
        seq_start = [i for i, v in enumerate(word) if v in node.edges][0] + 1 \
            if mass_found else 0
                
        # if we are at the end of the sequence and we've found the correct 
        # match, return 
        if mass_found and seq_start >= len(word):
            return node.edges[word[0]].kmers
                
        # all other cases look through all my children
        return_kmers = list(flatten(
            [self.__fuzzy_lookup_rec(
                c,
                word[seq_start:],  
                current_gap + gap_addition, 
                gap_limit
            ) for _, c in node.edges.items()]))
        
        # filter out all return masss that have [] as the kmers
        return_kmers = [x for x in return_kmers if len(x)]
        
        # if return masses is empty but at this depth we found something, return my value
        if not len(return_kmers) and mass_found:
            return node.kmers
        
        # if the current node is the root, don't append its stuff
        if node == self.root:
            return return_kmers
        
        # add node mass to reconstrunct the sequence
        return return_kmers
        
        
    def fuzzy_lookup(self, word, gap):
        # increment gap because we'll lose one due to root
        gap += 1
        return self.__fuzzy_lookup_rec(self.root, word, 0, gap)
        
        
        
                

## Test on smaller dataset

In [9]:
dawg = Dawg()
fasta_file = '/Users/zacharymcgrath/Desktop/nod2 data/filteredNOD2.fasta'
kmers, num_kmers = build_and_sort_kmers(fasta_file)


sorting list 1 of 1
Merging sorted lists...
Done


In [10]:
%%time
last = ''
for c, kmer in enumerate(kmers):
    
    print(f'On kmer {c}/{num_kmers} [{int(100 * c / num_kmers)}%]\r', end='')
        
    mzs = gen_spectrum(kmer, ion='b', charge=1)['spectrum']
    
    try:
        dawg.insert(mzs, kmer)
    except:
        print(f'last is not less than current:\nlast: {last}\ncurrent: {kmer}')
        a=1/0
    last = kmer

dawg.finish()

On kmer 0/109771 [0%]On kmer 1/109771 [0%]On kmer 2/109771 [0%]On kmer 3/109771 [0%]On kmer 4/109771 [0%]On kmer 5/109771 [0%]On kmer 6/109771 [0%]On kmer 7/109771 [0%]On kmer 8/109771 [0%]On kmer 9/109771 [0%]On kmer 10/109771 [0%]On kmer 11/109771 [0%]On kmer 12/109771 [0%]On kmer 13/109771 [0%]On kmer 14/109771 [0%]On kmer 15/109771 [0%]On kmer 16/109771 [0%]On kmer 17/109771 [0%]On kmer 18/109771 [0%]On kmer 19/109771 [0%]On kmer 20/109771 [0%]On kmer 21/109771 [0%]On kmer 22/109771 [0%]On kmer 23/109771 [0%]On kmer 24/109771 [0%]On kmer 25/109771 [0%]On kmer 26/109771 [0%]On kmer 27/109771 [0%]On kmer 28/109771 [0%]On kmer 29/109771 [0%]On kmer 30/109771 [0%]On kmer 31/109771 [0%]On kmer 32/109771 [0%]On kmer 33/109771 [0%]On kmer 34/109771 [0%]On kmer 35/109771 [0%]On kmer 36/109771 [0%]On kmer 37/109771 [0%]On kmer 38/109771 [0%]On kmer 39/109771 [0%]On kmer 40/109771 [0%]On kmer 41/109771 [0%]On kmer 42/109771 [0%]On kmer 43/109771 [0%

On kmer 420/109771 [0%]On kmer 421/109771 [0%]On kmer 422/109771 [0%]On kmer 423/109771 [0%]On kmer 424/109771 [0%]On kmer 425/109771 [0%]On kmer 426/109771 [0%]On kmer 427/109771 [0%]On kmer 428/109771 [0%]On kmer 429/109771 [0%]On kmer 430/109771 [0%]On kmer 431/109771 [0%]On kmer 432/109771 [0%]On kmer 433/109771 [0%]On kmer 434/109771 [0%]On kmer 435/109771 [0%]On kmer 436/109771 [0%]On kmer 437/109771 [0%]On kmer 438/109771 [0%]On kmer 439/109771 [0%]On kmer 440/109771 [0%]On kmer 441/109771 [0%]On kmer 442/109771 [0%]On kmer 443/109771 [0%]On kmer 444/109771 [0%]On kmer 445/109771 [0%]On kmer 446/109771 [0%]On kmer 447/109771 [0%]On kmer 448/109771 [0%]On kmer 449/109771 [0%]On kmer 450/109771 [0%]On kmer 451/109771 [0%]On kmer 452/109771 [0%]On kmer 453/109771 [0%]On kmer 454/109771 [0%]On kmer 455/109771 [0%]On kmer 456/109771 [0%]On kmer 457/109771 [0%]On kmer 458/109771 [0%]On kmer 459/109771 [0%]On kmer 460/109771 [0%]On kmer 461/1097

CPU times: user 31.7 s, sys: 4.23 s, total: 35.9 s
Wall time: 40.4 s


In [11]:
dawg.lookup(gen_spectrum('LAT', charge=1, ion='b')['spectrum'])


['IAT', 'LAT']

In [12]:
dawg.fuzzy_lookup(gen_spectrum('MACGLVAS', charge=1, ion='b')['spectrum'], 3)

['GVMV',
 'GVMVG',
 'AGMAM',
 'AC',
 'CVQM',
 'CM',
 'LMAM',
 'IMAM',
 'MACGLVAS',
 'MVGM']

In [13]:
gen_spectrum('MACGLVAS', ion='b', charge=1)['spectrum']

[132.04776143499998,
 203.084875435,
 306.09406043499996,
 363.11552443499994,
 476.19958843499995,
 575.268002435,
 646.305116435,
 733.3371444349999]

In [14]:
dawg.fuzzy_lookup(
    [
        132.04776143499998,
        150,
        203.084875435,
        363.11552443499994,
        400, 
        450,
        476.19958843499995,
        733.3371444349999
    ], 
    3)

['A', 'MACGLVAS', 'MVGM']

# Cython-ized DAWG for lower memory use (hopeful)

In [15]:
%load_ext Cython

In [36]:
%%cython

'''
Original code from: http://stevehanov.ca/blog/?id=115
By Steve Hanov, 2011. Released to the public domain.

Modified by:
Zachary McGrath
'''

from collections import namedtuple
from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free

# This class represents a node in the directed acyclic word graph (DAWG). It
# has a list of edges to other nodes. It has functions for testing whether it
# is equivalent to another node. Nodes are equivalent if they have identical
# edges, and each identical edge leads to identical states. The __hash__ and
# __eq__ functions allow it to be used as a key in a python dictionary.


cdef struct DawgNode:
    int final
    Edge* edges
    char** kmers
    
cdef struct Edge:
    DawgNode child
    float mass

cdef DawgNode new_dawg_node():
    cdef DawgNode node
    node.final = 0
    node.edges = []
    node.kmers = []
    return node
    
cdef struct PreviousSequence:
    char* sequence
    DawgNode* nodes
    

cdef str DawgNodeHash(DawgNode node):
    arr = []
    if node.final: 
        arr.append("1")
    else:
        arr.append("0")

    for edge in node.edges:
        arr.append( edge.mass )
        arr += [str(kmer) for kmer in node.kmers]

    return "_".join([str(a) for a in arr]).__hash__()


cdef class Dawg:
    cdef PreviousSequence previousWord
    cdef DawgNode root
    cdef list uncheckedNodes
    cdef dict minimizedNodes
    
    def __init__(self):
        pass

    def insert(self, word, kmer):
        if word < self.previousWord.sequence:
            raise Exception("Error: Words must be inserted in alphabetical order.")

        # find common prefix between word and previous word
        commonPrefix = 0
        for i in range(min(
            len(word), 
            len(self.previousWord.sequence )
        )):
            if word[i] != self.previousWord.sequence[i]:
                break
                
            # update the kmers at this node to include this new one
            # if its not already in the list
            new_kmer = kmer[:i+1]
            if new_kmer not in self.previousWord.nodes[i].kmers:
                # malloc and add the new kmer
                new_kmers = <char **> PyMem_Realloc(
                    self.previousWord.nodes[i].kmers,
                    sizeof(self.previousWord.nodes[i].kmers) + sizeof(char*)
                )
                new_kmers[-1] = new_kmer
                self.previousWord.nodes[i].kmers = new_kmers
                      
            # increment common prefix to keep track of how much is in common
            commonPrefix += 1

        # Check the uncheckedNodes for redundant nodes, proceeding from last
        # one down to the common prefix size. Then truncate the list at that
        # point.
        self._minimize(commonPrefix)

        # add the suffix, starting from the correct node mid-way through the
        # graph
        if len(self.uncheckedNodes) == 0:
            node = self.root
        else:
            node = self.uncheckedNodes[-1][2]
            
        # update previous sequence to have only the nodes up until common prefix
        self.previousWord.nodes = self.previoiusWord.nodes[:commonPrefix]

        for letter in word[commonPrefix:]:
            # create a new node
            nextNode = new_dog_node()
            nextNode.kmers = [kmer[:commonPrefix+1]]
            
            # update the edges at this new value to the new node
            node.edges.append((letter, nextNode))
            
            # append the new node to previous nodes
            self.previousWord.nodes.append(nextNode)
        
            # we have not checked these new nodes for duplicates
            self.uncheckedNodes.append((node, letter, nextNode))
            node = nextNode

        # set the last node to be true for final AND update the last sequence to be this sequence
        node.final = 1
        self.previousWord.sequence = word

    def finish( self ):
        # minimize all uncheckedNodes
        self._minimize(0);
        self.minimizedNodes = None
        self.uncheckedNodes = None

    def _minimize( self, downTo ):
        # proceed from the leaf up to a certain point
        for i in range( len(self.uncheckedNodes) - 1, downTo - 1, -1 ):
            (parent, letter, child) = self.uncheckedNodes[i];
      
            # hash the node once
            dnh = DawgNodeHash(child)
    
            if dnh in self.minimizedNodes:
                # replace the child with the previously encountered one
                parent.edges[letter] = self.minimizedNodes[dnh]
            else:
                # add the state to the minimized nodes.
                self.minimizedNodes[dnh] = child;
            self.uncheckedNodes.pop()

    def lookup(self, word):
        node = self.root
        for letter in word:
            if not any([letter == edgeLetter for edgeLetter, _ in node.edges]):
                return False
            
            for edgeLetter, nextNode in node.edges:
                if edgeLetter == letter:
                    node = nextNode

        return node.kmers
                
        
    def __fuzzy_lookup_rec(self, node, word, current_gap, gap_limit):
        # check to see if the total-current < 0. If so return empty mass
        if gap_limit - current_gap < 0:
            return []
        
        # if the sequence is none
        if len(word) <= 0:
            return []
        
        # see if we can find ANY of the future masses with a tolerance that
        # contains my mass
        mass_found = any([edgeLetter == w for edgeLetter, _ in node.edges for w in word])
                
        # if my mass is the one in question, don't add a gap
        gap_addition = 0 if mass_found else 1
        
        # start at the position of the boundary found + 1 to look through
        seq_start = [i for i, v in enumerate(word) if any([v == edgeLetter for edgeLetter, _ in node.edges])][0] + 1\
            if mass_found else 0
                
        # if we are at the end of the sequence and we've found the correct 
        # match, return 
        if mass_found and seq_start >= len(word):
            for i, (edgeLetter, _) in enumerate(node.edges):
                if word[0] == edgeLetter:
                    return node.edges[i].kmers
                
        # all other cases look through all my children
        return_kmers = list(flatten(
            [self.__fuzzy_lookup_rec(
                c,
                word[seq_start:],  
                current_gap + gap_addition, 
                gap_limit
            ) for _, c in node.edges]))
        
        # filter out all return masss that have [] as the kmers
        return_kmers = [x for x in return_kmers if len(x)]
        
        # if return masses is empty but at this depth we found something, return my value
        if not len(return_kmers) and mass_found:
            return node.kmers
        
        # if the current node is the root, don't append its stuff
        if node == self.root:
            return return_kmers
        
        # add node mass to reconstrunct the sequence
        return return_kmers
        
        
    def fuzzy_lookup(self, word, gap):
        # increment gap because we'll lose one due to root
        gap += 1
        return self.__fuzzy_lookup_rec(self.root, word, 0, gap)
        
        
        
                


Error compiling Cython file:
------------------------------------------------------------
...
                break
                
            # update the kmers at this node to include this new one
            # if its not already in the list
            new_kmer = kmer[:i+1]
            if new_kmer not in self.previousWord.nodes[i].kmers:
                       ^
------------------------------------------------------------

/Users/zacharymcgrath/.ipython/cython/_cython_magic_f6ea738979a15a0a78ec13d56f389a5e.pyx:80:24: Compiler crash in AnalyseExpressionsTransform

ModuleNode.body = StatListNode(_cython_magic_f6ea738979a15a0a78ec13d56f389a5e.pyx:10:0)
StatListNode.stats[3] = CClassDefNode(_cython_magic_f6ea738979a15a0a78ec13d56f389a5e.pyx:55:5,
    as_name = 'Dawg',
    class_name = 'Dawg',
    module_name = '',
    visibility = 'private')
CClassDefNode.body = StatListNode(_cython_magic_f6ea738979a15a0a78ec13d56f389a5e.pyx:56:4)
StatListNode.stats[1] = DefNode(_cython_magic_f6ea738