In [1]:
import numpy as np
import tensorflow as tf

class PerformanceScheduling:
    def __init__(self, factor=10):
        self.factor = factor
        self.best_val = np.inf

    def get_lr(self, lr, val_error):
        if val_error < self.best_val:
            self.best_val = val_error
            return lr
        else:
            return lr / self.factor

In [2]:
lr = 0.1
scheduler = PerformanceScheduling(factor=2)

val_errors = [0.5, 0.4, 0.42, 0.41, 0.39, 0.4]
for epoch, err in enumerate(val_errors):
    lr = scheduler.get_lr(lr, err)
    print(f"Epoch {epoch}: val_error={err}, lr={lr:.4f}")

Epoch 0: val_error=0.5, lr=0.1000
Epoch 1: val_error=0.4, lr=0.1000
Epoch 2: val_error=0.42, lr=0.0500
Epoch 3: val_error=0.41, lr=0.0250
Epoch 4: val_error=0.39, lr=0.0250
Epoch 5: val_error=0.4, lr=0.0125


In [3]:
# do this after every n epochs

In [4]:
class PerformanceScheduling:
    def __init__(self, patience=10, factor=10):
        self.patience = patience
        self.factor = factor
        self.best_val = np.inf
        self.wait = 0

    def get_lr(self, lr, val_error):
        if val_error < self.best_val:
            self.best_val = val_error
            self.wait = 0
            return lr
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.wait = 0
                return lr / self.factor
            else:
                return lr

In [5]:
scheduler = PerformanceScheduling(patience=3, factor=2)
lr = 0.1

val_losses = [0.5, 0.4, 0.42, 0.45, 0.46, 0.43, 0.41, 0.39, 0.4, 0.41]

for epoch, val_loss in enumerate(val_losses):
    lr = scheduler.get_lr(lr, val_loss)
    print(f"Epoch {epoch}: val_loss={val_loss:.2f}, lr={lr:.4f}")

Epoch 0: val_loss=0.50, lr=0.1000
Epoch 1: val_loss=0.40, lr=0.1000
Epoch 2: val_loss=0.42, lr=0.1000
Epoch 3: val_loss=0.45, lr=0.1000
Epoch 4: val_loss=0.46, lr=0.0500
Epoch 5: val_loss=0.43, lr=0.0500
Epoch 6: val_loss=0.41, lr=0.0500
Epoch 7: val_loss=0.39, lr=0.0500
Epoch 8: val_loss=0.40, lr=0.0500
Epoch 9: val_loss=0.41, lr=0.0500
