## Visualization of the superpixel images, graphs, edge connections and node labels

### Superpixels from
Superpixels are generated using the file `'./scripts/COCO/generate_cocosuperpixels_raw.py'`

In [None]:
import random
from torchvision import transforms, datasets

import os
import pickle
from scipy.spatial.distance import cdist
import scipy.io as sio
from scipy import ndimage
import numpy as np

import dgl
import torch
from torch.utils import data
from PIL import Image
import time
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib

matplotlib.rcParams.update({'font.size': 22})

### Prepare COCO Images and mask

### Automatically downloading dataset from the link in this repo
The links are from the source https://cocodataset.org/#download
There are 118K Train images and 5K Val images

#### Before proceeding further:
- Download the repo https://github.com/cocodataset/cocoapi in the current directory 
- Then run `make` inside the `cocoapi/PythonAPI` directory  

In [None]:
if not os.path.isfile('train2017.zip'):
    print('downloading..')
    !curl http://images.cocodataset.org/zips/train2017.zip -o train2017.zip
    !unzip train2017.zip
    !mv train2017 cocoapi/images/train2017
else:
    print('File already downloaded')
    
if not os.path.isfile('val2017.zip'):
    print('downloading..')
    !curl http://images.cocodataset.org/zips/val2017.zip -o val2017.zip
    !unzip val2017.zip
    !mv val2017 cocoapi/images/val2017
else:
    print('File already downloaded')
    
if not os.path.isfile('annotations_trainval2017.zip'):
    print('downloading..')
    !curl http://images.cocodataset.org/annotations/annotations_trainval2017.zip -o annotations_trainval2017.zip
    !unzip annotations_trainval2017.zip
    !mv annotations cocoapi/annotations
else:
    print('File already downloaded')

In [None]:
%matplotlib inline
from cocoapi.PythonAPI.pycocotools.coco import COCO
import skimage.io as io
import pylab
pylab.rcParams['figure.figsize'] = (8.0, 10.0)
pylab.rcParams.update({'font.size': 22})
from tqdm import tqdm

In [None]:
"""
    COCO categories: 
    person bicycle car motorcycle airplane bus train truck boat traffic light fire hydrant stop
    sign parking meter bench bird cat dog horse sheep cow elephant bear zebra giraffe backpack
    umbrella handbag tie suitcase frisbee skis snowboard sports ball kite baseball bat baseball
    glove skateboard surfboard tennis racket bottle wine glass cup fork knife spoon bowl banana
    apple sandwich orange broccoli carrot hot dog pizza donut cake chair couch potted plant bed
    dining table toilet tv laptop mouse remote keyboard cell phone microwave oven toaster sink
    refrigerator book clock vase scissors teddy bear hair drier toothbrush
"""

class COCO_Images_Masks(data.Dataset):
    def __init__(self, mode, root='./cocoapi'):
        self.root = root
        self.mode = mode
        self.all_superpixels = []
        self.all_rag_boundary_graphs = []
        self.all_sp_data = []
        self.all_sp_node_labels = []
        
        self.n_sp = 100
        self.compactness = 10
        self.seed = 41
        self.out_dir = '.'
        self.dataset = 'COCO'
        
        self.args = self.mode, self.seed, self.n_sp, self.compactness
        self.img_list = []
        self.mask_list = []
        
        self.num_images = self._pack_images_masks(mode)
        
    def _pack_images_masks(self, mode):
        # in this paper, we train on the train set and evaluate on the val set
        assert mode in ['train', 'val']
        
        dataType = 'val2017' if mode == 'val' else 'train2017'
        annFile = '{}/annotations/instances_{}.json'.format(self.root, dataType)
        
        # initialize COCO api for instance annotations
        coco=COCO(annFile)
        
        # cats = coco.loadCats(coco.getCatIds())
        # print(cats)
        
        imgIds = coco.getImgIds()
        cat_ids = coco.getCatIds()
        
        all_imgs = coco.loadImgs(imgIds)#[0]
        # for index in tqdm(range(len(all_imgs))):
        
        sample_length = 50
        
        for index in tqdm(range(sample_length)):
            img_meta_info = all_imgs[index]
            # img = io.imread(img_meta_info['coco_url']) # This command actually fetches the img from url each time
            img = Image.open(os.path.join(self.root, 'images', dataType, img_meta_info['file_name'])).convert('RGB')
        
            anns_ids = coco.getAnnIds(imgIds=img_meta_info['id'], catIds=cat_ids, iscrowd=None)
            anns = coco.loadAnns(anns_ids)
            
            mask = np.zeros((img_meta_info['height'],img_meta_info['width']))
            for ann in anns:
                mask = np.maximum(mask,coco.annToMask(ann)*ann['category_id'])

            self.img_list.append(np.array(img))
            self.mask_list.append(mask)
        return sample_length

    def __getitem__(self, index):
        return self.img_list[index], self.mask_list[index]

    def __len__(self):
        return self.num_images

