# Early Stopping

* Early stopping is a form of regularization used to **avoid overfitting** on the training dataset.
* Early stopping keeps track of the validation loss, if the loss stops decreasing for several epochs in a row the training stops.

<img src="https://i.imgur.com/M8ZxF2V.png" width="500"/>



## Import

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision import models

## Dataset & DataLoader

In [2]:
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
valid_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=False)

train_batch_size = 64
test_batch_size = 1000

# Create Dataloader 
train_loader = DataLoader(
    dataset = train_dataset,
    batch_size = train_batch_size,
    shuffle = True
)

valid_loader = DataLoader(
    dataset = valid_dataset,
    batch_size = test_batch_size,
    shuffle = False
)

## Simple ANN Model

In [3]:
class ANN(nn.Module):
    def __init__(self, input_size, hidden_size, num_class):
        super(ANN, self).__init__()
        self.fc1 = nn.Linear(
            in_features=input_size,
            out_features=hidden_size
        )
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(
            in_features=hidden_size,
            out_features=num_class
        )
    
    def forward(self, x):
        output = self.fc1(x)
        output = self.relu(output)
        output = self.fc2(output)
        return output

In [4]:
input_size = 28 * 28
num_class = 10
hidden_size = 500

model = ANN(
    input_size=input_size,
    hidden_size=hidden_size,
    num_class=num_class 
)

## Train & Validation

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
# define the tainning routine
def train(model, train_loader, optimizer, loss_function, epoch):

    model.train()
    loss_total = 0

    for idx, (images, labels) in enumerate(train_loader):
        images = images.to(device).view(-1, 28*28)
        labels = labels.to(device)
        
        # zero the parameter gradients
        optimizer.zero_grad()
        output = model(images)
        
        loss = loss_function(output, labels)
        loss_total += loss.item()
        loss.backward()
        
        optimizer.step()
        
    return loss_total / len(train_loader) 

In [9]:
# define the validation routine to evaluate model performance
def valid(model, valid_loader, loss_function):

    model.eval()
    loss_total = 0

    with torch.no_grad():
        for idx, (images, labels) in enumerate(train_loader):
            images = images.to(device).view(-1, 28*28)
            labels = labels.to(device)
            output = model(images)

            loss = loss_function(output, labels)
            loss_total += loss.item()
    return loss_total / len(valid_loader)

## Training & Validation Process (w/out Early Stopping)

In [10]:
import time

In [17]:
model = ANN(
    input_size=input_size,
    hidden_size=hidden_size,
    num_class=num_class 
)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_function = nn.CrossEntropyLoss()

In [18]:
init_start = time.time()
for epoch in range(0, 100):
    print(f'Epoch: {epoch}:')
    train_start_time = time.time()
    train_loss = train(model, train_loader, optimizer, loss_function, epoch)

    print("- Spend %.2f seconds on training process" % (time.time() - train_start_time))
    valid_start_time = time.time()
    val_loss = valid(model, valid_loader, loss_function)
    print("- Spend %.2f seconds on validation process" % (time.time() - valid_start_time))
    print('- Training Loss: %.3f, Validation Loss: %.3f' % (train_loss, val_loss))

print('Spend %3.f seconds on entire training process' % (time.time() - init_start))

Epoch: 0:
- Spend 6.56 seconds on training process
- Spend 5.55 seconds on validation process
- Training Loss: 0.262, Validation Loss: 11.437
Epoch: 1:
- Spend 6.41 seconds on training process
- Spend 5.50 seconds on validation process
- Training Loss: 0.103, Validation Loss: 5.996
Epoch: 2:
- Spend 6.37 seconds on training process
- Spend 5.43 seconds on validation process
- Training Loss: 0.067, Validation Loss: 4.458
Epoch: 3:
- Spend 7.05 seconds on training process
- Spend 5.48 seconds on validation process
- Training Loss: 0.047, Validation Loss: 3.413
Epoch: 4:
- Spend 6.38 seconds on training process
- Spend 5.37 seconds on validation process
- Training Loss: 0.034, Validation Loss: 2.025
Epoch: 5:
- Spend 6.09 seconds on training process
- Spend 5.06 seconds on validation process
- Training Loss: 0.026, Validation Loss: 1.669
Epoch: 6:
- Spend 6.03 seconds on training process
- Spend 5.08 seconds on validation process
- Training Loss: 0.018, Validation Loss: 1.104
Epoch: 7:
- 

## Training & Validation Process (w/ Early Stopping)

In [15]:
model = ANN(
    input_size=input_size,
    hidden_size=hidden_size,
    num_class=num_class 
)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_function = nn.CrossEntropyLoss()

In [16]:
init_start = time.time()

# init min validation loss
min_valid_loss = 100
patience = 5
trap_count = 0

for epoch in range(0, 100):
    print(f'Epoch: {epoch}:')
    train_start_time = time.time()
    train_loss = train(model, train_loader, optimizer, loss_function, epoch)

    valid_start_time = time.time()
    val_loss = valid(model, valid_loader, loss_function)
    
    # Early Stopping
    if val_loss < min_valid_loss:
        min_valid_loss = val_loss
        trap_count = 0
    else:
        trap_count += 1
    if trap_count >= patience:
        print(f'Validation loss not decrease ...{trap_count} times... stop training....')
        # save model weight
        torch.save(model.state_dict(), 'best_model.pkl')
        break

    print('- Training Loss: %.3f, Validation Loss: %.3f [%s/%s]' % (train_loss, val_loss, trap_count, patience))


print('Spend %3.f seconds on entire training process' % (time.time() - init_start))

Epoch: 0:
- Training Loss: 0.265, Validation Loss: 11.252 [0/5]
Epoch: 1:
- Training Loss: 0.102, Validation Loss: 6.490 [0/5]
Epoch: 2:
- Training Loss: 0.066, Validation Loss: 5.028 [0/5]
Epoch: 3:
- Training Loss: 0.047, Validation Loss: 3.039 [0/5]
Epoch: 4:
- Training Loss: 0.033, Validation Loss: 2.634 [0/5]
Epoch: 5:
- Training Loss: 0.026, Validation Loss: 2.246 [0/5]
Epoch: 6:
- Training Loss: 0.020, Validation Loss: 1.269 [0/5]
Epoch: 7:
- Training Loss: 0.016, Validation Loss: 1.082 [0/5]
Epoch: 8:
- Training Loss: 0.012, Validation Loss: 0.868 [0/5]
Epoch: 9:
- Training Loss: 0.008, Validation Loss: 0.758 [0/5]
Epoch: 10:
- Training Loss: 0.008, Validation Loss: 0.803 [1/5]
Epoch: 11:
- Training Loss: 0.008, Validation Loss: 0.857 [2/5]
Epoch: 12:
- Training Loss: 0.008, Validation Loss: 0.946 [3/5]
Epoch: 13:
- Training Loss: 0.008, Validation Loss: 0.626 [0/5]
Epoch: 14:
- Training Loss: 0.006, Validation Loss: 0.271 [0/5]
Epoch: 15:
- Training Loss: 0.005, Validation Los