In [1]:
from pathlib import Path

import numpy as np
import torch
%matplotlib inline
import matplotlib.pyplot as plt

from super_segmenter.models import UNet
from super_segmenter.params import Registry
from super_segmenter.training.data import PascalPartDataset
from super_segmenter.utils.metrics import intesection_over_union

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
PARAMS_SET = "u_net_baseline"
MODEL_PATH = Path("/root/models-small/unet/checkpoint-4488")
params = Registry.get_params(PARAMS_SET)

In [3]:
valid_dataset = PascalPartDataset(
    ids_path=params.data_params.val_ids_path,
    images_dir_path=params.data_params.images_dir_path,
    masks_dir_path=params.data_params.gt_masks_dir_path,
    img_size=params.model_params.image_size
)

In [4]:
torch.unique(valid_dataset[0][1])

tensor([0, 1, 2, 3, 4, 5, 6])

In [5]:
model = UNet(params.model_params).to("cpu")
model.load_state_dict(
    torch.load(MODEL_PATH, map_location=torch.device("cpu"))
)
model.eval()

UNet(
  (_encoder): UnetEncoder(
    (_layers): ModuleList(
      (0): UNetConvblock(
        (_block): Sequential(
          (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU()
          (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): ReLU()
          (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): UNetConvblock(
        (_block): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU()
          (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): ReLU()
          (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
    

In [6]:
def infer(image: torch.Tensor, transform=None) -> torch.Tensor:
   if transform:
      image = transform(image)
   
   probs = torch.softmax(model(image.unsqueeze(dim=0)), dim=1)
   return torch.argmax(probs, dim=1)

In [8]:
num_vals = 10
fig, axes = plt.subplots(num_vals, 3, figsize=(3*5, num_vals*5))

for i in range(num_vals):
    x_val, y_val = valid_dataset[np.random.randint(len(valid_dataset))]
    out = infer(image=x_val)
    
    image = x_val.permute(1, 2, 0).cpu().detach().numpy()
    label_class = y_val.cpu().detach()
    label_class_predicted = infer(image=x_val).squeeze().cpu().detach()
    mean_iou = intesection_over_union(
        gt_masks=label_class, pred_masks=label_class_predicted)
    
    axes[i, 0].imshow(image)
    axes[i, 0].set_title("Image")
    axes[i, 1].imshow(label_class.numpy())
    axes[i, 1].set_title("Label Class")
    axes[i, 2].imshow(label_class_predicted.numpy())
    axes[i, 2].set_title(f"Label Class - Predicted mIOU: {mean_iou:.5f}")