In [13]:
import os
import pathlib

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.models.detection.transform import GeneralizedRCNNTransform

from pytorch_faster_rcnn_tutorial.datasets import ObjectDetectionDatasetSingle, ObjectDetectionDataSet
from pytorch_faster_rcnn_tutorial.faster_RCNN import get_faster_rcnn_resnet
from pytorch_faster_rcnn_tutorial.transformations import ComposeDouble
from pytorch_faster_rcnn_tutorial.transformations import ComposeSingle
from pytorch_faster_rcnn_tutorial.transformations import FunctionWrapperDouble
from pytorch_faster_rcnn_tutorial.transformations import FunctionWrapperSingle
from pytorch_faster_rcnn_tutorial.transformations import apply_nms, apply_score_threshold
from pytorch_faster_rcnn_tutorial.transformations import normalize_01
from pytorch_faster_rcnn_tutorial.utils import get_filenames_of_path, collate_single, save_json
from pytorch_faster_rcnn_tutorial.visual import DatasetViewer
from pytorch_faster_rcnn_tutorial.visual import DatasetViewerSingle
from pytorch_faster_rcnn_tutorial.backbone_resnet import ResNetBackbones

In [None]:
#initiate parameters for inference
params = {'INPUT_DIR': 'pytorch_faster_rcnn_tutorial/data/shelves/test',  # input files for which to generate prediction
          'PREDICTIONS_PATH': 'pytorch_faster_rcnn_tutorial/data/shelves/predictions',  #predictions save directory
          'MODEL_DIR': 'experiment1',  # load model from checkpoint
          'VERSION': 'version_3',
          #'DOWNLOAD': False,  # whether to download from tensorboard --> I think can be deleted because logs are stored on here
          #'DOWNLOAD_PATH': 'model',  # where to save the model if DOWNLOAD is True
          #'PROJECT': 'Heads',  # Project name (last three params seem unnecessary for my approach)
          }

In [None]:
#input files for inference/prediction (unlabeled grocery images)
inputs = get_filenames_of_path(pathlib.Path(params['INPUT_DIR']))
inputs.sort()

In [None]:
#transformations to input data
transforms = ComposeSingle([
    FunctionWrapperSingle(np.moveaxis, source=-1, destination=0),
    FunctionWrapperSingle(normalize_01)
    ])

In [None]:
# create dataset (single because it has no targets)
dataset = ObjectDetectionDatasetSingle(inputs=inputs,
                                       transform=transforms,
                                       use_cache=False,
                                       )

In [None]:
# create dataloader
dataloader_prediction = DataLoader(dataset=dataset,
                                   batch_size=1,
                                   shuffle=False,
                                   num_workers=0,
                                   collate_fn=collate_single)

In [None]:
#change this to automatically get .ckpt and hparams.yaml paths from version --> LATER
checkpoint_path = str(os.getcwd()) +"/" + params['MODEL_DIR'] + "/" + params['VERSION'] + "/checkpoints"
for file in os.listdir(checkpoint_path):
    checkpoint_path += str("/" + os.fsdecode(file))
checkpoint = torch.load(checkpoint_path)
model_state_dict = checkpoint['hyper_parameters']['model'].state_dict()

In [28]:
#get project name and parameters (min_size, max_size, etc.) to use for transforms

#model = pl.LightningModule.load_from_checkpoint("/Users/wandermarijnissen/repos/Wander-python/experiment1/version_1/checkpoints/epoch=0-step=19.ckpt")#, hparams_file="/Users/wandermarijnissen/repos/Wander-python/experiment1/version_1/hparams.yaml")
model = get_faster_rcnn_resnet(num_classes=2,
                               backbone_name=  ResNetBackbones.RESNET152,  
                               anchor_size=((32, 64, 128, 256, 512),),
                               aspect_ratios=((0.5, 1.0, 2.0),),
                               fpn=False,
                               min_size=1024,
                               max_size=1025
                               )

In [None]:
model.load_state_dict(model_state_dict)

In [29]:
#see all parameters and set requires_grad 
for para in model.parameters():
    para.requires_grad = True
    print(para)

