In [2]:
#imports
import os
import pathlib

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.models.detection.transform import GeneralizedRCNNTransform

from pytorch_faster_rcnn.datasets import ObjectDetectionDatasetSingle, ObjectDetectionDataSet
from pytorch_faster_rcnn.faster_RCNN import get_faster_rcnn_resnet
from pytorch_faster_rcnn.transformations import ComposeDouble
from pytorch_faster_rcnn.transformations import ComposeSingle
from pytorch_faster_rcnn.transformations import FunctionWrapperDouble
from pytorch_faster_rcnn.transformations import FunctionWrapperSingle
from pytorch_faster_rcnn.transformations import apply_nms, apply_score_threshold
from pytorch_faster_rcnn.transformations import normalize_01
from pytorch_faster_rcnn.utils import get_filenames_of_path, collate_single, save_json
from pytorch_faster_rcnn.visual import DatasetViewer
from pytorch_faster_rcnn.visual import DatasetViewerSingle
from pytorch_faster_rcnn.backbone_resnet import ResNetBackbones

In [3]:
#initiate parameters for inference
params = {'INPUT_DIR': 'pytorch_faster_rcnn/data/shelves/test',  # input files for which to generate prediction
          'PREDICTIONS_PATH': 'pytorch_faster_rcnn/data/shelves/predictions',  #predictions save directory
          'MODEL_DIR': 'experiment1',  #directory to load trained models from
          'VERSION': 'version_6', #specific version to use for inference
          }

In [4]:
#load input files for inference/prediction (unlabeled grocery images)
inputs = get_filenames_of_path(pathlib.Path(params['INPUT_DIR']))
inputs.sort()

In [5]:
#transformations to apply to input data
transforms = ComposeSingle([
    FunctionWrapperSingle(np.moveaxis, source=-1, destination=0),
    FunctionWrapperSingle(normalize_01)
    ])

In [6]:
# create dataset (single because it has no targets)
dataset = ObjectDetectionDatasetSingle(inputs=inputs,
                                       transform=transforms,
                                       use_cache=False,
                                       )

In [7]:
# create dataloader
dataloader_prediction = DataLoader(dataset=dataset,
                                   batch_size=1,
                                   shuffle=False,
                                   num_workers=0,
                                   collate_fn=collate_single)

In [11]:
#load model from version indicated in params
checkpoint_path = str(os.getcwd()) +"/" + params['MODEL_DIR'] + "/" + params['VERSION'] + "/checkpoints"
for file in os.listdir(checkpoint_path):
    checkpoint_path += str("/" + os.fsdecode(file))
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
model_state_dict = checkpoint['hyper_parameters']['model'].state_dict()

In [12]:
#initiate model with same parameters as version indicated in params (same backbone, anchor sizes, aspect_ratios, etc.)
model = get_faster_rcnn_resnet(num_classes=2,
                               backbone_name= ResNetBackbones.RESNET50,  
                               anchor_size=((128, 256, 512),),
                               aspect_ratios=((0.5, 1.0, 2.0),),
                               fpn=False,
                               min_size=1024,
                               max_size=1025
                               )

In [13]:
#load weights onto model (!only works if params specified above match the one used in the trained version that is loaded)
model.load_state_dict(model_state_dict)

