In [None]:
# !pip install \
#     torch==1.4.0 \
#     torchvision==0.4.1 \
#     umap-learn==0.4.4
#     tqdm==4.46.1\
#     matplotlib==3.1.1\
#     natsort==6.0.0

In [None]:
from collections import OrderedDict
from datetime import datetime
from glob import glob
import json
import math
import os,sys
from pathlib import Path
from random import sample, seed
from time import time

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision import datasets
from torchvision import transforms as T
from torch.utils.data import DataLoader 
from torchvision.utils import make_grid
from torchvision.utils import save_image

from natsort import natsorted
import pandas as pd
from umap import UMAP

from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
plt.style.use('ggplot')
plt.style.use('seaborn-colorblind')
from IPython.display import clear_output

In [None]:
##config
dataset_dir = './dataset/imagenet/val'
out_dir = './out'


### real
# ndim = 15
# npoint = 50000
# maxDim = 2048 * 5 * 5
# maxTargetWidth = 16

# ## test
ndim = 2
npoint = 1000
maxDim = 100
maxTargetWidth = 2

In [None]:
if not Path(out_dir).exists():
    os.makedirs(out_dir)


device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'PyTorch version: {torch.__version__}')
print(f'Device: {device}')
print(f'loading dataset from: {dataset_dir}')
print(f'output directory: {out_dir}')

## Utility functions

