# Transformer Training Shape Walkthrough

This file describes all tensor shapes during **training**, where the decoder receives the full target sequence at once (teacher forcing).  
Notation:
- B = batch size
- I = encoder input length
- O = decoder target length
- E = embedding/model dimension
- H = number of heads
- FFN hidden dimension = 2048

---

# 1. Encoder (Training)

## Input
```
(B, I)
```

## Embedding + Positional Encoding
```
(B, I, E)
```

## Self-Attention Block

### Q, K, V projections
```
Q = (B, I, E)
K = (B, I, E)
V = (B, I, E)
```

### Split across heads
```
(B, H, I, E/H)
```

### Scaled dot-product
Scores:
```
(B, H, I, I)
```

Attention output:
```
(B, H, I, E/H)
```

Concat heads:
```
(B, I, E)
```

Final linear:
```
(B, I, E)
```

Residual + LayerNorm:
```
(B, I, E)
```

## Feed-Forward Network
```
(B, I, 2048) → (B, I, E)
```

Residual + LayerNorm:
```
(B, I, E)
```

Encoder output:
```
(B, I, E)
```

---

# 2. Decoder (Training Mode, Full Sequence Provided)

## Input target tokens
```
(B, O)
```

## Embedding + Positional Encoding
```
(B, O, E)
```

---

# A. Masked Self-Attention

### Q, K, V
```
(B, O, E)
```

### Split into heads
```
(B, H, O, E/H)
```

### Scores
```
(B, H, O, O)
```

(Causal mask applied)

### Output
```
(B, O, E)
```

Residual + LayerNorm:
```
(B, O, E)
```

---

# B. Cross-Attention  

Q from decoder:
```
(B, O, E)
```

K, V from encoder:
```
(B, I, E)
```

### Head split
```
Q: (B, H, O, E/H)
K: (B, H, I, E/H)
V: (B, H, I, E/H)
```

### Scores
```
(B, H, O, I)
```

Output:
```
(B, O, E)
```

Residual + LayerNorm:
```
(B, O, E)
```

---

# C. Feed-Forward Network
```
(B, O, 2048) → (B, O, E)
```

Residual + LayerNorm:
```
(B, O, E)
```

---

# 3. Final Linear + Softmax

Logits:
```
(B, O, Vocab)
```

Loss is computed against target labels of shape:
```
(B, O)
```

---

# Summary Table

| Stage | Shape |
|-------|--------|
| Encoder input | (B, I) |
| Encoder output | (B, I, E) |
| Decoder input | (B, O) |
| Decoder output before linear | (B, O, E) |
| Final logits | (B, O, Vocab) |

# Transformer Inference Shape Walkthrough

## Encoder (Inference Same as Training)
- Input: `(1, I)`
- Embedding + Positional: `(1, I, E)`
- Encoder output: `(1, I, E)`

---

# Decoder – Autoregressive Inference

## Step 1 — First Token
decoder receives:
[sos]
- Decoder input: `(1,1)`
- Embedding: `(1,1,E)`
- Masked self-attention: `(1,1,E)`
- Cross-attention: `(1,1,E)`
- FFN: `(1,1,E)`
- Final logits: `(1,1,Vocab)`

---

## Step 2 — Second Token
Now decoder receives:
[sos, predicted_token_1]

- Decoder input: `(1,2)`
- Embedding: `(1,2,E)`
- Masked self-attention: `(1,2,E)`
- Cross-attention: `(1,2,E)`
- FFN: `(1,2,E)`
- Final logits: `(1,2,Vocab)` (only last position used)

---

## General Step t
- Decoder input: `(1,t)`
- Embedding: `(1,t,E)`
- Masked self-attention: `(1,t,E)`
- Cross-attention: `(1,t,E)`
- FFN: `(1,t,E)`
- Final logits: `(1,t,Vocab)` → use `(1,Vocab)`

---

## Final Output for O Tokens
- Output tokens: `(1,O)`
- Final logits: `(1,O,Vocab)`