In [17]:
import time
from datetime import datetime

from nn import MLP

In [18]:
n = MLP(784, [128, 10])

## Step 1: Load and Prepare the Data

MNIST images are 28×28 pixels = 784 values. We need to:
1. Load the data
2. **Flatten** each image: (28, 28) → 784-element vector
3. **Normalize** pixels: 0-255 → 0.0-1.0 (helps training)
4. Convert to Value objects (for autodiff)

In [19]:
from mnist_loader import load_mnist

# Load the data
X_train, y_train, X_test, y_test = load_mnist()

print(f"Training data: {X_train.shape}")  # (60000, 28, 28)
print(f"Training labels: {y_train.shape}")  # (60000,)
print(f"Pixel range: {X_train.min()} to {X_train.max()}")  # 0 to 255

Training data: (60000, 28, 28)
Training labels: (60000,)
Pixel range: 0 to 255


In [20]:
# Start small - use only 100 samples for fast experimentation
n_samples = 100
X_small = X_train[:n_samples]
y_small = y_train[:n_samples]

# Flatten: (100, 28, 28) → (100, 784)
# reshape(-1, 784) means: keep batch size automatic, make each row 784 elements
X_flat = X_small.reshape(-1, 784)

# Normalize: 0-255 → 0.0-1.0
# Why? Keeps gradients stable. Large numbers (255) can cause exploding gradients
X_norm = X_flat / 255.0

print(f"Flattened shape: {X_flat.shape}")  # (100, 784)
print(f"Normalized range: {X_norm.min():.2f} to {X_norm.max():.2f}")  # 0.00 to 1.00

Flattened shape: (100, 784)
Normalized range: 0.00 to 1.00


## Step 2: Define the Loss Function

**What's a loss?** A number that measures how wrong your predictions are.
- Low loss = good predictions
- High loss = bad predictions

### Loss Function Comparison

Imagine you're classifying a digit as "3".

Your network outputs 10 scores (one for each digit 0-9).

Let's say the raw scores are:
`[1.2, 0.5, 2.1, 4.5, 1.8, 0.9, 1.1, 2.0, 0.7, 1.3]`

The correct answer is digit 3 (index 3, which has score 4.5 - the highest!). Now let's see how each loss function "thinks" about this:

#### 1) **MSE Loss (Mean Squared Error)**
MSE wants your outputs to match a target vector exactly.

First, convert the label to **one-hot encoding**:
- Label: 3 → Target: `[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]`
- (Only the correct class gets 1, all others get 0)

Your output: `[1.2, 0.5, 2.1, 4.5, 1.8, 0.9, 1.1, 2.0, 0.7, 1.3]`

MSE calculates: `(1.2-0)² + (0.5-0)² + (2.1-0)² + (4.5-1)² + ...`

The problem:
- It's angry that digit 2 has score 2.1 instead of 0.0
- It's angry that digit 3 has score 4.5 instead of 1.0
- It treats these errors equally!

**What MSE is saying**: "I don't care that you picked the right digit. Your numbers are wrong!"

It's like a teacher marking your test wrong because you wrote "3.0" instead of exactly "3".

#### 2) **Hinge Loss (Multi-class SVM)**
Hinge loss asks: "Is the correct class winning by at least a margin of 1.0 against EVERY wrong class?"

**Formula**: `L = Σ(i≠y) max(0, s_i - s_y + 1)` 
- where `y` = correct class index (3)
- `s_y` = score of correct class (4.5)
- `s_i` = score of each wrong class
- We sum up penalties for ALL wrong classes that are too close

**Complete Dry Run:**

Scores: `[1.2, 0.5, 2.1, 4.5, 1.8, 0.9, 1.1, 2.0, 0.7, 1.3]`  
True label: 3 (so s_y = 4.5)

For each wrong class i, calculate: max(0, s_i - 4.5 + 1)

