<a href="https://colab.research.google.com/github/ramayer/google-colab-examples/blob/main/Einx_vs_torch_einsum.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Comparison of Einx vs raw Pytorch for multi-head self attention


| Dimension/Feature         | MultiHeadAttention_einx (einx)         | MultiHeadAttention_torch (PyTorch)         |
|--------------------------|-----------------------------------------|--------------------------------------------|
| Batch                    | Implicit via "..." (no manual handling) | Explicit, must reshape/permute manually    |
| Channel                  | Named as "C" or part of "..."           | Explicit index, must track position        |
| Number of Heads          | Named as "nh" in pattern string         | Explicit index (e.g., index 2 after permute/view) |
| Head Dimension           | Named as "dh" in pattern string         | Explicit index (e.g., index 4 after permute/view) |
| Sequence/Spatial (N/H/W) | Named as "H", "W", "N" in pattern       | Explicit index, must compute/track         |
| Rearrangement            | Declarative: einx.rearrange pattern     | Imperative: view/permute with indices      |
| Contraction (dot prod)   | Declarative: einx.dot pattern           | Imperative: torch.bmm, must match dims     |
| Softmax                  | Declarative: einx.softmax pattern       | Explicit: F.softmax(dim=...)               |
| Adding new dims/layouts  | Easy: update pattern string             | Manual: update all view/permute indices    |
| Readability              | High: dimension names, "..." for batch  | Lower: must track indices, verbose         |


In [2]:
!pip install einx
import einx
import torch
import torch.nn as nn
import torch.nn.functional as F


Collecting einx
  Downloading einx-0.3.0-py3-none-any.whl.metadata (6.9 kB)
Downloading einx-0.3.0-py3-none-any.whl (102 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.0/103.0 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einx
Successfully installed einx-0.3.0


In [7]:
class MultiHeadAttention_torch(nn.Module):
    """Multi-head self-attention using torch.bmm (no einsum)."""
    def __init__(self, channels, num_heads=8):
        super().__init__()
        assert channels % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = channels // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1)
        self.proj = nn.Conv2d(channels, channels, kernel_size=1)

    def forward(self, x):
        B, C, H, W = x.shape
        N = H * W
        qkv = self.qkv(x)  # (B, 3*C, H, W)

        # reshape to (3, B, heads, N, head_dim)
        qkv = qkv.view(B, 3, self.num_heads, self.head_dim, N)
        qkv = qkv.permute(1, 0, 2, 4, 3).contiguous()  # (3, B, heads, N, head_dim)

        q, k, v = qkv[0], qkv[1], qkv[2]  # each: (B, heads, N, head_dim)

        # merge batch and heads for batched matrix-multiply with bmm
        bh = B * self.num_heads
        q = q.view(bh, N, self.head_dim)   # (B*heads, N, head_dim)
        k = k.view(bh, N, self.head_dim)   # (B*heads, N, head_dim)
        v = v.view(bh, N, self.head_dim)   # (B*heads, N, head_dim)

        # compute attention scores via bmm: (B*heads, N, N)
        scores = torch.bmm(q, k.transpose(1, 2)) * self.scale
        attn = F.softmax(scores, dim=-1)

        # weighted sum: (B*heads, N, head_dim)
        out = torch.bmm(attn, v)

        # restore shape to (B, C, H, W)
        out = out.view(B, self.num_heads, N, self.head_dim)           # (B, heads, N, head_dim)
        out = out.permute(0, 1, 3, 2).contiguous().view(B, self.num_heads * self.head_dim, H, W)
        return self.proj(out)

In [8]:
class MultiHeadAttention_einx(nn.Module):
    def __init__(self, channels, num_heads=8):
        super().__init__()
        assert channels % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = channels // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1)
        self.proj = nn.Conv2d(channels, channels, kernel_size=1)

    def forward(self, x):
        B, C, H, W = x.shape
        qkv = self.qkv(x)
        qkv = einx.rearrange('... (qkv nh dh) H W -> qkv ... nh (H W) dh',
                        qkv, qkv=3, nh=self.num_heads, dh=self.head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = einx.dot('... N dh, ... M dh -> ... N M', q, k) * self.scale
        attn = einx.softmax('... N [M]', attn)
        out = einx.dot('... N M, ... M d -> ... N d', attn, v)
        out = einx.rearrange('... nh (H W) dh -> ... (nh dh) H W',
                        out,  nh=self.num_heads, dh=self.head_dim, H=H, W=W)
        return self.proj(out)

In [None]:
print("""
Notice the differences in the models above.


""")

In [9]:
# instantiate one model, copy its state to the other, and verify identical forwards
torch.manual_seed(0)
channels, num_heads = 16, 4
B, H, W = 2, 8, 8

m1 = MultiHeadAttention_torch(channels, num_heads=num_heads)
m2 = MultiHeadAttention_einx(channels, num_heads=num_heads)

# copy parameters
m2.load_state_dict(m1.state_dict())

m1.eval(); m2.eval()

torch.manual_seed(1)
x = torch.randn(B, channels, H, W)

out1 = m1(x)
out2 = m2(x)

max_diff = (out1 - out2).abs().max().item()
print("max_abs_diff:", max_diff)
assert torch.allclose(out1, out2, atol=1e-6, rtol=1e-5), "Forward outputs differ"
print("Verification passed: forward outputs are identical within tolerance.")

max_abs_diff: 0.0
Verification passed: forward outputs are identical within tolerance.


In [10]:
import unittest
import torch


class TestMultiHeadAttention(unittest.TestCase):
    def test_output_shape(self):
        # channels must be divisible by num_heads
        channels, num_heads = 16, 4
        B, H, W = 2, 8, 8
        model = MultiHeadAttention_einx(channels, num_heads=num_heads)
        x = torch.randn(B, channels, H, W)
        out = model(x)
        self.assertEqual(out.shape, x.shape)

    def test_backward_computes_gradients(self):
        channels, num_heads = 12, 3
        B, H, W = 2, 6, 6
        model = MultiHeadAttention_einx(channels, num_heads=num_heads)
        x = torch.randn(B, channels, H, W, requires_grad=True)
        out = model(x)
        loss = out.sum()
        loss.backward()
        # At least one parameter should have a non-zero gradient
        grads = [p.grad for p in model.parameters()]
        self.assertTrue(any(g is not None and g.abs().sum().item() > 0 for g in grads))

    def test_deterministic_forward_given_same_inputs(self):
        torch.manual_seed(0)
        channels, num_heads = 8, 2
        B, H, W = 1, 4, 4
        model = MultiHeadAttention_einx(channels, num_heads=num_heads)
        x = torch.randn(B, channels, H, W)
        out1 = model(x)
        out2 = model(x)
        self.assertTrue(torch.allclose(out1, out2))

if __name__ == "__main__":
    # In a Jupyter environment, avoid exiting the kernel
    unittest.main(argv=[""], exit=False)



...
----------------------------------------------------------------------
Ran 3 tests in 0.664s

OK