In [None]:
##dummy layers
class Flatten(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return torch.flatten(x, 1)
    
class Input(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return x

In [None]:
def saveLabels(targets, out_dir):
    print('Getting labels...')
    fn_out = out_dir + '/labels.bin'
    if not (Path(out_dir) / 'labels.bin').exists():
        labels = np.array(targets)
        labels.astype(np.uint16).tofile(fn_out)
    else:
        labels = np.fromfile(fn_out, dtype=np.uint16)
        print('Skipped because labels.bin exists:', fn_out)
    return labels


def getUmap(loader, modules, layerNames, 
    npoint, ndim, 
    maxDim, maxTargetWidth, 
    out_dir, xy_prev = None,
    pooling='average',
    logging=True,
    seed=None
):
    f = None
    if logging:
        f = open(f'{out_dir}/log.txt', 'a')
        def log(*args): 
            print(*args)
            dt=datetime.now().strftime('[%Y-%m-%d %H:%M:%S] ')
            f.write(dt + ' '.join([str(a) for a in args]) + '\n')
            f.flush()
    else:
        log = print
        
    maxBatchCount = math.ceil(npoint / loader.batch_size)
    try:
        with torch.no_grad():
            log('Getting Activations...')
            for i in range(0, len(modules)):
                name = layerNames[i]
                fn_out = out_dir + f'/{name}.bin'
                if Path(fn_out).exists():
                    log(f'skipped {name} because file exists: {fn_out}')
                    xy_prev = np.fromfile(fn_out, dtype=np.float32).reshape([npoint, ndim])
                    continue
                else:
                    module = modules[i]
                    log(name, '...')
                    t0 = time()
                    acts = None
                    total = 0
                    for batchIndex, [imgs, targets] in enumerate(tqdm(loader)):
                        act = module(imgs.to(device)).detach().cpu()
                        if pooling is not None \
                        and len(act.size())==4 \
                        and act.view(act.size(0), -1).size(1) > maxDim:
                            nchannel = act.size(1)
                            target_width = math.floor(math.sqrt(maxDim / nchannel))
                            target_width = min(target_width, maxTargetWidth)
                            if target_width == 0:
                                target_width = 1
                            if pooling == 'average':
                                pool = nn.AdaptiveAvgPool2d([target_width,target_width])
                            elif pooling == 'max':
                                pool = nn.AdaptiveMaxPool2d([target_width,target_width])
                            else:
                                pool = pooling
                            act = pool(act)
                        if acts is None:
                            acts = torch.zeros(npoint, *act.shape[1:])
                        if npoint-total > act.shape[0]:
                            acts[total:total+act.shape[0]] = act
                        else:#final batch
                            acts[total:] = act[:npoint-total]
                        total += act.shape[0]
                        del act
                        if total > npoint:
                            break
                    log(acts.size())
                    log(f'activation done in {(time()-t0):.2f} sec')

                    log('UMAP...')
                    acts = acts.view(acts.size(0), -1)
                    t0 = time()

    #                 if xy_prev is not None:
    #                     umap = UMAP(n_components=ndim, init=xy_prev)
    #                 else:
    #                     print('spectral init')
    #                     umap = UMAP(n_components=ndim, init='spectral')
                    
                    umap = UMAP(n_components=ndim, init='spectral', random_state=seed)
                    log(f'UMAP input size={acts.shape}')
                    xy = umap.fit_transform(acts)
                    xy.astype(np.float32).tofile(fn_out)
                    log(f'UMAP done in {(time()-t0):.2f} sec')
                    xy_prev = xy
                log('=' * 80)
    finally:
        if f is not None:
            f.close()


def detail2coarse(dir0):
    
    fn_in = dir0 + '/labels.bin'
    fn_out = dir0 + '/labels-coarse.bin'

    with open('data/imagenet_class_index.json') as f:
        index2wnidlabel = json.load(f)

    coarse2detail = OrderedDict()
    with open('data/imagenet_coarse_categories.csv') as f:
        f.readline() ## skip the header
        for line in f:
            line = line.strip().split(',')
            try:
                stop = line.index('')
                coarse2detail[line[0]] = line[1:stop]
            except:
                coarse2detail[line[0]] = line[1:]
    
    ## fix some labels
    coarse2detail['vehicle'][8] = 'crane (machine)'
    coarse2detail['clothing'][25] = 'maillot (tank suit)'
    coarse_to_coarseIndex = {k:i for i,k in enumerate(coarse2detail.keys())}
    detail_to_detailIndex = {v[1]:int(k) for k,v in index2wnidlabel.items()}
    detail_to_detailIndex['crane (bird)'] = detail_to_detailIndex['crane'] ## fix

    detail2coarse = {}
    for c,ds in coarse2detail.items():
        for d in ds:
            detail2coarse[detail_to_detailIndex[d]] = coarse_to_coarseIndex[c]

    labels = np.fromfile(fn_in, dtype=np.uint16)
    labels_coarse = np.array([detail2coarse[l] for l in labels], dtype=np.uint16)

    labels_coarse.tofile(fn_out)
    print(fn_out)
    
    
def saveSchema(out_dir, layerNames, npoint, ndim):
    schema = OrderedDict()
    schema['archIndex'] = 0
    schema['labels'] = {
    #     'id': 'labels',
        'id': 'labels-coarse',
        'shape': npoint,
    }
    schema['layers'] = []

    for layerName in layerNames:
        layerMeta = {
            'id': layerName,
            'shape': [npoint, ndim],
#             'residual': 0 ## TODO
        }
        schema['layers'].append(layerMeta)

    with open(out_dir + '/schema.json', 'w') as f:
        json.dump(schema, f, indent=2)
    print(out_dir + '/schema.json')
    


def procrustes_similarity(traveler, bed, centralize=True):
    if centralize:
        traveler -= traveler.mean(0)
        bed -= bed.mean(0)
    
    if type(traveler).__name__ == 'ndarray':
        # numpy
        norm = np.linalg.norm
    else:
        # pytorch
        norm = torch.linalg.norm
    sim = norm(traveler.T @ bed, 'nuc') / (
        norm(traveler.T @ traveler, 'nuc') * norm(bed.T @ bed, 'nuc')
    )**0.5
    return sim

In [None]:
def saveAttributes(loader, model, npoint, out_dir, device):
    ### labels
    with open('data/imagenet_class_index.json') as f:
        index2wnidlabel = json.load(f)

    coarse2detail = OrderedDict()
    with open('data/imagenet_coarse_categories.csv') as f:
        f.readline() ## skip the header
        for line in f:
            line = line.strip().split(',')
            try:
                stop = line.index('')
                coarse2detail[line[0]] = line[1:stop]
            except:
                coarse2detail[line[0]] = line[1:]

    coarse2detail['vehicle'][8] = 'crane (machine)'
    coarse2detail['clothing'][25] = 'maillot (tank suit)'

    coarse_to_coarseIndex = {k:i for i,k in enumerate(coarse2detail.keys())}
    coarseIndex_to_coarse = {v:k for k,v in coarse_to_coarseIndex.items()}

    detail_to_detailIndex = {v[1]:int(k) for k,v in index2wnidlabel.items()}
    detail_to_detailIndex['crane (bird)'] = detail_to_detailIndex['crane'] ## fix

    detail2coarse = {}
    for c,ds in coarse2detail.items():
        for d in ds:
            detail2coarse[detail_to_detailIndex[d]] = coarse_to_coarseIndex[c]
    
    
    
    ## 'confidence' and error to the corect label given by Softmax layer
    model.eval()
    model.to(device)
    softmax = nn.Softmax(1)
    ce = nn.CrossEntropyLoss(reduction='none')
    
#     labels = [index2wnidlabel[str(i)][1] for i in dataset.targets]
#     labels_coarse = [coarseIndex_to_coarse[detail2coarse[i]] for i in dataset.targets]
    labels = []
    labels_coarse = []
    confidence = []
    error = []
    total = 0
    with torch.no_grad():
        for img, target in tqdm(loader):
            img, target = img.to(device), target.to(device)
            out = model(img)
            loss = ce(out, target)
            pred = softmax(out)
            conf = pred.gather(1, target.view(-1,1)).view(-1)
            
            for i in target:
                i = i.item()
                labels.append(index2wnidlabel[str(i)][1])
                labels_coarse.append(coarseIndex_to_coarse[detail2coarse[i]])
            error.append(loss.cpu())
            confidence.append(conf.cpu())
            total += img.size(0)
            if total > npoint:
                break
    
    labels = labels[:npoint]
    labels_coarse = labels_coarse[:npoint]
    error = torch.cat(error)[:npoint].numpy()
    confidence = torch.cat(confidence)[:npoint].numpy()
    
    data = np.stack([
        labels, labels_coarse, 
        confidence, error, 
    ], 1)
    columns = [
        'label', 'label_coarse', 
        'confidence', 'error', 
    ]
    df = pd.DataFrame(data, columns=columns)
    df.to_csv(f'{out_dir}/attributes.csv')

## Prepare dataset

In [None]:
print(f'loading dataset from: {dataset_dir}', end='... ')

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
transform = T.Compose([
    T.Resize(256), 
    T.CenterCrop(224), 
    T.ToTensor(),
    T.Normalize(mean, std),
])
dataset = datasets.ImageFolder(dataset_dir, transform=transform)
loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)

print('DONE')

## Load model

In [None]:
print(f'loading model', end='... ')

modelName = 'googlenet'
model = models.googlenet(pretrained=True, aux_logits=False)

# modelName = 'resnet50'
# model = models.resnet50(pretrained=True)

model = model.to(device)
model.eval()

print(f'{modelName} with ~{sum(p.numel() for p in model.parameters())/1e6:.1f}M', 'parameters, DONE')

## Define layers of interest

In [None]:
##resnet 50/101
# modules = [
#     Input(),
#     model.conv1, 
#     model.bn1, 
#     model.relu, 
#     model.maxpool, 
#     *model.layer1,
#     *model.layer2,
#     *model.layer3,
#     *model.layer4,
#     model.avgpool,
#     Flatten(),
#     model.fc,
#     nn.Softmax(dim=1)
# ]


##GoogLeNet (inception v1)
modules = [
    Input(),
    model.conv1,
    model.maxpool1,
    model.conv2,
    model.conv3,
    model.maxpool2,
    model.inception3a,
    model.inception3b,
    model.maxpool3,
    model.inception4a,
    model.inception4b,
    model.inception4c,
    model.inception4d,
    model.inception4e,
    model.maxpool4,
    model.inception5a,
    model.inception5b,
    model.avgpool,
    Flatten(),
#     model.dropout,
    model.fc,
    nn.Softmax(dim=1)
]


name_template = '{}-{}'
layerNames = [name_template.format(i, str(m).split('(')[0]) for i,m in enumerate(list(modules), 0)]
modules_concat = [nn.Sequential(*modules[:i+1]) for i in range(len(modules))]
print('layers:', layerNames)

In [None]:
## show pooling dimensions
print('Pooling dimensions:')



test_input = torch.rand([1,3,224,224]).to(device)
y = test_input
for name, m in zip(layerNames, modules):
    y = m(y)
    print(name)
    print(f'original -->', y.size())
    if len(y.size()) == 4:
        nchannel = y.size(1)
        width = y.size(3)
        if y.view(1,-1).size(1) > maxDim:
            target_width = math.floor(math.sqrt(maxDim / nchannel))
            target_width = min(maxTargetWidth, target_width)
            if target_width == 0:
                target_width = 1
            global_pooling = nn.AdaptiveAvgPool2d([target_width, target_width])
            y_pool = global_pooling(y)
        else:
            y_pool = y
        print('  pooled -->', y_pool.size())
        print(' flatten -->', y_pool.view(1,-1).size())
    print('-'*20)

## UMAP

In [None]:
startLayer = 0
stopLayer = len(modules)

out_subdir = f'{out_dir}/umap-{modelName}-{npoint}x{ndim}D/'
if not Path(out_subdir).exists():
    os.makedirs(out_subdir)
print('Output dir:', out_subdir)

saveLabels(dataset.targets, out_subdir)
detail2coarse(out_subdir)

getUmap(loader, 
        modules_concat[startLayer:stopLayer], 
        layerNames[startLayer:stopLayer], 
        npoint, ndim, maxDim, maxTargetWidth, 
        out_subdir,
        seed=42)
saveSchema(out_subdir, layerNames, npoint, ndim)
saveAttributes(loader, model, npoint, out_subdir, device)

## save image tiles

In [None]:
imageSize = 32
nrow = 100 ## num of image per row
ncol = 100


dir_in = dataset_dir
dir_out = f'{out_dir}/images'
if not Path(dir_out).exists():
    os.makedirs(dir_out)
    
transform = T.Compose([
    T.Resize(256), 
    T.CenterCrop(224),
    T.Resize(imageSize), 
    T.ToTensor(),
])
dataset = datasets.ImageFolder(dir_in, transform=transform)
image_loader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=8, pin_memory=True)

