# CSE 25 – Introduction to Artificial Intelligence  
## Worksheet 11: Evaluation for Classification Tasks

**Today’s focus:** 
Once a classification model is trained, how do we **evaluate** it and decide whether it **generalizes** to new data?

>Generalization = performance on new, unseen data.

### Guiding Questions
1. What does **accuracy** measure, and when can it be misleading?
2. What do **precision**, **recall**, and **F1** measure?
3. Why do we use **train / validation / test** splits?
4. What is **overfitting**, and how do we detect and deal with it?

### Learning Objectives
By the end of this worksheet, you will be able to:
- Compute **accuracy** from predictions and labels (or class)
- Use a **confusion matrix** to compute precision, recall, and F1 score
- Interpret a **training vs validation** curve and identify overfitting
- Explain the purpose of **train / validation / test** splits


**Instructions:**

Create a copy of this notebook and complete it during class.  
Work through the cells below **in order**.

You may discuss with your neighbors, but make sure you understand  
what each step is doing and why.

**Submission**
When finished, download the notebook as a PDF and upload it to Gradescope under  
`In-Class – Week 7 Tuesday`.

To download as a PDF on DataHub:  
`File -> Save and Export Notebook As -> PDF`


### Accuracy

In **classification**, we often use **accuracy** to evaluate a model. It measures how often a model predicts the correct class label.

$$
\text{Accuracy} = \frac{\text{Number of Correct Predictions}}{\text{Total Number of Predictions}}
$$

We will work with a **binary classification** setting where labels are:
- `1` = positive class
- `0` = negative class

A model outputs a **predicted label** for each example.

In [None]:
# Toy example: True labels and model predictions
y_true = [1, 0, 1, 0, 0, 1, 1, 0] # true labels
y_pred = [1, 0, 0, 0, 0, 1, 0, 0] # model predictions

# Let's look at them side-by-side
list(zip(y_true, y_pred))

Q. Compute the **accuracy** for the toy example above.
Write your answer below (as a fraction).

`YOUR ANSWER HERE`


Q. Complete the next code cell by writing the function `accuracy(y_true, y_pred)` that returns:

