In [1]:
import numpy as np 
import pandas as pd 
import cv2
from matplotlib import pyplot as plt
import sys
import struct
import os
import functools


def showImage(img):
    plt.figure(figsize=(15,15))
    plt.imshow(np.uint8(img),cmap='gray')
    plt.xticks([]),plt.yticks([])
    plt.show()

In [2]:
class Node:
    def __init__(self, value, left, right):
        self.value = value
        self.left = left
        self.right = right
   
    def getLeft(self):
        return self.left
    
    def getRight(self):
        return self.right
    
    def getValue(self):
        return self.value  
    
    
class HuffmanEncoding:
    
    def __init__(self):
        self.d = {}
        
    def clearMess(self):
        self.d = {}
    
    def huffmanEncoding(self, data):
        freq = {}
        for element in data:
            if element in freq:
                freq[element] += 1
            else:
                freq[element] = 1
        freq = sorted(freq.items(), key=lambda x: x[1], reverse=True)
        tempFreq = freq.copy()
        while(len(tempFreq)>1):
            (key1,val1) = tempFreq[-1]
            (key2,val2) = tempFreq[-2]
            tempFreq = tempFreq[:-2]
            if(isinstance(key1, Node)):
                ent1 = key1
            else:
                ent1 = Node(key1, None, None)
            if(isinstance(key2, Node)):
                ent2 = key2
            else:
                ent2 = Node(key2, None, None)
            node = Node(None, ent1, ent2)
            tempFreq.append((node, val1+val2))
            tempFreq = sorted(tempFreq, key = lambda x: x[1], reverse = True)
        return node

    def generateCodes(self, node, s):
        if (node.left == None and node.right == None and node.value != None):
            self.d[node.value] = s
        if(node.left != None):
            s = s + '0'
            self.generateCodes(node.left, s)
            s = s[:-1]
        if(node.right != None):
            s = s + '1'
            self.generateCodes(node.right, s)
            s = s[:-1]
        return self.d
    
    def encodeData(self, data):
        tree = self.huffmanEncoding(data)
        codec = compress.generateCodes(tree, "")
        encoded_text = ""
        for element in data:
            encoded_text += self.d[element]
        return encoded_text
    
    def padEncodedData(self, encodedData):
        extra_padding = 8 - len(encodedData) % 8
        for i in range(extra_padding):
            encodedData += "0"
        padded_info = "{0:08b}".format(extra_padding)
        paddedEncodedData = padded_info + encodedData
        return paddedEncodedData
    
    def unpadEncodedData(self, padded_encoded_text):
        padded_info = padded_encoded_text[:8]
        extra_padding = int(padded_info, 2)
        padded_encoded_text = padded_encoded_text[8:] 
        encoded_text = padded_encoded_text[:-1*extra_padding]
        return encoded_text

    def decodeData(self, tree, encodedData):
        temp= []
        ref = huffmanTree
        for bit in inEncodedData:
            if(bit == '0'):
                ref = ref.left
            else:
                ref = ref.right
            if(ref.value != None):
                temp.append(ref.value)
        #         ref = copy.deepcopy(tree)
                ref = huffmanTree
        return temp
       
    
class FileStorage():

    def __init__(self):
        self.data = []
        self.index = 0
        
    def clearMess(self):
        self.data = []
        self.index = 0
        
    def serialize(self, node):
        if(node.value == None):
            self.data.append('#')
            self.serialize(node.left)
            self.serialize(node.right)
        else:
            self.data.append(node.value)
        return self.data
    
    def deserialize(self, list):
        if(list[self.index] == '#'):
            node = Node(None, None, None)
            self.index += 1
            node.left = self.deserialize(list)
            node.right = self.deserialize(list)   
        else:
            node = Node(list[self.index], None, None)
            self.index += 1
        return node
    
    def printTree(self, node):
        if(node.value == None):
            print(None)
            self.printTree(node.left)
            self.printTree(node.right)
        else:
            print(node.value)
    
    def storeTreeInFile(self, outfile, serializedTree):
        #Storing Tree length in file
        sizeArray = struct.pack('Q', len(serializedTree))
        outfile.write(sizeArray)
        outfile.write(bytes('*', 'utf8')) #To create seperation between Size and content
        #Storing Tree Data
        for ii in serializedTree:
            if isinstance(ii, str):
                outfile.write(bytes(ii, 'utf8'))
            else:
                byte = struct.pack('f', ii)
                outfile.write(byte)
        outfile.write(bytes('*', 'utf-8'))
        return outfile
        
    def storeDataInFile(self, outfile, padded_encoded_text):
        data = bytearray()
        for i in range(0, len(padded_encoded_text), 8):
            byte = padded_encoded_text[i:i+8]
            data.append(int(byte, 2))
            
        sizeArray = struct.pack('Q', len(bytes(data)))
        outfile.write(sizeArray)
        outfile.write(bytes('*', 'utf8')) #To create seperation between Size and content
        
        outfile.write(bytes(data))
        return outfile
    
    def fetchTreeFromFile(self, infile):
        #Reading Size of tree
        print("Inside Tree", infile.tell())
        sizeArray = infile.read(8)
        size = struct.unpack('Q',sizeArray)[0]
        byte = infile.read(1)

        treeArray = []
        last_pos = 0
        byte = infile.read(1)
        if byte == b'#':
            treeArray.append(byte.decode('utf8'))
        else:
            infile.seek(last_pos)
            byte = infile.read(4)
            byte = struct.unpack('f', byte)[0]
            print(byte)
            treeArray.append(byte)
        while True:
            last_pos = infile.tell()
            byte = infile.read(1)
            if byte == b'*':
                break
            elif byte == b'#':
                treeArray.append(byte.decode('utf8'))        
            else:
                if len(treeArray) == size: # <-- Need to fix this
                    break
                infile.seek(last_pos)
                byte = infile.read(4)
                byte = struct.unpack('f', byte)[0]
                treeArray.append(byte)
        return treeArray, infile
    
    def fetchDataFromFile(self, infile):
        sizeArray = infile.read(8)
        
        size = struct.unpack('Q',sizeArray)[0]
        print("In Fetch Data function Size fetched", size)
        byte = infile.read(1)
        bitString = ""
        byte = infile.read(1)
        count = 1
        while(count <= size): 
            byte = ord(byte)
            bits = bin(byte)[2:].rjust(8, '0')
            bitString += bits
            byte = infile.read(1)
            count += 1
        print("In Fetch Data function count, size", count,size)
        infile.seek(infile.tell()-1)
        return bitString, infile