maxBatchCount = math.ceil(npoint / loader.batch_size)

imagePerPatch = nrow*ncol

all_imgs = []
for batchIndex, [imgs, targets] in enumerate(tqdm(image_loader)):
    all_imgs.append(imgs)
    if batchIndex >= maxBatchCount:
        break
all_imgs = torch.cat(all_imgs)[:npoint]

patchCount = math.ceil(npoint/imagePerPatch)
for i in tqdm(range(patchCount)):
    grid = make_grid(all_imgs[i*imagePerPatch:(i+1)*imagePerPatch], nrow=nrow, padding=0)
    if i < patchCount-1:
        canvas = grid
    else:
        canvas = torch.zeros([3, ncol*imageSize, nrow*imageSize])
        canvas[:, :grid.shape[1], :grid.shape[2]] = grid
    save_image(canvas, f'{dir_out}/imagenet-{imageSize}-{ncol}x{nrow}-{npoint}-{i}.png')

## (Optional) Similairy matrix using UMAP+Procrustes

In [None]:
# modelName1 = 'inceptionv3'
# modelName2 = 'vgg16'

# dir1 = 
# dir2 = 
# fns1 = natsorted([f for f in glob(f'{dir1}/*.bin') if 'labels' not in f])
# fns2 = natsorted([f for f in glob(f'{dir2}/*.bin') if 'labels' not in f])
# fns1, fns2

### npoint, ndim = 50000, 15
# proMatrix = np.zeros([len(fns1),len(fns2)])

In [None]:
# for i, fn_i in enumerate(tqdm(fns1)):
#     act_i = np.fromfile(fn_i, dtype=np.float32).reshape([-1, ndim])
#     for j, fn_j in enumerate(fns2):
#         if proMatrix[i,j] > 0:## if already computed
#             continue
#         else:
#             act_j = np.fromfile(fn_j, dtype=np.float32).reshape([-1, ndim])
#             proMatrix[i,j] = procrustes_similarity(
#                 act_i, 
#                 act_j, 
#                 centralize=True
#             )
# fn = f'data/vis/similarity-matrix/sim-{modelName1}-vs-{modelName2}-{len(fns1)}x{len(fns2)}.bin'
# proMatrix.astype(np.float32).tofile(fn)

# plt.imshow(proMatrix)
# print(fn)