In [1]:
import torch
import open_clip
from tqdm import tqdm
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader, SequentialSampler

import pickle
import os

import faiss
import numpy as np
from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize

from datasets import load_dataset

import sys
sys.path.append("..")
from src.text_data import TokenizedTextClassificationDataset
from src.quantize import cluster_feat

In [2]:
DATA_PATH = "/mnt/ssd/ronak/datasets/sst2/"
DEVICE = 'cuda:0'
MODEL_NAME = "vit_b32_laion2b"
DATASET = "sst2"
SEED = 11182023

### Load SST-2 Dataset

In [3]:
%%capture
train = load_dataset('sst2', split='train', cache_dir=DATA_PATH)
val = load_dataset('sst2', split='validation', cache_dir=DATA_PATH)
test = load_dataset('sst2', split='test', cache_dir=DATA_PATH)

In [5]:
np.unique(np.array(test['label']))

array([-1])

In [6]:
train.features
type(train['sentence'])
type(train['label'])

{'idx': Value(dtype='int32', id=None),
 'sentence': Value(dtype='string', id=None),
 'label': ClassLabel(names=['negative', 'positive'], id=None)}

### Use CLIP Tokenizer

In [6]:
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
model.to(DEVICE)
tokenizer = open_clip.get_tokenizer('ViT-B-32')

In [18]:
def get_tokenized_dataset(dset):
    texts = []
    labels = torch.tensor(dset['label']).long()
    for sent in tqdm(dset['sentence']):
        texts.append(tokenizer(sent)[0])
    texts = torch.stack(texts)
    print(texts.shape)
    print(labels.shape)
    return texts.numpy(), labels.numpy()

In [19]:
x_train, y_train = get_tokenized_dataset(train)

100%|██████████| 67349/67349 [00:09<00:00, 7404.64it/s]


torch.Size([67349, 77])
torch.Size([67349])


In [20]:
x_val, y_val = get_tokenized_dataset(val)
x_test, y_test = get_tokenized_dataset(test)

100%|██████████| 872/872 [00:00<00:00, 4515.70it/s]


torch.Size([872, 77])
torch.Size([872])


100%|██████████| 1821/1821 [00:00<00:00, 5026.50it/s]

torch.Size([1821, 77])
torch.Size([1821])





In [23]:
np.save(os.path.join(DATA_PATH, "x_train.npy"), x_train)
np.save(os.path.join(DATA_PATH, "y_train.npy"), y_train)
np.save(os.path.join(DATA_PATH, "x_val.npy"),   x_val)
np.save(os.path.join(DATA_PATH, "y_val.npy"),   y_val)
np.save(os.path.join(DATA_PATH, "x_test.npy"),  x_test)
np.save(os.path.join(DATA_PATH, "y_test.npy"),  y_test)

### Use CLIP Quantization

In [7]:
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
model.to(DEVICE)
tokenizer = open_clip.get_tokenizer('ViT-B-32')

In [8]:
x_train = np.load(os.path.join(DATA_PATH, "x_train.npy"))
y_train = np.load(os.path.join(DATA_PATH, "y_train.npy"))

print(x_train.shape)
print(y_train.shape)

(67349, 77)
(67349,)


In [9]:
batch_size = 256
train_dataset = TokenizedTextClassificationDataset(x_train, y_train)
dataloader = DataLoader(
    train_dataset, sampler=SequentialSampler(train_dataset), batch_size=batch_size
)

In [11]:
all_text_features, all_labels, all_idx = [], [], []
with torch.no_grad():
    for i, batch in tqdm(enumerate(dataloader)):
        idx, texts, labels = batch
        text_features = model.encode_text(texts.to(DEVICE))
        text_features /= text_features.norm(dim=-1, keepdim=True)
        all_text_features.append(text_features)
        all_labels.append(labels)
        all_idx.append(idx)
        
all_text_features = torch.cat(all_text_features).cpu().detach().numpy()
all_labels = torch.cat(all_labels).cpu().detach().numpy()
all_idx = torch.cat(all_idx).cpu().detach().numpy()

torch.save(all_text_features, os.path.join(DATA_PATH, f"{MODEL_NAME}_features.pt"))
torch.save(all_labels, os.path.join(DATA_PATH, f"{MODEL_NAME}_labels.pt"))
torch.save(all_idx, os.path.join(DATA_PATH, f"{MODEL_NAME}_idx.pt"))

264it [01:15,  3.49it/s]


In [16]:
all_text_features = torch.load(os.path.join(DATA_PATH, f"{MODEL_NAME}_features.pt"))
all_labels = torch.load(os.path.join(DATA_PATH, f"{MODEL_NAME}_labels.pt"))
all_idx = torch.load(os.path.join(DATA_PATH, f"{MODEL_NAME}_idx.pt"))

In [20]:
num_clusters = 100

text_labels, text_cluster = cluster_feat(all_text_features, num_clusters, seed=SEED)

label_to_idx = np.argsort(all_idx)
print(all_idx[label_to_idx])

# have the labels correspond to the indices in order.
text_labels_sorted = text_labels[label_to_idx]
class_labels_sorted = all_labels[label_to_idx]

print(text_labels_sorted.shape)
print(class_labels_sorted.shape)

[    0     1     2 ... 67346 67347 67348]
(67349,)
(67349,)


In [21]:
save_dir = f'/mnt/ssd/ronak/datasets/{DATASET}/quantization/{MODEL_NAME}_kmeans_{num_clusters}'

os.makedirs(save_dir, exist_ok=True)

np.save(os.path.join(save_dir, f'text_labels.npy'), text_labels_sorted)
np.save(os.path.join(save_dir, f'class_labels.npy'), class_labels_sorted)

In [28]:
print(x_train.max())

49407


In [27]:
x_train

array([[49406,  9379,   686, ...,     0,     0,     0],
       [49406, 12844,   871, ...,     0,     0,     0],
       [49406,   682,  3925, ...,     0,     0,     0],
       ...,
       [49406,   536, 16120, ...,     0,     0,     0],
       [49406,   320,  6262, ...,     0,     0,     0],
       [49406,   589,   686, ...,     0,     0,     0]])

In [29]:
x_test = np.load(os.path.join(DATA_PATH, "x_test.npy"))
print(x_test.max())

49407
