# Visualization of results

Load modules and constants

In [1]:
import shutil
import os
import src.utils as utils
from src.data.dataloader import mtlDataModule

from src.data.manifests import generate_manifest

import torchvision.transforms as T
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from PIL import Image

from src.utils import load_yaml
import torch

ROOT_DIR = "/zhome/3b/d/154066/repos/multitask-mayhem"

os.chdir(ROOT_DIR)

CLASS_LOOKUP = load_yaml(ROOT_DIR+"/configs/class_lookup.yaml")

  from .autonotebook import tqdm as notebook_tqdm


Get config to initialize `pl.DataModule` and create dataloaders

In [2]:
config_file = ROOT_DIR + "/configs/debug_foo.yaml"


data_module = mtlDataModule(config_path=config_file)

seg, det = False, False
if data_module.config["model"] in ["fasterrcnn", "fasterrcnn_mobilenetv3", "ssdlite"]:
    det = True
elif data_module.config["model"] in ["deeplabv3"]:
    seg = True

data_module.config["batch_size"] = 1
data_module.config["num_workers"] = 0
data_module.config["shuffle"] = False
print(torch.cuda.is_available())

data_module.prepare_data()
data_module.setup(stage="fit")
train_dataloader = data_module.train_dataloader()
data_module.setup(stage="validate")
valid_dataloader = data_module.val_dataloader()
data_module.setup(stage="test")
test_dataloader = data_module.test_dataloader()

100%|██████████| 1/1 [00:01<00:00,  1.35s/it]

False





Prepare landing folder and label names

In [3]:
sanity_check = ROOT_DIR+"/notebooks/sanity_check"
test_inference = ROOT_DIR+"/notebooks/test_inference"

if os.path.exists(sanity_check):
    shutil.rmtree(sanity_check)
if os.path.exists(test_inference):
    shutil.rmtree(test_inference)

os.makedirs(sanity_check, exist_ok=True)
os.makedirs(test_inference, exist_ok=True)

## Sanity check labels

Loop through images and draw labels (use show if not on HPC)

In [5]:
for i, batch in enumerate(train_dataloader):
    image, targets = batch
    
    if det:
        boxes = targets[0]["boxes"]
        labels = targets[0]["labels"]
        masks = targets[0]["masks"]
        label_names = [CLASS_LOOKUP["bbox_rev"][label.item()] for label in labels]
    elif seg:
        masks = targets[0]

    
    img = image[0].mul(255).type(torch.uint8)

    if det:
        drawn_image = draw_bounding_boxes(img, boxes, label_names)
        drawn_image = draw_segmentation_masks(drawn_image, masks, alpha=0.5, colors="green")
    elif seg:
        drawn_image = draw_segmentation_masks(img, masks, alpha=0.5, colors="green")
    
    image_pil = T.ToPILImage()(drawn_image)
    image_pil.save(sanity_check+"/{}.png".format(i))
    # image_pil.show()

Cleanup folder

In [18]:
from torchvision.models.detection import fasterrcnn_resnet50_fpn, ssdlite320_mobilenet_v3_large
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN_ResNet50_FPN_Weights
from torchvision.models.detection.ssdlite import SSDLiteClassificationHead, SSDLiteHead
import torchvision.models.detection._utils as det_utils
from functools import partial
import torch.nn as nn


# model = fasterrcnn_resnet50_fpn(
#     pretrained=True, weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT
# )
# in_features = model.roi_heads.box_predictor.cls_score.in_features
# model.roi_heads.box_predictor = FastRCNNPredictor(
#     in_features, 5
# )

model = ssdlite320_mobilenet_v3_large(
    pretrained=True,
    weights="DEFAULT",
)
in_features = det_utils.retrieve_out_channels(model.backbone, (320, 320))
num_anchors = model.anchor_generator.num_anchors_per_location()
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)  # NOTE: stolened from a blogpost
model.head.classification_head = SSDLiteClassificationHead(
    in_channels=in_features,
    num_classes=config["detection_classes"],
    num_anchors=num_anchors,
    norm_layer=norm_layer,
)

model.load_state_dict(torch.load("models/ssdlite_batch8_22-11-30T214955/weights/best.pth", map_location=torch.device('cpu')))

<All keys matched successfully>

In [19]:
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks

model.eval()
for i, batch in enumerate(test_dataloader):
    image, targets = batch

    targets = model(image)
    
    boxes = targets[0]["boxes"]
    labels = targets[0]["labels"]
    scores = targets[0]["scores"]
    score_mask = scores > 0.7

    boxes_filtered = boxes[score_mask]
    labels_filtered = labels[score_mask]
    scores_filtered = scores[score_mask]

    label_names = [class_lookup["bbox_rev"][label.item()] for label in labels_filtered]

    img = image[0].mul(255).type(torch.uint8)
    drawn_image = draw_bounding_boxes(img, boxes_filtered, label_names)
    # drawn_image = draw_segmentation_masks(drawn_image, masks, alpha=0.5, colors="green")
    image_pil = T.ToPILImage()(drawn_image)
    image_pil.save(sanity_check_folder+"/{}.png".format(i))

KeyboardInterrupt: 