In [None]:
from PIL import Image
import numpy as np
from matplotlib import patches
from tqdm import tqdm
import pandas as pd
import fnmatch
from multiprocessing import Pool
from pathlib import Path
from functools import partial

In [None]:
!git clone https://github.com/YangtaoWANG95/TokenCut.git

In [None]:
!cd TokenCut && pip install -r requirements.txt

In [None]:
train_img_folder = '../input/jpeg-happywhale-384x384/train_images-384-384/train_images-384-384/'
test_img_folder = '../input/jpeg-happywhale-384x384/test_images-384-384/test_images-384-384/'

In [None]:
import sys
sys.path.insert(0, './TokenCut')

In [None]:
import os
import argparse
import random
import pickle

import torch
import datetime
import torch.nn as nn
import numpy as np

from tqdm import tqdm
from PIL import Image

from networks import get_model
from datasets import ImageDataset, Dataset, bbox_iou
from visualizations import visualize_img, visualize_eigvec, visualize_predictions, visualize_predictions_gt 
from object_discovery import ncut 
import matplotlib.pyplot as plt
import time

# torch.multiprocessing.set_start_method('spawn')

In [None]:
PATCH_SIZE = 16
WHICH_FEATURES = 'k'
ARCH = 'vit_base'

TAU = 0.2
EPS = 1e-5
NO_BINARY_GRAPH = False

In [None]:
exp_name = f"TokenCut-{ARCH}"
if "vit" in ARCH:
    exp_name += f"{PATCH_SIZE}_{WHICH_FEATURES}"

