In [1]:
from types import SimpleNamespace
import wandb
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import DataLoader
from fastprogress import progress_bar
from torcheval.metrics import (
    Mean,
    BinaryAccuracy,
    BinaryPrecision,
    BinaryRecall,
    BinaryF1Score,
)

import params
from utils import get_data, set_seed, ImageDataset, load_model, to_device, get_class_name_in_snake_case as snake_case

default_cfg = SimpleNamespace(
    img_size=256,
    bs=16,
    seed=42,
    epochs=2,
    lr=2e-3,
    wd=1e-5,
    arch="resnet18",
    log_model=True,
    log_preds=False,
    # these are params that are not being changed
    image_column="file_name",
    target_column="mold",
    PROJECT_NAME=params.PROJECT_NAME,
    ENTITY=params.ENTITY,
    PROCESSED_DATA_AT=params.DATA_AT,
)


In [2]:
cfg = default_cfg


set_seed(cfg.seed)

run = wandb.init(
    project=cfg.PROJECT_NAME,
    entity=cfg.ENTITY,
    job_type="evaluation",
    tags=["staging"],
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcapecape[0m ([33mwandb_course[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [3]:

wandb.config.update(cfg)

df, processed_dataset_dir = get_data(cfg.PROCESSED_DATA_AT, eval=True)

test_data = df[df["test"] == True]
val_data = df[df["test"] == False]

test_transforms = val_transforms = [
    T.Resize(cfg.img_size),
    T.ToTensor(),
]

[34m[1mwandb[0m: Downloading large artifact lemon_data:v0, 137.77MB. 2692 files... 
[34m[1mwandb[0m:   2692 of 2692 files downloaded.  
Done. 0:0:0.4


In [4]:
val_dataset = ImageDataset(
    val_data,
    processed_dataset_dir,
    image_column=cfg.image_column,
    target_column=cfg.target_column,
    transform=val_transforms,
)

test_dataset = ImageDataset(
    test_data,
    processed_dataset_dir,
    image_column=cfg.image_column,
    target_column=cfg.target_column,
    transform=val_transforms,
)

test_dataloader = DataLoader(
    test_dataset, batch_size=cfg.bs, shuffle=False, num_workers=4
)
valid_dataloader = DataLoader(
    test_dataset, batch_size=cfg.bs, shuffle=False, num_workers=4
)

In [5]:
model_artifact_name = "rca-dev/model-registry/Lemon Detector:staging"
model = load_model(model_artifact_name)


[34m[1mwandb[0m: Downloading large artifact Lemon Mold Detector:candidate, 106.19MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.0


In [6]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def cross_entropy(x, y):
    "A flattened version of nn.BCEWithLogitsLoss"
    loss_func = nn.BCEWithLogitsLoss()
    return loss_func(x.squeeze(), y.squeeze().float())


In [8]:
metrics = [BinaryAccuracy(device=device),
            BinaryPrecision(device=device),
            BinaryRecall(device=device),
            BinaryF1Score(device=device),
            ]


In [9]:
outputs  = [torch.randn(32) for _ in range(10)]
labels = [torch.randint(0, 2, (32,)).bool() for _ in range(10)]

In [10]:
for out, lbl in zip(outputs, labels):
    for metric in metrics:
        metric.update(out, lbl)

In [11]:
for m in metrics:
    print(f"{snake_case(m)}: {m.compute()}")

binary_accuracy: 17.600000381469727
binary_precision: 0.5483871102333069
binary_recall: 0.3333333432674408
binary_f1_score: 0.41463416814804077


In [17]:
@torch.inference_mode()
def evaluate(loader):
    loss_mean = Mean(device=device)
    metrics = [BinaryAccuracy(device=device),
               BinaryPrecision(device=device),
               BinaryRecall(device=device),
               BinaryF1Score(device=device),
               ]

    for b in progress_bar(loader, leave=True, total=len(loader)):
        images, labels = to_device(b, device)
        outputs = model(images).squeeze()
        loss = cross_entropy(outputs, labels)
        loss_mean.update(loss)
        for metric in metrics:
            metric.update(outputs, labels.long())


    return loss, metrics

In [18]:

valid_loss, valid_metrics = evaluate(valid_dataloader)
test_loss, test_metrics   = evaluate(test_dataloader)

In [20]:
for m in valid_metrics:
    print(f"valid_{snake_case(m)}: {m.compute()}")
    wandb.summary[f"valid_{snake_case(m)}"] = m.compute()

for m in test_metrics:
    print("test_" + snake_case(m) + ": " + str(m.compute()))
    wandb.summary[f"test_{snake_case(m)}"] = m.compute()

valid_binary_accuracy: 0.9902439117431641
valid_binary_precision: 1.0
valid_binary_recall: 0.9512194991111755
valid_binary_f1_score: 0.9749999642372131
test_binary_accuracy: tensor(0.9902, device='cuda:0')
test_binary_precision: tensor(1., device='cuda:0')
test_binary_recall: tensor(0.9512, device='cuda:0')
test_binary_f1_score: tensor(0.9750, device='cuda:0')


In [21]:
run.finish()

VBox(children=(Label(value='0.002 MB of 0.020 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.112054…

0,1
test_binary_accuracy,0.99024
test_binary_f1_score,0.975
test_binary_precision,1.0
test_binary_recall,0.95122
valid_binary_accuracy,0.99024
valid_binary_f1_score,0.975
valid_binary_precision,1.0
valid_binary_recall,0.95122


TorchEval Bug

In [None]:
metrics = [BinaryAccuracy(device=device),
            BinaryPrecision(device=device),
            BinaryRecall(device=device),
            BinaryF1Score(device=device),
            ]


In [None]:
outputs  = [torch.randn(32) for _ in range(10)]
labels = [torch.randint(0, 2, (32,)).bool() for _ in range(10)]

In [None]:
for out, lbl in zip(outputs, labels):
    for metric in metrics:
        metric.update(out, lbl)

In [None]:
for m in metrics:
    print(f"{snake_case(m)}: {m.compute()}")

binary_accuracy: 17.600000381469727
binary_precision: 0.5483871102333069
binary_recall: 0.3333333432674408
binary_f1_score: 0.41463416814804077
