<a href="https://colab.research.google.com/github/ywang1110/PyTorch_Colab_Files/blob/main/CS336_LLM_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch

# PyTorch

## torch.matmul(input, other, *, out=None) → Tensor
* Matrix product of two tensors.


### Special rules for 1D inputs:
When mixing 1D with high-dimensional tensors:

* If **1st argument** is 1D (shape: (n,)):

  → Temporarily **reshape to (1, n)** by prepending a 1.

* If **2nd argument** is 1D (shape: (n,)):
  
  → Temporarily reshape to **(n, 1)** by appending a 1.

After the batched matrix multiply, the extra dimension is removed.

In [None]:
x = torch.FloatTensor([1,2])
y = torch.randn(2, 3)
z = x @ y
print(z.shape)

torch.Size([3])


In [None]:
x = torch.randn(2, 3)
y = torch.FloatTensor([1, 2, 3])
print(x.shape)
print(y.shape)

torch.Size([2, 3])
torch.Size([3])


In [None]:
z = x @ y
print(z.shape)
print(z)

torch.Size([2])
tensor([ 6.8562, -2.2629])


### 🧠 Key Idea: Batched Matrix Multiply

* If at least one input is N-dimensional with **N > 2**, PyTorch performs batched matrix multiplication:
* The **last two dimensions** are **treated as the matrix dimensions** (n × m or m × p).
* All preceding dimensions are batch dimensions (used for broadcasting).

#### ❗ Important Notes:
* The **last two dimensions** must align for matmul ((n, m) @ (m, p)).
* The **batch dimensions** must be **broadcastable**, using normal broadcasting rules:
  * They must be **equal** or **one of them must be 1**.
* This allows batched operations over many matrices without looping.

In [None]:
a = torch.randn(3, 1, 2, 4)  # shape: (3, 1, 2, 4)
b = torch.randn(5, 4, 6)     # shape: (5, 4, 6)

# Performs batched matrix multiply
out = torch.matmul(a, b)     # shape: (3, 5, 2, 6)
print(out.shape)

torch.Size([3, 5, 2, 6])
