In [31]:
import torch

expansion_rate = 4
layer_idx = 1
batch_size = 8
sequence_length = 16
d_model = 32

input_tensor = torch.randn(batch_size, sequence_length, expansion_rate, d_model)


a_m = torch.zeros((expansion_rate,1))
a_m[layer_idx % expansion_rate] = 1.0

a_r = torch.eye(expansion_rate)

beta = torch.ones(expansion_rate)

alpha = torch.cat([a_m, a_r], dim=1)

output = alpha.transpose(-1, -2) @ input_tensor

h0 = output[..., 0, :]
hr = output[..., 1:, :]

H = torch.einsum("n, bld -> blnd", beta, h0)

print(H.shape)
print(hr.shape)


out = H + hr

print(out.shape)

torch.Size([8, 16, 4, 32])
torch.Size([8, 16, 4, 32])
torch.Size([8, 16, 4, 32])


In [33]:
import torch
import torch.nn as nn

class StaticHyperConnection(nn.Module):
    def __init__(self, dim: int, rate: int, layer_id: int):
        """
        Optimized static hyper-connections implementation.
        
        Args:
            dim: Hidden dimension size
            rate: Expansion rate (n in paper)
            layer_id: Current layer index
        """
        super().__init__()
        
        # Initialize static alpha (width connections)
        init_alpha0 = torch.zeros((rate, 1))
        init_alpha0[layer_id % rate, 0] = 1.0
        alpha = torch.cat([init_alpha0, torch.eye(rate)], dim=1)
        self.static_alpha = nn.Parameter(alpha)  # [rate, rate+1]
        
        # Initialize static beta (depth connections)
        self.static_beta = nn.Parameter(torch.ones(rate))  # [rate]
        
    def forward(self, h: torch.Tensor, layer_fn: callable) -> torch.Tensor:
        """
        Forward pass with optimized einsum operations.
        
        Args:
            h: Input hidden states [batch, seq_len, rate, dim]
            layer_fn: Transformer layer function
        """
        # Width connections using efficient einsum
        # 'ij,blid->bljd' maps:
        #   i: input alpha rows
        #   j: input alpha cols
        #   b: batch
        #   l: sequence length
        #   d: hidden dimension
        mixed_h = torch.einsum('ij,blid->bljd', self.static_alpha, h)
        
        # Layer computation
        layer_out = layer_fn(mixed_h[..., 0, :])  # Use first vector as input
        
        # Depth connections using efficient einsum
        # 'n,bld->blnd' maps:
        #   n: expansion rate
        #   b: batch
        #   l: sequence length
        #   d: hidden dimension
        h = torch.einsum('n,bld->blnd', self.static_beta, layer_out) + mixed_h[..., 1:, :]
        
        return h

class TransformerWithStaticHC(nn.Module):
    def __init__(self, dim: int, num_layers: int, expansion_rate: int = 4):
        super().__init__()
        self.dim = dim
        self.expansion_rate = expansion_rate
        
        # Initialize transformer layers
        self.layers = nn.ModuleList([
            TransformerLayer(dim) for _ in range(num_layers)
        ])
        
        # Initialize static hyper-connections
        self.hyper_connections = nn.ModuleList([
            StaticHyperConnection(dim, expansion_rate, i)
            for i in range(num_layers)
        ])
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through transformer with optimized static hyper-connections.
        
        Args:
            x: Input tensor [batch, seq_len, dim]
        """
        # Initialize hyper hidden matrix H0
        batch_size, seq_len, _ = x.shape
        h = x.unsqueeze(2).expand(-1, -1, self.expansion_rate, -1)
        
        # Process through layers
        for layer, hyper_conn in zip(self.layers, self.hyper_connections):
            h = hyper_conn(h, layer)
        
        # Final output is sum of last hyper hidden vectors
        return h.sum(dim=2)

def test_efficiency():
    """Demonstrate the efficiency gain"""
    import time
    
    # Test parameters
    batch_size = 32
    seq_len = 128
    dim = 512
    rate = 4
    
    # Create test input
    x = torch.randn(batch_size, seq_len, rate, dim)
    
    # Test old style with broadcasting
    alpha = torch.randn(rate, rate+1)[None, None, ...]
    beta = torch.randn(rate)[None, None, ...]
    
    start = time.time()
    for _ in range(100):
        mixed_h = alpha.transpose(-1, -2) @ x
        h = torch.einsum('blh,bln->blnh', x, beta)
    old_time = time.time() - start
    
    # Test new style with efficient einsum
    alpha = torch.randn(rate, rate+1).cuda()
    beta = torch.randn(rate).cuda()
    
    start = time.time()
    for _ in range(100):
        mixed_h = torch.einsum('ij,blid->bljd', alpha, x)
        h = torch.einsum('n,bld->blnd', beta, x)
    new_time = time.time() - start
    
    print(f"Old style time: {old_time:.3f}s")
    print(f"New style time: {new_time:.3f}s")
    print(f"Speedup: {old_time/new_time:.2f}x")

test_efficiency()

RuntimeError: einsum(): the number of subscripts in the equation (3) does not match the number of dimensions (4) for operand 0 and no ellipsis was given