In [3]:
import torch
from torchvision.models import convnext_base, ConvNeXt_Base_Weights
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from tqdm import tqdm
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader, SequentialSampler


import pickle
import os

import numpy as np
from sklearn.decomposition import PCA

import sys
sys.path.append("..")
from src.image_data import ImageClassificationDataset
from src.quantize import cluster_feat

### Load ConvNeXt Model

In [None]:
DATA_PATH = '/mnt/ssd/ronak/datasets/cifar10'
DEVICE = 'cuda:0'

In [3]:
model = convnext_base(weights=ConvNeXt_Base_Weights.IMAGENET1K_V1).to(DEVICE)
train_nodes, eval_nodes = get_graph_node_names(model)

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


In [4]:
return_nodes = {
    # node_name: user-specified key for output dict
    'avgpool': 'features',
}
body = create_feature_extractor(model, return_nodes=return_nodes)

In [5]:
root = DATA_PATH
x_train = np.transpose(np.load(os.path.join(root, "x_train.npy")), axes=[0, 3, 1, 2])
y_train = np.load(os.path.join(root, "y_train.npy"))
print(x_train.shape)

batch_size = 256
transforms = ConvNeXt_Base_Weights.IMAGENET1K_V1.transforms()
train_dataset = ImageClassificationDataset(x_train, y_train, transforms)
dataloader = DataLoader(
    train_dataset, sampler=SequentialSampler(train_dataset), batch_size=batch_size
)

(50000, 3, 32, 32)


In [6]:
all_image_features, all_labels, all_idx = [], [], []
with torch.no_grad():
    for i, batch in tqdm(enumerate(dataloader)):
        idx, images, labels = batch
        image_features = body(images.to(DEVICE))['features'].squeeze()
        image_features /= image_features.norm(dim=-1, keepdim=True)
        all_image_features.append(image_features)
        all_labels.append(labels)
        all_idx.append(idx)
        
all_image_features = torch.cat(all_image_features).cpu().detach().numpy()
all_labels = torch.cat(all_labels).cpu().detach().numpy()
all_idx = torch.cat(all_idx).cpu().detach().numpy()

196it [06:56,  2.12s/it]


In [7]:
torch.save(all_image_features, os.path.join(DATA_PATH, "convnext_base_features.pt"))
torch.save(all_labels, os.path.join(DATA_PATH, "convnext_base_labels.pt"))
torch.save(all_idx, os.path.join(DATA_PATH, "convnext_base_idx.pt"))

## Perform Quantization

In [12]:
NUM_CLUSTERS = 50
SEED = 11182023
DATASET = "cifar10"
DATA_PATH = f'/mnt/ssd/ronak/datasets/{DATASET}'

In [13]:
all_image_features = torch.load(os.path.join(DATA_PATH, "convnext_base_features.pt"))
all_labels = torch.load(os.path.join(DATA_PATH, "convnext_base_labels.pt"))
all_idx = torch.load(os.path.join(DATA_PATH, "convnext_base_idx.pt"))

In [14]:
image_labels, image_cluster = cluster_feat(all_image_features, NUM_CLUSTERS, seed=SEED)

label_to_idx = np.argsort(all_idx)
print(all_idx[label_to_idx])

# have the labels correspond to the indices in order.
image_labels_sorted = image_labels[label_to_idx]
class_labels_sorted = all_labels[label_to_idx]

print(image_labels_sorted.shape)
print(class_labels_sorted.shape)

[    0     1     2 ... 49997 49998 49999]
(50000,)
(50000,)


In [15]:
model_name = "convnext_base"
save_dir = f'/mnt/ssd/ronak/datasets/{DATASET}/quantization/{model_name}_kmeans_{NUM_CLUSTERS}'

os.makedirs(save_dir, exist_ok=True)

np.save(os.path.join(save_dir, f'image_labels.npy'), image_labels_sorted)
np.save(os.path.join(save_dir, f'class_labels.npy'), class_labels_sorted)

_, counts = np.unique(all_labels, return_counts=True)
y_marginal = counts/np.sum(counts)
x_marginal = image_cluster.marginal

np.save(os.path.join(save_dir, f'image_marginal.npy'), x_marginal)
np.save(os.path.join(save_dir, f'class_marginal.npy'), y_marginal)