In [1]:
import deeplake
from torchvision import transforms, models

ds_train = deeplake.load('hub://activeloop/pacs-train')
ds_test = deeplake.load('hub://activeloop/pacs-test')



hub://activeloop/pacs-train loaded successfully.
This dataset can be visualized in Jupyter Notebook by ds.visualize() or at https://app.activeloop.ai/activeloop/pacs-train
hub://activeloop/pacs-test loaded successfully.
This dataset can be visualized in Jupyter Notebook by ds.visualize() or at https://app.activeloop.ai/activeloop/pacs-test


In [2]:
import torch
from torch import nn 
from torch.utils.data import DataLoader

import clip

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)

In [3]:
def get_image_features(clip_model, images):
    num_image_layer = clip_model.visual.transformer.layers
    images = images.to(device)

    out_list = []
    x = clip_model.visual.conv1(images.type(clip_model.dtype))
    x = x.reshape(x.shape[0], x.shape[1], -1)   # shape = [*, width, grid ** 2]
    x = x.permute(0, 2, 1)                      # shape = [*, grid ** 2, width]
    x = torch.cat([clip_model.visual.class_embedding.to(x.dtype) + 
                   torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
    x = x + clip_model.visual.positional_embedding.to(x.dtype)
    x = clip_model.visual.ln_pre(x)
    x = x.permute(1, 0, 2)          # NLD -> LND

    for i in range(num_image_layer):
        x = clip_model.visual.transformer.resblocks[i](x)
        tmp = x.permute(1, 0, 2)    # LND -> NLD
        tmp = tmp[:, 0, :].detach()
        out_list.append(tmp)

    image_features = torch.stack(out_list)

    return image_features

def get_text_features(clip_model, texts):
    num_text_layer = clip_model.transformer.layers
    texts = texts.to(device)

    out_list = []
    x = clip_model.token_embedding(texts).type(clip_model.dtype)  # [batch_size, n_ctx, d_clip_model]
    x = x + clip_model.positional_embedding.type(clip_model.dtype)
    x = x.permute(1, 0, 2)                  # NLD -> LND

    for i in range(num_text_layer):
        x = clip_model.transformer.resblocks[i](x)
        tmp = x.permute(1, 0, 2).detach()   # LND -> NLD
        out_list.append(tmp)

    text_features = torch.stack(out_list)

    return text_features

In [4]:
def get_image_features_and_labels(data_loader):
    image_features_list = []
    labels_list = []

    for i, data in enumerate(data_loader):
        images = data['images']
        labels = torch.squeeze(data['labels'])

        images = images.to(device)
        labels = labels.to(device)

        image_features = get_image_features(clip_model, images)
        
        labels_list.append(labels)
        image_features_list.append(image_features)
    
    return image_features_list, labels_list

In [5]:
batch_size = 32
tform = transforms.Compose([
    # transforms.RandomRotation(20), # Image augmentation
    transforms.ToTensor(), # Must convert to pytorch tensor for subsequent operations to run
    transforms.Normalize([0.5], [0.5]),
])

train_loader = ds_train.pytorch(num_workers = 0, shuffle = True, 
                                transform = {'images': tform, 'labels': None}, 
                                batch_size = batch_size, decode_method = {'images': 'pil'})
test_loader = ds_test.pytorch(num_workers = 0, transform = {'images': tform, 'labels': None}, 
                                batch_size = batch_size, decode_method = {'images': 'pil'})

In [None]:
train_image_features_list, train_labels_list = get_image_features_and_labels(train_loader)
test_image_features_list, test_labels_list = get_image_features_and_labels(test_loader)

In [None]:
import pickle

## save pickle
with open('train_image_features_list.pickle', 'wb') as fw: 
    pickle.dump(train_image_features_list, fw)
with open('train_labels_list.pickle', 'wb') as fw:
    pickle.dump(train_labels_list, fw)
with open('test_image_features_list.pickle', 'wb') as fw:
    pickle.dump(test_image_features_list, fw)
with open('test_labels_list.pickle', 'wb') as fw:
    pickle.dump(test_labels_list, fw)

In [6]:
import pickle

def save_images(data_loader, split=1, type='train'):
    images_list = []
    labels_list = []
    save_idx = 0
    size = int(len(data_loader)/split)

    for i, data in enumerate(data_loader):
        images = data['images']
        images = images.to(device)

        labels = torch.squeeze(data['labels'])
        labels = labels.to(device)
        
        images_list.append(images)
        labels_list.append(labels)
        
        if (i+1) % size == 0:
            with open(f'{type}_images_list{save_idx}.pickle', 'wb') as fw: 
                pickle.dump(images_list, fw)
            images_list = []
            save_idx += 1

    with open(f'{type}_images_list{save_idx}.pickle', 'wb') as fw: 
        pickle.dump(images_list, fw)
    
    return labels_list

def load_images_list(type='train', i=0):
    with open(f'{type}_images_list{i}.pickle', 'rb') as fr: 
        images_list = pickle.load(fr)
    return images_list

In [8]:
train_labels_list = save_images(train_loader, type="train", split=3)

with open('train_labels_list0.pickle', 'wb') as fw:
    pickle.dump(train_labels_list, fw)

In [7]:
test_labels_list = save_images(test_loader, type="test", split=3)

with open('test_labels_list0.pickle', 'wb') as fw:
    pickle.dump(test_labels_list, fw)

In [9]:
pacs_class = [
    'a dog', 'an elephant', 'a giraffe', 'a guitar', 'a horse', 'a house', 'a person'
]

def prompt(idx):
    return f"An image of {pacs_class[idx]}"

prompts = [prompt(x) for x in range(7)]
print(prompts)

['An image of a dog', 'An image of an elephant', 'An image of a giraffe', 'An image of a guitar', 'An image of a horse', 'An image of a house', 'An image of a person']


In [11]:
class_tokens = clip.tokenize([prompt(x) for x in range(7)]).to(device)
with open('class_tokens.pickle', 'wb') as fw:
    pickle.dump(class_tokens, fw)

In [10]:
class_features = get_text_features(clip_model, class_tokens)
with open('class_features.pickle', 'wb') as fw:
    pickle.dump(class_features, fw)