# Clustering using UMAP on the Latent Space

In this notebook, we visualize and interpret the latent space generated by the ViT using UMAP embeddings.
You can change the number of classes to plot different results. In the paper, we use 1, 3 and 10 random classes.

In [132]:
import os
from pathlib import Path
import getpass
import numpy as np
import time
import torch
from torch import nn
from tqdm import tqdm
import random
import sys
from torch.utils.data import DataLoader
from sklearn import preprocessing

# UMAP visualization (takes a while to load)
import umap
import matplotlib.pyplot as plt

# allow imports when running script from within project dir
[sys.path.append(i) for i in ['.', '..']]

# local
from src.helpers.helpers import get_random_indexes, get_random_classes
from src.model.dino_model import get_dino, ViTWrapper
from src.model.data import *

# seed
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

username = getpass.getuser()
DATA_PATH = Path('/','cluster', 'scratch', 'thobauma', 'dl_data')

ORI_PATH = Path(DATA_PATH, 'ori/validation')
ORI_LABEL_PATH = Path(ORI_PATH,'labels.csv')
ORI_IMAGES_PATH = Path(ORI_PATH,'images')

PGD_PATH = Path(DATA_PATH, 'adversarial_data')
PGD_LABEL_PATH = Path(ORI_PATH,'labels.csv')
PGD_IMAGES_PATH = Path(PGD_PATH, 'pgd_03/validation/images/')

In [133]:
# If CLASS_SUBSET is specified, INDEX_SUBSET will be ignored. Set CLASS_SUBSET=None if you want to use indexes.
INDEX_SUBSET = get_random_indexes(n_samples=300) # Use this to get random samples
CLASS_SUBSET = get_random_classes(number_of_classes=3, seed=1) # We are selecting 3 classes for this visualization

BATCH_SIZE = 100

NUM_WORKERS= 0
PIN_MEMORY=True

DEVICE = 'cuda'

In [134]:
CLASS_SUBSET

array([ 38, 236, 909])

In [135]:
label_encoder = preprocessing.LabelEncoder()
label_encoder.fit([i for i in CLASS_SUBSET])

LabelEncoder()

# Import DINO
Official repo: https://github.com/facebookresearch/dino

In [136]:
model, linear_classifier = get_dino()

Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.
Since no pretrained weights have been provided, we load the reference pretrained DINO weights.
Model vit_small built.
Embed dim 1536
We load the reference pretrained linear weights from dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth.


# Load data

In [137]:
org_dataset = ImageDataset(ORI_IMAGES_PATH, ORI_LABEL_PATH, ORIGINAL_TRANSFORM, CLASS_SUBSET, index_subset=None, label_encoder=label_encoder)
adv_dataset = ImageDataset(PGD_IMAGES_PATH, PGD_LABEL_PATH, ONLY_NORMALIZE_TRANSFORM, CLASS_SUBSET, index_subset=None, label_encoder=label_encoder)

org_loader = DataLoader(org_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, shuffle=True)
adv_loader = DataLoader(adv_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, shuffle=False)

In [138]:
model_wrap = ViTWrapper(model, linear_classifier, DEVICE, n_last_blocks=4, avgpool_patchtokens=False)
model_wrap= model_wrap.to(DEVICE)

# Generate input to linear layer

### Adversarial images

In [139]:
result = None
img_labels = None
imgs = None
correct = 0

with torch.no_grad():
    for images, labels, _ in tqdm(adv_loader):
        x = images.to(DEVICE)
        labels = torch.tensor(label_encoder.inverse_transform(labels.cpu())).to(DEVICE)

        # forward
        intermediate_output = model_wrap.vits16.get_intermediate_layers(x, model_wrap.n_last_blocks)
        output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
        if model_wrap.avgpool_patchtokens:
            output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
            output = output.reshape(output.shape[0], -1)
        
        if result is None:
            result = output
            img_labels = labels
            imgs = images
        else:
            result = torch.cat([result, output], 0)
            img_labels = torch.cat([img_labels, labels])
            imgs = torch.cat([imgs, images], 0)
        
        outputs = model_wrap.linear_layer(output)

        _, pre = torch.max(outputs.data, 1)

        correct += (pre == labels).sum()

100%|██████████| 2/2 [00:00<00:00,  2.55it/s]


In [140]:
correct

tensor(1, device='cuda:0')

### Original images

In [141]:
org_result = None
org_img_labels = None
correct = 0

with torch.no_grad():
    for images, labels, _ in tqdm(org_loader):
        x = images.to(DEVICE)
        labels = torch.tensor(label_encoder.inverse_transform(labels.cpu())).to(DEVICE)

        # forward
        intermediate_output = model_wrap.vits16.get_intermediate_layers(x, model_wrap.n_last_blocks)
        output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
        if model_wrap.avgpool_patchtokens:
            output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
            output = output.reshape(output.shape[0], -1)
        
        if org_result is None:
            org_result = output
            org_img_labels = labels
        else:
            org_result = torch.cat([org_result, output], 0)
            org_img_labels = torch.cat([org_img_labels, labels])
            
        outputs = model_wrap.linear_layer(output)

        _, pre = torch.max(outputs.data, 1)

        correct += (pre == labels).sum()

100%|██████████| 2/2 [00:01<00:00,  1.33it/s]


In [142]:
correct

tensor(100, device='cuda:0')

In [143]:
# Map classes
map_dict = {}
inv_map_dict = {}
labels = img_labels.cpu().numpy()
i = 0

for l in labels:
    if l not in map_dict:
        map_dict[l] = i
        inv_map_dict[i] = l
        i = i+1

