In [None]:
%matplotlib inline
import os
import sys
from functools import partial

import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import skimage.color as sk_color
import skimage.morphology as sk_morph
import scipy.io

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

sys.path.append('../moco-v3/')
import suppix_utils.datasets_seeds as datasets

sys.path.append('../')
import cast_models.cast

### Define models and load pre-trained weights

In [None]:
MODEL_CLASS_NAME = 'cast_base'
CHECKPOINT_PATH = '../snapshots/moco/imagenet1k/cast_base/checkpoint_0099.pth.tar'

model = cast_models.cast.__dict__[MODEL_CLASS_NAME]().cuda()
ckpt = torch.load(CHECKPOINT_PATH, map_location='cuda:0')
state_dict = {k[len('module.base_encoder.'):]: v for k, v in ckpt['state_dict'].items()
              if 'module.base_encoder.' in k and 'head' not in k}
msg = model.load_state_dict(state_dict, strict=False)
print(msg)
model = model.eval()

### Prepare dataloader for loading ImageNet images

In [None]:
DATA_ROOT = '../data/PartImageNet/images/val/'
# DATA_ROOT = './demo_images'
SAVE_ROOT = '../pred_segs/'

class ReturnIndexDataset(datasets.ImageFolder):
    def __getitem__(self, idx):
        img, seg, lab = super(ReturnIndexDataset, self).__getitem__(idx)
        return img, seg, lab, idx

normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])

augmentation = [
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
]

train_dataset = ReturnIndexDataset(
    DATA_ROOT,
    transforms.Compose(augmentation),
    normalize=normalize,
    n_segments=196,
    compactness=10.0,
    blur_ops=None,
    scale_factor=1.0)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=1, shuffle=False, drop_last=False)

In [None]:
for i, (img, suppixel, label, img_inds) in enumerate(train_loader):
    img = img.cuda() # input images
    suppixel = suppixel.cuda() # superpixels

    # Forward pass to return intermediate groupings
    intermediates = model.forward_features(img, suppixel)
    
    # Aggregate groupings from fine to coarse levels
    segmentations = {}
    prev_labels = {}
    # Iterate through the finest to the coarsest scales
    for level in [1, 2, 3, 4]:
        # Iterate through the mini-batch
        for b in range(img.shape[0]):
            # Grouping logit for the current level
            logit = intermediates['logit{:d}'.format(level)][b]
            label = torch.argmax(logit, dim=-1).detach()
            if level == 1 and len(label) < 196:
                npad = torch.unique(suppixel).shape[0] - logit.shape[0]
                label = torch.concat([label, torch.tensor([label[-1]] * npad, device=label.device)])

            # Gather coarser grouping at the current level
            # The level-1 grouping index for each level-0 group is [1, 2, 2, 0, 1, 1, 0]
            # The level-2 grouping index for each level-1 group is [0, 1, 0]
            # We infer the level-2 grouping for each level-0 group as [1, 0, 0, 0, 1, 1, 0]
            if level > 1:
                prev_label = prev_labels['level{:d}'.format(level-1)][b]
                label = torch.gather(label, 0, prev_label.view(-1))
            if prev_labels.get('level{:d}'.format(level), None) is None:
                prev_labels['level{:d}'.format(level)] = []
            prev_labels['level{:d}'.format(level)].append(label)

            # Gather groupings for each superpixel
            label = torch.gather(label, 0, suppixel[b].view(-1))
            label = label.view(suppixel.shape[-2:])

            # Save segmentations by levels
            save_root = os.path.join(SAVE_ROOT, MODEL_CLASS_NAME, 'level{:d}'.format(level))
            os.makedirs(save_root, exist_ok=True)

            file_name = train_dataset.samples[img_inds[b]][0]
            file_name = file_name.split('/')[-1].split('.')[0]
            with open(os.path.join(save_root, '{}.npy'.format(file_name)), 'wb') as f:
                np.save(f, label.cpu().data.numpy())