$$
\text{accuracy} = \frac{\#\text{ correct predictions}}{\#\text{ total predictions}}
$$

In [None]:
def accuracy(y_true, y_pred):
    """
    Return accuracy = (# correct) / (total).
    """
    # YOUR CODE HERE

In [None]:
# Test cases for the accuracy function
assert accuracy([1, 0, 1], [1, 0, 1]) == 1.0  # all correct
assert accuracy([1, 0, 1], [0, 0, 1]) == 2/3  # one incorrect
assert accuracy([0, 0, 0], [1, 1, 1]) == 0.0  # all incorrect
assert accuracy([1, 1, 0, 0], [1, 0, 0, 1]) == 0.5  # two correct

print("All accuracy tests passed.")

In [None]:
print (accuracy(y_true, y_pred))

Q. Why might accuracy be misleading in some cases, e.g. on **imbalanced** datasets?

*Hint: Suppose 95% of examples are negative (0), and a model predicts **all zeros**.*


`YOUR ANSWER HERE`

To evaluate the imbalanced example, call you `accuarcy` function on `imbalanced_y_true` and `imbalanced_y_pred` 

In [None]:
# Now let's see what happens when we have an imbalanced dataset
imbalanced_y_true = [0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1] # true labels                   
imbalanced_y_pred = [0]*21 # model predictions

print("Accuracy on imbalanced dataset:", accuracy(imbalanced_y_true, imbalanced_y_pred))

What accuracy do you get? Is accuracy the right measure in this case?

`YOUR ANSWER HERE`

This commonly happens in problems where one class is rare but important, such as:
- *fraud detection*, 
- *disease screening*, 
- *spam filtering*, or 
- *anomaly detection*.

Because most examples belong to the majority class, a model can predict only that class and still achieve very high accuracy. The metric hides the fact that the model completely fails to detect the rare events we actually care about.

### Confusion Matrix

Accuracy tells us overall correctness. A **confusion matrix** tells us *what kinds* of mistakes the model makes.

For binary classification:

- **TP**: true positive (true 1, predicted 1)  
- **FP**: false positive (true 0, predicted 1)  
- **FN**: false negative (true 1, predicted 0)  
- **TN**: true negative (true 0, predicted 0)


$$
\begin{array}{c|cc}
 & \textbf{Predicted 1} & \textbf{Predicted 0} \\
\hline
\textbf{Actual 1} & TP & FN \\
\textbf{Actual 0} & FP & TN \\
\end{array}
$$

Q. Complete the next code cell for the function `confusion_counts(y_true, y_pred)`.

In [None]:
def confusion_counts(y_true, y_pred):
    TP = FP = FN = TN = 0
    for yt, yp in zip(y_true, y_pred):
        pass
        # YOUR CODE HERE 
    return {"TP": TP, "FP": FP, "FN": FN, "TN": TN}

In [None]:
# Test cases for the confusion_counts function

# Case 1: All correct predictions
assert confusion_counts([1, 0, 1, 0], [1, 0, 1, 0]) == {'TP': 2, 'FP': 0, 'FN': 0, 'TN': 2}

# Case 2: All incorrect predictions
assert confusion_counts([1, 1, 0, 0], [0, 0, 1, 1]) == {'TP': 0, 'FP': 2, 'FN': 2, 'TN': 0}

# Case 3: Only true positives
assert confusion_counts([1, 1, 1], [1, 1, 1]) == {'TP': 3, 'FP': 0, 'FN': 0, 'TN': 0}

# Case 4: Only true negatives
assert confusion_counts([0, 0, 0], [0, 0, 0]) == {'TP': 0, 'FP': 0, 'FN': 0, 'TN': 3}

# Case 5: Mixed predictions
assert confusion_counts([1, 0, 1, 0], [0, 1, 1, 0]) == {'TP': 1, 'FP': 1, 'FN': 1, 'TN': 1}

print("All confusion_counts tests passed.")

In [None]:
counts = confusion_counts(y_true, y_pred)
print(counts)

In [None]:
imbalanced_counts = confusion_counts(imbalanced_y_true, imbalanced_y_pred)
print(imbalanced_counts)

We can also compute the accuracy from these counts:

$$
\text{Accuracy}=\frac{TP+TN}{TP+TN+FP+FN}
$$

Accuracy is the fraction of all predictions that are correct (correct positives + correct negatives over all cases).

#### Precision, Recall, F1

- **Precision**: Of all the examples the model *predicted* as positive, how many were actually positive?

$$
\text{Precision} = \frac{TP}{TP + FP}
$$

- **Recall**: Of all the examples that are *actually* positive, how many did the model correctly identify? Recall is also known as *True Positive Rate (TPR)*.

$$
\text{Recall} = \frac{TP}{TP + FN}
$$


- **F1**: single score combining precision and recall (harmonic mean). It is high only when both precision and recall are high.
$$
F_1 = 2\cdot\frac{\text{Precision}\cdot\text{Recall}}{\text{Precision}+\text{Recall}}
$$


![Precision vs Recall](images/Precisionrecall.png)

<sub>
Image source: "Precisionrecall.svg" via Wikimedia Commons. Licensed under CC BY-SA 3.0.  
https://commons.wikimedia.org/wiki/File:Precisionrecall.svg
</sub>


In [None]:
def precision_recall_f1(y_true, y_pred):
    "Return precision, recall, and F1 score for binary classification."
    "If undefined, return 0.0 for that metric."
    precision, recall, f1 = 0.0, 0.0, 0.0
    # YOUR CODE HERE
    return precision, recall, f1

In [None]:
# Test cases for precision_recall_f1

def _close(a, b, tol=1e-9):
    return abs(a - b) < tol

# Case 1: Existing toy example
p, r, f = precision_recall_f1(y_true, y_pred)
assert _close(p, 1.0)
assert _close(r, 0.5)
assert _close(f, 2/3)

# Case 2: Imbalanced example (predict all zeros)
p, r, f = precision_recall_f1(imbalanced_y_true, imbalanced_y_pred)
assert _close(p, 0.0)
assert _close(r, 0.0)
assert _close(f, 0.0)

# Case 3: All predictions are positive
p, r, f = precision_recall_f1([1, 0, 1, 0], [1, 1, 1, 1])
assert _close(p, 0.5)     # 2 TP / (2 TP + 2 FP)
assert _close(r, 1.0)     # 2 TP / (2 TP + 0 FN)
assert _close(f, 2/3)

# Case 4: No actual positives
p, r, f = precision_recall_f1([0, 0, 0], [0, 1, 0])
assert _close(p, 0.0)     # TP=0, FP=1
assert _close(r, 0.0)     # TP=0, FN=0 -> handled as 0.0
assert _close(f, 0.0)

# Case 5: No predicted positives
p, r, f = precision_recall_f1([1, 0, 1], [0, 0, 0])
assert _close(p, 0.0)     # TP=0, FP=0 -> handled as 0.0
assert _close(r, 0.0)     # TP=0, FN=2
assert _close(f, 0.0)

print("All precision_recall_f1 tests passed.")

In [None]:
prec, rec, f1 = precision_recall_f1(y_true, y_pred)
prec, rec, f1

In [None]:
prec_imb, rec_imb, f1_imb = precision_recall_f1(imbalanced_y_true, imbalanced_y_pred)
prec_imb, rec_imb, f1_imb

##### Precision vs Recall in Real-World Scenarios

**Discuss and answer the following**

Q. Look at the formulas for precision and recall above and answer the following in plain English:

- If **precision** is high, what does that tell us about the model's positive predictions?
- If **recall** is high, what does that tell us about the model's ability to find actual positives?

`YOUR ANSWER HERE`

Below are two example situations. Read each one and think about which type of mistake matters more.


**Example 1**

A medical screening system checks images for signs of a serious disease. Missing a real case could delay treatment, but extra follow-up tests are acceptable.

- Would you prioritize **precision** or **recall** here?  
- Which mistake is worse: a false positive or a false negative?  
- Briefly explain your reasoning.

`YOUR ANSWER HERE`


**Example 2**

An automated system flags students for academic misconduct.  A false accusation could unfairly penalize a student and cause stress.

- Would you prioritize **precision** or **recall** here?  
- Which mistake is worse: a false positive or a false negative?  
- Briefly explain your reasoning.

`YOUR ANSWER HERE`


##### Your Turn

Come up with **two real-world scenarios**:

- One where you would prefer **high precision**.
- One where you would prefer **high recall**.

For each scenario, describe:

- What counts as a positive prediction?
- What is a false positive?
- What is a false negative?
- Which mistake matters more, and why?

`YOUR ANSWER HERE`


#### HOMEWORK (No Submission Required)

Run the interactive demo in the next cell to explore how evaluation metrics change. This is for your own understanding and practice. You do **not** need to submit your answers.

##### Explore the Interactive Threshold Demo

Move the sliders and observe how the confusion matrix and metrics change.

Answer the following questions as you explore.


**1. Threshold effects**

As you slowly increase the threshold from left to right:

- What trend do you notice in **precision**?
- What trend do you notice in **recall**?
- Explain your observations using TP, FP, and FN.

`YOUR ANSWER HERE`


**2. Extreme thresholds**

- Set the threshold very close to **0** (almost everything predicted positive).  
  What happens to recall? What happens to precision?
- Set the threshold very close to **1** (almost nothing predicted positive).  
  What happens to recall? What happens to precision?

`YOUR ANSWER HERE`


**3. Imbalanced data (prevalence slider)**

- Lower the positive rate so that positives become rare.
- What happens to **accuracy** compared to precision and recall?
- Why might accuracy remain high even when recall is low?

`YOUR ANSWER HERE`


**4. Separability (model quality)**

- Increase separability. What changes do you observe in the confusion matrix?
- Does changing the threshold matter as much when separability is high? Why or why not?

`YOUR ANSWER HERE`


**5. Color mode (TP/FP/FN/TN)**

Switch to the outcome coloring.

- Where do most **false positives** appear relative to the threshold line?
- Where do most **false negatives** appear?

`YOUR ANSWER HERE`

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output


def confusion_counts(y_true, y_pred):
    y_true_ = np.asarray(y_true).astype(int)
    y_pred = np.asarray(y_pred).astype(int)
    tp = int(np.sum((y_true == 1) & (y_pred == 1)))
    fp = int(np.sum((y_true == 0) & (y_pred == 1)))
    fn = int(np.sum((y_true == 1) & (y_pred == 0)))
    tn = int(np.sum((y_true == 0) & (y_pred == 0)))
    return tp, fp, fn, tn

def safe_div(num, den):
    return float(num) / float(den) if den != 0 else np.nan

def compute_metrics(tp, fp, fn, tn):
    acc  = safe_div(tp + tn, tp + fp + fn + tn)
    prec = safe_div(tp, tp + fp)
    rec  = safe_div(tp, tp + fn)   # Recall / TPR
    f1   = safe_div(2*prec*rec, prec + rec) if (prec == prec and rec == rec and (prec + rec) != 0) else np.nan
    return acc, prec, rec, f1

def fmt(x):
    return "undefined" if (x != x) else f"{x:.3f}"

def make_dataset(n=100, prevalence=0.2, separability=2.0):
    rng = np.random.default_rng(0)  # fixed seed
    n_pos = int(round(n * prevalence))
    n_neg = n - n_pos

    pos_raw = rng.normal(loc=+separability/2, scale=1.0, size=n_pos)
    neg_raw = rng.normal(loc=-separability/2, scale=1.0, size=n_neg)

    sigmoid = lambda z: 1 / (1 + np.exp(-z))
    pos_scores = sigmoid(pos_raw)
    neg_scores = sigmoid(neg_raw)

    y_true  = np.array([1]*n_pos + [0]*n_neg)
    y_score = np.concatenate([pos_scores, neg_scores])

    idx = rng.permutation(n)
    return y_true[idx], y_score[idx]


def plot_confusion(ax, tp, fp, fn, tn):
    cm = np.array([[tp, fn],
                   [fp, tn]], dtype=float)

    vmax = max(cm.max(), 1.0)
    im = ax.imshow(cm, vmin=0, vmax=vmax, cmap="Blues")

    ax.set_title("Confusion Matrix")
    ax.set_xticks([0, 1], labels=["Pred 1", "Pred 0"])
    ax.set_yticks([0, 1], labels=["Actual 1", "Actual 0"])

    threshold = vmax * 0.55
    for (i, j), v in np.ndenumerate(cm):
        color = "white" if v >= threshold else "black"
        ax.text(j, i, f"{int(v)}", ha="center", va="center",
                color=color, fontsize=12, fontweight="bold")

    cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cb.set_label("Count", rotation=270, labelpad=12)

def plot_score_hist(ax, y_true, y_score, y_pred, thr, color_mode):
    bins = np.linspace(0, 1, 21)

    if color_mode == "By true label":
        pos = y_score[y_true == 1]
        neg = y_score[y_true == 0]
        ax.hist(neg, bins=bins, alpha=0.6, label="Actual 0 (neg)")
        ax.hist(pos, bins=bins, alpha=0.6, label="Actual 1 (pos)")

    elif color_mode == "By outcome (TP/FP/FN/TN)":
        tp_mask = (y_true == 1) & (y_pred == 1)
        fp_mask = (y_true == 0) & (y_pred == 1)
        fn_mask = (y_true == 1) & (y_pred == 0)
        tn_mask = (y_true == 0) & (y_pred == 0)

        ax.hist(y_score[tp_mask], bins=bins, alpha=0.7, label="TP (correct +)")
        ax.hist(y_score[tn_mask], bins=bins, alpha=0.7, label="TN (correct -)")
        ax.hist(y_score[fp_mask], bins=bins, alpha=0.7, label="FP (false alarm)")
        ax.hist(y_score[fn_mask], bins=bins, alpha=0.7, label="FN (missed +)")

    ax.axvline(thr, linewidth=2, label=f"threshold = {thr:.2f}")
    ax.set_xlim(0, 1)
    ax.set_title("Score Distributions + Threshold")
    ax.set_xlabel("Model score / probability")
    ax.set_ylabel("Count")
    ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False)


