In [None]:
import torch
import torch.nn as nn

class EarlyStopping:
    def __init__(self, metric = 'val_auc', round = 10, min_delta = 1e-3, 
                 verbose = False, restore_best_weights = True, mode = 'max'):
        self.metric = metric
        self.round = round
        self.min_delta = min_delta
        self.verbose = verbose
        self.restore_best_weights = restore_best_weights
        self.mode = mode
        self.best_score = None
        self.count = 0
        self.early_stop = False
        self.best_weights = None
        
        if self.mode == 'min':
            self.is_improvement = lambda current, best:current < best - self.min_delta
        elif self.mode == 'max':
            self.is_improvement = lambda current, best:current > best + self.min_delta
        else:
            raise ValueError("mode must be 'min' or 'max'")
        
    def __call__(self, current_score, model):
        if self.best_score is None:
            self.best_score = current_score
            if self.verbose:
                print(f'Initial {self.metric} is set as {self.best_score:.4f}')
            if self.restore_best_weights:
                self.best_weights = self.__get_weights__(model.state_dict())
        elif self.is_improvement(current_score, self.best_score):
            if self.verbose:
                improvement = 'raise' if self.mode == 'max' else 'fall'
                print(f'{self.metric} {improvement} from {self.best_score:.4f} to {current_score:.4f}, early stopping count reset')
            if self.restore_best_weights:
                self.best_weights = self.__get_weights__(model.state_dict())
            self.best_score = current_score
            self.count = 0
        else:
            self.count += 1
            if self.verbose:
                print(f'Early stopping count {self.count}/{self.round}')
            if self.count >= self.round:
                if self.verbose:
                    print(f'{self.metric} early stopping triggered after {self.count} rounds')
                self.early_stop = True
            
                
    def __get_weights__(self, state_dict):
        return {k:v.clone().detach() for k, v in state_dict.items()}
    
    def __load_best_weights__(self, model):
        if self.best_weights:
            model.load_state_dict(self.best_weights)