In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from sklearn.metrics import average_precision_score

import torch
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl



In [3]:
from voc_data import VOCDataset, PascalVOC
# from voc_transforms import ImageOnly, LabelsOnly

### Define Classifier

In [4]:
class VOCClassifier(pl.LightningModule):
    def __init__(self, batch_size=1, learning_rate=1e-4, shuffle=True):
        super().__init__()
        # Hyperparameters
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.learning_rate = learning_rate

        # Model definition
        self.stem = torchvision.models.resnet50(pretrained=True, progress=True)
        self.stem.fc = torch.nn.Linear(2048, 20)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        bs, ncrops, c, h, w = x.size()
        x = self.stem(x.view(-1, c, h, w))
        x = x.view(bs, ncrops, -1).max(1)[0]
        x = self.sigmoid(x)
        return x

    def training_step(self, batch, batch_idx):
        image, labels = batch
        pred = self.forward(image)
        loss = F.binary_cross_entropy(pred, labels)
        tensorboard_logs = {"loss": {"train": loss}}
        return {"loss": loss, "log": tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        image, labels = batch
        pred = self.forward(image)
        loss = F.binary_cross_entropy(pred, labels)

        correct = ((pred > 0.5) == labels).sum().float()
        count = labels.shape[0] * labels.shape[1]
        return {"val_loss": loss, "correct": correct, "count": count}

    def validation_end(self, outputs):
        avg_val_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        correct = torch.stack([x["correct"] for x in outputs]).sum()
        accuracy = correct / sum([x["count"] for x in outputs])
        tensorboard_logs = {"val_loss": avg_val_loss, "val_acc": accuracy}
        return {"val_loss": avg_val_loss, "log": tensorboard_logs}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

#     @staticmethod
#     def _collate_batch(batch):
#         """
#         Custom batch collation function.
#         """
#         images = torch.stack([e[0] for e in batch])
#         labels = torch.stack([e[1] for e in batch])
#         return images, labels

#     def train_dataloader(self):
#         return DataLoader(
#             self.train_dataset,
#             batch_size=self.batch_size,
#             shuffle=self.shuffle,
#             # collate_fn=self._collate_batch,
#         )

In [5]:
ROOT_DIR = "./data/VOCdevkit/VOC2012/"
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
SHUFFLE = True

In [6]:
# VOC data helper
voc = PascalVOC(ROOT_DIR)

# Data transforms
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()
base_transform = transforms.Compose([
    transforms.Resize(400),
    transforms.FiveCrop(224),
    transforms.Lambda(lambda crops: torch.stack([normalize(to_tensor(crop)) for crop in crops])),
])
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(25),
    base_transform,
])

# Datasets and DataLoaders
train_dataset = VOCDataset(voc, split="train", transform=train_transform)
val_dataset = VOCDataset(voc, split="val", transform=base_transform)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE, num_workers=4)

In [7]:
model = VOCClassifier(batch_size=1, shuffle=True)
print(model)

VOCClassifier(
  (stem): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
       

In [8]:
pl_logger = pl.loggers.TestTubeLogger(save_dir="experiments/")
trainer = pl.Trainer(gpus=[0], logger=pl_logger, progress_bar_refresh_rate=5, overfit_pct=0.1)

In [None]:
trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloader)

HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

  "Did not find hyperparameters at model.hparams. Saving checkpoint without"


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=72.0, style=Pro…

In [None]:
model.cuda()
for image, labels in train_dataloader:
    pred = model(image.cuda()[:1])
    print("Ground Truth:", labels[:1])
    print("Predictions:", pred)
    break