def render(threshold, prevalence, separability, color_mode):
    y_true, y_score = make_dataset(n=100,
                                   prevalence=float(prevalence),
                                   separability=float(separability))

    y_pred = (y_score >= threshold).astype(int)

    tp, fp, fn, tn = confusion_counts(y_true, y_pred)
    acc, prec, rec, f1 = compute_metrics(tp, fp, fn, tn)

    clear_output(wait=True)

    fig, axes = plt.subplots(1, 2, figsize=(13, 4))
    plot_confusion(axes[0], tp, fp, fn, tn)
    plot_score_hist(axes[1], y_true, y_score, y_pred, threshold, color_mode)
    plt.show()

    n_pos = int(np.sum(y_true == 1))
    n_neg = int(np.sum(y_true == 0))
    print(f"Dataset: n=100  positives={n_pos}  negatives={n_neg}  prevalence={n_pos/100:.2f}")
    print(f"Threshold: {threshold:.2f}")
    print(f"TP={tp}  FP={fp}  FN={fn}  TN={tn}\n")
    print(f"Accuracy   = {fmt(acc)}")
    print(f"Precision  = {fmt(prec)}")
    print(f"Recall/TPR = {fmt(rec)}")
    print(f"F1         = {fmt(f1)}")

thr = widgets.FloatSlider(
    value=0.50, min=0.0, max=1.0, step=0.01,
    description="Threshold",
    continuous_update=True,
    readout_format=".2f",
    style={"description_width": "initial"},
    layout=widgets.Layout(width="650px")
)

