In [None]:
#default_exp utils.engine

In [None]:
#export
import os
import time
import timeit
import torch
import numpy as np

from abc import ABC
from fastprogress.fastprogress import master_bar, progress_bar

In [None]:
#export
class Fitter(ABC):
    def __init__(self):
        pass
    
    def fit(self):
        pass
    
    def log(self):
        pass
    
    def train(self):
        pass
    
    def validate(self):
        pass

In [None]:
#export
class BertFitter(Fitter):
    def __init__(self, model, dataloaders, optimizer, loss_func, metrics, device, log_file='training_log.txt',scheduler=None, trial=None):
        self.model = model
        self.train_dl, self.valid_dl = dataloaders[0], dataloaders[1]
        self.optimizer = optimizer
        self.scheduler = scheduler
        if not os.path.exists(os.path.join('..', 'outputs')): os.makedirs(os.path.join('..', 'outputs'))
        if os.path.exists(os.path.join('..', 'outputs', f'{log_file}')): 
            os.remove(os.path.join('..', 'outputs', f'{log_file}'))
        self.log_file = os.path.join('..', 'outputs', f'{log_file}')
        self.loss_func = loss_func
        self.metrics = metrics
        self.device = device
        self.trial = trial #for optuna
        
    def fit(self, epochs, return_metric=False, monitor='epoch train_loss valid_loss metric time', model_path=os.path.join('..', 'weights', 'model.pth'), show_graph=True):
        self.model_path = model_path
        self.log(f'{time.ctime()}')
        self.log(f'Using device: {self.device}')
        mb = master_bar(range(1, epochs+1)) #MAJOR
        mb.write(monitor.split(),table=True)
        
        model = self.model.to(self.device)
        optimizer = self.optimizer
        best_metric = -np.inf
        train_loss, valid_loss, valid_metric = 0, 0, 0
        train_loss_list, valid_loss_list = [], []
        
        for i_, epoch in enumerate(mb):
            epoch_start = timeit.default_timer()
            start = time.time()
            self.log('-'*50)
            self.log(f'Running Epoch #{epoch} {"🔥"*epoch}')
            self.log(f'{"-"*50} \n')
            self.log('TRAINING...')
            for ind, batch in enumerate(progress_bar(self.train_dl, parent=mb)):
                train_loss += self.train(batch, model, optimizer, self.device, self.scheduler)
                if ind % 500 == 0:
                    self.log(f'Batch: {ind}, Train loss: {train_loss/ len(self.train_dl)}')
#                 break
                mb.child.comment = f'{train_loss / (ind+1 * self.train_dl.batch_size):.3f}'
            train_loss /= mb.child.total
            train_loss_list.append(train_loss) #for graph
            self.log(f'Training time: {round(time.time()-start, 2)} secs \n')
            
            start = time.time()
            self.log('EVALUATING...')
            with torch.no_grad():
                for ind, batch in enumerate(progress_bar(self.valid_dl, parent=mb)):
                    valid_loss_, valid_metric_ = self.validate(batch, model, self.device)
                    valid_loss += valid_loss_
                    valid_metric += valid_metric_
                    if ind % 500 == 0:
                        self.log(f'Batch: {ind}, Valid loss: {valid_loss/ len(self.valid_dl)}')
#                     break   
                    mb.child.comment = f'{valid_loss / (ind+1 * self.train_dl.batch_size):.3f}'
                
                valid_loss /= mb.child.total
                valid_metric /= mb.child.total
                valid_loss_list.append(valid_loss) #for graph
            
            if valid_metric > best_metric:
                #             save model
                if self.model_path is not None:
                    if not os.path.exists(os.path.join('..', 'weights')): os.makedirs(os.path.join('..', 'weights'))
                    self.log(f'Saving model weights at {self.model_path}')
                    torch.save(model.state_dict(), self.model_path)
                best_metric = valid_metric
                    
            if self.trial is not None:
                self.trial.report(best_metric, epoch)

                # Handle pruning based on the intermediate value.
                if self.trial.should_prune():
                    raise optuna.exceptions.TrialPruned()
            
            if show_graph:
                self.plot_loss_update(epoch, epochs, mb, train_loss_list, valid_loss_list) # for graph
                               
            epoch_end = timeit.default_timer()
            total_time = epoch_end - epoch_start
            mins, secs = divmod(total_time, 60)
            hours, mins = divmod(mins, 60)
            ret_time = f'{int(hours)}:{int(mins)}:{int(secs)}'
            mb.write([epoch,f'{train_loss:.6f}',f'{valid_loss:.6f}',f'{valid_metric:.6f}', f'{ret_time}'],table=True)
            self.log(f'Evaluation time: {ret_time}\n')
#             break
            
        if return_metric: return best_metric
    
    def train(self, xy, model, opt, device, sched=None):
        model.train()
        y = xy.pop('targets')
        x = xy
        inputs, targets = [x_.to(device) for x_ in x.values()], y.to(device)
        opt.zero_grad()
        out = model(*inputs)
        loss = self.loss_func(out, targets.argmax(dim=-1))
        loss.backward()
        opt.step()       
        if sched is not None:
            sched.step()
        return loss.item()
    
    def validate(self, xy, model, device):
        model.eval()
        y = xy.pop('targets')
        x = xy
        inputs, targets = [x_.to(device) for x_ in x.values()], y.to(device)
        out = model(*inputs)
        loss = self.loss_func(out, targets.argmax(dim=-1))
        
#         calc f1-score
#         magic done to make sure any number of columns can properly be calculated
        metrics = self.metrics(targets.cpu().argmax(dim=-1), out.cpu().argmax(dim=-1))  #sklearn metrics are (targ, inp)
        return loss.item(), metrics
            
    def log(self, message, verbose=False):
        if verbose: print(message)
        with open(self.log_file, 'a+') as logger_:
            logger_.write(f'{message}\n')
            
#     @staticmethod
#     def calc_loss(out, targ):
#         return nn.BCEWithLogitsLoss()(out, targ)
    
    @staticmethod
    def plot_loss_update(epoch, epochs, mb, train_loss, valid_loss):
        """ dynamically print the loss plot during the training/validation loop.
            expects epoch to start from 1.
        """
        x = range(1, epoch+1)
        y = np.concatenate((train_loss, valid_loss))
        graphs = [[x,train_loss], [x,valid_loss]]
        x_margin = 0.2
        y_margin = 0.05
        x_bounds = [1-x_margin, epochs+x_margin]
        y_bounds = [np.min(y)-y_margin, np.max(y)+y_margin]

        mb.update_graph(np.array(graphs), np.array(x_bounds), np.array(y_bounds))

In [None]:
#export
def get_preds(test_ds, test_dl, model,device=None, ensemble_proba=False):
    if device is None:
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    preds = np.zeros(len(test_ds))
    test_preds = []
    with torch.no_grad():
        for batch in progress_bar(test_dl):
            out = model(**batch)
            if not ensemble_proba:
                out = out.softmax(dim=-1).argmax(dim=-1)
            else:
                out = out.softmax(dim=-1)
            test_preds.append(out.cpu().numpy())
    return np.concatenate(test_preds)