In [3]:
def xor_bytes(lhs, rhs):
    return bytes(map(lambda x: x[0] ^^ x[1], zip(lhs, rhs)))

def hmac(hash_func, key: bytes, message: bytes):
    digest_size = hash_func().digest_size
    
    key_with_digest_size = key
    if len(key) < digest_size:
        key_with_digest_size += bytes(digest_size - len(key))
    elif len(key) > digest_size:
        key_with_digest_size = hash_func(key).digest()
    
    ipad = bytes([0x36 for _ in range(digest_size)])
    opad = bytes([0x5c for _ in range(digest_size)])
    
    key_with_ipad = xor_bytes(key_with_digest_size, ipad)
    key_with_opad = xor_bytes(key_with_digest_size, opad)
    
    return hash_func(key_with_opad + hash_func(key_with_ipad + message).digest()).digest()

In [7]:
def prf(hash_func, secret: bytes, label: bytes, init_seed: bytes, N):
    seed = label + init_seed
    
    random_bytes = b''
    
    a = seed
    while len(random_bytes) < N:
        a = hmac(hash_func, secret, a)
        random_bytes += hmac(hash_func, secret, a + seed)
    
    return random_bytes[:N]

In [8]:
prf(hashlib.sha256, b'secret', b'label', b'seed', 40)

b'\xb6~\xe1\xd0\x96\x9f\xa5\xf9\xf6\xe6/\x07UN\xa0\xcb^\x04.\xa4\xe4f\xe2\xb0\x8b"=\x98rz\xd2\xce\xb0\x82\x99%-U\xf3X'