prevalence = widgets.FloatSlider(
    value=0.20, min=0.01, max=0.99, step=0.01,
    description="Positive rate (prevalence)",
    continuous_update=False,
    readout_format=".2f",
    style={"description_width": "initial"},
    layout=widgets.Layout(width="650px")
)

separability = widgets.FloatSlider(
    value=2.0, min=0.0, max=5.0, step=0.1,
    description="Separability (hard → easy)",
    continuous_update=False,
    readout_format=".1f",
    style={"description_width": "initial"},
    layout=widgets.Layout(width="650px")
)

color_mode = widgets.ToggleButtons(
    options=["By true label", "By outcome (TP/FP/FN/TN)"],
    value="By true label",
    description="Histogram colors",
    style={"description_width": "initial"},
    layout=widgets.Layout(width="650px")
)

ui = widgets.VBox([thr, prevalence, separability, color_mode])
out = widgets.interactive_output(render, {
    "threshold": thr,
    "prevalence": prevalence,
    "separability": separability,
    "color_mode": color_mode
})

display(ui, out)


Use the controls at the top to explore how model behavior changes. This interactive graph shows how predictions and evaluation metrics depend on the **decision threshold**, the **data distribution**, and how well the model separates the classes.

**Controls**

- **Threshold**: changes how strict the classifier is.  
  - Lower threshold -> more positives (higher recall, more FP).  
  - Higher threshold -> fewer positives (higher precision, more FN).
