# BWT Implementation

- given a string with characters from the alphabet {A, C, G, T}, compress the string such that if there are more than one of the same character in a row, replace that character with the frequency followed by the character

- ex) AAACTTTTGG --> 3AC4T2G

- optional string argument for the constructor, if none provided a random string of 100 characters will be generated by default

- *** most useful when strings are prone to character repitition --> leads to stronger compression

In [311]:
import numpy as np

class BWT:
    def __init__(self, string=""):
        self.string = string
        self.bwt = ""
        if self.string == "":
            self.string = self.generate_random(100)
        if "$" not in self.string:
            raise Exception("Formatting error. Input string does not contain $")
        self.bwt = self.build_bwt(self.string)
        self.compressed = self.compress()
        self.compression_rate = self.get_compression_rate()
        self.fm, self.offset = self.get_FMindex()
        
    def summarize(self):
        print("Original String: {}".format(self.string))
        print("BWT: {}".format(self.bwt))
        print("Compressed BWT: {}".format(self.compressed))
        print("Compression Rate: {} %".format(self.compression_rate))
        
    
    def build_bwt(self, t):
        rotation = sorted([t[i:]+t[:i] for i in range(len(t))])
        return ''.join(r[-1] for r in rotation)

    
    def generate_random(self, n):
        string = ""
        for i in range(n):
            fork = np.random.randint(4)
            if fork == 0:
                string += "A"
            if fork == 1:
                string += "C"
            if fork == 2:
                string += "G"
            if fork == 3:
                string += "T"
        string += "$"
        return string
    
    def compress(self):
        last_char = self.bwt[0]
        char_count = 1
        compressed_bwt = ""
        for i in range(1, len(self.bwt)):
            if self.bwt[i] == last_char:
                char_count += 1
            else:
                if char_count > 1:
                    compressed_bwt += str(char_count) + last_char
                else:
                    compressed_bwt += last_char
                last_char = self.bwt[i]
                char_count = 1
        if char_count > 1:
            compressed_bwt += str(char_count) + last_char
        else:
            compressed_bwt += last_char

        return compressed_bwt
        
    def decompress(self):
        valid_chars = ["A", "C", "G", "T"]
        bwt = ""
        a = 0
        b = 0
        while a < len(self.compressed):
            if self.compressed[a] not in valid_chars:
                b = a
                while self.compressed[b] not in valid_chars:
                    b += 1
                bound = int(self.compressed[a:b])
                for i in range(bound):
                    bwt += self.compressed[b]
                a += 1
            else:
                bwt += self.compressed[a]
            a += 1
        return bwt
    
    def get_compression_rate(self):
        return round(((len(self.bwt) - len(self.compressed)) / len(self.bwt) * 100), 3)
    
    def get_FMindex(self):
        fm = [{c: 0 for c in self.bwt}]
        for c in self.bwt:
            row = {symbol: count + 1 if (symbol == c) else count for symbol, count in fm[-1].items()}
            fm.append(row)
        offset = {}
        N = 0
        for symbol in sorted(row.keys()):
            offset[symbol] = N
            N += row[symbol]
        return fm, offset
    
    def recover_suffix(self, i):
        suffix = ""
        c = self.bwt[i]
        predec = self.offset[c] + self.fm[i][c]
        suffix = c + suffix
        while (predec != i):
            c = self.bwt[predec]
            predec = self.offset[c] + self.fm[predec][c]
            suffix = c + suffix
        return suffix
        
    def print_suffixes(self):
        for i in range(len(self.bwt)):
            print("%2d: %s" % (i, self.recover_suffix(i)))
            
    def merge(self, other):
        interleave = [(c, 0) for c in self.bwt] + [(c, 1) for c in other]
        passes = min(len(self.bwt), len(other))
        for p in range(passes):
            i, j = 0, 0
            nextInterleave = []
            for c, k in sorted(interleave, key=lambda x: x[0]):
                if (k == 0):
                    b = self.bwt[i]
                    i += 1
                else:
                    b = other[j]
                    j += 1
                nextInterleave.append((b, k))
            if (nextInterleave == interleave):
                break
            interleave = nextInterleave
        return ''.join([c for c, k in interleave])
    
    def find(self, pattern):
        lo = 0
        hi = len(self.fm) - 1
        for symbol in reversed(pattern):
            lo = self.offset[symbol] + self.fm[lo][symbol]
            hi = self.offset[symbol] + self.fm[hi][symbol]
        for i in range(lo, hi):
            print("%2d: %s" % (i, self.recover_suffix(i)))
    
    
    

In [309]:
test = BWT('ACATCATACAT$')

In [305]:
test.summarize()

Original String: ACAT$CATACAT
BWT: TTTCCCAA$AAA
Compressed BWT: 3T3C2A$3A
Compression Rate: 25.0 %


In [287]:
test.print_suffixes()

 0: $ACAT$CATACAT
 1: $CATACAT$ACAT
 2: ACAT$ACAT$CAT
 3: ACAT$CATACAT$
 4: AT$ACAT$CATAC
 5: AT$CATACAT$AC
 6: ATACAT$ACAT$C
 7: CAT$ACAT$CATA
 8: CAT$CATACAT$A
 9: CATACAT$ACAT$
10: T$ACAT$CATACA
11: T$CATACAT$ACA
12: TACAT$ACAT$CA


In [289]:
test.find("CAT")

 7: CAT$ACAT$CATA
 8: CAT$CATACAT$A
 9: CATACAT$ACAT$


In [310]:
for i in range(10):
    test = BWT()
    test.summarize()
    print()

Original String: ACTTACGACTCAGACCACTGCACATTCACAGTACGCCCCATTAACGACACGCCTCCATATCAGTGGGACTTCTCCGTCTGATTCCTCCGGATCAGTCAAA$
BWT: AAACTGCCGATTCGC$GCCCCCGTCCGTTGACTATTCCAATCCGTTGTAAAACCACCTTAAACACGGTTCCGCTAACATTGAGTCAACCCTTGCCGACAAC
Compressed BWT: 3ACTG2CGA2TCGC$G5CGT2CG2TGACTA2T2C2AT2CG2TGT4A2CA2C2T3ACAC2G2T2CGCT2ACA2TGAGTC2A3C2TG2CGAC2AC
Compression Rate: 7.921 %

Original String: CTGGGGTCGAGGGCGCTGACCCGCATGGAAAGCCCACTCATAGAGAACCTTTTAGGGGTGGAAGTCTCCTCGAACTTAGTCGATGGCGCTCCGATCGAGT$
BWT: TGGGAGGACAGTAGTGTACGCGCTGCGATCTATTTCTCGGAGTCG$ACGACGTACCCCCAGGCCTTTGATAGGGGAGAAGGATTCCCCGAGGCAGACTCTC
Compressed BWT: T3GA2GACAGTAGTGTACGCGCTGCGATCTA3TCTC2GAGTCG$ACGACGTA5CA2G2C3TGATA4GAG2A2GA2T4CGA2GCAGACTCTC
Compression Rate: 9.901 %

Original String: TACGCGTTGAGCGGATCGCTTCGCAGCCATTTGCTTGTGTATCGAACCCTCGCCGCTAAGAGGAGACGGGCTGCCTCACGAGCTAATTGCCGCATTGATA$
BWT: ATGTTACTGGACGGGGTGCACTGGCGAGGGCTATCTACTGAGGGCCGGGCAGTCATGCCATCTACCAGCTACGCTTCACC$GCATCATTTCTGTCGAATCA
Compressed BWT: ATG2TACT2GAC4GTGCACT2GCGA3GCTATCTA