In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from itertools import chain
from pathlib import Path
from tqdm.notebook import tqdm
import numpy as np
from sklearn.metrics import average_precision_score
from PIL import Image
import shutil

import torch
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl



In [3]:
from voc_data import VOCDataset, PascalVOC
# from voc_transforms import ImageOnly, LabelsOnly

### Define Classifier

In [4]:
class VOCClassifier(pl.LightningModule):
    def __init__(self, batch_size=1, learning_rate=1e-4, shuffle=True, log_every_n_steps=10):
        super().__init__()
        self.log_every_n_steps = log_every_n_steps
        # Hyperparameters
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.learning_rate = learning_rate

        # Model definition
        self.stem = torchvision.models.resnet50(pretrained=True, progress=True)
        self.stem.fc = torch.nn.Linear(2048, 20)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        if len(x.shape) == 5:
            bs, ncrops, c, h, w = x.size()
            x = self.stem(x.view(-1, c, h, w))
            x = x.view(bs, ncrops, -1).max(1)[0]
        elif len(x.shape) == 4:
            x = self.stem(x)
        else:
            raise ValueError(f"Expected input to have rank 4 or 5, got {x.shape} (rank {len(x.shape)}) instead")
        x = self.sigmoid(x)
        return x

    @staticmethod
    def _multi_label_loss(pred, labels):
        loss = 0
        for i in range(labels.shape[1]):
            loss += F.binary_cross_entropy(pred[:, i].double(), labels[:, i].double())
        return loss / labels.shape[1]

    def training_step(self, batch, batch_idx):
        image, labels, _ = batch
        pred = self.forward(image)
#         loss = F.binary_cross_entropy(pred, labels)
        loss = self._multi_label_loss(pred, labels)
        output = {"loss": loss}
        if self.log_every_n_steps and batch_idx % self.log_every_n_steps == 0:
            tensorboard_logs = {"train_loss": loss}
            output["log"] = tensorboard_logs
        return output

    def validation_step(self, batch, batch_idx):
        image, labels, _ = batch
        pred = self.forward(image)
#         loss = F.binary_cross_entropy(pred, labels)
        loss = self._multi_label_loss(pred, labels)
        correct = ((pred > 0.5) == labels).sum().float()
        count = labels.shape[0] * labels.shape[1]
        output = {
            "val_loss": loss,
            "correct": correct, "count": count,
            "pred": pred, "labels": labels
        }
        return output

    def validation_end(self, outputs):
        avg_val_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        correct = torch.stack([x["correct"] for x in outputs]).sum()
        accuracy = correct / sum([x["count"] for x in outputs])
        all_labels = torch.cat([x["labels"] for x in outputs]).cpu().detach().numpy()
        all_pred = torch.cat([x["pred"] for x in outputs]).cpu().detach().numpy()
        mean_ap = average_precision_score(all_labels, all_pred, None).mean()

        tensorboard_logs = {"val_loss": avg_val_loss, "val_acc": accuracy}
        if np.isfinite(mean_ap):
            tensorboard_logs["mean_ap"] = mean_ap
        return {"val_loss": avg_val_loss, "log": tensorboard_logs}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [5]:
ROOT_DIR = "./data/VOCdevkit/VOC2012/"
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
SHUFFLE = True

In [6]:
# VOC data helper
voc = PascalVOC(ROOT_DIR)

# Data transforms
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()
base_transform = transforms.Compose([
    transforms.Resize(500),
    transforms.TenCrop(224),
    transforms.Lambda(lambda crops: torch.stack([normalize(to_tensor(crop)) for crop in crops])),
])
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(25),
    transforms.ColorJitter(contrast=0.25, saturation=0.25),
    base_transform,
])

# Datasets and DataLoaders
train_dataset = VOCDataset(voc, split="train", transform=train_transform)
val_dataset = VOCDataset(voc, split="val", transform=base_transform)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE, num_workers=2)

In [7]:
model = VOCClassifier(batch_size=1, shuffle=True)
print(model)

VOCClassifier(
  (stem): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
       

In [8]:
pl_logger = pl.loggers.TestTubeLogger(save_dir="experiments/")
trainer = pl.Trainer(
    gpus=[0],
    logger=pl_logger,
    progress_bar_refresh_rate=10,
#     val_check_interval=0.1,
#     val_percent_check=0.25,
    overfit_pct=0.1,
)

In [9]:
trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloader)

HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



  recall = tps / tps[-1]


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

  "Did not find hyperparameters at model.hparams. Saving checkpoint without"


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

  "Did not find hyperparameters at model.hparams. Saving checkpoint without"


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

  "Did not find hyperparameters at model.hparams. Saving checkpoint without"


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

  "Did not find hyperparameters at model.hparams. Saving checkpoint without"


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

  "Did not find hyperparameters at model.hparams. Saving checkpoint without"


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

  "Did not find hyperparameters at model.hparams. Saving checkpoint without"


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

  "Did not find hyperparameters at model.hparams. Saving checkpoint without"


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

  "Did not find hyperparameters at model.hparams. Saving checkpoint without"


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

  "Did not find hyperparameters at model.hparams. Saving checkpoint without"


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

  "Did not find hyperparameters at model.hparams. Saving checkpoint without"


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…