- **Positive rate (prevalence)**: changes how common positives are, illustrating class imbalance.
- **Separability**: controls how easy the classification task is (how well scores separate the classes).

**What you are seeing**

- Each example has a **model score** (a confidence value between 0 and 1).
- The histogram shows the distribution of many examples; each bar represents multiple data points.
- The vertical line is the current **threshold**.  
- Scores to the right of the line are predicted as positive.

**Color Modes**

- **By true label**: Colors show the actual class (positive vs negative).
- **By outcome (TP/FP/FN/TN)**: Colors show whether predictions are correct or incorrect:
  - TP and TN are correct predictions.
  - FP are false alarms.
  - FN are missed positives.

**Confusion Matrix**

- **TP**: correct positive predictions  
- **FP**: false alarms  
- **FN**: missed positives  
- **TN**: correct negative predictions  

As you adjust the controls, examples move between these categories, which changes the metrics.

*Note: In this course, we focus on **accuracy**, **precision**, **recall**, and the **F1 score**. In practice, other metrics such as specificity, false positive rate (FPR), ROC-AUC, and PR-AUC may also be used depending on the application. Different metrics emphasize different types of mistakes, so the “best” metric depends on the goal of the model.*

### Train / Validation / Test Split

Earlier in the course, we introduced the idea of splitting data into a **training set** and a **test set**.  
We trained a model on one portion of the data and then evaluated it on *new, unseen examples* to ask:

*Does the model **generalize**, or did it just memorize the training data?*

We observed that:

- Training performance is often **better** than test performance.
- A model that fits the training data well may still struggle on new data.
- Evaluation on unseen data helps us understand whether the model truly learned patterns.

Today, we refine that idea by introducing a **third split** called the **validation set**.


#### Training set
Used to fit model parameters (weights and biases) through gradient descent. The model sees these examples during learning.

#### Validation set
Used during development to monitor performance, compare models, tune *hyperparameters*, and detect **overfitting** (when a model learns the training data too closely and performs worse on new data). The model does **not** update its parameters on this data.

#### Test set
Used only once at the very end to report final performance. It acts as an unbiased estimate of how the model performs on truly new data.



> **Caution: Protect Your Test Set**
>
> The test set should be used **only once**, after all model decisions are finalized.
>
> Do **not** use the test set to:
>
> - tune hyperparameters  
> - choose model architectures  
> - decide when to stop training  
> - compare multiple versions during development
>
> If you repeatedly look at test performance while making decisions, the model indirectly adapts to the test data, and the reported results are no longer an unbiased measure of generalization.

> **Note:** Using your test set to make model or training decisions is a form of *data leakage*. It can give a false sense of how well your model actually works.


- **Training set** teaches the model.
- **Validation set** guides development decisions.
- **Test set** is the final report card.

##### Hyperparameters vs Parameters

**Parameters:**
- Weights and biases (learned via gradient descent)

**Some Hyperparameters:**
- Number of layers
- Number of neurons
- Learning rate
- Number of epochs

#### Overfitting via Train vs Validation Curves

Overfitting often looks like:

- training metric keeps improving
- validation metric improves, then gets worse

We will simulate a plausible training/validation accuracy curve.

If training performance keeps improving but validation performance worsens, the model is overfitting.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

rng = np.random.default_rng(0)
n_train, n_val = 12, 60

x_train = np.sort(rng.uniform(0, 1, n_train))
y_train = np.sin(2*np.pi*x_train) + rng.normal(0, 0.25, n_train)

x_val = np.sort(rng.uniform(0, 1, n_val))
y_val = np.sin(2*np.pi*x_val) + rng.normal(0, 0.25, n_val)

x_plot = np.linspace(0, 1, 300)

def mse(y, yhat):
    return float(np.mean((np.asarray(y) - np.asarray(yhat))**2))


deg_slider = widgets.IntSlider(value=1, min=0, max=10, step=1, description="degree")
show_true = widgets.Checkbox(value=False, description="show true curve")
show_val  = widgets.Checkbox(value=False, description="show val data")
out = widgets.Output()