Parameter containing:
tensor([[[[ 4.7132e-07,  6.3123e-07,  6.1915e-07,  ...,  2.9313e-07,
            2.1123e-07,  1.3036e-07],
          [ 4.8263e-07,  7.1548e-07,  7.1251e-07,  ...,  3.0581e-07,
            2.6611e-07,  2.3413e-07],
          [ 4.9888e-07,  6.3326e-07,  6.1920e-07,  ...,  1.2629e-07,
            1.8429e-07,  2.0732e-07],
          ...,
          [ 5.5013e-07,  3.1735e-07,  4.1098e-07,  ...,  3.1079e-07,
            3.4928e-07,  3.4718e-07],
          [ 6.2982e-07,  4.0325e-07,  3.4432e-07,  ...,  4.8297e-07,
            6.4529e-07,  5.4214e-07],
          [ 7.1402e-07,  5.0883e-07,  4.4785e-07,  ...,  6.2946e-07,
            6.5617e-07,  5.0979e-07]],

         [[ 5.0878e-07,  6.8802e-07,  6.1782e-07,  ...,  2.2142e-07,
            2.1541e-07,  1.8464e-07],
          [ 4.2393e-07,  6.5220e-07,  6.2894e-07,  ...,  2.8318e-07,
            2.5690e-07,  2.3177e-07],
          [ 4.6649e-07,  6.4230e-07,  6.2854e-07,  ...,  1.3226e-07,
            2.2451e-07,  2.1060e-07]

In [30]:
#sanity check: play around with freezing layers based on their name
for name, param in model.named_parameters():
    if param.requires_grad and 'backbone' in name:
        param.requires_grad = False
    print(param)

Parameter containing:
tensor([[[[ 4.7132e-07,  6.3123e-07,  6.1915e-07,  ...,  2.9313e-07,
            2.1123e-07,  1.3036e-07],
          [ 4.8263e-07,  7.1548e-07,  7.1251e-07,  ...,  3.0581e-07,
            2.6611e-07,  2.3413e-07],
          [ 4.9888e-07,  6.3326e-07,  6.1920e-07,  ...,  1.2629e-07,
            1.8429e-07,  2.0732e-07],
          ...,
          [ 5.5013e-07,  3.1735e-07,  4.1098e-07,  ...,  3.1079e-07,
            3.4928e-07,  3.4718e-07],
          [ 6.2982e-07,  4.0325e-07,  3.4432e-07,  ...,  4.8297e-07,
            6.4529e-07,  5.4214e-07],
          [ 7.1402e-07,  5.0883e-07,  4.4785e-07,  ...,  6.2946e-07,
            6.5617e-07,  5.0979e-07]],

         [[ 5.0878e-07,  6.8802e-07,  6.1782e-07,  ...,  2.2142e-07,
            2.1541e-07,  1.8464e-07],
          [ 4.2393e-07,  6.5220e-07,  6.2894e-07,  ...,  2.8318e-07,
            2.5690e-07,  2.3177e-07],
          [ 4.6649e-07,  6.4230e-07,  6.2854e-07,  ...,  1.3226e-07,
            2.2451e-07,  2.1060e-07]

In [None]:
#check all different keys from state_dict to explore model structure
state = model.state_dict()
state.keys()

In [31]:
#check number of parameters, nodes, layers, etc. 
total_params = sum(p.numel() for p in model.parameters())
print("Total params: " + str(total_params))
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Trainable params: " + str(trainable_params))
print("Non-Trainable params: " + str(total_params - trainable_params))

Total params: 199869589
Trainable params: 141725781
Non-Trainable params: 58143808


In [32]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters", "Trainable"])
    total_params = 0
    trainable_params = 0
    for name, parameter in model.named_parameters():
        #if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params, parameter.requires_grad])
        total_params+=params
        if parameter.requires_grad:
            trainable_params+=params
    print(table)
    print(f"Total Params: {total_params}")
    print(f"Total Trainable Params: {trainable_params}")
    return total_params

count_parameters(model)

+------------------------------------------+------------+-----------+
|                 Modules                  | Parameters | Trainable |
+------------------------------------------+------------+-----------+
|            backbone.0.weight             |    9408    |   False   |
|            backbone.1.weight             |     64     |   False   |
|             backbone.1.bias              |     64     |   False   |
|        backbone.4.0.conv1.weight         |    4096    |   False   |
|         backbone.4.0.bn1.weight          |     64     |   False   |
|          backbone.4.0.bn1.bias           |     64     |   False   |
|        backbone.4.0.conv2.weight         |   36864    |   False   |
|         backbone.4.0.bn2.weight          |     64     |   False   |
|          backbone.4.0.bn2.bias           |     64     |   False   |
|        backbone.4.0.conv3.weight         |   16384    |   False   |
|         backbone.4.0.bn3.weight          |    256     |   False   |
|          backbone.

199869589

In [None]:
# view dataset
datasetviewer = DatasetViewerSingle(dataset, rccn_transform=None)
datasetviewer.napari()

In [None]:
# inference (cpu)
model.eval()
for sample in dataloader_prediction:
    x, x_name = sample
    with torch.no_grad():
        pred = model(x)
        pred = {key: value.numpy() for key, value in pred[0].items()}
        name = pathlib.Path(x_name[0])
        save_dir = pathlib.Path(os.getcwd()) / params['PREDICTIONS_PATH']
        save_dir.mkdir(parents=True, exist_ok=True)
        pred_list = {key: value.tolist() for key, value in pred.items()}  # numpy arrays are not serializable -> .tolist()
        save_json(pred_list, path=save_dir / name.with_suffix('.json'))

In [None]:
# get prediction files (json)
predictions = get_filenames_of_path(pathlib.Path(os.getcwd()) / params['PREDICTIONS_PATH'])
predictions.sort()

In [None]:
# create dataset from test images and predictions
iou_threshold = 0.25
score_threshold = 0.3

transforms_prediction = ComposeDouble([
    FunctionWrapperDouble(np.moveaxis, source=-1, destination=0),
    FunctionWrapperDouble(normalize_01),
    FunctionWrapperDouble(apply_nms, input=False, target=True, iou_threshold=iou_threshold),
    FunctionWrapperDouble(apply_score_threshold, input=False, target=True, score_threshold=score_threshold)
])

dataset_prediction = ObjectDetectionDataSet(inputs=inputs,
                                            targets=predictions,
                                            transform=transforms_prediction,
                                            use_cache=False)

In [None]:
# mapping
color_mapping = {
    1: 'red',
}

In [None]:
# visualize predictions
datasetviewer_prediction = DatasetViewer(dataset_prediction, color_mapping)
datasetviewer_prediction.napari()
# add text properties gui
datasetviewer_prediction.gui_text_properties(datasetviewer_prediction.shape_layer)