# Early Stopping

Early stopping is performed on validation loss (not training loss) because the goal is to find out when your model is truly learning to generalize and not just memorizing the training data.

In [15]:
class EarlyStopper:
    def __init__(self, patience=10, min_delta=0):
        # patience: Number of epochs to wait for improvement before stopping
        # min_delta: Minimum decrease in validation loss to consider as an improvement
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0                  # Counts epochs with no significant improvement
        self.min_val_loss = float('inf')  # Tracks the lowest validation loss so far
        self.early_stop = False           # Will be set to True when stopping condition is met

    def __call__(self, val_loss):
        # If current validation loss is the best so far, it's an improvement
        if val_loss < self.min_val_loss:
            self.min_val_loss = val_loss  # Update best loss
            self.counter = 0              # Reset counter since we have improvement
            self.early_stop = False       # Reset stop flag
        else:
            # If loss hasn't improved enough (by at least min_delta), count it
            if val_loss > (self.min_val_loss + self.min_delta):
                self.counter += 1
                # If we've waited too long with no improvement, trigger early stop
                if self.counter >= self.patience:
                    self.early_stop = True
            else:
                pass  # No significant worsening, but not enough for an improvement; do nothing


In [16]:
early_stopper = EarlyStopper(patience=3)

val_loss_arr = [9, 8, 7, 8, 9, 9, 9 ,9]

for epoch, val_loss in enumerate(val_loss_arr):
    print(f"loss: {val_loss}")
    early_stopper(val_loss)
    if early_stopper.early_stop:
        print("Stopped at epoch:", epoch)
        break

loss: 9
loss: 8
loss: 7
loss: 8
loss: 9
loss: 9
Stopped at epoch: 5