In [144]:
mapped_labels = [map_dict[i]+3 for i in labels]
org_mapped_labels = [map_dict[i] for i in org_img_labels.cpu().numpy()]

# UMAP visualization

### Adversarial images only

In [146]:
reducer_adv = umap.UMAP(random_state=SEED, n_neighbors = 100, metric='canberra')

In [147]:
reducer_adv.fit(result.cpu().numpy())

UMAP(metric='canberra', n_neighbors=100, random_state=42, tqdm_kwds={'bar_format': '{desc}: {percentage:3.0f}%| {bar} {n_fmt}/{total_fmt} [{elapsed}]', 'desc': 'Epochs completed', 'disable': True})

In [148]:
embedding_adv = reducer_adv.transform(result.cpu().numpy())

In [149]:
# Fancy plot
from bokeh.plotting import figure, show, output_notebook
from bokeh.models import HoverTool, ColumnDataSource, CategoricalColorMapper, Legend
from bokeh.palettes import Spectral10, Category10
import pandas as pd
from io import BytesIO
from PIL import Image
import base64

output_notebook()

In [150]:
df = pd.DataFrame(embedding_adv, columns=('x', 'y'))
df['label'] = [str(x) for x in mapped_labels]
#df = pd.read_csv('df.csv', index_col=0)
df['label'] = [str(x) for x in df.label]
df['true_label'] = df['label'].values
df['true_label'] = df['true_label'].astype(int).replace(inv_map_dict)

datasource = ColumnDataSource(df)
color_mapping = CategoricalColorMapper(factors=[str(9 - x) for x in [i for i in range(10)]],
                                       palette=Category10[10])

plot_figure = figure(
    title='',
    plot_width=600,
    plot_height=400,
    tools=('pan, wheel_zoom, reset, save')
)

plot_figure.add_tools(HoverTool(tooltips="""
<div>
    <div>
        <span style='font-size: 16px; color: #224499'>Label:</span>
        <span style='font-size: 18px'>@true_label</span>
    </div>
</div>
"""))

plot_figure.add_layout(Legend(), 'right')

plot_figure.circle(
    'x',
    'y',
    source=datasource,
    color=dict(field='label', transform=color_mapping),
    line_alpha=0.6,
    fill_alpha=0.6,
    size=6,
    legend_field='true_label'
)

plot_figure.xaxis.visible = False
plot_figure.xgrid.visible = False
plot_figure.yaxis.visible = False
plot_figure.ygrid.visible = False

plot_figure.outline_line_width = 2
plot_figure.outline_line_alpha = 0.8
plot_figure.outline_line_color = "gray"

plot_figure.legend.location = "center"
plot_figure.legend.title = 'True class'
plot_figure.legend.title_text_font_style = "bold"
plot_figure.legend.title_text_font_size = "14px"

show(plot_figure)

### Combined for 3 classes

In [206]:
reducer = umap.UMAP(random_state=SEED, n_neighbors = 100, n_components=2)

In [207]:
combined = torch.cat([result, org_result], 0)
combined_labels = mapped_labels+org_mapped_labels

In [208]:
reducer.fit(combined.cpu().numpy())

UMAP(n_neighbors=100, random_state=42, tqdm_kwds={'bar_format': '{desc}: {percentage:3.0f}%| {bar} {n_fmt}/{total_fmt} [{elapsed}]', 'desc': 'Epochs completed', 'disable': True})

In [209]:
embedding = reducer.transform(combined.cpu().numpy())

In [210]:
# Fancy plot
from bokeh.plotting import figure, show, output_notebook
from bokeh.models import HoverTool, ColumnDataSource, CategoricalColorMapper, Legend
from bokeh.palettes import Spectral10, Category10
import pandas as pd
from io import BytesIO
from PIL import Image
import base64

output_notebook()

In [211]:
df = pd.DataFrame(embedding, columns=('x', 'y'))
df['label'] = [str(x) for x in combined_labels]
#df = pd.read_csv('df.csv', index_col=0)
df['label'] = [str(x) for x in df.label]
df['true_label'] = combined_labels
df['true_label'] = df['true_label'].replace({0: "Class 0 (Orig.)", 1: "Class 1 (Orig.)", 2: "Class 2 (Orig.)", 3: "Class 0 (Adv.)", 4: "Class 1 (Adv.)", 5: "Class 2 (Adv.)"})

datasource = ColumnDataSource(df)
color_mapping = CategoricalColorMapper(factors=[str(6 - x) for x in [i for i in range(6)]],
                                       palette=Category10[6])

plot_figure = figure(
    title='',
    plot_width=600,
    plot_height=400,
    tools=('pan, wheel_zoom, reset, save')
)

plot_figure.add_tools(HoverTool(tooltips="""
<div>
    <div>
        <span style='font-size: 16px; color: #224499'>Label:</span>
        <span style='font-size: 18px'>@true_label</span>
    </div>
</div>
"""))

plot_figure.add_layout(Legend(), 'right')

plot_figure.circle(
    'x',
    'y',
    source=datasource,
    color=dict(field='label', transform=color_mapping),
    line_alpha=0.6,
    fill_alpha=0.6,
    size=6,
    legend_field='true_label'
)

plot_figure.xaxis.visible = False
plot_figure.xgrid.visible = False
plot_figure.yaxis.visible = False
plot_figure.ygrid.visible = False

plot_figure.outline_line_width = 2
plot_figure.outline_line_alpha = 0.8
plot_figure.outline_line_color = "gray"

plot_figure.legend.location = "center"
plot_figure.legend.title = ""
plot_figure.legend.title_text_font_style = "bold"
plot_figure.legend.title_text_font_size = "14px"

show(plot_figure);