diff --git a/base62.py b/base62.py index cc96a6f..e02ba13 100644 --- a/base62.py +++ b/base62.py @@ -16,24 +16,18 @@ CHARSET_INVERTED = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" -def encode(n, minlen=1, charset=CHARSET_DEFAULT): +def encode(n, charset=CHARSET_DEFAULT): """Encodes a given integer ``n``.""" chs = [] while n > 0: - r = n % BASE - n //= BASE + n, r = divmod(n, BASE) + chs.insert(0, charset[r]) - chs.append(charset[r]) + if not chs: + return "0" - if len(chs) > 0: - chs.reverse() - else: - chs.append("0") - - s = "".join(chs) - s = charset[0] * max(minlen - len(s), 0) + s - return s + return "".join(chs) def encodebytes(barray, charset=CHARSET_DEFAULT): @@ -45,7 +39,27 @@ def encodebytes(barray, charset=CHARSET_DEFAULT): """ _check_type(barray, bytes) - return encode(int.from_bytes(barray, "big"), charset=charset) + + # Count the number of leading zeros. + leading_zeros_count = 0 + for i in range(len(barray)): + if barray[i] != 0: + break + leading_zeros_count += 1 + + # Encode the leading zeros as "0" followed by a character indicating the count. + # This pattern may occur several times if there are many leading zeros. + n, r = divmod(leading_zeros_count, len(charset) - 1) + zero_padding = f"0{charset[-1]}" * n + if r: + zero_padding += f"0{charset[r]}" + + # Special case: the input is empty, or is entirely null bytes. + if leading_zeros_count == len(barray): + return zero_padding + + value = encode(int.from_bytes(barray, "big"), charset=charset) + return zero_padding + value def decode(encoded, charset=CHARSET_DEFAULT): @@ -56,9 +70,6 @@ def decode(encoded, charset=CHARSET_DEFAULT): """ _check_type(encoded, str) - if encoded.startswith("0z"): - encoded = encoded[2:] - l, i, v = len(encoded), 0, 0 for x in encoded: v += _value(x, charset=charset) * (BASE ** (l - (i + 1))) @@ -75,6 +86,11 @@ def decodebytes(encoded, charset=CHARSET_DEFAULT): :rtype: bytes """ + leading_null_bytes = b"" + while encoded.startswith("0") and len(encoded) >= 2: + leading_null_bytes += b"\x00" * _value(encoded[1], charset) + encoded = encoded[2:] + decoded = decode(encoded, charset=charset) buf = bytearray() while decoded > 0: @@ -82,7 +98,7 @@ def decodebytes(encoded, charset=CHARSET_DEFAULT): decoded //= 256 buf.reverse() - return bytes(buf) + return leading_null_bytes + bytes(buf) def _value(ch, charset): diff --git a/tests/test_basic.py b/tests/test_basic.py index 73f1d09..bade00f 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -4,7 +4,6 @@ bytes_int_pairs = [ - (b"\x00", 0), (b"\x01", 1), (b"\x01\x01", 0x0101), (b"\xff\xff", 0xFFFF), @@ -20,9 +19,6 @@ def test_const(): def test_basic(): assert base62.encode(0) == "0" - assert base62.encode(0, minlen=0) == "0" - assert base62.encode(0, minlen=1) == "0" - assert base62.encode(0, minlen=5) == "00000" assert base62.decode("0") == 0 assert base62.decode("0000") == 0 assert base62.decode("000001") == 1 @@ -30,19 +26,11 @@ def test_basic(): assert base62.encode(34441886726) == "base62" assert base62.decode("base62") == 34441886726 - # NOTE: For backward compatibility. When I first wrote this module in PHP, - # I used to use the `0z` prefix to denote a base62 encoded string (similar - # to `0x` for hexadecimal strings). - assert base62.decode("0zbase62") == 34441886726 - def test_basic_inverted(): kwargs = {"charset": base62.CHARSET_INVERTED} assert base62.encode(0, **kwargs) == "0" - assert base62.encode(0, minlen=0, **kwargs) == "0" - assert base62.encode(0, minlen=1, **kwargs) == "0" - assert base62.encode(0, minlen=5, **kwargs) == "00000" assert base62.decode("0", **kwargs) == 0 assert base62.decode("0000", **kwargs) == 0 assert base62.decode("000001", **kwargs) == 1 @@ -50,11 +38,6 @@ def test_basic_inverted(): assert base62.encode(10231951886, **kwargs) == "base62" assert base62.decode("base62", **kwargs) == 10231951886 - # NOTE: For backward compatibility. When I first wrote this module in PHP, - # I used to use the `0z` prefix to denote a base62 encoded string (similar - # to `0x` for hexadecimal strings). - assert base62.decode("0zbase62", **kwargs) == 10231951886 - @pytest.mark.parametrize("b, i", bytes_int_pairs) def test_bytes_to_int(b, i): @@ -77,7 +60,7 @@ def test_encodebytes_rtype(): assert isinstance(encoded, str) -@pytest.mark.parametrize("s", ["0", "1", "a", "z", "ykzvd7ga", "0z1234"]) +@pytest.mark.parametrize("s", ["0", "1", "a", "z", "ykzvd7ga"]) def test_decodebytes(s): assert int.from_bytes(base62.decodebytes(s), "big") == base62.decode(s) @@ -113,3 +96,23 @@ def test_invalid_alphabet(): def test_invalid_string(): with pytest.raises(TypeError): base62.encodebytes({}) + + +@pytest.mark.parametrize( + "input_bytes, expected_encoded_text", + ( + (b"", ""), + (b"\x00", "01"), + (b"\x00\x00", "02"), + (b"\x00\x01", "011"), + (b"\x00" * 61, "0z"), + (b"\x00" * 62, "0z01"), + ), +) +def test_leading_zeros(input_bytes, expected_encoded_text): + """Verify that leading null bytes are not lost.""" + + encoded_text = base62.encodebytes(input_bytes) + assert encoded_text == expected_encoded_text + output_bytes = base62.decodebytes(encoded_text) + assert output_bytes == input_bytes