In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
!pip install -qU kmeans_pytorch timm

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
%cd /content/drive/MyDrive/GR

/content/drive/MyDrive/GR


In [4]:
import torch
from torchvision import datasets
from torch.utils.data import Dataset
from functools import partial
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import models as torchvision_models
import os
import shutil
from pathlib import Path
from torch.cuda.amp import autocast
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from kmeans_pytorch import kmeans

import logging
import torch.nn.functional as F
import torch.distributed as dist

import numpy as np
import torchvision
from math import sqrt

import timm.models.vision_transformer
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torchvision_archs = sorted(name for name in torchvision_models.__dict__
                           if name.islower() and not name.startswith("__")
                           and callable(torchvision_models.__dict__[name]))

print(device)

cuda


In [6]:
class ReturnIndexDataset(datasets.ImageFolder):
    def __getitem__(self, idx):
        img, lab = super(ReturnIndexDataset, self).__getitem__(idx)
        # path = super(ReturnIndexDataset, self).samples[idx]
        return idx, img, lab, idx

In [7]:
def find_class_means(X, labels, num_clusters):
    dim = X[0].shape[0]
    labels_sum = {i: torch.zeros(dim) for i in range(num_clusters)}
    labels_count = {i: 0 for i in range(num_clusters)}
    for i in range(len(X)):
        tensor = X[i]
        label = int(labels[i].item())
        labels_sum[label] += tensor
        labels_count[label] += 1
    labels_mean_tensor = torch.zeros((num_clusters, dim))
    for i in range(num_clusters):
        labels_mean_tensor[i] = labels_sum[i] / labels_count[i]
    return labels, labels_mean_tensor


In [8]:
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [24]:
def find_desired_samples(reps, indices, labels, base_dataset, target_dataset, cluster_centers, cluster_ids_x, quantile):
    res_values = []
    res_indices = []
    res_class_labels = []
    res_cluster_labels = []

    batch_size = 16
    num_clusters = len(cluster_centers)
    reps_dataset = CustomDataset(reps.detach())
    reps_dataloader = DataLoader(reps_dataset, batch_size=batch_size, shuffle=False)

    indices = torch.squeeze(indices)
    labels = torch.squeeze(labels)
    cluster_ids_x = torch.squeeze(cluster_ids_x)
    cluster_centers = cluster_centers.to(device)

    # calculate norm
    i = 0
    for tensor in tqdm(reps_dataloader, desc='Calculating norms'):
        tensor = tensor.to(device)
        norm_tensor = torch.linalg.norm(tensor.unsqueeze(dim=1) - cluster_centers.unsqueeze(dim=0), dim=2).detach()
        norm_tensor, norm_tensor_indecies = torch.sort(norm_tensor, dim=1)
        res_values += (norm_tensor[:, 0] - norm_tensor[:, 1] - norm_tensor[:, 2]).tolist()
        res_indices += (indices[batch_size * i: (i + 1) * batch_size]).tolist()
        res_class_labels += (labels[batch_size * i: (i + 1) * batch_size]).tolist()
        res_cluster_labels += norm_tensor_indecies[:, 0].tolist()
        i += 1

    # reordering samples and finding quantiles baesd on each class
    cluster_scores = {k: [res_values[i] for i in range(len(res_values)) if int(res_cluster_labels[i]) == k] for k in
                        range(len(cluster_centers))}

    # save representation's distribution histogram
    save_histograms(cluster_scores)

    quantiles = {k: torch.quantile(torch.tensor(cluster_scores[k]), q=quantile) for k in
                    range(num_clusters) if len(cluster_scores[k]) != 0}
    score_dicts = {int(res_indices[i]): (res_values[i], int(res_class_labels[i]), int(res_cluster_labels[i])) for i
                    in
                    range(len(res_values))}
    results_based_on_class = {i: [] for i in range(len(target_dataset.classes))}

    # finding images which are in the quantile period
    for k, v in tqdm(score_dicts.items(), desc='Finding images in quntile'):
        if v[0] > quantiles[v[2]].item():
            results_based_on_class[v[1]].append(k)

    # find path of desired samples
    img_paths = {}
    for idx, img, label, ind in tqdm(target_dataset, desc='Gathering paths of desired samples'):
        image_path = target_dataset.samples[idx][0]
        if ind in results_based_on_class[label]:
            try:
                img_paths[label].append(image_path)
            except KeyError:
                img_paths[label] = [image_path]

    return img_paths