val_set = COCO_Images_Masks('val')

In [None]:
val_set = COCO_Images_Masks('val')

### Functions definition for graph construction

In [None]:
def sigma(dists, kth=8):
    # Compute sigma and reshape
    try:
        # Get k-nearest neighbors for each node
        knns = np.partition(dists, kth, axis=-1)[:, kth::-1]
        sigma = knns.sum(axis=1).reshape((knns.shape[0], 1))/kth
    except ValueError:     # handling for graphs with num_nodes less than kth
        num_nodes = dists.shape[0]
        # this sigma value is irrelevant since not used for final compute_edge_list
        sigma = np.array([1]*num_nodes).reshape(num_nodes,1)
        
    return sigma + 1e-8 # adding epsilon to avoid zero value of sigma

def compute_adjacency_matrix_images(coord, feat, use_feat=True, kth=8):
    coord = coord.reshape(-1, 2)
    # Compute coordinate distance
    c_dist = cdist(coord, coord)
    
    if use_feat:
        # Compute feature distance
        f_dist = cdist(feat, feat)
        # Compute adjacency
        A = np.exp(- (c_dist/sigma(c_dist))**2 - (f_dist/sigma(f_dist))**2 )
    else:
        A = np.exp(- (c_dist/sigma(c_dist))**2)
        
    # Convert to symmetric matrix
    A = 0.5 * (A + A.T)
    A[np.diag_indices_from(A)] = 0
    return A        


def compute_edges_list(A, kth=8+1):
    # Get k-similar neighbor indices for each node

    num_nodes = A.shape[0]
    new_kth = num_nodes - kth
    
    if num_nodes > 9:
        knns = np.argpartition(A, new_kth-1, axis=-1)[:, new_kth:-1]
        knn_values = np.partition(A, new_kth-1, axis=-1)[:, new_kth:-1] # NEW
    else:
        # handling for graphs with less than kth nodes
        # in such cases, the resulting graph will be fully connected
        knns = np.tile(np.arange(num_nodes), num_nodes).reshape(num_nodes, num_nodes)
        knn_values = A # NEW
        
        # removing self loop
        if num_nodes != 1:
            knn_values = A[knns != np.arange(num_nodes)[:,None]].reshape(num_nodes,-1) # NEW
            knns = knns[knns != np.arange(num_nodes)[:,None]].reshape(num_nodes,-1)
    return knns, knn_values # NEW

### SuperpixDGL class for reading superpixels file and constructing graph

In [None]:
class SuperPixDGL(torch.utils.data.Dataset):
    def __init__(self,
                 data_dir,
                 dataset,
                 split,
                 graph_format='edge_wt_only_coord'):
        assert graph_format in ['edge_wt_only_coord', 'edge_wt_coord_feat', 'edge_wt_region_boundary']
        self.split = split
        self.graph_lists = []

        with open(os.path.join(data_dir, 'sample/COCO_500sp_%s_superpixels.pkl' % split), 'rb') as f:
            self.superpixels = pickle.load(f)
        
        with open(os.path.join(data_dir, 'sample/COCO_500sp_%s.pkl' % split), 'rb') as f:
            self.labels, self.sp_data = pickle.load(f)
            self.sp_node_labels = self.labels
        
        if graph_format == 'edge_wt_region_boundary':
            with open(os.path.join(data_dir, 'sample/COCO_500sp_%s_rag_boundary_graphs.pkl' % split), 'rb') as f:
                self.region_boundary_graphs = pickle.load(f)

        self.graph_format = graph_format 
        self.n_samples = len(self.labels)
        
        self._prepare()
    
    def _prepare(self):
        print("preparing %d graphs for the %s set..." % (self.n_samples, self.split.upper()))
        self.Adj_matrices, self.node_features, self.edges_lists, self.edge_features = [], [], [], []
        for index, sample in enumerate(self.sp_data):
            mean_px, coord = sample[:2]
            
            try:
                coord = coord / self.img_size
            except AttributeError:
                VOC_has_variable_image_sizes = True
                
            if self.graph_format == 'edge_wt_coord_feat':
                A = compute_adjacency_matrix_images(coord, mean_px) # using super-pixel locations + features
                edges_list, edge_values_list = compute_edges_list(A) 
            elif self.graph_format == 'edge_wt_only_coord':
                A = compute_adjacency_matrix_images(coord, mean_px, False) # using only super-pixel locations
                edges_list, edge_values_list = compute_edges_list(A) 
            elif self.graph_format == 'edge_wt_region_boundary':
                A, edges_list, edge_values_list = None, None, None

            N_nodes = mean_px.shape[0]
            
            mean_px = mean_px.reshape(N_nodes, -1)
            coord = coord.reshape(N_nodes, 2)
            x = np.concatenate((mean_px, coord), axis=1)

            if edge_values_list is not None:
                edge_values_list = edge_values_list.reshape(-1) 
            
            self.node_features.append(x)
            self.edge_features.append(edge_values_list) 
            self.Adj_matrices.append(A)
            self.edges_lists.append(edges_list)
        
    def __len__(self):
        """Return the number of graphs in the dataset."""
        return self.n_samples

    def __getitem__(self, idx):
        
        if self.graph_format == 'edge_wt_region_boundary':
            if self.node_features[idx].shape[0] == 1:
                # handling for 1 node where the self loop would be the only edge
                # since, VOC Superpixels has few samples (5 samples) with only 1 node
                g = dgl.DGLGraph()
                g.add_nodes(self.node_features[idx].shape[0]) 
                g = dgl.add_self_loop(g)
                # dummy edge feat since no actual edge present
                g.edata['feat'] = torch.zeros(1, 2) # 1 edge and 2 feat dim
                self.Adj_matrices[idx] = g.adjacency_matrix().to_dense().numpy()
            else:
                g = dgl.from_networkx(self.region_boundary_graphs[idx].to_directed(),
                                  edge_attrs=['weight', 'count'])
                g.edata['feat'] = torch.stack((g.edata['weight'], g.edata['count']),-1)
                del g.edata['weight'], g.edata['count']
                self.Adj_matrices[idx] = g.adjacency_matrix().to_dense().numpy()
        else:
            g = dgl.DGLGraph()
            g.add_nodes(self.node_features[idx].shape[0])
            for src, dsts in enumerate(self.edges_lists[idx]):
                g.add_edges(src, dsts[dsts!=src])
                
        g.ndata['feat'] = torch.Tensor(self.node_features[idx])
        
        return g, self.sp_node_labels[idx]