1

## Save Model

In [11]:
model_path = Path("models/model_fivecrop_flip.pt")

In [13]:
if model_path.exists():
    raise ValueError("Refusing to overwrite existing model.")
torch.save(model.state_dict(), model_path)

ValueError: Refusing to overwrite existing model.

## Load Model

In [10]:
sd = torch.load(model_path)
loaded_model = VOCClassifier(batch_size=1, shuffle=True)
loaded_model.load_state_dict(sd)

<All keys matched successfully>

## Evaluate Model

In [None]:
def visualize_top_and_bottom_k(
    pred, paths, cat_names,
    k=50, num_classes=5, total_classes=20,
    output_dir="data/output",
):
    # Top and bottom K image visualization (on randomly chosen classes)
    top_k = 50
    output_path = Path(output_dir)
    class_indices = np.random.choice(list(range(total_classes)), size=num_classes, replace=False)
    print(f"Saving top and bottom {k} images for {num_classes} classes...")
    for class_index in class_indices:
        sorted_indices = np.argsort(pred[:, class_index])
        # Top K
        top = sorted_indices[-top_k:]
        for i, src in enumerate(paths[top]):
            dst = output_path / f"top/{cat_names[class_index]}/{top_k - i}.jpg"
            dst.parent.mkdir(parents=True, exist_ok=True)
            shutil.copy2(src, dst)
        # Bottom K
        bottom = sorted_indices[:top_k]
        for i, src in enumerate(paths[bottom]):
            dst = output_path / f"bottom/{cat_names[class_index]}/{i + 1}.jpg"
            dst.parent.mkdir(parents=True, exist_ok=True)
            shutil.copy2(src, dst)
    print("Done!")

def show_tail_accuracy(pred, labels, start, num_steps=20):
    end = pred.max(0)[0].min()
    steps = np.linspace(start, end, num=num_steps)
    
    tail_accuracies = []
    for step in steps:
        classwise_tailacc = np.reshape(
            (np.sum((pred > step) * labels, 0) / np.sum(pred > step, 0)),
            (1, -1)
        )
        tail_accuracies.append(classwise_tailacc)

def get_predictions(model, dataloader, cuda=True, subset=None):
    all_labels = []
    all_pred = []
    all_paths = []

    if subset:
        assert isinstance(subset, float) and 0 < subset < 1
    max_samples = int(np.ceil(len(dataloader) * subset)) if subset else len(dataloader)

    # Get predictions
    with torch.no_grad():
        for i, (image, labels, path) in tqdm(enumerate(dataloader), total=max_samples, ncols='100%'):
            if cuda:
                model.cuda()
                image = image.cuda()
                labels = labels.cuda()
            pred = model(image)
            all_labels.append(labels.cpu().numpy())
            all_pred.append(pred.cpu().numpy())
            all_paths.append(path)

            if subset and i >= max_samples:
                break

    labels = np.concatenate(all_labels, 0)
    pred = np.concatenate(all_pred, 0)
    paths = np.concatenate(all_paths)

    return pred, labels, paths

def evaluate_model(pred, labels, paths, voc):
    cat_names = voc.list_image_sets()
    
    # Binarize predictions
    pred_binary = pred > 0.5

    # Average precision
    ap = average_precision_score(labels, pred, None)
    mean_ap = ap.mean()

    # Accuracy
    correct = np.sum(pred_binary == labels)
    total = labels.shape[0] * labels.shape[1]
    accuracy = correct / total

    print("Accuracy:", accuracy)
    print("Average Precision:", ap)
    print("Mean Average Precision:", mean_ap)
    visualize_top_and_bottom_k(
        pred, paths, cat_names,
        k=50, num_classes=5, total_classes=20,
        output_dir="data/output",
    )
    show_tail_accuracy(pred, labels, 0)

#     return accuracy, ap, mean_ap, pred, labels, paths

In [21]:
evaluate_model(model, val_dataloader, voc, cuda=True, subset=0.1)

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=73.0), HTML(value='')), layout=Layout(dis…

Accuracy: 0.3295608108108108
Average Precision: [0.13573478 0.04033112 0.05543846 0.11831555 0.04656833 0.04056815
 0.13643402 0.10728754 0.14336018 0.02989256 0.04922072 0.20472271
 0.02698561 0.03587689 0.33603921 0.08437039 0.02425837 0.0473138
 0.02818547 0.07062769]
Mean Average Precision: 0.08807657882196558
Saving top and bottom 50 images for 5 classes...


In [None]:
model.cuda()
for image, labels in train_dataloader:
    pred = model(image.cuda()[:1])
    print("Ground Truth:", labels[:1])
    print("Predictions:", pred)
    break