Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added some typing #13

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions dahuffman/huffmancodec.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,47 @@
import collections
import itertools
from io import IOBase
import sys
from heapq import heappush, heappop, heapify

import logging
import pickle
from pathlib import Path
from typing import Union, Any
from typing import (Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, Mapping, Iterable)

_log = logging.getLogger(__name__)


SymElementT = Union[str, int, bytes]
SymT = Union[Tuple[SymElementT, ...], SymElementT, "_EndOfFileSymbol"]
CodeT = Tuple[int, ...]
CodeTableT = Dict[SymT, CodeT]
SymTSeq = Iterable[SymT]


class _EndOfFileSymbol:
"""
Internal class for "end of file" symbol to be able
to detect the end of the encoded bit stream,
which does not necessarily align with byte boundaries.
"""

def __repr__(self):
def __repr__(self) -> str:
return '_EOF'

# Because _EOF will be compared with normal symbols (strings, bytes),
# we have to provide a minimal set of comparison methods.
# We'll make _EOF smaller than the rest (meaning lowest frequency)
def __lt__(self, other):
def __lt__(self, other: SymT) -> bool:
return True

def __gt__(self, other):
def __gt__(self, other: SymT) -> bool:
return False

def __eq__(self, other):
def __eq__(self, other: SymT) -> bool:
return other.__class__ == self.__class__

def __hash__(self):
def __hash__(self) -> int:
return hash(self.__class__)


Expand All @@ -44,7 +52,7 @@ def __hash__(self):
# TODO store/load code table from file
# TODO Directly encode to and decode from file

def _guess_concat(data):
def _guess_concat(data: Any) -> Union[Type[bytes], Type[list], Callable]:
"""
Guess concat function from given data
"""
Expand All @@ -67,7 +75,7 @@ class PrefixCodec:
Prefix code codec, using given code table.
"""

def __init__(self, code_table, concat=list, check=True, eof=_EOF):
def __init__(self, code_table: CodeTableT, concat: Callable = list, check: bool = True, eof: SymT = _EOF) -> None:
"""
Initialize codec with given code table.

Expand All @@ -87,14 +95,14 @@ def __init__(self, code_table, concat=list, check=True, eof=_EOF):
)
# TODO check if code table is actually a prefix code

def get_code_table(self):
def get_code_table(self) -> CodeTableT:
"""
Get code table
:return: dictionary mapping symbol to code tuple (bitsize, value)
"""
return self._table

def print_code_table(self, out=sys.stdout):
def print_code_table(self, out: IOBase = sys.stdout) -> None:
"""
Print code table overview
"""
Expand All @@ -113,7 +121,7 @@ def print_code_table(self, out=sys.stdout):
for row in zip(*columns):
out.write(template.format(*row))

def encode(self, data):
def encode(self, data: Any) -> bytes:
"""
Encode given data.

Expand All @@ -122,7 +130,7 @@ def encode(self, data):
"""
return bytes(self.encode_streaming(data))

def encode_streaming(self, data):
def encode_streaming(self, data: Any) -> Iterator[int]:
"""
Encode given data in streaming fashion.

Expand Down Expand Up @@ -161,7 +169,7 @@ def encode_streaming(self, data):
byte = buffer << (8 - size)
yield byte

def decode(self, data, concat=None):
def decode(self, data: Iterable[int], concat: Optional[Callable] = None) -> Any:
"""
Decode given data.

Expand All @@ -171,7 +179,7 @@ def decode(self, data, concat=None):
"""
return (concat or self._concat)(self.decode_streaming(data))

def decode_streaming(self, data):
def decode_streaming(self, data: Iterable[int]) -> Iterator[SymT]:
"""
Decode given data in streaming fashion

Expand All @@ -195,7 +203,7 @@ def decode_streaming(self, data):
buffer = 0
size = 0

def save(self, path: Union[str, Path], metadata: Any = None):
def save(self, path: Union[str, Path], metadata: Any = None) -> None:
"""
Persist the code table to a file.
:param path: file path to persist to
Expand All @@ -220,7 +228,7 @@ def save(self, path: Union[str, Path], metadata: Any = None):
))

@staticmethod
def load(path: Union[str, Path]) -> 'PrefixCodec':
def load(path: Union[str, Path]) -> "PrefixCodec":
"""
Load a persisted PrefixCodec
:param path: path to serialized PrefixCodec code table data.
Expand All @@ -245,7 +253,7 @@ class HuffmanCodec(PrefixCodec):
"""

@classmethod
def from_frequencies(cls, frequencies, concat=None, eof=_EOF):
def from_frequencies(cls, frequencies: Union[collections.Counter, Mapping[SymT, int]], concat: Optional[Callable] = None, eof: SymT = _EOF) -> "HuffmanCodec":
"""
Build Huffman code table from given symbol frequencies
:param frequencies: symbol to frequency mapping
Expand Down Expand Up @@ -280,7 +288,7 @@ def from_frequencies(cls, frequencies, concat=None, eof=_EOF):
return cls(table, concat=concat, check=False, eof=eof)

@classmethod
def from_data(cls, data):
def from_data(cls, data: Any) -> "HuffmanCodec":
"""
Build Huffman code table from symbol sequence

Expand Down