### Only coordinates for knn graph construction
This is done by setting `graph_format` option.   

In [None]:
# Taking the test dataset only for sample visualization
graph_format = 'edge_wt_only_coord'
tt = time.time()
data_only_coord_knn = SuperPixDGL("./", 
                                  dataset='COCO',
                                  split='val', 
                                  graph_format=graph_format)

print("Time taken: {:.4f}s".format(time.time()-tt))

### Both coordinates and features for knn graph construction

In [None]:
graph_format = 'edge_wt_coord_feat'
tt = time.time()
data_coord_feat_knn = SuperPixDGL("./", 
                                  dataset='COCO',
                                  split='val', 
                                  graph_format=graph_format)

print("Time taken: {:.4f}s".format(time.time()-tt))

### Region Boundary based graph construction with variable edges for every node

In [None]:
graph_format = 'edge_wt_region_boundary'
tt = time.time()
data_region_boundary = SuperPixDGL("./", 
                                  dataset='COCO',
                                  split='val', 
                                  graph_format=graph_format)

print("Time taken: {:.4f}s".format(time.time()-tt))

### Superpixels plot function definition

In [None]:
from scipy.spatial.distance import pdist, squareform
from pylab import rcParams
from skimage.segmentation import mark_boundaries

def show_image(plt, idx, alpha=1.0):
    plt.imshow(val_set.img_list[idx])

    plt.axis('off')
    plt.title.set_text(" Original Image")

def plot_superpixels_graph(plt, data, idx, overlay=None):
    with_edges = True
    sp_data = data.sp_data[idx]
    node_labels = data.sp_node_labels[idx]
    adj_matrix = data.Adj_matrices[idx]
    g = data[idx][0]
    
    Y = squareform(pdist(sp_data[1], 'euclidean'))
    x_coord = sp_data[1] #np.flip(dataset.train.sp_data[_][1], 1)
    # intensities = sp_data[0].mean(axis=1)
    
    G = nx.from_numpy_matrix(Y)
    pos = dict(zip(range(len(x_coord)), x_coord.tolist()))
    rotated_pos = {node: (y,-x) for (node, (x,y)) in pos.items()} # rotate the coords by 90 degree
    
    if overlay is not None:
        if overlay=="image":
            plt.imshow(val_set.img_list[idx])
        else:
            pass
        plt.imshow(mark_boundaries(val_set.img_list[idx], data.superpixels[idx],
                               color=[0,1,0], outline_color=[0,1,0]))
        rotated_pos = {node: (x,-y) for (node, (x,y)) in rotated_pos.items()} # reflect the graph on x-axis for overlaying
    
    edge_list = torch.stack(g.edges(),0).T.tolist()
        
    nx.draw_networkx_nodes(G, rotated_pos, node_color=node_labels, node_size=40) # len(intensities))
    if with_edges and overlay=="image":
        nx.draw_networkx_edges(G, rotated_pos, edge_list, edge_color='w', alpha=0.65)
    elif with_edges and overlay != "slic":
        nx.draw_networkx_edges(G, rotated_pos, edge_list, alpha=0.15)
    
    title = ""
    
    if data.graph_format == 'edge_wt_region_boundary':
        title += " `rag-boundary` graph overlay on SLIC SP"
        if overlay == None:
            title = "final `rag-boundary` graph"
    elif data.graph_format == 'edge_wt_only_coord':
        title += " `coo` graph overlay on SLIC SP"
        if overlay == None:
            title = "final `coo` graph"
    else:
        title += " `coo-feat` graph overlay on SLIC SP"
        if overlay == None:
            title = "final `coo-feat` graph"
    
    if overlay == "slic":
        title = "SLIC SP (compactness=30)"
    
    plt.title.set_text(title)

