# Protobuf `varint`

The exercise is to implement an encoder/decoder that converts unsigned 64-bit integers to/from variable-width integers. [The format used by Protobuf](https://protobuf.dev/programming-guides/encoding/#varints) is that each byte is made up of a continuation bit and 7 payload bits.

## Decoding

The Protobuf docs have a worked example of decoding `0x9601` to 150 so let's start with decoding.

First of all, let's declare some useful helper functions.

TODO: Do I even need this?

In [None]:
def num_bytes(n: int) -> int:
    # Initial version: return math.ceil(math.floor(math.log2(n) + 1) / 8)
    return (n.bit_length() + 7) // 8


def binary(bytes: bytes) -> str:
    return "_".join(format(byte, "09_b") for byte in bytes)


def nibbles(n: int) -> str:
    return format(n, "09_b")


def test(got, want):
    print("got: ", got)
    print("want:", want)

There are three steps in decoding:
1. Drop continuation bits
2. Convert to big-endian
3. Concatenate and interpret as an unsigned 64-bit integer

In [None]:
def decode(bytes: bytes) -> int:
    # How can I drop a bit??
    concat = 0
    for i, byte in enumerate(bytes):
        # 1. Drop the MSB/continuation byte
        msb = byte & 0b10000000  # Read continuation bit
        byte &= 0b01111111  # Unset the continuation bit

        # 2. Convert to big-endian order
        # 3. Concatenate and interpret as an integer
        concat += byte << (i * 7)

        # See if we need to keep going
        if not msb:
            break

    n = concat.to_bytes(num_bytes(concat))
    return int.from_bytes(n)

I wasn't sure how best to work with bytes in Python, but it turns out that you don't really need to bother with the `bytes` object. Instead, you can do everything you need on integers with bitmasks and shifting. I found it confusing trying to operate on a single byte at a time but you can visualise operations as:

    integer: ... 01010101 10110001 -> ... 01010101
    bitmask:              01111111 ->     01111111

and you know you've processed all bytes when all bits are zero.

A better version of `decode`:

In [None]:
def decode(varint: bytes) -> int:
    decoded = 0
    for byte in reversed(varint):
        decoded <<= 7 # Make room for byte
        decoded |= byte & 0x7f

    return decoded


In [None]:
print("Decode tests:")
test(decode(bytes([0b10010110, 0b00000001])), 150)
test(decode(bytes.fromhex("9601")), 150)
test(decode(bytes([0b00000001])), 1)
test(decode((1).to_bytes()), 1)

## Encoding

Now, let's get to encoding.

In [None]:
def encode(n: int) -> bytes:
    # 1. Group into bytes, leaving room for the continuation bits
    a = 0
    i = 0
    while n:
        a += (n & 0b01111111) << i * 8
        n >>= 7
        i += 1

    # 2. Convert to little-endian order
    bytes = bytearray(a.to_bytes(num_bytes(a), "little"))

    # 3. Add continuation bit
    for i in range(len(bytes) - 1):
        bytes[i] |= 0b10000000

    return bytes

All these conversions are unnecessary:

In [None]:
def encode(n: int) -> bytes:
    encoded = []
    while n > 127:
        b = 0x80 # Set continuation bit
        b |= n & 0x7f # Set payload
        n >>= 7
        encoded.append(b)
    encoded.append(n)

    return bytes(encoded)

In [None]:
print("Encode tests:")
test(binary(encode(150)), "1001_0110_0000_0001")
test(decode(encode(150)), 150)


## Extensions

### Benchmarks

TODO: Compare Python versions and C

### Encoding and decoding adjacent `varint`s

The whole point of the continuation bit is so you can tell when one `varint` ends and another begins. In the tests so far, we haven't covered that.

### ZigZag encoding

TODO