# First level model predictions analysis

In [1]:
from pathlib import Path

import sys
sys.path.insert(0, (Path(".").resolve().parent / "common").as_posix())

In [2]:
INPUT_PATH = Path(".").resolve().parent / "input/"

In [3]:
import pandas as pd

import torch

from albumentations import Compose, RandomCrop, RandomCropNearBBox, ShiftScaleRotate, GaussNoise, ElasticTransform
from albumentations import CenterCrop, Rotate, RandomRotate90, Flip
from albumentations.pytorch import ToTensor


from dataflow.datasets import HPADataset
from dataflow.dataloaders import get_base_train_val_loaders_by_fold, get_train_val_indices, Subset, TransformedDataset, DataLoader
from models.resnet import HPAResNet50

In [28]:
seed = 12
device = "cuda"
debug = False

val_fold_index = 0
n_folds = 3

n_tta = 4

tta_transforms = Compose([
    Flip(),
    RandomRotate90(),
    CenterCrop(320, 320),
    ToTensor()
])
tta_transform_fn = lambda dp: tta_transforms(**{"image": dp[0], "tags": dp[1].astype('float32')})


batch_size = 128
num_workers = 8

trainval_df = pd.read_csv(INPUT_PATH / "train.csv")
trainval_ds = HPADataset(trainval_df, INPUT_PATH / "train")
_, val_fold_indices = get_train_val_indices(trainval_df,
                                            fold_index=val_fold_index,
                                            n_splits=n_folds,
                                            random_state=seed)


val_ds = Subset(trainval_ds, val_fold_indices)
val_ds = TransformedDataset(val_ds, transform_fn=tta_transform_fn)


val_loader = DataLoader(val_ds, shuffle=False,
                        batch_size=batch_size, num_workers=num_workers,
                        pin_memory="cuda" in device, drop_last=False)

model = HPAResNet50(num_classes=HPADataset.num_tags)

In [29]:
run_uuid = "6bf2701872df4bd190a9c517a5e52f32"
run_name = "resnet50_val_acc_0.37"

weights_filename = "model_HPAResNet50_162_val_loss=0.07056979.pth"

In [30]:
def weights_path(mlflow_client, run_uuid, weights_filename):
    path = Path(client.tracking_uri) 
    run_info = client.get_run(run_id=run_uuid)
    artifact_uri = run_info.info.artifact_uri
    artifact_uri = artifact_uri[artifact_uri.find("/") + 1:]
    path /= Path(artifact_uri) / weights_filename
    assert path.exists(), "File is not found at {}".format(path.as_posix())
    return path.as_posix()    

In [31]:
import mlflow

client = mlflow.tracking.MlflowClient(tracking_uri="../output")
model.load_state_dict(torch.load(weights_path(client, run_uuid, weights_filename)))

In [32]:
model.to(device)
_ = model.eval()

In [33]:
from custom_ignite.metrics.accuracy import Accuracy
from custom_ignite.metrics.precision import Precision
from custom_ignite.metrics.recall import Recall
from ignite.metrics import MetricsLambda

In [34]:
accuracy_metric = Accuracy(is_multilabel=True)
precision_metric = Precision(average=False, is_multilabel=True)
recall_metric = Recall(average=False, is_multilabel=True)

f1_metric = precision_metric * recall_metric * 2 / (recall_metric + precision_metric + 1e-20)
f1_metric = MetricsLambda(lambda t: torch.mean(t).item(), f1_metric)

metrics = {
    "accuracy": accuracy_metric,
    "precision": precision_metric,
    "recall": recall_metric,
    "f1": f1_metric
}

# Add Precision/Recall per tag
from functools import partial


def thresholded_output_transform_per_tag(output, k):
    y_pred, y = output
    y_pred, y = y_pred[:, k], y[:, k]
    y_pred = torch.round(torch.sigmoid(y_pred))
    return y_pred, y


for i, t in enumerate(HPADataset.tags):
    metrics['pr_{}'.format(t)] = Precision(
        output_transform=partial(thresholded_output_transform_per_tag, k=i)
    )
    metrics['re_{}'.format(t)] = Recall(
        output_transform=partial(thresholded_output_transform_per_tag, k=i)
    )

In [35]:
from ignite.engine import create_supervised_evaluator, convert_tensor, Events, Engine
from ignite.contrib.handlers import ProgressBar

In [36]:
def prepare_batch(batch, device=None, non_blocking=False):
    """Prepare batch for training: pass to a device with options

    """
    x, y = batch['image'], batch['tags']
    return (convert_tensor(x, device=device, non_blocking=non_blocking), y)

In [38]:
predictor_tta = create_supervised_evaluator(model, device=device, non_blocking="cuda" in device, prepare_batch=prepare_batch)
ProgressBar(desc='Predict TTA', persist=True).attach(predictor_tta)


y_probas_mean_tta = [0 for _ in range(len(val_loader))]
y_true = []

