## Imports

In [1]:
#export
import torch

## Get

In [1]:
#export
def get_pred_class_function(god):
    pred_func_config = god.config['model']['predict_class_function']
    
    if pred_func_config['name'] == 'argmax':
        return lambda pred: pred.argmax(dim=pred_func_config['dim'])
    
    elif pred_func_config['name'] == 'binary_threshold':
        return lambda pred: (pred.squeeze(-1) > pred_func_config['pred_1_threshold']).int()
    
    elif pred_func_config['name'] == 'fish_hierarchical_preds':
        return lambda pred: torch.max(pred[0], dim=1)[1] * (torch.max(pred[1], dim=1)[1] + 1)
    
    elif pred_func_config['name'] == 'argmax_threshold':
        return lambda pred: (pred.argmax(dim=1) + 1) * ((pred >= pred_func_config['threshold']).any(dim=1).long())
    
    raise Exception(f"predict_class_function with name '{pred_func_config['name']}' not supported.")

In [25]:
#export
def get_pred_dict(god, model, dl, include_x=False, pred_class_function=None, device='cuda'):
    if not pred_class_function: pred_class_function = god.pred_class_function
    model = model.to(device)
    model.eval()
    pred_dict = {}
    n = 0
    all_xb, all_yb, all_pred, all_pred_class, all_loss = [], [], [], [], []
    with torch.no_grad():
        for xb, yb in dl:
            xb, yb = xb.to(device), yb.to(device)

            n += yb.size(0)
            pred = model(xb)
            pred_class = pred_class_function(pred)
            loss = torch.tensor([god.criterion(p.unsqueeze(0), y.unsqueeze(0)) for p,y in zip(pred, yb)])

            if include_x: all_xb.append(xb.cpu())
            all_yb.append(yb.cpu())
            all_pred.append(pred.cpu())
            all_pred_class.append(pred_class.cpu())
            all_loss.append(loss.cpu())
            
    model = model.to('cpu')
    return {
        'idx': torch.arange(n).tolist(),
        'x': torch.cat(all_xb).tolist() if all_xb else None,
        'y': torch.cat(all_yb).tolist(),
        'pred': torch.cat(all_pred).tolist(),
        'pred_class': torch.cat(all_pred_class).tolist(),
        'loss': torch.cat(all_loss).tolist()
    }

# WORK IN PROGRESS

## predictions

In [3]:
#OUTDATED AND NOT USED
def get_preds(config, model, dl, get_x=True):
    model = model.to(device)
    model.eval()
    results = []
    for xb, yb in dl:
        xb,yb = xb.to(device), yb.to(device)
        with torch.no_grad():
            pred = model(xb)
            pred_class = get_pred_class(config, pred)
            loss = get_criterion(config, reduction='none')(pred, yb)
            for i in range(len(xb)):
                results.append({
                    'x': xb[i].cpu() if get_x else i,
                    'y': yb[i].tolist(),
                    'pred': pred[i].tolist(),
                    'pred_class': pred_class[i].tolist(),
                    'loss': loss[i].tolist()
                })
    return results

In [4]:
def get_wrong_preds(config, model, dl):
    return [result for result in get_preds(config, model, dl) if result['y'] != result['pred']]

In [5]:
def get_highest_loss_preds(config, model, dl, max_n=0):
    results = sorted(get_preds(config, model, dl), key=lambda r: r['loss'], reverse=True)
    if max_n > 0: results = results[:max_n]
    return results

In [6]:
# WIP
def WIP_ensemble_prediction(config, model_ensemble, mode=True):
    dl = DataLoader(datasets['val'], batch_size=config['hp']['batch_size'], shuffle=False, num_workers=12)
    all_preds = []
    for model in model_ensemble:
        model.to(device)
        preds_list = get_preds(config, model, dl)
        all_preds.append(torch.stack([p['pred_class'] for p in preds_list]).squeeze(-1))
        model.to('cpu')
    y = torch.cat([yb for _,yb in dl])
    
    if mode: return (torch.stack(all_preds) > 0.5).int().mode(dim=0).values, y
    return (torch.stack(all_preds).mean(dim=0) > 0.5).int(), y

## prediction metrics

In [7]:
def pred_dict_accuracy(pred_dict):
    return tensor([p['pred_class'] == p['y'] for p in pred_dict]).float().mean().item()

def pred_dict_precision(pred_dict, target_class=1):
    return tensor([p['y'] == target_class for p in pred_dict if p['pred_class'] == target_class]).float().mean().item()

def pred_dict_recall(pred_dict, target_class=1):
    return tensor([p['pred_class'] == target_class for p in pred_dict if p['y'] == target_class]).float().mean().item()

## Test+Export

In [8]:
!python _notebook2script.py 06_Prediction.ipynb

Converted 06_Prediction.ipynb to exp/Prediction.py
