In [1]:
import torch
import open_clip
from training.data import get_data, CsvDataset
from training.params import parse_args
from tqdm import tqdm
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader

import pickle
import os

import faiss
import numpy as np
from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize

In [35]:
DATA_PATH = '/mnt/ssd/ronak/datasets/imagenet_captions'
DEVICE = 'cuda:1'

### Load ViT-B/32 Trained on Laion2B and ImageNet Captions Data

In [36]:
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
model.to(DEVICE)
tokenizer = open_clip.get_tokenizer('ViT-B-32')

In [37]:
params = [
    "--train-data", f"{DATA_PATH}/imagenet_captions_train_c10.csv",
    "--val-data", f"{DATA_PATH}/imagenet_captions_val_c10.csv",
    "--dataset-type", "csv",
    "--csv-img-key", "filepath",
    "--csv-caption-key", "title"
]
args = parse_args(params)
args.distributed = False

In [5]:
# data = get_data(
#     args,
#     (preprocess_train, preprocess_val),
#     epoch=0,
#     tokenizer=tokenizer,
# )
# dataloader = data['train'].dataloader

In [38]:
# Create dataloader from scratch to not drop last
is_train = True
input_filename = args.train_data if is_train else args.val_data
assert input_filename
dataset = CsvDataset(
    input_filename,
    preprocess_train,
    img_key=args.csv_img_key,
    caption_key=args.csv_caption_key,
    sep=args.csv_separator,
    tokenizer=tokenizer,
)
num_samples = len(dataset)
sampler = DistributedSampler(dataset) if args.distributed and is_train else None
shuffle = is_train and sampler is None

dataloader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    shuffle=shuffle,
    num_workers=args.workers,
    pin_memory=True,
    sampler=sampler,
    drop_last=False, # TODO: Changed from original
)
dataloader.num_samples = num_samples
dataloader.num_batches = len(dataloader)

### Perform Quantization

In [40]:
all_image_features, all_text_features, all_idx = [], [], []
with torch.no_grad():
    for i, batch in tqdm(enumerate(dataloader)):
        idx, images, texts = batch
        image_features = model.encode_image(images.to(DEVICE))
        text_features = model.encode_text(texts.to(DEVICE))
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        all_image_features.append(image_features)
        all_text_features.append(text_features)
        all_idx.append(idx)
        
all_image_features = torch.cat(all_image_features).cpu().detach().numpy()
all_text_features = torch.cat(all_text_features).cpu().detach().numpy()
all_idx = torch.cat(all_idx).cpu().detach().numpy()

191it [00:34,  5.48it/s]


In [41]:
print(all_image_features.shape)
print(all_text_features.shape)
print(all_idx.shape)

(12218, 512)
(12218, 512)
(12218,)


In [42]:
class KMeans(object):
    
    def __init__(self, norm, pca, idx, index, marginal):
        self.norm = norm
        self.pca = pca
        self.idx = idx
        self.index = index
        self.marginal = marginal
    
    def clustering(self, features):
        data = features
        if self.norm in ['l2', 'l1']:
            data = normalize(data, norm=self.norm, axis=1)
        data = self.pca.transform(data)[:, :self.idx+1]
        data = data.astype(np.float32)
        _, labels = self.index.search(data, 1)
        labels = labels.reshape(-1)
        return labels


def cluster_feat(features, num_clusters,
                 norm='none', whiten=True,
                 pca_max_data=-1,
                 explained_variance=0.9,
                 num_redo=5, max_iter=500, seed=0):
    assert 0 < explained_variance < 1
    assert norm in ['none', 'l2', 'l1', None]
    data1 = features
    if norm in ['l2', 'l1']:
        data1 = normalize(data1, norm=norm, axis=1)
    pca = PCA(n_components=None, whiten=whiten, random_state=seed+1)
    if pca_max_data < 0 or pca_max_data >= data1.shape[0]:
        pca.fit(data1)
    elif 0 < pca_max_data < data1.shape[0]:
        rng = np.random.RandomState(seed+5)
        idxs = rng.choice(data1.shape[0], size=pca_max_data, replace=False)
        pca.fit(data1[idxs])
    else:
        raise ValueError(f'Invalid argument pca_max_data={pca_max_data} with {data1.shape[0]} datapoints')
    s = np.cumsum(pca.explained_variance_ratio_)
    idx = np.argmax(s >= explained_variance)  # last index to consider
    data1 = pca.transform(data1)[:, :idx+1]
    # Cluster
    data1 = data1.astype(np.float32)
    kmeans = faiss.Kmeans(data1.shape[1], num_clusters, niter=max_iter,
                          nredo=num_redo, update_index=True, seed=seed+2, min_points_per_centroid=50)
    kmeans.train(data1)
    index = kmeans.index
    _, labels = index.search(data1, 1)
    
    # Drop clusters with low frequency
    ids, counts = np.unique(labels, return_counts=True)
    to_remove = ids[counts < 50]
    if len(to_remove) > 0:
        index.remove_ids(to_remove)
        _, labels = index.search(data1, 1)
    
    _, counts = np.unique(labels, return_counts=True)
    cluster = KMeans(norm, pca, idx, index, counts/np.sum(counts))
    return labels.reshape(-1), cluster

In [43]:
NUM_CLUSTERS = 50
SEED = 4282022

In [44]:
image_labels, image_cluster = cluster_feat(all_image_features, NUM_CLUSTERS, seed=SEED)

In [45]:
text_labels, text_cluster = cluster_feat(all_text_features, NUM_CLUSTERS, seed=SEED)

In [49]:
label_to_idx = np.argsort(all_idx)
all_idx[label_to_idx]

array([    0,     1,     2, ..., 12215, 12216, 12217])

In [51]:
# have the labels correspond to the indices in order.
image_labels_sorted = image_labels[label_to_idx]
text_labels_sorted = text_labels[label_to_idx]

In [53]:
DATA_PATH = f'/mnt/ssd/ronak/datasets/imagenet_captions/quantization/vit_b32_laion2b_kmeans_{NUM_CLUSTERS}'

# with open(os.path.join(DATA_PATH, f'vit_b32_laion2b_kmeans_{NUM_CLUSTERS}_image.p'), 'wb+') as f:
#     pickle.dump(image_cluster, f)

# with open(os.path.join(DATA_PATH, f'vit_b32_laion2b_kmeans_{NUM_CLUSTERS}_text.p'), 'wb+') as f:
#     pickle.dump(text_cluster, f)

np.save(os.path.join(DATA_PATH, f'image_labels.npy'), image_labels_sorted)
np.save(os.path.join(DATA_PATH, f'text_labels.npy'), text_labels_sorted)

In [54]:
len(image_cluster.marginal)

48

In [55]:
np.save(os.path.join(DATA_PATH, f'image_marginal.npy'), image_cluster.marginal)
np.save(os.path.join(DATA_PATH, f'text_marginal.npy'), text_cluster.marginal)