RuntimeError: Error(s) in loading state_dict for FasterRCNN:
	Unexpected key(s) in state_dict: "backbone.5.4.conv1.weight", "backbone.5.4.bn1.weight", "backbone.5.4.bn1.bias", "backbone.5.4.bn1.running_mean", "backbone.5.4.bn1.running_var", "backbone.5.4.bn1.num_batches_tracked", "backbone.5.4.conv2.weight", "backbone.5.4.bn2.weight", "backbone.5.4.bn2.bias", "backbone.5.4.bn2.running_mean", "backbone.5.4.bn2.running_var", "backbone.5.4.bn2.num_batches_tracked", "backbone.5.4.conv3.weight", "backbone.5.4.bn3.weight", "backbone.5.4.bn3.bias", "backbone.5.4.bn3.running_mean", "backbone.5.4.bn3.running_var", "backbone.5.4.bn3.num_batches_tracked", "backbone.5.5.conv1.weight", "backbone.5.5.bn1.weight", "backbone.5.5.bn1.bias", "backbone.5.5.bn1.running_mean", "backbone.5.5.bn1.running_var", "backbone.5.5.bn1.num_batches_tracked", "backbone.5.5.conv2.weight", "backbone.5.5.bn2.weight", "backbone.5.5.bn2.bias", "backbone.5.5.bn2.running_mean", "backbone.5.5.bn2.running_var", "backbone.5.5.bn2.num_batches_tracked", "backbone.5.5.conv3.weight", "backbone.5.5.bn3.weight", "backbone.5.5.bn3.bias", "backbone.5.5.bn3.running_mean", "backbone.5.5.bn3.running_var", "backbone.5.5.bn3.num_batches_tracked", "backbone.5.6.conv1.weight", "backbone.5.6.bn1.weight", "backbone.5.6.bn1.bias", "backbone.5.6.bn1.running_mean", "backbone.5.6.bn1.running_var", "backbone.5.6.bn1.num_batches_tracked", "backbone.5.6.conv2.weight", "backbone.5.6.bn2.weight", "backbone.5.6.bn2.bias", "backbone.5.6.bn2.running_mean", "backbone.5.6.bn2.running_var", "backbone.5.6.bn2.num_batches_tracked", "backbone.5.6.conv3.weight", "backbone.5.6.bn3.weight", "backbone.5.6.bn3.bias", "backbone.5.6.bn3.running_mean", "backbone.5.6.bn3.running_var", "backbone.5.6.bn3.num_batches_tracked", "backbone.5.7.conv1.weight", "backbone.5.7.bn1.weight", "backbone.5.7.bn1.bias", "backbone.5.7.bn1.running_mean", "backbone.5.7.bn1.running_var", "backbone.5.7.bn1.num_batches_tracked", "backbone.5.7.conv2.weight", "backbone.5.7.bn2.weight", "backbone.5.7.bn2.bias", "backbone.5.7.bn2.running_mean", "backbone.5.7.bn2.running_var", "backbone.5.7.bn2.num_batches_tracked", "backbone.5.7.conv3.weight", "backbone.5.7.bn3.weight", "backbone.5.7.bn3.bias", "backbone.5.7.bn3.running_mean", "backbone.5.7.bn3.running_var", "backbone.5.7.bn3.num_batches_tracked", "backbone.6.6.conv1.weight", "backbone.6.6.bn1.weight", "backbone.6.6.bn1.bias", "backbone.6.6.bn1.running_mean", "backbone.6.6.bn1.running_var", "backbone.6.6.bn1.num_batches_tracked", "backbone.6.6.conv2.weight", "backbone.6.6.bn2.weight", "backbone.6.6.bn2.bias", "backbone.6.6.bn2.running_mean", "backbone.6.6.bn2.running_var", "backbone.6.6.bn2.num_batches_tracked", "backbone.6.6.conv3.weight", "backbone.6.6.bn3.weight", "backbone.6.6.bn3.bias", "backbone.6.6.bn3.running_mean", "backbone.6.6.bn3.running_var", "backbone.6.6.bn3.num_batches_tracked", "backbone.6.7.conv1.weight", "backbone.6.7.bn1.weight", "backbone.6.7.bn1.bias", "backbone.6.7.bn1.running_mean", "backbone.6.7.bn1.running_var", "backbone.6.7.bn1.num_batches_tracked", "backbone.6.7.conv2.weight", "backbone.6.7.bn2.weight", "backbone.6.7.bn2.bias", "backbone.6.7.bn2.running_mean", "backbone.6.7.bn2.running_var", "backbone.6.7.bn2.num_batches_tracked", "backbone.6.7.conv3.weight", "backbone.6.7.bn3.weight", "backbone.6.7.bn3.bias", "backbone.6.7.bn3.running_mean", "backbone.6.7.bn3.running_var", "backbone.6.7.bn3.num_batches_tracked", "backbone.6.8.conv1.weight", "backbone.6.8.bn1.weight", "backbone.6.8.bn1.bias", "backbone.6.8.bn1.running_mean", "backbone.6.8.bn1.running_var", "backbone.6.8.bn1.num_batches_tracked", "backbone.6.8.conv2.weight", "backbone.6.8.bn2.weight", "backbone.6.8.bn2.bias", "backbone.6.8.bn2.running_mean", "backbone.6.8.bn2.running_var", "backbone.6.8.bn2.num_batches_tracked", "backbone.6.8.conv3.weight", "backbone.6.8.bn3.weight", "backbone.6.8.bn3.bias", "backbone.6.8.bn3.running_mean", "backbone.6.8.bn3.running_var", "backbone.6.8.bn3.num_batches_tracked", "backbone.6.9.conv1.weight", "backbone.6.9.bn1.weight", "backbone.6.9.bn1.bias", "backbone.6.9.bn1.running_mean", "backbone.6.9.bn1.running_var", "backbone.6.9.bn1.num_batches_tracked", "backbone.6.9.conv2.weight", "backbone.6.9.bn2.weight", "backbone.6.9.bn2.bias", "backbone.6.9.bn2.running_mean", "backbone.6.9.bn2.running_var", "backbone.6.9.bn2.num_batches_tracked", "backbone.6.9.conv3.weight", "backbone.6.9.bn3.weight", "backbone.6.9.bn3.bias", "backbone.6.9.bn3.running_mean", "backbone.6.9.bn3.running_var", "backbone.6.9.bn3.num_batches_tracked", "backbone.6.10.conv1.weight", "backbone.6.10.bn1.weight", "backbone.6.10.bn1.bias", "backbone.6.10.bn1.running_mean", "backbone.6.10.bn1.running_var", "backbone.6.10.bn1.num_batches_tracked", "backbone.6.10.conv2.weight", "backbone.6.10.bn2.weight", "backbone.6.10.bn2.bias", "backbone.6.10.bn2.running_mean", "backbone.6.10.bn2.running_var", "backbone.6.10.bn2.num_batches_tracked", "backbone.6.10.conv3.weight", "backbone.6.10.bn3.weight", "backbone.6.10.bn3.bias", "backbone.6.10.bn3.running_mean", "backbone.6.10.bn3.running_var", "backbone.6.10.bn3.num_batches_tracked", "backbone.6.11.conv1.weight", "backbone.6.11.bn1.weight", "backbone.6.11.bn1.bias", "backbone.6.11.bn1.running_mean", "backbone.6.11.bn1.running_var", "backbone.6.11.bn1.num_batches_tracked", "backbone.6.11.conv2.weight", "backbone.6.11.bn2.weight", "backbone.6.11.bn2.bias", "backbone.6.11.bn2.running_mean", "backbone.6.11.bn2.running_var", "backbone.6.11.bn2.num_batches_tracked", "backbone.6.11.conv3.weight", "backbone.6.11.bn3.weight", "backbone.6.11.bn3.bias", "backbone.6.11.bn3.running_mean", "backbone.6.11.bn3.running_var", "backbone.6.11.bn3.num_batches_tracked", "backbone.6.12.conv1.weight", "backbone.6.12.bn1.weight", "backbone.6.12.bn1.bias", "backbone.6.12.bn1.running_mean", "backbone.6.12.bn1.running_var", "backbone.6.12.bn1.num_batches_tracked", "backbone.6.12.conv2.weight", "backbone.6.12.bn2.weight", "backbone.6.12.bn2.bias", "backbone.6.12.bn2.running_mean", "backbone.6.12.bn2.running_var", "backbone.6.12.bn2.num_batches_tracked", "backbone.6.12.conv3.weight", "backbone.6.12.bn3.weight", "backbone.6.12.bn3.bias", "backbone.6.12.bn3.running_mean", "backbone.6.12.bn3.running_var", "backbone.6.12.bn3.num_batches_tracked", "backbone.6.13.conv1.weight", "backbone.6.13.bn1.weight", "backbone.6.13.bn1.bias", "backbone.6.13.bn1.running_mean", "backbone.6.13.bn1.running_var", "backbone.6.13.bn1.num_batches_tracked", "backbone.6.13.conv2.weight", "backbone.6.13.bn2.weight", "backbone.6.13.bn2.bias", "backbone.6.13.bn2.running_mean", "backbone.6.13.bn2.running_var", "backbone.6.13.bn2.num_batches_tracked", "backbone.6.13.conv3.weight", "backbone.6.13.bn3.weight", "backbone.6.13.bn3.bias", "backbone.6.13.bn3.running_mean", "backbone.6.13.bn3.running_var", "backbone.6.13.bn3.num_batches_tracked", "backbone.6.14.conv1.weight", "backbone.6.14.bn1.weight", "backbone.6.14.bn1.bias", "backbone.6.14.bn1.running_mean", "backbone.6.14.bn1.running_var", "backbone.6.14.bn1.num_batches_tracked", "backbone.6.14.conv2.weight", "backbone.6.14.bn2.weight", "backbone.6.14.bn2.bias", "backbone.6.14.bn2.running_mean", "backbone.6.14.bn2.running_var", "backbone.6.14.bn2.num_batches_tracked", "backbone.6.14.conv3.weight", "backbone.6.14.bn3.weight", "backbone.6.14.bn3.bias", "backbone.6.14.bn3.running_mean", "backbone.6.14.bn3.running_var", "backbone.6.14.bn3.num_batches_tracked", "backbone.6.15.conv1.weight", "backbone.6.15.bn1.weight", "backbone.6.15.bn1.bias", "backbone.6.15.bn1.running_mean", "backbone.6.15.bn1.running_var", "backbone.6.15.bn1.num_batches_tracked", "backbone.6.15.conv2.weight", "backbone.6.15.bn2.weight", "backbone.6.15.bn2.bias", "backbone.6.15.bn2.running_mean", "backbone.6.15.bn2.running_var", "backbone.6.15.bn2.num_batches_tracked", "backbone.6.15.conv3.weight", "backbone.6.15.bn3.weight", "backbone.6.15.bn3.bias", "backbone.6.15.bn3.running_mean", "backbone.6.15.bn3.running_var", "backbone.6.15.bn3.num_batches_tracked", "backbone.6.16.conv1.weight", "backbone.6.16.bn1.weight", "backbone.6.16.bn1.bias", "backbone.6.16.bn1.running_mean", "backbone.6.16.bn1.running_var", "backbone.6.16.bn1.num_batches_tracked", "backbone.6.16.conv2.weight", "backbone.6.16.bn2.weight", "backbone.6.16.bn2.bias", "backbone.6.16.bn2.running_mean", "backbone.6.16.bn2.running_var", "backbone.6.16.bn2.num_batches_tracked", "backbone.6.16.conv3.weight", "backbone.6.16.bn3.weight", "backbone.6.16.bn3.bias", "backbone.6.16.bn3.running_mean", "backbone.6.16.bn3.running_var", "backbone.6.16.bn3.num_batches_tracked", "backbone.6.17.conv1.weight", "backbone.6.17.bn1.weight", "backbone.6.17.bn1.bias", "backbone.6.17.bn1.running_mean", "backbone.6.17.bn1.running_var", "backbone.6.17.bn1.num_batches_tracked", "backbone.6.17.conv2.weight", "backbone.6.17.bn2.weight", "backbone.6.17.bn2.bias", "backbone.6.17.bn2.running_mean", "backbone.6.17.bn2.running_var", "backbone.6.17.bn2.num_batches_tracked", "backbone.6.17.conv3.weight", "backbone.6.17.bn3.weight", "backbone.6.17.bn3.bias", "backbone.6.17.bn3.running_mean", "backbone.6.17.bn3.running_var", "backbone.6.17.bn3.num_batches_tracked", "backbone.6.18.conv1.weight", "backbone.6.18.bn1.weight", "backbone.6.18.bn1.bias", "backbone.6.18.bn1.running_mean", "backbone.6.18.bn1.running_var", "backbone.6.18.bn1.num_batches_tracked", "backbone.6.18.conv2.weight", "backbone.6.18.bn2.weight", "backbone.6.18.bn2.bias", "backbone.6.18.bn2.running_mean", "backbone.6.18.bn2.running_var", "backbone.6.18.bn2.num_batches_tracked", "backbone.6.18.conv3.weight", "backbone.6.18.bn3.weight", "backbone.6.18.bn3.bias", "backbone.6.18.bn3.running_mean", "backbone.6.18.bn3.running_var", "backbone.6.18.bn3.num_batches_tracked", "backbone.6.19.conv1.weight", "backbone.6.19.bn1.weight", "backbone.6.19.bn1.bias", "backbone.6.19.bn1.running_mean", "backbone.6.19.bn1.running_var", "backbone.6.19.bn1.num_batches_tracked", "backbone.6.19.conv2.weight", "backbone.6.19.bn2.weight", "backbone.6.19.bn2.bias", "backbone.6.19.bn2.running_mean", "backbone.6.19.bn2.running_var", "backbone.6.19.bn2.num_batches_tracked", "backbone.6.19.conv3.weight", "backbone.6.19.bn3.weight", "backbone.6.19.bn3.bias", "backbone.6.19.bn3.running_mean", "backbone.6.19.bn3.running_var", "backbone.6.19.bn3.num_batches_tracked", "backbone.6.20.conv1.weight", "backbone.6.20.bn1.weight", "backbone.6.20.bn1.bias", "backbone.6.20.bn1.running_mean", "backbone.6.20.bn1.running_var", "backbone.6.20.bn1.num_batches_tracked", "backbone.6.20.conv2.weight", "backbone.6.20.bn2.weight", "backbone.6.20.bn2.bias", "backbone.6.20.bn2.running_mean", "backbone.6.20.bn2.running_var", "backbone.6.20.bn2.num_batches_tracked", "backbone.6.20.conv3.weight", "backbone.6.20.bn3.weight", "backbone.6.20.bn3.bias", "backbone.6.20.bn3.running_mean", "backbone.6.20.bn3.running_var", "backbone.6.20.bn3.num_batches_tracked", "backbone.6.21.conv1.weight", "backbone.6.21.bn1.weight", "backbone.6.21.bn1.bias", "backbone.6.21.bn1.running_mean", "backbone.6.21.bn1.running_var", "backbone.6.21.bn1.num_batches_tracked", "backbone.6.21.conv2.weight", "backbone.6.21.bn2.weight", "backbone.6.21.bn2.bias", "backbone.6.21.bn2.running_mean", "backbone.6.21.bn2.running_var", "backbone.6.21.bn2.num_batches_tracked", "backbone.6.21.conv3.weight", "backbone.6.21.bn3.weight", "backbone.6.21.bn3.bias", "backbone.6.21.bn3.running_mean", "backbone.6.21.bn3.running_var", "backbone.6.21.bn3.num_batches_tracked", "backbone.6.22.conv1.weight", "backbone.6.22.bn1.weight", "backbone.6.22.bn1.bias", "backbone.6.22.bn1.running_mean", "backbone.6.22.bn1.running_var", "backbone.6.22.bn1.num_batches_tracked", "backbone.6.22.conv2.weight", "backbone.6.22.bn2.weight", "backbone.6.22.bn2.bias", "backbone.6.22.bn2.running_mean", "backbone.6.22.bn2.running_var", "backbone.6.22.bn2.num_batches_tracked", "backbone.6.22.conv3.weight", "backbone.6.22.bn3.weight", "backbone.6.22.bn3.bias", "backbone.6.22.bn3.running_mean", "backbone.6.22.bn3.running_var", "backbone.6.22.bn3.num_batches_tracked", "backbone.6.23.conv1.weight", "backbone.6.23.bn1.weight", "backbone.6.23.bn1.bias", "backbone.6.23.bn1.running_mean", "backbone.6.23.bn1.running_var", "backbone.6.23.bn1.num_batches_tracked", "backbone.6.23.conv2.weight", "backbone.6.23.bn2.weight", "backbone.6.23.bn2.bias", "backbone.6.23.bn2.running_mean", "backbone.6.23.bn2.running_var", "backbone.6.23.bn2.num_batches_tracked", "backbone.6.23.conv3.weight", "backbone.6.23.bn3.weight", "backbone.6.23.bn3.bias", "backbone.6.23.bn3.running_mean", "backbone.6.23.bn3.running_var", "backbone.6.23.bn3.num_batches_tracked", "backbone.6.24.conv1.weight", "backbone.6.24.bn1.weight", "backbone.6.24.bn1.bias", "backbone.6.24.bn1.running_mean", "backbone.6.24.bn1.running_var", "backbone.6.24.bn1.num_batches_tracked", "backbone.6.24.conv2.weight", "backbone.6.24.bn2.weight", "backbone.6.24.bn2.bias", "backbone.6.24.bn2.running_mean", "backbone.6.24.bn2.running_var", "backbone.6.24.bn2.num_batches_tracked", "backbone.6.24.conv3.weight", "backbone.6.24.bn3.weight", "backbone.6.24.bn3.bias", "backbone.6.24.bn3.running_mean", "backbone.6.24.bn3.running_var", "backbone.6.24.bn3.num_batches_tracked", "backbone.6.25.conv1.weight", "backbone.6.25.bn1.weight", "backbone.6.25.bn1.bias", "backbone.6.25.bn1.running_mean", "backbone.6.25.bn1.running_var", "backbone.6.25.bn1.num_batches_tracked", "backbone.6.25.conv2.weight", "backbone.6.25.bn2.weight", "backbone.6.25.bn2.bias", "backbone.6.25.bn2.running_mean", "backbone.6.25.bn2.running_var", "backbone.6.25.bn2.num_batches_tracked", "backbone.6.25.conv3.weight", "backbone.6.25.bn3.weight", "backbone.6.25.bn3.bias", "backbone.6.25.bn3.running_mean", "backbone.6.25.bn3.running_var", "backbone.6.25.bn3.num_batches_tracked", "backbone.6.26.conv1.weight", "backbone.6.26.bn1.weight", "backbone.6.26.bn1.bias", "backbone.6.26.bn1.running_mean", "backbone.6.26.bn1.running_var", "backbone.6.26.bn1.num_batches_tracked", "backbone.6.26.conv2.weight", "backbone.6.26.bn2.weight", "backbone.6.26.bn2.bias", "backbone.6.26.bn2.running_mean", "backbone.6.26.bn2.running_var", "backbone.6.26.bn2.num_batches_tracked", "backbone.6.26.conv3.weight", "backbone.6.26.bn3.weight", "backbone.6.26.bn3.bias", "backbone.6.26.bn3.running_mean", "backbone.6.26.bn3.running_var", "backbone.6.26.bn3.num_batches_tracked", "backbone.6.27.conv1.weight", "backbone.6.27.bn1.weight", "backbone.6.27.bn1.bias", "backbone.6.27.bn1.running_mean", "backbone.6.27.bn1.running_var", "backbone.6.27.bn1.num_batches_tracked", "backbone.6.27.conv2.weight", "backbone.6.27.bn2.weight", "backbone.6.27.bn2.bias", "backbone.6.27.bn2.running_mean", "backbone.6.27.bn2.running_var", "backbone.6.27.bn2.num_batches_tracked", "backbone.6.27.conv3.weight", "backbone.6.27.bn3.weight", "backbone.6.27.bn3.bias", "backbone.6.27.bn3.running_mean", "backbone.6.27.bn3.running_var", "backbone.6.27.bn3.num_batches_tracked", "backbone.6.28.conv1.weight", "backbone.6.28.bn1.weight", "backbone.6.28.bn1.bias", "backbone.6.28.bn1.running_mean", "backbone.6.28.bn1.running_var", "backbone.6.28.bn1.num_batches_tracked", "backbone.6.28.conv2.weight", "backbone.6.28.bn2.weight", "backbone.6.28.bn2.bias", "backbone.6.28.bn2.running_mean", "backbone.6.28.bn2.running_var", "backbone.6.28.bn2.num_batches_tracked", "backbone.6.28.conv3.weight", "backbone.6.28.bn3.weight", "backbone.6.28.bn3.bias", "backbone.6.28.bn3.running_mean", "backbone.6.28.bn3.running_var", "backbone.6.28.bn3.num_batches_tracked", "backbone.6.29.conv1.weight", "backbone.6.29.bn1.weight", "backbone.6.29.bn1.bias", "backbone.6.29.bn1.running_mean", "backbone.6.29.bn1.running_var", "backbone.6.29.bn1.num_batches_tracked", "backbone.6.29.conv2.weight", "backbone.6.29.bn2.weight", "backbone.6.29.bn2.bias", "backbone.6.29.bn2.running_mean", "backbone.6.29.bn2.running_var", "backbone.6.29.bn2.num_batches_tracked", "backbone.6.29.conv3.weight", "backbone.6.29.bn3.weight", "backbone.6.29.bn3.bias", "backbone.6.29.bn3.running_mean", "backbone.6.29.bn3.running_var", "backbone.6.29.bn3.num_batches_tracked", "backbone.6.30.conv1.weight", "backbone.6.30.bn1.weight", "backbone.6.30.bn1.bias", "backbone.6.30.bn1.running_mean", "backbone.6.30.bn1.running_var", "backbone.6.30.bn1.num_batches_tracked", "backbone.6.30.conv2.weight", "backbone.6.30.bn2.weight", "backbone.6.30.bn2.bias", "backbone.6.30.bn2.running_mean", "backbone.6.30.bn2.running_var", "backbone.6.30.bn2.num_batches_tracked", "backbone.6.30.conv3.weight", "backbone.6.30.bn3.weight", "backbone.6.30.bn3.bias", "backbone.6.30.bn3.running_mean", "backbone.6.30.bn3.running_var", "backbone.6.30.bn3.num_batches_tracked", "backbone.6.31.conv1.weight", "backbone.6.31.bn1.weight", "backbone.6.31.bn1.bias", "backbone.6.31.bn1.running_mean", "backbone.6.31.bn1.running_var", "backbone.6.31.bn1.num_batches_tracked", "backbone.6.31.conv2.weight", "backbone.6.31.bn2.weight", "backbone.6.31.bn2.bias", "backbone.6.31.bn2.running_mean", "backbone.6.31.bn2.running_var", "backbone.6.31.bn2.num_batches_tracked", "backbone.6.31.conv3.weight", "backbone.6.31.bn3.weight", "backbone.6.31.bn3.bias", "backbone.6.31.bn3.running_mean", "backbone.6.31.bn3.running_var", "backbone.6.31.bn3.num_batches_tracked", "backbone.6.32.conv1.weight", "backbone.6.32.bn1.weight", "backbone.6.32.bn1.bias", "backbone.6.32.bn1.running_mean", "backbone.6.32.bn1.running_var", "backbone.6.32.bn1.num_batches_tracked", "backbone.6.32.conv2.weight", "backbone.6.32.bn2.weight", "backbone.6.32.bn2.bias", "backbone.6.32.bn2.running_mean", "backbone.6.32.bn2.running_var", "backbone.6.32.bn2.num_batches_tracked", "backbone.6.32.conv3.weight", "backbone.6.32.bn3.weight", "backbone.6.32.bn3.bias", "backbone.6.32.bn3.running_mean", "backbone.6.32.bn3.running_var", "backbone.6.32.bn3.num_batches_tracked", "backbone.6.33.conv1.weight", "backbone.6.33.bn1.weight", "backbone.6.33.bn1.bias", "backbone.6.33.bn1.running_mean", "backbone.6.33.bn1.running_var", "backbone.6.33.bn1.num_batches_tracked", "backbone.6.33.conv2.weight", "backbone.6.33.bn2.weight", "backbone.6.33.bn2.bias", "backbone.6.33.bn2.running_mean", "backbone.6.33.bn2.running_var", "backbone.6.33.bn2.num_batches_tracked", "backbone.6.33.conv3.weight", "backbone.6.33.bn3.weight", "backbone.6.33.bn3.bias", "backbone.6.33.bn3.running_mean", "backbone.6.33.bn3.running_var", "backbone.6.33.bn3.num_batches_tracked", "backbone.6.34.conv1.weight", "backbone.6.34.bn1.weight", "backbone.6.34.bn1.bias", "backbone.6.34.bn1.running_mean", "backbone.6.34.bn1.running_var", "backbone.6.34.bn1.num_batches_tracked", "backbone.6.34.conv2.weight", "backbone.6.34.bn2.weight", "backbone.6.34.bn2.bias", "backbone.6.34.bn2.running_mean", "backbone.6.34.bn2.running_var", "backbone.6.34.bn2.num_batches_tracked", "backbone.6.34.conv3.weight", "backbone.6.34.bn3.weight", "backbone.6.34.bn3.bias", "backbone.6.34.bn3.running_mean", "backbone.6.34.bn3.running_var", "backbone.6.34.bn3.num_batches_tracked", "backbone.6.35.conv1.weight", "backbone.6.35.bn1.weight", "backbone.6.35.bn1.bias", "backbone.6.35.bn1.running_mean", "backbone.6.35.bn1.running_var", "backbone.6.35.bn1.num_batches_tracked", "backbone.6.35.conv2.weight", "backbone.6.35.bn2.weight", "backbone.6.35.bn2.bias", "backbone.6.35.bn2.running_mean", "backbone.6.35.bn2.running_var", "backbone.6.35.bn2.num_batches_tracked", "backbone.6.35.conv3.weight", "backbone.6.35.bn3.weight", "backbone.6.35.bn3.bias", "backbone.6.35.bn3.running_mean", "backbone.6.35.bn3.running_var", "backbone.6.35.bn3.num_batches_tracked". 
	size mismatch for rpn.head.cls_logits.weight: copying a param with shape torch.Size([15, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([9, 2048, 1, 1]).
	size mismatch for rpn.head.cls_logits.bias: copying a param with shape torch.Size([15]) from checkpoint, the shape in current model is torch.Size([9]).
	size mismatch for rpn.head.bbox_pred.weight: copying a param with shape torch.Size([60, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([36, 2048, 1, 1]).
	size mismatch for rpn.head.bbox_pred.bias: copying a param with shape torch.Size([60]) from checkpoint, the shape in current model is torch.Size([36]).

In [27]:
# inference (cpu) --> makes the predictions and saves then in predictions_path as specified in params
model.eval()
for sample in dataloader_prediction:
    x, x_name = sample
    with torch.no_grad():
        pred = model(x)
        pred = {key: value.numpy() for key, value in pred[0].items()}
        name = pathlib.Path(x_name[0])
        save_dir = pathlib.Path(os.getcwd()) / params['PREDICTIONS_PATH']
        save_dir.mkdir(parents=True, exist_ok=True)
        pred_list = {key: value.tolist() for key, value in pred.items()}  # numpy arrays are not serializable -> .tolist()
        save_json(pred_list, path=save_dir / name.with_suffix('.json'))

KeyboardInterrupt: 

In [28]:
# get prediction files (json) & sort them
predictions = get_filenames_of_path(pathlib.Path(os.getcwd()) / params['PREDICTIONS_PATH'])
predictions.sort()

In [37]:
# set IoU threshold & score threshold
iou_threshold = 0.25
score_threshold = 0.5

# set transformations applied to dataset with predictions (can experiment with nms and score threshold)
transforms_prediction = ComposeDouble([
    FunctionWrapperDouble(np.moveaxis, source=-1, destination=0),
    FunctionWrapperDouble(normalize_01),
    FunctionWrapperDouble(apply_nms, input=False, target=True, iou_threshold=iou_threshold),
    FunctionWrapperDouble(apply_score_threshold, input=False, target=True, score_threshold=score_threshold)
])

#create dataset with predictions
dataset_prediction = ObjectDetectionDataSet(inputs=inputs,
                                            targets=predictions,
                                            transform=transforms_prediction,
                                            use_cache=False)

In [31]:
# mapping
color_mapping = {
    1: 'red',
}

# visualize predictions
datasetviewer_prediction = DatasetViewer(dataset_prediction, color_mapping)
datasetviewer_prediction.napari()
# add text properties gui
datasetviewer_prediction.gui_text_properties(datasetviewer_prediction.shape_layer)