One key thing to understand is how torch.matmul behaves based on the shapes of the input tensors.

For example, if you have two 2D tensors representing matrices, say A with shape (m, n) and B with shape (n, p), it will compute the standard matrix product resulting in a tensor of shape (m, p). 

If the inputs are 1D tensors, which you can think of as vectors, it performs a dot product and returns a scalar.

Things get more interesting with higher dimensions; The rules are similar with standard operations Tensor Broadcasting (**element wise ops**), but with a few twists.

PyTorch treats **the last two dimensions** of each tensor as the "core matrices" to multiply, while **any preceding dimensions** are considered batch dimensions that can be **broadcasted together**

The "core matrics" follow matrix multiplication rules.

everything else (batch dims) follow standard broadcasting rules.

Broadcasting happens primarily in the **batch dimensions**. 

The core matrix dimensions—**the last two of each tensor**—must match exactly for multiplication (**the inner dimensions need to be equal**), without broadcasting there

Summary:

The key difference:

	•In elementwise ops, all dimensions must broadcast to equal size.
	
	•In matmul, only batch dimensions broadcast; the matrix multiply dimensions must follow the linear algebra rule (…, n, m) @ (…, m, p).

So if your shapes don’t match in the last two dims like (n, m) and (m, p), it’s not a broadcasting issue — it’s a linear algebra dimension mismatch.

In [None]:
import torch
# Q: (batch, num_heads, seq_len, head_dim)
# K: (batch, num_heads, seq_len, head_dim)
# Attention scores: Q @ K^T
scores = torch.matmul(Q, K.transpose(-2, -1))
# Result: (batch, num_heads, seq_len, seq_len)