In [104]:
class Node():
    def __init__(self,val,left, right):
        self.val = val
        self.left = left
        self.right = right
        self.code = ''
    def __lt__(self,other):
        return True

In [105]:
from collections import Counter
from queue import PriorityQueue

class HuffmannEncoder():
    def __init__(self,init_text):
        self.encodes = dict()
        self.decodes = dict()
        counts = Counter(init_text)
        letters = set(init_text)
        q = PriorityQueue()
        for letter in letters:
            q.put((counts[letter], Node(letter,None, None)))
                  
        while( q.qsize() > 1):
            count_a, a = q.get()
            count_b, b = q.get()
            root = Node(None,a,b)
            q.put((count_a+count_b, root))
        root = q.get()[1]
                
        self.add_codes(root)
        
    def add_codes(self,root):
        if root.val:
            self.encodes[root.val] = root.code
            self.decodes[root.code] = root.val
        if root.left:
            root.left.code = root.code+"0"
            self.add_codes(root.left)
        if root.right:
            root.right.code = root.code+"1"
            self.add_codes(root.right)
            
    def encode(self,text):
        result =''
        current= ''
        for letter in text:
            current+=letter
            if current in self.encodes:
                result+=self.encodes[current]
                current=''
        return result
    
    def decode(self,text):
        result =''
        current= ''
        for letter in text:
            current+=letter
            if current in self.decodes:
                result+=self.decodes[current]
                current=''
        return result
            
            
            
                  

In [106]:
with open('./data/norm_wiki_sample.txt', 'r') as file:
    text = file.read()

In [107]:
encoder = HuffmannEncoder(text)

In [108]:
assert text == encoder.decode(encoder.encode(text))

In [109]:
print("Number of bits required for text:", len(text)*6)
print("Number of bits required for encoded text:", len(encoder.encode(text)))

Number of bits required for text: 64733646
Number of bits required for encoded text: 46489714