def redraw(degree, show_true_curve, show_val_data):
    with out:
        out.clear_output(wait=True)

        coeffs = np.polyfit(x_train, y_train, deg=degree)

        yhat_train = np.polyval(coeffs, x_train)
        yhat_val   = np.polyval(coeffs, x_val)
        yhat_plot  = np.polyval(coeffs, x_plot)

        train_mse = mse(y_train, yhat_train)
        val_mse   = mse(y_val, yhat_val)

        plt.figure(figsize=(7,4))

        plt.scatter(x_train, y_train, label="train")

        # show validation points only if toggled
        if show_val_data:
            plt.scatter(x_val, y_val, alpha=0.35, label="val")

        plt.plot(x_plot, yhat_plot, linewidth=2, label=f"fit (deg={degree})")

        if show_true_curve:
            plt.plot(x_plot, np.sin(2*np.pi*x_plot), linestyle="--", label="true: sin(2πx)")

        if show_val_data:
            title = f"Train MSE: {train_mse:.3f} | Val MSE: {val_mse:.3f}"
        else:
            title = f"Train MSE: {train_mse:.3f}"

        plt.title(title)
        plt.ylim(-2, 2)
        plt.legend()
        plt.show()

ui = widgets.VBox([deg_slider, show_true, show_val, out])
display(ui)

widgets.interactive_output(
    redraw,
    {
        "degree": deg_slider,
        "show_true_curve": show_true,
        "show_val_data": show_val
    }
)

Q. In the above graph, which model degree would you choose based on the validation loss?

`YOUR ANSWER HERE`

#### What are we seeing here?

We are fitting a **curve** to a small set of training data.

Use the **degree** slider to change how flexible the model is:

- Small degree -> simple curve
- Large degree -> very bendy curve

The model is trying to follow the pattern in the training points.

The polynomial degree is a *hyperparameter* that controls model capacity.

Try this:

1. Increase the degree until the curve looks very wiggly.
2. Turn on **show val data**.
3. Notice that the complicated curve may not match the validation points well.

Even though the model fits training data very closely, it may perform worse on new data.

##### Some Ways We Can Reduce Overfitting

- **Use a validation set (early stopping)**  
  Monitor validation loss during training and stop when it begins to worsen.

- **Simplify the model**  
  Use fewer parameters (for example, fewer layers, fewer neurons, or a lower-degree model).

- **Add regularization**  
  Add a small penalty to discourage very large weights, encouraging smoother solutions.

- **Increase the training data**  
  More data makes it harder for the model to memorize noise.

- **Use data augmentation**  
  Create additional training examples by slightly modifying existing ones (for example, flipping or rotating images).

#### Visual Examples: Learning Curve and Training Curve

In the next two cells, focus on interpreting the **graphs** (not the code).

- The **Learning Curve** shows how training and validation accuracy change as training set size increases.
- The **Training Curve** shows how training loss and validation loss change across epochs.

Use these plots to connect ideas from this worksheet:

- **Generalization** (how well the model performs on unseen data)
- **Overfitting** (training keeps improving while validation stops improving or worsens)
- **Model selection / early stopping** (choosing settings based on validation behavior)

As you read the plots, ask:

1. Is there a gap between train and validation performance?
2. Does the gap shrink as more data is used?
3. At what epoch does validation loss reach its minimum?
4. What might happen if training continues far beyond that point?

##### Learning Curve

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Load and preprocess the 8x8 MNIST digits dataset
X, y = load_digits(return_X_y=True)

# Shuffle the dataset to ensure random distribution of classes
rng = np.random.default_rng(42) # fixed seed for reproducibility
perm = rng.permutation(len(X))
X, y = X[perm], y[perm]

# Train / validation split 
X_train_full, X_val, y_train_full, y_val = train_test_split(
    X, y, test_size=300, stratify=y, random_state=0
)

# Standardize features for better convergence of logistic regression
scaler = StandardScaler()
X_train_full = scaler.fit_transform(X_train_full)
X_val = scaler.transform(X_val)

# Shuffle the training data to ensure random distribution of classes for learning curve
perm = rng.permutation(len(X_train_full))
X_train_full = X_train_full[perm]
y_train_full = y_train_full[perm]