In [None]:
def get_bounding_box(img_path, model):
    dataset = ImageDataset(img_path)
    preds_dict = {}
    cnt = 0
    corloc = np.zeros(len(dataset.dataloader))

    start_time = time.time() 
    pbar = dataset.dataloader
    for im_id, inp in enumerate(pbar):

        # ------------ IMAGE PROCESSING -------------------------------------------
        img = inp[0]

        init_image_size = img.shape

        # Get the name of the image
        im_name = dataset.get_image_name(inp[1])
        # Pass in case of no gt boxes in the image
        if im_name is None:
            continue

        # Padding the image with zeros to fit multiple of patch-size
        size_im = (
            img.shape[0],
            int(np.ceil(img.shape[1] / PATCH_SIZE) * PATCH_SIZE),
            int(np.ceil(img.shape[2] / PATCH_SIZE) * PATCH_SIZE),
        )
        paded = torch.zeros(size_im)
        paded[:, : img.shape[1], : img.shape[2]] = img
        img = paded

        # # Move to gpu
        img = img.cuda(non_blocking=True)

        # Size for transformers
        w_featmap = img.shape[-2] // PATCH_SIZE
        h_featmap = img.shape[-1] // PATCH_SIZE

        # ------------ EXTRACT FEATURES -------------------------------------------
        with torch.no_grad():

            # ------------ FORWARD PASS -------------------------------------------
            if "vit"  in ARCH:
                # Store the outputs of qkv layer from the last attention layer
                feat_out = {}
                def hook_fn_forward_qkv(module, input, output):
                    feat_out["qkv"] = output
                model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv)

                # Forward pass in the model
                attentions = model.get_last_selfattention(img[None, :, :, :])

                # Scaling factor
                scales = [PATCH_SIZE, PATCH_SIZE]

                # Dimensions
                nb_im = attentions.shape[0]  # Batch size
                nh = attentions.shape[1]  # Number of heads
                nb_tokens = attentions.shape[2]  # Number of tokens

                # Extract the qkv features of the last attention layer
                qkv = (
                    feat_out["qkv"]
                    .reshape(nb_im, nb_tokens, 3, nh, -1 // nh)
                    .permute(2, 0, 3, 1, 4)
                )
                q, k, v = qkv[0], qkv[1], qkv[2]
                k = k.transpose(1, 2).reshape(nb_im, nb_tokens, -1)
                q = q.transpose(1, 2).reshape(nb_im, nb_tokens, -1)
                v = v.transpose(1, 2).reshape(nb_im, nb_tokens, -1)

                # Modality selection
                if WHICH_FEATURES == "k":
                    #feats = k[:, 1:, :]
                    feats = k
                elif WHICH_FEATURES == "q":
                    #feats = q[:, 1:, :]
                    feats = q
                elif WHICH_FEATURES == "v":
                    #feats = v[:, 1:, :]
                    feats = v



            else:
                raise ValueError("Unknown model.")

        # ------------ Apply TokenCut ------------------------------------------- 
        pred, objects, foreground, seed , bins, eigenvector= ncut(
            feats, [w_featmap, h_featmap], scales, init_image_size, TAU, EPS, im_name=im_name, no_binary_graph=NO_BINARY_GRAPH)
        return img, pred

In [None]:
def convert_to_rect(img, label, color='b'):
    height, width = img.shape[0], img.shape[1]
    xmin, ymin, xmax, ymax  = label[0], label[1], label[2], label[3]
    rect = patches.Rectangle((
         xmin,
         ymin
    ),
        (xmax - xmin),
        (ymax - ymin),
        linewidth=1, edgecolor=color, facecolor='none'
    )
    return rect

In [None]:
def _get_bb(path):
    img, pred = get_bounding_box(str(path))
    return path.name, pred

In [None]:
def load_model(arch, patch_size):
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    #device = torch.device('cuda')
    model = get_model(arch, patch_size, device)
        
    return model

In [None]:
def get_preds(model, img_folder, n=None):
    total_len = len(fnmatch.filter(os.listdir(img_folder), '*.jpg'))
    preds = []
    pathlist = Path(img_folder).glob('*jpg')
    for i, path in tqdm(enumerate(pathlist), total=n or total_len):
        img, out_pred = get_bounding_box(str(path), model)
        width, height = float(img.shape[1]), float(img.shape[2]) 
        xmin, ymin, xmax, ymax = list(out_pred)
        preds.append([path.name, (xmin / width), (ymin / height), (xmax / width), (ymax / height)])
        if n and i >= n:
            break

    return preds

In [None]:
def show_preds(preds):
    return pd.DataFrame(preds, columns=['image', 'xmin', 'ymin', 'xmax', 'ymax'])

In [None]:
def convert_to_rect(img, label, linewidth=1, color='b'):
    width, height = img.size[0], img.size[1]
    xmin, ymin, xmax, ymax  = label[0], label[1], label[2], label[3]
    rect = patches.Rectangle((
         xmin * width,
         ymin * height
    ),
        (xmax - xmin) * width,
        (ymax - ymin) * height,
        linewidth=linewidth, edgecolor=color, facecolor='none'
    )
    return rect

In [None]:
color_list = ['r', 'g', 'b']

def show_img_grid(preds, dataset):
    row = 10; col = 4;

    plt.figure(figsize=(20,int(20*row/col)))
    for j in range(row*col):
        first_preds = preds[0]

        first_pred = first_preds[j]
        image, bb = first_pred[0], first_pred[1:]
        img = Image.open(f'../input/happy-whale-and-dolphin/{dataset}_images/{image}')
        plt.subplot(row,col,j+1)
        plt.axis('off')
        plt.imshow(img)
        ax = plt.gca()
        
        for i, pred in enumerate(preds):
            bb = pred[j][1:]
            c = color_list[i]
            ax.add_patch(convert_to_rect(img, bb, 3, color=c))
    plt.tight_layout()
    plt.show()

In [None]:
model = load_model('vit_small', 16)
preds_small = get_preds(model, train_img_folder, n=100)

In [None]:
model = load_model('vit_base', 16)
preds_base = get_preds(model, train_img_folder, n=100)

In [None]:
# Average performance.
# model = load_model('moco_vit_small', 16)
# preds_moco_vit_small = get_preds(model, train_img_folder, n=100)

In [None]:
model = load_model('moco_vit_base', 16)
preds_moco_vit_base = get_preds(model, train_img_folder, n=100)

In [None]:
# model = load_model('mae_vit_base', 16)
# preds_mae_vit_base = get_preds(model, train_img_folder, n=100)
# Not good!

In [None]:
show_img_grid([preds_small, preds_base, preds_moco_vit_base], dataset='train')

These bounding boxes seem basically flawless. 

Let's try on test.

In [None]:
model = load_model('vit_small', 16)
preds_small = get_preds(model, test_img_folder, n=100)

model = load_model('vit_base', 16)
preds_base = get_preds(model, test_img_folder, n=100)

model = load_model('moco_vit_base', 16)
preds_moco_vit_base = get_preds(model, test_img_folder, n=100)

show_img_grid([preds_small, preds_base, preds_moco_vit_base], dataset='test')

In [None]:
model = load_model('vit_small', 16)
preds = get_preds(model, train_img_folder)
pd.DataFrame(
    preds, columns=['image', 'xmin', 'ymin', 'xmax', 'ymax']
).to_csv('train_vit_small.csv')

model = load_model('vit_base', 16)
preds = get_preds(model, train_img_folder)
pd.DataFrame(
    preds, columns=['image', 'xmin', 'ymin', 'xmax', 'ymax']
).to_csv('train_vit_base.csv')

model = load_model('moco_vit_base', 16)
preds = get_preds(model, train_img_folder)
pd.DataFrame(
    preds, columns=['image', 'xmin', 'ymin', 'xmax', 'ymax']
).to_csv('train_moco_vit_base.csv')

In [None]:
model = load_model('vit_small', 16)
test_preds = get_preds(model, test_img_folder)
pd.DataFrame(
    test_preds, columns=['image', 'xmin', 'ymin', 'xmax', 'ymax']
).to_csv('test_vit_small.csv')

model = load_model('vit_base', 16)
test_preds = get_preds(model, test_img_folder)
pd.DataFrame(
    test_preds, columns=['image', 'xmin', 'ymin', 'xmax', 'ymax']
).to_csv('test_vit_base.csv')

model = load_model('moco_vit_base', 16)
test_preds = get_preds(model, test_img_folder)
pd.DataFrame(
    test_preds, columns=['image', 'xmin', 'ymin', 'xmax', 'ymax']
).to_csv('test_moco_vit_base.csv')

In [None]:
!ls -l ./