# 4. Loss Functions

The **Loss Function** (or Cost Function) tells the network how "wrong" it is.
Minimizing this number is the goal of training.
Different tasks need different loss functions!

In [None]:
import torch
import torch.nn as nn
import numpy as np

## 1. Regression: Mean Squared Error (MSE)

Used when predicting a continuous number (e.g., house price).

$MSE = \frac{1}{N} \sum (y_{pred} - y_{true})^2$

In [None]:
y_pred = torch.tensor([2.5, 0.0, 2.0])
y_true = torch.tensor([3.0, -0.5, 2.0])

criterion = nn.MSELoss()
loss = criterion(y_pred, y_true)

print(f"Predictions: {y_pred}")
print(f"Targets: {y_true}")
print(f"MSE Loss: {loss.item():.4f}")

# Manual calculation
manual_loss = ((y_pred - y_true)**2).mean()
print(f"Manual MSE: {manual_loss.item():.4f}")

## 2. Classification: Cross Entropy Loss

Used when predicting classes (e.g., Cat, Dog, Bird).
It combines **LogSoftmax** and **NLLLoss**.

$CE = -\sum y_{true} \cdot \log(p_{pred})$

**Important**: PyTorch's `CrossEntropyLoss` expects **raw logits** (scores before Softmax), not probabilities!

In [None]:
# 3 samples, 3 classes (0=Cat, 1=Dog, 2=Bird)
logits = torch.tensor([[2.0, 1.0, 0.1],   # Pred: Cat (class 0)
                       [0.5, 2.5, 0.3],   # Pred: Dog (class 1)
                       [4.0, 0.0, 0.0]])  # Pred: Cat (class 0)

targets = torch.tensor([0, 1, 2]) # True: Cat, Dog, Bird

criterion = nn.CrossEntropyLoss()
loss = criterion(logits, targets)

print(f"Logits:\n{logits}")
print(f"Targets: {targets}")
print(f"Cross Entropy Loss: {loss.item():.4f}")

# Why is loss high? The 3rd sample predicted Cat (score 4.0) but was Bird (class 2)!

## 3. Binary Classification: BCE Loss

Used for Yes/No questions.
PyTorch has `BCEWithLogitsLoss` (more stable) and `BCELoss`.
Use `BCEWithLogitsLoss` and pass raw logits.

In [None]:
logits = torch.tensor([2.5, -3.0, 0.0]) # High, Low, Uncertain
targets = torch.tensor([1.0, 0.0, 1.0]) # True, False, True

criterion = nn.BCEWithLogitsLoss()
loss = criterion(logits, targets)

probs = torch.sigmoid(logits)
print(f"Logits: {logits}")
print(f"Probabilities: {probs}")
print(f"Targets: {targets}")
print(f"BCE Loss: {loss.item():.4f}")

## 4. Summary

| Task | Last Layer Activation | Loss Function |
|------|----------------------|---------------|
| Regression | None (Linear) | MSELoss |
| Binary Class. | None (Logits) | BCEWithLogitsLoss |
| Multi-Class | None (Logits) | CrossEntropyLoss |