# Define training set sizes for learning curve
train_sizes = np.arange(100, len(X_train_full), 80)

# Compute training and validation accuracy for each training set size
train_acc, val_acc = [], []

# Loop over different training set sizes, fit the model, and evaluate accuracy
for n in train_sizes:
    X_train = X_train_full[:n]
    y_train = y_train_full[:n]

    model = LogisticRegression(
        solver="lbfgs",
        C=0.1,
        max_iter=500
    )

    model.fit(X_train, y_train)

    train_acc.append(model.score(X_train, y_train))
    val_acc.append(model.score(X_val, y_val))

plt.figure(figsize=(10,6))
plt.plot(train_sizes, train_acc, "-o", linewidth=3, label="Train")
plt.plot(train_sizes, val_acc, "-s", linewidth=3, label="Validation")
plt.text(
    0.2, 0.05,
    "Validation set kept constant while training set size increases.",
    transform=plt.gca().transAxes,
    fontsize=10,
    bbox=dict(boxstyle="round", alpha=0.2)
)
plt.figtext(
    0.5, -0.005,
    "Learning curve: training vs validation accuracy as training set size increases.",
    ha="center",
    fontsize=11
)
plt.xlabel("Training Set Size")
plt.ylabel("Accuracy")
plt.title("Softmax Logistic Regression on 8×8 MNIST Digits")
plt.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False)
plt.grid(alpha=0.3)
plt.savefig("images/learning_curve_digits.png", dpi=300, bbox_inches="tight")

plt.show()

#### Training Curve

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import log_loss, accuracy_score

import warnings
from sklearn.exceptions import ConvergenceWarning
warnings.filterwarnings("ignore", category=ConvergenceWarning)


X, y = load_digits(return_X_y=True)

X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=300, stratify=y, random_state=0
)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val   = scaler.transform(X_val)

max_epochs = 500

mlp = MLPClassifier(
    hidden_layer_sizes=(8,),
    activation="relu",
    solver="sgd",
    learning_rate_init=0.08,
    batch_size=128,
    max_iter=1,      # one epoch per .fit()
    warm_start=True, # continue training
    shuffle=True,
    random_state=0
)

train_loss, val_loss = [], []
val_acc = [] 

for epoch in range(1, max_epochs + 1):
    mlp.fit(X_train, y_train)

    # training loss from sklearn
    train_loss.append(mlp.loss_)

    # validation loss computed from predicted probabilities
    val_probs = mlp.predict_proba(X_val)
    val_loss.append(log_loss(y_val, val_probs))

    # optional: validation accuracy (not plotted)
    val_acc.append(accuracy_score(y_val, mlp.predict(X_val)))

epochs = np.arange(1, max_epochs + 1)


best_loss_epoch = int(np.argmin(val_loss) + 1)
best_loss = float(np.min(val_loss))
acc_at_best_loss = float(val_acc[best_loss_epoch - 1])


plt.figure(figsize=(8,4.5))
plt.plot(epochs, train_loss, label="training loss", linewidth=2)
plt.plot(epochs, val_loss, label="validation loss", linewidth=2)

plt.axvline(best_loss_epoch, linestyle="--", linewidth=2,
            label=f"lowest val loss = {best_loss_epoch}")
plt.scatter([best_loss_epoch], [best_loss], s=60)

plt.xlabel("Epoch")
plt.ylabel("Cross-Entropy Loss")
plt.title("Neural Network on 8×8 Digits: Loss vs Epoch")


plt.figtext(
    0.5, -0.01,
    "Training Curve: Loss vs Epoch",
    ha="center",
    fontsize=11
)
plt.grid(alpha=0.3)
plt.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False)
plt.tight_layout()
plt.savefig("images/loss_curve_digits.png", dpi=300, bbox_inches="tight")
plt.show()

print(f"Lowest validation loss = {best_loss:.3f} at epoch {best_loss_epoch}")
print(f"Final validation loss = {val_loss[-1]:.3f}")