In [1]:
import numpy as np
from bitstream import BitStream
from numpy import array


In [2]:
class ArithmeticEncoder(object):
    def __init__(self, bitlen):
        self.bit_prec = bitlen   #bit precision 
        self.max_range = 1 << self.bit_prec  #max range based on bit precision 2^bit_prec
        self.mask = self.max_range - 1  #max range index starting at 0
        self.renorm= self.max_range >> 1  #renormalization threshold
        self.second_mask = self.max_range >> 1
        self.low = 0  #initial low
        self.high = self.mask  #initial high
        self.s = 0  

    def update(self, sym, c):
        low = self.low   
        high = self.high
        range = high - low + 1
        total = (c[-1])  #cumulative propabilities
        symlow = c[sym]  
        symhigh = c[sym+1]  
        

        newlow = low + symlow*range // total  #low in arithmetic integer
        newhigh = low + symhigh*range // total -1 #high in arithemtic integer
        self.low = int(newlow)
        self.high = int(newhigh)
        range = self.high - self.low 
#         print("sym:", sym)
#         print("symlow:", self.low)
#         print("symhigh:", self.high)
#         print("range:", range)
        
        #renormalization
        while((self.low ^ self.high) & self.renorm) == 0:
            self.low = (self.low << 1) 
            range = range << 1 | 1
            self.high = self.low + range
            self.s = self.s + 1          
#         print("renorm low:",self.low)
#         print("renorm range:",range)
#         print("renorm high:", self.high)
   
    def write(self, c, sym):
        self.update(c,sym)
        return [self.low, self.s]
        
    def finish(self):
        self.output.write(1)

In [25]:
class ArithmeticDecoder(ArithmeticEncoder):
    def __init__(self,statesize,bitin):
        self.input = bitin
        self.bitstream = ([int(d) for d in str(self.input)])
#         print(self.bitstream)
        self.code = 0
        self.max_range = 1 << statesize
        self.renorm= self.max_range >> 1
        self.mask = self.max_range - 1
        
        self.stream = self.input[0:statesize]
        self.low = int(self.stream,2)
        self.t = statesize
        self.thresh = []
        self.thresh = [round(c[i]*self.mask) for i in range(len(c))]
        
    def decode(self):
        for i in range(len(self.thresh)-1):
            if((self.low < self.thresh[i+1]) & (self.low >= self.thresh[i])):
                sym = i
#                 print("encoded value:",self.low)
#                 print("threshold:", self.thresh)
#                 print("decoded value:", sym)
                rangenew = int(self.thresh[i+1] - self.thresh[i])
#                 print("range:", rangenew)
#                 print("low:", self.thresh[i])
                while (rangenew < self.renorm):
                    rangenew = rangenew << 1 
                    self.low = self.low << 1
                    self.thresh[i] = int(self.thresh[i]) << 1
#                     print("renorm val;",self.low)
#                     print("renorm range;",rangenew)
                    self.thresh[i+1] = rangenew + self.thresh[i]
#                     print("renorm low:" ,self.thresh[i])
                self.thresh = [round(c[j]*(self.thresh[i+1]-self.thresh[i]))+self.thresh[i] for j in range(len(c))]
                break
        return sym


In [29]:
#test bit stream
bitprecision = 64

#fake input stream -- use real data
in_stream = array(['a','c','g','t','t', 'a', 'c', 't', 'g', 'a','t','a','a','t','a','c','c','g','t'])
#fake probability array --use output from RNN.  
prob = array([0.5, 0.2, 0.2, 0.1])

#make an array of cumulative probabilities
c = []
for i in range(len(prob)+1):
    c.append(sum(prob[0:i]))

#convert letters into numbers that correspond to cum probability index
sym = []
for x, letter in enumerate(in_stream): 
    if letter == 'a':
        sym.append(0)
    elif letter =='c':
        sym.append(1)
    elif letter == 'g':
        sym.append(2)
    elif letter == 't':
        sym.append(3)
    
print("symbol stream input:", in_stream)
print("num stream:", sym)
print("probability:", prob)
print("cumulative prob:", c)

sym.append(1) #fake symbol at end to discard in decorder 

#BEGIN ENCODER
enc = ArithmeticEncoder(bitprecision)
for j in range(len(sym)):
    new = enc.write(sym[j], c)
low_final = new[0] #fina low value 
s_final = new[1]  #count of renormalizations

#convert low to bitstream, ensure proper number of bits 
bitstream = format(int(low_final), 'b')  
while len(bitstream) < (s_final+bitprecision):
    bitstream = '0' + bitstream
print("encoded bit length:", len(bitstream))


#START DECODER
dec = ArithmeticDecoder(bitprecision, bitstream)

symdec = []
for symbols in range(len(sym)):
    symdec.append(dec.decode())
print("decoded num stream:", symdec)

#convert numbers back to letters
dec_stream = []
for x, sym in enumerate(symdec): 
    if sym == 0:
        dec_stream.append('a')
    elif sym ==1:
        dec_stream.append('c')
    elif sym == 2:
        dec_stream.append('g')
    elif sym == 3:
        dec_stream.append('t')

dec_stream = dec_stream[:-1] 
print("decoded symbol stream:", dec_stream)

symbol stream input: ['a' 'c' 'g' 't' 't' 'a' 'c' 't' 'g' 'a' 't' 'a' 'a' 't' 'a' 'c' 'c' 'g'
 't']
num stream: [0, 1, 2, 3, 3, 0, 1, 3, 2, 0, 3, 0, 0, 3, 0, 1, 1, 2, 3]
probability: [0.5 0.2 0.2 0.1]
cumulative prob: [0, 0.5, 0.7, 0.8999999999999999, 0.9999999999999999]
encoded bit length: 106
decoded num stream: [0, 1, 2, 3, 3, 0, 1, 3, 2, 0, 3, 0, 0, 3, 0, 1, 1, 2, 3, 1]
decoded symbol stream: ['a', 'c', 'g', 't', 't', 'a', 'c', 't', 'g', 'a', 't', 'a', 'a', 't', 'a', 'c', 'c', 'g', 't']
