In [1]:
from torchvision import datasets
from torch.utils.data import DataLoader
from tqdm import tqdm

import numpy as np
import random
import torch
import timm
import os

In [2]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def get_dataloaders(dataset_name, transform, batch_size=32, num_workers=4, download=True):    
    if dataset_name not in ['cifar10', 'cifar100']:
        raise ValueError("Dataset must be 'cifar10' or 'cifar100'")
    
    dataset_class = datasets.CIFAR10 if dataset_name == 'cifar10' else datasets.CIFAR100
    
    train_dataset = dataset_class(root='./data', train=True, download=download, transform=transform)
    test_dataset = dataset_class(root='./data', train=False, download=download, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    dataloaders = {
        'train' : train_loader,
        'test' : test_loader
    }
    
    return dataloaders


def extract_embeddings(model, device, dataloader):

    embeddings_db, labels_db = [], []

    for extracted in tqdm(dataloader):
        images, labels = extracted
        images = images.to(device)
        output = model.forward_features(images)
        output = model.forward_head(output, pre_logits=True)
        labels_db.extend(labels)
        embeddings_db.extend(output.detach().cpu().numpy())

    data = {
        'embeddings': embeddings_db,
        'labels': labels_db
    }

    return data

In [3]:
def create_database(dataset, backbone, seed=42):

    seed_everything(seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # get model from timm
    model = timm.create_model(backbone, pretrained=True, num_classes=0).to(device)
    model.requires_grad_(False) # remove?
    model = model.eval()
    
    # get the required transform function for the given feature extractor
    data_config = timm.data.resolve_model_data_config(model)
    transforms = timm.data.create_transform(**data_config, is_training=False)

    # get dataloaders and filenames
    dataloaders  = get_dataloaders(dataset, transforms)

    # create database folders, if necessary
    os.makedirs(dataset, exist_ok=True)

    for split in ['train','test']:

        # get database of embeddings in the form
        #   db = {'embeddings' : [...], 'labels' : [...], 
        # the filenames are used for explainability purposes    
        db = extract_embeddings( model = model, 
                                 device = device,
                                 dataloader = dataloaders[split])
        
        # store database
        # database_root / dataset / train|test.npz
        np.savez(os.path.join(dataset,f'{split}.npz'), **db)

In [None]:
backbone = "vit_base_patch14_dinov2.lvd142m"
dataset = "cifar10" # or cifar100


create_database(dataset, backbone)