def save_histograms(cluster_scores):
    histograms_path = 'histograms'
    if os.path.exists(histograms_path):
        shutil.rmtree(histograms_path)
    os.mkdir(histograms_path)

    for cls, scores in cluster_scores.items():
        sns.distplot(scores)
        plt.title('class :' + str(cls))
        plt.savefig(os.path.join(histograms_path, 'class ' + str(cls) + '.jpg'))
        plt.clf()


def reverse_normalization(images):
    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])
    un_normalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
    return un_normalize(images)


def save_outputs(dst_path, dataset, farthest_samples_paths):
    try:
        shutil.rmtree(dst_path)
    except FileNotFoundError:
        pass
    Path(dst_path).mkdir(parents=True, exist_ok=True)
    for cls in dataset.classes:
        Path(os.path.join(dst_path, str(cls))).mkdir(parents=True, exist_ok=True)
    for cls, paths in farthest_samples_paths.items():
        for i, path in enumerate(paths):
            shutil.copy(path, os.path.join(dst_path, dataset.classes[cls]))


def generate_representations(batch_size, model, dataloader, dataset, desc=''):
    model.eval()

    reps = torch.zeros((len(dataloader) * batch_size, 1000))
    indices = torch.zeros((len(dataloader) * batch_size, 1))
    labels = torch.zeros((len(dataloader) * batch_size, 1))
    i = 0
    for idx, tensor, label, index in tqdm(dataloader, desc=desc):
        tensor = tensor.to(device)
        with autocast(enabled=True):
            feats = model(tensor)
        reps[i * batch_size: min((i + 1) * batch_size, len(dataset))] = feats.detach().cpu()
        labels[i * batch_size: min((i + 1) * batch_size, len(dataset))] = label[:, None]
        indices[i * batch_size: min((i + 1) * batch_size, len(dataset))] = index[:, None]
        i += 1
    return reps, indices, labels

In [10]:
# pretrained resnet-18 model
model = torch.hub.load('facebookresearch/swav:main', 'resnet50', pretrained=True)
model = model.to(device)

Downloading: "https://github.com/facebookresearch/swav/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar" to /root/.cache/torch/hub/checkpoints/swav_800ep_pretrain.pth.tar
100%|██████████| 108M/108M [00:00<00:00, 274MB/s]


In [11]:
# cifar_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)

# # Create directories for each class
# classes = cifar_dataset.classes
# data_dir = './cifar10_data'

# for cls in classes:
#     os.makedirs(os.path.join(data_dir, cls), exist_ok=True)

# import torchvision.transforms.functional as TF

# # Move images to respective class directories
# for idx, (image, label) in enumerate(cifar_dataset):
#     class_dir = os.path.join(data_dir, classes[label])
#     image_path = os.path.join(class_dir, f"img_{idx}.jpg")
#     tensor_image = TF.to_tensor(image)  # Convert PIL image to tensor
#     torchvision.utils.save_image(tensor_image, image_path)

# print("CIFAR10 dataset downloaded and organized successfully.")

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 48468316.91it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
CIFAR10 dataset downloaded and organized successfully.


In [12]:
def get_data(data_path):
    transform = transforms.Compose([
        transforms.Resize(256, interpolation=3),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    dataset = ReturnIndexDataset(data_path, transform=transform)
    dataloader = DataLoader(
        dataset,
        batch_size=16,
        num_workers=2,
        pin_memory=True,
        drop_last=False,
        shuffle=True
    )
    return dataloader, dataset

In [13]:
dataloader, dataset = get_data("/content/drive/MyDrive/GR/cifar10_data/")

In [14]:
reps, indices, labels = generate_representations(
    batch_size=16,
    model=model,
    dataloader=dataloader,
    dataset=dataset,
    desc='Generating representations'
)

  self.pid = os.fork()
Generating representations: 100%|██████████| 3125/3125 [05:46<00:00,  9.02it/s]


In [15]:
data_size, dims = reps.shape
num_clusters = len(dataset.classes)

cluster_ids_x, cluster_centers = kmeans(X=reps,
                                        num_clusters=num_clusters,
                                        distance='euclidean',
                                        device=device,
                                        tol=1e-5)

running k-means on cuda..


[running kmeans]: 83it [04:16,  3.09s/it, center_shift=0.000005, iteration=83, tol=0.000010]


In [18]:
farthest_samples_paths = find_desired_samples(reps=reps,
                                                indices=indices,
                                                labels=labels,
                                                base_dataset=dataset,
                                                target_dataset=dataset,
                                                cluster_centers=cluster_centers,
                                                cluster_ids_x=cluster_ids_x,
                                                quantile=0.9
                                              )

Calculating norms: 100%|██████████| 3125/3125 [00:01<00:00, 1997.71it/s]

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level function with
similar flexibility) or `histplot` (an axes-level function for histograms).

For a guide to updating your code to use the new functions, please see
https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751

  sns.distplot(scores)

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level function with
similar flexibility) or `histplot` (an axes-level function for histograms).

