# DAPE — Data-Adaptive Positional Encoding Demo

**Paper:** Zheng et al. (NeurIPS 2024) — *DAPE: Data-Adaptive Positional Encoding for Length Extrapolation*  
**Formula:** `PE_DAPE(x,i) = (1 + alpha(x)) * PE_sin(i) + beta(x)`  
**Properties:** Dynamic (adapts per input) | Small adaptation network | Variable-length friendly  
**Note:** This is a simplified additive version for unified comparison with other PE methods.

In [None]:
import math
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
class DAPEPositionalEncoding(nn.Module):
    '''
    Data-Adaptive PE (simplified version, inspired by Zheng et al. NeurIPS 2024).
    Modulates sinusoidal PE using a small network conditioned on input statistics.

    PE_DAPE(x, i) = (1 + alpha(x)) * PE_sin(i) + beta(x)
    where [alpha, beta] = AdaptNet([mean, std, norm_length])
    '''
    def __init__(self, d_model, max_seq_len=512, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.d_model = d_model

        # Base sinusoidal encoding
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe_base', pe.unsqueeze(0))  # (1, max_seq_len, d_model)

        # Adaptation network: [mean, std, norm_len] -> (alpha, beta) per dim
        self.adapt_net = nn.Sequential(
            nn.Linear(3, 32),
            nn.ReLU(),
            nn.Linear(32, 2 * d_model),
        )

        adapt_params = sum(p.numel() for p in self.adapt_net.parameters())
        print(f'DAPE: base=sinusoidal (0 params) + adapt_net ({adapt_params} params)')

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        with torch.no_grad():
            stats = torch.stack([
                x.mean(dim=[1, 2]),
                x.std(dim=[1, 2]).clamp(min=1e-6),
                torch.full((batch_size,), seq_len / 512.0, device=x.device),
            ], dim=1)  # (batch, 3)

        mod = self.adapt_net(stats)                       # (batch, 2*d_model)
        alpha = mod[:, :self.d_model].unsqueeze(1)        # (batch, 1, d_model)
        beta  = mod[:, self.d_model:].unsqueeze(1)        # (batch, 1, d_model)

        pe_adaptive = (1 + alpha) * self.pe_base[:, :seq_len, :] + beta
        return self.dropout(x + pe_adaptive)

print('DAPEPositionalEncoding defined.')

In [None]:
# Sanity check — verify different inputs get different PEs
d_model, seq_len = 64, 50
pe_layer = DAPEPositionalEncoding(d_model, max_seq_len=512, dropout=0.0)

# Two inputs with very different statistics
x_zeros = torch.zeros(1, seq_len, d_model)        # mean=0, std=0
x_large = torch.randn(1, seq_len, d_model) * 10   # larger scale

with torch.no_grad():
    out_zeros = pe_layer(x_zeros)
    out_large = pe_layer(x_large)

# If DAPE is truly adaptive, PE component should differ
pe_zeros = out_zeros - x_zeros
pe_large = out_large - x_large

diff = (pe_zeros - pe_large).abs().mean().item()
print(f'Input shape : {x_zeros.shape}')
print(f'Learnable params: {sum(p.numel() for p in pe_layer.parameters())}')
print(f'\nPE difference between x_zeros and x_large: {diff:.4f}')
print('(Should be > 0 if DAPE adapts to input statistics)')

In [None]:
# Heatmap — compare base sinusoidal vs DAPE-adapted PE
d_model, seq_len = 64, 60
pe_layer = DAPEPositionalEncoding(d_model, max_seq_len=512, dropout=0.0)

x = torch.randn(1, seq_len, d_model)
with torch.no_grad():
    encoded = pe_layer(x)

pe_component = (encoded - x).numpy()[0]       # (seq_len, d_model) — adaptive PE
base_pe = pe_layer.pe_base[0, :seq_len, :].numpy()  # (seq_len, d_model) — sinusoidal base

fig, axes = plt.subplots(1, 2, figsize=(16, 4))

im0 = axes[0].imshow(base_pe.T, aspect='auto', cmap='RdYlBu', origin='lower')
axes[0].set_xlabel('Position')
axes[0].set_ylabel('Dimension')
axes[0].set_title('Base Sinusoidal PE (static)')
plt.colorbar(im0, ax=axes[0])

im1 = axes[1].imshow(pe_component.T, aspect='auto', cmap='RdYlBu', origin='lower')
axes[1].set_xlabel('Position')
axes[1].set_ylabel('Dimension')
axes[1].set_title('DAPE Adapted PE (dynamic, changes per input)')
plt.colorbar(im1, ax=axes[1])

plt.tight_layout()
plt.savefig('demo_dape_heatmap.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: demo_dape_heatmap.png')

In [None]:
d_model = 64
adapt_params = 3*32 + 32 + 32*2*d_model + 2*d_model
print('=== DAPE Summary ===')
print(f'Learnable parameters: {sum(p.numel() for p in DAPEPositionalEncoding(d_model, dropout=0.0).parameters())}')
print('Base: sinusoidal (no params) + adaptation network')
print('Input features: [mean, std, normalized_length]')
print('Output: per-dimension scale (alpha) and shift (beta)')
print('Dynamic: YES — different inputs get different PEs')
print('Length-adaptive: YES')
print('Original paper: Zheng et al. NeurIPS 2024')
print()
print('Ready to use in experiments.')
print('Import: from PE.dape import DAPEPositionalEncoding')