# Understanding `jax.lax.bitcast_convert_type`

A bitcast reinterprets the **raw bits** of an array as a different type — without changing the underlying binary data.

- **Normal cast** (`astype`): changes bits to preserve the **value**
- **Bitcast**: preserves the **bits**, reinterprets them as a new type

In [1]:
import jax
import jax.numpy as jnp

## Example 1: float32 → int32 (same bit-width)

IEEE 754 `float32` for `1.0` is stored as bits `0 01111111 00000000000000000000000` = `0x3F800000` = `1065353216` in decimal.
The bitcast reads those same bits as an integer.

In [2]:
x = jnp.array(1.0, dtype=jnp.float32)

# Bitcast: reinterpret the raw bits as int32
y_bitcast = jax.lax.bitcast_convert_type(x, jnp.int32)

# Normal cast: convert the value
y_cast = x.astype(jnp.int32)

print(f"Original float32:     {x}")
print(f"Bitcast to int32:     {y_bitcast}  (== 0x{int(y_bitcast):08X})")
print(f"Normal cast to int32: {y_cast}")
print()
print("Notice: bitcast gives 1065353216 (the IEEE 754 bit pattern),")
print("        normal cast gives 1 (the numeric value).")

Original float32:     1.0
Bitcast to int32:     1065353216  (== 0x3F800000)
Normal cast to int32: 1

Notice: bitcast gives 1065353216 (the IEEE 754 bit pattern),
        normal cast gives 1 (the numeric value).


## Example 2: Roundtrip — bitcast is lossless

Bitcasting to int32 and back to float32 gives back the original value.

In [3]:
x = jnp.array([1.0, -2.5, 3.14], dtype=jnp.float32)

# Roundtrip: float32 → int32 → float32
as_int = jax.lax.bitcast_convert_type(x, jnp.int32)
back_to_float = jax.lax.bitcast_convert_type(as_int, jnp.float32)

print(f"Original:   {x}")
print(f"As int32:   {as_int}")
print(f"Back:       {back_to_float}")
print(f"Exact match: {jnp.array_equal(x, back_to_float)}")

Original:   [ 1.   -2.5   3.14]
As int32:   [ 1065353216 -1071644672  1078523331]
Back:       [ 1.   -2.5   3.14]
Exact match: True


## Example 3: Different bit-widths — shape changes

When source and target types have different widths, the **last dimension changes size**:

| Source → Target | Bit-width | Shape change |
|---|---|---|
| float32 → uint8 | 32 → 8 | Last dim expands 4× |
| uint8 → float32 | 8 → 32 | Last dim shrinks 4× |
| bfloat16 → int32 | 16 → 32 | Last dim shrinks 2× |

In [None]:
# float32 (32-bit) → uint8 (8-bit): each float becomes 4 bytes
x = jnp.array([1.0, 2.0], dtype=jnp.float32)
y = jax.lax.bitcast_convert_type(x, jnp.uint8)

print(f"float32 shape: {x.shape} → uint8 shape: {y.shape}")
print(f"Each float32 split into 4 bytes:")
# ':02X' means
# - X — format the integer as uppercase hexadecimal (e.g., 255 → FF)
# - 02 — pad with leading zeros to at least 2 characters (e.g., 0 → 00, 10 → 0A)
for i in range(len(x)):
    decimal_bytes = ' '.join(f'{b}' for b in y[i])
    print(f"decimal_bytes:  {float(x[i])} → [{decimal_bytes}]")
    hex_bytes = ' '.join(f'0x{b:02X}' for b in y[i])
    print(f"hex_bytes:  {float(x[i])} → [{hex_bytes}]")

float32 shape: (2,) → uint8 shape: (2, 4)
Each float32 split into 4 bytes:
decimal_bytes:  1.0 → [0 0 128 63]
hex_bytes:  1.0 → [0x00 0x00 0x80 0x3F]
decimal_bytes:  2.0 → [0 0 0 64]
hex_bytes:  2.0 → [0x00 0x00 0x00 0x40]


Question: Why does 1.0 become [0 0 128 63] after
  jax.lax.bitcast_convert_type(x, jnp.uint8)?

A: 
Float32 1.0 in IEEE 754 has the bit pattern 0x3F800000. Split into 4
  bytes from most significant to least:
```
  3F  80  00  00
  63  128  0   0
```
  But the system stores bytes in little-endian order (least significant
  byte first), so in memory the order is reversed:
```
  00  00  80  3F
   0   0  128  63
```
  That's why you get [0, 0, 128, 63] instead of [63, 128, 0, 0].

Question: Why does 2.0 become [0 0 0 64] after jax.lax.bitcast_convert_type(x,
   jnp.uint8)?

A:
Float32 2.0 in IEEE 754:

  - Sign: 0 (positive)
  - Exponent: 10000000 (128 biased, meaning actual exponent = 128 - 127 = 1)
  - Mantissa: 00000000000000000000000 (implicit leading 1, so value = 1.0 × 2¹ = 2.0)

  Full 32 bits: 0 10000000 00000000000000000000000 = 0x40000000

  Split into bytes (big-endian):
```
  40  00  00  00
  64   0   0   0
```
  Stored in little-endian (reversed):
```
  00  00  00  40
   0   0   0  64
```
  So you get [0, 0, 0, 64].

In [6]:
# Reverse: uint8 → float32 (4 bytes collapse into 1 float)
# These are the little-endian bytes for float32 1.0 = 0x3F800000
a = jnp.array([[0x00, 0x00, 0x80, 0x3F]], dtype=jnp.uint8)
b = jax.lax.bitcast_convert_type(a, jnp.float32)

print(f"uint8 shape: {a.shape} → float32 shape: {b.shape}")
print(f"bytes [00, 00, 80, 3F] → {b}")

uint8 shape: (1, 4) → float32 shape: (1,)
bytes [00, 00, 80, 3F] → [1.]


## Example 4: Practical use — packing bf16 into int32

This is common in TPU code for efficient storage and communication.
Two `bfloat16` values (16 bits each) fit into one `int32` (32 bits).

In [7]:
x = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=jnp.bfloat16)
print(f"Original bf16: {x}  shape={x.shape}")

# Reshape to group pairs along the last dim
x_pairs = x.reshape(2, 2)
print(f"Paired:        {x_pairs}  shape={x_pairs.shape}")

# Pack: two bf16 → one int32
packed = jax.lax.bitcast_convert_type(x_pairs, jnp.int32)
print(f"Packed int32:  {packed}  shape={packed.shape}")

# Unpack: one int32 → two bf16
unpacked = jax.lax.bitcast_convert_type(packed, jnp.bfloat16)
print(f"Unpacked bf16: {unpacked}  shape={unpacked.shape}")
print(f"Exact match:   {jnp.array_equal(x_pairs, unpacked)}")

Original bf16: [1 2 3 4]  shape=(4,)
Paired:        [[1 2]
 [3 4]]  shape=(2, 2)
Packed int32:  [1073758080 1082146880]  shape=(2,)
Unpacked bf16: [[1 2]
 [3 4]]  shape=(2, 2)
Exact match:   True


## Example 6: Bitcast vs astype — they are very different!

This drives home the distinction.

In [10]:
values = jnp.array([0.0, 0.5, 1.0, 2.0, 100.0], dtype=jnp.float32)

bitcasted = jax.lax.bitcast_convert_type(values, jnp.int32)
casted = values.astype(jnp.int32)

print(f"{'float32':>8s}  {'bitcast int32':>14s}  {'astype int32':>13s}")
print("-" * 40)
for v, bc, c in zip(values, bitcasted, casted):
    print(f"{float(v):>8.1f}  {int(bc):>14d}  {int(c):>13d}")

print()
print("astype preserves the VALUE (0.5 → 0, 100.0 → 100).")
print("bitcast preserves the BITS (gives the IEEE 754 encoding).")

 float32   bitcast int32   astype int32
----------------------------------------
     0.0               0              0
     0.5      1056964608              0
     1.0      1065353216              1
     2.0      1073741824              2
   100.0      1120403456            100

astype preserves the VALUE (0.5 → 0, 100.0 → 100).
bitcast preserves the BITS (gives the IEEE 754 encoding).


## Playground

Try your own experiments below!

In [None]:
# Try it yourself!
x = jnp.array([42.0], dtype=jnp.float32)
print(f"Your value: {x}")
print(f"As int32 bits: {jax.lax.bitcast_convert_type(x, jnp.int32)}")
print(f"As uint8 bytes: {jax.lax.bitcast_convert_type(x, jnp.uint8)}")