For a guide to updating your code to use the new functions, please see
https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751

  sns.distplot(scores)

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level fu

<Figure size 640x480 with 0 Axes>

In [19]:
save_outputs(dst_path='/content/drive/MyDrive/GR/hard_samples_90/',
                 dataset=dataset,
                 farthest_samples_paths=farthest_samples_paths)

In [20]:
!find cifar10_data -mindepth 1 -type d -exec sh -c 'echo -n "{}: "; ls -1 "{}" | wc -l' \;

cifar10_data/airplane: 5000
cifar10_data/automobile: 5000
cifar10_data/bird: 5000
cifar10_data/cat: 5000
cifar10_data/deer: 5000
cifar10_data/dog: 5000
cifar10_data/frog: 5000
cifar10_data/horse: 5000
cifar10_data/ship: 5000
cifar10_data/truck: 5000


In [21]:
!find hard_samples_90 -mindepth 1 -type d -exec sh -c 'echo -n "{}: "; ls -1 "{}" | wc -l' \;

hard_samples/airplane: 480
hard_samples/automobile: 439
hard_samples/bird: 475
hard_samples/cat: 529
hard_samples/deer: 705
hard_samples/dog: 429
hard_samples/frog: 355
hard_samples/horse: 560
hard_samples/ship: 502
hard_samples/truck: 529


### Quantile 0.5

In [None]:
farthest_samples_paths = find_desired_samples(reps=reps,
                                                indices=indices,
                                                labels=labels,
                                                base_dataset=dataset,
                                                target_dataset=dataset,
                                                cluster_centers=cluster_centers,
                                                cluster_ids_x=cluster_ids_x,
                                                quantile=0.5
                                              )

save_outputs(dst_path='/content/drive/MyDrive/GR/hard_samples_50/',
                 dataset=dataset,
                 farthest_samples_paths=farthest_samples_paths)

Calculating norms: 100%|██████████| 3125/3125 [00:01<00:00, 1797.03it/s]

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level function with
similar flexibility) or `histplot` (an axes-level function for histograms).

For a guide to updating your code to use the new functions, please see
https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751

  sns.distplot(scores)

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level function with
similar flexibility) or `histplot` (an axes-level function for histograms).

For a guide to updating your code to use the new functions, please see
https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751

  sns.distplot(scores)

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level fu

In [None]:
!find hard_samples_50 -mindepth 1 -type d -exec sh -c 'echo -n "{}: "; ls -1 "{}" | wc -l' \;

### Quantile 0.1

In [25]:
farthest_samples_paths = find_desired_samples(reps=reps,
                                                indices=indices,
                                                labels=labels,
                                                base_dataset=dataset,
                                                target_dataset=dataset,
                                                cluster_centers=cluster_centers,
                                                cluster_ids_x=cluster_ids_x,
                                                quantile=0.1
                                              )

save_outputs(dst_path='/content/drive/MyDrive/GR/hard_samples_10/',
                 dataset=dataset,
                 farthest_samples_paths=farthest_samples_paths)

Calculating norms: 100%|██████████| 3125/3125 [00:01<00:00, 2063.33it/s]

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level function with
similar flexibility) or `histplot` (an axes-level function for histograms).

For a guide to updating your code to use the new functions, please see
https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751

  sns.distplot(scores)

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level function with
similar flexibility) or `histplot` (an axes-level function for histograms).

For a guide to updating your code to use the new functions, please see
https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751

  sns.distplot(scores)

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level fu

<Figure size 640x480 with 0 Axes>

In [27]:
!find hard_samples_10 -mindepth 1 -type d -exec sh -c 'echo -n "{}: "; ls -1 "{}" | wc -l' \;

hard_samples_10/airplane: 4498
hard_samples_10/automobile: 4459
hard_samples_10/bird: 4608
hard_samples_10/cat: 4481
hard_samples_10/deer: 4783
hard_samples_10/dog: 4456
hard_samples_10/frog: 4212
hard_samples_10/horse: 4506
hard_samples_10/ship: 4493
hard_samples_10/truck: 4501