- Class 0: max(0, 1.2 - 4.5 + 1) = max(0, -2.3) = **0** ✓
- Class 1: max(0, 0.5 - 4.5 + 1) = max(0, -3.0) = **0** ✓
- Class 2: max(0, 2.1 - 4.5 + 1) = max(0, -1.4) = **0** ✓
- ~~Class 3~~: (skip - this is the correct class)
- Class 4: max(0, 1.8 - 4.5 + 1) = max(0, -1.7) = **0** ✓
- Class 5: max(0, 0.9 - 4.5 + 1) = max(0, -2.6) = **0** ✓
- Class 6: max(0, 1.1 - 4.5 + 1) = max(0, -2.4) = **0** ✓
- Class 7: max(0, 2.0 - 4.5 + 1) = max(0, -1.5) = **0** ✓
- Class 8: max(0, 0.7 - 4.5 + 1) = max(0, -2.8) = **0** ✓
- Class 9: max(0, 1.3 - 4.5 + 1) = max(0, -2.2) = **0** ✓

**Total Loss = 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 = 0** ✓

**What Hinge Loss is saying**: "You got it right by a comfortable margin (at least 1.0 from every wrong class). Perfect!"

**The limitation - A Different Example:**

Now imagine the scores were: `[2.0, 1.0, 2.0, 4.1, 2.0, 1.5, 1.0, 3.9, 1.0, 2.0]`  
True label: still 3 (so s_y = 4.1)

Notice: Class 7 has score 3.9, very close to the correct answer!

Let's recalculate:

- Class 0: max(0, 2.0 - 4.1 + 1) = max(0, -1.1) = **0** ✓
- Class 1: max(0, 1.0 - 4.1 + 1) = max(0, -2.1) = **0** ✓
- Class 2: max(0, 2.0 - 4.1 + 1) = max(0, -1.1) = **0** ✓
- ~~Class 3~~: (skip - correct class)
- Class 4: max(0, 2.0 - 4.1 + 1) = max(0, -1.1) = **0** ✓
- Class 5: max(0, 1.5 - 4.1 + 1) = max(0, -1.6) = **0** ✓
- Class 6: max(0, 1.0 - 4.1 + 1) = max(0, -2.1) = **0** ✓
- Class 7: max(0, 3.9 - 4.1 + 1) = max(0, **0.8**) = **0.8** ⚠️ PENALTY!
- Class 8: max(0, 1.0 - 4.1 + 1) = max(0, -2.1) = **0** ✓
- Class 9: max(0, 2.0 - 4.1 + 1) = max(0, -1.1) = **0** ✓

**Total Loss = 0 + 0 + 0 + 0 + 0 + 0 + 0.8 + 0 + 0 = 0.8** ⚠️

**The problem**: Even though you predicted correctly (class 3 is still the highest), hinge loss penalizes you because class 7 is too close. The margin between them is only 0.2 (4.1 - 3.9), but hinge loss wants at least 1.0.

This forces the network to be "confidently correct" - not just barely winning.

#### 3) **Softmax + Cross-Entropy**
This is a two-step process:

**Step 1: Softmax** converts scores to probabilities:
```
Raw scores: [1.2, 0.5, 2.1, 4.5, 1.8, 0.9, 1.1, 2.0, 0.7, 1.3]
Probabilities: [0.04, 0.02, 0.09, 0.70, 0.06, 0.03, 0.03, 0.08, 0.02, 0.04]
```
(They now sum to 1.0 - these are "confidences")

**Step 2: Cross-Entropy** asks: "What probability did you assign to the correct class?"