# Plotting sample superpixels, and graphs

In [None]:
num_samples_plot = 4
# sample_indices = np.random.choice(int(len(data_only_coord_knn)/2), num_samples_plot, replace=False)
sample_indices = np.array([38])#, 16]) # Set manually
print(sample_indices)     

if not os.path.exists('./coco_viz_files'):
    os.makedirs('./coco_viz_files')
    
for f_idx, idx in enumerate(sample_indices):
    print()
    print("------ Image ID ------ : ", idx)
    print("Num nodes: ", data_region_boundary[idx][0].number_of_nodes())
    print("Num edges Graph: edge_wt_only_coord: ", data_only_coord_knn[idx][0].number_of_edges())
    print("Num edges Graph: edge_wt_coord_feat: ", data_coord_feat_knn[idx][0].number_of_edges())
    print("Num edges Graph: Region boundary graph: ", data_region_boundary[idx][0].number_of_edges())
    
#     f = plt.figure(f_idx, figsize=(18, 3))
#     plt1 = f.add_subplot(141)
#     show_image(plt1, idx)

#     plt2 = f.add_subplot(142)
#     plot_superpixels_graph(plt2, data_only_coord_knn, idx, overlay="slic")
    
#     plt3 = f.add_subplot(143)
#     plot_superpixels_graph(plt3, data_region_boundary, idx, overlay="image")

#     plt4 = f.add_subplot(144)
#     plot_superpixels_graph(plt4, data_region_boundary, idx)
    
#     plt.subplots_adjust(hspace=0.1, wspace=0.1)
#     f.savefig('coco_viz_files/coco_'+str(idx)+'_row1.pdf', dpi=300)
#     plt.show()
    
    f = plt.figure(f_idx, figsize=(20, 7))
    plt1 = f.add_subplot(121)
    show_image(plt1, idx)

    plt2 = f.add_subplot(122)
    plot_superpixels_graph(plt2, data_only_coord_knn, idx, overlay="slic")
    
    plt.subplots_adjust(hspace=0.1, wspace=0.1)
    f.savefig('coco_viz_files/coco_'+str(idx)+'_row1.pdf', dpi=300)
    plt.show()
        
    f = plt.figure(f_idx, figsize=(20, 7))
    plt3 = f.add_subplot(121)
    plot_superpixels_graph(plt3, data_only_coord_knn, idx, overlay="image")
    
    plt4 = f.add_subplot(122)
    plot_superpixels_graph(plt4, data_only_coord_knn, idx)
    
    plt.subplots_adjust(hspace=0.1, wspace=0.1)
    f.savefig('coco_viz_files/coco_'+str(idx)+'_row2.pdf', dpi=300)
    plt.show()
    
    f = plt.figure(f_idx, figsize=(20, 7))
    plt5 = f.add_subplot(121)
    plot_superpixels_graph(plt5, data_coord_feat_knn, idx, overlay="image")

    
    plt6 = f.add_subplot(122)
    plot_superpixels_graph(plt6, data_coord_feat_knn, idx)
    
    plt.subplots_adjust(hspace=0.1, wspace=0.1)
    f.savefig('coco_viz_files/coco_'+str(idx)+'_row3.pdf', dpi=300)
    plt.show()
    
    f = plt.figure(f_idx, figsize=(20, 7))
    plt7 = f.add_subplot(121)
    plot_superpixels_graph(plt7, data_region_boundary, idx, overlay="image")

    plt8 = f.add_subplot(122)
    plot_superpixels_graph(plt8, data_region_boundary, idx)

    plt.subplots_adjust(hspace=0.1, wspace=0.1)
    f.savefig('coco_viz_files/coco_'+str(idx)+'_row4.pdf', dpi=300)
    plt.show()