In [9]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
from pathlib import Path
import shutil

import pytorch_lightning as pl
from pytorch_lightning.loggers.wandb import WandbLogger

import src.utils as utils
from src.data.dataloader import mtlDataModule
from src.models.lightning_frame import mtlMayhemModule
from src.data.manifests import generate_manifest

import torchvision.transforms as T
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from PIL import Image

from src.utils import load_yaml
import torch

HOME_DIR_HPC = "/zhome/3b/d/154066/repos/multitask-mayhem"
HOME_DIR_X1C7 = "/home/bbejczy/repos/multitask-mayhem"

os.chdir(HOME_DIR_HPC)

In [10]:
config = utils.load_yaml(HOME_DIR_HPC+"/configs/fasterrcnn_od_hpc.yaml")
manifest = generate_manifest(collections=config["collections"], data_root=config["data_root"])

  0%|          | 0/8 [00:00<?, ?it/s]

Cleaning mask directory


 12%|█▎        | 1/8 [00:09<01:04,  9.23s/it]

Cleaning mask directory


 25%|██▌       | 2/8 [00:18<00:56,  9.48s/it]

Cleaning mask directory


 38%|███▊      | 3/8 [00:22<00:34,  6.91s/it]

Cleaning mask directory


 50%|█████     | 4/8 [00:25<00:20,  5.17s/it]

Cleaning mask directory


 62%|██████▎   | 5/8 [00:25<00:10,  3.52s/it]

Cleaning mask directory


 75%|███████▌  | 6/8 [00:26<00:05,  2.70s/it]

Cleaning mask directory


 88%|████████▊ | 7/8 [00:28<00:02,  2.17s/it]

Cleaning mask directory


100%|██████████| 8/8 [00:28<00:00,  3.61s/it]


## Sanity check bounding boxes

Load datamodule with validation files

In [11]:
data_module = mtlDataModule(config_path=HOME_DIR_HPC+"/configs/ssdlite_batchsize.yaml")
data_module.prepare_data()
data_module.setup(stage="validate")
valid_dataloader = data_module.val_dataloader()
data_module.setup(stage="test")
test_dataloader = data_module.test_dataloader()

100%|██████████| 8/8 [00:17<00:00,  2.15s/it]


Prepare landing folder and label names

In [16]:
sanity_check_folder = HOME_DIR_HPC+"/notebooks/sanity_check"

shutil.rmtree(sanity_check_folder)

os.makedirs(sanity_check_folder, exist_ok=True)

class_lookup = load_yaml(HOME_DIR_HPC+"/configs/class_lookup.yaml")

Loop through images and draw labels (use show if not on HPC)

In [17]:
# for i, batch in enumerate(valid_dataloader):
#     image, targets = batch
    
#     boxes = targets[0]["boxes"]
#     labels = targets[0]["labels"]
#     masks = targets[0]["masks"]
#     label_names = [class_lookup["bbox_rev"][label.item()] for label in labels]

#     img = image[0].mul(255).type(torch.uint8)
#     drawn_image = draw_bounding_boxes(img, boxes, label_names)
#     drawn_image = draw_segmentation_masks(drawn_image, masks, alpha=0.5, colors="green")
#     image_pil = T.ToPILImage()(drawn_image)
#     image_pil.save(sanity_check_folder+"/{}.png".format(i))
#     # image_pil.show()

Cleanup folder

In [18]:
from torchvision.models.detection import fasterrcnn_resnet50_fpn, ssdlite320_mobilenet_v3_large
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN_ResNet50_FPN_Weights
from torchvision.models.detection.ssdlite import SSDLiteClassificationHead, SSDLiteHead
import torchvision.models.detection._utils as det_utils
from functools import partial
import torch.nn as nn


# model = fasterrcnn_resnet50_fpn(
#     pretrained=True, weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT
# )
# in_features = model.roi_heads.box_predictor.cls_score.in_features
# model.roi_heads.box_predictor = FastRCNNPredictor(
#     in_features, 5
# )

model = ssdlite320_mobilenet_v3_large(
    pretrained=True,
    weights="DEFAULT",
)
in_features = det_utils.retrieve_out_channels(model.backbone, (320, 320))
num_anchors = model.anchor_generator.num_anchors_per_location()
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)  # NOTE: stolened from a blogpost
model.head.classification_head = SSDLiteClassificationHead(
    in_channels=in_features,
    num_classes=config["detection_classes"],
    num_anchors=num_anchors,
    norm_layer=norm_layer,
)

model.load_state_dict(torch.load("models/ssdlite_batch8_22-11-30T214955/weights/best.pth", map_location=torch.device('cpu')))

<All keys matched successfully>

In [19]:
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks

model.eval()
for i, batch in enumerate(test_dataloader):
    image, targets = batch

    targets = model(image)
    
    boxes = targets[0]["boxes"]
    labels = targets[0]["labels"]
    scores = targets[0]["scores"]
    score_mask = scores > 0.7

    boxes_filtered = boxes[score_mask]
    labels_filtered = labels[score_mask]
    scores_filtered = scores[score_mask]

    label_names = [class_lookup["bbox_rev"][label.item()] for label in labels_filtered]

    img = image[0].mul(255).type(torch.uint8)
    drawn_image = draw_bounding_boxes(img, boxes_filtered, label_names)
    # drawn_image = draw_segmentation_masks(drawn_image, masks, alpha=0.5, colors="green")
    image_pil = T.ToPILImage()(drawn_image)
    image_pil.save(sanity_check_folder+"/{}.png".format(i))

KeyboardInterrupt: 