@predictor_tta.on(Events.ITERATION_COMPLETED)
def save_tta_predictions(engine):
    output = engine.state.output
    iteration = (engine.state.iteration - 1) % len(val_loader)
    y_probas = torch.sigmoid(output[0].detach())
        
    y_probas_mean_tta[iteration] += y_probas * 1.0 / n_tta

    tta_index = engine.state.epoch - 1
    if tta_index == 0:
        y_true.append(output[1])

In [39]:
predictor_tta.run(val_loader, max_epochs=n_tta)

Predict TTA[81/81] 100%|██████████ [00:31<00:00]
Predict TTA[81/81] 100%|██████████ [00:31<00:00]
Predict TTA[81/81] 100%|██████████ [00:32<00:00]
Predict TTA[81/81] 100%|██████████ [00:32<00:00]


<ignite.engine.engine.State at 0x7f6889d4bba8>

In [40]:
y_preds_mean_tta = [torch.round(y_probas).cpu() for y_probas in y_probas_mean_tta]

In [41]:
def validate(engine, batch):
    y_pred, y = batch
    return y_pred, y


validator = Engine(validate)
ProgressBar(desc='Validation').attach(validator)

for name, metric in metrics.items():
    metric.attach(validator, name)


from collections import defaultdict

tags_counter = defaultdict(int)

@validator.on(Events.ITERATION_COMPLETED)
def count_tags(engine):
    _, y = engine.state.output
    for i, t in enumerate(HPADataset.tags):
        tags_counter[t] += torch.sum(y[:, i]).item()

In [50]:
data = [(y_pred, y) for y_pred, y in zip(y_preds_mean_tta, y_true)]

validator.run(data, max_epochs=1)



<ignite.engine.engine.State at 0x7f688a0400b8>

In [51]:
acc_score = validator.state.metrics['accuracy']
pr_score = torch.mean(validator.state.metrics['precision']).item()
re_score = torch.mean(validator.state.metrics['recall']).item()
f1_score = validator.state.metrics['f1']

In [52]:
acc_score, pr_score, re_score, f1_score

(0.37253951323572193, 0.690301884353082, 0.586601053686286, 0.6110574725381293)

In [56]:
validator.state.metrics['pr_{}'.format(HPADataset.tags[0])].item()

0.8385391638635271

In [53]:
for t in HPADataset.tags:
    print("Pr={:.4f} | Re={:.4f} | #={:4} - {}".format(validator.state.metrics['pr_{}'.format(t)], 
                                                       validator.state.metrics['re_{}'.format(t)], 
                                                       int(tags_counter[t]),
                                                       t))

Pr=0.8385 | Re=0.8126 | #=4295 - Actin filaments
Pr=0.8987 | Re=0.6794 | #= 418 - Aggresome
Pr=0.8815 | Re=0.6164 | #=1207 - Cell junctions
Pr=0.8000 | Re=0.4769 | #= 520 - Centrosome
Pr=0.8449 | Re=0.6065 | #= 620 - Cytokinetic bridge
Pr=0.8082 | Re=0.3520 | #= 838 - Cytoplasmic bodies
Pr=0.6993 | Re=0.2976 | #= 336 - Cytosol
Pr=0.8149 | Re=0.5473 | #= 941 - Endoplasmic reticulum
Pr=0.5000 | Re=0.0556 | #=  18 - Endosomes
Pr=0.0000 | Re=0.0000 | #=  15 - Focal adhesion sites
Pr=0.0000 | Re=0.0000 | #=  10 - Golgi apparatus
Pr=0.7434 | Re=0.3096 | #= 365 - Intermediate filaments
Pr=0.6250 | Re=0.1528 | #= 229 - Lipid droplets
Pr=0.6531 | Re=0.1788 | #= 179 - Lysosomes
Pr=0.8990 | Re=0.7753 | #= 356 - Microtubule ends
Pr=0.0000 | Re=0.0000 | #=   7 - Microtubule organizing center
Pr=0.0000 | Re=0.0000 | #= 177 - Microtubules
Pr=0.4286 | Re=0.0429 | #=  70 - Mitochondria
Pr=0.5294 | Re=0.0598 | #= 301 - Mitotic spindle
Pr=0.6667 | Re=0.1255 | #= 494 - Nuclear bodies
Pr=0.0000 | Re=0.0000

In [2]:
import numpy as np

y_preds = np.random.randint(0, 2, size=(10, 28))

In [11]:
" ".join([str(v) for v in np.where(y_preds[0, :] > 0)[0]])

'0 1 3 4 6 8 12 13 15 16 17 21 24 25 26 27'

In [6]:
y_preds

array([[1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1,
        0, 0, 1, 1, 1, 1],
       [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0,
        1, 0, 0, 1, 1, 1],
       [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1,
        1, 0, 1, 1, 0, 0],
       [0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1,
        1, 0, 1, 1, 1, 1],
       [0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1,
        0, 1, 1, 0, 0, 1],
       [1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 0, 0],
       [1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1,
        0, 0, 1, 1, 0, 1],
       [1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1,
        0, 0, 1, 0, 1, 0],
       [1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0,
        0, 1, 1, 0, 1, 1],
       [0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0,
        1, 0, 1, 