- You said: 0.70 (70% confident it's a 3)
- Loss = -log(0.70) = 0.36

If you had been more confident:
- 90% confident → Loss = -log(0.90) = 0.11 (better!)
- 99% confident → Loss = -log(0.99) = 0.01 (even better!)

If you had been less confident:
- 50% confident → Loss = -log(0.50) = 0.69 (worse!)
- 10% confident → Loss = -log(0.10) = 2.30 (much worse!)
- 1% confident → Loss = -log(0.01) = 4.61 (terrible!)

**What Cross-Entropy is saying**: "You got it right, but you're only 70% sure. Let's push you to be MORE confident in the correct answer."

It's never satisfied! Even at 99.9% it's saying "you can do better."

---

### Summary Comparison

| Loss Function | What it optimizes | When to use | Pros | Cons |
|--------------|-------------------|-------------|------|------|
| **MSE** | Match exact target values | Regression, simple cases | Simple, easy to implement | Not ideal for classification |
| **Hinge Loss** | Win by a margin | Binary/multi-class SVM | Forces confident boundaries | Doesn't care beyond margin |
| **Cross-Entropy** | Maximize probability of correct class | Multi-class classification | Industry standard, best performance | Requires softmax, slightly complex |

**For this tutorial**: We'll start with **MSE** (simplest to understand and implement). Once you see it working, you can experiment with the others!

In [21]:
from micrograd import Value


def loss_fn():
    inputs = [list(map(Value, xrow)) for xrow in X_norm]

    predictions = []

    for i, x in enumerate(inputs):
        t0 = time.time()
        predictions.append(n(x))
        t1 = time.time()
        print(f"[{i} - {datetime.utcnow().isoformat()}] Time taken: {t1 - t0:.3f}s")

    # For each sample:
    #   - Get the true label (y_small[i])
    #   - Create one-hot target: [0,0,0,1,0,0,0,0,0,0] if label is 3
    #   - Calculate (pred - target)^2 for all 10 outputs
    #   - Sum them up
    # Then average across all samples

    losses = []
    for i in range(n_samples):
        # Get prediction (list of 10 Values)
        pred = predictions[i]

        # Get true label
        true_label = y_small[i]

        target = [0] * 10
        target[true_label] = 1

        sample_loss = sum((pred[j] - target[j]) ** 2 for j in range(10))
        losses.append(sample_loss)

    total_loss = sum(losses) / (1.0 * n_samples)

    # For each sample, check if argmax(prediction) == true_label
    # Hint: Find which output has the highest .data value
    correct = 0.0
    for i in range(n_samples):
        predicted_digit = max(range(10), key=lambda j: predictions[i][j].data)

        if predicted_digit == y_small[i]:
            correct += 1

    return total_loss, correct / n_samples


loss_fn()

  print(f"[{i} - {datetime.utcnow().isoformat()}] Time taken: {t1 - t0:.3f}s")


[0 - 2025-11-23T08:16:44.552212] Time taken: 0.168s
[1 - 2025-11-23T08:16:59.963528] Time taken: 15.411s
[2 - 2025-11-23T08:17:00.160954] Time taken: 0.197s
[3 - 2025-11-23T08:17:00.353751] Time taken: 0.193s
[4 - 2025-11-23T08:17:00.530942] Time taken: 0.177s
[5 - 2025-11-23T08:17:00.699889] Time taken: 0.169s
[6 - 2025-11-23T08:17:00.866844] Time taken: 0.167s
[7 - 2025-11-23T08:17:01.041253] Time taken: 0.174s
[8 - 2025-11-23T08:17:01.206970] Time taken: 0.166s
[9 - 2025-11-23T08:17:19.391561] Time taken: 18.185s
[10 - 2025-11-23T08:17:19.550400] Time taken: 0.159s
[11 - 2025-11-23T08:17:19.717665] Time taken: 0.167s
[12 - 2025-11-23T08:17:19.882101] Time taken: 0.164s
[13 - 2025-11-23T08:17:20.053011] Time taken: 0.171s
[14 - 2025-11-23T08:17:20.217164] Time taken: 0.164s
[15 - 2025-11-23T08:17:20.387940] Time taken: 0.171s
[16 - 2025-11-23T08:17:20.557880] Time taken: 0.170s
[17 - 2025-11-23T08:17:20.718755] Time taken: 0.161s
[18 - 2025-11-23T08:17:20.886316] Time taken: 0.167s
[

(Value(data=10.627507965669595), 0.08)