# Binary Positional Encoding — Demo

**Idea:** Represent position `i` as its binary bit pattern, then project to `d_model`.  
**Formula:** `bit_k(i) = (i >> k) & 1`, projected via `Linear(n_bits, d_model)`  
**Properties:** No trig ops (fast) | Deterministic | One learnable projection layer  
**Best for:** Short sequences, edge/resource-constrained devices

In [None]:
import math
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
class BinaryPositionalEncoding(nn.Module):
    '''
    Binary PE: encodes position i as its binary bit pattern,
    then projects from n_bits -> d_model via a learned linear layer.
    '''
    def __init__(self, d_model, max_seq_len=512, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        n_bits = math.ceil(math.log2(max_seq_len + 1))

        # Build binary table: position i -> n_bits bit vector
        pe_binary = torch.zeros(max_seq_len, n_bits)
        for pos in range(max_seq_len):
            for bit in range(n_bits):
                pe_binary[pos, bit] = float((pos >> bit) & 1)

        self.register_buffer('pe_binary', pe_binary.unsqueeze(0))  # (1, max_seq_len, n_bits)
        self.projection = nn.Linear(n_bits, d_model, bias=False)   # learned projection

        print(f'Binary PE: {n_bits} bits -> {d_model} dims  |  projection params: {n_bits * d_model}')

    def forward(self, x):
        seq_len = x.size(1)
        pe = self.projection(self.pe_binary[:, :seq_len, :])  # (1, seq_len, d_model)
        x = x + pe
        return self.dropout(x)

print('BinaryPositionalEncoding defined.')

In [None]:
# Sanity check — shape test
d_model, seq_len, batch = 64, 50, 4
pe_layer = BinaryPositionalEncoding(d_model, max_seq_len=512, dropout=0.0)

x = torch.zeros(batch, seq_len, d_model)
out = pe_layer(x)

print(f'Input shape : {x.shape}')
print(f'Output shape: {out.shape}')
print(f'Learnable params: {sum(p.numel() for p in pe_layer.parameters())}')

# Show first 5 position binary patterns (before projection)
print('\nFirst 5 positions binary patterns (raw bits):')
n_bits = pe_layer.pe_binary.shape[-1]
for pos in range(5):
    bits = [int(pe_layer.pe_binary[0, pos, b].item()) for b in range(n_bits)]
    print(f'  pos {pos}: {bits}')

In [None]:
# Heatmap — visualize raw binary patterns and projected PE
d_model, seq_len = 64, 60
pe_layer = BinaryPositionalEncoding(d_model, max_seq_len=512, dropout=0.0)

# Raw binary patterns (before projection)
raw_binary = pe_layer.pe_binary[0, :seq_len, :].detach().numpy()  # (seq_len, n_bits)

# Projected PE (after linear layer)
dummy = torch.zeros(1, seq_len, d_model)
with torch.no_grad():
    projected = pe_layer(dummy)
pe_matrix = projected[0].numpy()  # (seq_len, d_model)

fig, axes = plt.subplots(1, 2, figsize=(16, 4))

# Raw binary heatmap
axes[0].imshow(raw_binary.T, aspect='auto', cmap='binary', origin='lower')
axes[0].set_xlabel('Position')
axes[0].set_ylabel('Bit Index')
axes[0].set_title('Binary Patterns (raw bits, before projection)')

# Projected PE heatmap
im = axes[1].imshow(pe_matrix.T, aspect='auto', cmap='RdYlBu', origin='lower')
axes[1].set_xlabel('Position')
axes[1].set_ylabel('Dimension')
axes[1].set_title('Binary PE Heatmap (after projection to d_model)')
plt.colorbar(im, ax=axes[1])

plt.tight_layout()
plt.savefig('demo_binary_heatmap.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: demo_binary_heatmap.png')

In [None]:
# Summary
d_model, max_seq_len = 64, 512
n_bits = math.ceil(math.log2(max_seq_len + 1))
print('=== Binary PE Summary ===')
print(f'Bits needed for {max_seq_len} positions: {n_bits}')
print(f'Projection params: {n_bits} x {d_model} = {n_bits * d_model}')
print(f'No trig functions: YES (just bit shifts)')
print(f'Generalizes to unseen lengths: YES (bits work for any pos)')
print(f'Task-adaptive: NO (pattern is fixed, projection is learned)')
print()
print('Ready to use in experiments.')
print('Import: from PE.binary_pe import BinaryPositionalEncoding')