## Imports

In [1]:
#export
from tqdm.notebook import tqdm
import torch

## Training

In [12]:
#export
def get_device(god):
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        print('Warning: No cuda device found.')
        return torch.device('cpu')

In [22]:
#export
def one_episode(god):
    god.before_episode()
    
    n_epochs = god.config['training']['n_epochs']
    use_progress_bar = god.config['training']['output']['progress_bar']['epoch']
    for epoch in tqdm(range(1, n_epochs+1)) if use_progress_bar else range(1, n_epochs+1):
        one_epoch(god)        
        
    god.after_episode()

In [1]:
#export
def one_epoch(god):
    god.before_epoch()
    
    for phase in ('train', 'val'):
        if phase == 'train':
            god.model.train()
        elif (god.state['epoch_nr'] % god.config['training'].get('val_frequency', 1)) == 0:
            god.model.eval()
        else:
            continue

        god.before_phase(phase)
        
        use_progress_bar = god.config['training']['output']['progress_bar']['batch']
        for xb, yb in tqdm(god.dataloaders[phase]) if use_progress_bar else god.dataloaders[phase]:
            one_batch(god, xb, yb)
        
        god.after_phase()
                
    god.after_epoch()

In [3]:
type((1,2))

tuple

In [4]:
#export
def one_batch(god, xb, yb):
    god.before_batch(xb, yb)
    
    xb, yb = xb.to(god.device), yb.to(god.device)
    with torch.set_grad_enabled(god.state['phase'] == 'train'):
        pred = god.model(xb)
        loss = god.criterion(pred, yb)
        if god.state['phase'] == 'train':
            god.optimizer.zero_grad()
            loss.backward()
            god.optimizer.step()
            if god.scheduler: god.scheduler.step()
    
    # todo work around for tuple pred, this likely should just use some class that supports a custom detach method instead of tuple
    god.after_batch(pred.detach().cpu() if type(pred) != tuple else [p.detach().cpu() for p in pred], loss.detach().cpu())

In [16]:
#OUTDATED AND NOT USED
def train_model(config, model, criterion, optimizer, dataloaders, n_epochs=25, scheduler=None, silent=False, advanced_metrics=False):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_epoch, best_acc, best_train_acc = 1, 0., 0.
    train_loss, train_acc, val_loss, val_acc = [], [], [], []
    train_precision, train_recall, val_precision, val_recall = [], [], [], []  # added precision and recall lists
    train_f1, val_f1 = [], []  # added F1 score lists

    for epoch in tqdm(range(0, n_epochs+1)) if not silent else range(0, n_epochs+1):
        if not silent: print('-' * 10)
        if not silent: print(f'Epoch {epoch}/{n_epochs}')
        
        for phase in ('train', 'val'):
            if god.state['phase'] == 'train':
                if epoch == 0: continue
                model.train()
            else:
                model.eval()
        
            n = 0
            running_loss = 0.0
            running_corrects  = 0

            # Initiate lists for precision, recall and f1_score calculation
            true_labels = []
            pred_labels = []

            for xb, yb in dataloaders[phase]:
                xb, yb = xb.to(device), yb.to(device)
                optimizer.zero_grad()
                with torch.set_grad_enabled(god.state['phase'] == 'train'):
                    pred = model(xb)
                    loss = criterion(pred, yb)
                    if god.state['phase'] == 'train':
                        loss.backward()
                        optimizer.step()
                        if scheduler: scheduler.step()

                n += xb.size(0)
                running_loss += loss.item() * xb.size(0)
                running_corrects += (get_pred_class(config, pred) == yb).int().sum().item()

                # Add the true and predicted labels for this batch to the lists
                true_labels.extend(yb.cpu().numpy())
                pred_labels.extend(get_pred_class(config, pred).cpu().numpy())
                

            epoch_loss = running_loss / n
            epoch_acc = running_corrects / n

            if advanced_metrics:
                labels_to_consider = [i for i in range(1, len(config['id2label']))]
                epoch_precision = precision_score(true_labels, pred_labels, labels=labels_to_consider, average='macro') 
                epoch_recall = recall_score(true_labels, pred_labels, labels=labels_to_consider, average='macro')
                epoch_f1 = f1_score(true_labels, pred_labels, labels=labels_to_consider, average='macro')
            #if advanced_metrics: epoch_precision = precision_score(true_labels, pred_labels, average='macro') # pos_label=1, 
            #if advanced_metrics: epoch_recall = recall_score(true_labels, pred_labels, average='macro')
            #if advanced_metrics: epoch_f1 = f1_score(true_labels, pred_labels, average='macro')

            if not silent and advanced_metrics: print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} Precision: {epoch_precision:.4f} Recall: {epoch_recall:.4f} F1-score: {epoch_f1:.4f}')
            elif not silent: print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'train': 
                train_loss.append(epoch_loss)
                train_acc.append(epoch_acc)
                if advanced_metrics: train_precision.append(epoch_precision)
                if advanced_metrics: train_recall.append(epoch_recall)
                if advanced_metrics: train_f1.append(epoch_f1)
                best_train_acc = max(epoch_acc, best_train_acc)
                
            if phase == 'val':
                val_loss.append(epoch_loss)
                val_acc.append(epoch_acc)
                if advanced_metrics: val_precision.append(epoch_precision)
                if advanced_metrics: val_recall.append(epoch_recall)
                if advanced_metrics: val_f1.append(epoch_f1)
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    best_epoch = epoch
                
        if not silent: print()
    
    
    save_results(config, model, train_loss, train_acc, val_loss, val_acc, best_model_wts, best_acc, best_train_acc, best_epoch, n_epochs)
    return model, train_acc, val_acc, train_precision, train_recall, train_f1, val_precision, val_recall, val_f1

## Training util

In [17]:
#WIP
def old_save_results(config, model, train_loss, train_acc, val_loss, val_acc, best_model_wts, best_acc, best_train_acc, best_epoch, n_epochs):
    #model.load_state_dict(best_model_wts)
    
    history_json = os.path.join(base_path, config['history_json'])
    
    if not os.path.isfile(history_json):
        with open(history_json, 'w') as f:
            json.dump({'history': []}, f)
    
    with open(history_json, 'r+') as f:
        metrics = {'train_loss': train_loss, 'train_acc': train_acc, 'val_loss': val_loss, 'val_acc': val_acc}
        result = {
            'name': model.name,
            'date': datetime.now().strftime("%Y-%m-%dT%H:%M:%S"),
            'best_val_acc': best_acc,
            'best_val_acc_epoch': f'{best_epoch}/{n_epochs}',
            'best_train_acc': best_train_acc,
            'metrics': metrics,
            'hp': config['hp'].__str__()
        }
        root = json.load(f)
        root['history'].append(result)
        f.seek(0)
        json.dump(root, f)    
        f.truncate()      

In [18]:
#WIP
def save_model(config, model, name):
    output_dir = os.path.join('weights', name)
    os.makedirs(output_dir)
    torch.save(model, os.path.join(output_dir, f'{name}.pth'))
    torch.save(model, os.path.join(output_dir, f'params_{name}.pth'))
    torch.save(config, os.path.join(output_dir, f'config_{name}.pth'))
    with open(os.path.join(output_dir, f'config_raw_{name}.txt'), 'w') as f: f.write(str(config))

## Test+Export

In [19]:
!python _notebook2script.py 05_Training.ipynb

Converted 05_Training.ipynb to exp/Training.py
