In [None]:
%matplotlib inline
import os
import sys
from functools import partial

import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import skimage.color as sk_color
import skimage.morphology as sk_morph
import scipy.io

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

sys.path.append('../moco-v3/')
import suppix_utils.datasets_seeds as datasets

sys.path.append('../vits_vis_utils/')
import vits_vis_attn
import segsort
from clustering import SphericalKMeans

### Define models and load pre-trained weights

In [None]:
MODEL_CLASS_NAME = "vit_base"
CHECKPOINT_PATH = '../ckpt/cast_base.pth.tar'

model = vits_vis_attn.__dict__[MODEL_CLASS_NAME]().cuda()
model = nn.DataParallel(model)
# ckpt = torch.load(CHECKPOINT_PATH, map_location='cuda:0')
# model.load_state_dict(ckpt['state_dict'], strict=True)
model = model.module
model = model.eval()

### Prepare dataloader for loading ImageNet images

In [None]:
# DATA_ROOT = '../data/PartImageNet/images/val/'
DATA_ROOT = './demo_images'
SAVE_ROOT = '../pred_segs/'

class ReturnIndexDataset(datasets.ImageFolder):
    def __getitem__(self, idx):
        img, seg, lab = super(ReturnIndexDataset, self).__getitem__(idx)
        return img, seg, lab, idx

normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])

augmentation = [
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
]

train_dataset = ReturnIndexDataset(
    DATA_ROOT,
    transforms.Compose(augmentation),
    normalize=normalize,
    n_segments=196,
    compactness=10.0,
    blur_ops=None,
    scale_factor=1.0)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=1, shuffle=False, drop_last=False)

In [None]:
for i, (img, suppixel, label, img_inds) in enumerate(train_loader):
    img = img.cuda() # input images

    # Forward pass to return token features
    x = model.forward_token_features(img)

    # Rescale token features back to image resolution
    x = x.unflatten(1, (14, 14)).permute(0, 3, 1, 2)
    x = F.interpolate(x, scale_factor=(16, 16), mode='bilinear')
    x = x.permute(0, 2, 3, 1).flatten(1, 2)

    # Perform spherical K-Means to infer segmentations
    norm_x = F.normalize(x, dim=-1)
    segmentations = {}
    num_segs = [64, 32, 16, 8] # We use the same granularities as CAST
    # Iterate through the finest to the coarsest scales
    for num_ind, num_seg in enumerate(num_segs):
        # Iterate through the mini-batch
        for b in range(img.shape[0]):
            # Set up K-Means clustering arguments
            kmeans = SphericalKMeans(K=num_seg, concentration=5, iterations=30)
            if num_ind > 0:
                # Gather coarser grouping at the current level
                # The level-1 grouping index for each level-0 group is [1, 2, 2, 0, 1, 1, 0]
                # The level-2 grouping index for each level-1 group is [0, 1, 0]
                # We infer the level-2 grouping for each level-0 group as [1, 0, 0, 0, 1, 1, 0]

                # To get coarser grouping, we calculate "prototypes"--the mean-average token features,
                # within each finer group
                prev_label = torch.from_numpy(segmentations[num_segs[num_ind-1]][b]).view(-1).cuda()
                prototypes = segsort.calculate_prototypes_from_labels(x[b], prev_label)
                # Perform K-Means clustering
                kmeans_label, _ = kmeans(prototypes)
                _, kmeans_label = torch.unique(kmeans_label, return_inverse=True)
                kmeans_label = kmeans_label[prev_label]
            else:
                kmeans_label, _ = kmeans(x[b])
                _, kmeans_label = torch.unique(kmeans_label, return_inverse=True)
            label = kmeans_label.view(224, 224)

            save_root = os.path.join(SAVE_ROOT, MODEL_CLASS_NAME, 'level{:d}'.format(num_ind + 1))
            os.makedirs(save_root, exist_ok=True)

            file_name = train_dataset.samples[img_inds[b]][0]
            file_name = file_name.split('/')[-1].split('.')[0]
            with open(os.path.join(save_root, '{}.npy'.format(file_name)), 'wb') as f:
                np.save(f, label.cpu().data.numpy())

            if num_seg not in segmentations.keys():
                segmentations[num_seg] = []
            segmentations[num_seg].append(kmeans_label.view(224, 224).data.cpu().numpy())
        segmentations[num_seg] = np.stack(segmentations[num_seg], axis=0)