Skip to content

Commit

Permalink
Issue #12 Add type annotations based on PR #13 by KOLANICH
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Dec 2, 2021
1 parent f20c2a2 commit 4b7e087
Showing 1 changed file with 24 additions and 19 deletions.
43 changes: 24 additions & 19 deletions dahuffman/huffmancodec.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import collections
import itertools
from io import IOBase
import sys
from heapq import heappush, heappop, heapify

import logging
import pickle
from pathlib import Path
from typing import Union, Any
from typing import Union, Any, Callable, Iterator, Optional, Mapping, Iterable

_log = logging.getLogger(__name__)

Expand All @@ -18,22 +19,22 @@ class _EndOfFileSymbol:
which does not necessarily align with byte boundaries.
"""

def __repr__(self):
def __repr__(self) -> str:
return '_EOF'

# Because _EOF will be compared with normal symbols (strings, bytes),
# we have to provide a minimal set of comparison methods.
# We'll make _EOF smaller than the rest (meaning lowest frequency)
def __lt__(self, other):
def __lt__(self, other) -> bool:
return True

def __gt__(self, other):
def __gt__(self, other) -> bool:
return False

def __eq__(self, other):
def __eq__(self, other) -> bool:
return other.__class__ == self.__class__

def __hash__(self):
def __hash__(self) -> int:
return hash(self.__class__)


Expand All @@ -44,7 +45,7 @@ def __hash__(self):
# TODO store/load code table from file
# TODO Directly encode to and decode from file

def _guess_concat(data):
def _guess_concat(data: Any) -> Callable:
"""
Guess concat function from given data
"""
Expand All @@ -67,7 +68,7 @@ class PrefixCodec:
Prefix code codec, using given code table.
"""

def __init__(self, code_table, concat=list, check=True, eof=_EOF):
def __init__(self, code_table: dict, concat: Callable = list, check: bool = True, eof=_EOF):
"""
Initialize codec with given code table.
Expand All @@ -87,14 +88,14 @@ def __init__(self, code_table, concat=list, check=True, eof=_EOF):
)
# TODO check if code table is actually a prefix code

def get_code_table(self):
def get_code_table(self) -> dict:
"""
Get code table
:return: dictionary mapping symbol to code tuple (bitsize, value)
"""
return self._table

def print_code_table(self, out=sys.stdout):
def print_code_table(self, out: IOBase = sys.stdout) -> None:
"""
Print code table overview
"""
Expand All @@ -113,7 +114,7 @@ def print_code_table(self, out=sys.stdout):
for row in zip(*columns):
out.write(template.format(*row))

def encode(self, data):
def encode(self, data: Union[str, bytes, Iterable]) -> bytes:
"""
Encode given data.
Expand All @@ -122,12 +123,12 @@ def encode(self, data):
"""
return bytes(self.encode_streaming(data))

def encode_streaming(self, data):
def encode_streaming(self, data: Union[str, bytes, Iterable]) -> Iterator[int]:
"""
Encode given data in streaming fashion.
:param data: sequence of symbols (e.g. byte string, unicode string, list, iterator)
:return: generator of bytes (single character strings in Python2, ints in Python 3)
:return: generator of bytes
"""
# Buffer value and size
buffer = 0
Expand Down Expand Up @@ -161,7 +162,9 @@ def encode_streaming(self, data):
byte = buffer << (8 - size)
yield byte

def decode(self, data, concat=None):
def decode(
self, data: Union[bytes, Iterable[int]], concat: Optional[Callable] = None
) -> Union[str, bytes, Iterable]:
"""
Decode given data.
Expand All @@ -171,7 +174,7 @@ def decode(self, data, concat=None):
"""
return (concat or self._concat)(self.decode_streaming(data))

def decode_streaming(self, data):
def decode_streaming(self, data: Union[bytes, Iterable[int]]) -> Iterator:
"""
Decode given data in streaming fashion
Expand All @@ -195,7 +198,7 @@ def decode_streaming(self, data):
buffer = 0
size = 0

def save(self, path: Union[str, Path], metadata: Any = None):
def save(self, path: Union[str, Path], metadata: Any = None) -> None:
"""
Persist the code table to a file.
:param path: file path to persist to
Expand All @@ -220,7 +223,7 @@ def save(self, path: Union[str, Path], metadata: Any = None):
))

@staticmethod
def load(path: Union[str, Path]) -> 'PrefixCodec':
def load(path: Union[str, Path]) -> "PrefixCodec":
"""
Load a persisted PrefixCodec
:param path: path to serialized PrefixCodec code table data.
Expand All @@ -245,7 +248,9 @@ class HuffmanCodec(PrefixCodec):
"""

@classmethod
def from_frequencies(cls, frequencies, concat=None, eof=_EOF):
def from_frequencies(
cls, frequencies: Union[dict, Mapping], concat: Optional[Callable] = None, eof=_EOF
) -> "HuffmanCodec":
"""
Build Huffman code table from given symbol frequencies
:param frequencies: symbol to frequency mapping
Expand Down Expand Up @@ -280,7 +285,7 @@ def from_frequencies(cls, frequencies, concat=None, eof=_EOF):
return cls(table, concat=concat, check=False, eof=eof)

@classmethod
def from_data(cls, data):
def from_data(cls, data: Union[str, bytes, Iterable]) -> "HuffmanCodec":
"""
Build Huffman code table from symbol sequence
Expand Down

0 comments on commit 4b7e087

Please sign in to comment.