In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
from pathlib import Path
import shutil

import pytorch_lightning as pl
from pytorch_lightning.loggers.wandb import WandbLogger

import src.utils as utils
from src.data.dataloader import mtlDataModule
from src.models.lightning_frame import mtlMayhemModule
from src.data.manifests import generate_manifest

import torchvision.transforms as T
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from PIL import Image

from src.utils import load_yaml
import torch

HOME_DIR_HPC = "/zhome/3b/d/154066/repos/multitask-mayhem"
HOME_DIR_X1C7 = "/home/bbejczy/repos/multitask-mayhem"

os.chdir(HOME_DIR_HPC)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
config = utils.load_yaml(HOME_DIR_HPC+"/configs/fasterrcnn_od_hpc.yaml")
manifest = generate_manifest(collections=config["collections"], data_root=config["data_root"])

100%|██████████| 8/8 [00:24<00:00,  3.08s/it]


## Sanity check bounding boxes

Load datamodule with validation files

In [10]:
data_module = mtlDataModule(config=HOME_DIR_HPC+"/configs/fasterrcnn_od_hpc.yaml")
data_module.prepare_data()
data_module.setup(stage="validate")
valid_dataloader = data_module.val_dataloader()
data_module.setup(stage="test")
test_dataloader = data_module.test_dataloader()

100%|██████████| 8/8 [00:24<00:00,  3.06s/it]


Prepare landing folder and label names

In [36]:
sanity_check_folder = HOME_DIR_HPC+"/notebooks/sanity_check"
os.makedirs(sanity_check_folder, exist_ok=True)

class_lookup = load_yaml(HOME_DIR_HPC+"/configs/class_lookup.yaml")

Loop through images and draw labels (use show if not on HPC)

In [11]:
for i, batch in enumerate(valid_dataloader):
    image, targets = batch
    
    boxes = targets[0]["boxes"]
    labels = targets[0]["labels"]
    masks = targets[0]["masks"]
    label_names = [class_lookup["bbox_rev"][label.item()] for label in labels]

    img = image[0].mul(255).type(torch.uint8)
    drawn_image = draw_bounding_boxes(img, boxes, label_names)
    drawn_image = draw_segmentation_masks(drawn_image, masks, alpha=0.5, colors="green")
    image_pil = T.ToPILImage()(drawn_image)
    image_pil.save(sanity_check_folder+"/{}.png".format(i))
    # image_pil.show()



Cleanup folder

In [35]:
shutil.rmtree(sanity_check_folder)


In [8]:
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN_ResNet50_FPN_Weights

model = fasterrcnn_resnet50_fpn(
    pretrained=True, weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT
)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(
    in_features, 5
)

model.load_state_dict(torch.load("models/fasterrcnn_od_hpc_22-11-21T153530/weights/best.pth", map_location=torch.device('cpu')))

<All keys matched successfully>

In [38]:
model.eval()
for i, batch in enumerate(test_dataloader):
    image, targets = batch

    targets = model(image)
    
    boxes = targets[0]["boxes"]
    labels = targets[0]["labels"]
    scores = targets[0]["scores"]
    score_mask = scores > 0.8

    boxes_filtered = boxes[score_mask]
    labels_filtered = labels[score_mask]

    label_names = [class_lookup["bbox_rev"][label.item()] for label in labels_filtered]

    img = image[0].mul(255).type(torch.uint8)
    drawn_image = draw_bounding_boxes(img, boxes_filtered, label_names)
    # drawn_image = draw_segmentation_masks(drawn_image, masks, alpha=0.5, colors="green")
    image_pil = T.ToPILImage()(drawn_image)
    image_pil.save(sanity_check_folder+"/{}.png".format(i))



KeyboardInterrupt: 