# Visualization of results

Load modules and constants

In [1]:
import shutil
import os
import src.utils as utils
from src.data.dataloader import mtlDataModule 
import random

from src.data.manifests import generate_manifest

import torchvision.transforms as T
from torchvision.utils import draw_segmentation_masks
from src.visualization.draw_things import draw_bounding_boxes
from PIL import Image
from pathlib import Path

from src.utils import load_yaml
from src.models.model_loader import ModelLoader
import torch

ROOT_DIR = "/home/bnbj/repos/multitask-mayhem"

os.chdir(ROOT_DIR)

CLASS_LOOKUP = load_yaml(ROOT_DIR+"/configs/class_lookup.yaml")

  from .autonotebook import tqdm as notebook_tqdm


Get config to initialize `pl.DataModule` and create dataloaders

In [3]:
config_file = ROOT_DIR + "/models/fasterrcnn_mobilenetv3_baseline_22-12-04T190124/fasterrcnn_mobilenetv3_baseline_22-12-04T190124.yaml"


data_module = mtlDataModule(config_path=config_file)

seg, det = False, False
if data_module.config["model"] in ["fasterrcnn", "fasterrcnn_mobilenetv3", "ssdlite"]:
    det = True
elif data_module.config["model"] in ["deeplabv3"]:
    seg = True

data_module.config["batch_size"] = 1
data_module.config["num_workers"] = 0
data_module.config["shuffle"] = False

model_loader = ModelLoader(config=data_module.config)
model = model_loader.grab_model()
model_folder = str(Path(config_file).parents[0])
model.load_state_dict(torch.load(model_folder+"/weights/best.pth", map_location=torch.device('cpu')))

data_module.prepare_data()
data_module.setup(stage="fit")
train_dataloader = data_module.train_dataloader()
data_module.setup(stage="validate")
valid_dataloader = data_module.val_dataloader()
data_module.setup(stage="test")
test_dataloader = data_module.test_dataloader()

100%|██████████| 8/8 [00:05<00:00,  1.52it/s]


Prepare landing folder and label names

In [4]:
test_inference = ROOT_DIR+"/notebooks/test_inference"

if os.path.exists(test_inference):
    shutil.rmtree(test_inference)

os.makedirs(test_inference, exist_ok=True)

In [5]:
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

class ImageDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        test_set = utils.list_files_with_extension("data/test/2022-09-23-10-07-37/synchronized_l515_image/", ".png", "path")
        random.seed(42)
        self.image_list = random.sample(test_set, 200)
        self.transforms = transforms.Compose([transforms.ToTensor()])
    
    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image = self.transforms(Image.open(self.image_list[idx]))
        return image.type(torch.FloatTensor)

img_dataset = ImageDataset()

test_set_dataloader = DataLoader(img_dataset)

In [6]:
def tuple_of_tensors_to_tensor(tuple_of_tensors):
    return torch.stack(list(tuple_of_tensors), dim=0)

In [7]:
model.eval()
for i, batch in enumerate(test_dataloader):
    image, target = batch

    image = tuple_of_tensors_to_tensor(image)

    preds = model(image)
    
    if det:
        preds = preds[0]
        boxes = preds["boxes"]
        labels = preds["labels"]
        scores = preds["scores"]
        score_mask = scores > 0.1


        boxes_filtered = boxes[score_mask]
        labels_filtered = labels[score_mask]
        scores_filtered = scores[score_mask]
    


        label_names = [CLASS_LOOKUP["bbox_rev"][label.item()] for label in labels_filtered]
    elif seg:
        masks = preds["out"]
        masks = torch.sigmoid(masks)
        masks = (masks>0.5)

    
    img = image.mul(255).type(torch.uint8)

    if det:
        drawn_image = draw_bounding_boxes(
            image = img.squeeze(0),
            boxes = boxes_filtered,
            labels = label_names,
            scores = scores_filtered
            )
    elif seg:
        drawn_image = draw_segmentation_masks(img.squeeze(0), masks.squeeze(0), alpha=0.5, colors="green")
    
    image_pil = T.ToPILImage()(drawn_image)
    image_pil.save(test_inference+"/{}.png".format(i))
    
    # image_pil.show()



In [3]:
import random
list1 = ["a", "b", "c", "d", "f"]
list2 = [1,2,3,4,5]

list_zip = list(zip(list1, list2))

list_random = random.sample(list_zip, 3)
list_random

[('f', 5), ('b', 2), ('d', 4)]