In [4]:
# !pip install timm -q

In [5]:
# import timm
import torch
import numpy as np
import os, sys, shutil
import transformers
import requests
from PIL import Image
from pathlib import Path
from scipy.io import loadmat
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm


In [6]:
import torchvision
from matplotlib import pyplot as plt

def show_tensor(tensor, transpose=None, normalize=None, figsize=(10,10), nrow=None, padding=2, verbose=True, **kwargs):
    '''Flexible tool for visulizing tensors of any shape. Support batch_size >= 1.'''
    if not isinstance(tensor, torch.Tensor):
        tensor = torch.tensor(np.array(tensor))
    tensor = tensor.detach().cpu().float()
    
    if tensor.ndim == 4 and tensor.shape[1] == 1:
        if verbose: print('processing as black&white')
        tensor = tensor.repeat(1,3,1,1)
    elif tensor.ndim == 3:
        tensor = tensor.unsqueeze(0)
    elif tensor.ndim == 2:
        if verbose: print('processing as black&white')
        tensor = tensor.unsqueeze(0).repeat(3,1,1).unsqueeze(0)
        
    if normalize is None:
        if tensor.max() <= 1.0 and tensor.min() >= 0.0:
            normalize = False
        else:
            if verbose: print('tensor has been normalized to [0., 1.]')
            normalize = True
            
    if transpose is None:
        transpose = True if tensor.shape[1] != 3 else False
    if transpose:
        tensor = tensor.permute(0,3,1,2)
    
    if nrow is None:
        nrow = int(np.ceil(np.sqrt(tensor.shape[0])))
        
    grid = torchvision.utils.make_grid(tensor, normalize=normalize, nrow=nrow, padding=padding, **kwargs)
    plt.figure(figsize=figsize)
    return plt.imshow(grid.permute(1,2,0))

In [21]:
class CarsDataset(torch.utils.data.Dataset):
    def __init__(self, train_path, train_targets):
        self.train_path = train_path
        self.files = list(map(str, Path(train_path).glob('*.*')))
        self.train_targets = train_targets
        
    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert('RGB')
        return img, train_targets[idx]
    
    def __len__(self):
        return len(self.files)
    
    
def get_features(model, processor, images, pool_mode='auto'):
    inputs = processor(images=images, return_tensors="pt", padding=True)
    inputs = {k:x.to(device) for k,x in inputs.items()}

    with torch.no_grad():
        if hasattr(model, 'get_image_features'):
            # for CLIP, FLAVA
            out = model.get_image_features(**inputs)
        else:
            # for ConvNeXt, BEiT
            out_raw = model(**inputs)
            if pool_mode == 'auto':
                out = out_raw.pooler_output
            elif pool_mode == 'mean':
                out = out_raw.last_hidden_state.mean(1)
        if out.ndim == 3:
            out = out.mean(1)

    return out

    
def get_all_features(loader, model, processor, debug=False, first_for_debug=True):
    if first_for_debug:
        batch = next(iter(loader))
        imgs, targets = zip(*batch)
        out = get_features(model, processor, imgs)
        print(out.shape)
        
    all_features = []

    for batch in tqdm(loader):
        imgs, targets = zip(*batch)
        out = get_features(model, processor, imgs)
        all_features.append(out)
        if debug: break

    all_features = torch.cat(all_features, 0)
    
    return all_features

In [8]:
with open("/kaggle/input/stanford-cars-dataset/car_devkit/devkit/train_perfect_preds.txt") as f:
    train_targets = np.array(list(map(lambda x: int(x), f.readlines())))

car_names = loadmat("/kaggle/input/stanford-cars-dataset/car_devkit/devkit/cars_meta.mat")['class_names'][0]

In [9]:
dataset = CarsDataset("/kaggle/input/stanford-cars-dataset/cars_train/cars_train", train_targets)
loader = DataLoader(dataset, shuffle=False, batch_size=8, collate_fn=lambda x: x, num_workers=2)

In [24]:
device = 'cuda'

model_names = ["openai/clip-vit-base-patch32", "facebook/flava-full", "facebook/convnext-xlarge-224-22k", "microsoft/beit-large-patch16-224-pt22k"]
# model_name = "openai/clip-vit-base-patch32"
model_name = "facebook/flava-full"
# model_name = "facebook/convnext-xlarge-224-22k"
# model_name = "microsoft/beit-large-patch16-224-pt22k"

pool_mode = 'auto' # or 'mean'

In [12]:
model = transformers.AutoModel.from_pretrained(model_name).to(device).eval()
processor = transformers.AutoFeatureExtractor.from_pretrained(model_name)
print(model_name, sum(map(torch.numel, model.parameters()))/1e6, 'M parameters')  

Some weights of the model checkpoint at facebook/flava-full were not used when initializing FlavaModel: ['mlm_head.decoder.weight', 'image_codebook.blocks.group_2.group.block_1.res_path.path.conv_4.bias', 'image_codebook.blocks.output.conv.bias', 'mim_head.transform.LayerNorm.weight', 'image_codebook.blocks.group_3.group.block_2.res_path.path.conv_2.weight', 'mmm_image_head.transform.dense.weight', 'image_codebook.blocks.group_3.group.block_1.res_path.path.conv_3.weight', 'mim_head.bias', 'image_codebook.blocks.group_1.group.block_2.res_path.path.conv_4.bias', 'image_codebook.blocks.group_4.group.block_2.res_path.path.conv_3.weight', 'image_codebook.blocks.group_2.group.block_2.res_path.path.conv_4.weight', 'mmm_image_head.transform.LayerNorm.weight', 'mim_head.decoder.bias', 'mlm_head.transform.dense.weight', 'mmm_text_head.transform.LayerNorm.weight', 'image_codebook.blocks.group_2.group.block_2.res_path.path.conv_2.bias', 'image_codebook.blocks.group_3.group.block_1.res_path.path.co

facebook/flava-full 241.356289 M parameters


In [22]:
f = get_all_features(loader, model, processor, True)

torch.Size([8, 768])


  0%|          | 0/1018 [00:00<?, ?it/s]

In [25]:
all_models_features = []

for model_name in model_names[-1:]:
    model = transformers.AutoModel.from_pretrained(model_name).to(device).eval()
    processor = transformers.AutoFeatureExtractor.from_pretrained(model_name)
    print(model_name, sum(map(torch.numel, model.parameters()))/1e6, 'M parameters')    
    
    features = get_all_features(loader, model, processor)
    torch.save(features, model_name.replace('/','_')+'.pt')
#     all_models_features.append(features)

Downloading:   0%|          | 0.00/737 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.16G [00:00<?, ?B/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Some weights of the model checkpoint at microsoft/beit-large-patch16-224-pt22k were not used when initializing BeitModel: ['lm_head.bias', 'layernorm.weight', 'layernorm.bias', 'lm_head.weight']
- This IS expected if you are initializing BeitModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BeitModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BeitModel were not initialized from the model checkpoint at microsoft/beit-large-patch16-224-pt22k and are newly initialized: ['beit.pooler.layernorm.weight', 'beit.pooler.layernorm.bias']
You should probably TRAIN this model on a down-stream task to be able to use

Downloading:   0%|          | 0.00/276 [00:00<?, ?B/s]

microsoft/beit-large-patch16-224-pt22k 303.137216 M parameters
torch.Size([8, 1024])


  0%|          | 0/1018 [00:00<?, ?it/s]

In [26]:
features.shape

torch.Size([8144, 1024])