In [5]:
import clip
import json
import torch
import numpy as np
from torch import nn
from itertools import chain

In [6]:
with open('datasets/CUB/concepts_generated.json', 'r') as fp:
    raw_concepts = json.load(fp)

all_concepts = []
for class_name, concept_dict in raw_concepts.items():
    all_concepts += chain.from_iterable(concept_dict.values())

In [9]:
len(all_concepts), len(set(all_concepts))

(2273, 1612)

In [11]:
all_concepts_sorted = sorted(set(all_concepts))

In [14]:
with open('datasets/CUB/concepts.txt', 'w') as fp:
    fp.write('\n'.join(all_concepts_sorted))

In [17]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clip_model, clip_preprocess = clip.load('ViT-B/16', device=device)

In [18]:
clip_preprocess

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    <function _convert_image_to_rgb at 0x16e3e2700>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

In [27]:
import os
import torch.nn.functional as f
from data.cub.cub_dataset import CUBDataset
from torch.utils.data import DataLoader

dataset_train = CUBDataset(os.path.join('datasets', 'CUB'), num_attrs=312,
                           split='train', transforms=clip_preprocess)
dataset_val = CUBDataset(os.path.join('datasets', 'CUB'), num_attrs=312,
                         split='val', transforms=clip_preprocess)
dataloader_train = DataLoader(dataset=dataset_train, batch_size=4,
                              shuffle=False, num_workers=8)
dataloader_val = DataLoader(dataset=dataset_val, batch_size=4,
                            shuffle=False, num_workers=8)

In [26]:
with open('datasets/CUB/concepts.txt', 'r') as fp:
    concepts = fp.read().splitlines()
concepts_tokenized = clip.tokenize(concepts)
concepts_tokenized.shape

torch.Size([1612, 77])

In [28]:
with torch.no_grad():
    text_features = clip_model.encode_text(concepts_tokenized)
    text_features_norm = f.normalize(text_features, dim=-1)
    all_similarities = []
    class_ids = []
    all_image_features = []
    for i, batch in enumerate(dataloader_val):
        image_features = clip_model.encode_image(batch['pixel_values'])
        image_features_norm = f.normalize(image_features, dim=-1)
        similarities = image_features_norm @ text_features_norm.T
        
        all_image_features.append(image_features_norm)
        all_similarities.append(similarities)
        class_ids.append(batch['class_ids'])
        if i == 2:
            break

In [29]:
torch.cat(all_similarities).shape

torch.Size([12, 1612])

In [37]:
batch['class_ids']

tensor([0, 0, 0, 0])

In [38]:
all_similarities[0].cpu().numpy()

array([[0.2154929 , 0.21587837, 0.19058917, ..., 0.19989917, 0.19759679,
        0.16969772],
       [0.23307317, 0.24768081, 0.2046141 , ..., 0.22584477, 0.21760248,
        0.18832101],
       [0.24319793, 0.21682146, 0.21479067, ..., 0.2012999 , 0.1935352 ,
        0.15541677],
       [0.22313201, 0.26847357, 0.20930922, ..., 0.22100756, 0.21872948,
        0.19040647]], dtype=float32)