In [1]:
# Standard imports
import numpy as np
import pathlib
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import ipywidgets as widgets
from ipywidgets import interact, interact_manual, fixed
import seaborn as sns
import sys
import yaml
from pathlib import Path
import shutil
import json
from typing import List, Union
import gc
import zipfile
from pathlib import Path
import re

# Deep learning imports.
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, TensorDataset

import torchvision
from torchvision.utils import draw_bounding_boxes, make_grid
from torchvision.ops import masks_to_boxes, box_area
import torchvision.transforms.functional as TF
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.ops import box_iou

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

import torchmetrics
from torchmetrics.detection.mean_ap import MeanAveragePrecision, compute_area

# Additional settings.
pd.set_option("display.max_rows", 500)
pd.set_option("display.max_columns", 500)

sys.path.append(sys.path[0]+'/..')
from he6_cres_deep_learning.daq import DAQ, Config
root_dir = sys.path[0]+'/config/fasterRCNN'

#### Define Dataset class that will load spec files and targets

In [2]:
class CRES_Dataset(torch.utils.data.Dataset):
    """DOCUMENT."""

    def __init__(
        self, root_dir, freq_bins=4096, max_pool=3, file_max=10, transform=None
    ):
        """
        Args:
            root_dir (string): Directory with all the spec files and targets.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        
        self.root_dir = root_dir
        self.freq_bins = freq_bins
        self.max_pool = max_pool
        self.file_max = file_max
        self.transform = transform

        self.imgs, self.targets = self.collect_imgs_and_targets()

        # Guarentee the correct type.
        self.imgs = self.imgs.type(torch.ByteTensor)

        return None

    def __getitem__(self, idx):

        img = self.imgs[idx]
        target = self.targets[idx]

        if self.transform:
            img = self.transform(img)

        return img, target

    def __len__(self):

        return len(self.imgs)

    def collect_imgs_and_targets(self):

        img_dir = self.root_dir + "/spec_files"
        target_dir = self.root_dir + "/label_files"

        # TODO: make it so directories get spec_prefix instead of files
        imgs, exp_name = self.load_spec_dir(img_dir)
#---------------------------------------------------------------------------------------------------        
        # Is this really the best way to scale the bboxes?
        targets = self.load_target_dir(target_dir, exp_name, imgs[0][0].shape)
        # targets = targets.long()

        return imgs, targets

    def load_spec_dir(self, dir_path):
        """
        Loads all of the images in a directory into torch images.

        Args:
            dir_path (str): path should point to a directory that only contains
                .JPG images. Or any image type compatible with cv2.imread().

            resize_factor (float): how to resize the image. Often one would
                like to reduce the size of the images to be easier/faster to
                use with our maskrcnn model.

        Returns:
            imgs (List[torch.ByteTensor[3, H, W]]): list of images (each a
                torch.ByteTensor of shape(3, H, W)).
        """
        path_glob = Path(dir_path).glob("**/*")
        files = [x for x in path_glob if x.is_file()]
        file_names = [str(x.name) for x in files]
        files = [str(x) for x in files]
        
        # extract experiment name to match to target file 
        exp_name = list(set(re.findall(r'[a-zA-Z0-9]+', name)[0] for name in file_names))
        
        # Extract the file index from the file name.
        file_idxs = [int(re.findall(r"\d+", name)[0]) for name in file_names]

        # Sort the files list based on the file_idx.
        files = [
            file
            for (file, file_idx) in sorted(
                zip(files, file_idxs), key=lambda pair: pair[1]
            )
        ]
        # Maxpool to use on images and labels.
        maxpool = nn.MaxPool2d(self.max_pool, self.max_pool, return_indices=False)

        if len(files) == 0:
            raise UserWarning("No files found at the input path.")

        imgs = []
        for file in files[: self.file_max]:
            img = self.spec_to_numpy(file)
            img = torch.from_numpy(img).unsqueeze(0)
            img = img.permute(0, 2, 1)

            # Apply max pooling now so we never have to hold the large images.

            imgs.append(maxpool(img.float()))

        imgs = torch.stack(imgs)

        return imgs, exp_name
    
    def load_target_dir(self, dir_path, exp_name, spec_shape): # spec_shape[0] is frequency, spec_shape[1] is time
        """
        TODO: Document
        Load bbox json files
        """
        path_glob = Path(dir_path).glob(f"{exp_name}*")
        files = [x for x in path_glob if x.is_file()]
        files = [str(x) for x in files]

        if len(files) == 0:
            raise UserWarning("No files found at the input path.")
        
        # targets will be a list of dicts that contain the bboxes, labels, and scores
        targets = []
        targets_dict = {'boxes': [],
                        'labels': []}
        for file in files[: self.file_max]:
            # read all bboxes for experiment
            with open(file, 'r') as f:
                bboxes = json.load(f)
                
                # each value corresponds to a file number
                for file_num, bbox_dict in bboxes.items():
                    # make sure we've populated the dict at least once before appending to list
                    if file_num != '0':
                        targets.append(targets_dict)
                        targets_dict = {'boxes': [],
                                        'labels': []}
                        
                    # each bbox corresponds to an event in the file
                    for bbox in bbox_dict.values():
                        # apply maxpooling reduction before appending
                        bbox = torch.tensor(bbox)/self.max_pool
                        bbox = bbox.round().int()
                        # max pooling can lead to tracks with no pixel width, avoid this
                        if bbox[3] == bbox[1]:
                            bbox[3] += 1
                        if bbox[2] == bbox[0]:
                            bbox[2] += 1
                        targets_dict['boxes'].append(bbox)
                        targets_dict['labels'].append(torch.tensor([1]))
                        
        targets.append(targets_dict)
        targets_dict = {'boxes': [],
                        'labels': []}
        for target in targets:
            target['boxes'] = torch.stack(target['boxes'])
            target['labels'] = torch.tensor(target['labels'])
        return targets
        

    def spec_to_numpy(
        self, spec_path, slices=-1, packets_per_slice=1, start_packet=None
    ):
        """
        TODO: Document.
        Making this just work for one packet per spectrum because that works for simulation in Katydid.
        * Make another function that works with 4 packets per spectrum (for reading the Kr data).
        """

        BYTES_IN_PAYLOAD = self.freq_bins
        BYTES_IN_HEADER = 32
        BYTES_IN_PACKET = BYTES_IN_PAYLOAD + BYTES_IN_HEADER

        if slices == -1:
            spec_array = np.fromfile(spec_path, dtype="uint8", count=-1).reshape(
                (-1, BYTES_IN_PACKET)
            )[:, BYTES_IN_HEADER:]
        else:
            spec_array = np.fromfile(
                spec_path, dtype="uint8", count=BYTES_IN_PAYLOAD * slices
            ).reshape((-1, BYTES_IN_PACKET))[:, BYTES_IN_HEADER:]
        
        if packets_per_slice > 1:

            spec_flat_list = [
                spec_array[(start_packet + i) % packets_per_slice :: packets_per_slice]
                for i in range(packets_per_slice)
            ]
            spec_flat = np.concatenate(spec_flat_list, axis=1)
            spec_array = spec_flat

        return spec_array

#### Define DataModule that will handle train/val/test splitting of data for modeling

In [3]:
class CRES_DM(pl.LightningDataModule):
    """
    Self contained PyTorch Lightning DataModule for testing image
    segmentation models with PyTorch Lightning. Uses the torch dataset
    ImageSegmentation_DS.

    Args:
        train_val_size (int): total size of the training and validation
            sets combined.
        train_val_split (Tuple[float, float]): should sum to 1.0. For example
            if train_val_size = 100 and train_val_split = (0.80, 0.20)
            then the training set will contain 80 imgs and the validation
            set will contain 20 imgs.
        test_size (int): the size of the test data set.
        batch_size (int): batch size to be input to dataloaders. Applies
            for training, val, and test datasets.

    Notes: For now you can decide to shuffle the entire dataset or not but
    the train is always shuffled and the val/test isn't so you can look at
    the same images easily.
    """

    def __init__(
        self,
        root_dir,
        freq_bins=4096,
        max_pool=8,
        file_max=10,
        transform=None,
        train_val_test_splits=(0.6, 0.3, 0.1),
        batch_size=1,
        shuffle_dataset=True,
        seed=42,
        num_workers=0,
        class_map={
            0: {
                "name": "background",
                "target_color": (255, 255, 255),
            },
            1: {"name": "event", "target_color": (255, 0, 0)}
        },
    ):

        super().__init__()

        # Attributes.
        self.root_dir = root_dir
        self.freq_bins = freq_bins
        self.max_pool = max_pool
        self.file_max = file_max
        self.transform = transform
        self.class_map = class_map
        self.train_val_test_splits = train_val_test_splits
        self.batch_size = batch_size
        self.shuffle_dataset = shuffle_dataset
        self.seed = seed
        self.num_workers = num_workers
        self.setup()

    def setup(self, stage=None):

        self.cres_dataset = CRES_Dataset(
            self.root_dir,
            freq_bins=self.freq_bins,
            max_pool=self.max_pool,
            file_max=self.file_max,
            transform=self.transform,
        )

        # Creating data indices for training and validation splits:
        dataset_size = len(self.cres_dataset)
        indices = list(range(dataset_size))
        splits = self.train_val_test_splits
        split_idxs = [
            int(np.floor(splits[0] * dataset_size)),
            int(np.floor((splits[0] + splits[1]) * dataset_size)),
        ]

        if self.shuffle_dataset:
            rng = np.random.default_rng(self.seed)
            rng.shuffle(indices)

        train_indices, val_indices, test_indices = (
            indices[: split_idxs[0]],
            indices[split_idxs[0] : split_idxs[1]],
            indices[split_idxs[1] :],
        )

        # Creating PT data samplers and loaders. For now only train is shuffled.
        self.train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
        self.val_sampler = torch.utils.data.SequentialSampler(val_indices)
        self.test_sampler = torch.utils.data.SequentialSampler(test_indices)

        return None

    def collate_fn(self, batch):
        """
        When dealing with lists of target dictionaries one needs to be 
        careful how the batches are collated. The default pytorch dataloader 
        behaviour is to return a single dictionary for the whole batch of 
        images which won't work as input to the mask rcnn model. Instead 
        we want a list of dictionaries; one for each image. See here for 
        more details on the dataloader collate_fn:
        https://python.plainenglish.io/understanding-collate-fn-in-pytorch-f9d1742647d3

        Returns:
            imgs (torch.UInt8Tensor[batch_size, 3, img_size, img_size]): 
                batch of images.
            targets (List[Dict[torch.Tensor]]): list of dictionaries of 
                length batch_size.

        """
        imgs = []
        targets = []

        for img, target in batch:
            imgs.append(img)
            targets.append(target)

        # Converts list of tensor images (of shape (3,H,W) and len batch_size)
        # into a tensor of shape (batch_size, 3, H, W).
        imgs = torch.stack(imgs)

        return imgs, targets
    
    def train_dataloader(self):
        return DataLoader(
            self.cres_dataset,
            batch_size=self.batch_size,
            sampler=self.train_sampler,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn
        )

    def val_dataloader(self):
        return DataLoader(
            self.cres_dataset,
            batch_size=self.batch_size,
            sampler=self.val_sampler,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn
        )

    def test_dataloader(self):
        return DataLoader(
            self.cres_dataset,
            batch_size=self.batch_size,
            sampler=self.test_sampler,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn
        )

#### Function to load the model to be trained

In [4]:
def get_fasterrcnn(num_classes=2, pretrained=True):
    """A function for loading the PyTorch implementation of FasterRCNN.
    To not have predictor changed at all set num_classes = -1.
    See here for documentation on the input and output specifics:
    https://pytorch.org/vision/stable/models/faster_rcnn.html

    Args:
        num_classes (int): number of output classes desired.
        pretrained (bool): whether or not to load a model pretrained on the COCO dataset. 
    """

    # load Faster RCNN pre-trained model
    if pretrained:
        model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
    else:
        model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None)

    if num_classes != -1:
        # get the number of input features
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        # define a new head for the detector with required number of classes
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model

#### Function for 'overfitting' of the data

In [12]:
def overfit(imgs, targets, model, optimizer,  device=None,  epochs= 100): 
    
    model = model.to(device)
    model.train()

    # Formatting for input to model. 
    imgs_normed = imgs / 255.0
    imgs_normed = imgs_normed.to(device)
    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
    

    for epoch in range(epochs):

        loss_dict = model(imgs_normed, targets)

        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if epoch %25 == 0:
            print(f"epoch: {epoch}")
            print(f"loss {losses:.4f}\n")

    return None

In [27]:
cres_dm = CRES_DM(root_dir, max_pool=16)

In [28]:
cres_dm.setup(stage = "fit")
dataiter = iter(cres_dm.train_dataloader())

imgs, targets = dataiter.next()

print(f"Input shape:\n {imgs.shape} \n" )

print(f"Faster RCNN target for img 0:\n ")
for key, val in targets.items():
    print(f"{key}:\n {val}")

AttributeError: 'tuple' object has no attribute 'shape'

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = get_fasterrcnn(num_classes = 2, pretrained = True)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
overfit(imgs, targets, model, optimizer, device, epochs=400)

In [17]:
def apply_score_cut(preds, score_threshold=0.5):
    """
    Takes a list of prediction dictionaries (one for each image) and cuts
    out all instances whose score is below the score threshold.

    Args:
        preds (List[Dict[torch.Tensor]]): predictions as output by the
            torchvision implimentation of MaskRCNN or FasterRCNN. The 
            scores are in the range (0,1) and signify the certainty of 
            the model for that instance.
            See link below for details on the target/prediction formatting.
            https://pytorch.org/vision/0.12/_modules/torchvision/models/detection/mask_rcnn.html
        score_threshold (float): the threshold to apply to the identified
            objects. If an instance is below the score_threshold it will
            be removed from the score_thresholded_preds dictionary.

    Returns:
        score_thresholded_preds (List[Dict[torch.Tensor]]): predictions
            that exceed score_threshold.
    """
    score_thresholded_preds = [
        {key: value[pred["scores"] > score_threshold] for key, value in pred.items()}
        for pred in preds
    ]

    return score_thresholded_preds

#### Define LightningModule class for training

In [25]:
class CRES_LM(pl.LightningModule):

    def __init__(self, num_classes = 2, lr = 3e-4, pretrained = False):
        super().__init__()

        # LM Attributes.
        self.num_classes = num_classes
        self.pretrained = pretrained
        self.lr = lr

        # Log hyperparameters. 
        self.save_hyperparameters()

        # Metrics.
        # self.iou = JaccardIndex(task='binary')
        # self.map_bbox = MeanAveragePrecision(iou_type = "bbox", class_metrics = False)

        # Faster RCNN model. 
        self.model = self.get_fasterrcnn_model(self.num_classes, self.pretrained)

    def forward(self, imgs):
        self.model.eval()
        imgs_normed = self.norm_imgs(imgs)
        return self.model(imgs_normed)

    def training_step(self, train_batch, batch_idx):

        imgs, targets = train_batch
        imgs_normed = self.norm_imgs(imgs)

        loss_dict = self.model(imgs_normed, targets)
        losses = sum(loss for loss in loss_dict.values())

        self.log('Loss/train_loss', losses)

        return losses

    def validation_step(self, val_batch, batch_idx):

        imgs, targets = val_batch
        preds = self.forward(imgs)
        
        
        iou_list = torch.tensor([box_iou(target["boxes"], pred["boxes"]).diag().mean() for target, pred in zip(targets, preds)])
        # print(iou_list)
        self.log('IoU_bbox/val',iou_list)

        return None

    def configure_optimizers(self): 

        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)

        return optimizer

    def get_fasterrcnn_model(self, num_classes, pretrained):
        
        # load Faster RCNN pre-trained model
        if pretrained: 
            model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')
        else: 
            model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None)
        
        # get the number of input features 
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        # define a new head for the detector with required number of classes
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 

        return model

    def norm_imgs(self, imgs): 

        imgs_normed = imgs.float() / 255.0

        return imgs_normed

In [6]:
# Define training object
cres_dm = CRES_DM(root_dir,
                  max_pool=16,
                  file_max=1000,
                  batch_size=4,
                  num_workers=4
                  )

In [7]:
cres_dm.cres_dataset.imgs.shape, len(cres_dm.cres_dataset.targets), cres_dm.cres_dataset.targets[0]

(torch.Size([1000, 1, 256, 320]),
 1000,
 {'boxes': tensor([[205, 144, 320, 146],
          [ 59, 250, 250, 252],
          [298, 189, 320, 190],
          [148, 160, 320, 167],
          [172, 118, 320, 125]], dtype=torch.int32),
  'labels': tensor([1, 1, 1, 1, 1])})

In [None]:
# Create Instance of LightningModule
cres_lm = CRES_LM(num_classes = 2, lr = 1e-4, pretrained = True)

# Create callback for ModelCheckpoints. 
checkpoint_callback = ModelCheckpoint(filename='{epoch:02d}', 
                                      save_top_k = 15, 
                                      monitor = "Loss/train_loss", 
                                      every_n_epochs = 1)

# Define Logger. 
logger = TensorBoardLogger("tb_logs", name="cres", log_graph = False)

# Set device.
device = "gpu" if torch.cuda.is_available() else "cpu"

# Create an instance of a Trainer.
trainer = pl.Trainer(logger = logger, 
                     callbacks = [checkpoint_callback], 
                     accelerator = device, 
                     max_epochs = 10, 
                     log_every_n_steps = 1, 
                     check_val_every_n_epoch= 1)

# Fit. 
trainer.fit(cres_lm, cres_dm.train_dataloader(), cres_dm.val_dataloader())

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type       | Params
-------------------------------------
0 | model | FasterRCNN | 41.3 M
-------------------------------------
41.1 M    Trainable params
222 K     Non-trainable params
41.3 M    Total params
165.197   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

tensor([0.0111, 0.0000, 0.0007, 0.0140])




tensor([0.0000, 0.0000, 0.0084, 0.0082])


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

tensor([0.0000, 0.1868, 0.2417, 0.0950])
tensor([0.0000, 0.3538, 0.1598, 0.1001])
tensor([0.0000, 0.0339, 0.4004, 0.0000])
tensor([0.2709, 0.1168, 0.0000, 0.0639])
tensor([0.1087, 0.2300, 0.1185, 0.1430])
tensor([0.0000, 0.1497, 0.1386, 0.1015])
tensor([0.0000, 0.0000, 0.2625, 0.2109])
tensor([0.1285, 0.4133, 0.4767, 0.1631])
tensor([0.1152, 0.0000, 0.0000, 0.0000])
tensor([0.7095, 0.0000, 0.1134, 0.1976])
tensor([0.1406, 0.4031, 0.0468, 0.0000])
tensor([0.1323, 0.0804, 0.1452, 0.0000])
tensor([0.5456, 0.3481, 0.0000, 0.2603])
tensor([0.1286, 0.2058, 0.1250, 0.0604])
tensor([0.5825, 0.0000, 0.4316, 0.6580])
tensor([0.7738, 0.1676, 0.2706, 0.1312])
tensor([0.2306, 0.0000, 0.1304, 0.1865])
tensor([0.2084, 0.1691, 0.0489, 0.2223])
tensor([0.1993, 0.0000, 0.3161, 0.0000])
tensor([0.1448, 0.6521, 0.3355, 0.2061])
tensor([0.1619, 0.1392, 0.4915, 0.1042])
tensor([0.0854, 0.1453, 0.0540, 0.0000])
tensor([0.0000, 0.0939, 0.2523, 0.0883])
tensor([0.0000, 0.3907, 0.0000, 0.2455])
tensor([0.0000, 



Validation: 0it [00:00, ?it/s]

tensor([0., 0., 0., 0.])
tensor([0.0000, 0.2493, 0.0000, 0.1468])
tensor([0.0000, 0.0718, 0.0000, 0.0533])
tensor([0.1058, 0.0000, 0.0837, 0.1630])
tensor([0.1829, 0.1629, 0.0000, 0.0000])
tensor([0.0000, 0.0000, 0.2090, 0.1406])
tensor([0.2692, 0.0000, 0.0000, 0.1244])
tensor([0.0000, 0.2645, 0.0000, 0.2215])
tensor([0.0000, 0.2548, 0.3141, 0.0000])
tensor([0.8185, 0.0000, 0.0000, 0.0000])
tensor([0.2187, 0.4513, 0.4588, 0.0000])
tensor([0., 0., 0., 0.])
tensor([0.2645, 0.2500, 0.0000, 0.2464])
tensor([0.3348, 0.1133, 0.0218, 0.2547])
tensor([0.4342, 0.1534, 0.0000, 0.7147])
tensor([0.9468, 0.1142, 0.0000, 0.0000])
tensor([0.0000, 0.1698, 0.1159, 0.0000])
tensor([0.1700, 0.0000, 0.1549, 0.1521])
tensor([0.0000, 0.1915, 0.2750, 0.1527])
tensor([0.2259, 0.3225, 0.2278, 0.0000])
tensor([0.1713, 0.1202, 0.6566, 0.1042])
tensor([0.3058, 0.3473, 0.0000, 0.1367])
tensor([0.0000, 0.3020, 0.2292, 0.2156])
tensor([0.0000, 0.4603, 0.0000, 0.0158])
tensor([0.1953, 0.0000, 0.2356, 0.3043])
tensor(

Validation: 0it [00:00, ?it/s]

tensor([0.0000, 0.1806, 0.0000, 0.1323])
tensor([0.0000, 0.2888, 0.2181, 0.1030])
tensor([0., 0., 0., 0.])
tensor([0.1258, 0.0912, 0.0000, 0.3011])
tensor([0.0000, 0.1818, 0.0000, 0.0000])
tensor([0.0685, 0.0000, 0.3122, 0.0000])
tensor([0.0000, 0.0000, 0.1327, 0.1042])
tensor([0.1752, 0.2731, 0.0000, 0.2401])
tensor([0.0000, 0.0000, 0.2599, 0.2243])
tensor([0.7536, 0.2531, 0.1135, 0.0000])
tensor([0.2229, 0.4969, 0.0927, 0.0953])
tensor([0.0000, 0.0000, 0.0000, 0.1245])
tensor([0.3087, 0.3527, 0.0000, 0.2366])
tensor([0.0000, 0.0000, 0.1087, 0.1699])
tensor([0.4361, 0.1697, 0.2079, 0.7946])
tensor([0.9018, 0.2591, 0.0000, 0.0113])
tensor([0.0000, 0.0058, 0.1640, 0.0000])
tensor([0.1597, 0.2287, 0.1349, 0.1589])
tensor([0.1840, 0.0000, 0.1750, 0.2676])
tensor([0.3157, 0.0000, 0.4286, 0.0000])
tensor([0.2621, 0.0000, 0.6190, 0.2867])
tensor([0.0862, 0.1997, 0.0000, 0.1141])
tensor([0.0000, 0.0000, 0.2854, 0.1170])
tensor([0.0200, 0.5267, 0.1965, 0.0100])
tensor([0.1408, 0.1950, 0.4524, 

Validation: 0it [00:00, ?it/s]

tensor([0.0000, 0.3163, 0.1744, 0.2661])
tensor([0.0000, 0.2915, 0.0544, 0.0876])
tensor([0.0000, 0.0287, 0.0000, 0.0811])
tensor([0.1381, 0.2881, 0.0000, 0.3403])
tensor([0.3080, 0.1473, 0.0457, 0.0000])
tensor([0.1426, 0.0000, 0.0000, 0.2068])
tensor([0.0000, 0.0000, 0.4114, 0.0000])
tensor([0.1334, 0.2983, 0.0000, 0.0932])
tensor([0.0000, 0.0000, 0.0741, 0.0000])
tensor([0.8880, 0.3714, 0.0000, 0.0000])
tensor([0.2299, 0.3486, 0.0962, 0.0000])
tensor([0.0000, 0.1179, 0.0000, 0.0000])
tensor([0.2913, 0.4541, 0.0000, 0.2394])
tensor([0.3323, 0.5938, 0.0168, 0.1211])
tensor([0.5861, 0.1771, 0.7816, 0.8452])
tensor([0.9375, 0.3355, 0.0000, 0.0000])
tensor([0.0000, 0.0000, 0.0411, 0.0000])
tensor([0.0000, 0.0000, 0.0825, 0.2710])
tensor([0.0000, 0.1671, 0.0561, 0.2628])
tensor([0.1905, 0.0000, 0.4335, 0.0000])
tensor([0.3094, 0.0000, 0.6984, 0.1745])
tensor([0.1539, 0.1248, 0.0000, 0.0000])
tensor([0.0000, 0.0000, 0.2314, 0.1280])
tensor([0.0000, 0.6190, 0.2342, 0.0000])
tensor([0.0947, 

Validation: 0it [00:00, ?it/s]

tensor([0.0000, 0.0000, 0.1056, 0.0042])
tensor([0.0000, 0.2981, 0.0822, 0.1744])
tensor([0.0000, 0.0000, 0.0000, 0.1574])
tensor([0.1408, 0.1631, 0.0000, 0.0000])
tensor([0.0000, 0.1159, 0.0000, 0.0000])
tensor([0.0706, 0.0000, 0.1603, 0.2108])
tensor([0.0910, 0.0000, 0.1619, 0.2215])
tensor([0.1616, 0.3038, 0.0000, 0.2527])
tensor([0.0000, 0.2582, 0.3598, 0.0000])
tensor([0.9059, 0.1708, 0.1431, 0.0000])
tensor([0.2229, 0.4336, 0.1240, 0.0000])
tensor([0.0000, 0.0000, 0.0000, 0.2288])
tensor([0.3259, 0.4528, 0.0000, 0.0000])
tensor([0.3184, 0.0092, 0.0000, 0.1054])
tensor([0.6764, 0.1862, 0.2048, 0.8982])
tensor([0.8870, 0.3085, 0.0000, 0.0000])
tensor([0.1201, 0.0103, 0.2136, 0.0000])
tensor([0.2871, 0.0000, 0.0000, 0.0000])
tensor([0.1629, 0.1699, 0.2363, 0.3630])
tensor([0.0000, 0.0000, 0.4758, 0.1059])
tensor([0.1219, 0.1628, 0.6869, 0.3686])
tensor([0.1657, 0.1863, 0.0000, 0.0000])
tensor([0.0000, 0.0904, 0.2798, 0.0423])
tensor([0.0000, 0.5160, 0.2867, 0.0000])
tensor([0.1052, 

In [12]:
%load_ext tensorboard
%tensorboard --logdir tb_logs/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 2712), started 0:00:37 ago. (Use '!kill 2712' to kill it.)

In [13]:
PATH = 'tb_logs/cres/version_{}/checkpoints/epoch={:02d}.ckpt'.format(52, 2)
cres_lm = CRES_LM.load_from_checkpoint(PATH)

In [15]:
cres_dm.setup(stage = "test")
test_dataiter = iter(cres_dm.test_dataloader())

In [22]:
imgs, targets = test_dataiter.next()
preds = cres_lm(imgs)

In [23]:
def show(imgs, figsize=(10.0, 10.0)):

    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), figsize=figsize, squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = TF.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    plt.show()

    return None

def display_boxes(imgs, target_pred_dict, class_map, width = 1, fill=False):

    num_imgs = len(imgs)
    result_imgs = [
        draw_bounding_boxes(
            imgs[i].type(torch.uint8),
            target_pred_dict[i]["boxes"].int(),
            fill=fill,
            colors=[
                class_map[j.item()]["target_color"]
                for j in target_pred_dict[i]["labels"]
            ],
            width=width,
        )
        for i in range(num_imgs)
    ]

    return result_imgs




In [24]:
%matplotlib inline
@interact
def vizualize_targets_predictions(
                                target_box = widgets.Checkbox(value=False,description='target boxes'),
                                pred_box = widgets.Checkbox(value=False,description='prediction boxes'),
                                num_imgs= widgets.IntSlider(value=len(preds),min=0,max=len(preds),step=1, description = "num_imgs"),
                                score_thresh = widgets.FloatSlider(value=.5,min=0,max=1,step=.0001, description = "score_thresh"),
                                width =  widgets.IntSlider(value=1,min=1,max=10,step=1), 
                                display_size = widgets.IntSlider(value=20,min=2,max=50,step=1)
                                ): 

    preds_cut = apply_score_cut(preds, score_threshold=score_thresh)
    result_image = [imgs[i] for i in range(num_imgs)]

    if target_box: 
        result_image = display_boxes(result_image, targets, cres_dm.class_map, fill = True)

    if pred_box: 
        result_image = display_boxes(result_image, preds_cut, cres_dm.class_map)

    grid = make_grid(result_image)
    show(grid, figsize = (display_size, display_size))

interactive(children=(Checkbox(value=False, description='target boxes'), Checkbox(value=False, description='pr…