In [4]:
import torch
from torch import nn


class EarlyStopper:
    def __init__(self, metric_name, patience=1, min_delta=0, minimize=True):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_val_metric = None
        self.max_val_metric = None
        self.name = metric_name
        self.minimize = minimize

    def min_criteria(self, val_metric, model):

      print(val_metric, self.min_val_metric)

      if self.min_val_metric is None or val_metric < self.min_val_metric:
          self.min_val_metric = val_metric
          self.counter = 0
          torch.save(model.state_dict(), "best_model.pth")
      elif val_metric > (self.min_val_metric + self.min_delta):
          self.counter += 1
          if self.counter >= self.patience:
              return True
      return False

    def max_criteria(self, val_metric, model):
      if self.max_val_metric is None or val_metric > self.max_val_metric:
          self.max_val_metric = val_metric
          self.counter = 0
          torch.save(model.state_dict(), "best_model.pth")
      elif val_metric < (self.max_val_metric - self.min_delta):
          self.counter += 1
          if self.counter >= self.patience:
              return True
      return False


    def early_stop(self, val_metrics, model):
      val_metric = val_metrics['val_' + self.name]
      if self.minimize:
        return self.min_criteria(val_metric, model)
      else:
        return self.max_criteria(val_metric, model)

def train_one_epoch(model, loss, optimizer, ds, l2_regularization=None):

  ''' model - pytorch model
  loss -  pytorch loss
  optimizer - pytorch optimizer
  ds  -  pytorch dataset'''

  model.train() # we switch the model to the training mode
  train_loss = 0 # this variable accumulates loss
  ds_len = 0 # len of the dataset
  # loop over batches of the training set 
  for x, y in ds:
    # print("HERE", x.shape, y.shape)
    x, y = x.cuda(), y.cuda()
    optimizer.zero_grad()

    output = model(x) # forward pass of the model 
    
    # we calculate loss and gradients for optimization
    l = loss(output, y)
    if l2_regularization:
      l+= l2_regularization.calculate(model)

    l.backward()

    # optimizer updates weights of the model 
    optimizer.step()

    # loss record 
    train_loss += l.item()*x.shape[0]
    ds_len += x.shape[0]

  return train_loss/ds_len


def test(model, ds):
  ''' model - pytorch model
  ds  -  pytorch dataset'''

  model.eval() # we switch the model to the evaluation mode
  final_output  = []
  # loop over batches of the test set
  for x, y in ds:
    x, y = x.cuda(), y.cuda()
    # we say that we do not want to calculate gradient for optimization
    with torch.no_grad():
      output = model(x) # forward pass of the model 
      # we collect all outputs of model
      final_output.append(output.detach().cpu())

  return torch.cat(final_output)



def validate(model, val_metrics, ds):
  ''' model - pytorch model
  val_metrics -  dictionary of metrics
  ds  -  pytorch dataset'''

  y_test = torch.cat([y for x, y in ds])

  model_pred = test(model, ds)
  metric_out = {}

  for name, metric in val_metrics.items():
    metric_out['val_' + name] = metric(model_pred, y_test)
  
  return metric_out


def train(model, loss, val_metrics, optimizer, train_ds, dev_ds, num_epochs=10,
          early_stopper=None, l2_regularization=None):

  ''' model - pytorch model
  loss -  pytorch loss
  optimizer - pytorch optimizer
  train_ds  -  pytorch dataset for training
  dev_ds  -  pytorch dataset for evaluation while training
  num_epochs - number of epochs,that defines the number times that the learning 
                algorithm will work through the entire training dataset
  '''
  # here we record parameters of network after each epoch
  # param_history = []

  history = {"train_loss": []}

  for key in val_metrics:
    history['val_' + key] = []


  for epoch in range(num_epochs):

    # if epoch == 0:
    #   param_history.append(get_weights(model))
      # print(param_history[-1])

    print('=========')
    current_train_loss = train_one_epoch(model=model, loss=loss, 
                                         optimizer=optimizer, ds=train_ds, 
                                         l2_regularization=l2_regularization)

    val_metric_out = validate(model=model, val_metrics=val_metrics, ds=dev_ds)

    for name, vm in val_metric_out.items():
      history[name].append(vm)


    history["train_loss"].append(current_train_loss)

    output2print = "epoch {}".format(epoch + 1) + \
                    " train loss: {:.4f} ".format(current_train_loss) + \
                    " ".join("{}: {:.4f}".format(k, v) for k, v in val_metric_out.items())

    print(output2print)


    if early_stopper is not None and early_stopper.early_stop(val_metric_out, model):
      print("EARLY STOPPING ")
      model.load_state_dict(torch.load("best_model.pth"))
      return history
                      
  return history


class L2Regularization:
    def __init__(self, l2_lambda=0.01):
        self.l2_lambda = l2_lambda


    def calculate(self, model):
      l2_reg = torch.tensor(0.,  device='cuda:0')
      for param in model.parameters():
          l2_reg += torch.norm(param) 
      #     print("HERE2",l2_reg)
        
      # print("HERE",l2_reg)

      return l2_reg*self.l2_lambda

