In [61]:
from typing import Any, Literal

import torch
from torch import nn

def remainder2bit(remainder: torch.Tensor, num_bits: int = 127) -> torch.Tensor:
    exponent_bits = torch.arange(num_bits, device=remainder.device).type(remainder.type())
    exponent_bits = exponent_bits.repeat(remainder.shape + (1,))
    out = (remainder.unsqueeze(-1) * 2 ** exponent_bits) % 1
    return torch.floor(2 * out)


def integer2bit(integer: torch.Tensor, num_bits: int = 8) -> torch.Tensor:
    exponent_bits = - torch.arange(-(num_bits - 1), 1, device=integer.device).type(integer.type())
    exponent_bits = exponent_bits.repeat(integer.shape + (1,))
    out = integer.unsqueeze(-1) / 2 ** exponent_bits
    return (out - (out % 1)) % 2

# https://github.com/SymposiumOrganization/NeuralSymbolicRegressionThatScales/blob/main/src/nesymres/architectures/set_encoder.py
def float2bit(f: torch.Tensor, num_e_bits: int = 5, num_m_bits: int = 10, bias: int = 127, dtype: Any = torch.float32) -> torch.Tensor:
    # Create output tensor with same shape as input plus bits dimension
    output_shape = list(f.shape) + [1 + num_e_bits + num_m_bits]
    result = torch.zeros(output_shape, device=f.device, dtype=dtype)

    # Handle special cases
    is_nan = torch.isnan(f)
    is_inf = torch.isinf(f)
    is_neg_inf = is_inf & (f < 0)
    is_pos_inf = is_inf & (f > 0)
    is_normal = ~(is_nan | is_inf)

    # For normal numbers, use existing logic
    if torch.any(is_normal):
        normal_vals = f[is_normal]

        # SIGN BIT
        s = (torch.sign(normal_vals + 0.001) * -1 + 1) * 0.5
        s = s.unsqueeze(-1)
        f1 = torch.abs(normal_vals)

        # EXPONENT BIT
        e_scientific = torch.floor(torch.log2(f1))
        e_scientific[e_scientific == float("-inf")] = -(2 ** (num_e_bits - 1) - 1)
        e_decimal = e_scientific + (2 ** (num_e_bits - 1) - 1)
        e = integer2bit(e_decimal, num_bits=num_e_bits)

        # MANTISSA
        f2 = f1 / 2 ** e_scientific
        m2 = remainder2bit(f2 % 1, num_bits=bias)
        fin_m = m2[..., :num_m_bits]

        normal_result = torch.cat([s, e, fin_m], dim=-1).type(dtype)
        result[is_normal] = normal_result

    # Handle NaN
    if torch.any(is_nan):
        # Set all exponent bits to 1 and non-zero mantissa (conventionally first mantissa bit is 1)
        # nan_pattern = torch.zeros(result.shape[:-1]).type(dtype)
        nan_pattern = torch.zeros(num_e_bits + num_m_bits + 1, device=f.device, dtype=dtype)

        # Set exponent bits (all 1s)
        nan_pattern[1:1 + num_e_bits] = 1

        # Set first mantissa bit to 1
        nan_pattern[1 + num_e_bits] = 1

        result[is_nan] = nan_pattern

    # Handle positive infinity
    if torch.any(is_pos_inf):
        inf_pattern = torch.zeros(num_e_bits + num_m_bits + 1, device=f.device, dtype=dtype)
        # Sign bit is 0
        # Set all exponent bits to 1
        inf_pattern[1:1 + num_e_bits] = 1
        # Mantissa is all zeros
        result[is_pos_inf] = inf_pattern

    # Handle negative infinity
    if torch.any(is_neg_inf):
        neg_inf_pattern = torch.zeros(num_e_bits + num_m_bits + 1, device=f.device, dtype=dtype)
        # Sign bit is 1
        neg_inf_pattern[0] = 1
        # Set all exponent bits to 1
        neg_inf_pattern[1:1 + num_e_bits] = 1
        # Mantissa is all zeros
        result[is_neg_inf] = neg_inf_pattern

    return result.type(dtype)

In [67]:
X1 = torch.randn((3,))

# Set random values to NaN
X1[0] = float("nan")

print(X1)

tensor([    nan, -0.5378,  0.2399])


In [None]:
f1 = float2bit(X1)
f1

In [69]:
X2 = torch.randn((3, 4))

# Set random values to NaN
X2[0, 0] = float("nan")

print(X2)

tensor([[    nan, -1.1179,  0.3784,  0.2911],
        [ 2.5097, -1.0462, -0.1275, -0.6664],
        [-1.5127, -0.7032, -0.9884, -0.7022]])


In [None]:
f2 = float2bit(X2)
f2

In [74]:
f2[0, 0]

tensor([0., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [75]:
assert (f1[0] == f2[0, 0]).all()