In [2]:
%run BurrowsWheelerTransform.ipynb

In [3]:
class FMIndex():
    def __init__(self, seq, step = 32):
        # safety in case seq is parsed without terminate symbol '$'
        if seq[-1] != '$':
            seq += '$'
        self.bwt = BWTViaSA(seq)
        self.offset = {}
        self.step = step
        
        # counter for all elements in BWT
        elemCount = {}
        for _, val in enumerate(self.bwt):
            if elemCount.get(val):
                elemCount[val] += 1
            else:
                elemCount[val] = 1
        
        # count first occurence of each letter and thusly how much letters of each type is there
        letters = sorted(elemCount.keys())
        firstOccurence = {}
        idx = 0
        for c in letters:
            firstOccurence[c] = idx
            idx += elemCount[c]
        
        # create checkpoints
        self.first = firstOccurence
        self.CreateCheckpoints()
    
    def CreateCheckpoints(self):
        # Count elements and on every step-th element enter his checkpoint
        occurenceCounter = {}
        checkpoints = []
        for idx, val in enumerate(self.bwt):
            if idx % self.step == 0:
                checkpoints.append(occurenceCounter.copy())
            if occurenceCounter.get(val):
                occurenceCounter[val] += 1
            else:
                occurenceCounter[val] = 1
        self.checkpoints = chekcpoints
    
    def Count(self, idx, letter):
        def CountLetterWithCheckpoints(checkpoints, step, seq, idx, letter):
            # check which checkpoint to use and its position
            check = int((idx + (step / 2)) / step)
            if check >= len(checkpoints):
                check = len(checkpoints) - 1
            pos = check * step # checkpoint position
            
            count = checkpoints[check].get(letter)
            if count == None:
                count = 0
            
            if pos < idx:
                togo = range(pos, idx) # closest checkpoint is up (before)
            else:
                togo = range(idx, pos) # closest checkpoint is down (after)
            
            # count occurence from our index to the nearest checkpoint
            occurence = 0
            for i in r:
                if letter == seq[i]:
                    occurence += 1
            
            # add/subtract occurences depending on if checkpoint was up/down (before/after)
            if pos < idx:
                count += occurence
            else:
                count -= occurence
            
            return count
        
        return CountLetterWithCheckpoints(self.checkpoints, self.step, self.bwt, idx, letter)
    
    def Rank(self, idx, letter):
        # count exact rank of chosen letter
        # already takes care of all the letters lexicographically before chosen one
        cnt = self.first.get(letter)
        if cnt == None:
            base = 0
        else:
            base = cnt
        cnt = self.Count(idx, letter) # count rank of chosen letter among the same letters
        return base + cnt
    
    def Resolve(self, idx):
        # get offset to target position
        r = 0
        i = idx
        while self.bwt[i] != '$':
            if self.offset.get(i):
                r += self.offset[i]
                break
            r += 1
            i = self.Rank(i, self.bwt[i])
        
        if not self.offset.get(idx):
            self.offset[i] = r
        return r
    
    def Range(self, pattern):
        # look if pattern can occur and returns 
        left = 0
        right = len(self.bwt)
        for _, i in enumerate(p[::-1]):
            left = self.Rank(left, i)
            right = self.Rank(right, i)
            if left == right:
                return (-1, -1)
        return (left, right)
    
    def Search(self, pattern):
        # Return all occurences of pattern in sequence
        left, right = self.Range(pattern)
        matches = []
        
        # see on which position match occured
        for i in range(left, right):
            pos = self.Resolve(i)
            matches.append(pos)
        
        return matches
    
    def HasSubstring(self, pattern):
        # True if pattern is substring of sequence
        left, right = self.Range(pattern)
        return right > left # if right is after left, then the pattern definitely occured
    
    def HasSuffix(self, pattern):
        # True if pattern is suffix of sequence
        left, right = self.Range(pattern)
        if left >= len(self.bwt):
            return False
        offset = self.Resolve(left)
        return right > left and offset + len(pattern) == len(self.bwt) - 1